@@ -12,6 +12,9 @@ use serde_json::Value;
1212pub struct AcousticModelConfig {
1313 /// Hidden dimension size.
1414 pub hidden_size : usize ,
15+ /// Text embedding dimension (before projection).
16+ /// In Qwen3-TTS this is 2048, which gets projected to hidden_size (1024).
17+ pub embedding_dim : usize ,
1518 /// Number of attention heads.
1619 pub num_attention_heads : usize ,
1720 /// Number of key-value heads (for GQA).
@@ -89,57 +92,86 @@ impl AcousticModelConfig {
8992 /// Parse configuration from JSON string (HuggingFace config.json format).
9093 ///
9194 /// Maps HuggingFace field names to our config structure.
95+ /// Supports both flat configs and Qwen3-TTS nested configs (with talker_config).
9296 pub fn from_json ( json : & str ) -> Result < Self , String > {
9397 let v: Value =
9498 serde_json:: from_str ( json) . map_err ( |e| format ! ( "failed to parse JSON: {e}" ) ) ?;
9599
96- // Helper to extract values with defaults
100+ // Check if this is a nested Qwen3-TTS config with talker_config
101+ let talker = v. get ( "talker_config" ) ;
102+
103+ // Helper to extract values with defaults, checking talker_config first
97104 let get_usize = |key : & str , default : usize | -> usize {
98- v. get ( key)
105+ // First try talker_config, then root
106+ talker
107+ . and_then ( |t| t. get ( key) )
108+ . or_else ( || v. get ( key) )
99109 . and_then ( |v| v. as_u64 ( ) )
100110 . map ( |v| v as usize )
101111 . unwrap_or ( default)
102112 } ;
103113
104114 let get_f64 = |key : & str , default : f64 | -> f64 {
105- v. get ( key) . and_then ( |v| v. as_f64 ( ) ) . unwrap_or ( default)
115+ talker
116+ . and_then ( |t| t. get ( key) )
117+ . or_else ( || v. get ( key) )
118+ . and_then ( |v| v. as_f64 ( ) )
119+ . unwrap_or ( default)
106120 } ;
107121
108122 let get_u32 = |key : & str , default : u32 | -> u32 {
109- v. get ( key)
123+ talker
124+ . and_then ( |t| t. get ( key) )
125+ . or_else ( || v. get ( key) )
110126 . and_then ( |v| v. as_u64 ( ) )
111127 . map ( |v| v as u32 )
112128 . unwrap_or ( default)
113129 } ;
114130
115131 // Map HuggingFace field names to our config
116- // HuggingFace uses: hidden_size, num_hidden_layers, num_attention_heads , etc.
132+ // Qwen3-TTS uses: talker_config. hidden_size, talker_config. num_hidden_layers, etc.
117133 let hidden_size = get_usize ( "hidden_size" , 1024 ) ;
134+ // Qwen3-TTS uses text_hidden_size=2048 which gets projected to hidden_size=1024
135+ let embedding_dim = get_usize ( "text_hidden_size" , 2048 ) ;
118136 let num_attention_heads = get_usize ( "num_attention_heads" , 16 ) ;
119137 let num_kv_heads = get_usize ( "num_key_value_heads" , 8 ) ;
120138 let num_layers = get_usize ( "num_hidden_layers" , 28 ) ;
121139 let intermediate_size = get_usize ( "intermediate_size" , 3072 ) ;
122140 let head_dim = get_usize ( "head_dim" , hidden_size / num_attention_heads) ;
123141
124- let text_vocab_size = get_usize ( "vocab_size" , 151936 ) ;
125- let codec_vocab_size = get_usize ( "codec_vocab_size" , 3072 ) ;
142+ // Vocabulary sizes - use text_vocab_size for text, vocab_size for codec
143+ let text_vocab_size = get_usize ( "text_vocab_size" , 151936 ) ;
144+ let codec_vocab_size = get_usize ( "vocab_size" , 3072 ) ;
126145 let num_code_groups = get_usize ( "num_code_groups" , 16 ) ;
127146 let codebook_size = get_usize ( "codebook_size" , 2048 ) ;
128147
129148 let max_position_embeddings = get_usize ( "max_position_embeddings" , 32768 ) ;
130149 let rope_theta = get_f64 ( "rope_theta" , 1_000_000.0 ) ;
131150 let rms_norm_eps = get_f64 ( "rms_norm_eps" , 1e-6 ) ;
132151
133- // Special tokens
134- let tts_bos_token_id = get_u32 ( "tts_bos_token_id" , 151672 ) ;
135- let tts_eos_token_id = get_u32 ( "tts_eos_token_id" , 151673 ) ;
136- let tts_pad_token_id = get_u32 ( "tts_pad_token_id" , 151671 ) ;
152+ // Special tokens - from root level in Qwen3-TTS
153+ let tts_bos_token_id = v
154+ . get ( "tts_bos_token_id" )
155+ . and_then ( |v| v. as_u64 ( ) )
156+ . map ( |v| v as u32 )
157+ . unwrap_or ( 151672 ) ;
158+ let tts_eos_token_id = v
159+ . get ( "tts_eos_token_id" )
160+ . and_then ( |v| v. as_u64 ( ) )
161+ . map ( |v| v as u32 )
162+ . unwrap_or ( 151673 ) ;
163+ let tts_pad_token_id = v
164+ . get ( "tts_pad_token_id" )
165+ . and_then ( |v| v. as_u64 ( ) )
166+ . map ( |v| v as u32 )
167+ . unwrap_or ( 151671 ) ;
137168 let codec_bos_id = get_u32 ( "codec_bos_id" , 2149 ) ;
138169 let codec_eos_id = get_u32 ( "codec_eos_token_id" , 2150 ) ;
139170 let codec_pad_id = get_u32 ( "codec_pad_id" , 2148 ) ;
140171
141172 Ok ( Self {
142173 hidden_size,
174+ embedding_dim,
143175 num_attention_heads,
144176 num_kv_heads,
145177 num_layers,
@@ -166,11 +198,12 @@ impl AcousticModelConfig {
166198 Self {
167199 // Talker dimensions
168200 hidden_size : 1024 ,
201+ embedding_dim : 2048 , // text_embedding is [vocab_size, 2048]
169202 num_attention_heads : 16 ,
170203 num_kv_heads : 8 ,
171204 num_layers : 28 ,
172205 intermediate_size : 3072 ,
173- head_dim : 128 , // 1024 / 8 per spec, but config says 128
206+ head_dim : 128 , // num_heads * head_dim = 16 * 128 = 2048 for Q
174207
175208 // Vocabulary
176209 text_vocab_size : 151936 ,
@@ -198,6 +231,7 @@ impl AcousticModelConfig {
198231 Self {
199232 // Talker dimensions (estimated, needs verification)
200233 hidden_size : 2048 ,
234+ embedding_dim : 4096 , // estimated
201235 num_attention_heads : 16 ,
202236 num_kv_heads : 8 ,
203237 num_layers : 28 ,
@@ -229,6 +263,7 @@ impl AcousticModelConfig {
229263 pub fn tiny ( ) -> Self {
230264 Self {
231265 hidden_size : 64 ,
266+ embedding_dim : 64 , // same as hidden_size for testing
232267 num_attention_heads : 4 ,
233268 num_kv_heads : 2 ,
234269 num_layers : 2 ,
@@ -261,6 +296,7 @@ impl AcousticModelConfig {
261296 pub fn legacy ( ) -> Self {
262297 Self {
263298 hidden_size : 2048 ,
299+ embedding_dim : 2048 ,
264300 num_attention_heads : 16 ,
265301 num_kv_heads : 4 ,
266302 num_layers : 24 ,
@@ -408,30 +444,35 @@ mod tests {
408444
409445 #[ test]
410446 fn test_from_json ( ) {
447+ // Test with Qwen3-TTS style config (nested talker_config)
411448 let json = r#"{
412- "hidden_size": 512,
413- "num_attention_heads": 8,
414- "num_key_value_heads": 4,
415- "num_hidden_layers": 12,
416- "intermediate_size": 2048,
417- "vocab_size": 50000,
418- "codec_vocab_size": 1024,
419- "num_code_groups": 8,
420- "codebook_size": 512,
421- "max_position_embeddings": 4096,
422- "rope_theta": 100000.0,
423- "rms_norm_eps": 1e-5,
424449 "tts_bos_token_id": 100,
425450 "tts_eos_token_id": 101,
426451 "tts_pad_token_id": 0,
427- "codec_bos_id": 10,
428- "codec_eos_token_id": 11,
429- "codec_pad_id": 0
452+ "talker_config": {
453+ "hidden_size": 512,
454+ "text_hidden_size": 1024,
455+ "num_attention_heads": 8,
456+ "num_key_value_heads": 4,
457+ "num_hidden_layers": 12,
458+ "intermediate_size": 2048,
459+ "text_vocab_size": 50000,
460+ "vocab_size": 1024,
461+ "num_code_groups": 8,
462+ "codebook_size": 512,
463+ "max_position_embeddings": 4096,
464+ "rope_theta": 100000.0,
465+ "rms_norm_eps": 1e-5,
466+ "codec_bos_id": 10,
467+ "codec_eos_token_id": 11,
468+ "codec_pad_id": 0
469+ }
430470 }"# ;
431471
432472 let config = AcousticModelConfig :: from_json ( json) . unwrap ( ) ;
433473
434474 assert_eq ! ( config. hidden_size, 512 ) ;
475+ assert_eq ! ( config. embedding_dim, 1024 ) ;
435476 assert_eq ! ( config. num_attention_heads, 8 ) ;
436477 assert_eq ! ( config. num_kv_heads, 4 ) ;
437478 assert_eq ! ( config. num_layers, 12 ) ;
0 commit comments