Skip to content

Commit a500f1e

Browse files
authored
CORE-13 feat: Support RT-DETRv4 detection model (Comfy-Org#12748)
1 parent 3f77450 commit a500f1e

7 files changed

Lines changed: 922 additions & 3 deletions

File tree

comfy/ldm/rt_detr/rtdetr_v4.py

Lines changed: 725 additions & 0 deletions
Large diffs are not rendered by default.

comfy/model_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
import comfy.ldm.kandinsky5.model
5353
import comfy.ldm.anima.model
5454
import comfy.ldm.ace.ace_step15
55+
import comfy.ldm.rt_detr.rtdetr_v4
5556

5657
import comfy.model_management
5758
import comfy.patcher_extension
@@ -1957,3 +1958,7 @@ def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
19571958

19581959
def concat_cond(self, **kwargs):
19591960
return None
1961+
1962+
class RT_DETR_v4(BaseModel):
1963+
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
1964+
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)

comfy/model_detection.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -698,6 +698,12 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
698698
dit_config["audio_model"] = "ace1.5"
699699
return dit_config
700700

701+
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4
702+
dit_config = {}
703+
dit_config["image_model"] = "RT_DETR_v4"
704+
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
705+
return dit_config
706+
701707
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
702708
return None
703709

comfy/supported_models.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1734,6 +1734,21 @@ def clip_target(self, state_dict={}):
17341734
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
17351735
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
17361736

1737-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
1737+
1738+
class RT_DETR_v4(supported_models_base.BASE):
1739+
unet_config = {
1740+
"image_model": "RT_DETR_v4",
1741+
}
1742+
1743+
supported_inference_dtypes = [torch.float16, torch.float32]
1744+
1745+
def get_model(self, state_dict, prefix="", device=None):
1746+
out = model_base.RT_DETR_v4(self, device=device)
1747+
return out
1748+
1749+
def clip_target(self, state_dict={}):
1750+
return None
1751+
1752+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4]
17381753

17391754
models += [SVD_img2vid]

comfy_extras/nodes_rtdetr.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
from typing_extensions import override
2+
3+
import torch
4+
from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES
5+
import comfy.model_management
6+
import comfy.utils
7+
from comfy_api.latest import ComfyExtension, io
8+
from torchvision.transforms import ToPILImage, ToTensor
9+
from PIL import ImageDraw, ImageFont
10+
11+
12+
class RTDETR_detect(io.ComfyNode):
13+
@classmethod
14+
def define_schema(cls):
15+
return io.Schema(
16+
node_id="RTDETR_detect",
17+
display_name="RT-DETR Detect",
18+
category="detection/",
19+
search_aliases=["bbox", "bounding box", "object detection", "coco"],
20+
inputs=[
21+
io.Model.Input("model", display_name="model"),
22+
io.Image.Input("image", display_name="image"),
23+
io.Float.Input("threshold", display_name="threshold", default=0.5),
24+
io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."),
25+
io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."),
26+
],
27+
outputs=[
28+
io.BoundingBox.Output("bboxes")],
29+
)
30+
31+
@classmethod
32+
def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput:
33+
B, H, W, C = image.shape
34+
35+
image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled")
36+
37+
comfy.model_management.load_model_gpu(model)
38+
results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts
39+
40+
all_bbox_dicts = []
41+
42+
for det in results:
43+
keep = det['scores'] > threshold
44+
boxes = det['boxes'][keep].cpu()
45+
labels = det['labels'][keep].cpu()
46+
scores = det['scores'][keep].cpu()
47+
48+
bbox_dicts = [
49+
{
50+
"x": float(box[0]),
51+
"y": float(box[1]),
52+
"width": float(box[2] - box[0]),
53+
"height": float(box[3] - box[1]),
54+
"label": COCO_CLASSES[int(label)],
55+
"score": float(score)
56+
}
57+
for box, label, score in zip(boxes, labels, scores)
58+
if class_name == "all" or COCO_CLASSES[int(label)] == class_name
59+
]
60+
bbox_dicts.sort(key=lambda d: d["score"], reverse=True)
61+
all_bbox_dicts.append(bbox_dicts[:max_detections])
62+
63+
return io.NodeOutput(all_bbox_dicts)
64+
65+
66+
class DrawBBoxes(io.ComfyNode):
67+
@classmethod
68+
def define_schema(cls):
69+
return io.Schema(
70+
node_id="DrawBBoxes",
71+
display_name="Draw BBoxes",
72+
category="detection/",
73+
search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"],
74+
inputs=[
75+
io.Image.Input("image", optional=True),
76+
io.BoundingBox.Input("bboxes", force_input=True),
77+
],
78+
outputs=[
79+
io.Image.Output("out_image"),
80+
],
81+
)
82+
83+
@classmethod
84+
def execute(cls, bboxes, image=None) -> io.NodeOutput:
85+
# Normalise to list[list[dict]], then fit to batch size B.
86+
B = image.shape[0] if image is not None else 1
87+
if isinstance(bboxes, dict):
88+
bboxes = [[bboxes]]
89+
elif not isinstance(bboxes, list) or not bboxes:
90+
bboxes = [[]]
91+
elif isinstance(bboxes[0], dict):
92+
bboxes = [bboxes] # flat list → same detections for every image
93+
94+
if len(bboxes) == 1:
95+
bboxes = bboxes * B
96+
bboxes = (bboxes + [[]] * B)[:B]
97+
98+
if image is None:
99+
B = len(bboxes)
100+
max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640)
101+
max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640)
102+
image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32)
103+
104+
all_out_images = []
105+
for i in range(B):
106+
detections = bboxes[i]
107+
if detections:
108+
boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections])
109+
labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections]
110+
scores = torch.tensor([d.get("score", 1.0) for d in detections])
111+
else:
112+
boxes = torch.zeros((0, 4))
113+
labels = []
114+
scores = torch.zeros((0,))
115+
116+
pil_image = image[i].movedim(-1, 0)
117+
img = ToPILImage()(pil_image)
118+
if detections:
119+
img = cls.draw_detections(img, boxes, labels, scores)
120+
all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1))
121+
122+
out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device())
123+
return io.NodeOutput(out_images)
124+
125+
@classmethod
126+
def draw_detections(cls, img, boxes, labels, scores):
127+
draw = ImageDraw.Draw(img)
128+
try:
129+
font = ImageFont.truetype('arial.ttf', 16)
130+
except Exception:
131+
font = ImageFont.load_default()
132+
colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128),
133+
(0,255,255),(255,20,147),(100,149,237)]
134+
for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()):
135+
x1, y1, x2, y2 = box.tolist()
136+
color_idx = COCO_CLASSES.index(label) if label is not None else 0
137+
c = colors[color_idx % len(colors)]
138+
draw.rectangle([x1, y1, x2, y2], outline=c, width=3)
139+
if label is not None:
140+
draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font)
141+
return img
142+
143+
144+
class RTDETRExtension(ComfyExtension):
145+
@override
146+
async def get_node_list(self) -> list[type[io.ComfyNode]]:
147+
return [
148+
RTDETR_detect,
149+
DrawBBoxes,
150+
]
151+
152+
153+
async def comfy_entrypoint() -> RTDETRExtension:
154+
return RTDETRExtension()

