From 8234cb242015c332a02e8b72c603149bfa2e2ba0 Mon Sep 17 00:00:00 2001 From: hanhandi <1540984562@qq.com> Date: Fri, 8 May 2026 11:22:54 +0000 Subject: [PATCH 1/3] feat: add dialog context handling and omit empty text results configuration --- .../bytedance_llm_based_asr/config.py | 4 + .../bytedance_llm_based_asr/dialog_ctx.py | 67 ++++++++ .../bytedance_llm_based_asr/extension.py | 161 ++++++++++++++++-- .../bytedance_llm_based_asr/manifest.json | 45 ++++- .../bytedance_llm_based_asr/property.json | 1 + .../volcengine_asr_client.py | 8 + 6 files changed, 273 insertions(+), 13 deletions(-) create mode 100644 ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py index 09d7d09bf8..df1ed411bc 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/config.py @@ -44,6 +44,10 @@ def get_enable_utterance_grouping(self) -> bool: """Get enable utterance grouping from params.""" return self.params.get("enable_utterance_grouping", True) + def get_omit_empty_text_results(self) -> bool: + """When True, do not emit ASR results whose text is empty or whitespace-only.""" + return self.params.get("omit_empty_text_results", False) + def get_request_config(self) -> dict[str, Any]: """Get request configuration for ASR. diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py new file mode 100644 index 0000000000..5f0625408e --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py @@ -0,0 +1,67 @@ +# ABOUTME: Maps Main POC dialog_messages to Volcengine request.corpus.context string +# ABOUTME: context_data is newest-first per Volcengine dialog_ctx specification +from __future__ import annotations + +import json +from typing import Any + +MAX_CONTEXT_ITEMS = 20 + + +def _lang_is_zh(language: str) -> bool: + return language.lower().startswith("zh") + + +def _line_for_message(role: str, content: str, lang_zh: bool) -> str: + c = (content or "").strip() + if not c: + return "" + if role == "user": + return f"用户:{c}" if lang_zh else f"User: {c}" + if role == "assistant": + return f"助手:{c}" if lang_zh else f"Assistant: {c}" + return "" + + +def build_volc_dialog_context_str( + dialog_messages: list[Any], + *, + language: str = "zh-CN", + max_items: int = MAX_CONTEXT_ITEMS, +) -> str | None: + """Build corpus.context value: inner JSON as a single string. + + Returns None when there is no dialog context (caller should remove corpus.context). + """ + if not dialog_messages: + return None + + lang_zh = _lang_is_zh(language) + lines_chrono: list[str] = [] + for m in dialog_messages: + if not isinstance(m, dict): + continue + role = m.get("role") + if role not in ("user", "assistant"): + continue + line = _line_for_message( + str(role), str(m.get("content") or ""), lang_zh + ) + if line: + lines_chrono.append(line) + + if not lines_chrono: + return None + + newest_first = list(reversed(lines_chrono))[:max_items] + + context_data = [{"text": line} for line in newest_first] + + if not context_data: + return None + + inner = { + "context_type": "dialog_ctx", + "context_data": context_data, + } + return json.dumps(inner, ensure_ascii=False) diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py index 8c4af858a4..3c119c88f3 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py @@ -44,6 +44,7 @@ ) from .config import BytedanceASRLLMConfig +from .dialog_ctx import build_volc_dialog_context_str from .volcengine_asr_client import VolcengineASRClient, ASRResponse, Utterance from .log_id_dumper_manager import LogIdDumperManager from .const import ( @@ -137,6 +138,8 @@ def __init__(self, name: str): # Enable utterance grouping self.enable_utterance_grouping: bool = True + self._update_configs_lock: asyncio.Lock = asyncio.Lock() + @override def vendor(self) -> str: """Get the name of the ASR vendor.""" @@ -638,8 +641,18 @@ async def _send_asr_result_from_text( duration_ms: int, language: str, metadata: dict[str, Any], - ) -> None: - """Send ASR result with given text and metadata.""" + ) -> bool: + """Send ASR result with given text and metadata. + + Returns False when the result is not sent (e.g. empty text with omit enabled). + """ + if ( + self.config + and self.config.get_omit_empty_text_results() + and not text.strip() + ): + return False + asr_result = ASRResult( text=text, final=is_final, @@ -650,6 +663,7 @@ async def _send_asr_result_from_text( metadata=metadata, ) await self.send_asr_result(asr_result) + return True async def _track_utterance_timestamps(self, result: ASRResponse) -> None: """Track utterance timestamps and send two-pass delay metrics.""" @@ -763,7 +777,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None: metadata = self._build_metadata_with_asr_info( result_level_fields=result_level_asr_info ) - await self._send_asr_result_from_text( + sent = await self._send_asr_result_from_text( text=result.text, is_final=False, start_ms=actual_start_ms, @@ -799,7 +813,6 @@ async def _on_asr_result(self, result: ASRResponse) -> None: # Extract metadata (always include invoke_type and source for all results) if is_final: - has_final_result = True metadata = self._extract_final_result_metadata( utterance, result_level_fields=result_level_asr_info, @@ -810,7 +823,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None: result_level_fields=result_level_asr_info, ) - await self._send_asr_result_from_text( + sent = await self._send_asr_result_from_text( text=utterance.text, is_final=is_final, start_ms=actual_start_ms, @@ -818,6 +831,8 @@ async def _on_asr_result(self, result: ASRResponse) -> None: language=result.language, metadata=metadata, ) + if is_final and sent: + has_final_result = True else: # Filter out invalid utterances first valid_utterances = [ @@ -875,10 +890,7 @@ async def _on_asr_result(self, result: ASRResponse) -> None: merged["start_time"] ) - if is_final: - has_final_result = True - - await self._send_asr_result_from_text( + sent = await self._send_asr_result_from_text( text=merged["text"], is_final=is_final, start_ms=actual_start_ms, @@ -886,6 +898,8 @@ async def _on_asr_result(self, result: ASRResponse) -> None: language=result.language, metadata=merged["metadata"], ) + if is_final and sent: + has_final_result = True # finalize end signal if there was any final result if has_final_result: @@ -1046,7 +1060,132 @@ def _on_disconnected(self) -> None: ) self.connected = False + @staticmethod + def _set_update_configs_cmd_result( + cmd_result: CmdResult, *, code: int, message: str + ) -> None: + """Match ten_ai_base cmd_in convention: result.code + result.message.""" + cmd_result.set_property_int("code", code) + cmd_result.set_property_string("message", message) + + async def _run_update_configs( + self, payload: dict[str, Any] + ) -> tuple[bool, str]: + if not self.config: + return False, "config not loaded" + if payload.get("schema_version") != 1: + return False, "unsupported schema_version" + raw = payload.get("dialog_messages") + if raw is not None and not isinstance(raw, list): + return False, "dialog_messages must be a list" + + dialog_messages: list[Any] = raw if isinstance(raw, list) else [] + + context_str = build_volc_dialog_context_str( + dialog_messages, + language=self.config.language, + ) + req = self.config.params.setdefault("request", {}) + corpus = req.setdefault("corpus", {}) + if context_str is None: + corpus.pop("context", None) + else: + corpus["context"] = context_str + + self.ten_env.log_info( + f"update_configs: corpus.context set, reconnecting (messages={len(dialog_messages)}): " + f"{json.dumps(dialog_messages, ensure_ascii=False)}", + category=LOG_CATEGORY_KEY_POINT, + ) + + await self.stop_connection() + await self.start_connection() + + return True, "" + async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: """Handle commands.""" - cmd_result = CmdResult.create(StatusCode.OK, cmd) - await ten_env.return_result(cmd_result) + if cmd.get_name() != "update_configs": + await super().on_cmd(ten_env, cmd) + return + + try: + prop_json, jerr = cmd.get_property_to_json(None) + if jerr or not prop_json: + ten_env.log_error( + f"update_configs: missing_or_invalid_cmd_json err={jerr!r}", + category=LOG_CATEGORY_KEY_POINT, + ) + bad = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + bad, + code=-1, + message="missing_or_invalid_cmd_json", + ) + await ten_env.return_result(bad) + return + + try: + payload = json.loads(prop_json) + except json.JSONDecodeError as e: + ten_env.log_error( + f"update_configs: invalid_json {e}", + category=LOG_CATEGORY_KEY_POINT, + ) + bad = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + bad, + code=-1, + message=f"invalid_json:{e}", + ) + await ten_env.return_result(bad) + return + + if not isinstance(payload, dict): + ten_env.log_error( + "update_configs: payload_not_object", + category=LOG_CATEGORY_KEY_POINT, + ) + bad = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + bad, + code=-1, + message="payload_not_object", + ) + await ten_env.return_result(bad) + return + + async with self._update_configs_lock: + ok, err_msg = await self._run_update_configs(payload) + + if ok: + cmd_result = CmdResult.create(StatusCode.OK, cmd) + self._set_update_configs_cmd_result( + cmd_result, + code=0, + message="volc_dialog_ctx", + ) + else: + ten_env.log_error( + f"update_configs: {err_msg}", + category=LOG_CATEGORY_KEY_POINT, + ) + cmd_result = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + cmd_result, + code=-1, + message=err_msg, + ) + await ten_env.return_result(cmd_result) + except Exception as e: + ten_env.log_error( + f"update_configs: unexpected error: {e}", + category=LOG_CATEGORY_KEY_POINT, + ) + cmd_result = CmdResult.create(StatusCode.ERROR, cmd) + self._set_update_configs_cmd_result( + cmd_result, + code=-1, + message=str(e), + ) + await ten_env.return_result(cmd_result) diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json index 95235d8707..d8e9c00a2e 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json @@ -1,7 +1,7 @@ { "type": "extension", "name": "bytedance_llm_based_asr", - "version": "0.3.20", + "version": "0.4.0", "dependencies": [ { "type": "system", @@ -43,7 +43,48 @@ } } } - } + }, + "cmd_in": [ + { + "name": "update_configs", + "property": { + "properties": { + "schema_version": { + "type": "int32" + }, + "dialog_messages": { + "type": "array", + "items": { + "type": "object", + "properties": { + "role": { + "type": "string" + }, + "content": { + "type": "string" + } + } + } + } + }, + "required": [ + "schema_version" + ] + }, + "result": { + "property": { + "properties": { + "code": { + "type": "int64" + }, + "message": { + "type": "string" + } + } + } + } + } + ] }, "package": { "include": [ diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json index 49db5fd3e0..edc41256fc 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/property.json @@ -5,6 +5,7 @@ "app_key": "${env:BYTEDANCE_ASR_LLM_APP_KEY|}", "access_key": "${env:BYTEDANCE_ASR_LLM_ACCESS_KEY|}", "language": "zh-CN", + "omit_empty_text_results": false, "resource_id": "${env:BYTEDANCE_ASR_LLM_RESOURCE_ID|}", "model_version": "${env:BYTEDANCE_ASR_LLM_MODEL_VERSION|}", "audio": { diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py index 50f6246281..32e4e07bb0 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/volcengine_asr_client.py @@ -506,6 +506,14 @@ async def _send_full_client_request(self) -> None: if not self.websocket: raise RuntimeError("WebSocket not connected") + request_payload = { + "request": self.config.get_request_config(), + } + if self.ten_env: + self.ten_env.log_info( + "full_client_request params (request section): " + + json.dumps(request_payload, ensure_ascii=False), + ) request = RequestBuilder.new_full_client_request(self.seq, self.config) self.seq += 1 await self.websocket.send(request) From 5163cd0013f02e48808d29fe08d72532c4aad885 Mon Sep 17 00:00:00 2001 From: hanhandi <1540984562@qq.com> Date: Tue, 12 May 2026 04:49:33 +0000 Subject: [PATCH 2/3] feat: change update_configs --- .../bytedance_llm_based_asr/dialog_ctx.py | 67 ----------------- .../bytedance_llm_based_asr/extension.py | 74 ++++++++++++++----- .../bytedance_llm_based_asr/manifest.json | 27 +++---- .../test_apply_string_field_under_params.py | 50 +++++++++++++ 4 files changed, 117 insertions(+), 101 deletions(-) delete mode 100644 ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py create mode 100644 ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py deleted file mode 100644 index 5f0625408e..0000000000 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/dialog_ctx.py +++ /dev/null @@ -1,67 +0,0 @@ -# ABOUTME: Maps Main POC dialog_messages to Volcengine request.corpus.context string -# ABOUTME: context_data is newest-first per Volcengine dialog_ctx specification -from __future__ import annotations - -import json -from typing import Any - -MAX_CONTEXT_ITEMS = 20 - - -def _lang_is_zh(language: str) -> bool: - return language.lower().startswith("zh") - - -def _line_for_message(role: str, content: str, lang_zh: bool) -> str: - c = (content or "").strip() - if not c: - return "" - if role == "user": - return f"用户:{c}" if lang_zh else f"User: {c}" - if role == "assistant": - return f"助手:{c}" if lang_zh else f"Assistant: {c}" - return "" - - -def build_volc_dialog_context_str( - dialog_messages: list[Any], - *, - language: str = "zh-CN", - max_items: int = MAX_CONTEXT_ITEMS, -) -> str | None: - """Build corpus.context value: inner JSON as a single string. - - Returns None when there is no dialog context (caller should remove corpus.context). - """ - if not dialog_messages: - return None - - lang_zh = _lang_is_zh(language) - lines_chrono: list[str] = [] - for m in dialog_messages: - if not isinstance(m, dict): - continue - role = m.get("role") - if role not in ("user", "assistant"): - continue - line = _line_for_message( - str(role), str(m.get("content") or ""), lang_zh - ) - if line: - lines_chrono.append(line) - - if not lines_chrono: - return None - - newest_first = list(reversed(lines_chrono))[:max_items] - - context_data = [{"text": line} for line in newest_first] - - if not context_data: - return None - - inner = { - "context_type": "dialog_ctx", - "context_data": context_data, - } - return json.dumps(inner, ensure_ascii=False) diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py index 3c119c88f3..79fe8bf1f4 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py @@ -44,7 +44,6 @@ ) from .config import BytedanceASRLLMConfig -from .dialog_ctx import build_volc_dialog_context_str from .volcengine_asr_client import VolcengineASRClient, ASRResponse, Utterance from .log_id_dumper_manager import LogIdDumperManager from .const import ( @@ -53,6 +52,43 @@ ) +def _apply_string_field_under_params( + params: dict[str, Any], field_path: str, value: str +) -> tuple[bool, str]: + """Set or clear a nested key under params using dot-separated path. + + Empty ``value`` removes the terminal key (same as prior dialog_ctx clear). + + For each prefix segment: if the key is missing, an empty dict is created; + if the key exists and its value is not a dict, the call fails with + ``path_conflict_non_dict_at:...`` and ``params`` is left unchanged for + that traversal (no silent overwrite of scalars/lists/etc.). + """ + parts = [p for p in field_path.split(".") if p] + if not parts: + return False, "empty_field_path" + cur: Any = params + for key in parts[:-1]: + if not isinstance(cur, dict): + return False, f"path_not_dict_at:{key}" + if key not in cur: + nxt: dict[str, Any] = {} + cur[key] = nxt + cur = nxt + elif isinstance(cur[key], dict): + cur = cur[key] + else: + return False, f"path_conflict_non_dict_at:{key}" + last = parts[-1] + if not isinstance(cur, dict): + return False, "parent_not_dict" + if value == "": + cur.pop(last, None) + else: + cur[last] = value + return True, "" + + @dataclass class TwoPassDelayTracker: """Track two-pass delay metrics timestamps""" @@ -1075,26 +1111,26 @@ async def _run_update_configs( return False, "config not loaded" if payload.get("schema_version") != 1: return False, "unsupported schema_version" - raw = payload.get("dialog_messages") - if raw is not None and not isinstance(raw, list): - return False, "dialog_messages must be a list" - - dialog_messages: list[Any] = raw if isinstance(raw, list) else [] - - context_str = build_volc_dialog_context_str( - dialog_messages, - language=self.config.language, + field_path = payload.get("field_path") + if not isinstance(field_path, str) or not field_path.strip(): + return False, "field_path must be a non-empty string" + raw_val = payload.get("value", "") + if raw_val is not None and not isinstance(raw_val, str): + return False, "value must be a string when present" + value = "" if raw_val is None else raw_val + + params = self.config.params + if not isinstance(params, dict): + return False, "params_not_dict" + ok, err = _apply_string_field_under_params( + params, field_path.strip(), value ) - req = self.config.params.setdefault("request", {}) - corpus = req.setdefault("corpus", {}) - if context_str is None: - corpus.pop("context", None) - else: - corpus["context"] = context_str + if not ok: + return False, err self.ten_env.log_info( - f"update_configs: corpus.context set, reconnecting (messages={len(dialog_messages)}): " - f"{json.dumps(dialog_messages, ensure_ascii=False)}", + f"update_configs: params[{field_path!r}] updated, reconnecting " + f"(value_len={len(value)})", category=LOG_CATEGORY_KEY_POINT, ) @@ -1163,7 +1199,7 @@ async def on_cmd(self, ten_env: AsyncTenEnv, cmd: Cmd) -> None: self._set_update_configs_cmd_result( cmd_result, code=0, - message="volc_dialog_ctx", + message="", ) else: ten_env.log_error( diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json index d8e9c00a2e..e8ffc456aa 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json @@ -1,7 +1,7 @@ { "type": "extension", "name": "bytedance_llm_based_asr", - "version": "0.4.0", + "version": "0.4.1", "dependencies": [ { "type": "system", @@ -52,23 +52,20 @@ "schema_version": { "type": "int32" }, - "dialog_messages": { - "type": "array", - "items": { - "type": "object", - "properties": { - "role": { - "type": "string" - }, - "content": { - "type": "string" - } - } - } + "field_path": { + "type": "string" + }, + "value": { + "type": "string" + }, + "metadata": { + "type": "object", + "properties": {} } }, "required": [ - "schema_version" + "schema_version", + "field_path" ] }, "result": { diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py new file mode 100644 index 0000000000..84d0081e96 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py @@ -0,0 +1,50 @@ +"""Tests for _apply_string_field_under_params nested path writes.""" + +import os +import sys +import types + +extension_dir = os.path.join(os.path.dirname(__file__), "..") +sys.path.insert(0, extension_dir) + +package = types.ModuleType("bytedance_llm_based_asr") +package.__path__ = [extension_dir] +sys.modules["bytedance_llm_based_asr"] = package + +from bytedance_llm_based_asr import config as config_module # noqa: E402 + +sys.modules["bytedance_llm_based_asr.config"] = config_module + +from bytedance_llm_based_asr import extension as extension_module # noqa: E402 + +sys.modules["bytedance_llm_based_asr.extension"] = extension_module + +from bytedance_llm_based_asr.extension import ( # noqa: E402 + _apply_string_field_under_params, +) + + +def test_creates_missing_intermediate_dicts(): + params: dict = {} + ok, err = _apply_string_field_under_params( + params, "request.corpus.context", '{"x":1}' + ) + assert ok and err == "" + assert params["request"]["corpus"]["context"] == '{"x":1}' + + +def test_conflict_when_intermediate_is_non_dict_no_mutation(): + params = {"request": "not-a-dict"} + ok, err = _apply_string_field_under_params( + params, "request.corpus.context", "v" + ) + assert not ok + assert err.startswith("path_conflict_non_dict_at:") + assert params == {"request": "not-a-dict"} + + +def test_clear_terminal_key(): + params = {"a": {"b": "x"}} + ok, err = _apply_string_field_under_params(params, "a.b", "") + assert ok and err == "" + assert params == {"a": {}} From 393290d227543fa4b5667ce50b9008595e4b8515 Mon Sep 17 00:00:00 2001 From: hanhandi <1540984562@qq.com> Date: Tue, 12 May 2026 07:49:27 +0000 Subject: [PATCH 3/3] feat: change update_configs --- .../bytedance_llm_based_asr/extension.py | 104 +++---- .../bytedance_llm_based_asr/manifest.json | 42 +-- .../test_apply_string_field_under_params.py | 50 ---- .../test_omit_empty_merge_and_helpers.py | 271 ++++++++++++++++++ 4 files changed, 328 insertions(+), 139 deletions(-) delete mode 100644 ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py create mode 100644 ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py index 79fe8bf1f4..a9a805b8a5 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/extension.py @@ -11,7 +11,7 @@ import websockets from dataclasses import asdict, dataclass from ten_ai_base.message import ModuleMetrics -from typing import Any +from typing import Any, cast from typing_extensions import override from ten_ai_base.asr import ( @@ -52,41 +52,41 @@ ) -def _apply_string_field_under_params( - params: dict[str, Any], field_path: str, value: str -) -> tuple[bool, str]: - """Set or clear a nested key under params using dot-separated path. +def _deep_merge_dict(target: dict[str, Any], patch: dict[str, Any]) -> None: + """Merge ``patch`` into ``target`` in place; dict values recurse.""" + for key, patch_val in patch.items(): + if ( + key in target + and isinstance(target[key], dict) + and isinstance(patch_val, dict) + ): + _deep_merge_dict(target[key], patch_val) + else: + target[key] = patch_val + - Empty ``value`` removes the terminal key (same as prior dialog_ctx clear). +def _strip_empty_request_corpus_context(params: dict[str, Any]) -> None: + """Normalize ``request.corpus.context`` after ``update_configs`` merges JSON into ``params``. - For each prefix segment: if the key is missing, an empty dict is created; - if the key exists and its value is not a dict, the call fails with - ``path_conflict_non_dict_at:...`` and ``params`` is left unchanged for - that traversal (no silent overwrite of scalars/lists/etc.). + Contract for runtime config patches (``update_configs`` payload → merged into + ``config.params``): + + - If ``request`` or ``corpus`` is missing or not a ``dict``, this function is a no-op. + - If ``corpus.context`` is the empty string ``""``, the ``context`` key is **removed** + from ``corpus``. That is interpreted as "clear dialog / hot context" for the next + connection, not as "send an explicit empty string field" to the service. + + Limitation: callers cannot use this path to force the wire request to include + ``context: ""`` while keeping the key; only omission vs non-empty string is supported. """ - parts = [p for p in field_path.split(".") if p] - if not parts: - return False, "empty_field_path" - cur: Any = params - for key in parts[:-1]: - if not isinstance(cur, dict): - return False, f"path_not_dict_at:{key}" - if key not in cur: - nxt: dict[str, Any] = {} - cur[key] = nxt - cur = nxt - elif isinstance(cur[key], dict): - cur = cur[key] - else: - return False, f"path_conflict_non_dict_at:{key}" - last = parts[-1] - if not isinstance(cur, dict): - return False, "parent_not_dict" - if value == "": - cur.pop(last, None) - else: - cur[last] = value - return True, "" + request = params.get("request") + if not isinstance(request, dict): + return + corpus = request.get("corpus") + if not isinstance(corpus, dict): + return + if corpus.get("context") == "": + corpus.pop("context", None) @dataclass @@ -1109,28 +1109,34 @@ async def _run_update_configs( ) -> tuple[bool, str]: if not self.config: return False, "config not loaded" - if payload.get("schema_version") != 1: - return False, "unsupported schema_version" - field_path = payload.get("field_path") - if not isinstance(field_path, str) or not field_path.strip(): - return False, "field_path must be a non-empty string" - raw_val = payload.get("value", "") - if raw_val is not None and not isinstance(raw_val, str): - return False, "value must be a string when present" - value = "" if raw_val is None else raw_val + + params_patch = payload.get("params") + url_val = payload.get("url") + dump_present = "dump" in payload + + has_params = isinstance(params_patch, dict) and bool(params_patch) + has_url = isinstance(url_val, str) and bool(url_val.strip()) params = self.config.params if not isinstance(params, dict): return False, "params_not_dict" - ok, err = _apply_string_field_under_params( - params, field_path.strip(), value - ) - if not ok: - return False, err + + if has_url: + params["api_url"] = cast(str, url_val).strip() + + if dump_present: + self.config.dump = bool(payload["dump"]) + + keys_for_log: list[str] = [] + if has_params: + patch = cast(dict[str, Any], params_patch) + keys_for_log = list(patch.keys()) + _deep_merge_dict(params, patch) + _strip_empty_request_corpus_context(params) self.ten_env.log_info( - f"update_configs: params[{field_path!r}] updated, reconnecting " - f"(value_len={len(value)})", + "update_configs: merged payload keys into config, reconnecting " + f"(params_keys={keys_for_log})", category=LOG_CATEGORY_KEY_POINT, ) diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json index e8ffc456aa..e850623f12 100644 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/manifest.json @@ -1,7 +1,7 @@ { "type": "extension", "name": "bytedance_llm_based_asr", - "version": "0.4.1", + "version": "0.4.3", "dependencies": [ { "type": "system", @@ -43,45 +43,7 @@ } } } - }, - "cmd_in": [ - { - "name": "update_configs", - "property": { - "properties": { - "schema_version": { - "type": "int32" - }, - "field_path": { - "type": "string" - }, - "value": { - "type": "string" - }, - "metadata": { - "type": "object", - "properties": {} - } - }, - "required": [ - "schema_version", - "field_path" - ] - }, - "result": { - "property": { - "properties": { - "code": { - "type": "int64" - }, - "message": { - "type": "string" - } - } - } - } - } - ] + } }, "package": { "include": [ diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py deleted file mode 100644 index 84d0081e96..0000000000 --- a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_apply_string_field_under_params.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Tests for _apply_string_field_under_params nested path writes.""" - -import os -import sys -import types - -extension_dir = os.path.join(os.path.dirname(__file__), "..") -sys.path.insert(0, extension_dir) - -package = types.ModuleType("bytedance_llm_based_asr") -package.__path__ = [extension_dir] -sys.modules["bytedance_llm_based_asr"] = package - -from bytedance_llm_based_asr import config as config_module # noqa: E402 - -sys.modules["bytedance_llm_based_asr.config"] = config_module - -from bytedance_llm_based_asr import extension as extension_module # noqa: E402 - -sys.modules["bytedance_llm_based_asr.extension"] = extension_module - -from bytedance_llm_based_asr.extension import ( # noqa: E402 - _apply_string_field_under_params, -) - - -def test_creates_missing_intermediate_dicts(): - params: dict = {} - ok, err = _apply_string_field_under_params( - params, "request.corpus.context", '{"x":1}' - ) - assert ok and err == "" - assert params["request"]["corpus"]["context"] == '{"x":1}' - - -def test_conflict_when_intermediate_is_non_dict_no_mutation(): - params = {"request": "not-a-dict"} - ok, err = _apply_string_field_under_params( - params, "request.corpus.context", "v" - ) - assert not ok - assert err.startswith("path_conflict_non_dict_at:") - assert params == {"request": "not-a-dict"} - - -def test_clear_terminal_key(): - params = {"a": {"b": "x"}} - ok, err = _apply_string_field_under_params(params, "a.b", "") - assert ok and err == "" - assert params == {"a": {}} diff --git a/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py new file mode 100644 index 0000000000..f399280e3c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/bytedance_llm_based_asr/tests/test_omit_empty_merge_and_helpers.py @@ -0,0 +1,271 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import copy +import os +import sys +import types +from unittest.mock import AsyncMock, MagicMock + +import pytest + +extension_dir = os.path.join(os.path.dirname(__file__), "..") +sys.path.insert(0, extension_dir) + +package = types.ModuleType("bytedance_llm_based_asr") +package.__path__ = [extension_dir] +sys.modules["bytedance_llm_based_asr"] = package + +from bytedance_llm_based_asr import config as config_module + +sys.modules["bytedance_llm_based_asr.config"] = config_module + +from bytedance_llm_based_asr import extension as extension_module + +sys.modules["bytedance_llm_based_asr.extension"] = extension_module + +from bytedance_llm_based_asr.config import BytedanceASRLLMConfig +from bytedance_llm_based_asr.extension import ( + BytedanceASRLLMExtension, + _deep_merge_dict, + _strip_empty_request_corpus_context, +) + + +def _minimal_config( + *, omit_empty_text_results: bool | None = None +) -> BytedanceASRLLMConfig: + params: dict = { + "audio": {"rate": 16000}, + "request": {"model_name": "bigmodel"}, + } + if omit_empty_text_results is not None: + params["omit_empty_text_results"] = omit_empty_text_results + return BytedanceASRLLMConfig.model_validate({"params": params}) + + +def test_get_omit_empty_text_results_defaults_false() -> None: + cfg = _minimal_config() + assert cfg.get_omit_empty_text_results() is False + + +def test_get_omit_empty_text_results_true_when_set() -> None: + cfg = _minimal_config(omit_empty_text_results=True) + assert cfg.get_omit_empty_text_results() is True + + +def test_deep_merge_dict_nested() -> None: + target: dict = {"a": {"b": 1}, "c": 2} + _deep_merge_dict(target, {"a": {"b": 2, "d": 3}}) + assert target == {"a": {"b": 2, "d": 3}, "c": 2} + + +def test_deep_merge_dict_non_dict_patch_replaces_dict() -> None: + target: dict = {"a": {"b": 1}} + _deep_merge_dict(target, {"a": "scalar"}) + assert target == {"a": "scalar"} + + +def test_deep_merge_dict_dict_patch_replaces_non_dict() -> None: + target: dict = {"a": 1} + _deep_merge_dict(target, {"a": {"nested": True}}) + assert target == {"a": {"nested": True}} + + +def test_strip_empty_context_removes_empty_string() -> None: + params = { + "request": {"corpus": {"context": "", "other": 1}}, + } + _strip_empty_request_corpus_context(params) + assert params["request"]["corpus"] == {"other": 1} + + +def test_strip_empty_context_keeps_non_empty_string() -> None: + params = {"request": {"corpus": {"context": "hello"}}} + _strip_empty_request_corpus_context(params) + assert params["request"]["corpus"]["context"] == "hello" + + +def test_strip_empty_context_missing_context_no_op() -> None: + params = {"request": {"corpus": {"other": 1}}} + before = copy.deepcopy(params) + _strip_empty_request_corpus_context(params) + assert params == before + + +def test_strip_empty_context_request_not_dict() -> None: + params = {"request": "bad"} + before = copy.deepcopy(params) + _strip_empty_request_corpus_context(params) + assert params == before + + +def test_strip_empty_context_corpus_not_dict() -> None: + params = {"request": {"corpus": "bad"}} + before = copy.deepcopy(params) + _strip_empty_request_corpus_context(params) + assert params == before + + +@pytest.fixture +def mock_ten_env(): + env = AsyncMock() + env.log_info = MagicMock() + return env + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_omit_on_skips_empty( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config(omit_empty_text_results=True) + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text=" ", + is_final=True, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is False + ext.send_asr_result.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_omit_on_sends_non_empty( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config(omit_empty_text_results=True) + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text="hi", + is_final=True, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is True + ext.send_asr_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_omit_off_sends_whitespace_only( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config(omit_empty_text_results=False) + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text=" ", + is_final=False, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is True + ext.send_asr_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_send_asr_result_from_text_no_config_sends(mock_ten_env) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = None + ext.send_asr_result = AsyncMock() + + sent = await ext._send_asr_result_from_text( + text="", + is_final=True, + start_ms=0, + duration_ms=1, + language="zh-CN", + metadata={}, + ) + assert sent is True + ext.send_asr_result.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_run_update_configs_returns_false_when_config_missing( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = None + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, msg = await ext._run_update_configs({"params": {"x": 1}}) + assert ok is False + assert msg == "config not loaded" + ext.stop_connection.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_run_update_configs_merges_and_strips_empty_context( + mock_ten_env, +) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config() + ext.config.params["request"]["corpus"] = {"context": "old"} + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, msg = await ext._run_update_configs( + { + "params": { + "request": { + "corpus": {"context": ""}, + "enable_nonstream": False, + } + } + } + ) + assert ok is True + assert msg == "" + req = ext.config.params["request"] + assert "context" not in req.get("corpus", {}) + assert req["enable_nonstream"] is False + ext.stop_connection.assert_awaited_once() + ext.start_connection.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_run_update_configs_sets_trimmed_api_url(mock_ten_env) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config() + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, _ = await ext._run_update_configs({"url": " wss://example.test/asr "}) + assert ok is True + assert ext.config.params["api_url"] == "wss://example.test/asr" + + +@pytest.mark.asyncio +async def test_run_update_configs_dump_flag(mock_ten_env) -> None: + ext = BytedanceASRLLMExtension("t") + ext.ten_env = mock_ten_env + ext.config = _minimal_config() + ext.config.dump = False + ext.stop_connection = AsyncMock() + ext.start_connection = AsyncMock() + + ok, _ = await ext._run_update_configs({"dump": True}) + assert ok is True + assert ext.config.dump is True