Skip to content

Commit 40862c0

Browse files
Support Ace Step 1.5 XL model. (Comfy-Org#13317)
1 parent 50076f3 commit 40862c0

2 files changed

Lines changed: 20 additions & 7 deletions

File tree

comfy/ldm/ace/ace_step15.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -611,6 +611,7 @@ def __init__(
611611
intermediate_size,
612612
patch_size,
613613
audio_acoustic_hidden_dim,
614+
condition_dim=None,
614615
layer_types=None,
615616
sliding_window=128,
616617
rms_norm_eps=1e-6,
@@ -640,7 +641,7 @@ def __init__(
640641

641642
self.time_embed = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
642643
self.time_embed_r = TimestepEmbedding(256, hidden_size, dtype=dtype, device=device, operations=operations)
643-
self.condition_embedder = Linear(hidden_size, hidden_size, dtype=dtype, device=device)
644+
self.condition_embedder = Linear(condition_dim, hidden_size, dtype=dtype, device=device)
644645

645646
if layer_types is None:
646647
layer_types = ["full_attention"] * num_layers
@@ -1035,6 +1036,9 @@ def __init__(
10351036
fsq_dim=2048,
10361037
fsq_levels=[8, 8, 8, 5, 5, 5],
10371038
fsq_input_num_quantizers=1,
1039+
encoder_hidden_size=2048,
1040+
encoder_intermediate_size=6144,
1041+
encoder_num_heads=16,
10381042
audio_model=None,
10391043
dtype=None,
10401044
device=None,
@@ -1054,24 +1058,24 @@ def __init__(
10541058

10551059
self.decoder = AceStepDiTModel(
10561060
in_channels, hidden_size, num_dit_layers, num_heads, num_kv_heads, head_dim,
1057-
intermediate_size, patch_size, audio_acoustic_hidden_dim,
1061+
intermediate_size, patch_size, audio_acoustic_hidden_dim, condition_dim=encoder_hidden_size,
10581062
layer_types=layer_types, sliding_window=sliding_window, rms_norm_eps=rms_norm_eps,
10591063
dtype=dtype, device=device, operations=operations
10601064
)
10611065
self.encoder = AceStepConditionEncoder(
1062-
text_hidden_dim, timbre_hidden_dim, hidden_size, num_lyric_layers, num_timbre_layers,
1063-
num_heads, num_kv_heads, head_dim, intermediate_size, rms_norm_eps,
1066+
text_hidden_dim, timbre_hidden_dim, encoder_hidden_size, num_lyric_layers, num_timbre_layers,
1067+
encoder_num_heads, num_kv_heads, head_dim, encoder_intermediate_size, rms_norm_eps,
10641068
dtype=dtype, device=device, operations=operations
10651069
)
10661070
self.tokenizer = AceStepAudioTokenizer(
1067-
audio_acoustic_hidden_dim, hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
1071+
audio_acoustic_hidden_dim, encoder_hidden_size, pool_window_size, fsq_dim=fsq_dim, fsq_levels=fsq_levels, fsq_input_num_quantizers=fsq_input_num_quantizers, num_layers=num_tokenizer_layers, head_dim=head_dim, rms_norm_eps=rms_norm_eps,
10681072
dtype=dtype, device=device, operations=operations
10691073
)
10701074
self.detokenizer = AudioTokenDetokenizer(
1071-
hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
1075+
encoder_hidden_size, pool_window_size, audio_acoustic_hidden_dim, num_layers=2, head_dim=head_dim,
10721076
dtype=dtype, device=device, operations=operations
10731077
)
1074-
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, hidden_size, dtype=dtype, device=device))
1078+
self.null_condition_emb = nn.Parameter(torch.empty(1, 1, encoder_hidden_size, dtype=dtype, device=device))
10751079

10761080
def prepare_condition(
10771081
self,

comfy/model_detection.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,15 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
696696
if '{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix) in state_dict_keys:
697697
dit_config = {}
698698
dit_config["audio_model"] = "ace1.5"
699+
head_dim = 128
700+
dit_config["hidden_size"] = state_dict['{}decoder.layers.0.self_attn_norm.weight'.format(key_prefix)].shape[0]
701+
dit_config["intermediate_size"] = state_dict['{}decoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0]
702+
dit_config["num_heads"] = state_dict['{}decoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim
703+
704+
dit_config["encoder_hidden_size"] = state_dict['{}encoder.lyric_encoder.layers.0.input_layernorm.weight'.format(key_prefix)].shape[0]
705+
dit_config["encoder_num_heads"] = state_dict['{}encoder.lyric_encoder.layers.0.self_attn.q_proj.weight'.format(key_prefix)].shape[0] // head_dim
706+
dit_config["encoder_intermediate_size"] = state_dict['{}encoder.lyric_encoder.layers.0.mlp.gate_proj.weight'.format(key_prefix)].shape[0]
707+
dit_config["num_dit_layers"] = count_blocks(state_dict_keys, '{}decoder.layers.'.format(key_prefix) + '{}.')
699708
return dit_config
700709

701710
if '{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix) in state_dict_keys: # RT-DETR_v4

0 commit comments

Comments
 (0)