Skip to content

Commit dbcd2c8

Browse files
committed
feat: добавить автоопределение CustomVoice моделей и tts-app крейт
- Добавить is_custom_voice флаг в PipelineConfig - Реализовать detect_custom_voice_model() для парсинга config.json - Синхронизировать параметры сэмплирования с Python SDK (temp=0.9, rep_penalty=1.05) - Установить min_new_tokens=2 как в Python SDK - Улучшить логику выбора формата промпта для CustomVoice - Добавить новый крейт tts-app в workspace
1 parent d4f9209 commit dbcd2c8

File tree

18 files changed

+6395
-20
lines changed

18 files changed

+6395
-20
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ members = [
99
"crates/runtime",
1010
"crates/tts-cli",
1111
"crates/tts-server",
12+
"crates/tts-app",
1213
]
1314

1415
[workspace.package]

crates/runtime/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ metrics-exporter-prometheus.workspace = true
3434

3535
# Serialization
3636
serde.workspace = true
37+
serde_json.workspace = true
3738
toml.workspace = true
3839

3940
# Utilities

crates/runtime/src/pipeline.rs

Lines changed: 60 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ pub struct PipelineConfig {
4141
pub max_seq_len: usize,
4242
/// Default speaker for CustomVoice models (e.g., "vivian", "ryan").
4343
pub default_speaker: Option<String>,
44+
/// Whether the model is a CustomVoice model (requires speaker prompt format).
45+
pub is_custom_voice: bool,
4446
}
4547

4648
impl Default for PipelineConfig {
@@ -52,6 +54,7 @@ impl Default for PipelineConfig {
5254
chunk_tokens: 10,
5355
max_seq_len: 4096, // Max sequence length for audio generation
5456
default_speaker: None,
57+
is_custom_voice: false,
5558
}
5659
}
5760
}
@@ -220,9 +223,16 @@ impl TtsPipeline {
220223
AcousticBackend::Mock
221224
};
222225

223-
let config = PipelineConfig::neural();
226+
// Check if this is a CustomVoice model by looking for spk_id in config.json
227+
let is_custom_voice = Self::detect_custom_voice_model(talker_dir);
224228

225-
info!("Pipeline created with {:?} acoustic backend", acoustic);
229+
let mut config = PipelineConfig::neural();
230+
config.is_custom_voice = is_custom_voice;
231+
232+
info!(
233+
"Pipeline created with {:?} acoustic backend, is_custom_voice={}",
234+
acoustic, is_custom_voice
235+
);
226236

227237
Ok(Self {
228238
normalizer: Normalizer::new(),
@@ -235,6 +245,27 @@ impl TtsPipeline {
235245
})
236246
}
237247

