diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py index 09d7d09bf8..df1ed411bc 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py @@ -44,6 +44,10 @@ def get_enable_utterance_grouping(self) -> bool: """Get enable utterance grouping from params.""" return self.params.get("enable_utterance_grouping", True) + def get_omit_empty_text_results(self) -> bool: + """When True, do not emit ASR results whose text is empty or whitespace-only.""" + return self.params.get("omit_empty_text_results", False) + def get_request_config(self) -> dict[str, Any]: """Get request configuration for ASR. diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py index 8c4af858a4..a9a805b8a5 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py @@ -11,7 +11,7 @@ import websockets from dataclasses import asdict, dataclass from ten_ai_base.message import ModuleMetrics -from typing import Any +from typing import Any, cast from typing_extensions import override from ten_ai_base.asr import ( @@ -52,6 +52,43 @@ ) +def _deep_merge_dict(target: dict[str, Any], patch: dict[str, Any]) -> None: + """Merge ``patch`` into ``target`` in place; dict values recurse.""" + for key, patch_val in patch.items(): + if ( + key in target + and isinstance(target[key], dict) + and isinstance(patch_val, dict) + ): + _deep_merge_dict(target[key], patch_val) + else: + target[key] = patch_val + + +def _strip_empty_request_corpus_context(params: dict[str, Any]) -> None: + """Normalize ``request.corpus.context`` after ``update_configs`` merges JSON into ``params``. + + Contract for runtime config patches (``update_configs`` payload → merged into + ``config.params``): + + - If ``request`` or ``corpus`` is missing or not a ``dict``, this function is a no-op. + - If ``corpus.context`` is the empty string ``""``, the ``context`` key is **removed** + from ``corpus``. That is interpreted as "clear dialog / hot context" for the next + connection, not as "send an explicit empty string field" to the service. + + Limitation: callers cannot use this path to force the wire request to include + ``context: ""`` while keeping the key; only omission vs non-empty string is supported. + """ + request = params.get("request") + if not isinstance(request, dict): + return + corpus = request.get("corpus") + if not isinstance(corpus, dict): + return + if corpus.get("context") == "": + corpus.pop("context", None) + + @dataclass class TwoPassDelayTracker: """Track two-pass delay metrics timestamps""" @@ -137,6 +174,8 @@ def __init__(self, name: str): # Enable utterance grouping self.enable_utterance_grouping: bool = True + self._update_configs_lock: asyncio.Lock = asyncio.Lock() + @override def vendor(self) -> str: """Get the name of the ASR vendor.""" @@ -638,8 +677,18 @@ async def _send_asr_result_from_text( duration_ms: int, language: str, metadata: dict[str, Any], - ) -> None: - """Send ASR result with given text and metadata.""" + ) -> bool: + """Send ASR result with given text and metadata. + + Returns False when the result is not sent (e.g. empty text with omit enabled). + """ + if ( + self.config + and self.config.get_omit_empty_text_results() + and not text.strip() + ): + return False + asr_result = ASRResult( text=text, final=is_final, @@ -650,6 +699,7 @@ async def _send_asr_result_from_text( metadata=metadata, ) await self.send_asr_result(asr_result) + return True async def _track_utterance_timestamps(self, result: ASRResponse) -> None: """Track utterance timestamps and send two-pass delay metrics.""" @@ -763,7 +813,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None: metadata = self._build_metadata_with_asr_info( result_level_fields=result_level_asr_info ) - await self._send_asr_result_from_text( + sent = await self._send_asr_result_from_text( text=result.text, is_final=False, start_ms=actual_start_ms, @@ -799,7 +849,6 @@ async def _on_asr_result(self, result: ASRResponse) -> None: # Extract metadata (always include invoke_type and source for all results) if is_final: - has_final_result = True metadata = self._extract_final_result_metadata( utterance, result_level_fields=result_level_asr_info, @@ -810,7 +859,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None: result_level_fields=result_level_asr_info, ) - await self._send_asr_result_from_text( + sent = await self._send_asr_result_from_text( text=utterance.text, is_final=is_final, start_ms=actual_start_ms, @@ -818,6 +867,8 @@ async def _on_asr_result(self, result: ASRResponse) -> None: language=result.language, metadata=metadata, ) + if is_final and sent: + has_final_result = True else: # Filter out invalid utterances first valid_utterances = [ @@ -875,10 +926,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None: merged["start_time"] ) - if is_final: - has_final_result = True - - await self._send_asr_result_from_text( + sent = await self._send_asr_result_from_text( text=merged["text"], is_final=is_final, start_ms=actual_start_ms, @@ -886,6 +934,8 @@ async def _on_asr_result(self, result: ASRResponse) -> None: language=result.language, metadata=merged["metadata"], ) + if is_final and sent: + has_final_result = True # finalize end signal if there was any final result if has_final_result: @@ -1046,7 +1096,138 @@ def _on_disconnected(self) -> None: ) self.connected = False + @staticmethod + def _set_update_configs_cmd_result( + cmd_result: CmdResult, *, code: int, message: str + ) -> None: + """Match ten_ai_base cmd_in convention: result.code + result.message.""" + cmd_result.set_property_int("code", code) + cmd_result.set_property_string("message", message) + + async def _run_update_configs( + self, payload: dict[str, Any] + ) -> tuple[bool, str]: + if not self.config: + return False, "config not loaded" + + params_patch = payload.get("params") + url_val = payload.get("url") + dump_present = "dump" in payload + + has_params = isinstance(params_patch, dict) and bool(params_patch) + has_url = isinstance(url_val, str) and bool(url_val.strip()) + + params = self.config.params + if not isinstance(params, dict): + return False, "params_not_dict" + + if has_url: + params["api_url"] = cast(str, url_val).strip() + + if dump_present: + self.config.dump = bool(payload["dump"]) + + keys_for_log: list[str] = [] + if has_params: + patch = cast(dict[str, Any], params_patch) + keys_for_log = list(patch.keys()) + _deep_merge_dict(params, patch) + _strip_empty_request_corpus_context(params) + + self.ten_env.log_info( + "update_configs: merged payload keys into config, reconnecting " + f"(params_keys={keys_for_log})", + category=LOG_CATEGORY_KEY_POINT, + ) + + await self.stop_connection() + await self.start_connection() + + return True, "" + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: """Handle commands.""" - cmd_result = CmdResult.create(StatusCode.OK, cmd) - await ten_env.return_result(cmd_result) + if cmd.get_name() != "update_configs": + await super().on_cmd(ten_env, cmd) + return + + try: + prop_json, jerr = cmd.get_property_to_json(None) + if jerr or not prop_json: + ten_env.log_error( + f"update_configs: missing_or_invalid_cmd_json err={jerr!r}", + category=LOG_CATEGORY_KEY_POINT, + ) + bad = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + bad, + code=-1, + message="missing_or_invalid_cmd_json", + ) + await ten_env.return_result(bad) + return + + try: + payload = json.loads(prop_json) + except json.JSONDecodeError as e: + ten_env.log_error( + f"update_configs: invalid_json {e}", + category=LOG_CATEGORY_KEY_POINT, + ) + bad = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + bad, + code=-1, + message=f"invalid_json:{e}", + ) + await ten_env.return_result(bad) + return + + if not isinstance(payload, dict): + ten_env.log_error( + "update_configs: payload_not_object", + category=LOG_CATEGORY_KEY_POINT, + ) + bad = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + bad, + code=-1, + message="payload_not_object", + ) + await ten_env.return_result(bad) + return + + async with self._update_configs_lock: + ok, err_msg = await self._run_update_configs(payload) + + if ok: + cmd_result = CmdResult.create(StatusCode.OK, cmd) + self._set_update_configs_cmd_result( + cmd_result, + code=0, + message="", + ) + else: + ten_env.log_error( + f"update_configs: {err_msg}", + category=LOG_CATEGORY_KEY_POINT, + ) + cmd_result = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + cmd_result, + code=-1, + message=err_msg, + ) + await ten_env.return_result(cmd_result) + except Exception as e: + ten_env.log_error( + f"update_configs: unexpected error: {e}", + category=LOG_CATEGORY_KEY_POINT, + ) + cmd_result = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + cmd_result, + code=-1, + message=str(e), + ) + await ten_env.return_result(cmd_result) diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json index 95235d8707..e850623f12 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json @@ -1,7 +1,7 @@ { "type": "extension", "name": "bytedance_llm_based_asr", - "version": "0.3.20", + "version": "0.4.3", "dependencies": [ { "type": "system", diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json index 49db5fd3e0..edc41256fc 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json @@ -5,6 +5,7 @@ "app_key": "${env:BYTEDANCE_ASR_LLM_APP_KEY|}", "access_key": "${env:BYTEDANCE_ASR_LLM_ACCESS_KEY|}", "language": "zh-CN", + "omit_empty_text_results": false, "resource_id": "${env:BYTEDANCE_ASR_LLM_RESOURCE_ID|}", "model_version": "${env:BYTEDANCE_ASR_LLM_MODEL_VERSION|}", "audio": { diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py new file mode 100644 index 0000000000..f399280e3c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py @@ -0,0 +1,271 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import copy +import os +import sys +import types +from unittest.mock import AsyncMock, MagicMock + +import pytest + +extension_dir = os.path.join(os.path.dirname(__file__), "..") +sys.path.insert(0, extension_dir) + +package = types.ModuleType("bytedance_llm_based_asr") +package.__path__ = [extension_dir] +sys.modules["bytedance_llm_based_asr"] = package + +from bytedance_llm_based_asr import config as config_module + +sys.modules["bytedance_llm_based_asr.config"] = config_module + +from bytedance_llm_based_asr import extension as extension_module + +sys.modules["bytedance_llm_based_asr.extension"] = extension_module + +from bytedance_llm_based_asr.config import BytedanceASRLLMConfig +from bytedance_llm_based_asr.extension import ( + BytedanceASRLLMExtension, + _deep_merge_dict, + _strip_empty_request_corpus_context, +) + + +def _minimal_config( + *, omit_empty_text_results: bool | None = None +) -> BytedanceASRLLMConfig: + params: dict = { + "audio": {"rate": 16000}, + "request": {"model_name": "bigmodel"}, + } + if omit_empty_text_results is not None: + params["omit_empty_text_results"] = omit_empty_text_results + return BytedanceASRLLMConfig.model_validate({"params": params}) + + +def test_get_omit_empty_text_results_defaults_false() -> None: + cfg = _minimal_config() + assert cfg.get_omit_empty_text_results() is False + + +def test_get_omit_empty_text_results_true_when_set() -> None: + cfg = _minimal_config(omit_empty_text_results=True) + assert cfg.get_omit_empty_text_results() is True + + +def test_deep_merge_dict_nested() -> None: + target: dict = {"a": {"b": 1}, "c": 2} + _deep_merge_dict(target, {"a": {"b": 2, "d": 3}}) + assert target == {"a": {"b": 2, "d": 3}, "c": 2} + + +def test_deep_merge_dict_non_dict_patch_replaces_dict() -> None: + target: dict = {"a": {"b": 1}} + _deep_merge_dict(target, {"a": "scalar"}) + assert target == {"a": "scalar"} + + +def test_deep_merge_dict_dict_patch_replaces_non_dict() -> None: + target: dict = {"a": 1} + _deep_merge_dict(target, {"a": {"nested": True}}) + assert target == {"a": {"nested": True}} + + +def test_strip_empty_context_removes_empty_string() -> None: + params = { + "request": {"corpus": {"context": "", "other": 1}}, + } + _strip_empty_request_corpus_context(params) + assert params["request"]["corpus"] == {"other": 1} + + +def test_strip_empty_context_keeps_non_empty_string() -> None: + params = {"request": {"corpus": {"context": "hello"}}} + _strip_empty_request_corpus_context(params) + assert params["request"]["corpus"]["context"] == "hello" + + +def test_strip_empty_context_missing_context_no_op() -> None: + params = {"request": {"corpus": {"other": 1}}} + before = copy.deepcopy(params) + _strip_empty_request_corpus_context(params) + assert params == before + + +def test_strip_empty_context_request_not_dict() -> None: + params = {"request": "bad"} + before = copy.deepcopy(params) + _strip_empty_request_corpus_context(params) + assert params == before + + +def test_strip_empty_context_corpus_not_dict() -> None: + params = {"request": {"corpus": "bad"}} + before = copy.deepcopy(params) + _strip_empty_request_corpus_context(params) + assert params == before + + +@pytest.fixture +def mock_ten_env(): + env = AsyncMock() + env.log_info = MagicMock() + return env + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_omit_on_skips_empty( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config(omit_empty_text_results=True) + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text=" ", + is_final=True, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is False + ext.send_asr_result.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_omit_on_sends_non_empty( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config(omit_empty_text_results=True) + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text="hi", + is_final=True, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is True + ext.send_asr_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_omit_off_sends_whitespace_only( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config(omit_empty_text_results=False) + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text=" ", + is_final=False, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is True + ext.send_asr_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_no_config_sends(mock_ten_env) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = None + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text="", + is_final=True, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is True + ext.send_asr_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_run_update_configs_returns_false_when_config_missing( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = None + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, msg = await ext._run_update_configs({"params": {"x": 1}}) + assert ok is False + assert msg == "config not loaded" + ext.stop_connection.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_run_update_configs_merges_and_strips_empty_context( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config() + ext.config.params["request"]["corpus"] = {"context": "old"} + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, msg = await ext._run_update_configs( + { + "params": { + "request": { + "corpus": {"context": ""}, + "enable_nonstream": False, + } + } + } + ) + assert ok is True + assert msg == "" + req = ext.config.params["request"] + assert "context" not in req.get("corpus", {}) + assert req["enable_nonstream"] is False + ext.stop_connection.assert_awaited_once() + ext.start_connection.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_run_update_configs_sets_trimmed_api_url(mock_ten_env) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config() + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, _ = await ext._run_update_configs({"url": " wss://example.test/asr "}) + assert ok is True + assert ext.config.params["api_url"] == "wss://example.test/asr" + + +@pytest.mark.asyncio +async def test_run_update_configs_dump_flag(mock_ten_env) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config() + ext.config.dump = False + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, _ = await ext._run_update_configs({"dump": True}) + assert ok is True + assert ext.config.dump is True diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py index 50f6246281..32e4e07bb0 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py @@ -506,6 +506,14 @@ async def _send_full_client_request(self) -> None: if not self.websocket: raise RuntimeError("WebSocket not connected") + request_payload = { + "request": self.config.get_request_config(), + } + if self.ten_env: + self.ten_env.log_info( + "full_client_request params (request section): " + + json.dumps(request_payload, ensure_ascii=False), + ) request = RequestBuilder.new_full_client_request(self.seq, self.config) self.seq += 1 await self.websocket.send(request)