Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions comfy/latent_formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def __init__(self):

self.latent_rgb_factors_bias = [-0.0329, -0.0718, -0.0851]
self.latent_rgb_factors_reshape = lambda t: t.reshape(t.shape[0], 32, 2, 2, t.shape[-2], t.shape[-1]).permute(0, 1, 4, 2, 5, 3).reshape(t.shape[0], 32, t.shape[-2] * 2, t.shape[-1] * 2)
self.taesd_decoder_name = "taef2_decoder"

def process_in(self, latent):
return latent
Expand Down Expand Up @@ -783,3 +784,10 @@ class ZImagePixelSpace(ChromaRadiance):
No VAE encoding/decoding — the model operates directly on RGB pixels.
"""
pass

class CogVideoX(LatentFormat):
latent_channels = 16
latent_dimensions = 3

def __init__(self):
self.scale_factor = 1.15258426
Empty file added comfy/ldm/cogvideo/__init__.py
Empty file.
573 changes: 573 additions & 0 deletions comfy/ldm/cogvideo/model.py

Large diffs are not rendered by default.

566 changes: 566 additions & 0 deletions comfy/ldm/cogvideo/vae.py

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions comfy/model_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
import comfy.ldm.kandinsky5.model
import comfy.ldm.anima.model
import comfy.ldm.ace.ace_step15
import comfy.ldm.cogvideo.model
import comfy.ldm.rt_detr.rtdetr_v4
import comfy.ldm.ernie.model
import comfy.ldm.sam3.detector
Expand Down Expand Up @@ -81,6 +82,7 @@ class ModelType(Enum):
IMG_TO_IMG = 9
FLOW_COSMOS = 10
IMG_TO_IMG_FLOW = 11
V_PREDICTION_DDPM = 12


def model_sampling(model_config, model_type):
Expand Down Expand Up @@ -115,6 +117,8 @@ def model_sampling(model_config, model_type):
s = comfy.model_sampling.ModelSamplingCosmosRFlow
elif model_type == ModelType.IMG_TO_IMG_FLOW:
c = comfy.model_sampling.IMG_TO_IMG_FLOW
elif model_type == ModelType.V_PREDICTION_DDPM:
c = comfy.model_sampling.V_PREDICTION_DDPM

class ModelSampling(s, c):
pass
Expand Down Expand Up @@ -1979,3 +1983,59 @@ def extra_conds(self, **kwargs):
class SAM3(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.sam3.detector.SAM3Model)

class CogVideoX(BaseModel):
def __init__(self, model_config, model_type=ModelType.V_PREDICTION_DDPM, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.cogvideo.model.CogVideoXTransformer3DModel)
self.image_to_video = image_to_video

def concat_cond(self, **kwargs):
noise = kwargs.get("noise", None)
# Detect extra channels needed (e.g. 32 - 16 = 16 for ref latent)
extra_channels = self.diffusion_model.in_channels - noise.shape[1]
if extra_channels == 0:
return None

image = kwargs.get("concat_latent_image", None)
device = kwargs["device"]

if image is None:
shape = list(noise.shape)
shape[1] = extra_channels
return torch.zeros(shape, dtype=noise.dtype, layout=noise.layout, device=noise.device)

latent_dim = self.latent_format.latent_channels
image = utils.common_upscale(image.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")

if noise.ndim == 5 and image.ndim == 5:
if image.shape[-3] < noise.shape[-3]:
image = torch.nn.functional.pad(image, (0, 0, 0, 0, 0, noise.shape[-3] - image.shape[-3]), "constant", 0)
elif image.shape[-3] > noise.shape[-3]:
image = image[:, :, :noise.shape[-3]]

for i in range(0, image.shape[1], latent_dim):
image[:, i:i + latent_dim] = self.process_latent_in(image[:, i:i + latent_dim])
image = utils.resize_to_batch_size(image, noise.shape[0])

if image.shape[1] > extra_channels:
image = image[:, :extra_channels]
elif image.shape[1] < extra_channels:
repeats = extra_channels // image.shape[1]
remainder = extra_channels % image.shape[1]
parts = [image] * repeats
if remainder > 0:
parts.append(image[:, :remainder])
image = torch.cat(parts, dim=1)

return image

def extra_conds(self, **kwargs):
out = super().extra_conds(**kwargs)
# OFS embedding (CogVideoX 1.5 I2V), default 2.0 as used by SparkVSR
if self.diffusion_model.ofs_proj_dim is not None:
ofs = kwargs.get("ofs", None)
if ofs is None:
noise = kwargs.get("noise", None)
ofs = torch.full((noise.shape[0],), 2.0, device=noise.device, dtype=noise.dtype)
out['ofs'] = comfy.conds.CONDRegular(ofs)
return out
48 changes: 48 additions & 0 deletions comfy/model_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,54 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):

return dit_config

if '{}blocks.0.norm1.linear.weight'.format(key_prefix) in state_dict_keys: # CogVideoX
dit_config = {}
dit_config["image_model"] = "cogvideox"

# Extract config from weight shapes
norm1_weight = state_dict['{}blocks.0.norm1.linear.weight'.format(key_prefix)]
time_embed_dim = norm1_weight.shape[1]
dim = norm1_weight.shape[0] // 6

dit_config["num_attention_heads"] = dim // 64
dit_config["attention_head_dim"] = 64
dit_config["time_embed_dim"] = time_embed_dim
dit_config["num_layers"] = count_blocks(state_dict_keys, '{}blocks.'.format(key_prefix) + '{}.')

# Detect in_channels from patch_embed
patch_proj_key = '{}patch_embed.proj.weight'.format(key_prefix)
if patch_proj_key in state_dict_keys:
w = state_dict[patch_proj_key]
if w.ndim == 4:
# Conv2d: [out, in, kh, kw] — CogVideoX 1.0
dit_config["in_channels"] = w.shape[1]
dit_config["patch_size"] = w.shape[2]
elif w.ndim == 2:
# Linear: [out, in_channels * patch_size * patch_size * patch_size_t] — CogVideoX 1.5
dit_config["patch_size"] = 2
dit_config["patch_size_t"] = 2
dit_config["in_channels"] = w.shape[1] // (2 * 2 * 2) # 256 // 8 = 32

text_proj_key = '{}patch_embed.text_proj.weight'.format(key_prefix)
if text_proj_key in state_dict_keys:
dit_config["text_embed_dim"] = state_dict[text_proj_key].shape[1]

# Detect OFS embedding
ofs_key = '{}ofs_embedding_linear_1.weight'.format(key_prefix)
if ofs_key in state_dict_keys:
dit_config["ofs_embed_dim"] = state_dict[ofs_key].shape[1]

# Detect positional embedding type
pos_key = '{}patch_embed.pos_embedding'.format(key_prefix)
if pos_key in state_dict_keys:
dit_config["use_learned_positional_embeddings"] = True
dit_config["use_rotary_positional_embeddings"] = False
else:
dit_config["use_learned_positional_embeddings"] = False
dit_config["use_rotary_positional_embeddings"] = True

return dit_config

if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
dit_config = {}
dit_config["image_model"] = "wan2.1"
Expand Down
24 changes: 24 additions & 0 deletions comfy/model_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,30 @@ def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input * self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) - model_output * sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2) ** 0.5

class V_PREDICTION_DDPM:
"""CogVideoX v-prediction: model receives raw x_t (unscaled), predicts velocity v.
x_0 = sqrt(alpha) * x_t - sqrt(1-alpha) * v
= x_t / sqrt(sigma^2 + 1) - v * sigma / sqrt(sigma^2 + 1)
"""
def calculate_input(self, sigma, noise):
return noise

def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
return model_input / (sigma ** 2 + 1.0) ** 0.5 - model_output * sigma / (sigma ** 2 + 1.0) ** 0.5

def noise_scaling(self, sigma, noise, latent_image, max_denoise=False):
sigma = reshape_sigma(sigma, noise.ndim)
if max_denoise:
noise = noise * torch.sqrt(1.0 + sigma ** 2.0)
else:
noise = noise * sigma
noise += latent_image
return noise

def inverse_noise_scaling(self, sigma, latent):
return latent

class EDM(V_PREDICTION):
def calculate_denoised(self, sigma, model_output, model_input):
sigma = reshape_sigma(sigma, model_output.ndim)
Expand Down
17 changes: 16 additions & 1 deletion comfy/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import comfy.ldm.wan.vae2_2
import comfy.ldm.hunyuan3d.vae
import comfy.ldm.ace.vae.music_dcae_pipeline
import comfy.ldm.cogvideo.vae
import comfy.ldm.hunyuan_video.vae
import comfy.ldm.mmaudio.vae.autoencoder
import comfy.pixel_space_convert
Expand Down Expand Up @@ -478,7 +479,10 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)
encoder_config={'target': "comfy.ldm.modules.diffusionmodules.model.Encoder", 'params': encoder_config},
decoder_config={'target': "comfy.ldm.modules.temporal_ae.VideoDecoder", 'params': decoder_config})
elif "taesd_decoder.1.weight" in sd:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
if isinstance(metadata, dict) and "tae_latent_channels" in metadata:
self.latent_channels = metadata["tae_latent_channels"]
else:
self.latent_channels = sd["taesd_decoder.1.weight"].shape[1]
self.first_stage_model = comfy.taesd.taesd.TAESD(latent_channels=self.latent_channels)
elif "vquantizer.codebook.weight" in sd: #VQGan: stage a of stable cascade
self.first_stage_model = StageA()
Expand Down Expand Up @@ -652,6 +656,17 @@ def __init__(self, sd=None, device=None, config=None, dtype=None, metadata=None)

self.memory_used_encode = lambda shape, dtype: (1400 * 9 * shape[-2] * shape[-1]) * model_management.dtype_size(dtype)
self.memory_used_decode = lambda shape, dtype: (3600 * 4 * shape[-2] * shape[-1] * 16 * 16) * model_management.dtype_size(dtype)
elif "decoder.conv_in.conv.weight" in sd and "decoder.mid_block.resnets.0.norm1.norm_layer.weight" in sd: # CogVideoX VAE
self.upscale_ratio = (lambda a: max(0, a * 4 - 3), 8, 8)
self.upscale_index_formula = (4, 8, 8)
self.downscale_ratio = (lambda a: max(0, math.floor((a + 3) / 4)), 8, 8)
self.downscale_index_formula = (4, 8, 8)
self.latent_dim = 3
self.latent_channels = sd["encoder.conv_out.conv.weight"].shape[0] // 2
self.first_stage_model = comfy.ldm.cogvideo.vae.AutoencoderKLCogVideoX(latent_channels=self.latent_channels)
self.memory_used_decode = lambda shape, dtype: (2800 * max(2, ((shape[2] - 1) * 4) + 1) * shape[3] * shape[4] * (8 * 8)) * model_management.dtype_size(dtype)
self.memory_used_encode = lambda shape, dtype: (1400 * max(1, shape[2]) * shape[3] * shape[4]) * model_management.dtype_size(dtype)
self.working_dtypes = [torch.bfloat16, torch.float16, torch.float32]
elif "decoder.conv_in.conv.weight" in sd:
ddconfig = {'double_z': True, 'z_channels': 4, 'resolution': 256, 'in_channels': 3, 'out_ch': 3, 'ch': 128, 'ch_mult': [1, 2, 4, 4], 'num_res_blocks': 2, 'attn_resolutions': [], 'dropout': 0.0}
ddconfig["conv3d"] = True
Expand Down
49 changes: 48 additions & 1 deletion comfy/supported_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import comfy.text_encoders.ace15
import comfy.text_encoders.longcat_image
import comfy.text_encoders.ernie
import comfy.text_encoders.cogvideo

from . import supported_models_base
from . import latent_formats
Expand Down Expand Up @@ -1832,6 +1833,52 @@ class SAM31(SAM3):
unet_config = {"image_model": "SAM31"}


models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31]
class CogVideoX_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "cogvideox",
}

sampling_settings = {
"linear_start": 0.00085,
"linear_end": 0.012,
"beta_schedule": "linear",
"zsnr": True,
}

unet_extra_config = {}
latent_format = latent_formats.CogVideoX

supported_inference_dtypes = [torch.bfloat16, torch.float16, torch.float32]

vae_key_prefix = ["vae."]
text_encoder_key_prefix = ["text_encoders."]

def get_model(self, state_dict, prefix="", device=None):
# CogVideoX 1.5 (patch_size_t=2) has different training base dimensions for RoPE
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, device=device)
return out

def clip_target(self, state_dict={}):
return supported_models_base.ClipTarget(comfy.text_encoders.cogvideo.CogVideoXT5Tokenizer, comfy.text_encoders.sd3_clip.T5XXLModel)

class CogVideoX_I2V(CogVideoX_T2V):
unet_config = {
"image_model": "cogvideox",
"in_channels": 32,
}

def get_model(self, state_dict, prefix="", device=None):
if self.unet_config.get("patch_size_t") is not None:
self.unet_config.setdefault("sample_height", 96)
self.unet_config.setdefault("sample_width", 170)
self.unet_config.setdefault("sample_frames", 81)
out = model_base.CogVideoX(self, image_to_video=True, device=device)
return out

models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima, RT_DETR_v4, ErnieImage, SAM3, SAM31, CogVideoX_I2V, CogVideoX_T2V]

models += [SVD_img2vid]
Loading
Loading