Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def get_enable_utterance_grouping(self) -> bool:
"""Get enable utterance grouping from params."""
return self.params.get("enable_utterance_grouping", True)

def get_omit_empty_text_results(self) -> bool:
"""When True, do not emit ASR results whose text is empty or whitespace-only."""
return self.params.get("omit_empty_text_results", False)

def get_request_config(self) -> dict[str, Any]:
"""Get request configuration for ASR.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import websockets
from dataclasses import asdict, dataclass
from ten_ai_base.message import ModuleMetrics
from typing import Any
from typing import Any, cast
from typing_extensions import override

from ten_ai_base.asr import (
Expand Down Expand Up @@ -52,6 +52,43 @@
)


def _deep_merge_dict(target: dict[str, Any], patch: dict[str, Any]) -> None:
"""Merge ``patch`` into ``target`` in place; dict values recurse."""
for key, patch_val in patch.items():
if (
key in target
and isinstance(target[key], dict)
and isinstance(patch_val, dict)
):
_deep_merge_dict(target[key], patch_val)
else:
target[key] = patch_val


def _strip_empty_request_corpus_context(params: dict[str, Any]) -> None:
"""Normalize ``request.corpus.context`` after ``update_configs`` merges JSON into ``params``.

Contract for runtime config patches (``update_configs`` payload → merged into
``config.params``):

- If ``request`` or ``corpus`` is missing or not a ``dict``, this function is a no-op.
- If ``corpus.context`` is the empty string ``""``, the ``context`` key is **removed**
from ``corpus``. That is interpreted as "clear dialog / hot context" for the next
connection, not as "send an explicit empty string field" to the service.

Limitation: callers cannot use this path to force the wire request to include
``context: ""`` while keeping the key; only omission vs non-empty string is supported.
"""
request = params.get("request")
if not isinstance(request, dict):
return
corpus = request.get("corpus")
if not isinstance(corpus, dict):
return
if corpus.get("context") == "":
corpus.pop("context", None)


@dataclass
class TwoPassDelayTracker:
"""Track two-pass delay metrics timestamps"""
Expand Down Expand Up @@ -137,6 +174,8 @@ def __init__(self, name: str):
# Enable utterance grouping
self.enable_utterance_grouping: bool = True

self._update_configs_lock: asyncio.Lock = asyncio.Lock()

@override
def vendor(self) -> str:
"""Get the name of the ASR vendor."""
Expand Down Expand Up @@ -638,8 +677,18 @@ async def _send_asr_result_from_text(
duration_ms: int,
language: str,
metadata: dict[str, Any],
) -> None:
"""Send ASR result with given text and metadata."""
) -> bool:
"""Send ASR result with given text and metadata.

