Skip to content

Commit 164a9d4

Browse files
authored
[Partner Nodes] add ByteDance Seed LLM node (Comfy-Org#13919)
Signed-off-by: bigcat88 <bigcat88@icloud.com>
1 parent 16f862f commit 164a9d4

2 files changed

Lines changed: 372 additions & 0 deletions

File tree

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Pydantic models for BytePlus ModelArk Responses API.
2+
3+
See: https://docs.byteplus.com/en/docs/ModelArk/1585128 (request)
4+
https://docs.byteplus.com/en/docs/ModelArk/1783703 (response)
5+
"""
6+
7+
from typing import Literal
8+
9+
from pydantic import BaseModel, Field
10+
11+
12+
class BytePlusInputText(BaseModel):
13+
type: Literal["input_text"] = "input_text"
14+
text: str = Field(...)
15+
16+
17+
class BytePlusInputImage(BaseModel):
18+
type: Literal["input_image"] = "input_image"
19+
image_url: str = Field(..., description="Image URL or `data:image/...;base64,...` payload")
20+
detail: str = Field("auto", description="One of high, low, auto")
21+
22+
23+
class BytePlusInputVideo(BaseModel):
24+
type: Literal["input_video"] = "input_video"
25+
video_url: str = Field(..., description="Video URL or `data:video/...;base64,...` payload")
26+
fps: float | None = Field(None, ge=0.2, le=5.0)
27+
28+
29+
BytePlusMessageContent = BytePlusInputText | BytePlusInputImage | BytePlusInputVideo
30+
31+
32+
class BytePlusInputMessage(BaseModel):
33+
type: Literal["message"] = "message"
34+
role: str = Field(..., description="One of user, system, assistant, developer")
35+
content: list[BytePlusMessageContent] = Field(...)
36+
37+
38+
class BytePlusResponseCreateRequest(BaseModel):
39+
model: str = Field(...)
40+
input: list[BytePlusInputMessage] = Field(...)
41+
instructions: str | None = Field(None)
42+
max_output_tokens: int | None = Field(None, ge=1)
43+
temperature: float | None = Field(None, ge=0.0, le=2.0)
44+
store: bool | None = Field(False)
45+
stream: bool | None = Field(False)
46+
47+
48+
class BytePlusOutputText(BaseModel):
49+
type: Literal["output_text"] = "output_text"
50+
text: str = Field(...)
51+
52+
53+
class BytePlusOutputRefusal(BaseModel):
54+
type: Literal["refusal"] = "refusal"
55+
refusal: str = Field(...)
56+
57+
58+
class BytePlusOutputContent(BaseModel):
59+
type: str = Field(...)
60+
text: str | None = Field(None)
61+
refusal: str | None = Field(None)
62+
63+
64+
class BytePlusOutputMessage(BaseModel):
65+
type: str = Field(...)
66+
id: str | None = Field(None)
67+
role: str | None = Field(None)
68+
status: str | None = Field(None)
69+
content: list[BytePlusOutputContent] | None = Field(None)
70+
71+
72+
class BytePlusInputTokensDetails(BaseModel):
73+
cached_tokens: int | None = Field(None)
74+
75+
76+
class BytePlusOutputTokensDetails(BaseModel):
77+
reasoning_tokens: int | None = Field(None)
78+
79+
80+
class BytePlusResponseUsage(BaseModel):
81+
input_tokens: int | None = Field(None)
82+
output_tokens: int | None = Field(None)
83+
total_tokens: int | None = Field(None)
84+
input_tokens_details: BytePlusInputTokensDetails | None = Field(None)
85+
output_tokens_details: BytePlusOutputTokensDetails | None = Field(None)
86+
87+
88+
class BytePlusResponseError(BaseModel):
89+
code: str = Field(...)
90+
message: str = Field(...)
91+
92+
93+
class BytePlusResponseObject(BaseModel):
94+
id: str | None = Field(None)
95+
object: str | None = Field(None)
96+
created_at: int | None = Field(None)
97+
model: str | None = Field(None)
98+
status: str | None = Field(None)
99+
error: BytePlusResponseError | None = Field(None)
100+
output: list[BytePlusOutputMessage] | None = Field(None)
101+
usage: BytePlusResponseUsage | None = Field(None)
Lines changed: 271 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,271 @@
1+
"""API Nodes for ByteDance Seed LLM via the BytePlus ModelArk Responses API.
2+
3+
See: https://docs.byteplus.com/en/docs/ModelArk/1585128
4+
"""
5+
6+
from typing_extensions import override
7+
8+
from comfy_api.latest import IO, ComfyExtension, Input
9+
from comfy_api_nodes.apis.bytedance_llm import (
10+
BytePlusInputImage,
11+
BytePlusInputMessage,
12+
BytePlusInputText,
13+
BytePlusInputVideo,
14+
BytePlusMessageContent,
15+
BytePlusResponseCreateRequest,
16+
BytePlusResponseObject,
17+
)
18+
from comfy_api_nodes.util import (
19+
ApiEndpoint,
20+
get_number_of_images,
21+
sync_op,
22+
upload_images_to_comfyapi,
23+
upload_video_to_comfyapi,
24+
validate_string,
25+
)
26+
27+
BYTEPLUS_RESPONSES_ENDPOINT = "/proxy/byteplus/api/v3/responses"
28+
SEED_MAX_IMAGES = 20
29+
SEED_MAX_VIDEOS = 4
30+
31+
SEED_MODELS: dict[str, str] = {
32+
"Seed 2.0 Pro": "seed-2-0-pro-260328",
33+
"Seed 2.0 Lite": "seed-2-0-lite-260228",
34+
"Seed 2.0 Mini": "seed-2-0-mini-260215",
35+
}
36+
37+
# USD per 1M tokens: (input, cache_hit_input, output)
38+
_SEED_PRICES_PER_MILLION: dict[str, tuple[float, float, float]] = {
39+
"seed-2-0-pro-260328": (0.50, 0.10, 3.00),
40+
"seed-2-0-lite-260228": (0.25, 0.05, 2.00),
41+
"seed-2-0-mini-260215": (0.10, 0.02, 0.40),
42+
}
43+
44+
45+
def _seed_model_inputs(max_images: int = SEED_MAX_IMAGES, max_videos: int = SEED_MAX_VIDEOS):
46+
return [
47+
IO.Autogrow.Input(
48+
"images",
49+
template=IO.Autogrow.TemplateNames(
50+
IO.Image.Input("image"),
51+
names=[f"image_{i}" for i in range(1, max_images + 1)],
52+
min=0,
53+
),
54+
tooltip=f"Optional image(s) to use as context for the model. Up to {max_images} images.",
55+
),
56+
IO.Autogrow.Input(
57+
"videos",
58+
template=IO.Autogrow.TemplateNames(
59+
IO.Video.Input("video"),
60+
names=[f"video_{i}" for i in range(1, max_videos + 1)],
61+
min=0,
62+
),
63+
tooltip=f"Optional video(s) to use as context for the model. Up to {max_videos} videos.",
64+
),
65+
IO.Float.Input(
66+
"temperature",
67+
default=1.0,
68+
min=0.0,
69+
max=2.0,
70+
step=0.01,
71+
tooltip="Controls randomness. 0.0 is deterministic, higher values are more random.",
72+
advanced=True,
73+
),
74+
]
75+
76+
77+
def _calculate_price(model_id: str, response: BytePlusResponseObject) -> float | None:
78+
"""Compute approximate USD price from response usage."""
79+
if not response.usage:
80+
return None
81+
rates = _SEED_PRICES_PER_MILLION.get(model_id)
82+
if rates is None:
83+
return None
84+
input_rate, cache_hit_rate, output_rate = rates
85+
input_tokens = response.usage.input_tokens or 0
86+
output_tokens = response.usage.output_tokens or 0
87+
cached = 0
88+
if response.usage.input_tokens_details:
89+
cached = response.usage.input_tokens_details.cached_tokens or 0
90+
fresh_input = max(0, input_tokens - cached)
91+
total = fresh_input * input_rate + cached * cache_hit_rate + output_tokens * output_rate
92+
return total / 1_000_000.0
93+
94+
95+
def _get_text_from_response(response: BytePlusResponseObject) -> str:
96+
"""Extract concatenated text from all assistant message output_text blocks."""
97+
if not response.output:
98+
return ""
99+
chunks: list[str] = []
100+
for item in response.output:
101+
if item.type != "message" or not item.content:
102+
continue
103+
for block in item.content:
104+
if block.type == "output_text" and block.text:
105+
chunks.append(block.text)
106+
elif block.type == "refusal" and block.refusal:
107+
raise ValueError(f"Model refused to respond: {block.refusal}")
108+
return "\n".join(chunks)
109+
110+
111+
async def _build_image_content_blocks(
112+
cls: type[IO.ComfyNode],
113+
image_tensors: list[Input.Image],
114+
) -> list[BytePlusInputImage]:
115+
urls = await upload_images_to_comfyapi(
116+
cls,
117+
image_tensors,
118+
max_images=SEED_MAX_IMAGES,
119+
wait_label="Uploading reference images",
120+
)
121+
return [BytePlusInputImage(image_url=url) for url in urls]
122+
123+
124+
async def _build_video_content_blocks(
125+
cls: type[IO.ComfyNode],
126+
videos: list[Input.Video],
127+
) -> list[BytePlusInputVideo]:
128+
blocks: list[BytePlusInputVideo] = []
129+
total = len(videos)
130+
for idx, video in enumerate(videos):
131+
label = "Uploading reference video"
132+
if total > 1:
133+
label = f"{label} ({idx + 1}/{total})"
134+
url = await upload_video_to_comfyapi(cls, video, wait_label=label)
135+
blocks.append(BytePlusInputVideo(video_url=url))
136+
return blocks
137+
138+
139+
class ByteDanceSeedNode(IO.ComfyNode):
140+
"""Generate text responses from a ByteDance Seed 2.0 model."""
141+
142+
@classmethod
143+
def define_schema(cls):
144+
return IO.Schema(
145+
node_id="ByteDanceSeedNode",
146+
display_name="ByteDance Seed",
147+
category="api node/text/ByteDance",
148+
essentials_category="Text Generation",
149+
description="Generate text responses with ByteDance's Seed 2.0 models. "
150+
"Provide a text prompt and optionally one or more images or videos for multimodal context.",
151+
inputs=[
152+
IO.String.Input(
153+
"prompt",
154+
multiline=True,
155+
default="",
156+
tooltip="Text input to the model.",
157+
),
158+
IO.DynamicCombo.Input(
159+
"model",
160+
options=[IO.DynamicCombo.Option(label, _seed_model_inputs()) for label in SEED_MODELS],
161+
tooltip="The Seed model used to generate the response.",
162+
),
163+
IO.Int.Input(
164+
"seed",
165+
default=0,
166+
min=0,
167+
max=2147483647,
168+
control_after_generate=True,
169+
tooltip="Seed controls whether the node should re-run; "
170+
"results are non-deterministic regardless of seed.",
171+
),
172+
IO.String.Input(
173+
"system_prompt",
174+
multiline=True,
175+
default="",
176+
optional=True,
177+
advanced=True,
178+
tooltip="Foundational instructions that dictate the model's behavior.",
179+
),
180+
],
181+
outputs=[IO.String.Output()],
182+
hidden=[
183+
IO.Hidden.auth_token_comfy_org,
184+
IO.Hidden.api_key_comfy_org,
185+
IO.Hidden.unique_id,
186+
],
187+
is_api_node=True,
188+
price_badge=IO.PriceBadge(
189+
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
190+
expr="""
191+
(
192+
$m := widgets.model;
193+
$contains($m, "mini") ? {
194+
"type": "list_usd",
195+
"usd": [0.00025, 0.0009],
196+
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
197+
}
198+
: $contains($m, "lite") ? {
199+
"type": "list_usd",
200+
"usd": [0.0003, 0.002],
201+
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
202+
}
203+
: $contains($m, "pro") ? {
204+
"type": "list_usd",
205+
"usd": [0.0005, 0.003],
206+
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
207+
}
208+
: {"type":"text", "text":"Token-based"}
209+
)
210+
""",
211+
),
212+
)
213+
214+
@classmethod
215+
async def execute(
216+
cls,
217+
prompt: str,
218+
model: dict,
219+
seed: int,
220+
system_prompt: str = "",
221+
) -> IO.NodeOutput:
222+
validate_string(prompt, strip_whitespace=True, min_length=1)
223+
model_label = model["model"]
224+
temperature = model["temperature"]
225+
model_id = SEED_MODELS[model_label]
226+
227+
image_tensors: list[Input.Image] = [t for t in (model.get("images") or {}).values() if t is not None]
228+
if sum(get_number_of_images(t) for t in image_tensors) > SEED_MAX_IMAGES:
229+
raise ValueError(f"Up to {SEED_MAX_IMAGES} images are supported per request.")
230+
231+
video_inputs: list[Input.Video] = [v for v in (model.get("videos") or {}).values() if v is not None]
232+
if len(video_inputs) > SEED_MAX_VIDEOS:
233+
raise ValueError(f"Up to {SEED_MAX_VIDEOS} videos are supported per request.")
234+
235+
content: list[BytePlusMessageContent] = []
236+
if image_tensors:
237+
content.extend(await _build_image_content_blocks(cls, image_tensors))
238+
if video_inputs:
239+
content.extend(await _build_video_content_blocks(cls, video_inputs))
240+
content.append(BytePlusInputText(text=prompt))
241+
242+
response = await sync_op(
243+
cls,
244+
ApiEndpoint(path=BYTEPLUS_RESPONSES_ENDPOINT, method="POST"),
245+
response_model=BytePlusResponseObject,
246+
data=BytePlusResponseCreateRequest(
247+
model=model_id,
248+
input=[BytePlusInputMessage(role="user", content=content)],
249+
instructions=system_prompt or None,
250+
temperature=temperature,
251+
store=False,
252+
stream=False,
253+
),
254+
price_extractor=lambda r: _calculate_price(model_id, r),
255+
)
256+
if response.error:
257+
raise ValueError(f"Seed API error ({response.error.code}): {response.error.message}")
258+
result = _get_text_from_response(response)
259+
if not result:
260+
raise ValueError("Empty response from Seed model.")
261+
return IO.NodeOutput(result)
262+
263+
264+
class ByteDanceLLMExtension(ComfyExtension):
265+
@override
266+
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
267+
return [ByteDanceSeedNode]
268+
269+
270+
async def comfy_entrypoint() -> ByteDanceLLMExtension:
271+
return ByteDanceLLMExtension()

0 commit comments

Comments
 (0)