diff --git a/src/om1_speech/main.py b/src/om1_speech/main.py index e43f2ce..f3e0d0b 100644 --- a/src/om1_speech/main.py +++ b/src/om1_speech/main.py @@ -66,6 +66,7 @@ def __init__(self, args: argparse.Namespace): logging.basicConfig( level=getattr(logging, self.args.log_level.upper()), format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + force=True, ) for name in logging.root.manager.loggerDict: logging.getLogger(name).setLevel( diff --git a/src/om1_speech/riva/args.py b/src/om1_speech/riva/args.py index 1b0eb6d..f8799a7 100644 --- a/src/om1_speech/riva/args.py +++ b/src/om1_speech/riva/args.py @@ -32,7 +32,7 @@ def add_asr_config_argparse_parameters( """ parser.add_argument( "--asr-sample-rate-hz", - default=16000, + default=48000, type=int, help="A number of frames per second in audio streamed from a microphone.", ) diff --git a/src/om1_speech/riva/asr_processor.py b/src/om1_speech/riva/asr_processor.py index af3df12..74fa274 100644 --- a/src/om1_speech/riva/asr_processor.py +++ b/src/om1_speech/riva/asr_processor.py @@ -39,9 +39,25 @@ def __init__( self.callback = callback self.running: bool = True - # Customize the ASR model - self.args.stop_threshold = 0.99 - self.args.stop_threshold_eou = 0.99 + # ASR settings + if ( + not hasattr(self.args, "stop_history_eou") + or self.args.stop_history_eou == -1 + ): + self.args.stop_history_eou = ( + 200 # milliseconds of silence before finalizing + ) + if ( + not hasattr(self.args, "stop_threshold_eou") + or self.args.stop_threshold_eou == -1.0 + ): + self.args.stop_threshold_eou = 0.80 # Balanced threshold + + # Reduce stop_history for faster resets but not too aggressive + if not hasattr(self.args, "stop_history") or self.args.stop_history == -1: + self.args.stop_history = 600 # Longer to prevent mid-sentence cuts + if not hasattr(self.args, "stop_threshold") or self.args.stop_threshold == -1.0: + self.args.stop_threshold = 0.90 # Higher to prevent false cuts self._initialize_model() @@ -123,17 +139,9 @@ def _yield_audio_chunks(self, audio_source: Any): while self.running: if audio_source: data = audio_source.get_audio_chunk() - if ( - data - and isinstance(data, dict) - and "audio" in data - and "rate" in data - ): - if data["rate"] != self.args.asr_sample_rate_hz: - self.args.asr_sample_rate_hz = data["rate"] - self._initialize_model() + if data and isinstance(data, dict) and "audio" in data: yield base64.b64decode(data["audio"]) - time.sleep(0.01) # Small delay to prevent busy waiting + time.sleep(0.01) def process_audio(self, audio_source: Any): """ @@ -151,6 +159,21 @@ def process_audio(self, audio_source: Any): if self.model is None or self.model_config is None: raise RuntimeError("ASR model is not initialized.") + logger.info("Waiting for first audio chunk to initialize sample rate...") + while self.running: + data = audio_source.get_audio_chunk() + if data and isinstance(data, dict) and "rate" in data: + if data["rate"] != self.args.asr_sample_rate_hz: + logger.info( + f"Updating sample rate from {self.args.asr_sample_rate_hz} to {data['rate']}" + ) + self.args.asr_sample_rate_hz = data["rate"] + self._initialize_model() + if hasattr(audio_source, "audio_queue"): + audio_source.audio_queue.put(data) + break + time.sleep(0.01) + responses = self.model.streaming_response_generator( audio_chunks=self._yield_audio_chunks(audio_source), streaming_config=self.model_config, @@ -169,9 +192,11 @@ def process_audio(self, audio_source: Any): continue if result.is_final: - logging.info(f"ASR: {transcript}") + logging.info(f"Final ASR Result: {transcript}") if self.callback: self.callback(json.dumps({"asr_reply": transcript})) + else: + logging.info(f"Interim ASR Result: {transcript}") def stop(self): """ diff --git a/src/om1_utils/ws/server.py b/src/om1_utils/ws/server.py index 7ad5c55..2f19a57 100644 --- a/src/om1_utils/ws/server.py +++ b/src/om1_utils/ws/server.py @@ -101,7 +101,7 @@ async def process_global_messages(self): await self.connections[connection_id].send(message) self.global_queue.task_done() except Empty: - await asyncio.sleep(0.05) + await asyncio.sleep(0.001) except ConnectionClosed: pass except Exception as e: @@ -125,7 +125,7 @@ async def process_connection_messages(self, connection_id: str): await websocket.send(message) queue.task_done() except Empty: - await asyncio.sleep(0.05) + await asyncio.sleep(0.001) except ConnectionClosed: break except Exception as e: