@@ -41,6 +41,8 @@ pub struct PipelineConfig {
4141 pub max_seq_len : usize ,
4242 /// Default speaker for CustomVoice models (e.g., "vivian", "ryan").
4343 pub default_speaker : Option < String > ,
44+ /// Whether the model is a CustomVoice model (requires speaker prompt format).
45+ pub is_custom_voice : bool ,
4446}
4547
4648impl Default for PipelineConfig {
@@ -52,6 +54,7 @@ impl Default for PipelineConfig {
5254 chunk_tokens : 10 ,
5355 max_seq_len : 4096 , // Max sequence length for audio generation
5456 default_speaker : None ,
57+ is_custom_voice : false ,
5558 }
5659 }
5760}
@@ -220,9 +223,16 @@ impl TtsPipeline {
220223 AcousticBackend :: Mock
221224 } ;
222225
223- let config = PipelineConfig :: neural ( ) ;
226+ // Check if this is a CustomVoice model by looking for spk_id in config.json
227+ let is_custom_voice = Self :: detect_custom_voice_model ( talker_dir) ;
224228
225- info ! ( "Pipeline created with {:?} acoustic backend" , acoustic) ;
229+ let mut config = PipelineConfig :: neural ( ) ;
230+ config. is_custom_voice = is_custom_voice;
231+
232+ info ! (
233+ "Pipeline created with {:?} acoustic backend, is_custom_voice={}" ,
234+ acoustic, is_custom_voice
235+ ) ;
226236
227237 Ok ( Self {
228238 normalizer : Normalizer :: new ( ) ,
@@ -235,6 +245,27 @@ impl TtsPipeline {
235245 } )
236246 }
237247
248+ /// Detect if this is a CustomVoice model by checking config.json for spk_id.
249+ fn detect_custom_voice_model ( model_dir : & Path ) -> bool {
250+ let config_path = model_dir. join ( "config.json" ) ;
251+ if let Ok ( content) = std:: fs:: read_to_string ( & config_path) {
252+ if let Ok ( config) = serde_json:: from_str :: < serde_json:: Value > ( & content) {
253+ // CustomVoice models have talker_config.spk_id object with speaker mappings
254+ if let Some ( spk_id) = config
255+ . get ( "talker_config" )
256+ . and_then ( |t| t. get ( "spk_id" ) )
257+ . and_then ( |s| s. as_object ( ) )
258+ {
259+ if !spk_id. is_empty ( ) {
260+ info ! ( "Detected CustomVoice model with {} speakers" , spk_id. len( ) ) ;
261+ return true ;
262+ }
263+ }
264+ }
265+ }
266+ false
267+ }
268+
238269 /// Try to load CodePredictor from the same weights file as the main model.
239270 fn try_load_code_predictor (
240271 weights_path : & Path ,
@@ -454,8 +485,9 @@ impl TtsPipeline {
454485 seed : None ,
455486 } ;
456487
457- // Minimum tokens based on text length
458- let min_tokens = ( text_tokens. len ( ) * 5 ) . max ( 20 ) ;
488+ // Match Python SDK: min_new_tokens = 2
489+ // This allows EOS early if model decides the text is complete
490+ let min_tokens = 2 ;
459491
460492 // Generate using embeddings
461493 if let Some ( cp) = code_predictor {
@@ -701,24 +733,20 @@ impl TtsPipeline {
701733 "Combined embeddings built (non_streaming format)"
702734 ) ;
703735
704- // Configure sampling - use greedy for debugging to compare with reference
736+ // Configure sampling - match Python SDK parameters
705737 // Python SDK: temperature=0.9, top_p=1.0, top_k=50, repetition_penalty=1.05
706- // TODO: Make this configurable, use temp=0 for greedy comparison
707738 let sampling_config = SamplingConfig {
708- temperature : 0.0 , // Greedy for debugging
739+ temperature : 0.9 ,
709740 top_p : 1.0 ,
710741 top_k : 50 ,
711- repetition_penalty : 1.0 , // No penalty for greedy
742+ repetition_penalty : 1.05 ,
712743 seed : None ,
713744 } ;
714745
715- // min_new_tokens based on text length: ~5-10 audio tokens per text token
716- let min_tokens = ( text_tokens. len ( ) * 5 ) . max ( 20 ) ;
717- info ! (
718- "Setting min_new_tokens={} based on {} text tokens" ,
719- min_tokens,
720- text_tokens. len( )
721- ) ;
746+ // Match Python SDK: min_new_tokens = 2
747+ // This allows EOS early if model decides the text is complete
748+ let min_tokens = 2 ;
749+ info ! ( "min_new_tokens={} (matching Python SDK)" , min_tokens) ;
722750
723751 // ========== COMPUTE trailing_text_hidden ==========
724752 // Python SDK (modeling_qwen3_tts.py:2230-2232):
@@ -985,8 +1013,9 @@ impl TtsPipeline {
9851013 seed : None ,
9861014 } ;
9871015
988- // min_new_tokens based on text length
989- let min_tokens = ( text_tokens. len ( ) * 5 ) . max ( 20 ) ;
1016+ // Match Python SDK: min_new_tokens = 2
1017+ // This allows EOS early if model decides the text is complete
1018+ let min_tokens = 2 ;
9901019
9911020 // If no CodePredictor, generate only zeroth codebook
9921021 let Some ( cp) = code_predictor else {
@@ -1108,11 +1137,21 @@ impl TtsPipeline {
11081137 model,
11091138 code_predictor,
11101139 } => {
1111- // Use CustomVoice format if speaker is provided or if we have speaker configured
1112- let use_speaker = speaker. is_some ( ) || self . config . default_speaker . is_some ( ) ;
1140+ // Use CustomVoice format if:
1141+ // 1. Speaker is explicitly provided, OR
1142+ // 2. We have a default speaker configured, OR
1143+ // 3. The model is a CustomVoice model (even without speaker, needs proper prompt format)
1144+ let use_speaker_format = speaker. is_some ( )
1145+ || self . config . default_speaker . is_some ( )
1146+ || self . config . is_custom_voice ;
1147+
11131148 let actual_speaker = speaker. or ( self . config . default_speaker . as_deref ( ) ) ;
11141149
1115- if use_speaker {
1150+ if use_speaker_format {
1151+ info ! (
1152+ "Using CustomVoice format: speaker={:?}, is_custom_voice={}" ,
1153+ actual_speaker, self . config. is_custom_voice
1154+ ) ;
11161155 self . generate_acoustic_with_speaker (
11171156 model,
11181157 code_predictor. as_deref ( ) ,
@@ -1122,6 +1161,7 @@ impl TtsPipeline {
11221161 max_tokens,
11231162 )
11241163 } else {
1164+ info ! ( "Using simple format (non-CustomVoice model)" ) ;
11251165 self . generate_acoustic_neural (
11261166 model,
11271167 code_predictor. as_deref ( ) ,
0 commit comments