Skip to content

Commit 749d5b4

Browse files
authored
feat: SAM (segment anything) 3.1 support (CORE-34) (Comfy-Org#13408)
1 parent e988df7 commit 749d5b4

9 files changed

Lines changed: 3502 additions & 1 deletion

File tree

comfy/ldm/sam3/detector.py

Lines changed: 596 additions & 0 deletions
Large diffs are not rendered by default.

comfy/ldm/sam3/sam.py

Lines changed: 425 additions & 0 deletions
Large diffs are not rendered by default.

comfy/ldm/sam3/tracker.py

Lines changed: 1785 additions & 0 deletions
Large diffs are not rendered by default.

comfy/model_base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
import comfy.ldm.ace.ace_step15
5555
import comfy.ldm.rt_detr.rtdetr_v4
5656
import comfy.ldm.ernie.model
57+
import comfy.ldm.sam3.detector
5758

5859
import comfy.model_management
5960
import comfy.patcher_extension
@@ -1974,3 +1975,7 @@ def extra_conds(self, **kwargs):
19741975
if cross_attn is not None:
19751976
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
19761977
return out
1978+
1979+
class SAM3(BaseModel):
1980+
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
1981+
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)

comfy/model_detection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,14 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
718718
dit_config["image_model"] = "ernie"
719719
return dit_config
720720

721+
if 'detector.backbone.vision_backbone.trunk.blocks.0.attn.qkv.weight' in state_dict_keys: # SAM3 / SAM3.1
722+
if 'detector.transformer.decoder.query_embed.weight' in state_dict_keys:
723+
dit_config = {}
724+
dit_config["image_model"] = "SAM3"
725+
if 'detector.backbone.vision_backbone.propagation_convs.0.conv_1x1.weight' in state_dict_keys:
726+
dit_config["image_model"] = "SAM31"
727+
return dit_config
728+
721729
if '{}input_blocks.0.0.weight'.format(key_prefix) not in state_dict_keys:
722730
return None
723731

@@ -873,6 +881,10 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
873881
return model_config
874882

875883
def unet_prefix_from_state_dict(state_dict):
884+
# SAM3: detector.* and tracker.* at top level, no common prefix
885+
if any(k.startswith("detector.") for k in state_dict) and any(k.startswith("tracker.") for k in state_dict):
886+
return ""
887+
876888
candidates = ["model.diffusion_model.", #ldm/sgm models
877889
"model.model.", #audio models
878890
"net.", #cosmos

comfy/supported_models.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1781,6 +1781,57 @@ def clip_target(self, state_dict={}):
17811781
return supported_models_base.ClipTarget(comfy.text_encoders.ernie.ErnieTokenizer, comfy.text_encoders.ernie.te(**hunyuan_detect))
17821782

17831783

1784-
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage]
1784+
class SAM3(supported_models_base.BASE):
1785+
unet_config = {"image_model": "SAM3"}
1786+
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
1787+
text_encoder_key_prefix = ["detector.backbone.language_backbone."]
1788+
unet_extra_prefix = ""
1789+
1790+
def process_clip_state_dict(self, state_dict):
1791+
clip_keys = getattr(self, "_clip_stash", {})
1792+
clip_keys = utils.state_dict_prefix_replace(clip_keys, {"detector.backbone.language_backbone.": "", "backbone.language_backbone.": ""}, filter_keys=True)
1793+
clip_keys = utils.clip_text_transformers_convert(clip_keys, "encoder.", "sam3_clip.transformer.")
1794+
return {k: v for k, v in clip_keys.items() if not k.startswith("resizer.")}
1795+
1796+
def process_unet_state_dict(self, state_dict):
1797+
self._clip_stash = {k: state_dict.pop(k) for k in list(state_dict.keys()) if "language_backbone" in k and "resizer" not in k}
1798+
# SAM3.1: remap tracker.model.* -> tracker.*
1799+
for k in list(state_dict.keys()):
1800+
if k.startswith("tracker.model."):
1801+
state_dict["tracker." + k[len("tracker.model."):]] = state_dict.pop(k)
1802+
# SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
1803+
for k in [k for k in list(state_dict.keys()) if ".attn.freqs_cis" in k]:
1804+
state_dict.pop(k)
1805+
# Split fused QKV projections
1806+
for k in [k for k in list(state_dict.keys()) if k.endswith((".in_proj_weight", ".in_proj_bias"))]:
1807+
t = state_dict.pop(k)
1808+
base, suffix = k.rsplit(".in_proj_", 1)
1809+
s = ".weight" if suffix == "weight" else ".bias"
1810+
d = t.shape[0] // 3
1811+
state_dict[base + ".q_proj" + s] = t[:d]
1812+
state_dict[base + ".k_proj" + s] = t[d:2*d]
1813+
state_dict[base + ".v_proj" + s] = t[2*d:]
1814+
# Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
1815+
for k in list(state_dict.keys()):
1816+
if "sam_mask_decoder.transformer." not in k:
1817+
continue
1818+
new_k = k.replace(".mlp.lin1.", ".mlp.0.").replace(".mlp.lin2.", ".mlp.2.").replace(".norm_final_attn.", ".norm_final.")
1819+
if new_k != k:
1820+
state_dict[new_k] = state_dict.pop(k)
1821+
return state_dict
1822+
1823+
def get_model(self, state_dict, prefix="", device=None):
1824+
return model_base.SAM3(self, device=device)
1825+
1826+
def clip_target(self, state_dict={}):
1827+
import comfy.text_encoders.sam3_clip
1828+
return supported_models_base.ClipTarget(comfy.text_encoders.sam3_clip.SAM3TokenizerWrapper, comfy.text_encoders.sam3_clip.SAM3ClipModelWrapper)
1829+
1830+
1831+
class SAM31(SAM3):
1832+
unet_config = {"image_model": "SAM31"}
1833+
1834+
1835+
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
17851836

