@@ -611,6 +611,7 @@ def __init__(
611611 intermediate_size ,
612612 patch_size ,
613613 audio_acoustic_hidden_dim ,
614+ condition_dim = None ,
614615 layer_types = None ,
615616 sliding_window = 128 ,
616617 rms_norm_eps = 1e-6 ,
@@ -640,7 +641,7 @@ def __init__(
640641
641642 self .time_embed = TimestepEmbedding (256 , hidden_size , dtype = dtype , device = device , operations = operations )
642643 self .time_embed_r = TimestepEmbedding (256 , hidden_size , dtype = dtype , device = device , operations = operations )
643- self .condition_embedder = Linear (hidden_size , hidden_size , dtype = dtype , device = device )
644+ self .condition_embedder = Linear (condition_dim , hidden_size , dtype = dtype , device = device )
644645
645646 if layer_types is None :
646647 layer_types = ["full_attention" ] * num_layers
@@ -1035,6 +1036,9 @@ def __init__(
10351036 fsq_dim = 2048 ,
10361037 fsq_levels = [8 , 8 , 8 , 5 , 5 , 5 ],
10371038 fsq_input_num_quantizers = 1 ,
1039+ encoder_hidden_size = 2048 ,
1040+ encoder_intermediate_size = 6144 ,
1041+ encoder_num_heads = 16 ,
10381042 audio_model = None ,
10391043 dtype = None ,
10401044 device = None ,
@@ -1054,24 +1058,24 @@ def __init__(
10541058
10551059 self .decoder = AceStepDiTModel (
10561060 in_channels , hidden_size , num_dit_layers , num_heads , num_kv_heads , head_dim ,
1057- intermediate_size , patch_size , audio_acoustic_hidden_dim ,
1061+ intermediate_size , patch_size , audio_acoustic_hidden_dim , condition_dim = encoder_hidden_size ,
10581062 layer_types = layer_types , sliding_window = sliding_window , rms_norm_eps = rms_norm_eps ,
10591063 dtype = dtype , device = device , operations = operations
10601064 )
10611065 self .encoder = AceStepConditionEncoder (
1062- text_hidden_dim , timbre_hidden_dim , hidden_size , num_lyric_layers , num_timbre_layers ,
1063- num_heads , num_kv_heads , head_dim , intermediate_size , rms_norm_eps ,
1066+ text_hidden_dim , timbre_hidden_dim , encoder_hidden_size , num_lyric_layers , num_timbre_layers ,
1067+ encoder_num_heads , num_kv_heads , head_dim , encoder_intermediate_size , rms_norm_eps ,
10641068 dtype = dtype , device = device , operations = operations
10651069 )
10661070 self .tokenizer = AceStepAudioTokenizer (
1067- audio_acoustic_hidden_dim , hidden_size , pool_window_size , fsq_dim = fsq_dim , fsq_levels = fsq_levels , fsq_input_num_quantizers = fsq_input_num_quantizers , num_layers = num_tokenizer_layers , head_dim = head_dim , rms_norm_eps = rms_norm_eps ,
1071+ audio_acoustic_hidden_dim , encoder_hidden_size , pool_window_size , fsq_dim = fsq_dim , fsq_levels = fsq_levels , fsq_input_num_quantizers = fsq_input_num_quantizers , num_layers = num_tokenizer_layers , head_dim = head_dim , rms_norm_eps = rms_norm_eps ,
10681072 dtype = dtype , device = device , operations = operations
10691073 )
10701074 self .detokenizer = AudioTokenDetokenizer (
1071- hidden_size , pool_window_size , audio_acoustic_hidden_dim , num_layers = 2 , head_dim = head_dim ,
1075+ encoder_hidden_size , pool_window_size , audio_acoustic_hidden_dim , num_layers = 2 , head_dim = head_dim ,
10721076 dtype = dtype , device = device , operations = operations
10731077 )
1074- self .null_condition_emb = nn .Parameter (torch .empty (1 , 1 , hidden_size , dtype = dtype , device = device ))
1078+ self .null_condition_emb = nn .Parameter (torch .empty (1 , 1 , encoder_hidden_size , dtype = dtype , device = device ))
10751079
10761080 def prepare_condition (
10771081 self ,
0 commit comments