Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/om1_speech/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def __init__(self, args: argparse.Namespace):
logging.basicConfig(
level=getattr(logging, self.args.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
force=True,
)
for name in logging.root.manager.loggerDict:
logging.getLogger(name).setLevel(
Expand Down
2 changes: 1 addition & 1 deletion src/om1_speech/riva/args.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def add_asr_config_argparse_parameters(
"""
parser.add_argument(
"--asr-sample-rate-hz",
default=16000,
default=48000,
type=int,
help="A number of frames per second in audio streamed from a microphone.",
)
Expand Down
53 changes: 39 additions & 14 deletions src/om1_speech/riva/asr_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,25 @@ def __init__(
self.callback = callback
self.running: bool = True

# Customize the ASR model
self.args.stop_threshold = 0.99
self.args.stop_threshold_eou = 0.99
# ASR settings
if (
not hasattr(self.args, "stop_history_eou")
or self.args.stop_history_eou == -1
):
self.args.stop_history_eou = (
200 # milliseconds of silence before finalizing
)
if (
not hasattr(self.args, "stop_threshold_eou")
or self.args.stop_threshold_eou == -1.0
):
self.args.stop_threshold_eou = 0.80 # Balanced threshold

# Reduce stop_history for faster resets but not too aggressive
if not hasattr(self.args, "stop_history") or self.args.stop_history == -1:
self.args.stop_history = 600 # Longer to prevent mid-sentence cuts
if not hasattr(self.args, "stop_threshold") or self.args.stop_threshold == -1.0:
self.args.stop_threshold = 0.90 # Higher to prevent false cuts

self._initialize_model()

Expand Down Expand Up @@ -123,17 +139,9 @@ def _yield_audio_chunks(self, audio_source: Any):
while self.running:
if audio_source:
data = audio_source.get_audio_chunk()
if (
data
and isinstance(data, dict)
and "audio" in data
and "rate" in data
):
if data["rate"] != self.args.asr_sample_rate_hz:
self.args.asr_sample_rate_hz = data["rate"]
self._initialize_model()
if data and isinstance(data, dict) and "audio" in data:
yield base64.b64decode(data["audio"])
time.sleep(0.01) # Small delay to prevent busy waiting
time.sleep(0.01)

def process_audio(self, audio_source: Any):
"""
Expand All @@ -151,6 +159,21 @@ def process_audio(self, audio_source: Any):
if self.model is None or self.model_config is None:
raise RuntimeError("ASR model is not initialized.")

logger.info("Waiting for first audio chunk to initialize sample rate...")
while self.running:
data = audio_source.get_audio_chunk()
if data and isinstance(data, dict) and "rate" in data:
if data["rate"] != self.args.asr_sample_rate_hz:
logger.info(
f"Updating sample rate from {self.args.asr_sample_rate_hz} to {data['rate']}"
)
self.args.asr_sample_rate_hz = data["rate"]
self._initialize_model()
if hasattr(audio_source, "audio_queue"):
audio_source.audio_queue.put(data)
break
time.sleep(0.01)

responses = self.model.streaming_response_generator(
audio_chunks=self._yield_audio_chunks(audio_source),
streaming_config=self.model_config,
Expand All @@ -169,9 +192,11 @@ def process_audio(self, audio_source: Any):
continue

if result.is_final:
logging.info(f"ASR: {transcript}")
logging.info(f"Final ASR Result: {transcript}")
if self.callback:
self.callback(json.dumps({"asr_reply": transcript}))
else:
logging.info(f"Interim ASR Result: {transcript}")

def stop(self):
"""
Expand Down
4 changes: 2 additions & 2 deletions src/om1_utils/ws/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ async def process_global_messages(self):
await self.connections[connection_id].send(message)
self.global_queue.task_done()
except Empty:
await asyncio.sleep(0.05)
await asyncio.sleep(0.001)
except ConnectionClosed:
pass
except Exception as e:
Expand All @@ -125,7 +125,7 @@ async def process_connection_messages(self, connection_id: str):
await websocket.send(message)
queue.task_done()
except Empty:
await asyncio.sleep(0.05)
await asyncio.sleep(0.001)
except ConnectionClosed:
break
except Exception as e:
Expand Down
Loading