1717"""
1818
1919import torch
20- import nodes
20+ from typing_extensions import override
21+
2122import comfy .utils
23+ import nodes
24+ from comfy_api .latest import ComfyExtension , io
2225
2326
24- class StableCascade_EmptyLatentImage :
25- def __init__ (self , device = "cpu" ):
26- self .device = device
27+ class StableCascade_EmptyLatentImage (io .ComfyNode ):
28+ @classmethod
29+ def define_schema (cls ):
30+ return io .Schema (
31+ node_id = "StableCascade_EmptyLatentImage" ,
32+ category = "latent/stable_cascade" ,
33+ inputs = [
34+ io .Int .Input ("width" , default = 1024 , min = 256 , max = nodes .MAX_RESOLUTION , step = 8 ),
35+ io .Int .Input ("height" , default = 1024 , min = 256 , max = nodes .MAX_RESOLUTION , step = 8 ),
36+ io .Int .Input ("compression" , default = 42 , min = 4 , max = 128 , step = 1 ),
37+ io .Int .Input ("batch_size" , default = 1 , min = 1 , max = 4096 ),
38+ ],
39+ outputs = [
40+ io .Latent .Output (display_name = "stage_c" ),
41+ io .Latent .Output (display_name = "stage_b" ),
42+ ],
43+ )
2744
2845 @classmethod
29- def INPUT_TYPES (s ):
30- return {"required" : {
31- "width" : ("INT" , {"default" : 1024 , "min" : 256 , "max" : nodes .MAX_RESOLUTION , "step" : 8 }),
32- "height" : ("INT" , {"default" : 1024 , "min" : 256 , "max" : nodes .MAX_RESOLUTION , "step" : 8 }),
33- "compression" : ("INT" , {"default" : 42 , "min" : 4 , "max" : 128 , "step" : 1 }),
34- "batch_size" : ("INT" , {"default" : 1 , "min" : 1 , "max" : 4096 })
35- }}
36- RETURN_TYPES = ("LATENT" , "LATENT" )
37- RETURN_NAMES = ("stage_c" , "stage_b" )
38- FUNCTION = "generate"
39-
40- CATEGORY = "latent/stable_cascade"
41-
42- def generate (self , width , height , compression , batch_size = 1 ):
46+ def execute (cls , width , height , compression , batch_size = 1 ):
4347 c_latent = torch .zeros ([batch_size , 16 , height // compression , width // compression ])
4448 b_latent = torch .zeros ([batch_size , 4 , height // 4 , width // 4 ])
45- return ({
49+ return io . NodeOutput ({
4650 "samples" : c_latent ,
4751 }, {
4852 "samples" : b_latent ,
4953 })
5054
51- class StableCascade_StageC_VAEEncode :
52- def __init__ (self , device = "cpu" ):
53- self .device = device
5455
56+ class StableCascade_StageC_VAEEncode (io .ComfyNode ):
5557 @classmethod
56- def INPUT_TYPES (s ):
57- return {"required" : {
58- "image" : ("IMAGE" ,),
59- "vae" : ("VAE" , ),
60- "compression" : ("INT" , {"default" : 42 , "min" : 4 , "max" : 128 , "step" : 1 }),
61- }}
62- RETURN_TYPES = ("LATENT" , "LATENT" )
63- RETURN_NAMES = ("stage_c" , "stage_b" )
64- FUNCTION = "generate"
65-
66- CATEGORY = "latent/stable_cascade"
67-
68- def generate (self , image , vae , compression ):
58+ def define_schema (cls ):
59+ return io .Schema (
60+ node_id = "StableCascade_StageC_VAEEncode" ,
61+ category = "latent/stable_cascade" ,
62+ inputs = [
63+ io .Image .Input ("image" ),
64+ io .Vae .Input ("vae" ),
65+ io .Int .Input ("compression" , default = 42 , min = 4 , max = 128 , step = 1 ),
66+ ],
67+ outputs = [
68+ io .Latent .Output (display_name = "stage_c" ),
69+ io .Latent .Output (display_name = "stage_b" ),
70+ ],
71+ )
72+
73+ @classmethod
74+ def execute (cls , image , vae , compression ):
6975 width = image .shape [- 2 ]
7076 height = image .shape [- 3 ]
7177 out_width = (width // compression ) * vae .downscale_ratio
@@ -75,67 +81,82 @@ def generate(self, image, vae, compression):
7581
7682 c_latent = vae .encode (s [:,:,:,:3 ])
7783 b_latent = torch .zeros ([c_latent .shape [0 ], 4 , (height // 8 ) * 2 , (width // 8 ) * 2 ])
78- return ({
84+ return io . NodeOutput ({
7985 "samples" : c_latent ,
8086 }, {
8187 "samples" : b_latent ,
8288 })
8389
84- class StableCascade_StageB_Conditioning :
85- @classmethod
86- def INPUT_TYPES (s ):
87- return {"required" : { "conditioning" : ("CONDITIONING" ,),
88- "stage_c" : ("LATENT" ,),
89- }}
90- RETURN_TYPES = ("CONDITIONING" ,)
91-
92- FUNCTION = "set_prior"
9390
94- CATEGORY = "conditioning/stable_cascade"
91+ class StableCascade_StageB_Conditioning (io .ComfyNode ):
92+ @classmethod
93+ def define_schema (cls ):
94+ return io .Schema (
95+ node_id = "StableCascade_StageB_Conditioning" ,
96+ category = "conditioning/stable_cascade" ,
97+ inputs = [
98+ io .Conditioning .Input ("conditioning" ),
99+ io .Latent .Input ("stage_c" ),
100+ ],
101+ outputs = [
102+ io .Conditioning .Output (),
103+ ],
104+ )
95105
96- def set_prior (self , conditioning , stage_c ):
106+ @classmethod
107+ def execute (cls , conditioning , stage_c ):
97108 c = []
98109 for t in conditioning :
99110 d = t [1 ].copy ()
100- d [' stable_cascade_prior' ] = stage_c [' samples' ]
111+ d [" stable_cascade_prior" ] = stage_c [" samples" ]
101112 n = [t [0 ], d ]
102113 c .append (n )
103- return ( c , )
114+ return io . NodeOutput ( c )
104115
105- class StableCascade_SuperResolutionControlnet :
106- def __init__ (self , device = "cpu" ):
107- self .device = device
108116
117+ class StableCascade_SuperResolutionControlnet (io .ComfyNode ):
109118 @classmethod
110- def INPUT_TYPES (s ):
111- return {"required" : {
112- "image" : ("IMAGE" ,),
113- "vae" : ("VAE" , ),
114- }}
115- RETURN_TYPES = ("IMAGE" , "LATENT" , "LATENT" )
116- RETURN_NAMES = ("controlnet_input" , "stage_c" , "stage_b" )
117- FUNCTION = "generate"
118-
119- EXPERIMENTAL = True
120- CATEGORY = "_for_testing/stable_cascade"
121-
122- def generate (self , image , vae ):
119+ def define_schema (cls ):
120+ return io .Schema (
121+ node_id = "StableCascade_SuperResolutionControlnet" ,
122+ category = "_for_testing/stable_cascade" ,
123+ is_experimental = True ,
124+ inputs = [
125+ io .Image .Input ("image" ),
126+ io .Vae .Input ("vae" ),
127+ ],
128+ outputs = [
129+ io .Image .Output (display_name = "controlnet_input" ),
130+ io .Latent .Output (display_name = "stage_c" ),
131+ io .Latent .Output (display_name = "stage_b" ),
132+ ],
133+ )
134+
135+ @classmethod
136+ def execute (cls , image , vae ):
123137 width = image .shape [- 2 ]
124138 height = image .shape [- 3 ]
125139 batch_size = image .shape [0 ]
126140 controlnet_input = vae .encode (image [:,:,:,:3 ]).movedim (1 , - 1 )
127141
128142 c_latent = torch .zeros ([batch_size , 16 , height // 16 , width // 16 ])
129143 b_latent = torch .zeros ([batch_size , 4 , height // 2 , width // 2 ])
130- return (controlnet_input , {
144+ return io . NodeOutput (controlnet_input , {
131145 "samples" : c_latent ,
132146 }, {
133147 "samples" : b_latent ,
134148 })
135149
136- NODE_CLASS_MAPPINGS = {
137- "StableCascade_EmptyLatentImage" : StableCascade_EmptyLatentImage ,
138- "StableCascade_StageB_Conditioning" : StableCascade_StageB_Conditioning ,
139- "StableCascade_StageC_VAEEncode" : StableCascade_StageC_VAEEncode ,
140- "StableCascade_SuperResolutionControlnet" : StableCascade_SuperResolutionControlnet ,
141- }
150+
151+ class StableCascadeExtension (ComfyExtension ):
152+ @override
153+ async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
154+ return [
155+ StableCascade_EmptyLatentImage ,
156+ StableCascade_StageB_Conditioning ,
157+ StableCascade_StageC_VAEEncode ,
158+ StableCascade_SuperResolutionControlnet ,
159+ ]
160+
161+ async def comfy_entrypoint () -> StableCascadeExtension :
162+ return StableCascadeExtension ()
0 commit comments