Skip to content

Commit f949094

Browse files
authored
convert Stable Cascade nodes to V3 schema (Comfy-Org#9373)
1 parent 4449e14 commit f949094

1 file changed

Lines changed: 93 additions & 72 deletions

File tree

comfy_extras/nodes_stable_cascade.py

Lines changed: 93 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -17,55 +17,61 @@
1717
"""
1818

1919
import torch
20-
import nodes
20+
from typing_extensions import override
21+
2122
import comfy.utils
23+
import nodes
24+
from comfy_api.latest import ComfyExtension, io
2225

2326

24-
class StableCascade_EmptyLatentImage:
25-
def __init__(self, device="cpu"):
26-
self.device = device
27+
class StableCascade_EmptyLatentImage(io.ComfyNode):
28+
@classmethod
29+
def define_schema(cls):
30+
return io.Schema(
31+
node_id="StableCascade_EmptyLatentImage",
32+
category="latent/stable_cascade",
33+
inputs=[
34+
io.Int.Input("width", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
35+
io.Int.Input("height", default=1024, min=256, max=nodes.MAX_RESOLUTION, step=8),
36+
io.Int.Input("compression", default=42, min=4, max=128, step=1),
37+
io.Int.Input("batch_size", default=1, min=1, max=4096),
38+
],
39+
outputs=[
40+
io.Latent.Output(display_name="stage_c"),
41+
io.Latent.Output(display_name="stage_b"),
42+
],
43+
)
2744

2845
@classmethod
29-
def INPUT_TYPES(s):
30-
return {"required": {
31-
"width": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
32-
"height": ("INT", {"default": 1024, "min": 256, "max": nodes.MAX_RESOLUTION, "step": 8}),
33-
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
34-
"batch_size": ("INT", {"default": 1, "min": 1, "max": 4096})
35-
}}
36-
RETURN_TYPES = ("LATENT", "LATENT")
37-
RETURN_NAMES = ("stage_c", "stage_b")
38-
FUNCTION = "generate"
39-
40-
CATEGORY = "latent/stable_cascade"
41-
42-
def generate(self, width, height, compression, batch_size=1):
46+
def execute(cls, width, height, compression, batch_size=1):
4347
c_latent = torch.zeros([batch_size, 16, height // compression, width // compression])
4448
b_latent = torch.zeros([batch_size, 4, height // 4, width // 4])
45-
return ({
49+
return io.NodeOutput({
4650
"samples": c_latent,
4751
}, {
4852
"samples": b_latent,
4953
})
5054

51-
class StableCascade_StageC_VAEEncode:
52-
def __init__(self, device="cpu"):
53-
self.device = device
5455

56+
class StableCascade_StageC_VAEEncode(io.ComfyNode):
5557
@classmethod
56-
def INPUT_TYPES(s):
57-
return {"required": {
58-
"image": ("IMAGE",),
59-
"vae": ("VAE", ),
60-
"compression": ("INT", {"default": 42, "min": 4, "max": 128, "step": 1}),
61-
}}
62-
RETURN_TYPES = ("LATENT", "LATENT")
63-
RETURN_NAMES = ("stage_c", "stage_b")
64-
FUNCTION = "generate"
65-
66-
CATEGORY = "latent/stable_cascade"
67-
68-
def generate(self, image, vae, compression):
58+
def define_schema(cls):
59+
return io.Schema(
60+
node_id="StableCascade_StageC_VAEEncode",
61+
category="latent/stable_cascade",
62+
inputs=[
63+
io.Image.Input("image"),
64+
io.Vae.Input("vae"),
65+
io.Int.Input("compression", default=42, min=4, max=128, step=1),
66+
],
67+
outputs=[
68+
io.Latent.Output(display_name="stage_c"),
69+
io.Latent.Output(display_name="stage_b"),
70+
],
71+
)
72+
73+
@classmethod
74+
def execute(cls, image, vae, compression):
6975
width = image.shape[-2]
7076
height = image.shape[-3]
7177
out_width = (width // compression) * vae.downscale_ratio
@@ -75,67 +81,82 @@ def generate(self, image, vae, compression):
7581

7682
c_latent = vae.encode(s[:,:,:,:3])
7783
b_latent = torch.zeros([c_latent.shape[0], 4, (height // 8) * 2, (width // 8) * 2])
78-
return ({
84+
return io.NodeOutput({
7985
"samples": c_latent,
8086
}, {
8187
"samples": b_latent,
8288
})
8389

84-
class StableCascade_StageB_Conditioning:
85-
@classmethod
86-
def INPUT_TYPES(s):
87-
return {"required": { "conditioning": ("CONDITIONING",),
88-
"stage_c": ("LATENT",),
89-
}}
90-
RETURN_TYPES = ("CONDITIONING",)
91-
92-
FUNCTION = "set_prior"
9390

94-
CATEGORY = "conditioning/stable_cascade"
91+
class StableCascade_StageB_Conditioning(io.ComfyNode):
92+
@classmethod
93+
def define_schema(cls):
94+
return io.Schema(
95+
node_id="StableCascade_StageB_Conditioning",
96+
category="conditioning/stable_cascade",
97+
inputs=[
98+
io.Conditioning.Input("conditioning"),
99+
io.Latent.Input("stage_c"),
100+
],
101+
outputs=[
102+
io.Conditioning.Output(),
103+
],
104+
)
95105

96-
def set_prior(self, conditioning, stage_c):
106+
@classmethod
107+
def execute(cls, conditioning, stage_c):
97108
c = []
98109
for t in conditioning:
99110
d = t[1].copy()
100-
d['stable_cascade_prior'] = stage_c['samples']
111+
d["stable_cascade_prior"] = stage_c["samples"]
101112
n = [t[0], d]
102113
c.append(n)
103-
return (c, )
114+
return io.NodeOutput(c)
104115

105-
class StableCascade_SuperResolutionControlnet:
106-
def __init__(self, device="cpu"):
107-
self.device = device
108116

117+
class StableCascade_SuperResolutionControlnet(io.ComfyNode):
109118
@classmethod
110-
def INPUT_TYPES(s):
111-
return {"required": {
112-
"image": ("IMAGE",),
113-
"vae": ("VAE", ),
114-
}}
115-
RETURN_TYPES = ("IMAGE", "LATENT", "LATENT")
116-
RETURN_NAMES = ("controlnet_input", "stage_c", "stage_b")
117-
FUNCTION = "generate"
118-
119-
EXPERIMENTAL = True
120-
CATEGORY = "_for_testing/stable_cascade"
121-
122-
def generate(self, image, vae):
119+
def define_schema(cls):
120+
return io.Schema(
121+
node_id="StableCascade_SuperResolutionControlnet",
122+
category="_for_testing/stable_cascade",
123+
is_experimental=True,
124+
inputs=[
125+
io.Image.Input("image"),
126+
io.Vae.Input("vae"),
127+
],
128+
outputs=[
129+
io.Image.Output(display_name="controlnet_input"),
130+
io.Latent.Output(display_name="stage_c"),
131+
io.Latent.Output(display_name="stage_b"),
132+
],
133+
)
134+
135+
@classmethod
136+
def execute(cls, image, vae):
123137
width = image.shape[-2]
124138
height = image.shape[-3]
125139
batch_size = image.shape[0]
126140
controlnet_input = vae.encode(image[:,:,:,:3]).movedim(1, -1)
127141

128142
c_latent = torch.zeros([batch_size, 16, height // 16, width // 16])
129143
b_latent = torch.zeros([batch_size, 4, height // 2, width // 2])
130-
return (controlnet_input, {
144+
return io.NodeOutput(controlnet_input, {
131145
"samples": c_latent,
132146
}, {
133147
"samples": b_latent,
134148
})
135149

136-
NODE_CLASS_MAPPINGS = {
137-
"StableCascade_EmptyLatentImage": StableCascade_EmptyLatentImage,
138-
"StableCascade_StageB_Conditioning": StableCascade_StageB_Conditioning,
139-
"StableCascade_StageC_VAEEncode": StableCascade_StageC_VAEEncode,
140-
"StableCascade_SuperResolutionControlnet": StableCascade_SuperResolutionControlnet,
141-
}
150+
151+
class StableCascadeExtension(ComfyExtension):
152+
@override
153+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
154+
return [
155+
StableCascade_EmptyLatentImage,
156+
StableCascade_StageB_Conditioning,
157+
StableCascade_StageC_VAEEncode,
158+
StableCascade_SuperResolutionControlnet,
159+
]
160+
161+
async def comfy_entrypoint() -> StableCascadeExtension:
162+
return StableCascadeExtension()

0 commit comments

Comments
 (0)