Skip to content

Commit 4e92b7d

Browse files
committed
feat(acoustic-model): адаптировать архитектуру под Qwen3-TTS
- Добавить раздельные text_embedding и codec_embedding - Реализовать TextProjection (embedding_dim → hidden_size) - Добавить QK-Norm в Attention (опционально) - Использовать linear_no_bias для attention/MLP проекций - Поддержать вложенный talker_config в парсере JSON - Добавить forward_mixed для корректной обработки смешанных последовательностей
1 parent fae26d2 commit 4e92b7d

5 files changed

Lines changed: 350 additions & 94 deletions

File tree

crates/acoustic-model/src/code_predictor.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
//! 16 codebooks for ultra-low-latency streaming.
66
77
use candle_core::{DType, Device, IndexOp, Result, Tensor};
8-
use candle_nn::{Embedding, Module, VarBuilder, embedding, linear};
8+
use candle_nn::{Embedding, Module, VarBuilder, embedding};
99
use tracing::{debug, info, instrument};
1010

1111
use crate::config::CodePredictorConfig;
@@ -69,6 +69,7 @@ impl CodePredictor {
6969
// Create a temporary AcousticModelConfig for transformer blocks
7070
let acoustic_config = crate::config::AcousticModelConfig {
7171
hidden_size: config.hidden_size,
72+
embedding_dim: config.hidden_size, // Same as hidden_size for code predictor
7273
num_attention_heads: config.num_attention_heads,
7374
num_kv_heads: config.num_kv_heads,
7475
num_layers: config.num_layers,
@@ -104,10 +105,11 @@ impl CodePredictor {
104105
let norm = RmsNorm::new(config.hidden_size, config.rms_norm_eps, vb_model.pp("norm"))?;
105106

106107
// Separate LM head for each residual codebook (groups 1 to num_code_groups-1)
108+
// Qwen3-TTS uses linear_no_bias for lm_heads
107109
let num_residual_groups = config.num_code_groups - 1;
108110
let mut lm_heads = Vec::with_capacity(num_residual_groups);
109111
for i in 0..num_residual_groups {
110-
let head = linear(
112+
let head = candle_nn::linear_no_bias(
111113
config.hidden_size,
112114
config.codebook_size,
113115
vb_model.pp(format!("lm_head.{i}")),

crates/acoustic-model/src/config.rs

Lines changed: 68 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ use serde_json::Value;
1212
pub struct AcousticModelConfig {
1313
/// Hidden dimension size.
1414
pub hidden_size: usize,
15+
/// Text embedding dimension (before projection).
16+
/// In Qwen3-TTS this is 2048, which gets projected to hidden_size (1024).
17+
pub embedding_dim: usize,
1518
/// Number of attention heads.
1619
pub num_attention_heads: usize,
1720
/// Number of key-value heads (for GQA).
@@ -89,57 +92,86 @@ impl AcousticModelConfig {
8992
/// Parse configuration from JSON string (HuggingFace config.json format).
9093
///
9194
/// Maps HuggingFace field names to our config structure.
95+
/// Supports both flat configs and Qwen3-TTS nested configs (with talker_config).
9296
pub fn from_json(json: &str) -> Result<Self, String> {
9397
let v: Value =
9498
serde_json::from_str(json).map_err(|e| format!("failed to parse JSON: {e}"))?;
9599

96-
// Helper to extract values with defaults
100+
// Check if this is a nested Qwen3-TTS config with talker_config
101+
let talker = v.get("talker_config");
102+
103+
// Helper to extract values with defaults, checking talker_config first
97104
let get_usize = |key: &str, default: usize| -> usize {
98-
v.get(key)
105+
// First try talker_config, then root
106+
talker
107+
.and_then(|t| t.get(key))
108+
.or_else(|| v.get(key))
99109
.and_then(|v| v.as_u64())
100110
.map(|v| v as usize)
101111
.unwrap_or(default)
102112
};
103113

104114
let get_f64 = |key: &str, default: f64| -> f64 {
105-
v.get(key).and_then(|v| v.as_f64()).unwrap_or(default)
115+
talker
116+
.and_then(|t| t.get(key))
117+
.or_else(|| v.get(key))
118+
.and_then(|v| v.as_f64())
119+
.unwrap_or(default)
106120
};
107121

108122
let get_u32 = |key: &str, default: u32| -> u32 {
109-
v.get(key)
123+
talker
124+
.and_then(|t| t.get(key))
125+
.or_else(|| v.get(key))
110126
.and_then(|v| v.as_u64())
111127
.map(|v| v as u32)
112128
.unwrap_or(default)
113129
};
114130

115131
// Map HuggingFace field names to our config
116-
// HuggingFace uses: hidden_size, num_hidden_layers, num_attention_heads, etc.
132+
// Qwen3-TTS uses: talker_config.hidden_size, talker_config.num_hidden_layers, etc.
117133
let hidden_size = get_usize("hidden_size", 1024);
134+
// Qwen3-TTS uses text_hidden_size=2048 which gets projected to hidden_size=1024
135+
let embedding_dim = get_usize("text_hidden_size", 2048);
118136
let num_attention_heads = get_usize("num_attention_heads", 16);
119137
let num_kv_heads = get_usize("num_key_value_heads", 8);
120138
let num_layers = get_usize("num_hidden_layers", 28);
121139
let intermediate_size = get_usize("intermediate_size", 3072);
122140
let head_dim = get_usize("head_dim", hidden_size / num_attention_heads);
123141

124-
let text_vocab_size = get_usize("vocab_size", 151936);
125-
let codec_vocab_size = get_usize("codec_vocab_size", 3072);
142+
// Vocabulary sizes - use text_vocab_size for text, vocab_size for codec
143+
let text_vocab_size = get_usize("text_vocab_size", 151936);
144+
let codec_vocab_size = get_usize("vocab_size", 3072);
126145
let num_code_groups = get_usize("num_code_groups", 16);
127146
let codebook_size = get_usize("codebook_size", 2048);
128147

129148
let max_position_embeddings = get_usize("max_position_embeddings", 32768);
130149
let rope_theta = get_f64("rope_theta", 1_000_000.0);
131150
let rms_norm_eps = get_f64("rms_norm_eps", 1e-6);
132151

133-
// Special tokens
134-
let tts_bos_token_id = get_u32("tts_bos_token_id", 151672);
135-
let tts_eos_token_id = get_u32("tts_eos_token_id", 151673);
136-
let tts_pad_token_id = get_u32("tts_pad_token_id", 151671);
152+
// Special tokens - from root level in Qwen3-TTS
153+
let tts_bos_token_id = v
154+
.get("tts_bos_token_id")
155+
.and_then(|v| v.as_u64())
156+
.map(|v| v as u32)
157+
.unwrap_or(151672);
158+
let tts_eos_token_id = v
159+
.get("tts_eos_token_id")
160+
.and_then(|v| v.as_u64())
161+
.map(|v| v as u32)
162+
.unwrap_or(151673);
163+
let tts_pad_token_id = v
164+
.get("tts_pad_token_id")
165+
.and_then(|v| v.as_u64())
166+
.map(|v| v as u32)
167+
.unwrap_or(151671);
137168
let codec_bos_id = get_u32("codec_bos_id", 2149);
138169
let codec_eos_id = get_u32("codec_eos_token_id", 2150);
139170
let codec_pad_id = get_u32("codec_pad_id", 2148);
140171

141172
Ok(Self {
142173
hidden_size,
174+
embedding_dim,
143175
num_attention_heads,
144176
num_kv_heads,
145177
num_layers,
@@ -166,11 +198,12 @@ impl AcousticModelConfig {
166198
Self {
167199
// Talker dimensions
168200
hidden_size: 1024,
201+
embedding_dim: 2048, // text_embedding is [vocab_size, 2048]
169202
num_attention_heads: 16,
170203
num_kv_heads: 8,
171204
num_layers: 28,
172205
intermediate_size: 3072,
173-
head_dim: 128, // 1024 / 8 per spec, but config says 128
206+
head_dim: 128, // num_heads * head_dim = 16 * 128 = 2048 for Q
174207

175208
// Vocabulary
176209
text_vocab_size: 151936,
@@ -198,6 +231,7 @@ impl AcousticModelConfig {
198231
Self {
199232
// Talker dimensions (estimated, needs verification)
200233
hidden_size: 2048,
234+
embedding_dim: 4096, // estimated
201235
num_attention_heads: 16,
202236
num_kv_heads: 8,
203237
num_layers: 28,
@@ -229,6 +263,7 @@ impl AcousticModelConfig {
229263
pub fn tiny() -> Self {
230264
Self {
231265
hidden_size: 64,
266+
embedding_dim: 64, // same as hidden_size for testing
232267
num_attention_heads: 4,
233268
num_kv_heads: 2,
234269
num_layers: 2,
@@ -261,6 +296,7 @@ impl AcousticModelConfig {
261296
pub fn legacy() -> Self {
262297
Self {
263298
hidden_size: 2048,
299+
embedding_dim: 2048,
264300
num_attention_heads: 16,
265301
num_kv_heads: 4,
266302
num_layers: 24,
@@ -408,30 +444,35 @@ mod tests {
408444

409445
#[test]
410446
fn test_from_json() {
447+
// Test with Qwen3-TTS style config (nested talker_config)
411448
let json = r#"{
412-
"hidden_size": 512,
413-
"num_attention_heads": 8,
414-
"num_key_value_heads": 4,
415-
"num_hidden_layers": 12,
416-
"intermediate_size": 2048,
417-
"vocab_size": 50000,
418-
"codec_vocab_size": 1024,
419-
"num_code_groups": 8,
420-
"codebook_size": 512,
421-
"max_position_embeddings": 4096,
422-
"rope_theta": 100000.0,
423-
"rms_norm_eps": 1e-5,
424449
"tts_bos_token_id": 100,
425450
"tts_eos_token_id": 101,
426451
"tts_pad_token_id": 0,
427-
"codec_bos_id": 10,
428-
"codec_eos_token_id": 11,
429-
"codec_pad_id": 0
452+
"talker_config": {
453+
"hidden_size": 512,
454+
"text_hidden_size": 1024,
455+
"num_attention_heads": 8,
456+
"num_key_value_heads": 4,
457+
"num_hidden_layers": 12,
458+
"intermediate_size": 2048,
459+
"text_vocab_size": 50000,
460+
"vocab_size": 1024,
461+
"num_code_groups": 8,
462+
"codebook_size": 512,
463+
"max_position_embeddings": 4096,
464+
"rope_theta": 100000.0,
465+
"rms_norm_eps": 1e-5,
466+
"codec_bos_id": 10,
467+
"codec_eos_token_id": 11,
468+
"codec_pad_id": 0
469+
}
430470
}"#;
431471

432472
let config = AcousticModelConfig::from_json(json).unwrap();
433473

434474
assert_eq!(config.hidden_size, 512);
475+
assert_eq!(config.embedding_dim, 1024);
435476
assert_eq!(config.num_attention_heads, 8);
436477
assert_eq!(config.num_kv_heads, 4);
437478
assert_eq!(config.num_layers, 12);

0 commit comments

Comments
 (0)