comfy_extras/nodes_sdpose.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,14 +661,15 @@ def define_schema(cls):
661661
io.Int.Input("output_width", default=512, min=64, max=4096, step=8, tooltip="Width each crop is resized to."),
662662
io.Int.Input("output_height", default=512, min=64, max=4096, step=8, tooltip="Height each crop is resized to."),
663663
io.Int.Input("padding", default=0, min=0, max=1024, step=1, tooltip="Extra padding in pixels added on each side of the bbox before cropping."),
664+
io.Combo.Input("keep_aspect", options=["stretch", "pad"], default="stretch", tooltip="Whether to stretch the crop to fit the output size, or pad with black pixels to preserve aspect ratio."),
664665
],
665666
outputs=[
666667
io.Image.Output(tooltip="All crops stacked into a single image batch."),
667668
],
668669
)
669670

670671
@classmethod
671-
def execute(cls, image, bboxes, output_width, output_height, padding) -> io.NodeOutput:
672+
def execute(cls, image, bboxes, output_width, output_height, padding, keep_aspect="stretch") -> io.NodeOutput:
672673
total_frames = image.shape[0]
673674
img_h = image.shape[1]
674675
img_w = image.shape[2]
@@ -716,7 +717,19 @@ def execute(cls, image, bboxes, output_width, output_height, padding) -> io.Node
716717
x1, y1, x2, y2 = fb_x1, fb_y1, fb_x2, fb_y2
717718

718719
crop_chw = frame_chw[:, :, y1:y2, x1:x2] # (1, C, crop_h, crop_w)
719-
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
720+
721+
if keep_aspect == "pad":
722+
crop_h, crop_w = y2 - y1, x2 - x1
723+
scale = min(output_width / crop_w, output_height / crop_h)
724+
scaled_w = int(round(crop_w * scale))
725+
scaled_h = int(round(crop_h * scale))
726+
scaled = comfy.utils.common_upscale(crop_chw, scaled_w, scaled_h, upscale_method="bilinear", crop="disabled")
727+
pad_left = (output_width - scaled_w) // 2
728+
pad_top = (output_height - scaled_h) // 2
729+
resized = torch.zeros(1, num_ch, output_height, output_width, dtype=image.dtype, device=image.device)
730+
resized[:, :, pad_top:pad_top + scaled_h, pad_left:pad_left + scaled_w] = scaled
731+
else: # "stretch"
732+
resized = comfy.utils.common_upscale(crop_chw, output_width, output_height, upscale_method="bilinear", crop="disabled")
720733
crops.append(resized)
721734

722735
if not crops:

nodes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2457,6 +2457,7 @@ async def init_builtin_extra_nodes():
24572457
"nodes_number_convert.py",
24582458
"nodes_painter.py",
24592459
"nodes_curve.py",
2460+
"nodes_rtdetr.py"
24602461
]
24612462

24622463
import_failed = []

0 commit comments

Comments
 (0)