@@ -1781,6 +1781,57 @@ def clip_target(self, state_dict={}):
17811781 return supported_models_base .ClipTarget (comfy .text_encoders .ernie .ErnieTokenizer , comfy .text_encoders .ernie .te (** hunyuan_detect ))
17821782
17831783
1784- models = [LotusD , Stable_Zero123 , SD15_instructpix2pix , SD15 , SD20 , SD21UnclipL , SD21UnclipH , SDXL_instructpix2pix , SDXLRefiner , SDXL , SSD1B , KOALA_700M , KOALA_1B , Segmind_Vega , SD_X4Upscaler , Stable_Cascade_C , Stable_Cascade_B , SV3D_u , SV3D_p , SD3 , StableAudio , AuraFlow , PixArtAlpha , PixArtSigma , HunyuanDiT , HunyuanDiT1 , FluxInpaint , Flux , LongCatImage , FluxSchnell , GenmoMochi , LTXV , LTXAV , HunyuanVideo15_SR_Distilled , HunyuanVideo15 , HunyuanImage21Refiner , HunyuanImage21 , HunyuanVideoSkyreelsI2V , HunyuanVideoI2V , HunyuanVideo , CosmosT2V , CosmosI2V , CosmosT2IPredict2 , CosmosI2VPredict2 , ZImagePixelSpace , ZImage , Lumina2 , WAN22_T2V , WAN21_T2V , WAN21_I2V , WAN21_FunControl2V , WAN21_Vace , WAN21_Camera , WAN22_Camera , WAN22_S2V , WAN21_HuMo , WAN22_Animate , WAN21_FlowRVS , WAN21_SCAIL , Hunyuan3Dv2mini , Hunyuan3Dv2 , Hunyuan3Dv2_1 , HiDream , Chroma , ChromaRadiance , ACEStep , ACEStep15 , Omnigen2 , QwenImage , Flux2 , Kandinsky5Image , Kandinsky5 , Anima , RT_DETR_v4 , ErnieImage ]
1784+ class SAM3 (supported_models_base .BASE ):
1785+ unet_config = {"image_model" : "SAM3" }
1786+ supported_inference_dtypes = [torch .float16 , torch .bfloat16 , torch .float32 ]
1787+ text_encoder_key_prefix = ["detector.backbone.language_backbone." ]
1788+ unet_extra_prefix = ""
1789+
1790+ def process_clip_state_dict (self , state_dict ):
1791+ clip_keys = getattr (self , "_clip_stash" , {})
1792+ clip_keys = utils .state_dict_prefix_replace (clip_keys , {"detector.backbone.language_backbone." : "" , "backbone.language_backbone." : "" }, filter_keys = True )
1793+ clip_keys = utils .clip_text_transformers_convert (clip_keys , "encoder." , "sam3_clip.transformer." )
1794+ return {k : v for k , v in clip_keys .items () if not k .startswith ("resizer." )}
1795+
1796+ def process_unet_state_dict (self , state_dict ):
1797+ self ._clip_stash = {k : state_dict .pop (k ) for k in list (state_dict .keys ()) if "language_backbone" in k and "resizer" not in k }
1798+ # SAM3.1: remap tracker.model.* -> tracker.*
1799+ for k in list (state_dict .keys ()):
1800+ if k .startswith ("tracker.model." ):
1801+ state_dict ["tracker." + k [len ("tracker.model." ):]] = state_dict .pop (k )
1802+ # SAM3.1: remove per-block freqs_cis buffers (computed dynamically)
1803+ for k in [k for k in list (state_dict .keys ()) if ".attn.freqs_cis" in k ]:
1804+ state_dict .pop (k )
1805+ # Split fused QKV projections
1806+ for k in [k for k in list (state_dict .keys ()) if k .endswith ((".in_proj_weight" , ".in_proj_bias" ))]:
1807+ t = state_dict .pop (k )
1808+ base , suffix = k .rsplit (".in_proj_" , 1 )
1809+ s = ".weight" if suffix == "weight" else ".bias"
1810+ d = t .shape [0 ] // 3
1811+ state_dict [base + ".q_proj" + s ] = t [:d ]
1812+ state_dict [base + ".k_proj" + s ] = t [d :2 * d ]
1813+ state_dict [base + ".v_proj" + s ] = t [2 * d :]
1814+ # Remap tracker SAM decoder transformer key names to match sam.py TwoWayTransformer
1815+ for k in list (state_dict .keys ()):
1816+ if "sam_mask_decoder.transformer." not in k :
1817+ continue
1818+ new_k = k .replace (".mlp.lin1." , ".mlp.0." ).replace (".mlp.lin2." , ".mlp.2." ).replace (".norm_final_attn." , ".norm_final." )
1819+ if new_k != k :
1820+ state_dict [new_k ] = state_dict .pop (k )
1821+ return state_dict
1822+
1823+ def get_model (self , state_dict , prefix = "" , device = None ):
1824+ return model_base .SAM3 (self , device = device )
1825+
1826+ def clip_target (self , state_dict = {}):
1827+ import comfy .text_encoders .sam3_clip
1828+ return supported_models_base .ClipTarget (comfy .text_encoders .sam3_clip .SAM3TokenizerWrapper , comfy .text_encoders .sam3_clip .SAM3ClipModelWrapper )
1829+
1830+
1831+ class SAM31 (SAM3 ):
1832+ unet_config = {"image_model" : "SAM31" }
1833+
1834+
1835+ models = [LotusD , Stable_Zero123 , SD15_instructpix2pix , SD15 , SD20 , SD21UnclipL , SD21UnclipH , SDXL_instructpix2pix , SDXLRefiner , SDXL , SSD1B , KOALA_700M , KOALA_1B , Segmind_Vega , SD_X4Upscaler , Stable_Cascade_C , Stable_Cascade_B , SV3D_u , SV3D_p , SD3 , StableAudio , AuraFlow , PixArtAlpha , PixArtSigma , HunyuanDiT , HunyuanDiT1 , FluxInpaint , Flux , LongCatImage , FluxSchnell , GenmoMochi , LTXV , LTXAV , HunyuanVideo15_SR_Distilled , HunyuanVideo15 , HunyuanImage21Refiner , HunyuanImage21 , HunyuanVideoSkyreelsI2V , HunyuanVideoI2V , HunyuanVideo , CosmosT2V , CosmosI2V , CosmosT2IPredict2 , CosmosI2VPredict2 , ZImagePixelSpace , ZImage , Lumina2 , WAN22_T2V , WAN21_T2V , WAN21_I2V , WAN21_FunControl2V , WAN21_Vace , WAN21_Camera , WAN22_Camera , WAN22_S2V , WAN21_HuMo , WAN22_Animate , WAN21_FlowRVS , WAN21_SCAIL , Hunyuan3Dv2mini , Hunyuan3Dv2 , Hunyuan3Dv2_1 , HiDream , Chroma , ChromaRadiance , ACEStep , ACEStep15 , Omnigen2 , QwenImage , Flux2 , Kandinsky5Image , Kandinsky5 , Anima , RT_DETR_v4 , ErnieImage , SAM3 , SAM31 ]
17851836
17861837models += [SVD_img2vid ]
0 commit comments