Returns False when the result is not sent (e.g. empty text with omit enabled).
"""
if (
self.config
and self.config.get_omit_empty_text_results()
and not text.strip()
):
return False

asr_result = ASRResult(
text=text,
final=is_final,
Expand All @@ -650,6 +699,7 @@ async def _send_asr_result_from_text(
metadata=metadata,
)
await self.send_asr_result(asr_result)
return True

async def _track_utterance_timestamps(self, result: ASRResponse) -> None:
"""Track utterance timestamps and send two-pass delay metrics."""
Expand Down Expand Up @@ -763,7 +813,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
metadata = self._build_metadata_with_asr_info(
result_level_fields=result_level_asr_info
)
await self._send_asr_result_from_text(
sent = await self._send_asr_result_from_text(
text=result.text,
is_final=False,
start_ms=actual_start_ms,
Expand Down Expand Up @@ -799,7 +849,6 @@ async def _on_asr_result(self, result: ASRResponse) -> None:

# Extract metadata (always include invoke_type and source for all results)
if is_final:
has_final_result = True
metadata = self._extract_final_result_metadata(
utterance,
result_level_fields=result_level_asr_info,
Expand All @@ -810,14 +859,16 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
result_level_fields=result_level_asr_info,
)

await self._send_asr_result_from_text(
sent = await self._send_asr_result_from_text(
text=utterance.text,
is_final=is_final,
start_ms=actual_start_ms,
duration_ms=duration_ms,
language=result.language,
metadata=metadata,
)
if is_final and sent:
has_final_result = True
else:
# Filter out invalid utterances first
valid_utterances = [
Expand Down Expand Up @@ -875,17 +926,16 @@ async def _on_asr_result(self, result: ASRResponse) -> None:
merged["start_time"]
)

if is_final:
has_final_result = True

await self._send_asr_result_from_text(
sent = await self._send_asr_result_from_text(
text=merged["text"],
is_final=is_final,
start_ms=actual_start_ms,
duration_ms=merged["duration_ms"],
language=result.language,
metadata=merged["metadata"],
)
if is_final and sent:
has_final_result = True

# finalize end signal if there was any final result
if has_final_result:
Expand Down Expand Up @@ -1046,7 +1096,138 @@ def _on_disconnected(self) -> None:
)
self.connected = False

@staticmethod
def _set_update_configs_cmd_result(
cmd_result: CmdResult, *, code: int, message: str
) -> None:
"""Match ten_ai_base cmd_in convention: result.code + result.message."""
cmd_result.set_property_int("code", code)
cmd_result.set_property_string("message", message)

async def _run_update_configs(
self, payload: dict[str, Any]
) -> tuple[bool, str]:
if not self.config:
return False, "config not loaded"

params_patch = payload.get("params")
url_val = payload.get("url")
dump_present = "dump" in payload

has_params = isinstance(params_patch, dict) and bool(params_patch)
has_url = isinstance(url_val, str) and bool(url_val.strip())

params = self.config.params
if not isinstance(params, dict):
return False, "params_not_dict"

if has_url:
params["api_url"] = cast(str, url_val).strip()

if dump_present:
self.config.dump = bool(payload["dump"])

keys_for_log: list[str] = []
if has_params:
patch = cast(dict[str, Any], params_patch)
keys_for_log = list(patch.keys())
_deep_merge_dict(params, patch)
_strip_empty_request_corpus_context(params)

self.ten_env.log_info(
"update_configs: merged payload keys into config, reconnecting "
f"(params_keys={keys_for_log})",
category=LOG_CATEGORY_KEY_POINT,
)

await self.stop_connection()
await self.start_connection()

return True, ""

async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None:
"""Handle commands."""
cmd_result = CmdResult.create(StatusCode.OK, cmd)
await ten_env.return_result(cmd_result)
if cmd.get_name() != "update_configs":
await super().on_cmd(ten_env, cmd)
return

try:
prop_json, jerr = cmd.get_property_to_json(None)
if jerr or not prop_json:
ten_env.log_error(
f"update_configs: missing_or_invalid_cmd_json err={jerr!r}",
category=LOG_CATEGORY_KEY_POINT,
)
bad = CmdResult.create(StatusCode.ERROR, cmd)
self._set_update_configs_cmd_result(
bad,
code=-1,
message="missing_or_invalid_cmd_json",
)
await ten_env.return_result(bad)
return

try:
payload = json.loads(prop_json)
except json.JSONDecodeError as e:
ten_env.log_error(
f"update_configs: invalid_json {e}",
category=LOG_CATEGORY_KEY_POINT,
)
bad = CmdResult.create(StatusCode.ERROR, cmd)
self._set_update_configs_cmd_result(
bad,
code=-1,
message=f"invalid_json:{e}",
)
await ten_env.return_result(bad)
return

if not isinstance(payload, dict):
ten_env.log_error(
"update_configs: payload_not_object",
category=LOG_CATEGORY_KEY_POINT,
)
bad = CmdResult.create(StatusCode.ERROR, cmd)
self._set_update_configs_cmd_result(
bad,
code=-1,
message="payload_not_object",
)
await ten_env.return_result(bad)
return

async with self._update_configs_lock:
ok, err_msg = await self._run_update_configs(payload)

if ok:
cmd_result = CmdResult.create(StatusCode.OK, cmd)
self._set_update_configs_cmd_result(
cmd_result,
code=0,
message="",
)
else:
ten_env.log_error(
f"update_configs: {err_msg}",
category=LOG_CATEGORY_KEY_POINT,
)
cmd_result = CmdResult.create(StatusCode.ERROR, cmd)
self._set_update_configs_cmd_result(
cmd_result,
code=-1,
message=err_msg,
)
await ten_env.return_result(cmd_result)
except Exception as e:
ten_env.log_error(
f"update_configs: unexpected error: {e}",
category=LOG_CATEGORY_KEY_POINT,
)
cmd_result = CmdResult.create(StatusCode.ERROR, cmd)
self._set_update_configs_cmd_result(
cmd_result,
code=-1,
message=str(e),
)
await ten_env.return_result(cmd_result)
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"type": "extension",
"name": "bytedance_llm_based_asr",
"version": "0.3.20",
"version": "0.4.3",
"dependencies": [
{
"type": "system",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"app_key": "${env:BYTEDANCE_ASR_LLM_APP_KEY|}",
"access_key": "${env:BYTEDANCE_ASR_LLM_ACCESS_KEY|}",
"language": "zh-CN",
"omit_empty_text_results": false,
"resource_id": "${env:BYTEDANCE_ASR_LLM_RESOURCE_ID|}",
"model_version": "${env:BYTEDANCE_ASR_LLM_MODEL_VERSION|}",
"audio": {
Expand Down
Loading
Loading