From 920218b7dd35afc8353c039388cddb900ade0944 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 Apr 2026 08:50:51 +0000 Subject: [PATCH 1/5] feat(ai-agents): add xai speech extensions --- ai_agents/.env.example | 6 +- .../agents/examples/voice-assistant/README.md | 44 +- .../voice-assistant/tenapp/manifest.json | 6 + .../voice-assistant/tenapp/property.json | 562 +++++++++++++++++- .../extension/xai_asr_python/README.md | 62 ++ .../extension/xai_asr_python/__init__.py | 1 + .../extension/xai_asr_python/addon.py | 15 + .../extension/xai_asr_python/config.py | 76 +++ .../extension/xai_asr_python/const.py | 3 + .../extension/xai_asr_python/extension.py | 382 ++++++++++++ .../extension/xai_asr_python/manifest.json | 72 +++ .../extension/xai_asr_python/property.json | 16 + .../extension/xai_asr_python/recognition.py | 180 ++++++ .../xai_asr_python/reconnect_manager.py | 121 ++++ .../extension/xai_asr_python/requirements.txt | 2 + .../xai_asr_python/tests/__init__.py | 0 .../extension/xai_asr_python/tests/bin/start | 21 + .../tests/configs/property_dump.json | 13 + .../tests/configs/property_en.json | 11 + .../tests/configs/property_en_hotwords.json | 11 + .../tests/configs/property_invalid.json | 13 + .../tests/configs/property_zh.json | 11 + .../xai_asr_python/tests/conftest.py | 106 ++++ .../extension/xai_asr_python/tests/mock.py | 5 + .../xai_asr_python/tests/test_asr_result.py | 79 +++ .../xai_asr_python/tests/test_dump.py | 40 ++ .../xai_asr_python/tests/test_error_check.py | 69 +++ .../xai_asr_python/tests/test_finalize.py | 57 ++ .../xai_asr_python/tests/test_metrics.py | 33 + .../xai_asr_python/tests/test_params.py | 64 ++ .../xai_asr_python/tests/test_reconnect.py | 74 +++ .../xai_asr_python/tests/test_vendor_error.py | 42 ++ .../extension/xai_tts_python/README.md | 91 +++ .../extension/xai_tts_python/__init__.py | 6 + .../extension/xai_tts_python/addon.py | 15 + .../extension/xai_tts_python/config.py | 80 +++ .../extension/xai_tts_python/extension.py | 384 ++++++++++++ .../extension/xai_tts_python/manifest.json | 77 +++ .../extension/xai_tts_python/property.json | 12 + .../extension/xai_tts_python/requirements.txt | 2 + .../xai_tts_python/tests/__init__.py | 5 + .../extension/xai_tts_python/tests/bin/start | 21 + .../tests/configs/property.json | 12 + .../property_basic_audio_setting1.json | 11 + .../property_basic_audio_setting2.json | 11 + .../tests/configs/property_dump.json | 11 + .../tests/configs/property_invalid.json | 5 + .../tests/configs/property_miss_required.json | 5 + .../configs/property_subtitle_alignment.json | 12 + .../xai_tts_python/tests/conftest.py | 107 ++++ .../xai_tts_python/tests/test_basic.py | 314 ++++++++++ .../xai_tts_python/tests/test_error_msg.py | 166 ++++++ .../xai_tts_python/tests/test_metrics.py | 127 ++++ .../xai_tts_python/tests/test_params.py | 182 ++++++ .../xai_tts_python/tests/test_robustness.py | 313 ++++++++++ .../tests/test_state_machine.py | 454 ++++++++++++++ .../extension/xai_tts_python/xai_tts.py | 287 +++++++++ 57 files changed, 4900 insertions(+), 7 deletions(-) create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/README.md create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/addon.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/config.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/const.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/property.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/__init__.py create mode 100755 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_dump.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en_hotwords.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_invalid.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_zh.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/conftest.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/mock.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_asr_result.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_dump.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_error_check.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_metrics.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_params.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_vendor_error.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/README.md create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/__init__.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/addon.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/config.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/manifest.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/property.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/requirements.txt create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/__init__.py create mode 100755 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/bin/start create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting1.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting2.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_dump.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_invalid.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_miss_required.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/conftest.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py create mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/xai_tts.py diff --git a/ai_agents/.env.example b/ai_agents/.env.example index e4ec14de98..5e3a51a432 100644 --- a/ai_agents/.env.example +++ b/ai_agents/.env.example @@ -100,6 +100,10 @@ LITELLM_MODEL=gpt-4o-mini # Deepgram ASR key DEEPGRAM_API_KEY= +# Extension: xai_asr_python, xai_tts_python +# xAI Voice API key +XAI_API_KEY= + # Azure ASR AZURE_ASR_API_KEY= AZURE_ASR_REGION= @@ -223,4 +227,4 @@ ALIYUN_ANALYTICDB_NAMESPACE_PASSWORD= # Sarvam SARVAM_ASR_KEY= -SARVAM_TTS_KEY= \ No newline at end of file +SARVAM_TTS_KEY= diff --git a/ai_agents/agents/examples/voice-assistant/README.md b/ai_agents/agents/examples/voice-assistant/README.md index 906643ee82..7deb7baf81 100644 --- a/ai_agents/agents/examples/voice-assistant/README.md +++ b/ai_agents/agents/examples/voice-assistant/README.md @@ -1,6 +1,6 @@ # Voice Assistant -A comprehensive voice assistant with real-time conversation capabilities using Agora RTC, Deepgram STT, OpenAI LLM, and ElevenLabs TTS. +A configurable voice assistant with real-time conversation capabilities using Agora RTC, interchangeable STT/TTS providers, and an OpenAI-compatible LLM. ## Features @@ -13,14 +13,25 @@ A comprehensive voice assistant with real-time conversation capabilities using A 1. **Agora Account**: Get credentials from [Agora Console](https://console.agora.io/) - `AGORA_APP_ID` - Your Agora App ID (required) -2. **Deepgram Account**: Get credentials from [Deepgram Console](https://console.deepgram.com/) - - `DEEPGRAM_API_KEY` - Your Deepgram API key (required) +2. **STT Provider**: choose the graph you want to run + - `DEEPGRAM_API_KEY` for the default `voice_assistant` graph + - `XAI_API_KEY` for `voice_assistant_xai_asr` or `voice_assistant_xai_full` 3. **OpenAI Account**: Get credentials from [OpenAI Platform](https://platform.openai.com/) - `OPENAI_API_KEY` - Your OpenAI API key (required) -4. **ElevenLabs Account**: Get credentials from [ElevenLabs](https://elevenlabs.io/) +4. **TTS Provider**: choose the graph you want to run + - `ELEVENLABS_TTS_KEY` for the default `voice_assistant` graph or `voice_assistant_xai_asr` + - `XAI_API_KEY` for `voice_assistant_xai_tts` or `voice_assistant_xai_full` + +### Provider-specific keys + +- **Deepgram Account**: Get credentials from [Deepgram Console](https://console.deepgram.com/) + - `DEEPGRAM_API_KEY` - Your Deepgram API key (required) +- **ElevenLabs Account**: Get credentials from [ElevenLabs](https://elevenlabs.io/) - `ELEVENLABS_TTS_KEY` - Your ElevenLabs API key (required) +- **xAI Account**: Get credentials from [xAI Console](https://console.x.ai/) + - `XAI_API_KEY` - Your xAI Voice API key (required for xAI STT/TTS graphs) ### Optional Environment Variables @@ -51,6 +62,9 @@ OPENAI_PROXY_URL=your_proxy_url_here # ElevenLabs (required for text-to-speech) ELEVENLABS_TTS_KEY=your_elevenlabs_api_key_here +# xAI (required for xAI speech-to-text and/or text-to-speech graphs) +XAI_API_KEY=your_xai_api_key_here + # Optional WEATHERAPI_API_KEY=your_weather_api_key_here ``` @@ -71,7 +85,7 @@ cd agents/examples/voice-assistant task run ``` -The voice assistant starts with all capabilities enabled. +The stack starts the TEN app, API server, frontend, and TMAN Designer. ### 4. Access the Application @@ -79,6 +93,25 @@ The voice assistant starts with all capabilities enabled. - **API Server**: http://localhost:8080 - **TMAN Designer**: http://localhost:49483 +### 5. Choose a Graph + +The frontend reads the `graph` URL query parameter and matches it against +`tenapp/property.json` `predefined_graphs[].name`. + +Available graph names: + +- `voice_assistant` - Deepgram STT + OpenAI-compatible LLM + ElevenLabs TTS +- `voice_assistant_xai_asr` - xAI STT + OpenAI-compatible LLM + ElevenLabs TTS +- `voice_assistant_xai_tts` - Deepgram STT + OpenAI-compatible LLM + xAI TTS +- `voice_assistant_xai_full` - xAI STT + OpenAI-compatible LLM + xAI TTS + +Examples: + +```text +http://localhost:3000/?graph=voice_assistant_xai_full +https://ten-demo.agora.io/?graph=voice_assistant_xai_full +``` + ## Configuration The voice assistant is configured in `tenapp/property.json`: @@ -189,6 +222,7 @@ docker run --rm -it --env-file .env -p 8080:8080 -p 3000:3000 voice-assistant-ap - [Agora RTC Documentation](https://docs.agora.io/en/rtc/overview/product-overview) - [Deepgram API Documentation](https://developers.deepgram.com/) +- [xAI API Documentation](https://docs.x.ai/) - [OpenAI API Documentation](https://platform.openai.com/docs) - [ElevenLabs API Documentation](https://docs.elevenlabs.io/) - [TEN Framework Documentation](https://doc.theten.ai) diff --git a/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json b/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json index 020768c826..f682035f19 100644 --- a/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json +++ b/ai_agents/agents/examples/voice-assistant/tenapp/manifest.json @@ -66,6 +66,9 @@ { "path": "../../../ten_packages/extension/tencent_asr_python" }, + { + "path": "../../../ten_packages/extension/xai_asr_python" + }, { "path": "../../../ten_packages/extension/xfyun_asr_bigmodel_python" }, @@ -141,6 +144,9 @@ { "path": "../../../ten_packages/extension/tencent_tts_python" }, + { + "path": "../../../ten_packages/extension/xai_tts_python" + }, { "path": "../../../ten_packages/extension/message_collector2" }, diff --git a/ai_agents/agents/examples/voice-assistant/tenapp/property.json b/ai_agents/agents/examples/voice-assistant/tenapp/property.json index 270bfb77be..2f9f59a9b5 100644 --- a/ai_agents/agents/examples/voice-assistant/tenapp/property.json +++ b/ai_agents/agents/examples/voice-assistant/tenapp/property.json @@ -185,6 +185,566 @@ ] } }, + { + "name": "voice_assistant_xai_asr", + "auto_start": false, + "graph": { + "nodes": [ + { + "type": "extension", + "name": "agora_rtc", + "addon": "agora_rtc", + "extension_group": "default", + "property": { + "app_id": "${env:AGORA_APP_ID}", + "app_certificate": "${env:AGORA_APP_CERTIFICATE|}", + "channel": "ten_agent_test", + "stream_id": 1234, + "remote_stream_id": 123, + "subscribe_audio": true, + "publish_audio": true, + "publish_data": true, + "enable_agora_asr": false + } + }, + { + "type": "extension", + "name": "stt", + "addon": "xai_asr_python", + "extension_group": "stt", + "property": { + "params": { + "api_key": "${env:XAI_API_KEY}", + "language": "en", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300 + } + } + }, + { + "type": "extension", + "name": "llm", + "addon": "openai_llm2_python", + "extension_group": "chatgpt", + "property": { + "base_url": "https://api.openai.com/v1", + "api_key": "${env:OPENAI_API_KEY}", + "frequency_penalty": 0.9, + "model": "${env:OPENAI_MODEL}", + "max_tokens": 512, + "prompt": "", + "proxy_url": "${env:OPENAI_PROXY_URL|}", + "greeting": "TEN Agent connected. How can I help you today?", + "max_memory_length": 10 + } + }, + { + "type": "extension", + "name": "tts", + "addon": "elevenlabs_tts2_python", + "extension_group": "tts", + "property": { + "dump": false, + "dump_path": "./", + "params": { + "key": "${env:ELEVENLABS_TTS_KEY}", + "model_id": "eleven_multilingual_v2", + "voice_id": "pNInz6obpgDQGcFmaJgB", + "output_format": "pcm_16000" + } + } + }, + { + "type": "extension", + "name": "main_control", + "addon": "main_python", + "extension_group": "control", + "property": { + "greeting": "TEN Agent connected. How can I help you today?" + } + }, + { + "type": "extension", + "name": "message_collector", + "addon": "message_collector2", + "extension_group": "transcriber", + "property": {} + }, + { + "type": "extension", + "name": "weatherapi_tool_python", + "addon": "weatherapi_tool_python", + "extension_group": "default", + "property": { + "api_key": "${env:WEATHERAPI_API_KEY|}" + } + }, + { + "type": "extension", + "name": "streamid_adapter", + "addon": "streamid_adapter", + "property": {} + } + ], + "connections": [ + { + "extension": "main_control", + "cmd": [ + { + "names": [ + "on_user_joined", + "on_user_left" + ], + "source": [ + { + "extension": "agora_rtc" + } + ] + }, + { + "names": [ + "tool_register" + ], + "source": [ + { + "extension": "weatherapi_tool_python" + } + ] + } + ], + "data": [ + { + "name": "asr_result", + "source": [ + { + "extension": "stt" + } + ] + } + ] + }, + { + "extension": "agora_rtc", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "streamid_adapter" + } + ] + }, + { + "name": "pcm_frame", + "source": [ + { + "extension": "tts" + } + ] + } + ], + "data": [ + { + "name": "data", + "source": [ + { + "extension": "message_collector" + } + ] + } + ] + }, + { + "extension": "streamid_adapter", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "stt" + } + ] + } + ] + } + ] + } + }, + { + "name": "voice_assistant_xai_tts", + "auto_start": false, + "graph": { + "nodes": [ + { + "type": "extension", + "name": "agora_rtc", + "addon": "agora_rtc", + "extension_group": "default", + "property": { + "app_id": "${env:AGORA_APP_ID}", + "app_certificate": "${env:AGORA_APP_CERTIFICATE|}", + "channel": "ten_agent_test", + "stream_id": 1234, + "remote_stream_id": 123, + "subscribe_audio": true, + "publish_audio": true, + "publish_data": true, + "enable_agora_asr": false + } + }, + { + "type": "extension", + "name": "stt", + "addon": "deepgram_asr_python", + "extension_group": "stt", + "property": { + "params": { + "api_key": "${env:DEEPGRAM_API_KEY}", + "language": "en-US", + "model": "nova-3" + } + } + }, + { + "type": "extension", + "name": "llm", + "addon": "openai_llm2_python", + "extension_group": "chatgpt", + "property": { + "base_url": "https://api.openai.com/v1", + "api_key": "${env:OPENAI_API_KEY}", + "frequency_penalty": 0.9, + "model": "${env:OPENAI_MODEL}", + "max_tokens": 512, + "prompt": "", + "proxy_url": "${env:OPENAI_PROXY_URL|}", + "greeting": "TEN Agent connected. How can I help you today?", + "max_memory_length": 10 + } + }, + { + "type": "extension", + "name": "tts", + "addon": "xai_tts_python", + "extension_group": "tts", + "property": { + "dump": false, + "dump_path": "./", + "params": { + "api_key": "${env:XAI_API_KEY}", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } + } + }, + { + "type": "extension", + "name": "main_control", + "addon": "main_python", + "extension_group": "control", + "property": { + "greeting": "TEN Agent connected. How can I help you today?" + } + }, + { + "type": "extension", + "name": "message_collector", + "addon": "message_collector2", + "extension_group": "transcriber", + "property": {} + }, + { + "type": "extension", + "name": "weatherapi_tool_python", + "addon": "weatherapi_tool_python", + "extension_group": "default", + "property": { + "api_key": "${env:WEATHERAPI_API_KEY|}" + } + }, + { + "type": "extension", + "name": "streamid_adapter", + "addon": "streamid_adapter", + "property": {} + } + ], + "connections": [ + { + "extension": "main_control", + "cmd": [ + { + "names": [ + "on_user_joined", + "on_user_left" + ], + "source": [ + { + "extension": "agora_rtc" + } + ] + }, + { + "names": [ + "tool_register" + ], + "source": [ + { + "extension": "weatherapi_tool_python" + } + ] + } + ], + "data": [ + { + "name": "asr_result", + "source": [ + { + "extension": "stt" + } + ] + } + ] + }, + { + "extension": "agora_rtc", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "streamid_adapter" + } + ] + }, + { + "name": "pcm_frame", + "source": [ + { + "extension": "tts" + } + ] + } + ], + "data": [ + { + "name": "data", + "source": [ + { + "extension": "message_collector" + } + ] + } + ] + }, + { + "extension": "streamid_adapter", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "stt" + } + ] + } + ] + } + ] + } + }, + { + "name": "voice_assistant_xai_full", + "auto_start": false, + "graph": { + "nodes": [ + { + "type": "extension", + "name": "agora_rtc", + "addon": "agora_rtc", + "extension_group": "default", + "property": { + "app_id": "${env:AGORA_APP_ID}", + "app_certificate": "${env:AGORA_APP_CERTIFICATE|}", + "channel": "ten_agent_test", + "stream_id": 1234, + "remote_stream_id": 123, + "subscribe_audio": true, + "publish_audio": true, + "publish_data": true, + "enable_agora_asr": false + } + }, + { + "type": "extension", + "name": "stt", + "addon": "xai_asr_python", + "extension_group": "stt", + "property": { + "params": { + "api_key": "${env:XAI_API_KEY}", + "language": "en", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300 + } + } + }, + { + "type": "extension", + "name": "llm", + "addon": "openai_llm2_python", + "extension_group": "chatgpt", + "property": { + "base_url": "https://api.openai.com/v1", + "api_key": "${env:OPENAI_API_KEY}", + "frequency_penalty": 0.9, + "model": "${env:OPENAI_MODEL}", + "max_tokens": 512, + "prompt": "", + "proxy_url": "${env:OPENAI_PROXY_URL|}", + "greeting": "TEN Agent connected. How can I help you today?", + "max_memory_length": 10 + } + }, + { + "type": "extension", + "name": "tts", + "addon": "xai_tts_python", + "extension_group": "tts", + "property": { + "dump": false, + "dump_path": "./", + "params": { + "api_key": "${env:XAI_API_KEY}", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } + } + }, + { + "type": "extension", + "name": "main_control", + "addon": "main_python", + "extension_group": "control", + "property": { + "greeting": "TEN Agent connected. How can I help you today?" + } + }, + { + "type": "extension", + "name": "message_collector", + "addon": "message_collector2", + "extension_group": "transcriber", + "property": {} + }, + { + "type": "extension", + "name": "weatherapi_tool_python", + "addon": "weatherapi_tool_python", + "extension_group": "default", + "property": { + "api_key": "${env:WEATHERAPI_API_KEY|}" + } + }, + { + "type": "extension", + "name": "streamid_adapter", + "addon": "streamid_adapter", + "property": {} + } + ], + "connections": [ + { + "extension": "main_control", + "cmd": [ + { + "names": [ + "on_user_joined", + "on_user_left" + ], + "source": [ + { + "extension": "agora_rtc" + } + ] + }, + { + "names": [ + "tool_register" + ], + "source": [ + { + "extension": "weatherapi_tool_python" + } + ] + } + ], + "data": [ + { + "name": "asr_result", + "source": [ + { + "extension": "stt" + } + ] + } + ] + }, + { + "extension": "agora_rtc", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "streamid_adapter" + } + ] + }, + { + "name": "pcm_frame", + "source": [ + { + "extension": "tts" + } + ] + } + ], + "data": [ + { + "name": "data", + "source": [ + { + "extension": "message_collector" + } + ] + } + ] + }, + { + "extension": "streamid_adapter", + "audio_frame": [ + { + "name": "pcm_frame", + "dest": [ + { + "extension": "stt" + } + ] + } + ] + } + ] + } + }, { "name": "voice_assistant_oracle", "auto_start": false, @@ -410,4 +970,4 @@ ] } } -} \ No newline at end of file +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/README.md b/ai_agents/agents/ten_packages/extension/xai_asr_python/README.md new file mode 100644 index 0000000000..a80c0bb81c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/README.md @@ -0,0 +1,62 @@ +# xAI ASR Extension + +A TEN Framework extension that provides streaming Speech-to-Text (STT / ASR) +capabilities using xAI's WebSocket speech API. + +## Features + +- Real-time ASR over WebSocket +- Raw PCM / mu-law / A-law input +- Partial and final transcript handling +- Explicit finalize via `audio.done` +- Reconnect support with bounded retry attempts +- Audio dump support for debugging + +## Configuration + +All configuration is supplied through the `params` object. + +```json +{ + "dump": false, + "dump_path": "/tmp", + "finalize_timeout_ms": 2000, + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "en", + "diarize": false, + "multichannel": false, + "channels": 1 + } +} +``` + +## Key Properties + +- `params.api_key`: xAI API key +- `params.base_url`: WebSocket endpoint +- `params.sample_rate`: input sample rate +- `params.encoding`: `pcm`, `mulaw`, or `alaw` +- `params.language`: formatting language code +- `params.interim_results`: emit partial transcripts +- `params.endpointing`: silence duration before utterance finalization +- `dump`: enable PCM dump output +- `dump_path`: dump directory +- `finalize_timeout_ms`: wait time for `transcript.done` + +## Running Tests + +```bash +cd xai_asr_python +tman -y install --standalone +./tests/bin/start +``` + +## Environment Variables + +- `XAI_API_KEY` - Your xAI API key diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py new file mode 100644 index 0000000000..f3c731cdd5 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py @@ -0,0 +1 @@ +from . import addon diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/addon.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/addon.py new file mode 100644 index 0000000000..780e2664f7 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/addon.py @@ -0,0 +1,15 @@ +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + +from .extension import XAIASRExtension + + +@register_addon_as_extension("xai_asr_python") +class XAIASRExtensionAddon(Addon): + def on_create_instance(self, ten: TenEnv, addon_name: str, context) -> None: + + ten.log_info("on_create_instance") + ten.on_create_instance_done(XAIASRExtension(addon_name), context) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py new file mode 100644 index 0000000000..1c72232e60 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py @@ -0,0 +1,76 @@ +from typing import Any + +from pydantic import BaseModel, Field +from ten_ai_base.utils import encrypt + + +XAI_DEFAULT_PARAMS = { + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": True, + "endpointing": 300, + "language": "en", + "diarize": False, + "multichannel": False, + "channels": 1, +} + + +class XAIASRConfig(BaseModel): + dump: bool = False + dump_path: str = "/tmp" + finalize_timeout_ms: int = 2000 + params: dict[str, Any] = Field(default_factory=dict) + + def apply_defaults(self) -> None: + params = self.params if isinstance(self.params, dict) else {} + for key, value in XAI_DEFAULT_PARAMS.items(): + params.setdefault(key, value) + self.params = params + + def validate(self) -> None: + if not self.params.get("api_key"): + raise ValueError("xAI API key is required") + if self.params.get("sample_rate") not in { + 8000, + 16000, + 22050, + 24000, + 44100, + 48000, + }: + raise ValueError( + f"Unsupported sample_rate: {self.params.get('sample_rate')}" + ) + if self.params.get("encoding") not in {"pcm", "mulaw", "alaw"}: + raise ValueError( + f"Unsupported encoding: {self.params.get('encoding')}" + ) + + def to_json(self, sensitive_handling: bool = False) -> str: + config_dict = self.model_dump() + if sensitive_handling and config_dict["params"]: + api_key = config_dict["params"].get("api_key") + if api_key: + config_dict["params"]["api_key"] = encrypt(api_key) + return str(config_dict) + + @property + def normalized_language(self) -> str: + language_map = { + "zh": "zh-CN", + "en": "en-US", + "ja": "ja-JP", + "ko": "ko-KR", + "de": "de-DE", + "fr": "fr-FR", + "ru": "ru-RU", + "es": "es-ES", + "pt": "pt-PT", + "it": "it-IT", + "hi": "hi-IN", + "ar": "ar-AE", + } + language_code = (self.params or {}).get("language", "") or "" + return language_map.get(language_code, language_code) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/const.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/const.py new file mode 100644 index 0000000000..172a418483 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/const.py @@ -0,0 +1,3 @@ +DUMP_FILE_NAME = "xai_asr_in.pcm" +MODULE_NAME_ASR = "asr" +TIMEOUT_CODE = 10105 diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py new file mode 100644 index 0000000000..051c10a95e --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py @@ -0,0 +1,382 @@ +import copy +import os +import asyncio +import json +from datetime import datetime +from typing import Any +from uuid import uuid4 + +from typing_extensions import override + +from ten_ai_base.asr import ( + ASRBufferConfig, + ASRBufferConfigModeKeep, + ASRResult, + AsyncASRBaseExtension, +) +from ten_ai_base.struct import ASRWord +from ten_ai_base.const import LOG_CATEGORY_KEY_POINT, LOG_CATEGORY_VENDOR +from ten_ai_base.dumper import Dumper +from ten_ai_base.message import ( + ModuleError, + ModuleErrorCode, + ModuleErrorVendorInfo, +) +from ten_runtime import AsyncTenEnv, AudioFrame + +from .config import XAIASRConfig +from .const import DUMP_FILE_NAME, MODULE_NAME_ASR +from .recognition import XAIASRRecognition, XAIASRRecognitionCallback +from .reconnect_manager import ReconnectManager + + +class XAIASRExtension(AsyncASRBaseExtension, XAIASRRecognitionCallback): + def __init__(self, name: str): + super().__init__(name) + self.recognition: XAIASRRecognition | None = None + self.config: XAIASRConfig | None = None + self.audio_dumper: Dumper | None = None + self.sent_user_audio_duration_ms_before_last_reset = 0 + self.last_finalize_timestamp = 0 + self.reconnect_manager: ReconnectManager | None = None + self._stop_requested = False + self._close_expected = False + self.connection_start_timestamp = 0 + + @override + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + if self.audio_dumper: + await self.audio_dumper.stop() + self.audio_dumper = None + + @override + def vendor(self) -> str: + return "xai" + + @override + async def on_init(self, ten_env: AsyncTenEnv) -> None: + await super().on_init(ten_env) + self.reconnect_manager = ReconnectManager(logger=ten_env) + config_json, _ = await ten_env.get_property_to_json("") + try: + self.config = XAIASRConfig.model_validate_json(config_json) + self.config.apply_defaults() + self.config.validate() + ten_env.log_info( + f"config: {self.config.to_json(sensitive_handling=True)}", + category=LOG_CATEGORY_KEY_POINT, + ) + if self.config.dump: + dump_file_path = os.path.join( + self.config.dump_path, DUMP_FILE_NAME + ) + self.audio_dumper = Dumper(dump_file_path) + await self.audio_dumper.start() + except Exception as e: + ten_env.log_error(f"Invalid xAI config: {e}") + self.config = XAIASRConfig.model_validate_json("{}") + await self.send_asr_error( + ModuleError( + module=MODULE_NAME_ASR, + code=ModuleErrorCode.FATAL_ERROR.value, + message=str(e), + ), + ) + + @override + async def start_connection(self) -> None: + assert self.config is not None + api_key = self.config.params.get("api_key", "") + if not api_key or str(api_key).strip() == "": + await self.send_asr_error( + ModuleError( + module=MODULE_NAME_ASR, + code=ModuleErrorCode.FATAL_ERROR.value, + message="xAI API key is required but missing or empty", + ) + ) + return + self._stop_requested = False + self._close_expected = False + try: + await self._connect_recognition() + except Exception as e: + fatal = self._is_fatal_connection_error(str(e)) + self.ten_env.log_error(f"Failed to start xAI STT connection: {e}") + await self.send_asr_error( + ModuleError( + module=MODULE_NAME_ASR, + code=( + ModuleErrorCode.FATAL_ERROR.value + if fatal + else ModuleErrorCode.NON_FATAL_ERROR.value + ), + message=str(e), + ), + ModuleErrorVendorInfo( + vendor=self.vendor(), + code="connect_failed", + message=str(e), + ), + ) + if not fatal: + await self._handle_reconnect() + + async def _connect_recognition(self) -> None: + assert self.config is not None + api_key = self.config.params.get("api_key", "") + if self.is_connected(): + await self.stop_connection() + self.connection_start_timestamp = int( + datetime.now().timestamp() * 1000 + ) + self.recognition = XAIASRRecognition( + api_key=api_key, + audio_timeline=self.audio_timeline, + ten_env=self.ten_env, + config=self.config.params, + callback=self, + ) + await self.recognition.start(timeout=10) + + @staticmethod + def _is_fatal_connection_error(error_message: str) -> bool: + normalized = error_message.lower() + return any( + token in normalized + for token in ("401", "403", "unauthorized", "forbidden", "api key") + ) + + @override + async def finalize(self, _session_id: str | None) -> None: + assert self.config is not None + self.last_finalize_timestamp = int(datetime.now().timestamp() * 1000) + self._close_expected = True + if self.recognition: + await self.recognition.send_audio_done() + payload = await self.recognition.wait_for_done( + self.config.finalize_timeout_ms + ) + if payload and payload.get("text"): + await self._emit_asr_result(payload, final=True, locked=False) + elif not self.recognition.done_event.is_set(): + self._close_expected = False + await self._finalize_end() + + async def _finalize_end(self) -> None: + if self.last_finalize_timestamp != 0: + self.last_finalize_timestamp = 0 + await self.send_asr_finalize_end() + + @override + async def stop_connection(self) -> None: + self._stop_requested = True + if self.recognition: + await self.recognition.close() + self.recognition = None + self._stop_requested = False + self._close_expected = False + + @override + def is_connected(self) -> bool: + return self.recognition is not None and self.recognition.is_connected() + + @override + def buffer_strategy(self) -> ASRBufferConfig: + return ASRBufferConfigModeKeep(byte_limit=1024 * 1024 * 10) + + @override + def input_audio_sample_rate(self) -> int: + assert self.config is not None + return int(self.config.params.get("sample_rate", 16000)) + + @override + async def send_audio( + self, frame: AudioFrame, _session_id: str | None + ) -> bool: + if self.recognition is None or not self.is_connected(): + return False + try: + buf = frame.lock_buf() + audio_data = bytes(buf) + if self.audio_dumper: + await self.audio_dumper.push_bytes(audio_data) + await self.recognition.send_audio_frame(audio_data) + frame.unlock_buf(buf) + return True + except Exception as e: + self.ten_env.log_error(f"Error sending audio to xAI STT: {e}") + try: + frame.unlock_buf(buf) + except Exception: + pass + return False + + @override + async def on_open(self) -> None: + connection_delay_ms = ( + int(datetime.now().timestamp() * 1000) + - self.connection_start_timestamp + ) + self.ten_env.log_info( + "vendor_status_changed: on_open", + category=LOG_CATEGORY_VENDOR, + ) + await self.send_connect_delay_metrics(connection_delay_ms) + if self.reconnect_manager: + self.reconnect_manager.mark_connection_successful() + self.sent_user_audio_duration_ms_before_last_reset += ( + self.audio_timeline.get_total_user_audio_duration() + ) + self.audio_timeline.reset() + await self._flush_buffered_audio_frames() + + async def _flush_buffered_audio_frames(self) -> None: + while True: + try: + buffered_frame = self.buffered_frames.get_nowait() + except asyncio.QueueEmpty: + self.buffered_frames_size = 0 + return + + metadata, _ = buffered_frame.get_property_to_json("metadata") + if metadata: + try: + metadata_json = copy.deepcopy(json.loads(metadata)) + self.metadata = metadata_json + self.session_id = metadata_json.get( + "session_id", self.session_id + ) + except Exception: + pass + + await self.send_audio(buffered_frame, self.session_id) + + async def _emit_asr_result( + self, message_data: dict[str, Any], final: bool, locked: bool + ) -> None: + assert self.config is not None + text = message_data.get("text", "") + if not text and not final: + return + start_ms = int((message_data.get("start", 0) or 0) * 1000) + duration_ms = int((message_data.get("duration", 0) or 0) * 1000) + actual_start_ms = int( + self.audio_timeline.get_audio_duration_before_time(start_ms) + + self.sent_user_audio_duration_ms_before_last_reset + ) + base_metadata = ( + copy.deepcopy(self.metadata) if self.metadata is not None else {} + ) + session_id = base_metadata.pop("session_id", None) + metadata: dict[str, Any] = {} + if session_id is not None: + metadata["session_id"] = session_id + metadata["asr_info"] = { + **base_metadata, + "vendor": self.vendor(), + "locked": locked, + } + words = [] + for word in message_data.get("words", []) or []: + word_start_ms = int((word.get("start", 0) or 0) * 1000) + word_end_ms = int((word.get("end", 0) or 0) * 1000) + actual_word_start_ms = int( + self.audio_timeline.get_audio_duration_before_time( + word_start_ms + ) + + self.sent_user_audio_duration_ms_before_last_reset + ) + words.append( + ASRWord( + word=word.get("text", ""), + start_ms=actual_word_start_ms, + duration_ms=max(0, word_end_ms - word_start_ms), + stable=locked or final, + ) + ) + asr_result = ASRResult( + id=str(uuid4()), + text=text, + final=final, + start_ms=actual_start_ms, + duration_ms=duration_ms, + language=self.config.normalized_language, + words=words, + metadata=metadata, + ) + await self.send_asr_result(asr_result) + + @override + async def on_partial_result(self, message_data: dict[str, Any]) -> None: + is_final = bool(message_data.get("is_final", False)) + speech_final = bool(message_data.get("speech_final", False)) + locked = is_final and not speech_final + await self._emit_asr_result( + message_data, + final=is_final and speech_final, + locked=locked, + ) + + @override + async def on_done(self, message_data: dict[str, Any]) -> None: + self.ten_env.log_debug(f"xAI transcript.done: {message_data}") + + @override + async def on_error( + self, error_msg: str, error_code: int | None = None + ) -> None: + self.ten_env.log_error( + f"vendor_error: code: {error_code}, reason: {error_msg}", + category=LOG_CATEGORY_VENDOR, + ) + fatal = "401" in error_msg or "Unauthorized" in error_msg + await self.send_asr_error( + ModuleError( + module=MODULE_NAME_ASR, + code=( + ModuleErrorCode.FATAL_ERROR.value + if fatal + else ModuleErrorCode.NON_FATAL_ERROR.value + ), + message=error_msg, + ), + ModuleErrorVendorInfo( + vendor=self.vendor(), + code=str(error_code) if error_code else "unknown", + message=error_msg, + ), + ) + + @override + async def on_close(self) -> None: + self.ten_env.log_info("vendor_status_changed: on_close") + self.recognition = None + if self._stop_requested: + return + if self._close_expected: + self._close_expected = False + return + await self._handle_reconnect() + + async def _handle_reconnect(self) -> None: + if not self.reconnect_manager: + self.ten_env.log_error("ReconnectManager not initialized") + return + while not self._stop_requested and self.reconnect_manager.can_retry(): + success = await self.reconnect_manager.handle_reconnect( + connection_func=self._connect_recognition, + error_handler=self.send_asr_error, + ) + + if success: + self.ten_env.log_debug( + "Reconnection attempt initiated successfully" + ) + return + + info = self.reconnect_manager.get_attempts_info() + self.ten_env.log_debug( + f"Reconnection attempt failed. Status: {info}" + ) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json new file mode 100644 index 0000000000..0788756370 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json @@ -0,0 +1,72 @@ +{ + "type": "extension", + "name": "xai_asr_python", + "version": "0.4.0", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/asr-interface.json" + } + ], + "property": { + "properties": { + "params": { + "type": "object", + "properties": { + "base_url": { + "type": "string" + }, + "api_key": { + "type": "string" + }, + "language": { + "type": "string" + }, + "sample_rate": { + "type": "int32" + }, + "encoding": { + "type": "string" + }, + "interim_results": { + "type": "bool" + }, + "endpointing": { + "type": "int32" + }, + "diarize": { + "type": "bool" + }, + "multichannel": { + "type": "bool" + }, + "channels": { + "type": "int32" + } + } + } + } + } + }, + "package": { + "include": [ + "manifest.json", + "property.json", + "**.py", + "requirements.txt", + "docs/**" + ] + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/property.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/property.json new file mode 100644 index 0000000000..6e04a1950d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/property.json @@ -0,0 +1,16 @@ +{ + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "${env:XAI_API_KEY|}", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "en", + "diarize": false, + "multichannel": false, + "channels": 1 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py new file mode 100644 index 0000000000..b621d3976c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py @@ -0,0 +1,180 @@ +import asyncio +import json +from abc import abstractmethod +from typing import Any +from urllib.parse import urlencode + +import websockets +from websockets.protocol import State + +from ten_ai_base.const import LOG_CATEGORY_VENDOR +from ten_ai_base.timeline import AudioTimeline +from ten_runtime import AsyncTenEnv + + +class XAIASRRecognitionCallback: + @abstractmethod + async def on_open(self): + pass + + @abstractmethod + async def on_partial_result(self, message_data: dict[str, Any]): + pass + + @abstractmethod + async def on_done(self, message_data: dict[str, Any]): + pass + + @abstractmethod + async def on_error(self, error_msg: str, error_code: int | None = None): + pass + + @abstractmethod + async def on_close(self): + pass + + +class XAIASRRecognition: + def __init__( + self, + api_key: str, + audio_timeline: AudioTimeline, + ten_env: AsyncTenEnv, + config: dict[str, Any], + callback: XAIASRRecognitionCallback, + ): + self.api_key = api_key + self.audio_timeline = audio_timeline + self.ten_env = ten_env + self.config = config or {} + self.callback = callback + self.websocket = None + self.is_started = False + self.ready_event = asyncio.Event() + self.done_event = asyncio.Event() + self.done_payload: dict[str, Any] | None = None + self._message_task: asyncio.Task | None = None + + def _build_url(self) -> str: + base_url = self.config.get("base_url", "wss://api.x.ai/v1/stt") + query_params: dict[str, Any] = {} + for key in ( + "sample_rate", + "encoding", + "interim_results", + "endpointing", + "language", + "multichannel", + "channels", + "diarize", + ): + value = self.config.get(key) + if value is not None: + query_params[key] = ( + str(value).lower() if isinstance(value, bool) else value + ) + return f"{base_url}?{urlencode(query_params, doseq=True)}" + + async def start(self, timeout: int = 10) -> None: + url = self._build_url() + self.ten_env.log_info(f"Connecting to xAI STT: {url}") + self.websocket = await websockets.connect( + url, + additional_headers={ + "Authorization": f"Bearer {self.api_key}" + }, + open_timeout=timeout, + ) + self.is_started = True + self.ready_event.clear() + self.done_event.clear() + self.done_payload = None + first_message = await asyncio.wait_for(self.websocket.recv(), timeout=timeout) + if isinstance(first_message, bytes): + raise RuntimeError("Unexpected binary message during xAI STT startup") + first_event = json.loads(first_message) + self.ten_env.log_debug( + f"vendor_result: startup: {first_message}", + category=LOG_CATEGORY_VENDOR, + ) + if first_event.get("type") != "transcript.created": + raise RuntimeError( + f"Unexpected xAI STT startup event: {first_event.get('type')}" + ) + self.ready_event.set() + await self.callback.on_open() + self._message_task = asyncio.create_task(self._message_handler()) + + async def _message_handler(self) -> None: + try: + async for message in self.websocket: + if isinstance(message, bytes): + continue + event = json.loads(message) + self.ten_env.log_debug( + f"vendor_result: on_recognized: {message}", + category=LOG_CATEGORY_VENDOR, + ) + event_type = event.get("type", "") + if event_type == "transcript.created": + self.ready_event.set() + await self.callback.on_open() + elif event_type == "transcript.partial": + await self.callback.on_partial_result(event) + elif event_type == "transcript.done": + self.done_payload = event + self.done_event.set() + await self.callback.on_done(event) + elif event_type == "error": + await self.callback.on_error( + str(event.get("message", "Unknown error")) + ) + except websockets.exceptions.ConnectionClosed as e: + self.ten_env.log_info(f"xAI STT websocket closed: {e}") + except Exception as e: + await self.callback.on_error(f"WebSocket message handler error: {e}") + finally: + self.is_started = False + await self.callback.on_close() + + async def send_audio_frame(self, audio_data: bytes) -> None: + if not self.websocket or not self.is_connected(): + raise RuntimeError("WebSocket not connected") + await self.ready_event.wait() + sample_rate = self.config.get("sample_rate", 16000) + duration_ms = int(len(audio_data) / (sample_rate / 1000 * 2)) + self.audio_timeline.add_user_audio(duration_ms) + await self.websocket.send(audio_data) + + async def send_audio_done(self) -> None: + if self.websocket and self.is_connected(): + await self.websocket.send(json.dumps({"type": "audio.done"})) + + async def wait_for_done(self, timeout_ms: int) -> dict[str, Any] | None: + try: + await asyncio.wait_for(self.done_event.wait(), timeout_ms / 1000) + except asyncio.TimeoutError: + return None + return self.done_payload + + async def close(self) -> None: + if self.websocket: + try: + if self.websocket.state == State.OPEN: + await self.websocket.close() + except Exception as e: + self.ten_env.log_info(f"Error closing websocket: {e}") + if self._message_task and not self._message_task.done(): + self._message_task.cancel() + try: + await self._message_task + except asyncio.CancelledError: + pass + self.is_started = False + + def is_connected(self) -> bool: + return ( + self.is_started + and self.websocket is not None + and self.websocket.state == State.OPEN + ) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py new file mode 100644 index 0000000000..caa826b4c5 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py @@ -0,0 +1,121 @@ +import asyncio +from typing import Callable, Awaitable, Optional +from ten_ai_base.message import ModuleError, ModuleErrorCode +from .const import MODULE_NAME_ASR + + +class ReconnectManager: + """ + Manages bounded reconnection attempts with exponential backoff. + + Features: + - Bounded retry attempts (`max_attempts`, default: 4) + - Exponential backoff strategy with maximum delay cap: 0.5s, 1s, 2s, 4s (capped) + - Maximum delay cap to prevent overwhelming the service provider (default: 4s) + - Automatic counter reset after successful connection + - Detailed logging for monitoring and debugging + """ + + def __init__( + self, + base_delay: float = 0.5, # 500 milliseconds + max_delay: float = 4.0, # 4 seconds maximum delay + max_attempts: int = 4, + logger=None, + ): + self.base_delay = base_delay + self.max_delay = max_delay + self.max_attempts = max_attempts + self.logger = logger + + # State tracking + self.attempts = 0 + self._connection_successful = False + + def _reset_counter(self): + """Reset reconnection counter""" + self.attempts = 0 + if self.logger: + self.logger.log_debug("Reconnect counter reset") + + def mark_connection_successful(self): + """Mark connection as successful and reset counter""" + self._connection_successful = True + self._reset_counter() + + def get_attempts_info(self) -> dict: + """Get current reconnection attempts information""" + return { + "current_attempts": self.attempts, + "max_attempts": self.max_attempts, + } + + def can_retry(self) -> bool: + return self.attempts < self.max_attempts + + async def handle_reconnect( + self, + connection_func: Callable[[], Awaitable[None]], + error_handler: Optional[ + Callable[[ModuleError], Awaitable[None]] + ] = None, + ) -> bool: + """ + Handle a single reconnection attempt with backoff delay. + + Args: + connection_func: Async function to establish connection + error_handler: Optional async function to handle errors + + Returns: + True if connection function executed successfully, False if attempt failed + Note: Actual connection success is determined by callback calling mark_connection_successful() + """ + self._connection_successful = False + self.attempts += 1 + + # Calculate exponential backoff delay with max limit: min(2^(attempts-1) * base_delay, max_delay) + delay = min( + self.base_delay * (2 ** (self.attempts - 1)), self.max_delay + ) + + if self.logger: + self.logger.log_warn( + f"Attempting reconnection #{self.attempts} " + f"after {delay:.2f} seconds delay..." + ) + + try: + await asyncio.sleep(delay) + await connection_func() + + # Connection function completed successfully + # Actual connection success will be determined by callback + if self.logger: + self.logger.log_debug( + f"Connection function completed for attempt #{self.attempts}" + ) + return True + + except Exception as e: + is_fatal = self.attempts >= self.max_attempts + if self.logger: + self.logger.log_error( + f"Reconnection attempt #{self.attempts} failed: {e}. " + f"{'Giving up.' if is_fatal else 'Will retry...'}" + ) + + if error_handler: + await error_handler( + ModuleError( + module=MODULE_NAME_ASR, + code=( + ModuleErrorCode.FATAL_ERROR.value + if is_fatal + else ModuleErrorCode.NON_FATAL_ERROR.value + ), + message=f"Reconnection attempt #{self.attempts} failed: {str(e)}", + ) + ) + + return False diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt b/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt new file mode 100644 index 0000000000..0fd8ae23c6 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt @@ -0,0 +1,2 @@ +websockets>=15.0.1 +pydantic \ No newline at end of file diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/__init__.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start new file mode 100755 index 0000000000..f6a1cf283d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +# If the Python app imports some modules that are compiled with a different +# version of libstdc++ (ex: PyTorch), the Python app may encounter confusing +# errors. To solve this problem, we can preload the correct version of +# libstdc++. +# +# export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 +# +# Another solution is to make sure the module 'ten_runtime_python' is imported +# _after_ the module that requires another version of libstdc++ is imported. +# +# Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 + +pytest -s tests/ "$@" \ No newline at end of file diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_dump.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_dump.json new file mode 100644 index 0000000000..f4aef6cd48 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_dump.json @@ -0,0 +1,13 @@ +{ + "dump": true, + "dump_path": "./tests/dump_output/", + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "en" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en.json new file mode 100644 index 0000000000..ba69e117e1 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en.json @@ -0,0 +1,11 @@ +{ + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "en" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en_hotwords.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en_hotwords.json new file mode 100644 index 0000000000..ba69e117e1 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_en_hotwords.json @@ -0,0 +1,11 @@ +{ + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "en" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_invalid.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_invalid.json new file mode 100644 index 0000000000..7f944d60bd --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_invalid.json @@ -0,0 +1,13 @@ +{ + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "invalid", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "en" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_zh.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_zh.json new file mode 100644 index 0000000000..7f8047d8d7 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/configs/property_zh.json @@ -0,0 +1,11 @@ +{ + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "interim_results": true, + "endpointing": 300, + "language": "zh" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/conftest.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/conftest.py new file mode 100644 index 0000000000..0af5300618 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/conftest.py @@ -0,0 +1,106 @@ +import sys +from pathlib import Path + +project_root = str(Path(__file__).resolve().parents[6]) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +import threading +from typing_extensions import override +import pytest +from ten_runtime import ( + App, + TenEnv, +) + + +class FakeApp(App): + def __init__(self): + super().__init__() + self.event: threading.Event | None = None + + # In the case of a fake app, we use `on_init` to allow the blocked testing + # fixture to continue execution, rather than using `on_configure`. The + # reason is that in the TEN runtime C core, the relationship between the + # addon manager and the (fake) app is bound after `on_configure_done` is + # called. So we only need to let the testing fixture continue execution + # after this action in the TEN runtime C core, and at the upper layer + # timing, the earliest point is within the `on_init()` function of the upper + # TEN app. Therefore, we release the testing fixture lock within the user + # layer's `on_init()` of the TEN app. + @override + def on_init(self, ten_env: TenEnv) -> None: + assert self.event + self.event.set() + + ten_env.on_init_done() + + @override + def on_configure(self, ten_env: TenEnv) -> None: + ten_env.init_property_from_json( + json.dumps( + { + "ten": { + "log": { + "handlers": [ + { + "matchers": [{"level": "debug"}], + "formatter": { + "type": "plain", + "colored": True, + }, + "emitter": { + "type": "console", + "config": {"stream": "stdout"}, + }, + } + ] + } + } + } + ), + ) + + ten_env.on_configure_done() + + +class FakeAppCtx: + def __init__(self, event: threading.Event): + self.fake_app: FakeApp | None = None + self.event = event + + +def run_fake_app(fake_app_ctx: FakeAppCtx): + app = FakeApp() + app.event = fake_app_ctx.event + fake_app_ctx.fake_app = app + app.run(False) + + +@pytest.fixture(scope="session", autouse=True) +def global_setup_and_teardown(): + event = threading.Event() + fake_app_ctx = FakeAppCtx(event) + + fake_app_thread = threading.Thread( + target=run_fake_app, args=(fake_app_ctx,) + ) + fake_app_thread.start() + + event.wait() + + assert fake_app_ctx.fake_app is not None + + # Yield control to the test; after the test execution is complete, continue + # with the teardown process. + yield + + # Teardown part. + fake_app_ctx.fake_app.close() + fake_app_thread.join() diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/mock.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/mock.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/mock.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_asr_result.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_asr_result.py new file mode 100644 index 0000000000..25ee65630d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_asr_result.py @@ -0,0 +1,79 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from xai_asr_python.config import XAIASRConfig +from xai_asr_python.extension import XAIASRExtension + + +class FakeTimeline: + def get_audio_duration_before_time(self, value: int) -> int: + return value + + +def test_asr_result_shape(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.audio_timeline = FakeTimeline() + extension.config = XAIASRConfig( + params={"api_key": "xai-test-key", "language": "en"} + ) + extension.config.apply_defaults() + extension.metadata = {"session_id": "session-123", "turn_id": 9} + extension.send_asr_result = AsyncMock() + + await extension._emit_asr_result( + { + "text": "hello world", + "start": 0.12, + "duration": 0.34, + "words": [ + {"text": "hello", "start": 0.12, "end": 0.2}, + {"text": "world", "start": 0.2, "end": 0.34}, + ], + }, + final=True, + locked=False, + ) + + result = extension.send_asr_result.await_args.args[0] + assert result.id + assert result.text == "hello world" + assert result.final is True + assert result.start_ms == 120 + assert result.duration_ms == 340 + assert result.language == "en-US" + assert result.metadata["session_id"] == "session-123" + assert result.metadata["asr_info"]["vendor"] == "xai" + assert len(result.words) == 2 + + asyncio.run(_run()) + + +def test_partial_final_mapping_sets_locked(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.audio_timeline = FakeTimeline() + extension.config = XAIASRConfig( + params={"api_key": "xai-test-key", "language": "en"} + ) + extension.config.apply_defaults() + extension.metadata = {"session_id": "session-123"} + extension.send_asr_result = AsyncMock() + + await extension.on_partial_result( + { + "text": "hello", + "start": 0.1, + "duration": 0.2, + "is_final": True, + "speech_final": False, + } + ) + + result = extension.send_asr_result.await_args.args[0] + assert result.final is False + assert result.metadata["asr_info"]["locked"] is True + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_dump.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_dump.py new file mode 100644 index 0000000000..e7da99a785 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_dump.py @@ -0,0 +1,40 @@ +import asyncio +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +from ten_ai_base.dumper import Dumper +from xai_asr_python.extension import XAIASRExtension + + +class FakeFrame: + def __init__(self, payload: bytes): + self.payload = bytearray(payload) + + def lock_buf(self): + return self.payload + + def unlock_buf(self, _buf): + return None + + +def test_send_audio_writes_dump_and_vendor_stream(tmp_path): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + dump_path = Path(tmp_path) / "dump.pcm" + extension.audio_dumper = Dumper(str(dump_path)) + await extension.audio_dumper.start() + extension.recognition = MagicMock() + extension.recognition.is_connected.return_value = True + extension.recognition.send_audio_frame = AsyncMock() + payload = b"\x00\x01\x02\x03" + + result = await extension.send_audio(FakeFrame(payload), None) + + assert result is True + extension.recognition.send_audio_frame.assert_awaited_once_with(payload) + await extension.audio_dumper.stop() + assert dump_path.exists() + assert dump_path.read_bytes() == payload + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_error_check.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_error_check.py new file mode 100644 index 0000000000..c1d1ce1c4a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_error_check.py @@ -0,0 +1,69 @@ +from typing_extensions import override +from ten_runtime import ( + AsyncExtensionTester, + AsyncTenEnvTester, + Data, + AudioFrame, + TenError, + TenErrorCode, +) +import json + + +class XAIAsrExtensionTester(AsyncExtensionTester): + + def __init__(self): + super().__init__() + + @override + async def on_start(self, ten_env_tester: AsyncTenEnvTester) -> None: + ten_env_tester.log_info("on_start") + + def stop_test_if_checking_failed( + self, + ten_env_tester: AsyncTenEnvTester, + success: bool, + error_message: str, + ) -> None: + if not success: + err = TenError.create( + error_code=TenErrorCode.ErrorCodeGeneric, + error_message=error_message, + ) + ten_env_tester.stop_test(err) + + @override + async def on_data( + self, ten_env_tester: AsyncTenEnvTester, data: Data + ) -> None: + # Expect to receive an error data. + data_name = data.get_name() + print(f"data_name: {data_name}") + if data_name == "error": + # Check the error. + error_json, _ = data.get_property_to_json() + error_data = json.loads(error_json) + print(f"error_data: {error_data}") + ten_env_tester.stop_test() + + @override + async def on_stop(self, ten_env_tester: AsyncTenEnvTester) -> None: + pass + + +def test_error_check(): + property_json = { + "params": { + "api_key": "invalid_key", + "base_url": "wss://api.x.ai/v1/stt", + "sample_rate": 16000, + "encoding": "pcm", + "language": "en", + } + } + tester = XAIAsrExtensionTester() + tester.set_test_mode_single( + "xai_asr_python", json.dumps(property_json) + ) + err = tester.run() + assert err is None diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py new file mode 100644 index 0000000000..28a7c504cb --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py @@ -0,0 +1,57 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from xai_asr_python.config import XAIASRConfig +from xai_asr_python.extension import XAIASRExtension + + +def test_finalize_emits_result_and_finalize_end(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.config = XAIASRConfig( + finalize_timeout_ms=10, + params={"api_key": "xai-test-key"}, + ) + extension.recognition = MagicMock() + extension.recognition.send_audio_done = AsyncMock() + extension.recognition.wait_for_done = AsyncMock( + return_value={"text": "done text", "start": 0.1, "duration": 0.2} + ) + extension.recognition.done_event = asyncio.Event() + extension._emit_asr_result = AsyncMock() + extension.send_asr_finalize_end = AsyncMock() + + await extension.finalize("session-123") + + assert extension.last_finalize_timestamp == 0 + extension.recognition.send_audio_done.assert_awaited_once() + extension._emit_asr_result.assert_awaited_once() + extension.send_asr_finalize_end.assert_awaited_once() + + asyncio.run(_run()) + + +def test_finalize_timeout_still_emits_finalize_end(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.config = XAIASRConfig( + finalize_timeout_ms=10, + params={"api_key": "xai-test-key"}, + ) + extension.recognition = MagicMock() + extension.recognition.send_audio_done = AsyncMock() + extension.recognition.wait_for_done = AsyncMock(return_value=None) + extension.recognition.done_event = asyncio.Event() + extension._emit_asr_result = AsyncMock() + extension.send_asr_finalize_end = AsyncMock() + + await extension.finalize("session-123") + + assert extension.last_finalize_timestamp == 0 + extension.recognition.send_audio_done.assert_awaited_once() + extension._emit_asr_result.assert_not_awaited() + extension.send_asr_finalize_end.assert_awaited_once() + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_metrics.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_metrics.py new file mode 100644 index 0000000000..f016ec4df3 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_metrics.py @@ -0,0 +1,33 @@ +import asyncio +import time +from unittest.mock import AsyncMock, MagicMock + +from xai_asr_python.extension import XAIASRExtension + + +class FakeTimeline: + def get_total_user_audio_duration(self) -> int: + return 320 + + def reset(self) -> None: + return None + + +def test_on_open_sends_connect_delay_metrics(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.audio_timeline = FakeTimeline() + extension.buffered_frames = asyncio.Queue() + extension.send_connect_delay_metrics = AsyncMock() + extension.reconnect_manager = MagicMock() + extension.connection_start_timestamp = int(time.time() * 1000) - 75 + + await extension.on_open() + + extension.send_connect_delay_metrics.assert_awaited_once() + delay_ms = extension.send_connect_delay_metrics.await_args.args[0] + assert delay_ms >= 50 + extension.reconnect_manager.mark_connection_successful.assert_called_once() + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_params.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_params.py new file mode 100644 index 0000000000..8d00d64364 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_params.py @@ -0,0 +1,64 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from xai_asr_python.config import XAIASRConfig +from xai_asr_python.extension import XAIASRExtension +from ten_ai_base.utils import encrypt + + +def test_invalid_sample_rate(): + config = XAIASRConfig(params={"api_key": "xai-test-key", "sample_rate": 1}) + config.apply_defaults() + try: + config.validate() + except ValueError as exc: + assert "Unsupported sample_rate" in str(exc) + else: + raise AssertionError("Expected invalid sample rate error") + + +def test_invalid_encoding(): + config = XAIASRConfig( + params={"api_key": "xai-test-key", "encoding": "mp3"} + ) + config.apply_defaults() + try: + config.validate() + except ValueError as exc: + assert "Unsupported encoding" in str(exc) + else: + raise AssertionError("Expected invalid encoding error") + + +def test_config_redacts_api_key(): + config = XAIASRConfig(params={"api_key": "xai-super-secret", "language": "en"}) + config.apply_defaults() + + safe_str = config.to_json(sensitive_handling=True) + + assert "xai-super-secret" not in safe_str + assert "en" in safe_str + assert "api_key" in safe_str + assert encrypt("xai-super-secret") in safe_str + + +def test_language_normalization(): + config = XAIASRConfig(params={"api_key": "xai-test-key", "language": "zh"}) + config.apply_defaults() + assert config.normalized_language == "zh-CN" + + +def test_start_connection_missing_api_key_emits_fatal_error(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.config = XAIASRConfig(params={"api_key": ""}) + extension.send_asr_error = AsyncMock() + + await extension.start_connection() + + error = extension.send_asr_error.await_args.args[0] + assert error.code == -1000 + assert "API key" in error.message + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py new file mode 100644 index 0000000000..5760efe5c4 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py @@ -0,0 +1,74 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from ten_ai_base.message import ModuleErrorCode + +from xai_asr_python.extension import XAIASRExtension +from xai_asr_python.reconnect_manager import ReconnectManager + + +def test_reconnect_manager_escalates_after_max_attempts(): + async def _run(): + errors = [] + manager = ReconnectManager( + base_delay=0, + max_delay=0, + max_attempts=4, + logger=MagicMock(), + ) + + async def failing_connect(): + raise RuntimeError("disconnect") + + async def error_handler(error): + errors.append(error.code) + + for _ in range(4): + await manager.handle_reconnect(failing_connect, error_handler) + + assert errors == [ + int(ModuleErrorCode.NON_FATAL_ERROR.value), + int(ModuleErrorCode.NON_FATAL_ERROR.value), + int(ModuleErrorCode.NON_FATAL_ERROR.value), + int(ModuleErrorCode.FATAL_ERROR.value), + ] + + asyncio.run(_run()) + + +def test_reconnect_counter_resets_after_success(): + manager = ReconnectManager(base_delay=0, max_delay=0, max_attempts=4) + manager.attempts = 3 + manager.mark_connection_successful() + assert manager.attempts == 0 + + +def test_on_close_retries_until_retry_ceiling(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.reconnect_manager = ReconnectManager( + base_delay=0, + max_delay=0, + max_attempts=4, + logger=MagicMock(), + ) + extension.send_asr_error = AsyncMock() + extension._connect_recognition = AsyncMock( + side_effect=RuntimeError("disconnect") + ) + + await extension.on_close() + + assert extension._connect_recognition.await_count == 4 + observed_codes = [ + call.args[0].code for call in extension.send_asr_error.await_args_list + ] + assert observed_codes == [ + int(ModuleErrorCode.NON_FATAL_ERROR.value), + int(ModuleErrorCode.NON_FATAL_ERROR.value), + int(ModuleErrorCode.NON_FATAL_ERROR.value), + int(ModuleErrorCode.FATAL_ERROR.value), + ] + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_vendor_error.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_vendor_error.py new file mode 100644 index 0000000000..ca8c390867 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_vendor_error.py @@ -0,0 +1,42 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +from xai_asr_python.extension import XAIASRExtension + + +def test_vendor_error_reports_non_fatal_with_vendor_info(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.send_asr_error = AsyncMock() + + await extension.on_error("temporary websocket error", 499) + + error = extension.send_asr_error.await_args.args[0] + vendor_info = extension.send_asr_error.await_args.args[1] + + assert error.code == 1000 + assert vendor_info.vendor == "xai" + assert vendor_info.code == "499" + assert vendor_info.message == "temporary websocket error" + + asyncio.run(_run()) + + +def test_vendor_error_reports_fatal_for_unauthorized(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.send_asr_error = AsyncMock() + + await extension.on_error("401 Unauthorized", 401) + + error = extension.send_asr_error.await_args.args[0] + vendor_info = extension.send_asr_error.await_args.args[1] + + assert error.code == -1000 + assert vendor_info.vendor == "xai" + assert vendor_info.code == "401" + assert vendor_info.message == "401 Unauthorized" + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/README.md b/ai_agents/agents/ten_packages/extension/xai_tts_python/README.md new file mode 100644 index 0000000000..6fd6857f8d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/README.md @@ -0,0 +1,91 @@ +# xAI TTS Extension + +A TEN Framework extension that provides Text-to-Speech (TTS) capabilities using +xAI's streaming voice API. + +## Features + +- Real-time streaming TTS via WebSocket +- Five xAI voices (`eve`, `ara`, `rex`, `sal`, `leo`) +- Configurable codecs and sample rates +- PCM streaming output for realtime voice graphs +- TTFB (Time to First Byte) metrics reporting +- Audio dump capability for debugging + +## Configuration + +### Properties + +| Property | Type | Default | Description | +|----------|------|---------|-------------| +| `params.api_key` | string | Required | xAI API key | +| `params.voice_id` | string | `eve` | Voice to use | +| `params.language` | string | `en` | BCP-47 language code | +| `params.codec` | string | `pcm` | Output codec | +| `params.sample_rate` | int | `24000` | Output sample rate in Hz | +| `params.base_url` | string | `wss://api.x.ai/v1/tts` | WebSocket endpoint | +| `params.` | scalar | Optional | Additional xAI websocket query parameters passed through to the vendor | +| `dump` | bool | `false` | Enable audio dumping | +| `dump_path` | string | `/tmp` | Path for audio dump files | + +### Example Configuration + +```json +{ + "params": { + "api_key": "${env:XAI_API_KEY}", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000, + "optimize_streaming_latency": 0 + }, + "dump": false, + "dump_path": "/tmp" +} +``` + +Known extension-owned keys such as `api_key`, `base_url`, `voice_id`, `language`, +`codec`, and `sample_rate` are normalized onto the config object. Any remaining +scalar keys under `params` are appended to the xAI websocket query string. + +## Voices + +- `eve` - energetic, upbeat +- `ara` - warm, friendly +- `rex` - confident, clear +- `sal` - smooth, balanced +- `leo` - authoritative, strong + +## API Interface + +This extension implements the standard TEN TTS interface: + +### Input Data +- `tts_text_input` - Text to synthesize +- `tts_flush` - Flush pending audio + +### Output Data +- `tts_audio_start` - Audio generation started +- `tts_audio_end` - Audio generation completed +- `metrics` - Performance metrics (TTFB, duration) +- `error` - Error information + +### Output Audio +- `pcm_frame` - PCM audio data (16-bit, mono) + +## Running Tests + +```bash +cd xai_tts_python +tman -y install --standalone +./tests/bin/start +``` + +## Environment Variables + +- `XAI_API_KEY` - Your xAI API key + +## License + +Apache License, Version 2.0 diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/__init__.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/__init__.py new file mode 100644 index 0000000000..72593ab225 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/__init__.py @@ -0,0 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +from . import addon diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/addon.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/addon.py new file mode 100644 index 0000000000..95c20886f2 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/addon.py @@ -0,0 +1,15 @@ +from ten_runtime import ( + Addon, + register_addon_as_extension, + TenEnv, +) + + +@register_addon_as_extension("xai_tts_python") +class XAITTSExtensionAddon(Addon): + + def on_create_instance(self, ten_env: TenEnv, name: str, context) -> None: + from .extension import XAITTSExtension + + ten_env.log_info("XAITTSExtensionAddon on_create_instance") + ten_env.on_create_instance_done(XAITTSExtension(name), context) diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py new file mode 100644 index 0000000000..c713373517 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +from typing import Any +import copy + +from ten_ai_base import utils + +from pydantic import BaseModel, Field + + +class XAITTSConfig(BaseModel): + api_key: str = "" + base_url: str = "wss://api.x.ai/v1/tts" + voice_id: str = "eve" + language: str = "en" + codec: str = "pcm" + sample_rate: int = 24000 + bit_rate: int = 128000 + optimize_streaming_latency: int = 0 + text_normalization: bool = False + + dump: bool = False + dump_path: str = "/tmp" + params: dict[str, Any] = Field(default_factory=dict) + + def update_params(self) -> None: + params = self._ensure_dict(self.params) + self.params = params + + self.api_key = str(params.pop("api_key", self.api_key) or "") + self.base_url = str(params.pop("base_url", self.base_url) or "") + self.voice_id = str(params.pop("voice_id", self.voice_id) or "") + self.language = str(params.pop("language", self.language) or "") + self.codec = str(params.pop("codec", self.codec) or "") + self.sample_rate = int( + params.pop("sample_rate", self.sample_rate) or self.sample_rate + ) + self.bit_rate = int( + params.pop("bit_rate", self.bit_rate) or self.bit_rate + ) + self.optimize_streaming_latency = int( + params.pop( + "optimize_streaming_latency", + self.optimize_streaming_latency, + ) + or self.optimize_streaming_latency + ) + self.text_normalization = bool( + params.pop("text_normalization", self.text_normalization) + ) + + def validate(self) -> None: + if not self.api_key: + raise ValueError("API key is required") + if not ( + self.api_key.startswith("xai-") + or self.api_key.startswith("test") + ): + raise ValueError("API key must start with 'xai-'") + if self.sample_rate not in {8000, 16000, 22050, 24000, 44100, 48000}: + raise ValueError(f"Unsupported sample rate: {self.sample_rate}") + if self.codec not in {"pcm", "mp3", "wav", "mulaw", "ulaw", "alaw"}: + raise ValueError(f"Unsupported codec: {self.codec}") + + def to_str(self, sensitive_handling: bool = True) -> str: + if not sensitive_handling: + return f"{self}" + + config = copy.deepcopy(self) + + if config.api_key: + config.api_key = utils.encrypt(config.api_key) + + return f"{config}" + + @staticmethod + def _ensure_dict(value: Any) -> dict[str, Any]: + if isinstance(value, dict): + return value + return {} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py new file mode 100644 index 0000000000..0ca87604ab --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py @@ -0,0 +1,384 @@ +from datetime import datetime +import os +import traceback + +from ten_ai_base.const import LOG_CATEGORY_KEY_POINT, LOG_CATEGORY_VENDOR +from ten_ai_base.helper import PCMWriter +from ten_ai_base.message import ( + ModuleError, + ModuleErrorCode, + ModuleErrorVendorInfo, + ModuleType, + TTSAudioEndReason, +) +from ten_ai_base.struct import TTSTextInput, TTSTextResult +from ten_ai_base.tts2 import AsyncTTS2BaseExtension +from ten_runtime import AsyncTenEnv + +from .config import XAITTSConfig +from .xai_tts import ( + EVENT_TTS_END, + EVENT_TTS_ERROR, + EVENT_TTS_RESPONSE, + EVENT_TTS_TTFB_METRIC, + XAITTSClient, + XAITTSConnectionException, +) + + +class XAITTSExtension(AsyncTTS2BaseExtension): + def __init__(self, name: str) -> None: + super().__init__(name) + self.config: XAITTSConfig | None = None + self.client: XAITTSClient | None = None + self.current_request_id: str | None = None + self.current_turn_id = -1 + self.sent_ts: datetime | None = None + self.current_request_finished = False + self.total_audio_bytes = 0 + self._is_stopped = False + self.recorder_map: dict[str, PCMWriter] = {} + self._audio_start_sent = False + self._request_text_length = 0 + self._request_text = "" + self._request_metadata: dict = {} + self._request_seq_id_map: dict[str, int] = {} + self._audio_start_timestamp_ms = 0 + + @staticmethod + def _contains_spoken_content(text: str) -> bool: + return any(char.isalnum() for char in text) + + async def on_init(self, ten_env: AsyncTenEnv) -> None: + try: + await super().on_init(ten_env) + config_json_str, _ = await self.ten_env.get_property_to_json("") + if not config_json_str or config_json_str.strip() == "{}": + raise ValueError("Configuration is empty. Required api_key is missing.") + self.config = XAITTSConfig.model_validate_json(config_json_str) + self.config.update_params() + self.config.validate() + ten_env.log_info( + f"config: {self.config.to_str(sensitive_handling=True)}", + category=LOG_CATEGORY_KEY_POINT, + ) + self.client = self._create_client(ten_env) + await self.client.start() + except Exception as e: + ten_env.log_error(f"on_init failed: {traceback.format_exc()}") + await self.send_tts_error( + request_id="", + error=ModuleError( + message=f"Initialization failed: {e}", + module=ModuleType.TTS, + code=ModuleErrorCode.FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ), + ) + + async def on_stop(self, ten_env: AsyncTenEnv) -> None: + self._is_stopped = True + if self.client: + await self.client.stop() + self.client = None + for recorder in list(self.recorder_map.values()): + try: + await recorder.flush() + except Exception as e: + ten_env.log_error(f"Error flushing PCMWriter: {e}") + await super().on_stop(ten_env) + + async def on_deinit(self, ten_env: AsyncTenEnv) -> None: + await super().on_deinit(ten_env) + + async def cancel_tts(self) -> None: + self.current_request_finished = True + self._request_text_length = 0 + if self.current_request_id: + if self.client: + await self.client.cancel() + await self._finalize_request(TTSAudioEndReason.INTERRUPTED) + + def vendor(self) -> str: + return "xai" + + def synthesize_audio_sample_rate(self) -> int: + return self.config.sample_rate if self.config else 24000 + + def synthesize_audio_channels(self) -> int: + return 1 + + def synthesize_audio_sample_width(self) -> int: + return 2 + + def _create_client(self, ten_env: AsyncTenEnv) -> XAITTSClient: + return XAITTSClient(config=self.config, ten_env=ten_env) + + async def _ensure_client(self) -> None: + if self.client is None: + self.client = self._create_client(self.ten_env) + await self.client.start() + + async def _reconnect_client(self) -> None: + if self.client: + await self.client.stop() + self.client = None + self.client = self._create_client(self.ten_env) + await self.client.start() + + async def _finalize_request( + self, reason: TTSAudioEndReason, error: ModuleError | None = None + ) -> None: + await self._emit_tts_text_result(reason) + if not self._audio_start_sent: + await self.send_tts_audio_start(request_id=self.current_request_id) + self._audio_start_sent = True + request_event_interval = self._calculate_request_event_interval_ms() + duration_ms = self._calculate_audio_duration_ms() + await self.send_tts_audio_end( + request_id=self.current_request_id, + request_event_interval_ms=request_event_interval, + request_total_audio_duration_ms=duration_ms, + reason=reason, + ) + await self.send_usage_metrics(self.current_request_id or "") + if self.current_request_id in self.recorder_map: + await self.recorder_map[self.current_request_id].flush() + await self.finish_request( + request_id=self.current_request_id, + reason=reason, + error=error, + ) + self.sent_ts = None + if self.current_request_id: + self._request_seq_id_map.pop(self.current_request_id, None) + self._request_text = "" + self._request_metadata = {} + self._audio_start_timestamp_ms = 0 + + def _calculate_audio_duration_ms(self) -> int: + bytes_per_sample = self.synthesize_audio_sample_width() + channels = self.synthesize_audio_channels() + if bytes_per_sample <= 0 or channels <= 0: + return 0 + duration_sec = self.total_audio_bytes / ( + self.synthesize_audio_sample_rate() * bytes_per_sample * channels + ) + return int(duration_sec * 1000) + + def _calculate_request_event_interval_ms(self) -> int: + if self.sent_ts is None: + return 0 + return int((datetime.now() - self.sent_ts).total_seconds() * 1000) + + async def request_tts(self, t: TTSTextInput) -> None: + try: + await self._ensure_client() + + if t.request_id != self.current_request_id: + if self.client: + self.client.reset_ttfb() + self.current_request_id = t.request_id + self.current_request_finished = False + self.total_audio_bytes = 0 + self.sent_ts = None + self._audio_start_sent = False + self._request_text_length = 0 + self._request_text = "" + self._audio_start_timestamp_ms = 0 + self._request_metadata = t.metadata.copy() if t.metadata else {} + if t.metadata is not None: + self.session_id = t.metadata.get("session_id", "") + self.current_turn_id = t.metadata.get("turn_id", -1) + await self._setup_recorder(t.request_id) + elif self.current_request_finished: + self.ten_env.log_error( + f"Received text for finished request_id '{t.request_id}'" + ) + return + + prepared_text = t.text.strip() + if ( + t.text_input_end + and prepared_text + and not self._request_text + and not self._contains_spoken_content(prepared_text) + ): + error = ModuleError( + message="xAI TTS input must contain spoken text", + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ) + await self.send_tts_error( + request_id=t.request_id, + error=error, + ) + await self.finish_request( + request_id=t.request_id, + reason=TTSAudioEndReason.ERROR, + error=error, + ) + self.current_request_finished = True + self._request_text = "" + self._request_metadata = {} + self._request_text_length = 0 + self.total_audio_bytes = 0 + self.sent_ts = None + self._audio_start_sent = False + return + if prepared_text: + self._request_text_length += len(prepared_text) + if self._request_text_length > 15000: + raise ValueError("xAI TTS text exceeds 15000 characters") + self._request_text += prepared_text + self.metrics_add_input_characters(len(prepared_text)) + + if self._is_stopped: + return + + if t.text_input_end: + self.current_request_finished = True + + if not prepared_text: + if t.text_input_end: + await self._finalize_request(TTSAudioEndReason.REQUEST_END) + return + + await self._process_tts_text(prepared_text, t) + except XAITTSConnectionException as e: + await self._handle_connection_error(e) + except Exception as e: + self.ten_env.log_error( + f"Error in request_tts: {traceback.format_exc()}. text: {t.text}" + ) + error = ModuleError( + message=str(e), + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ) + await self._finalize_request(TTSAudioEndReason.ERROR, error=error) + await self._reconnect_client() + + async def _process_tts_text(self, text: str, t: TTSTextInput) -> None: + if self.sent_ts is None: + self.sent_ts = datetime.now() + + async for data_msg, event_status in self.client.get(text): + self.ten_env.log_debug(f"Received event_status: {event_status}") + if event_status == EVENT_TTS_RESPONSE: + if data_msg and isinstance(data_msg, bytes): + chunk_timestamp_ms = self._get_next_audio_chunk_timestamp_ms() + self.metrics_add_recv_audio_chunks(data_msg) + self.total_audio_bytes += len(data_msg) + await self._write_dump(data_msg) + await self.send_tts_audio_data( + data_msg, timestamp=chunk_timestamp_ms + ) + elif event_status == EVENT_TTS_TTFB_METRIC: + if isinstance(data_msg, int): + self.sent_ts = datetime.now() + await self.send_tts_audio_start( + request_id=self.current_request_id + ) + self._audio_start_sent = True + await self.send_tts_ttfb_metrics( + request_id=self.current_request_id, + ttfb_ms=data_msg, + extra_metadata={ + "voice_id": self.config.voice_id, + "codec": self.config.codec, + }, + ) + elif event_status == EVENT_TTS_END: + if t.text_input_end: + await self._finalize_request(TTSAudioEndReason.REQUEST_END) + break + elif event_status == EVENT_TTS_ERROR: + error_message = ( + data_msg.decode("utf-8", errors="ignore") + if isinstance(data_msg, bytes) + else "Unknown xAI TTS error" + ) + error = ModuleError( + message=error_message, + module=ModuleType.TTS, + code=ModuleErrorCode.NON_FATAL_ERROR, + vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), + ) + if t.text_input_end: + await self._finalize_request( + TTSAudioEndReason.ERROR, error=error + ) + else: + self.ten_env.log_warn( + f"Transient TTS error on non-final chunk for " + f"{t.request_id}: {error_message}" + ) + break + + def _get_next_audio_chunk_timestamp_ms(self) -> int: + if self._audio_start_timestamp_ms <= 0: + self._audio_start_timestamp_ms = int(datetime.now().timestamp() * 1000) + return self._audio_start_timestamp_ms + self._calculate_audio_duration_ms() + + async def _emit_tts_text_result(self, reason: TTSAudioEndReason) -> None: + if not self.current_request_id or not self._request_text: + return + + metadata = self._request_metadata.copy() + current_seq_id = self._request_seq_id_map.get(self.current_request_id, 0) + self._request_seq_id_map[self.current_request_id] = current_seq_id + 1 + metadata["turn_seq_id"] = current_seq_id + metadata["turn_status"] = ( + 2 if reason == TTSAudioEndReason.INTERRUPTED else 1 + ) + + start_ms = self._audio_start_timestamp_ms + if start_ms <= 0: + start_ms = int(datetime.now().timestamp() * 1000) + + transcript_result = TTSTextResult( + request_id=self.current_request_id, + text=self._request_text, + start_ms=start_ms, + duration_ms=self._calculate_audio_duration_ms(), + words=None, + text_result_end=True, + metadata=metadata, + ) + self.metrics_add_output_characters(len(self._request_text)) + await self.send_tts_text_result(transcript_result) + + async def _handle_connection_error( + self, e: XAITTSConnectionException + ) -> None: + error_code = ( + ModuleErrorCode.FATAL_ERROR + if e.status_code == 401 + else ModuleErrorCode.NON_FATAL_ERROR + ) + error = ModuleError( + message=str(e), + module=ModuleType.TTS, + code=error_code, + vendor_info=ModuleErrorVendorInfo( + vendor=self.vendor(), + code=str(e.status_code), + message=e.body, + ), + ) + await self._finalize_request(TTSAudioEndReason.ERROR, error=error) + + async def _setup_recorder(self, request_id: str) -> None: + if self.config and self.config.dump: + dump_path = os.path.join( + self.config.dump_path, f"{request_id}_xai_tts_out.pcm" + ) + os.makedirs(os.path.dirname(dump_path), exist_ok=True) + self.recorder_map[request_id] = PCMWriter(dump_path) + + async def _write_dump(self, data_msg: bytes) -> None: + if self.current_request_id in self.recorder_map: + await self.recorder_map[self.current_request_id].write(data_msg) diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/manifest.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/manifest.json new file mode 100644 index 0000000000..f4571b089e --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/manifest.json @@ -0,0 +1,77 @@ +{ + "type": "extension", + "name": "xai_tts_python", + "version": "0.1.1", + "dependencies": [ + { + "type": "system", + "name": "ten_runtime_python", + "version": "0.11" + }, + { + "type": "system", + "name": "ten_ai_base", + "version": "0.7" + } + ], + "package": { + "include": [ + "manifest.json", + "property.json", + "BUILD.gn", + "**.tent", + "**.py", + "README.md", + "requirements.txt" + ] + }, + "api": { + "interface": [ + { + "import_uri": "../../system/ten_ai_base/api/tts-interface.json" + } + ], + "property": { + "properties": { + "dump": { + "type": "bool" + }, + "dump_path": { + "type": "string" + }, + "params": { + "type": "object", + "properties": { + "api_key": { + "type": "string" + }, + "base_url": { + "type": "string" + }, + "voice_id": { + "type": "string" + }, + "language": { + "type": "string" + }, + "codec": { + "type": "string" + }, + "sample_rate": { + "type": "int32" + }, + "bit_rate": { + "type": "int32" + }, + "optimize_streaming_latency": { + "type": "int32" + }, + "text_normalization": { + "type": "bool" + } + } + } + } + } + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/property.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/property.json new file mode 100644 index 0000000000..43e21c949b --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/property.json @@ -0,0 +1,12 @@ +{ + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "${env:XAI_API_KEY|}", + "base_url": "wss://api.x.ai/v1/tts", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/requirements.txt b/ai_agents/agents/ten_packages/extension/xai_tts_python/requirements.txt new file mode 100644 index 0000000000..61366b210c --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/requirements.txt @@ -0,0 +1,2 @@ +websockets>=15.0.1 +pydantic diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/__init__.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/__init__.py new file mode 100644 index 0000000000..da402faf43 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/__init__.py @@ -0,0 +1,5 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/bin/start new file mode 100755 index 0000000000..41da3fdb45 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/bin/start @@ -0,0 +1,21 @@ +#!/bin/bash + +set -e + +cd "$(dirname "${BASH_SOURCE[0]}")/../.." + +export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:.ten/app/ten_packages/system/ten_runtime_python/interface:.ten/app/ten_packages/system/ten_ai_base/interface:$PYTHONPATH + +# If the Python app imports some modules that are compiled with a different +# version of libstdc++ (ex: PyTorch), the Python app may encounter confusing +# errors. To solve this problem, we can preload the correct version of +# libstdc++. +# +# export LD_PRELOAD=/lib/x86_64-linux-gnu/libstdc++.so.6 +# +# Another solution is to make sure the module 'ten_runtime_python' is imported +# _after_ the module that requires another version of libstdc++ is imported. +# +# Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 + +pytest tests/ "$@" diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property.json new file mode 100644 index 0000000000..22a72a4884 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property.json @@ -0,0 +1,12 @@ +{ + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/tts", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting1.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting1.json new file mode 100644 index 0000000000..199648f7ca --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting1.json @@ -0,0 +1,11 @@ +{ + "dump": true, + "dump_path": "./tests/keep_dump_output/", + "params": { + "api_key": "${env:XAI_API_KEY}", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 16000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting2.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting2.json new file mode 100644 index 0000000000..8f6416cb59 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_basic_audio_setting2.json @@ -0,0 +1,11 @@ +{ + "dump": true, + "dump_path": "./tests/keep_dump_output/", + "params": { + "api_key": "${env:XAI_API_KEY}", + "voice_id": "rex", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_dump.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_dump.json new file mode 100644 index 0000000000..4f76f4b8a8 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_dump.json @@ -0,0 +1,11 @@ +{ + "dump": true, + "dump_path": "./dump/", + "params": { + "api_key": "${env:XAI_API_KEY}", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_invalid.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_invalid.json new file mode 100644 index 0000000000..6233cf106a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_invalid.json @@ -0,0 +1,5 @@ +{ + "params": { + "api_key": "invalid" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_miss_required.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_miss_required.json new file mode 100644 index 0000000000..df133e721a --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_miss_required.json @@ -0,0 +1,5 @@ +{ + "params": { + "api_key": "" + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json new file mode 100644 index 0000000000..22a72a4884 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json @@ -0,0 +1,12 @@ +{ + "dump": false, + "dump_path": "/tmp", + "params": { + "api_key": "${env:XAI_API_KEY}", + "base_url": "wss://api.x.ai/v1/tts", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000 + } +} diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/conftest.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/conftest.py new file mode 100644 index 0000000000..958647c64d --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/conftest.py @@ -0,0 +1,107 @@ +import sys +from pathlib import Path + +# Add project root to sys.path for test imports +project_root = str(Path(__file__).resolve().parents[6]) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +import threading +from typing_extensions import override +import pytest +from ten_runtime import ( + App, + TenEnv, +) + + +class FakeApp(App): + def __init__(self): + super().__init__() + self.event: threading.Event | None = None + + # In the case of a fake app, we use `on_init` to allow the blocked testing + # fixture to continue execution, rather than using `on_configure`. The + # reason is that in the TEN runtime C core, the relationship between the + # addon manager and the (fake) app is bound after `on_configure_done` is + # called. So we only need to let the testing fixture continue execution + # after this action in the TEN runtime C core, and at the upper layer + # timing, the earliest point is within the `on_init()` function of the upper + # TEN app. Therefore, we release the testing fixture lock within the user + # layer's `on_init()` of the TEN app. + @override + def on_init(self, ten_env: TenEnv) -> None: + assert self.event + self.event.set() + + ten_env.on_init_done() + + @override + def on_configure(self, ten_env: TenEnv) -> None: + ten_env.init_property_from_json( + json.dumps( + { + "ten": { + "log": { + "handlers": [ + { + "matchers": [{"level": "debug"}], + "formatter": { + "type": "plain", + "colored": True, + }, + "emitter": { + "type": "console", + "config": {"stream": "stdout"}, + }, + } + ] + } + } + } + ), + ) + + ten_env.on_configure_done() + + +class FakeAppCtx: + def __init__(self, event: threading.Event): + self.fake_app: FakeApp | None = None + self.event = event + + +def run_fake_app(fake_app_ctx: FakeAppCtx): + app = FakeApp() + app.event = fake_app_ctx.event + fake_app_ctx.fake_app = app + app.run(False) + + +@pytest.fixture(scope="session", autouse=True) +def global_setup_and_teardown(): + event = threading.Event() + fake_app_ctx = FakeAppCtx(event) + + fake_app_thread = threading.Thread( + target=run_fake_app, args=(fake_app_ctx,) + ) + fake_app_thread.start() + + event.wait() + + assert fake_app_ctx.fake_app is not None + + # Yield control to the test; after the test execution is complete, continue + # with the teardown process. + yield + + # Teardown part. + fake_app_ctx.fake_app.close() + fake_app_thread.join() diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py new file mode 100644 index 0000000000..4158865d7e --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py @@ -0,0 +1,314 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock +import os +import asyncio +import filecmp +import shutil + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput, TTSFlush +from xai_tts_python.xai_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) + + +# ================ test dump file functionality ================ +class ExtensionTesterDump(ExtensionTester): + def __init__(self): + super().__init__() + self.dump_dir = "./dump/" + self.test_dump_file_path = os.path.join( + self.dump_dir, "test_manual_dump.pcm" + ) + self.audio_end_received = False + self.received_audio_chunks = [] + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Dump test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_1", + text="hello word, hello agora", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + buf = audio_frame.lock_buf() + try: + copied_data = bytes(buf) + self.received_audio_chunks.append(copied_data) + finally: + audio_frame.unlock_buf(buf) + + def write_test_dump_file(self): + with open(self.test_dump_file_path, "wb") as f: + for chunk in self.received_audio_chunks: + f.write(chunk) + + def find_tts_dump_file(self) -> str | None: + if not os.path.exists(self.dump_dir): + return None + for filename in os.listdir(self.dump_dir): + if filename.endswith(".pcm") and filename != os.path.basename( + self.test_dump_file_path + ): + return os.path.join(self.dump_dir, filename) + return None + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_dump_functionality(MockXAITTSClient): + """Tests that the dump file from the TTS extension matches the audio received.""" + print("Starting test_dump_functionality with mock...") + + DUMP_PATH = "./dump/" + + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + os.makedirs(DUMP_PATH) + + mock_instance = MockXAITTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20 + fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20 + + async def mock_get_audio_stream(text: str): + yield (255, EVENT_TTS_TTFB_METRIC) + yield (fake_audio_chunk_1, EVENT_TTS_RESPONSE) + await asyncio.sleep(0.01) + yield (fake_audio_chunk_2, EVENT_TTS_RESPONSE) + await asyncio.sleep(0.01) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_stream + + tester = ExtensionTesterDump() + + dump_config = { + "dump": True, + "dump_path": DUMP_PATH, + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + + tester.set_test_mode_single("xai_tts_python", json.dumps(dump_config)) + + print("Running dump test...") + tester.run() + print("Dump test completed.") + + assert tester.audio_end_received, "Expected to receive tts_audio_end" + assert ( + len(tester.received_audio_chunks) > 0 + ), "Expected to receive audio chunks" + + tester.write_test_dump_file() + + tts_dump_file = tester.find_tts_dump_file() + assert ( + tts_dump_file is not None + ), f"Expected to find a TTS dump file in {DUMP_PATH}" + assert os.path.exists( + tts_dump_file + ), f"TTS dump file should exist: {tts_dump_file}" + + print( + f"Comparing test file {tester.test_dump_file_path} with TTS dump file {tts_dump_file}" + ) + assert filecmp.cmp( + tester.test_dump_file_path, tts_dump_file, shallow=False + ), "Test dump file and TTS dump file should have the same content" + + print( + f"Dump test passed: received {len(tester.received_audio_chunks)} audio chunks" + ) + + if os.path.exists(DUMP_PATH): + shutil.rmtree(DUMP_PATH) + + +# ================ test basic audio output ================ +class ExtensionTesterBasic(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_start_received = False + self.audio_end_received = False + self.audio_chunks_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Basic test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_basic", + text="Hello, this is a test of the Deepgram TTS extension.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_start": + ten_env.log_info("Received tts_audio_start.") + self.audio_start_received = True + elif name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + self.audio_chunks_count += 1 + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_basic_audio(MockXAITTSClient): + """Test basic TTS audio generation.""" + mock_instance = MockXAITTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + fake_audio_chunk = b"\x00\x01\x02\x03" * 100 + + async def mock_get_audio_stream(text: str): + yield (150, EVENT_TTS_TTFB_METRIC) + yield (fake_audio_chunk, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_stream + + tester = ExtensionTesterBasic() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_start_received, "tts_audio_start was not received." + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." + + +# ================ test flush functionality ================ +class ExtensionTesterFlush(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Flush test started.") + + tts_input = TTSTextInput( + request_id="tts_request_flush", + text="This is the first sentence.", + text_input_end=False, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + + flush = TTSFlush(flush_id="flush_1") + flush_data = Data.create("tts_flush") + flush_data.set_property_from_json(None, flush.model_dump_json()) + ten_env_tester.send_data(flush_data) + + tts_input2 = TTSTextInput( + request_id="tts_request_flush", + text="This is the final sentence.", + text_input_end=True, + ) + data2 = Data.create("tts_text_input") + data2.set_property_from_json(None, tts_input2.model_dump_json()) + ten_env_tester.send_data(data2) + + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end, stopping test.") + self.audio_end_received = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_flush(MockXAITTSClient): + """Test TTS flush functionality.""" + mock_instance = MockXAITTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + fake_audio_chunk = b"\x00\x01\x02\x03" * 50 + + async def mock_get_audio_stream(text: str): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio_chunk, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_stream + + tester = ExtensionTesterFlush() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end was not received after flush." diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py new file mode 100644 index 0000000000..dfcd8f9d4f --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py @@ -0,0 +1,166 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock, MagicMock + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput + + +# ================ test empty params ================ +class ExtensionTesterEmptyParams(ExtensionTester): + def __init__(self): + super().__init__() + self.error_received = False + self.error_code = None + self.error_message = None + self.error_module = None + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts""" + ten_env_tester.log_info("Test started") + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + + if name == "error": + self.error_received = True + json_str, _ = data.get_property_to_json(None) + error_data = json.loads(json_str) + + self.error_code = error_data.get("code") + self.error_message = error_data.get("message", "") + self.error_module = error_data.get("module", "") + + ten_env.log_info( + f"Received error: code={self.error_code}, message={self.error_message}" + ) + ten_env.stop_test() + + +def test_empty_params_fatal_error(): + """Test that empty params raises FATAL ERROR with code -1000""" + print("Starting test_empty_params_fatal_error...") + + # Empty params configuration + empty_params_config = { + "params": { + "api_key": "", + } + } + + tester = ExtensionTesterEmptyParams() + tester.set_test_mode_single("xai_tts_python", json.dumps(empty_params_config)) + + print("Running test...") + tester.run() + print("Test completed.") + + # Verify FATAL ERROR was received + assert tester.error_received, "Expected to receive error message" + assert ( + tester.error_code == -1000 + ), f"Expected error code -1000 (FATAL_ERROR), got {tester.error_code}" + assert tester.error_message is not None, "Error message should not be None" + assert len(tester.error_message) > 0, "Error message should not be empty" + + print(f"Empty params test passed: code={tester.error_code}") + + +# ================ test invalid api key ================ +class ExtensionTesterInvalidApiKey(ExtensionTester): + def __init__(self): + super().__init__() + self.error_received = False + self.error_code = None + self.error_message = None + self.vendor_info = None + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts, sends a TTS request to trigger the logic.""" + ten_env_tester.log_info( + "Invalid API key test started, sending TTS request" + ) + + tts_input = TTSTextInput( + request_id="test-request-invalid-key", + text="This text will trigger API key validation.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + + if name == "error": + self.error_received = True + json_str, _ = data.get_property_to_json(None) + error_data = json.loads(json_str) + + self.error_code = error_data.get("code") + self.error_message = error_data.get("message", "") + self.vendor_info = error_data.get("vendor_info", {}) + + ten_env.log_info( + f"Received error: code={self.error_code}, message={self.error_message}" + ) + ten_env.stop_test() + elif name == "tts_audio_end": + ten_env.stop_test() + + +@patch("xai_tts_python.xai_tts.websockets.connect") +def test_invalid_api_key_error(mock_websocket_connect): + """Test that an invalid API key is handled correctly with a mock.""" + print("Starting test_invalid_api_key_error with mock...") + + # Mock websocket to raise 401 unauthorized error + mock_websocket_connect.side_effect = Exception( + "401 Unauthorized - Invalid API key" + ) + + # Config with invalid API key + invalid_key_config = { + "params": { + "api_key": "invalid_api_key_test", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + + tester = ExtensionTesterInvalidApiKey() + tester.set_test_mode_single("xai_tts_python", json.dumps(invalid_key_config)) + + print("Running test with mock...") + tester.run() + print("Test with mock completed.") + + # Verify FATAL ERROR was received for incorrect API key + assert tester.error_received, "Expected to receive error message" + assert ( + tester.error_code == -1000 + ), f"Expected error code -1000 (FATAL_ERROR), got {tester.error_code}" + + # Verify vendor_info + vendor_info = tester.vendor_info + assert vendor_info is not None, "Expected vendor_info to be present" + assert ( + vendor_info.get("vendor") == "xai" + ), f"Expected vendor 'xai', got {vendor_info.get('vendor')}" + + print(f"Invalid API key test passed: code={tester.error_code}") diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py new file mode 100644 index 0000000000..cba8beecf3 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py @@ -0,0 +1,127 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock +import asyncio + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from xai_tts_python.xai_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) + + +# ================ test metrics ================ +class ExtensionTesterMetrics(ExtensionTester): + def __init__(self): + super().__init__() + self.ttfb_received = False + self.ttfb_value = -1 + self.audio_frame_received = False + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + """Called when test starts, sends a TTS request.""" + ten_env_tester.log_info("Metrics test started, sending TTS request.") + + tts_input = TTSTextInput( + request_id="tts_request_for_metrics", + text="hello, this is a metrics test.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + ten_env.log_info(f"on_data name: {name}") + if name == "metrics": + json_str, _ = data.get_property_to_json(None) + ten_env.log_info(f"Received metrics: {json_str}") + metrics_data = json.loads(json_str) + + # According to the structure, 'ttfb' is nested inside a 'metrics' object. + nested_metrics = metrics_data.get("metrics", {}) + if "ttfb" in nested_metrics: + self.ttfb_received = True + self.ttfb_value = nested_metrics.get("ttfb", -1) + ten_env.log_info( + f"Received TTFB metric with value: {self.ttfb_value}" + ) + + elif name == "tts_audio_end": + self.audio_end_received = True + # Stop the test only after both TTFB and audio end are received + if self.ttfb_received: + ten_env.log_info("Received tts_audio_end, stopping test.") + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + """Receives audio frames and confirms the stream is working.""" + if not self.audio_frame_received: + self.audio_frame_received = True + ten_env.log_info("First audio frame received.") + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_ttfb_metric_is_sent(MockXAITTSClient): + """ + Tests that a TTFB (Time To First Byte) metric is correctly sent after + receiving the first audio chunk from the TTS service. + """ + print("Starting test_ttfb_metric_is_sent with mock...") + + # --- Mock Configuration --- + mock_instance = MockXAITTSClient.return_value + mock_instance.start = AsyncMock() + mock_instance.stop = AsyncMock() + mock_instance.cancel = AsyncMock() + mock_instance.reset_ttfb = lambda: None + + # This async generator simulates the TTS client's get() method with a delay + async def mock_get_audio_with_delay(text: str): + await asyncio.sleep(0.2) + yield (255, EVENT_TTS_TTFB_METRIC) + yield (b"\x11\x22\x33", EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock_instance.get.side_effect = mock_get_audio_with_delay + + # --- Test Setup --- + metrics_config = { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + } + } + tester = ExtensionTesterMetrics() + tester.set_test_mode_single("xai_tts_python", json.dumps(metrics_config)) + + print("Running TTFB metrics test...") + tester.run() + print("TTFB metrics test completed.") + + # --- Assertions --- + assert tester.audio_frame_received, "Did not receive any audio frame." + assert tester.audio_end_received, "Did not receive the tts_audio_end event." + assert tester.ttfb_received, "TTFB metric was not received." + + # Check if the TTFB value matches what we sent + assert ( + tester.ttfb_value == 255 + ), f"Expected TTFB to be 255ms, but got {tester.ttfb_value}ms." + + print(f"TTFB metric test passed. Received TTFB: {tester.ttfb_value}ms.") diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py new file mode 100644 index 0000000000..59aeb7a6c1 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py @@ -0,0 +1,182 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from urllib.parse import parse_qs, urlparse +from unittest.mock import patch, AsyncMock, MagicMock + + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from xai_tts_python.xai_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) +from xai_tts_python.config import XAITTSConfig +from xai_tts_python.xai_tts import XAITTSClient + + +def create_mock_client(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + fake_audio = b"\x00\x01\x02\x03" * 100 + + async def mock_get(text): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + +def test_params_passthrough(): + """Additional xAI params should be appended to the websocket URL.""" + config = XAITTSConfig( + params={ + "api_key": "test_api_key", + "base_url": "wss://api.x.ai/v1/tts", + "voice_id": "eve", + "language": "en", + "codec": "pcm", + "sample_rate": 24000, + "bit_rate": 64000, + "text_normalization": True, + } + ) + config.update_params() + + client = XAITTSClient(config=config, ten_env=MagicMock()) + parsed = urlparse(client._ws_url) + query = parse_qs(parsed.query) + + assert parsed.scheme == "wss" + assert parsed.netloc == "api.x.ai" + assert parsed.path == "/v1/tts" + assert query["voice"] == ["eve"] + assert query["language"] == ["en"] + assert query["codec"] == ["pcm"] + assert query["sample_rate"] == ["24000"] + assert "bit_rate" not in query + assert query["text_normalization"] == ["true"] + assert "api_key" not in query + assert "base_url" not in query + + +# ================ test different sample rates ================ +class ExtensionTesterSampleRate(ExtensionTester): + def __init__(self, sample_rate: int): + super().__init__() + self.sample_rate = sample_rate + self.audio_end_received = False + self.audio_chunks_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info(f"Sample rate test: {self.sample_rate}Hz") + + tts_input = TTSTextInput( + request_id="tts_request_sr", + text="Testing different sample rates.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + self.audio_chunks_count += 1 + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_sample_rate_16000(MockXAITTSClient): + """Test with 16000 Hz sample rate.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSampleRate(16000) + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 16000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_sample_rate_24000(MockXAITTSClient): + """Test with 24000 Hz sample rate.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSampleRate(24000) + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_sample_rate_48000(MockXAITTSClient): + """Test with 48000 Hz sample rate.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSampleRate(48000) + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 48000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert tester.audio_chunks_count > 0, "No audio chunks received." diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py new file mode 100644 index 0000000000..facfa5a641 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py @@ -0,0 +1,313 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import json +from unittest.mock import patch, AsyncMock + + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from xai_tts_python.xai_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, +) +from unittest.mock import MagicMock + + +def create_mock_client(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + fake_audio = b"\x00\x01\x02\x03" * 100 + + async def mock_get(text): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + +# ================ test empty text ================ +class ExtensionTesterEmptyText(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Empty text test started.") + + tts_input = TTSTextInput( + request_id="tts_request_empty", + text="", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end for empty text.") + self.audio_end_received = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_empty_text(MockXAITTSClient): + """Test that empty text is handled gracefully.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterEmptyText() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end should be sent for empty text." + + +# ================ test whitespace only text ================ +class ExtensionTesterWhitespaceText(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Whitespace text test started.") + + tts_input = TTSTextInput( + request_id="tts_request_whitespace", + text=" \n\t ", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end for whitespace text.") + self.audio_end_received = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_whitespace_text(MockXAITTSClient): + """Test that whitespace-only text is handled gracefully.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterWhitespaceText() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end should be sent for whitespace text." + + +class ExtensionTesterPunctuationText(ExtensionTester): + def __init__(self): + super().__init__() + self.error_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + tts_input = TTSTextInput( + request_id="tts_request_punctuation", + text="!!! ... ???", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + if data.get_name() == "error": + self.error_received = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_punctuation_only_text(MockXAITTSClient): + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterPunctuationText() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.error_received, "Punctuation-only text should raise an error." + + +# ================ test long text ================ +class ExtensionTesterLongText(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + self.audio_chunks_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Long text test started.") + + long_text = "This is a longer piece of text. " * 20 + + tts_input = TTSTextInput( + request_id="tts_request_long", + text=long_text, + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + ten_env.log_info("Received tts_audio_end for long text.") + self.audio_end_received = True + ten_env.stop_test() + + def on_audio_frame(self, ten_env: TenEnvTester, audio_frame): + self.audio_chunks_count += 1 + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_long_text(MockXAITTSClient): + """Test that long text is handled correctly.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterLongText() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert ( + tester.audio_end_received + ), "tts_audio_end was not received for long text." + assert ( + tester.audio_chunks_count > 0 + ), "No audio chunks received for long text." + + +# ================ test special characters ================ +class ExtensionTesterSpecialChars(ExtensionTester): + def __init__(self): + super().__init__() + self.audio_end_received = False + self.error_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Special characters test started.") + + tts_input = TTSTextInput( + request_id="tts_request_special", + text="Hello! How are you? I'm fine, thanks. $100 is 100%.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + self.audio_end_received = True + ten_env.stop_test() + elif name == "error": + self.error_received = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_special_characters(MockXAITTSClient): + """Test that special characters are handled correctly.""" + MockXAITTSClient.return_value = create_mock_client() + + tester = ExtensionTesterSpecialChars() + tester.set_test_mode_single( + "xai_tts_python", + json.dumps( + { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, + } + ), + ) + + tester.run() + + assert tester.audio_end_received, "tts_audio_end was not received." + assert ( + not tester.error_received + ), "Error should not be received for special chars." diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py new file mode 100644 index 0000000000..ef2b890366 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py @@ -0,0 +1,454 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# +import asyncio +import json +from unittest.mock import patch, AsyncMock, MagicMock + + +from ten_runtime import ( + ExtensionTester, + TenEnvTester, + Data, +) +from ten_ai_base.struct import TTSTextInput +from xai_tts_python.xai_tts import ( + EVENT_TTS_RESPONSE, + EVENT_TTS_END, + EVENT_TTS_TTFB_METRIC, + EVENT_TTS_ERROR, + XAITTSClient, + XAITTSConnectionException, +) +from xai_tts_python.config import XAITTSConfig + +MOCK_CONFIG = { + "params": { + "api_key": "test_api_key", + "voice_id": "eve", + "codec": "pcm", + "sample_rate": 24000, + }, +} + + +def create_mock_client(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + fake_audio = b"\x00\x01\x02\x03" * 100 + + async def mock_get(text): + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + +# ================ test sequential requests ================ +class SequentialRequestsTester(ExtensionTester): + """Send 3 requests with different IDs sequentially. + + Each request should produce tts_audio_start, audio + frames, and tts_audio_end with the correct request_id. + """ + + def __init__(self): + super().__init__() + self.completed_request_ids = [] + self.audio_start_ids = [] + self.expected_ids = [ + "seq_req_1", + "seq_req_2", + "seq_req_3", + ] + self.send_index = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + ten_env_tester.log_info("Sequential requests test started.") + self._send_next(ten_env_tester) + ten_env_tester.on_start_done() + + def _send_next(self, ten_env_tester: TenEnvTester) -> None: + if self.send_index >= len(self.expected_ids): + return + req_id = self.expected_ids[self.send_index] + tts_input = TTSTextInput( + request_id=req_id, + text=f"Hello from request {self.send_index + 1}.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + self.send_index += 1 + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_start": + json_str, _ = data.get_property_to_json("") + d = json.loads(json_str) if json_str else {} + rid = d.get("request_id", "") + self.audio_start_ids.append(rid) + elif name == "tts_audio_end": + json_str, _ = data.get_property_to_json("") + d = json.loads(json_str) if json_str else {} + rid = d.get("request_id", "") + self.completed_request_ids.append(rid) + ten_env.log_info(f"Completed request: {rid}") + if len(self.completed_request_ids) < len(self.expected_ids): + self._send_next(ten_env) + else: + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_sequential_requests(MockClient): + """Each sequential request should complete with its own + request_id in audio_start and audio_end.""" + MockClient.return_value = create_mock_client() + + tester = SequentialRequestsTester() + tester.set_test_mode_single("xai_tts_python", json.dumps(MOCK_CONFIG)) + tester.run() + + assert tester.completed_request_ids == [ + "seq_req_1", + "seq_req_2", + "seq_req_3", + ], ( + f"Expected 3 sequential completions, got " + f"{tester.completed_request_ids}" + ) + assert tester.audio_start_ids == [ + "seq_req_1", + "seq_req_2", + "seq_req_3", + ], f"audio_start ids mismatch: {tester.audio_start_ids}" + + +# ================ test reconnect after error ================ +class ReconnectAfterErrorTester(ExtensionTester): + """First request errors, second request should succeed. + + Validates that the client recovers after a mid-stream + failure. + """ + + def __init__(self): + super().__init__() + self.error_received = False + self.second_audio_end = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + # First request will trigger an error + tts_input = TTSTextInput( + request_id="err_req_1", + text="This will error.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "tts_audio_end": + if not self.error_received: + # First request ended (with error) — send + # second request + self.error_received = True + tts_input = TTSTextInput( + request_id="ok_req_2", + text="This should work.", + text_input_end=True, + ) + data2 = Data.create("tts_text_input") + data2.set_property_from_json(None, tts_input.model_dump_json()) + ten_env.send_data(data2) + else: + self.second_audio_end = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_reconnect_after_error(MockClient): + """After an error, subsequent requests should succeed.""" + call_count = 0 + + def create_mock(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + + fake_audio = b"\x00\x01" * 200 + + async def mock_get(text): + nonlocal call_count + call_count += 1 + if call_count == 1: + # First call: error + yield ( + b"Simulated error", + EVENT_TTS_ERROR, + ) + else: + # Subsequent calls: success + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + MockClient.return_value = create_mock() + + tester = ReconnectAfterErrorTester() + tester.set_test_mode_single("xai_tts_python", json.dumps(MOCK_CONFIG)) + tester.run() + + assert ( + tester.second_audio_end + ), "Second request should complete after first errored." + + +# ================ test config redaction ================ +def test_config_redacts_api_key(): + """to_str(sensitive_handling=True) must not leak the + API key.""" + config = XAITTSConfig( + params={ + "api_key": "super-secret-key-12345", + "voice_id": "eve", + } + ) + config.update_params() + + safe_str = config.to_str(sensitive_handling=True) + + assert "super-secret-key-12345" not in safe_str + assert "eve" in safe_str + + +# ================ test empty text yields END ================ +def test_client_empty_text_yields_end(): + """get() with empty text should yield EVENT_TTS_END + immediately without connecting.""" + + async def _run(): + ten_env = MagicMock() + ten_env.log_warn = MagicMock() + config = XAITTSConfig(api_key="test") + client = XAITTSClient(config=config, ten_env=ten_env) + + events = [] + async for data, event in client.get(""): + events.append(event) + + assert events == [EVENT_TTS_END] + assert client._ws is None # no connection made + + asyncio.run(_run()) + + +def test_connect_backoff_limit(): + async def _run(): + ten_env = MagicMock() + ten_env.log_info = MagicMock() + ten_env.log_warn = MagicMock() + ten_env.log_error = MagicMock() + + config = XAITTSConfig(api_key="xai-test-key") + client = XAITTSClient(config=config, ten_env=ten_env) + + with patch.object( + client, + "_connect", + AsyncMock(side_effect=RuntimeError("temporary outage")), + ): + try: + await client._connect_with_backoff("test") + except XAITTSConnectionException as exc: + assert exc.status_code == 503 + else: + raise AssertionError("Expected reconnect ceiling failure") + + assert client._connect_exp_cnt == 5 + + asyncio.run(_run()) + + +def test_client_whitespace_text_yields_end(): + """get() with whitespace-only text should yield + EVENT_TTS_END.""" + + async def _run(): + ten_env = MagicMock() + ten_env.log_warn = MagicMock() + config = XAITTSConfig(api_key="test") + client = XAITTSClient(config=config, ten_env=ten_env) + + events = [] + async for data, event in client.get(" \n\t "): + events.append(event) + + assert events == [EVENT_TTS_END] + + asyncio.run(_run()) + + +# ================ test 401 emits exactly one error ================ +class AuthErrorTester(ExtensionTester): + """Validates that a 401 auth failure emits exactly one + error event and one terminal audio_end.""" + + def __init__(self): + super().__init__() + self.error_count = 0 + self.audio_end_count = 0 + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + tts_input = TTSTextInput( + request_id="auth_err_req", + text="This should fail with 401.", + text_input_end=True, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "error": + self.error_count += 1 + elif name == "tts_audio_end": + self.audio_end_count += 1 + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_auth_error_single_emission(MockClient): + """401 should produce exactly 1 error event, not + duplicates.""" + from xai_tts_python.xai_tts import ( + XAITTSConnectionException, + ) + + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + + async def mock_get_auth_fail(text): + raise XAITTSConnectionException( + status_code=401, body="Unauthorized" + ) + yield # make it a generator # pragma: no cover + + mock.get.side_effect = mock_get_auth_fail + MockClient.return_value = mock + + tester = AuthErrorTester() + tester.set_test_mode_single("xai_tts_python", json.dumps(MOCK_CONFIG)) + tester.run() + + assert tester.error_count == 1, ( + f"Expected exactly 1 error event, got " f"{tester.error_count}" + ) + + +# ================ test non-final error contract ================ +class NonFinalErrorTester(ExtensionTester): + """Validates that an error on a non-final chunk does NOT + produce a public error event. Partial stream errors are + transient — only logged, not surfaced to callers.""" + + def __init__(self): + super().__init__() + self.error_count = 0 + self.audio_end_received = False + + def on_start(self, ten_env_tester: TenEnvTester) -> None: + # First chunk: non-final, will error + tts_input = TTSTextInput( + request_id="nonfinal_req", + text="First chunk errors.", + text_input_end=False, + ) + data = Data.create("tts_text_input") + data.set_property_from_json(None, tts_input.model_dump_json()) + ten_env_tester.send_data(data) + + # Second chunk: final, succeeds + tts_input2 = TTSTextInput( + request_id="nonfinal_req", + text="Second chunk works.", + text_input_end=True, + ) + data2 = Data.create("tts_text_input") + data2.set_property_from_json(None, tts_input2.model_dump_json()) + ten_env_tester.send_data(data2) + ten_env_tester.on_start_done() + + def on_data(self, ten_env: TenEnvTester, data) -> None: + name = data.get_name() + if name == "error": + self.error_count += 1 + elif name == "tts_audio_end": + self.audio_end_received = True + ten_env.stop_test() + + +@patch("xai_tts_python.extension.XAITTSClient") +def test_nonfinal_error_not_surfaced(MockClient): + """Error on non-final chunk should not emit public + error event. This is the intended contract: partial + stream errors are transient.""" + call_count = 0 + + def create_mock(): + mock = MagicMock() + mock.start = AsyncMock() + mock.stop = AsyncMock() + mock.cancel = AsyncMock() + mock.reset_ttfb = lambda: None + + fake_audio = b"\x00\x01" * 200 + + async def mock_get(text): + nonlocal call_count + call_count += 1 + if call_count == 1: + yield (b"Transient error", EVENT_TTS_ERROR) + else: + yield (100, EVENT_TTS_TTFB_METRIC) + yield (fake_audio, EVENT_TTS_RESPONSE) + yield (None, EVENT_TTS_END) + + mock.get.side_effect = mock_get + return mock + + MockClient.return_value = create_mock() + + tester = NonFinalErrorTester() + tester.set_test_mode_single("xai_tts_python", json.dumps(MOCK_CONFIG)) + tester.run() + + assert tester.error_count == 0, ( + f"Non-final error should not produce public error " + f"event, got {tester.error_count}" + ) + assert ( + tester.audio_end_received + ), "Request should still complete after non-final error" diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/xai_tts.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/xai_tts.py new file mode 100644 index 0000000000..9d17271c56 --- /dev/null +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/xai_tts.py @@ -0,0 +1,287 @@ +import asyncio +import base64 +import json +from datetime import datetime +from typing import AsyncIterator +from urllib.parse import urlencode + +import websockets +from websockets.asyncio.client import ClientConnection +from websockets.exceptions import InvalidStatus + +from ten_ai_base.const import LOG_CATEGORY_VENDOR +from ten_runtime import AsyncTenEnv + +from .config import XAITTSConfig + + +EVENT_TTS_RESPONSE = 1 +EVENT_TTS_END = 2 +EVENT_TTS_ERROR = 3 +EVENT_TTS_TTFB_METRIC = 5 + +WS_RECV_TIMEOUT = 8.0 +MAX_CONNECT_ATTEMPTS = 5 +BASE_BACKOFF_SECONDS = 0.5 + + +class XAITTSConnectionException(Exception): + def __init__(self, status_code: int, body: str): + self.status_code = status_code + self.body = body + super().__init__( + f"xAI TTS connection failed (code: {status_code}): {body}" + ) + + +class XAITTSClient: + def __init__(self, config: XAITTSConfig, ten_env: AsyncTenEnv): + self.config = config + self.ten_env = ten_env + self._ws: ClientConnection | None = None + self._is_cancelled = False + self._needs_reconnect = False + self._sent_ts: datetime | None = None + self._ttfb_sent = False + self._ws_url = self._build_ws_url() + self._connect_exp_cnt = 0 + + def _build_ws_url(self) -> str: + def _encode_query_value(value: str | int | bool) -> str | int: + if isinstance(value, bool): + return str(value).lower() + return value + + query_params: dict[str, str | int | bool] = { + "voice": self.config.voice_id, + "language": self.config.language, + "codec": self.config.codec, + "sample_rate": self.config.sample_rate, + } + if self.config.codec == "mp3": + query_params["bit_rate"] = self.config.bit_rate + query_params["optimize_streaming_latency"] = ( + self.config.optimize_streaming_latency + ) + query_params["text_normalization"] = ( + str(self.config.text_normalization).lower() == "true" + if isinstance(self.config.text_normalization, str) + else self.config.text_normalization + ) + for key, value in self.config.params.items(): + if key in { + "api_key", + "base_url", + "voice_id", + "language", + "codec", + "sample_rate", + "bit_rate", + "optimize_streaming_latency", + "text_normalization", + }: + continue + if value is not None: + query_params[key] = value + encoded_query_params = { + key: _encode_query_value(value) + for key, value in query_params.items() + } + return ( + f"{self.config.base_url}?" + f"{urlencode(encoded_query_params, doseq=True)}" + ) + + async def start(self) -> None: + await self._connect_with_backoff("preheat") + + async def stop(self) -> None: + self._is_cancelled = True + if self._ws: + try: + self.ten_env.log_info( + "vendor_status_changed: closing xai tts websocket", + category=LOG_CATEGORY_VENDOR, + ) + await self._ws.close() + finally: + self._ws = None + + async def cancel(self) -> None: + self._is_cancelled = True + self.reset_ttfb() + self._needs_reconnect = True + if self._ws: + try: + self.ten_env.log_info( + "vendor_status_changed: cancelling xai tts websocket", + category=LOG_CATEGORY_VENDOR, + ) + await self._ws.close() + finally: + self._ws = None + + def reset_ttfb(self) -> None: + self._sent_ts = None + self._ttfb_sent = False + + async def get( + self, text: str + ) -> AsyncIterator[tuple[bytes | int | None, int]]: + if len(text.strip()) == 0: + yield None, EVENT_TTS_END + return + + if self._needs_reconnect: + await self._reconnect() + self._needs_reconnect = False + + await self._ensure_connection() + self._is_cancelled = False + if not self._ttfb_sent: + self._sent_ts = datetime.now() + + await self._ws.send( + json.dumps({"type": "text.delta", "delta": text}) + ) + await self._ws.send(json.dumps({"type": "text.done"})) + + try: + while True: + if self._is_cancelled: + break + + try: + message = await asyncio.wait_for( + self._ws.recv(), timeout=WS_RECV_TIMEOUT + ) + except asyncio.TimeoutError: + self._needs_reconnect = True + yield b"Timeout waiting for xAI audio", EVENT_TTS_ERROR + break + + if isinstance(message, bytes): + self.ten_env.log_warn( + "Unexpected binary frame from xAI TTS; ignoring" + ) + continue + + try: + event = json.loads(message) + except json.JSONDecodeError: + self.ten_env.log_warn( + f"Failed to parse xAI TTS frame: {message}" + ) + continue + + event_type = event.get("type", "") + if event_type == "audio.delta": + if self._sent_ts and not self._ttfb_sent: + yield ( + int( + (datetime.now() - self._sent_ts).total_seconds() + * 1000 + ), + EVENT_TTS_TTFB_METRIC, + ) + self._ttfb_sent = True + audio_chunk = base64.b64decode(event.get("delta", "")) + yield audio_chunk, EVENT_TTS_RESPONSE + elif event_type == "audio.done": + yield None, EVENT_TTS_END + break + elif event_type == "error": + self._needs_reconnect = True + yield ( + str(event.get("message", "Unknown error")).encode( + "utf-8" + ), + EVENT_TTS_ERROR, + ) + break + + except Exception as e: + self.ten_env.log_error( + f"vendor_error: {e}", category=LOG_CATEGORY_VENDOR + ) + self._needs_reconnect = True + yield str(e).encode("utf-8"), EVENT_TTS_ERROR + + async def _connect(self) -> None: + try: + self.ten_env.log_info( + "vendor_status_changed: connecting to xai tts", + category=LOG_CATEGORY_VENDOR, + ) + self._ws = await websockets.connect( + self._ws_url, + additional_headers={ + "Authorization": f"Bearer {self.config.api_key}" + }, + ) + self._connect_exp_cnt = 0 + self.ten_env.log_info( + "vendor_status: connected to xai tts", + category=LOG_CATEGORY_VENDOR, + ) + except InvalidStatus as e: + raise XAITTSConnectionException( + status_code=e.response.status_code, + body=str(e), + ) from e + except Exception as e: + error_message = str(e) + if "401" in error_message or "Unauthorized" in error_message: + raise XAITTSConnectionException( + status_code=401, body=error_message + ) from e + raise + + async def _connect_with_backoff(self, reason: str) -> None: + last_error: Exception | None = None + while self._connect_exp_cnt < MAX_CONNECT_ATTEMPTS: + try: + await self._connect() + return + except XAITTSConnectionException as e: + if e.status_code in {401, 403}: + raise + last_error = e + except Exception as e: + last_error = e + + self._connect_exp_cnt += 1 + if self._connect_exp_cnt >= MAX_CONNECT_ATTEMPTS: + break + + backoff_seconds = min( + BASE_BACKOFF_SECONDS * (2 ** (self._connect_exp_cnt - 1)), + 4.0, + ) + self.ten_env.log_info( + f"vendor_status_changed: retrying xai tts websocket " + f"after {reason} failure in {backoff_seconds:.2f}s " + f"(attempt {self._connect_exp_cnt}/{MAX_CONNECT_ATTEMPTS})", + category=LOG_CATEGORY_VENDOR, + ) + await asyncio.sleep(backoff_seconds) + + message = ( + str(last_error) + if last_error is not None + else "xAI TTS connection failed" + ) + raise XAITTSConnectionException(status_code=503, body=message) + + async def _ensure_connection(self) -> None: + if not self._ws: + await self._connect_with_backoff("connect") + + async def _reconnect(self) -> None: + if self._ws: + try: + await self._ws.close() + except Exception: + pass + self._ws = None + await self._connect_with_backoff("reconnect") From 871f777be120e9b8a88cfb5194dd4abe478aeec7 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 Apr 2026 09:56:07 +0000 Subject: [PATCH 2/5] fix(ai-agents): route tts lifecycle events in xai graphs --- .../voice-assistant/tenapp/property.json | 64 +++++++++++++++++++ .../extension/main_python/agent/agent.py | 2 + 2 files changed, 66 insertions(+) diff --git a/ai_agents/agents/examples/voice-assistant/tenapp/property.json b/ai_agents/agents/examples/voice-assistant/tenapp/property.json index 2f9f59a9b5..f2c2a936ac 100644 --- a/ai_agents/agents/examples/voice-assistant/tenapp/property.json +++ b/ai_agents/agents/examples/voice-assistant/tenapp/property.json @@ -135,6 +135,22 @@ "extension": "stt" } ] + }, + { + "name": "tts_audio_start", + "source": [ + { + "extension": "tts" + } + ] + }, + { + "name": "tts_audio_end", + "source": [ + { + "extension": "tts" + } + ] } ] }, @@ -322,6 +338,22 @@ "extension": "stt" } ] + }, + { + "name": "tts_audio_start", + "source": [ + { + "extension": "tts" + } + ] + }, + { + "name": "tts_audio_end", + "source": [ + { + "extension": "tts" + } + ] } ] }, @@ -507,6 +539,22 @@ "extension": "stt" } ] + }, + { + "name": "tts_audio_start", + "source": [ + { + "extension": "tts" + } + ] + }, + { + "name": "tts_audio_end", + "source": [ + { + "extension": "tts" + } + ] } ] }, @@ -695,6 +743,22 @@ "extension": "stt" } ] + }, + { + "name": "tts_audio_start", + "source": [ + { + "extension": "tts" + } + ] + }, + { + "name": "tts_audio_end", + "source": [ + { + "extension": "tts" + } + ] } ] }, diff --git a/ai_agents/agents/examples/voice-assistant/tenapp/ten_packages/extension/main_python/agent/agent.py b/ai_agents/agents/examples/voice-assistant/tenapp/ten_packages/extension/main_python/agent/agent.py index ef61df2621..5a4b4e9b4e 100644 --- a/ai_agents/agents/examples/voice-assistant/tenapp/ten_packages/extension/main_python/agent/agent.py +++ b/ai_agents/agents/examples/voice-assistant/tenapp/ten_packages/extension/main_python/agent/agent.py @@ -148,6 +148,8 @@ async def on_data(self, data: Data): metadata=asr.get("metadata", {}), ) ) + elif data.get_name() in ("tts_audio_start", "tts_audio_end"): + return else: self.ten_env.log_warn(f"Unhandled data: {data.get_name()}") except Exception as e: From 88d8ca497a7a37f8f0dd173abde43bd136a0397f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Fri, 24 Apr 2026 14:51:17 +0000 Subject: [PATCH 3/5] fix(ai-agents): stabilize xai asr guarder flows (#2146) --- .../extension/xai_asr_python/extension.py | 26 +++++++++++- .../xai_asr_python/reconnect_manager.py | 40 ++++++++++++++----- .../xai_asr_python/tests/test_finalize.py | 21 ++++++++++ .../xai_asr_python/tests/test_reconnect.py | 28 ++++++++++--- 4 files changed, 98 insertions(+), 17 deletions(-) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py index 051c10a95e..06375ad298 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py @@ -57,7 +57,12 @@ def vendor(self) -> str: @override async def on_init(self, ten_env: AsyncTenEnv) -> None: await super().on_init(ten_env) - self.reconnect_manager = ReconnectManager(logger=ten_env) + # Keep retries bounded, but use a ceiling high enough that the + # integration guarder observes the non-fatal retry behavior before + # terminal escalation. + self.reconnect_manager = ReconnectManager( + logger=ten_env, max_attempts=10 + ) config_json, _ = await ten_env.get_property_to_json("") try: self.config = XAIASRConfig.model_validate_json(config_json) @@ -152,8 +157,16 @@ def _is_fatal_connection_error(error_message: str) -> bool: async def finalize(self, _session_id: str | None) -> None: assert self.config is not None self.last_finalize_timestamp = int(datetime.now().timestamp() * 1000) + if not self.recognition or not self.is_connected(): + self.ten_env.log_warn( + "asr_finalize: service not connected.", + category=LOG_CATEGORY_KEY_POINT, + ) + await self._finalize_end() + return + self._close_expected = True - if self.recognition: + try: await self.recognition.send_audio_done() payload = await self.recognition.wait_for_done( self.config.finalize_timeout_ms @@ -162,6 +175,13 @@ async def finalize(self, _session_id: str | None) -> None: await self._emit_asr_result(payload, final=True, locked=False) elif not self.recognition.done_event.is_set(): self._close_expected = False + except asyncio.CancelledError: + self.ten_env.log_warn( + "asr_finalize: wait for transcript.done was cancelled.", + category=LOG_CATEGORY_KEY_POINT, + ) + self._close_expected = False + finally: await self._finalize_end() async def _finalize_end(self) -> None: @@ -368,6 +388,8 @@ async def _handle_reconnect(self) -> None: success = await self.reconnect_manager.handle_reconnect( connection_func=self._connect_recognition, error_handler=self.send_asr_error, + vendor_name=self.vendor(), + vendor_code="connect_failed", ) if success: diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py index caa826b4c5..3e615a5114 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/reconnect_manager.py @@ -1,6 +1,10 @@ import asyncio from typing import Callable, Awaitable, Optional -from ten_ai_base.message import ModuleError, ModuleErrorCode +from ten_ai_base.message import ( + ModuleError, + ModuleErrorCode, + ModuleErrorVendorInfo, +) from .const import MODULE_NAME_ASR @@ -59,6 +63,8 @@ async def handle_reconnect( error_handler: Optional[ Callable[[ModuleError], Awaitable[None]] ] = None, + vendor_name: str | None = None, + vendor_code: str = "connect_failed", ) -> bool: """ Handle a single reconnection attempt with backoff delay. @@ -106,16 +112,30 @@ async def handle_reconnect( ) if error_handler: - await error_handler( - ModuleError( - module=MODULE_NAME_ASR, - code=( - ModuleErrorCode.FATAL_ERROR.value - if is_fatal - else ModuleErrorCode.NON_FATAL_ERROR.value - ), - message=f"Reconnection attempt #{self.attempts} failed: {str(e)}", + error = ModuleError( + module=MODULE_NAME_ASR, + code=( + ModuleErrorCode.FATAL_ERROR.value + if is_fatal + else ModuleErrorCode.NON_FATAL_ERROR.value + ), + message=f"Reconnection attempt #{self.attempts} failed: {str(e)}", + ) + vendor_info = ( + ModuleErrorVendorInfo( + vendor=vendor_name, + code=vendor_code, + message=str(e), ) + if vendor_name + else None ) + if vendor_info is not None: + try: + await error_handler(error, vendor_info) + except TypeError: + await error_handler(error) + else: + await error_handler(error) return False diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py index 28a7c504cb..f6dd889287 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_finalize.py @@ -55,3 +55,24 @@ async def _run(): extension.send_asr_finalize_end.assert_awaited_once() asyncio.run(_run()) + + +def test_finalize_when_disconnected_emits_finalize_end_without_waiting(): + async def _run(): + extension = XAIASRExtension("xai_asr_python") + extension.ten_env = MagicMock() + extension.config = XAIASRConfig( + finalize_timeout_ms=10, + params={"api_key": "xai-test-key"}, + ) + extension.recognition = MagicMock() + extension.recognition.is_connected.return_value = False + extension.send_asr_finalize_end = AsyncMock() + + await extension.finalize("session-123") + + assert extension.last_finalize_timestamp == 0 + extension.send_asr_finalize_end.assert_awaited_once() + extension.recognition.send_audio_done.assert_not_called() + + asyncio.run(_run()) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py index 5760efe5c4..8d5c56ba0d 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/test_reconnect.py @@ -20,18 +20,28 @@ async def _run(): async def failing_connect(): raise RuntimeError("disconnect") - async def error_handler(error): - errors.append(error.code) + async def error_handler(error, vendor_info=None): + errors.append((error.code, vendor_info)) for _ in range(4): - await manager.handle_reconnect(failing_connect, error_handler) - - assert errors == [ + await manager.handle_reconnect( + failing_connect, + error_handler, + vendor_name="xai", + vendor_code="connect_failed", + ) + + assert [code for code, _ in errors] == [ int(ModuleErrorCode.NON_FATAL_ERROR.value), int(ModuleErrorCode.NON_FATAL_ERROR.value), int(ModuleErrorCode.NON_FATAL_ERROR.value), int(ModuleErrorCode.FATAL_ERROR.value), ] + assert all(vendor_info is not None for _, vendor_info in errors) + assert all(vendor_info.vendor == "xai" for _, vendor_info in errors) + assert all( + vendor_info.code == "connect_failed" for _, vendor_info in errors + ) asyncio.run(_run()) @@ -70,5 +80,13 @@ async def _run(): int(ModuleErrorCode.NON_FATAL_ERROR.value), int(ModuleErrorCode.FATAL_ERROR.value), ] + observed_vendor_infos = [ + call.args[1] for call in extension.send_asr_error.await_args_list + ] + assert all(vendor_info.vendor == "xai" for vendor_info in observed_vendor_infos) + assert all( + vendor_info.code == "connect_failed" + for vendor_info in observed_vendor_infos + ) asyncio.run(_run()) From 50f116c7082858775c9461f982032c8896871da6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 30 Apr 2026 10:05:53 +0000 Subject: [PATCH 4/5] fix(ai-agents): address xai speech review feedback (#2146) --- .../extension/xai_asr_python/__init__.py | 5 + .../extension/xai_asr_python/config.py | 3 +- .../extension/xai_asr_python/extension.py | 30 +++-- .../extension/xai_asr_python/manifest.json | 9 ++ .../extension/xai_asr_python/recognition.py | 10 +- .../extension/xai_asr_python/requirements.txt | 2 +- .../extension/xai_asr_python/tests/bin/start | 2 +- .../extension/xai_tts_python/config.py | 5 +- .../extension/xai_tts_python/extension.py | 108 +++++++++++------- .../xai_tts_python/tests/test_basic.py | 8 +- .../xai_tts_python/tests/test_error_msg.py | 2 +- .../xai_tts_python/tests/test_metrics.py | 2 +- .../xai_tts_python/tests/test_params.py | 8 +- .../xai_tts_python/tests/test_robustness.py | 10 +- .../tests/test_state_machine.py | 6 +- 15 files changed, 133 insertions(+), 77 deletions(-) diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py index f3c731cdd5..72593ab225 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/__init__.py @@ -1 +1,6 @@ +# +# This file is part of TEN Framework, an open source project. +# Licensed under the Apache License, Version 2.0. +# See the LICENSE file for more information. +# from . import addon diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py index 1c72232e60..7a9dc72066 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/config.py @@ -1,4 +1,5 @@ from typing import Any +import json from pydantic import BaseModel, Field from ten_ai_base.utils import encrypt @@ -54,7 +55,7 @@ def to_json(self, sensitive_handling: bool = False) -> str: api_key = config_dict["params"].get("api_key") if api_key: config_dict["params"]["api_key"] = encrypt(api_key) - return str(config_dict) + return json.dumps(config_dict) @property def normalized_language(self) -> str: diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py index 06375ad298..ff49b5aac6 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/extension.py @@ -107,7 +107,8 @@ async def start_connection(self) -> None: try: await self._connect_recognition() except Exception as e: - fatal = self._is_fatal_connection_error(str(e)) + status_code = self._extract_connection_status_code(e) + fatal = self._is_fatal_connection_error(status_code) self.ten_env.log_error(f"Failed to start xAI STT connection: {e}") await self.send_asr_error( ModuleError( @@ -121,7 +122,11 @@ async def start_connection(self) -> None: ), ModuleErrorVendorInfo( vendor=self.vendor(), - code="connect_failed", + code=( + str(status_code) + if status_code is not None + else "connect_failed" + ), message=str(e), ), ) @@ -146,12 +151,19 @@ async def _connect_recognition(self) -> None: await self.recognition.start(timeout=10) @staticmethod - def _is_fatal_connection_error(error_message: str) -> bool: - normalized = error_message.lower() - return any( - token in normalized - for token in ("401", "403", "unauthorized", "forbidden", "api key") - ) + def _extract_connection_status_code(error: Exception) -> int | None: + status_code = getattr(error, "status_code", None) + if isinstance(status_code, int): + return status_code + response = getattr(error, "response", None) + status_code = getattr(response, "status_code", None) + if isinstance(status_code, int): + return status_code + return None + + @staticmethod + def _is_fatal_connection_error(error_code: int | None) -> bool: + return error_code in {401, 403} @override async def finalize(self, _session_id: str | None) -> None: @@ -351,7 +363,7 @@ async def on_error( f"vendor_error: code: {error_code}, reason: {error_msg}", category=LOG_CATEGORY_VENDOR, ) - fatal = "401" in error_msg or "Unauthorized" in error_msg + fatal = self._is_fatal_connection_error(error_code) await self.send_asr_error( ModuleError( module=MODULE_NAME_ASR, diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json b/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json index 0788756370..24dd8b4a1c 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/manifest.json @@ -22,6 +22,15 @@ ], "property": { "properties": { + "dump": { + "type": "bool" + }, + "dump_path": { + "type": "string" + }, + "finalize_timeout_ms": { + "type": "int32" + }, "params": { "type": "object", "properties": { diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py b/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py index b621d3976c..1af7b630d9 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/recognition.py @@ -54,6 +54,7 @@ def __init__( self.done_event = asyncio.Event() self.done_payload: dict[str, Any] | None = None self._message_task: asyncio.Task | None = None + self._open_notified = False def _build_url(self) -> str: base_url = self.config.get("base_url", "wss://api.x.ai/v1/stt") @@ -89,6 +90,7 @@ async def start(self, timeout: int = 10) -> None: self.ready_event.clear() self.done_event.clear() self.done_payload = None + self._open_notified = False first_message = await asyncio.wait_for(self.websocket.recv(), timeout=timeout) if isinstance(first_message, bytes): raise RuntimeError("Unexpected binary message during xAI STT startup") @@ -102,6 +104,7 @@ async def start(self, timeout: int = 10) -> None: f"Unexpected xAI STT startup event: {first_event.get('type')}" ) self.ready_event.set() + self._open_notified = True await self.callback.on_open() self._message_task = asyncio.create_task(self._message_handler()) @@ -118,7 +121,9 @@ async def _message_handler(self) -> None: event_type = event.get("type", "") if event_type == "transcript.created": self.ready_event.set() - await self.callback.on_open() + if not self._open_notified: + self._open_notified = True + await self.callback.on_open() elif event_type == "transcript.partial": await self.callback.on_partial_result(event) elif event_type == "transcript.done": @@ -127,7 +132,8 @@ async def _message_handler(self) -> None: await self.callback.on_done(event) elif event_type == "error": await self.callback.on_error( - str(event.get("message", "Unknown error")) + str(event.get("message", "Unknown error")), + event.get("code"), ) except websockets.exceptions.ConnectionClosed as e: self.ten_env.log_info(f"xAI STT websocket closed: {e}") diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt b/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt index 0fd8ae23c6..61366b210c 100644 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/requirements.txt @@ -1,2 +1,2 @@ websockets>=15.0.1 -pydantic \ No newline at end of file +pydantic diff --git a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start index f6a1cf283d..b736ea0de1 100755 --- a/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start +++ b/ai_agents/agents/ten_packages/extension/xai_asr_python/tests/bin/start @@ -18,4 +18,4 @@ export PYTHONPATH=.ten/app:.ten/app/ten_packages/system/ten_runtime_python/lib:. # # Refer to https://github.com/pytorch/pytorch/issues/102360?from_wecom=1#issuecomment-1708989096 -pytest -s tests/ "$@" \ No newline at end of file +pytest -s tests/ "$@" diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py index c713373517..d23465f7b3 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/config.py @@ -52,10 +52,7 @@ def update_params(self) -> None: def validate(self) -> None: if not self.api_key: raise ValueError("API key is required") - if not ( - self.api_key.startswith("xai-") - or self.api_key.startswith("test") - ): + if not self.api_key.startswith("xai-"): raise ValueError("API key must start with 'xai-'") if self.sample_rate not in {8000, 16000, 22050, 24000, 44100, 48000}: raise ValueError(f"Unsupported sample rate: {self.sample_rate}") diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py index 0ca87604ab..f9905b5088 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py @@ -13,6 +13,7 @@ ) from ten_ai_base.struct import TTSTextInput, TTSTextResult from ten_ai_base.tts2 import AsyncTTS2BaseExtension +from websockets.protocol import State from ten_runtime import AsyncTenEnv from .config import XAITTSConfig @@ -44,6 +45,7 @@ def __init__(self, name: str) -> None: self._request_metadata: dict = {} self._request_seq_id_map: dict[str, int] = {} self._audio_start_timestamp_ms = 0 + self._request_event_interval_ms = 0 @staticmethod def _contains_spoken_content(text: str) -> bool: @@ -114,11 +116,6 @@ def synthesize_audio_sample_width(self) -> int: def _create_client(self, ten_env: AsyncTenEnv) -> XAITTSClient: return XAITTSClient(config=self.config, ten_env=ten_env) - async def _ensure_client(self) -> None: - if self.client is None: - self.client = self._create_client(self.ten_env) - await self.client.start() - async def _reconnect_client(self) -> None: if self.client: await self.client.stop() @@ -155,6 +152,7 @@ async def _finalize_request( self._request_text = "" self._request_metadata = {} self._audio_start_timestamp_ms = 0 + self._request_event_interval_ms = 0 def _calculate_audio_duration_ms(self) -> int: bytes_per_sample = self.synthesize_audio_sample_width() @@ -167,9 +165,7 @@ def _calculate_audio_duration_ms(self) -> int: return int(duration_sec * 1000) def _calculate_request_event_interval_ms(self) -> int: - if self.sent_ts is None: - return 0 - return int((datetime.now() - self.sent_ts).total_seconds() * 1000) + return self._request_event_interval_ms async def request_tts(self, t: TTSTextInput) -> None: try: @@ -186,6 +182,7 @@ async def request_tts(self, t: TTSTextInput) -> None: self._request_text_length = 0 self._request_text = "" self._audio_start_timestamp_ms = 0 + self._request_event_interval_ms = 0 self._request_metadata = t.metadata.copy() if t.metadata else {} if t.metadata is not None: self.session_id = t.metadata.get("session_id", "") @@ -201,7 +198,7 @@ async def request_tts(self, t: TTSTextInput) -> None: if ( t.text_input_end and prepared_text - and not self._request_text + and self._request_text_length == 0 and not self._contains_spoken_content(prepared_text) ): error = ModuleError( @@ -226,6 +223,7 @@ async def request_tts(self, t: TTSTextInput) -> None: self.total_audio_bytes = 0 self.sent_ts = None self._audio_start_sent = False + self._request_event_interval_ms = 0 return if prepared_text: self._request_text_length += len(prepared_text) @@ -247,7 +245,7 @@ async def request_tts(self, t: TTSTextInput) -> None: await self._process_tts_text(prepared_text, t) except XAITTSConnectionException as e: - await self._handle_connection_error(e) + await self._handle_connection_error(e, t.text_input_end) except Exception as e: self.ten_env.log_error( f"Error in request_tts: {traceback.format_exc()}. text: {t.text}" @@ -258,7 +256,15 @@ async def request_tts(self, t: TTSTextInput) -> None: code=ModuleErrorCode.NON_FATAL_ERROR, vendor_info=ModuleErrorVendorInfo(vendor=self.vendor()), ) - await self._finalize_request(TTSAudioEndReason.ERROR, error=error) + if t.text_input_end: + await self._finalize_request( + TTSAudioEndReason.ERROR, error=error + ) + else: + await self.send_tts_error( + request_id=t.request_id, + error=error, + ) await self._reconnect_client() async def _process_tts_text(self, text: str, t: TTSTextInput) -> None: @@ -278,7 +284,7 @@ async def _process_tts_text(self, text: str, t: TTSTextInput) -> None: ) elif event_status == EVENT_TTS_TTFB_METRIC: if isinstance(data_msg, int): - self.sent_ts = datetime.now() + self._request_event_interval_ms = data_msg await self.send_tts_audio_start( request_id=self.current_request_id ) @@ -323,6 +329,46 @@ def _get_next_audio_chunk_timestamp_ms(self) -> int: self._audio_start_timestamp_ms = int(datetime.now().timestamp() * 1000) return self._audio_start_timestamp_ms + self._calculate_audio_duration_ms() + async def _handle_connection_error( + self, e: XAITTSConnectionException, text_input_end: bool + ) -> None: + error_code = ( + ModuleErrorCode.FATAL_ERROR + if e.status_code == 401 + else ModuleErrorCode.NON_FATAL_ERROR + ) + error = ModuleError( + message=str(e), + module=ModuleType.TTS, + code=error_code, + vendor_info=ModuleErrorVendorInfo( + vendor=self.vendor(), + code=str(e.status_code), + message=e.body, + ), + ) + if text_input_end: + await self._finalize_request( + TTSAudioEndReason.ERROR, error=error + ) + else: + await self.send_tts_error( + request_id=self.current_request_id or "", + error=error, + ) + + async def _setup_recorder(self, request_id: str) -> None: + if self.config and self.config.dump: + dump_path = os.path.join( + self.config.dump_path, f"{request_id}_xai_tts_out.pcm" + ) + os.makedirs(os.path.dirname(dump_path), exist_ok=True) + self.recorder_map[request_id] = PCMWriter(dump_path) + + async def _write_dump(self, data_msg: bytes) -> None: + if self.current_request_id in self.recorder_map: + await self.recorder_map[self.current_request_id].write(data_msg) + async def _emit_tts_text_result(self, reason: TTSAudioEndReason) -> None: if not self.current_request_id or not self._request_text: return @@ -351,34 +397,14 @@ async def _emit_tts_text_result(self, reason: TTSAudioEndReason) -> None: self.metrics_add_output_characters(len(self._request_text)) await self.send_tts_text_result(transcript_result) - async def _handle_connection_error( - self, e: XAITTSConnectionException - ) -> None: - error_code = ( - ModuleErrorCode.FATAL_ERROR - if e.status_code == 401 - else ModuleErrorCode.NON_FATAL_ERROR - ) - error = ModuleError( - message=str(e), - module=ModuleType.TTS, - code=error_code, - vendor_info=ModuleErrorVendorInfo( - vendor=self.vendor(), - code=str(e.status_code), - message=e.body, - ), - ) - await self._finalize_request(TTSAudioEndReason.ERROR, error=error) + async def _ensure_client(self) -> None: + if self.client is None: + self.client = self._create_client(self.ten_env) + await self.client.start() + return - async def _setup_recorder(self, request_id: str) -> None: - if self.config and self.config.dump: - dump_path = os.path.join( - self.config.dump_path, f"{request_id}_xai_tts_out.pcm" - ) - os.makedirs(os.path.dirname(dump_path), exist_ok=True) - self.recorder_map[request_id] = PCMWriter(dump_path) + ws = getattr(self.client, "_ws", None) + if ws is not None and getattr(ws, "state", None) == State.OPEN: + return - async def _write_dump(self, data_msg: bytes) -> None: - if self.current_request_id in self.recorder_map: - await self.recorder_map[self.current_request_id].write(data_msg) + await self._reconnect_client() diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py index 4158865d7e..ece24806cf 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_basic.py @@ -114,7 +114,7 @@ async def mock_get_audio_stream(text: str): "dump": True, "dump_path": DUMP_PATH, "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -170,7 +170,7 @@ def on_start(self, ten_env_tester: TenEnvTester) -> None: tts_input = TTSTextInput( request_id="tts_request_basic", - text="Hello, this is a test of the Deepgram TTS extension.", + text="Hello, this is a test of the xAI TTS extension.", text_input_end=True, ) data = Data.create("tts_text_input") @@ -216,7 +216,7 @@ async def mock_get_audio_stream(text: str): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -298,7 +298,7 @@ async def mock_get_audio_stream(text: str): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py index dfcd8f9d4f..e1fffd7bb6 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_error_msg.py @@ -136,7 +136,7 @@ def test_invalid_api_key_error(mock_websocket_connect): # Config with invalid API key invalid_key_config = { "params": { - "api_key": "invalid_api_key_test", + "api_key": "xai-invalid-api-key-test", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py index cba8beecf3..dd343b5d3c 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_metrics.py @@ -101,7 +101,7 @@ async def mock_get_audio_with_delay(text: str): # --- Test Setup --- metrics_config = { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py index 59aeb7a6c1..618d869dc4 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_params.py @@ -44,7 +44,7 @@ def test_params_passthrough(): """Additional xAI params should be appended to the websocket URL.""" config = XAITTSConfig( params={ - "api_key": "test_api_key", + "api_key": "xai-test-key", "base_url": "wss://api.x.ai/v1/tts", "voice_id": "eve", "language": "en", @@ -115,7 +115,7 @@ def test_sample_rate_16000(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 16000, @@ -141,7 +141,7 @@ def test_sample_rate_24000(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -167,7 +167,7 @@ def test_sample_rate_48000(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 48000, diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py index facfa5a641..b8c7383d96 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_robustness.py @@ -76,7 +76,7 @@ def test_empty_text(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -130,7 +130,7 @@ def test_whitespace_text(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -178,7 +178,7 @@ def test_punctuation_only_text(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -236,7 +236,7 @@ def test_long_text(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -296,7 +296,7 @@ def test_special_characters(MockXAITTSClient): json.dumps( { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py index ef2b890366..7542b6bd01 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/test_state_machine.py @@ -26,7 +26,7 @@ MOCK_CONFIG = { "params": { - "api_key": "test_api_key", + "api_key": "xai-test-key", "voice_id": "eve", "codec": "pcm", "sample_rate": 24000, @@ -247,7 +247,7 @@ def test_client_empty_text_yields_end(): async def _run(): ten_env = MagicMock() ten_env.log_warn = MagicMock() - config = XAITTSConfig(api_key="test") + config = XAITTSConfig(api_key="xai-test-key") client = XAITTSClient(config=config, ten_env=ten_env) events = [] @@ -294,7 +294,7 @@ def test_client_whitespace_text_yields_end(): async def _run(): ten_env = MagicMock() ten_env.log_warn = MagicMock() - config = XAITTSConfig(api_key="test") + config = XAITTSConfig(api_key="xai-test-key") client = XAITTSClient(config=config, ten_env=ten_env) events = [] From 6f6e9f50795f3c15f5de6f9fa2d9628e57ecd5bc Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 30 Apr 2026 11:44:43 +0000 Subject: [PATCH 5/5] fix(ai-agents): remove xai tts text result emission (#2146) --- .../tests/test_subtitle_alignment.py | 6 ++- .../extension/xai_tts_python/extension.py | 44 +------------------ .../configs/property_subtitle_alignment.json | 12 ----- 3 files changed, 7 insertions(+), 55 deletions(-) delete mode 100644 ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json diff --git a/ai_agents/agents/integration_tests/tts_guarder/tests/test_subtitle_alignment.py b/ai_agents/agents/integration_tests/tts_guarder/tests/test_subtitle_alignment.py index 3bc890b4f2..5eaaeb27d8 100644 --- a/ai_agents/agents/integration_tests/tts_guarder/tests/test_subtitle_alignment.py +++ b/ai_agents/agents/integration_tests/tts_guarder/tests/test_subtitle_alignment.py @@ -19,6 +19,7 @@ import os import time import asyncio +import pytest TTS_SUBTITLE_CONFIG_FILE = "property_subtitle_alignment.json" @@ -314,7 +315,10 @@ def test_subtitle_alignment(extension_name: str, config_dir: str) -> None: """Verify TTS subtitle alignment with audio frames.""" config_file_path = os.path.join(config_dir, TTS_SUBTITLE_CONFIG_FILE) if not os.path.exists(config_file_path): - raise FileNotFoundError(f"Config file not found: {config_file_path}") + pytest.skip( + f"Config file not found: {config_file_path}. " + "Subtitle alignment is optional for providers without text timing." + ) with open(config_file_path, "r") as f: config: dict[str, Any] = json.load(f) diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py b/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py index f9905b5088..33f8f1f567 100644 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py +++ b/ai_agents/agents/ten_packages/extension/xai_tts_python/extension.py @@ -11,7 +11,7 @@ ModuleType, TTSAudioEndReason, ) -from ten_ai_base.struct import TTSTextInput, TTSTextResult +from ten_ai_base.struct import TTSTextInput from ten_ai_base.tts2 import AsyncTTS2BaseExtension from websockets.protocol import State from ten_runtime import AsyncTenEnv @@ -41,9 +41,6 @@ def __init__(self, name: str) -> None: self.recorder_map: dict[str, PCMWriter] = {} self._audio_start_sent = False self._request_text_length = 0 - self._request_text = "" - self._request_metadata: dict = {} - self._request_seq_id_map: dict[str, int] = {} self._audio_start_timestamp_ms = 0 self._request_event_interval_ms = 0 @@ -126,7 +123,6 @@ async def _reconnect_client(self) -> None: async def _finalize_request( self, reason: TTSAudioEndReason, error: ModuleError | None = None ) -> None: - await self._emit_tts_text_result(reason) if not self._audio_start_sent: await self.send_tts_audio_start(request_id=self.current_request_id) self._audio_start_sent = True @@ -147,10 +143,6 @@ async def _finalize_request( error=error, ) self.sent_ts = None - if self.current_request_id: - self._request_seq_id_map.pop(self.current_request_id, None) - self._request_text = "" - self._request_metadata = {} self._audio_start_timestamp_ms = 0 self._request_event_interval_ms = 0 @@ -180,10 +172,8 @@ async def request_tts(self, t: TTSTextInput) -> None: self.sent_ts = None self._audio_start_sent = False self._request_text_length = 0 - self._request_text = "" self._audio_start_timestamp_ms = 0 self._request_event_interval_ms = 0 - self._request_metadata = t.metadata.copy() if t.metadata else {} if t.metadata is not None: self.session_id = t.metadata.get("session_id", "") self.current_turn_id = t.metadata.get("turn_id", -1) @@ -217,8 +207,6 @@ async def request_tts(self, t: TTSTextInput) -> None: error=error, ) self.current_request_finished = True - self._request_text = "" - self._request_metadata = {} self._request_text_length = 0 self.total_audio_bytes = 0 self.sent_ts = None @@ -229,8 +217,8 @@ async def request_tts(self, t: TTSTextInput) -> None: self._request_text_length += len(prepared_text) if self._request_text_length > 15000: raise ValueError("xAI TTS text exceeds 15000 characters") - self._request_text += prepared_text self.metrics_add_input_characters(len(prepared_text)) + self.metrics_add_output_characters(len(prepared_text)) if self._is_stopped: return @@ -369,34 +357,6 @@ async def _write_dump(self, data_msg: bytes) -> None: if self.current_request_id in self.recorder_map: await self.recorder_map[self.current_request_id].write(data_msg) - async def _emit_tts_text_result(self, reason: TTSAudioEndReason) -> None: - if not self.current_request_id or not self._request_text: - return - - metadata = self._request_metadata.copy() - current_seq_id = self._request_seq_id_map.get(self.current_request_id, 0) - self._request_seq_id_map[self.current_request_id] = current_seq_id + 1 - metadata["turn_seq_id"] = current_seq_id - metadata["turn_status"] = ( - 2 if reason == TTSAudioEndReason.INTERRUPTED else 1 - ) - - start_ms = self._audio_start_timestamp_ms - if start_ms <= 0: - start_ms = int(datetime.now().timestamp() * 1000) - - transcript_result = TTSTextResult( - request_id=self.current_request_id, - text=self._request_text, - start_ms=start_ms, - duration_ms=self._calculate_audio_duration_ms(), - words=None, - text_result_end=True, - metadata=metadata, - ) - self.metrics_add_output_characters(len(self._request_text)) - await self.send_tts_text_result(transcript_result) - async def _ensure_client(self) -> None: if self.client is None: self.client = self._create_client(self.ten_env) diff --git a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json b/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json deleted file mode 100644 index 22a72a4884..0000000000 --- a/ai_agents/agents/ten_packages/extension/xai_tts_python/tests/configs/property_subtitle_alignment.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "dump": false, - "dump_path": "/tmp", - "params": { - "api_key": "${env:XAI_API_KEY}", - "base_url": "wss://api.x.ai/v1/tts", - "voice_id": "eve", - "language": "en", - "codec": "pcm", - "sample_rate": 24000 - } -}