|
| 1 | +from typing_extensions import override |
| 2 | + |
| 3 | +import torch |
| 4 | +from comfy.ldm.rt_detr.rtdetr_v4 import COCO_CLASSES |
| 5 | +import comfy.model_management |
| 6 | +import comfy.utils |
| 7 | +from comfy_api.latest import ComfyExtension, io |
| 8 | +from torchvision.transforms import ToPILImage, ToTensor |
| 9 | +from PIL import ImageDraw, ImageFont |
| 10 | + |
| 11 | + |
| 12 | +class RTDETR_detect(io.ComfyNode): |
| 13 | + @classmethod |
| 14 | + def define_schema(cls): |
| 15 | + return io.Schema( |
| 16 | + node_id="RTDETR_detect", |
| 17 | + display_name="RT-DETR Detect", |
| 18 | + category="detection/", |
| 19 | + search_aliases=["bbox", "bounding box", "object detection", "coco"], |
| 20 | + inputs=[ |
| 21 | + io.Model.Input("model", display_name="model"), |
| 22 | + io.Image.Input("image", display_name="image"), |
| 23 | + io.Float.Input("threshold", display_name="threshold", default=0.5), |
| 24 | + io.Combo.Input("class_name", options=["all"] + COCO_CLASSES, default="all", tooltip="Filter detections by class. Set to 'all' to disable filtering."), |
| 25 | + io.Int.Input("max_detections", display_name="max_detections", default=100, tooltip="Maximum number of detections to return per image. In order of descending confidence score."), |
| 26 | + ], |
| 27 | + outputs=[ |
| 28 | + io.BoundingBox.Output("bboxes")], |
| 29 | + ) |
| 30 | + |
| 31 | + @classmethod |
| 32 | + def execute(cls, model, image, threshold, class_name, max_detections) -> io.NodeOutput: |
| 33 | + B, H, W, C = image.shape |
| 34 | + |
| 35 | + image_in = comfy.utils.common_upscale(image.movedim(-1, 1), 640, 640, "bilinear", crop="disabled") |
| 36 | + |
| 37 | + comfy.model_management.load_model_gpu(model) |
| 38 | + results = model.model.diffusion_model(image_in, (W, H)) # list of B dicts |
| 39 | + |
| 40 | + all_bbox_dicts = [] |
| 41 | + |
| 42 | + for det in results: |
| 43 | + keep = det['scores'] > threshold |
| 44 | + boxes = det['boxes'][keep].cpu() |
| 45 | + labels = det['labels'][keep].cpu() |
| 46 | + scores = det['scores'][keep].cpu() |
| 47 | + |
| 48 | + bbox_dicts = [ |
| 49 | + { |
| 50 | + "x": float(box[0]), |
| 51 | + "y": float(box[1]), |
| 52 | + "width": float(box[2] - box[0]), |
| 53 | + "height": float(box[3] - box[1]), |
| 54 | + "label": COCO_CLASSES[int(label)], |
| 55 | + "score": float(score) |
| 56 | + } |
| 57 | + for box, label, score in zip(boxes, labels, scores) |
| 58 | + if class_name == "all" or COCO_CLASSES[int(label)] == class_name |
| 59 | + ] |
| 60 | + bbox_dicts.sort(key=lambda d: d["score"], reverse=True) |
| 61 | + all_bbox_dicts.append(bbox_dicts[:max_detections]) |
| 62 | + |
| 63 | + return io.NodeOutput(all_bbox_dicts) |
| 64 | + |
| 65 | + |
| 66 | +class DrawBBoxes(io.ComfyNode): |
| 67 | + @classmethod |
| 68 | + def define_schema(cls): |
| 69 | + return io.Schema( |
| 70 | + node_id="DrawBBoxes", |
| 71 | + display_name="Draw BBoxes", |
| 72 | + category="detection/", |
| 73 | + search_aliases=["bbox", "bounding box", "object detection", "rt_detr", "visualize detections", "coco"], |
| 74 | + inputs=[ |
| 75 | + io.Image.Input("image", optional=True), |
| 76 | + io.BoundingBox.Input("bboxes", force_input=True), |
| 77 | + ], |
| 78 | + outputs=[ |
| 79 | + io.Image.Output("out_image"), |
| 80 | + ], |
| 81 | + ) |
| 82 | + |
| 83 | + @classmethod |
| 84 | + def execute(cls, bboxes, image=None) -> io.NodeOutput: |
| 85 | + # Normalise to list[list[dict]], then fit to batch size B. |
| 86 | + B = image.shape[0] if image is not None else 1 |
| 87 | + if isinstance(bboxes, dict): |
| 88 | + bboxes = [[bboxes]] |
| 89 | + elif not isinstance(bboxes, list) or not bboxes: |
| 90 | + bboxes = [[]] |
| 91 | + elif isinstance(bboxes[0], dict): |
| 92 | + bboxes = [bboxes] # flat list → same detections for every image |
| 93 | + |
| 94 | + if len(bboxes) == 1: |
| 95 | + bboxes = bboxes * B |
| 96 | + bboxes = (bboxes + [[]] * B)[:B] |
| 97 | + |
| 98 | + if image is None: |
| 99 | + B = len(bboxes) |
| 100 | + max_w = max((int(d["x"] + d["width"]) for frame in bboxes for d in frame), default=640) |
| 101 | + max_h = max((int(d["y"] + d["height"]) for frame in bboxes for d in frame), default=640) |
| 102 | + image = torch.zeros((B, max_h, max_w, 3), dtype=torch.float32) |
| 103 | + |
| 104 | + all_out_images = [] |
| 105 | + for i in range(B): |
| 106 | + detections = bboxes[i] |
| 107 | + if detections: |
| 108 | + boxes = torch.tensor([[d["x"], d["y"], d["x"] + d["width"], d["y"] + d["height"]] for d in detections]) |
| 109 | + labels = [d.get("label") if d.get("label") in COCO_CLASSES else None for d in detections] |
| 110 | + scores = torch.tensor([d.get("score", 1.0) for d in detections]) |
| 111 | + else: |
| 112 | + boxes = torch.zeros((0, 4)) |
| 113 | + labels = [] |
| 114 | + scores = torch.zeros((0,)) |
| 115 | + |
| 116 | + pil_image = image[i].movedim(-1, 0) |
| 117 | + img = ToPILImage()(pil_image) |
| 118 | + if detections: |
| 119 | + img = cls.draw_detections(img, boxes, labels, scores) |
| 120 | + all_out_images.append(ToTensor()(img).unsqueeze(0).movedim(1, -1)) |
| 121 | + |
| 122 | + out_images = torch.cat(all_out_images, dim=0).to(comfy.model_management.intermediate_device()) |
| 123 | + return io.NodeOutput(out_images) |
| 124 | + |
| 125 | + @classmethod |
| 126 | + def draw_detections(cls, img, boxes, labels, scores): |
| 127 | + draw = ImageDraw.Draw(img) |
| 128 | + try: |
| 129 | + font = ImageFont.truetype('arial.ttf', 16) |
| 130 | + except Exception: |
| 131 | + font = ImageFont.load_default() |
| 132 | + colors = [(255,0,0),(0,200,0),(0,0,255),(255,165,0),(128,0,128), |
| 133 | + (0,255,255),(255,20,147),(100,149,237)] |
| 134 | + for box, label, score in sorted(zip(boxes, labels, scores), key=lambda x: x[2].item()): |
| 135 | + x1, y1, x2, y2 = box.tolist() |
| 136 | + color_idx = COCO_CLASSES.index(label) if label is not None else 0 |
| 137 | + c = colors[color_idx % len(colors)] |
| 138 | + draw.rectangle([x1, y1, x2, y2], outline=c, width=3) |
| 139 | + if label is not None: |
| 140 | + draw.text((x1 + 2, y1 + 2), f'{label} {score:.2f}', fill=c, font=font) |
| 141 | + return img |
| 142 | + |
| 143 | + |
| 144 | +class RTDETRExtension(ComfyExtension): |
| 145 | + @override |
| 146 | + async def get_node_list(self) -> list[type[io.ComfyNode]]: |
| 147 | + return [ |
| 148 | + RTDETR_detect, |
| 149 | + DrawBBoxes, |
| 150 | + ] |
| 151 | + |
| 152 | + |
| 153 | +async def comfy_entrypoint() -> RTDETRExtension: |
| 154 | + return RTDETRExtension() |
0 commit comments