Skip to content

Commit 428c323

Browse files
authored
[Partner Nodes] new OpenAI Image node with DynamicCombo and Autogrow (Comfy-Org#13838)
1 parent 46063aa commit 428c323

1 file changed

Lines changed: 313 additions & 0 deletions

File tree

comfy_api_nodes/nodes_openai.py

Lines changed: 313 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ApiEndpoint,
2828
download_url_to_bytesio,
2929
downscale_image_tensor,
30+
get_number_of_images,
3031
poll_op,
3132
sync_op,
3233
tensor_to_base64_string,
@@ -372,6 +373,7 @@ def define_schema(cls):
372373
display_name="OpenAI GPT Image 2",
373374
category="api node/image/OpenAI",
374375
description="Generates images synchronously via OpenAI's GPT Image endpoint.",
376+
is_deprecated=True,
375377
inputs=[
376378
IO.String.Input(
377379
"prompt",
@@ -640,6 +642,316 @@ async def execute(
640642
return IO.NodeOutput(await validate_and_cast_response(response))
641643

642644

645+
def _gpt_image_shared_inputs():
646+
"""Inputs shared by all GPT Image models (quality + reference images + mask)."""
647+
return [
648+
IO.Combo.Input(
649+
"quality",
650+
default="low",
651+
options=["low", "medium", "high"],
652+
tooltip="Image quality, affects cost and generation time.",
653+
),
654+
IO.Autogrow.Input(
655+
"images",
656+
template=IO.Autogrow.TemplateNames(
657+
IO.Image.Input("image"),
658+
names=[f"image_{i}" for i in range(1, 17)],
659+
min=0,
660+
),
661+
tooltip="Optional reference image(s) for image editing. Up to 16 images.",
662+
),
663+
IO.Mask.Input(
664+
"mask",
665+
optional=True,
666+
tooltip="Optional mask for inpainting (white areas will be replaced). "
667+
"Requires exactly one reference image.",
668+
),
669+
]
670+
671+
672+
def _gpt_image_legacy_model_inputs():
673+
"""Per-model widget set for legacy gpt-image-1 / gpt-image-1.5 (4 base sizes, transparent bg allowed)."""
674+
return [
675+
IO.Combo.Input(
676+
"size",
677+
default="auto",
678+
options=["auto", "1024x1024", "1024x1536", "1536x1024"],
679+
tooltip="Image size.",
680+
),
681+
IO.Combo.Input(
682+
"background",
683+
default="auto",
684+
options=["auto", "opaque", "transparent"],
685+
tooltip="Return image with or without background.",
686+
),
687+
*_gpt_image_shared_inputs(),
688+
]
689+
690+
691+
class OpenAIGPTImageNodeV2(IO.ComfyNode):
692+
693+
@classmethod
694+
def define_schema(cls):
695+
return IO.Schema(
696+
node_id="OpenAIGPTImageNodeV2",
697+
display_name="OpenAI GPT Image 2",
698+
category="api node/image/OpenAI",
699+
description="Generates images via OpenAI's GPT Image endpoint.",
700+
inputs=[
701+
IO.String.Input(
702+
"prompt",
703+
default="",
704+
multiline=True,
705+
tooltip="Text prompt for GPT Image",
706+
),
707+
IO.DynamicCombo.Input(
708+
"model",
709+
options=[
710+
IO.DynamicCombo.Option(
711+
"gpt-image-2",
712+
[
713+
IO.Combo.Input(
714+
"size",
715+
default="auto",
716+
options=[
717+
"auto",
718+
"1024x1024",
719+
"1024x1536",
720+
"1536x1024",
721+
"2048x2048",
722+
"2048x1152",
723+
"1152x2048",
724+
"3840x2160",
725+
"2160x3840",
726+
"Custom",
727+
],
728+
tooltip="Image size. Select 'Custom' to use the custom width and height.",
729+
),
730+
IO.Int.Input(
731+
"custom_width",
732+
default=1024,
733+
min=1024,
734+
max=3840,
735+
step=16,
736+
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
737+
),
738+
IO.Int.Input(
739+
"custom_height",
740+
default=1024,
741+
min=1024,
742+
max=3840,
743+
step=16,
744+
tooltip="Used only when `size` is 'Custom'. Must be a multiple of 16.",
745+
),
746+
IO.Combo.Input(
747+
"background",
748+
default="auto",
749+
options=["auto", "opaque"],
750+
tooltip="Return image with or without background.",
751+
),
752+
*_gpt_image_shared_inputs(),
753+
],
754+
),
755+
IO.DynamicCombo.Option("gpt-image-1.5", _gpt_image_legacy_model_inputs()),
756+
IO.DynamicCombo.Option("gpt-image-1", _gpt_image_legacy_model_inputs()),
757+
],
758+
),
759+
IO.Int.Input(
760+
"n",
761+
default=1,
762+
min=1,
763+
max=8,
764+
step=1,
765+
tooltip="How many images to generate",
766+
display_mode=IO.NumberDisplay.number,
767+
),
768+
IO.Int.Input(
769+
"seed",
770+
default=0,
771+
min=0,
772+
max=2147483647,
773+
step=1,
774+
display_mode=IO.NumberDisplay.number,
775+
control_after_generate=True,
776+
tooltip="not implemented yet in backend",
777+
),
778+
],
779+
outputs=[IO.Image.Output()],
780+
hidden=[
781+
IO.Hidden.auth_token_comfy_org,
782+
IO.Hidden.api_key_comfy_org,
783+
IO.Hidden.unique_id,
784+
],
785+
is_api_node=True,
786+
price_badge=IO.PriceBadge(
787+
depends_on=IO.PriceBadgeDepends(widgets=["model", "model.quality", "n"]),
788+
expr="""
789+
(
790+
$ranges := {
791+
"gpt-image-1": {
792+
"low": [0.011, 0.02],
793+
"medium": [0.042, 0.07],
794+
"high": [0.167, 0.25]
795+
},
796+
"gpt-image-1.5": {
797+
"low": [0.009, 0.02],
798+
"medium": [0.034, 0.062],
799+
"high": [0.133, 0.22]
800+
},
801+
"gpt-image-2": {
802+
"low": [0.0048, 0.019],
803+
"medium": [0.041, 0.168],
804+
"high": [0.165, 0.67]
805+
}
806+
};
807+
$range := $lookup($lookup($ranges, widgets.model), $lookup(widgets, "model.quality"));
808+
$nRaw := widgets.n;
809+
$n := ($nRaw != null and $nRaw != 0) ? $nRaw : 1;
810+
($n = 1)
811+
? {"type":"range_usd","min_usd": $range[0], "max_usd": $range[1], "format": {"approximate": true}}
812+
: {
813+
"type":"range_usd",
814+
"min_usd": $range[0] * $n,
815+
"max_usd": $range[1] * $n,
816+
"format": { "suffix": "/Run", "approximate": true }
817+
}
818+
)
819+
""",
820+
),
821+
)
822+
823+
@classmethod
824+
async def execute(
825+
cls,
826+
prompt: str,
827+
model: dict,
828+
n: int,
829+
seed: int,
830+
) -> IO.NodeOutput:
831+
validate_string(prompt, strip_whitespace=False)
832+
833+
model_id = model["model"]
834+
size = model["size"]
835+
background = model["background"]
836+
quality = model["quality"]
837+
custom_width = model.get("custom_width", 1024)
838+
custom_height = model.get("custom_height", 1024)
839+
840+
images_dict = model.get("images") or {}
841+
image_tensors: list[Input.Image] = [t for t in images_dict.values() if t is not None]
842+
n_images = sum(get_number_of_images(t) for t in image_tensors)
843+
mask = model.get("mask")
844+
845+
if mask is not None and n_images == 0:
846+
raise ValueError("Cannot use a mask without an input image")
847+
848+
if size == "Custom":
849+
if custom_width % 16 != 0 or custom_height % 16 != 0:
850+
raise ValueError(
851+
f"Custom width and height must be multiples of 16, got {custom_width}x{custom_height}"
852+
)
853+
if max(custom_width, custom_height) > 3840:
854+
raise ValueError(
855+
f"Custom resolution max edge must be <= 3840, got {custom_width}x{custom_height}"
856+
)
857+
ratio = max(custom_width, custom_height) / min(custom_width, custom_height)
858+
if ratio > 3:
859+
raise ValueError(
860+
f"Custom resolution aspect ratio must not exceed 3:1, got {custom_width}x{custom_height}"
861+
)
862+
total_pixels = custom_width * custom_height
863+
if not 655_360 <= total_pixels <= 8_294_400:
864+
raise ValueError(
865+
f"Custom resolution total pixels must be between 655,360 and 8,294,400, got {total_pixels}"
866+
)
867+
size = f"{custom_width}x{custom_height}"
868+
869+
if model_id == "gpt-image-1":
870+
price_extractor = calculate_tokens_price_image_1
871+
elif model_id == "gpt-image-1.5":
872+
price_extractor = calculate_tokens_price_image_1_5
873+
elif model_id == "gpt-image-2":
874+
price_extractor = calculate_tokens_price_image_2_0
875+
else:
876+
raise ValueError(f"Unknown model: {model_id}")
877+
878+
if image_tensors:
879+
flat: list[torch.Tensor] = []
880+
for tensor in image_tensors:
881+
if len(tensor.shape) == 4:
882+
flat.extend(tensor[i : i + 1] for i in range(tensor.shape[0]))
883+
else:
884+
flat.append(tensor.unsqueeze(0))
885+
886+
files = []
887+
for i, single_image in enumerate(flat):
888+
scaled_image = downscale_image_tensor(single_image, total_pixels=2048 * 2048).squeeze()
889+
image_np = (scaled_image.numpy() * 255).astype(np.uint8)
890+
img = Image.fromarray(image_np)
891+
img_byte_arr = BytesIO()
892+
img.save(img_byte_arr, format="PNG")
893+
img_byte_arr.seek(0)
894+
895+
if len(flat) == 1:
896+
files.append(("image", (f"image_{i}.png", img_byte_arr, "image/png")))
897+
else:
898+
files.append(("image[]", (f"image_{i}.png", img_byte_arr, "image/png")))
899+
900+
if mask is not None:
901+
if len(flat) != 1:
902+
raise Exception("Cannot use a mask with multiple image")
903+
ref_image = flat[0]
904+
if mask.shape[1:] != ref_image.shape[1:-1]:
905+
raise Exception("Mask and Image must be the same size")
906+
_, height, width = mask.shape
907+
rgba_mask = torch.zeros(height, width, 4, device="cpu")
908+
rgba_mask[:, :, 3] = 1 - mask.squeeze().cpu()
909+
scaled_mask = downscale_image_tensor(
910+
rgba_mask.unsqueeze(0), total_pixels=2048 * 2048
911+
).squeeze()
912+
mask_np = (scaled_mask.numpy() * 255).astype(np.uint8)
913+
mask_img = Image.fromarray(mask_np)
914+
mask_img_byte_arr = BytesIO()
915+
mask_img.save(mask_img_byte_arr, format="PNG")
916+
mask_img_byte_arr.seek(0)
917+
files.append(("mask", ("mask.png", mask_img_byte_arr, "image/png")))
918+
919+
response = await sync_op(
920+
cls,
921+
ApiEndpoint(path="/proxy/openai/images/edits", method="POST"),
922+
response_model=OpenAIImageGenerationResponse,
923+
data=OpenAIImageEditRequest(
924+
model=model_id,
925+
prompt=prompt,
926+
quality=quality,
927+
background=background,
928+
n=n,
929+
size=size,
930+
moderation="low",
931+
),
932+
content_type="multipart/form-data",
933+
files=files,
934+
price_extractor=price_extractor,
935+
)
936+
else:
937+
response = await sync_op(
938+
cls,
939+
ApiEndpoint(path="/proxy/openai/images/generations", method="POST"),
940+
response_model=OpenAIImageGenerationResponse,
941+
data=OpenAIImageGenerationRequest(
942+
model=model_id,
943+
prompt=prompt,
944+
quality=quality,
945+
background=background,
946+
n=n,
947+
size=size,
948+
moderation="low",
949+
),
950+
price_extractor=price_extractor,
951+
)
952+
return IO.NodeOutput(await validate_and_cast_response(response))
953+
954+
643955
class OpenAIChatNode(IO.ComfyNode):
644956
"""
645957
Node to generate text responses from an OpenAI model.
@@ -999,6 +1311,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
9991311
OpenAIDalle2,
10001312
OpenAIDalle3,
10011313
OpenAIGPTImage1,
1314+
OpenAIGPTImageNodeV2,
10021315
OpenAIChatNode,
10031316
OpenAIInputFiles,
10041317
OpenAIChatConfig,

0 commit comments

Comments
 (0)