17861837
models += [SVD_img2vid]

comfy/text_encoders/sam3_clip.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import re
2+
from comfy import sd1_clip
3+
4+
SAM3_CLIP_CONFIG = {
5+
"architectures": ["CLIPTextModel"],
6+
"hidden_act": "quick_gelu",
7+
"hidden_size": 1024,
8+
"intermediate_size": 4096,
9+
"num_attention_heads": 16,
10+
"num_hidden_layers": 24,
11+
"max_position_embeddings": 32,
12+
"projection_dim": 512,
13+
"vocab_size": 49408,
14+
"layer_norm_eps": 1e-5,
15+
"eos_token_id": 49407,
16+
}
17+
18+
19+
class SAM3ClipModel(sd1_clip.SDClipModel):
20+
def __init__(self, device="cpu", dtype=None, model_options={}):
21+
super().__init__(device=device, dtype=dtype, max_length=32, layer="last", textmodel_json_config=SAM3_CLIP_CONFIG, special_tokens={"start": 49406, "end": 49407, "pad": 0}, return_projected_pooled=False, return_attention_masks=True, enable_attention_masks=True, model_options=model_options)
22+
23+
24+
class SAM3Tokenizer(sd1_clip.SDTokenizer):
25+
def __init__(self, embedding_directory=None, tokenizer_data={}):
26+
super().__init__(max_length=32, pad_with_end=False, pad_token=0, embedding_directory=embedding_directory, embedding_size=1024, embedding_key="sam3_clip", tokenizer_data=tokenizer_data)
27+
self.disable_weights = True
28+
29+
30+
def _parse_prompts(text):
31+
"""Split comma-separated prompts with optional :N max detections per category"""
32+
text = text.replace("(", "").replace(")", "")
33+
parts = [p.strip() for p in text.split(",") if p.strip()]
34+
result = []
35+
for part in parts:
36+
m = re.match(r'^(.+?)\s*:\s*([\d.]+)\s*$', part)
37+
if m:
38+
text_part = m.group(1).strip()
39+
val = m.group(2)
40+
max_det = max(1, round(float(val)))
41+
result.append((text_part, max_det))
42+
else:
43+
result.append((part, 1))
44+
return result
45+
46+
47+
class SAM3TokenizerWrapper(sd1_clip.SD1Tokenizer):
48+
def __init__(self, embedding_directory=None, tokenizer_data={}):
49+
super().__init__(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data, clip_name="l", tokenizer=SAM3Tokenizer, name="sam3_clip")
50+
51+
def tokenize_with_weights(self, text: str, return_word_ids=False, **kwargs):
52+
parsed = _parse_prompts(text)
53+
if len(parsed) <= 1 and (not parsed or parsed[0][1] == 1):
54+
return super().tokenize_with_weights(text, return_word_ids, **kwargs)
55+
# Tokenize each prompt part separately, store per-part batches and metadata
56+
inner = getattr(self, self.clip)
57+
per_prompt = []
58+
for prompt_text, max_det in parsed:
59+
batches = inner.tokenize_with_weights(prompt_text, return_word_ids, **kwargs)
60+
per_prompt.append((batches, max_det))
61+
# Main output uses first prompt's tokens (for compatibility)
62+
out = {self.clip_name: per_prompt[0][0], "sam3_per_prompt": per_prompt}
63+
return out
64+
65+
66+
class SAM3ClipModelWrapper(sd1_clip.SD1ClipModel):
67+
def __init__(self, device="cpu", dtype=None, model_options={}, **kwargs):
68+
super().__init__(device=device, dtype=dtype, model_options=model_options, clip_name="l", clip_model=SAM3ClipModel, name="sam3_clip")
69+
70+
def encode_token_weights(self, token_weight_pairs):
71+
per_prompt = token_weight_pairs.pop("sam3_per_prompt", None)
72+
if per_prompt is None:
73+
return super().encode_token_weights(token_weight_pairs)
74+
75+
# Encode each prompt separately, pack into extra dict
76+
inner = getattr(self, self.clip)
77+
multi_cond = []
78+
first_pooled = None
79+
for batches, max_det in per_prompt:
80+
out = inner.encode_token_weights(batches)
81+
cond, pooled = out[0], out[1]
82+
extra = out[2] if len(out) > 2 else {}
83+
if first_pooled is None:
84+
first_pooled = pooled
85+
multi_cond.append({
86+
"cond": cond,
87+
"attention_mask": extra.get("attention_mask"),
88+
"max_detections": max_det,
89+
})
90+
91+
# Return first prompt as main (for non-SAM3 consumers), all prompts in metadata
92+
main = multi_cond[0]
93+
main_extra = {}
94+
if main["attention_mask"] is not None:
95+
main_extra["attention_mask"] = main["attention_mask"]
96+
main_extra["sam3_multi_cond"] = multi_cond
97+
return (main["cond"], first_pooled, main_extra)

0 commit comments

Comments
 (0)