248+
/// Detect if this is a CustomVoice model by checking config.json for spk_id.
249+
fn detect_custom_voice_model(model_dir: &Path) -> bool {
250+
let config_path = model_dir.join("config.json");
251+
if let Ok(content) = std::fs::read_to_string(&config_path) {
252+
if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
253+
// CustomVoice models have talker_config.spk_id object with speaker mappings
254+
if let Some(spk_id) = config
255+
.get("talker_config")
256+
.and_then(|t| t.get("spk_id"))
257+
.and_then(|s| s.as_object())
258+
{
259+
if !spk_id.is_empty() {
260+
info!("Detected CustomVoice model with {} speakers", spk_id.len());
261+
return true;
262+
}
263+
}
264+
}
265+
}
266+
false
267+
}
268+
238269
/// Try to load CodePredictor from the same weights file as the main model.
239270
fn try_load_code_predictor(
240271
weights_path: &Path,
@@ -454,8 +485,9 @@ impl TtsPipeline {
454485
seed: None,
455486
};
456487

457-
// Minimum tokens based on text length
458-
let min_tokens = (text_tokens.len() * 5).max(20);
488+
// Match Python SDK: min_new_tokens = 2
489+
// This allows EOS early if model decides the text is complete
490+
let min_tokens = 2;
459491

460492
// Generate using embeddings
461493
if let Some(cp) = code_predictor {
@@ -701,24 +733,20 @@ impl TtsPipeline {
701733
"Combined embeddings built (non_streaming format)"
702734
);
703735

704-
// Configure sampling - use greedy for debugging to compare with reference
736+
// Configure sampling - match Python SDK parameters
705737
// Python SDK: temperature=0.9, top_p=1.0, top_k=50, repetition_penalty=1.05
706-
// TODO: Make this configurable, use temp=0 for greedy comparison
707738
let sampling_config = SamplingConfig {
708-
temperature: 0.0, // Greedy for debugging
739+
temperature: 0.9,
709740
top_p: 1.0,
710741
top_k: 50,
711-
repetition_penalty: 1.0, // No penalty for greedy
742+
repetition_penalty: 1.05,
712743
seed: None,
713744
};
714745

715-
// min_new_tokens based on text length: ~5-10 audio tokens per text token
716-
let min_tokens = (text_tokens.len() * 5).max(20);
717-
info!(
718-
"Setting min_new_tokens={} based on {} text tokens",
719-
min_tokens,
720-
text_tokens.len()
721-
);
746+
// Match Python SDK: min_new_tokens = 2
747+
// This allows EOS early if model decides the text is complete
748+
let min_tokens = 2;
749+
info!("min_new_tokens={} (matching Python SDK)", min_tokens);
722750

723751
// ========== COMPUTE trailing_text_hidden ==========
724752
// Python SDK (modeling_qwen3_tts.py:2230-2232):
@@ -985,8 +1013,9 @@ impl TtsPipeline {
9851013
seed: None,
9861014
};
9871015

988-
// min_new_tokens based on text length
989-
let min_tokens = (text_tokens.len() * 5).max(20);
1016+
// Match Python SDK: min_new_tokens = 2
1017+
// This allows EOS early if model decides the text is complete
1018+
let min_tokens = 2;
9901019

9911020
// If no CodePredictor, generate only zeroth codebook
9921021
let Some(cp) = code_predictor else {
@@ -1108,11 +1137,21 @@ impl TtsPipeline {
11081137
model,
11091138
code_predictor,
11101139
} => {
1111-
// Use CustomVoice format if speaker is provided or if we have speaker configured
1112-
let use_speaker = speaker.is_some() || self.config.default_speaker.is_some();
1140+
// Use CustomVoice format if:
1141+
// 1. Speaker is explicitly provided, OR
1142+
// 2. We have a default speaker configured, OR
1143+
// 3. The model is a CustomVoice model (even without speaker, needs proper prompt format)
1144+
let use_speaker_format = speaker.is_some()
1145+
|| self.config.default_speaker.is_some()
1146+
|| self.config.is_custom_voice;
1147+
11131148
let actual_speaker = speaker.or(self.config.default_speaker.as_deref());
11141149

1115-
if use_speaker {
1150+
if use_speaker_format {
1151+
info!(
1152+
"Using CustomVoice format: speaker={:?}, is_custom_voice={}",
1153+
actual_speaker, self.config.is_custom_voice
1154+
);
11161155
self.generate_acoustic_with_speaker(
11171156
model,
11181157
code_predictor.as_deref(),
@@ -1122,6 +1161,7 @@ impl TtsPipeline {
11221161
max_tokens,
11231162
)
11241163
} else {
1164+
info!("Using simple format (non-CustomVoice model)");
11251165
self.generate_acoustic_neural(
11261166
model,
11271167
code_predictor.as_deref(),

crates/tts-app/Cargo.toml

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
[package]
2+
name = "tts-app"
3+
version = "0.1.0"
4+
description = "Qwen3-TTS Desktop Application"
5+
authors = ["RustTTS Team"]
6+
license = "MIT"
7+
edition = "2021"
8+
rust-version = "1.70"
9+
10+
[lib]
11+
name = "tts_app_lib"
12+
crate-type = ["staticlib", "cdylib", "rlib"]
13+
14+
[build-dependencies]
15+
tauri-build = { version = "2", features = [] }
16+
17+
[dependencies]
18+
# Tauri
19+
tauri = { version = "2", features = [] }
20+
tauri-plugin-shell = "2"
21+
22+
# Our TTS runtime
23+
runtime = { path = "../runtime" }
24+
tts-core = { path = "../tts-core" }
25+
audio-codec-12hz = { path = "../audio-codec-12hz" }
26+
27+
# Async runtime
28+
tokio = { version = "1", features = ["rt-multi-thread", "sync", "macros"] }
29+
30+
# Audio playback
31+
rodio = { version = "0.19", default-features = false, features = ["wav"] }
32+
33+
# Serialization
34+
serde = { version = "1", features = ["derive"] }
35+
serde_json = "1"
36+
37+
# Error handling
38+
anyhow = "1"
39+
thiserror = "1"
40+
41+
# Logging
42+
tracing = "0.1"
43+
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
44+
45+
# Tensor operations (for device selection)
46+
candle-core = "0.8"
47+
48+
# System directories
49+
dirs = "5"
50+
51+
[features]
52+
default = ["custom-protocol"]
53+
custom-protocol = ["tauri/custom-protocol"]

crates/tts-app/build.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
fn main() {
2+
tauri_build::build()
3+
}

crates/tts-app/gen/schemas/acl-manifests.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
{}

0 commit comments

Comments
 (0)