From be76fc9ccb2f5f1d63d58aacc99b38ece3c7e097 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 5 May 2026 12:00:56 +0200 Subject: [PATCH 1/9] Basic auth and connection support for OpenAI via Azure --- .env.example | 1 + README.md | 8 ++++++++ examples/openai-dialog.rs | 5 +++++ external/openai-api-rs | 2 +- 4 files changed, 15 insertions(+), 1 deletion(-) diff --git a/.env.example b/.env.example index bc0d668..64f9267 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ # OpenAI Configuration OPENAI_API_KEY=your_openai_key OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview +OPENAI_REALTIME_ENDPOINT= # Aristech ARISTECH_ENDPOINT= diff --git a/README.md b/README.md index 26a00f9..ab5e8e3 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ Configure the services by setting the appropriate environment variables in your # OpenAI Configuration OPENAI_API_KEY=your_openai_key OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview +OPENAI_REALTIME_ENDPOINT= # Azure Configuration AZURE_SUBSCRIPTION_KEY=your_azure_key @@ -105,6 +106,13 @@ ELEVENLABS_API_KEY=your_elevenlabs_key AUDIO_KNIFE_ADDRESS=127.0.0.1:8123 ``` +For Azure OpenAI realtime endpoints (`*.openai.azure.com`), the realtime client automatically appends +`api-key` as a query parameter to the websocket URL. For other hosts, it uses the standard +`Authorization: Bearer ...` header. + +The websocket client does not follow redirects. If the endpoint responds with `3xx` (for example +`302 Found`), update the configured endpoint URL to the final websocket target. + ## License [MIT License](LICENSE) \ No newline at end of file diff --git a/examples/openai-dialog.rs b/examples/openai-dialog.rs index e070a49..ac55a1e 100644 --- a/examples/openai-dialog.rs +++ b/examples/openai-dialog.rs @@ -76,6 +76,11 @@ async fn main() -> Result<()> { let openai = OpenAIDialog; let mut params = openai_dialog::Params::new(key, model); + if let Ok(endpoint) = env::var("OPENAI_REALTIME_ENDPOINT") + && !endpoint.trim().is_empty() + { + params.host = Some(endpoint); + } params.tools.push(get_time_function_definition()); let (output_sender, output_receiver) = unbounded_channel(); diff --git a/external/openai-api-rs b/external/openai-api-rs index f7d659e..6e14db0 160000 --- a/external/openai-api-rs +++ b/external/openai-api-rs @@ -1 +1 @@ -Subproject commit f7d659e293be65f2682ec232d3ff31dda2d139c7 +Subproject commit 6e14db059c8bd5b868d5f9832ea420803a075aaf From 20cdac751395be2608cc4bef9c3ee86e877bf948 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 5 May 2026 12:35:46 +0200 Subject: [PATCH 2/9] Port openai-dialog from beta types to GA --- external/openai-api-rs | 2 +- services/openai-dialog/src/lib.rs | 114 ++++++++++++++++++++---------- 2 files changed, 77 insertions(+), 39 deletions(-) diff --git a/external/openai-api-rs b/external/openai-api-rs index 6e14db0..03b265e 160000 --- a/external/openai-api-rs +++ b/external/openai-api-rs @@ -1 +1 @@ -Subproject commit 6e14db059c8bd5b868d5f9832ea420803a075aaf +Subproject commit 03b265e7501c85d164994e058491ca1d773e0d56 diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index 99f53ea..07f3fbc 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -16,8 +16,8 @@ use openai_api_rs::realtime::{ client_event::{self, ClientEvent}, server_event::{self, ServerEvent}, types::{ - self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus, - ToolChoice, + self, ItemContentType, ItemRole, ItemStatus, ItemType, OutputModality, RealtimeVoice, + ResponseStatus, ToolChoice, }, }; use serde::{Deserialize, Serialize}; @@ -258,7 +258,7 @@ impl Client { { let mut send_update = false; - let mut session = types::Session::default(); + let mut session = types::RealtimeSession::default(); if let Some(instructions) = params.instructions { session.instructions = Some(instructions); @@ -266,13 +266,23 @@ impl Client { }; if let Some(voice) = params.voice { - session.voice = Some(voice); + let mut audio = session.audio.unwrap_or(types::AudioConfig { + input: None, + output: None, + }); + let mut output = audio.output.unwrap_or(types::AudioOutput { + format: None, + speed: 1.0, + voice: None, + }); + output.voice = Some(voice); + audio.output = Some(output); + session.audio = Some(audio); send_update = true; } if let Some(temperature) = params.temperature { - session.temperature = Some(temperature); - send_update = true; + warn!("Ignoring unsupported realtime session temperature: {temperature}"); } if !params.tools.is_empty() { @@ -288,7 +298,7 @@ impl Client { if send_update { self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate { event_id: None, - session, + session: types::UntaggedSession::Realtime(session), })) .await?; debug!("Session updated"); @@ -356,28 +366,41 @@ impl Client { } }; - let session = session_created.session; - - // PartialEq is not implemented for AudioFormat. - let Some(types::AudioFormat::PCM16) = session.input_audio_format else { - bail!( - "Unexpected input audio format: {:?}, expected {:?}", - session.input_audio_format, - types::AudioFormat::PCM16 - ) + let session = match session_created.session { + types::UntaggedSession::Realtime(session) => session, + other => bail!("Unexpected non-realtime session: {other:?}"), }; - let Some(types::AudioFormat::PCM16) = session.output_audio_format else { - bail!( - "Unexpected output audio format: {:?}, expected {:?}", - session.output_audio_format, - types::AudioFormat::PCM16 - ) - }; + let input_format = session + .audio + .as_ref() + .and_then(|a| a.input.as_ref()) + .and_then(|i| i.format.as_ref()); + let output_format = session + .audio + .as_ref() + .and_then(|a| a.output.as_ref()) + .and_then(|o| o.format.as_ref()); + + if let Some(format) = input_format + && !matches!(format, types::AudioFormat::Pcm(_)) + { + bail!("Unexpected input audio format: {input_format:?}, expected PCM") + } - let modalities = session.modalities.unwrap_or_default(); - if !modalities.iter().any(|m| m == "audio") { - bail!("Expect `audio` modality: {:?}", modalities); + if let Some(format) = output_format + && !matches!(format, types::AudioFormat::Pcm(_)) + { + bail!("Unexpected output audio format: {output_format:?}, expected PCM") + } + + let modalities = session.output_modalities.unwrap_or_default(); + if !modalities.is_empty() + && !modalities + .iter() + .any(|m| matches!(m, OutputModality::Audio)) + { + bail!("Expect audio output modality: {:?}", modalities); } Ok(()) @@ -448,15 +471,27 @@ impl Client { tools, tool_choice, } => { + if let Some(temperature) = temperature { + warn!("Ignoring unsupported realtime session temperature: {temperature}"); + } + + let audio = voice.map(|voice| types::AudioConfig { + input: None, + output: Some(types::AudioOutput { + format: None, + speed: 1.0, + voice: Some(voice), + }), + }); + let event = ClientEvent::SessionUpdate(client_event::SessionUpdate { - session: types::Session { + session: types::UntaggedSession::Realtime(types::RealtimeSession { instructions, - voice, - temperature, + audio, tools, tool_choice, ..Default::default() - }, + }), ..Default::default() }); self.send_client_event(event).await?; @@ -684,13 +719,16 @@ impl Client { }) .await?; } - ServerEvent::SessionUpdated(server_event::SessionUpdated { - session: types::Session { tools, .. }, - .. - }) => output.service_event( - OutputPath::Control, - ServiceOutputEvent::SessionUpdated { tools }, - )?, + ServerEvent::SessionUpdated(server_event::SessionUpdated { session, .. }) => { + let tools = match session { + types::UntaggedSession::Realtime(session) => session.tools, + types::UntaggedSession::Transcription(_) => None, + }; + output.service_event( + OutputPath::Control, + ServiceOutputEvent::SessionUpdated { tools }, + )? + } response => { trace!("Unhandled response: {:?}", response) @@ -741,7 +779,7 @@ impl Client { let response = ClientEvent::ResponseCreate(client_event::ResponseCreate { event_id: Some(event_id.clone()), - response: Some(types::Session { + response: Some(types::RealtimeSession { instructions: Some(prompt_request.0.clone()), ..Default::default() }), From 36902e96b7462020ab6883bf45c612f6c6739c47 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 5 May 2026 12:45:11 +0200 Subject: [PATCH 3/9] Support tagged / untagged session updates, dependending on which host is used --- external/openai-api-rs | 2 +- services/openai-dialog/src/lib.rs | 53 +++++++++++++++++++++++++------ 2 files changed, 44 insertions(+), 11 deletions(-) diff --git a/external/openai-api-rs b/external/openai-api-rs index 03b265e..35685c4 160000 --- a/external/openai-api-rs +++ b/external/openai-api-rs @@ -1 +1 @@ -Subproject commit 03b265e7501c85d164994e058491ca1d773e0d56 +Subproject commit 35685c4cf23c089b61a9aba50f2cf423db5b051b diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index 07f3fbc..7e195b0 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -188,13 +188,22 @@ impl Host { .await .map_err(|e| anyhow!(e.to_string()))?; - Ok(Client::new(read, write)) + Ok(Client::new(read, write, self.is_azure_host())) + } + + fn is_azure_host(&self) -> bool { + self.client + .wss_url + .split('/').nth(2) + .map(|host| host.ends_with(".openai.azure.com")) + .unwrap_or(false) } } pub struct Client { read: SplitStream>>, write: SplitSink>, Message>, + use_tagged_session_update: bool, response_state: ResponseState, inflight_prompt: Option<(String, PromptRequest)>, @@ -212,10 +221,12 @@ impl Client { fn new( read: SplitStream>>, write: SplitSink>, Message>, + use_tagged_session_update: bool, ) -> Self { Self { read, write, + use_tagged_session_update, response_state: ResponseState::Idle, inflight_prompt: None, pending_prompts: Default::default(), @@ -296,9 +307,17 @@ impl Client { } if send_update { + let payload = if self.use_tagged_session_update { + client_event::SessionUpdatePayload::Tagged(types::Session::Realtime(session)) + } else { + client_event::SessionUpdatePayload::Untagged(types::UntaggedSession::Realtime( + session, + )) + }; + self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate { event_id: None, - session: types::UntaggedSession::Realtime(session), + session: payload, })) .await?; debug!("Session updated"); @@ -472,7 +491,9 @@ impl Client { tool_choice, } => { if let Some(temperature) = temperature { - warn!("Ignoring unsupported realtime session temperature: {temperature}"); + warn!( + "Ignoring unsupported realtime session temperature: {temperature}" + ); } let audio = voice.map(|voice| types::AudioConfig { @@ -484,14 +505,26 @@ impl Client { }), }); + let session = types::RealtimeSession { + instructions, + audio, + tools, + tool_choice, + ..Default::default() + }; + + let payload = if self.use_tagged_session_update { + client_event::SessionUpdatePayload::Tagged(types::Session::Realtime( + session, + )) + } else { + client_event::SessionUpdatePayload::Untagged( + types::UntaggedSession::Realtime(session), + ) + }; + let event = ClientEvent::SessionUpdate(client_event::SessionUpdate { - session: types::UntaggedSession::Realtime(types::RealtimeSession { - instructions, - audio, - tools, - tool_choice, - ..Default::default() - }), + session: payload, ..Default::default() }); self.send_client_event(event).await?; From 41256cb10e5bb3f2d37aad4b7410e81a042ffe7c Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 5 May 2026 12:50:23 +0200 Subject: [PATCH 4/9] azure openai: Support delta frames --- services/openai-dialog/src/lib.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index 7e195b0..d6f65f9 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -587,6 +587,12 @@ impl Client { ..delta.clone() }) } + ServerEvent::ResponseOutputAudioDelta(delta) => { + ServerEvent::ResponseOutputAudioDelta(server_event::ResponseOutputAudioDelta { + delta: "[REMOVED]".to_string(), + ..delta.clone() + }) + } event => event.clone(), }; @@ -621,6 +627,16 @@ impl Client { }; output.audio_frame(frame)?; } + ServerEvent::ResponseOutputAudioDelta(audio_delta) => { + let decoded = BASE64_STANDARD.decode(audio_delta.delta)?; + let samples = audio::from_le_bytes(&decoded); + trace!("Sending {} samples", samples.len()); + let frame = AudioFrame { + format: output_format, + samples, + }; + output.audio_frame(frame)?; + } ServerEvent::InputAudioBufferSpeechStarted(_) => output.clear_audio()?, ServerEvent::ResponseCreated(server_event::ResponseCreated { response: types::Response { object, .. }, From 4c6d210dd68dcf731e111fe184bfdfa219d714c2 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 5 May 2026 13:04:08 +0200 Subject: [PATCH 5/9] fmt --- services/openai-dialog/src/lib.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index d6f65f9..6836d72 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -194,7 +194,8 @@ impl Host { fn is_azure_host(&self) -> bool { self.client .wss_url - .split('/').nth(2) + .split('/') + .nth(2) .map(|host| host.ends_with(".openai.azure.com")) .unwrap_or(false) } From aed5aea038ee9fa3ea7266b4cfb725e254142ef1 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 6 May 2026 14:54:21 +0200 Subject: [PATCH 6/9] openai-dialog: Improve protocol handling --- Cargo.toml | 1 + examples/openai-dialog.rs | 38 +++++++-- external/openai-api-rs | 2 +- services/openai-dialog/Cargo.toml | 1 + services/openai-dialog/src/lib.rs | 136 ++++++++++++++++++++++-------- 5 files changed, 135 insertions(+), 43 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1bca914..e7174dd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,7 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } openai-api-rs = { workspace = true } serde_json = { workspace = true } chrono-tz = { version = "0.10.3" } +url = { workspace = true } # For recognizing audio files in azure-transcribe. diff --git a/examples/openai-dialog.rs b/examples/openai-dialog.rs index ac55a1e..f02350d 100644 --- a/examples/openai-dialog.rs +++ b/examples/openai-dialog.rs @@ -10,10 +10,11 @@ use std::{ use anyhow::{Context, Result, bail}; use chrono::Utc; +use clap::{Parser, ValueEnum}; use context_switch::{InputModality, OutputModality}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use openai_api_rs::realtime::types; -use openai_dialog::{OpenAIDialog, ServiceInputEvent, ServiceOutputEvent}; +use openai_dialog::{OpenAIDialog, Protocol, ServiceInputEvent, ServiceOutputEvent}; use rodio::{DeviceSinkBuilder, Player, Source}; use serde_json::json; use tokio::{ @@ -27,8 +28,33 @@ use context_switch_core::{ conversation::{Conversation, Input, Output}, }; +#[derive(Debug, Parser)] +struct Cli { + #[arg(long, value_enum)] + protocol: Option, + #[arg(long)] + endpoint: Option, +} + +#[derive(Debug, Clone, Copy, ValueEnum)] +enum CliProtocol { + #[value(name = "openai")] + OpenAI, + Azure, +} + +impl From for Protocol { + fn from(value: CliProtocol) -> Self { + match value { + CliProtocol::OpenAI => Protocol::OpenAI, + CliProtocol::Azure => Protocol::Azure, + } + } +} + #[tokio::main] async fn main() -> Result<()> { + let cli = Cli::parse(); dotenvy::dotenv_override().context("Reading .env file")?; tracing_subscriber::fmt::init(); @@ -76,11 +102,11 @@ async fn main() -> Result<()> { let openai = OpenAIDialog; let mut params = openai_dialog::Params::new(key, model); - if let Ok(endpoint) = env::var("OPENAI_REALTIME_ENDPOINT") - && !endpoint.trim().is_empty() - { - params.host = Some(endpoint); - } + params.host = cli + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = cli.protocol.map(Into::into); params.tools.push(get_time_function_definition()); let (output_sender, output_receiver) = unbounded_channel(); diff --git a/external/openai-api-rs b/external/openai-api-rs index 35685c4..2b2792d 160000 --- a/external/openai-api-rs +++ b/external/openai-api-rs @@ -1 +1 @@ -Subproject commit 35685c4cf23c089b61a9aba50f2cf423db5b051b +Subproject commit 2b2792d561a3dfe95c1ee9de2968c08727c481da diff --git a/services/openai-dialog/Cargo.toml b/services/openai-dialog/Cargo.toml index 3654c0f..436f653 100644 --- a/services/openai-dialog/Cargo.toml +++ b/services/openai-dialog/Cargo.toml @@ -22,3 +22,4 @@ base64 = { workspace = true } serde = { workspace = true } async-trait = { workspace = true } uuid = { workspace = true } +url = { workspace = true } diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index 6836d72..d9feeee 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -12,7 +12,7 @@ use futures::{ stream::{SplitSink, SplitStream}, }; use openai_api_rs::realtime::{ - api::RealtimeClient, + api::{RealtimeClient, RealtimeProtocol}, client_event::{self, ClientEvent}, server_event::{self, ServerEvent}, types::{ @@ -27,6 +27,7 @@ use tokio_tungstenite::{ tungstenite::{Bytes, protocol::Message}, }; use tracing::{debug, info, trace, warn}; +use url::Url; use uuid::Uuid; use context_switch_core::{ @@ -39,6 +40,7 @@ use context_switch_core::{ pub struct Params { pub api_key: String, pub model: String, + pub protocol: Option, pub host: Option, pub instructions: Option, pub voice: Option, @@ -53,6 +55,7 @@ impl Params { Self { api_key: api_key.into(), model: model.into(), + protocol: None, host: None, instructions: None, voice: None, @@ -63,6 +66,22 @@ impl Params { } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Protocol { + OpenAI, + Azure, +} + +impl Protocol { + fn to_realtime_protocol(self) -> RealtimeProtocol { + match self { + Protocol::OpenAI => RealtimeProtocol::OpenAI, + Protocol::Azure => RealtimeProtocol::Azure, + } + } +} + #[derive(Debug)] pub struct OpenAIDialog; @@ -80,10 +99,12 @@ impl Service for OpenAIDialog { bail!("Input and output audio formats must match for OpenAI dialog service"); } + let protocol = resolve_protocol(params.protocol, params.host.as_deref())?; + let host = if let Some(host) = ¶ms.host { - Host::new_with_host(host, ¶ms.api_key, ¶ms.model) + Host::new_with_host(host, ¶ms.api_key, ¶ms.model, protocol) } else { - Host::new(¶ms.api_key, ¶ms.model) + Host::new(¶ms.api_key, ¶ms.model, protocol) }; info!("Connecting to {host:?}"); let mut client = host.connect().await?; @@ -107,6 +128,52 @@ impl Service for OpenAIDialog { } } +fn resolve_protocol(protocol: Option, host: Option<&str>) -> Result { + let protocol = match protocol { + Some(protocol) => Ok(protocol), + None => infer_protocol_from_host(host), + }?; + + validate_protocol_host(protocol, host)?; + Ok(protocol) +} + +fn validate_protocol_host(protocol: Protocol, host: Option<&str>) -> Result<()> { + match (protocol, host) { + (Protocol::Azure, None) => { + bail!( + "Protocol `azure` requires an Azure OpenAI `host` endpoint. Set `host` to your Azure OpenAI realtime websocket URL." + ) + } + (Protocol::Azure, Some(_)) => Ok(()), + (Protocol::OpenAI, _) => Ok(()), + } +} + +fn infer_protocol_from_host(host: Option<&str>) -> Result { + let host = match host { + Some(host) => host, + None => return Ok(Protocol::OpenAI), + }; + + let parsed = Url::parse(host).with_context(|| format!("Invalid host URL: {host}"))?; + + let is_openai_realtime_endpoint = parsed.scheme() == "wss" + && parsed.host_str() == Some("api.openai.com") + && parsed.path() == "/v1/realtime"; + + if is_openai_realtime_endpoint { + return Ok(Protocol::OpenAI); + } + + match parsed.host_str() { + Some(host) if host.ends_with(".openai.azure.com") => Ok(Protocol::Azure), + _ => bail!( + "Cannot infer protocol from host `{host}`. Set `protocol` explicitly to `openai` or `azure`, use `wss://api.openai.com/v1/realtime`, or use an Azure OpenAI host." + ), + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum ServiceInputEvent { @@ -165,18 +232,24 @@ impl fmt::Debug for Host { } impl Host { - pub fn new_with_host(host: &str, api_key: &str, model: &str) -> Self { + pub fn new_with_host(host: &str, api_key: &str, model: &str, protocol: Protocol) -> Self { Host { - client: RealtimeClient::new_with_endpoint(host.into(), api_key.into(), model.into()), + client: RealtimeClient::new_with_endpoint_and_protocol( + host.into(), + api_key.into(), + model.into(), + protocol.to_realtime_protocol(), + ), } } - pub fn new(api_key: &str, model: &str) -> Self { + pub fn new(api_key: &str, model: &str, protocol: Protocol) -> Self { Host { - client: RealtimeClient::new_with_endpoint( + client: RealtimeClient::new_with_endpoint_and_protocol( "wss://api.openai.com/v1/realtime".into(), api_key.into(), model.into(), + protocol.to_realtime_protocol(), ), } } @@ -188,23 +261,14 @@ impl Host { .await .map_err(|e| anyhow!(e.to_string()))?; - Ok(Client::new(read, write, self.is_azure_host())) - } - - fn is_azure_host(&self) -> bool { - self.client - .wss_url - .split('/') - .nth(2) - .map(|host| host.ends_with(".openai.azure.com")) - .unwrap_or(false) + Ok(Client::new(read, write, self.client.protocol)) } } pub struct Client { read: SplitStream>>, write: SplitSink>, Message>, - use_tagged_session_update: bool, + protocol: RealtimeProtocol, response_state: ResponseState, inflight_prompt: Option<(String, PromptRequest)>, @@ -222,18 +286,32 @@ impl Client { fn new( read: SplitStream>>, write: SplitSink>, Message>, - use_tagged_session_update: bool, + protocol: RealtimeProtocol, ) -> Self { Self { read, write, - use_tagged_session_update, + protocol, response_state: ResponseState::Idle, inflight_prompt: None, pending_prompts: Default::default(), } } + fn session_update_payload( + &self, + session: types::RealtimeSession, + ) -> client_event::SessionUpdatePayload { + match self.protocol { + RealtimeProtocol::OpenAI => client_event::SessionUpdatePayload::Untagged( + types::UntaggedSession::Realtime(session), + ), + RealtimeProtocol::Azure => { + client_event::SessionUpdatePayload::Tagged(types::Session::Realtime(session)) + } + } + } + /// Run an audio dialog. pub async fn dialog( &mut self, @@ -308,13 +386,7 @@ impl Client { } if send_update { - let payload = if self.use_tagged_session_update { - client_event::SessionUpdatePayload::Tagged(types::Session::Realtime(session)) - } else { - client_event::SessionUpdatePayload::Untagged(types::UntaggedSession::Realtime( - session, - )) - }; + let payload = self.session_update_payload(session); self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate { event_id: None, @@ -514,15 +586,7 @@ impl Client { ..Default::default() }; - let payload = if self.use_tagged_session_update { - client_event::SessionUpdatePayload::Tagged(types::Session::Realtime( - session, - )) - } else { - client_event::SessionUpdatePayload::Untagged( - types::UntaggedSession::Realtime(session), - ) - }; + let payload = self.session_update_payload(session); let event = ClientEvent::SessionUpdate(client_event::SessionUpdate { session: payload, From cf2ff7ef684ded14ac915b888140faf613402c14 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 6 May 2026 15:16:47 +0200 Subject: [PATCH 7/9] openai-dialog: Remove temperature support --- services/openai-dialog/src/lib.rs | 15 --------------- 1 file changed, 15 deletions(-) diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index d9feeee..cd0e1e1 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -44,7 +44,6 @@ pub struct Params { pub host: Option, pub instructions: Option, pub voice: Option, - pub temperature: Option, #[serde(default)] pub tools: Vec, tool_choice: Option, @@ -59,7 +58,6 @@ impl Params { host: None, instructions: None, voice: None, - temperature: None, tools: vec![], tool_choice: None, } @@ -192,8 +190,6 @@ pub enum ServiceInputEvent { #[serde(skip_serializing_if = "Option::is_none")] voice: Option, #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_choice: Option, @@ -371,10 +367,6 @@ impl Client { send_update = true; } - if let Some(temperature) = params.temperature { - warn!("Ignoring unsupported realtime session temperature: {temperature}"); - } - if !params.tools.is_empty() { session.tools = Some(params.tools); send_update = true; @@ -559,16 +551,9 @@ impl Client { ServiceInputEvent::SessionUpdate { instructions, voice, - temperature, tools, tool_choice, } => { - if let Some(temperature) = temperature { - warn!( - "Ignoring unsupported realtime session temperature: {temperature}" - ); - } - let audio = voice.map(|voice| types::AudioConfig { input: None, output: Some(types::AudioOutput { From a8f7b201b68d397f82a96a82cce62cbcc03990bd Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 6 May 2026 15:53:44 +0200 Subject: [PATCH 8/9] openai-dialog: Implement protocol specific ways to specify the voice --- Cargo.toml | 1 + examples/openai-dialog.rs | 17 +++++++++++ external/openai-api-rs | 2 +- services/openai-dialog/src/lib.rs | 50 ++++++++++++++++++------------- 4 files changed, 49 insertions(+), 21 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index e7174dd..54ed8cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,6 +81,7 @@ openai-api-rs = { workspace = true } serde_json = { workspace = true } chrono-tz = { version = "0.10.3" } url = { workspace = true } +strum = { version = "0.28" } # For recognizing audio files in azure-transcribe. diff --git a/examples/openai-dialog.rs b/examples/openai-dialog.rs index f02350d..a54d875 100644 --- a/examples/openai-dialog.rs +++ b/examples/openai-dialog.rs @@ -10,6 +10,7 @@ use std::{ use anyhow::{Context, Result, bail}; use chrono::Utc; +use clap::builder::{PossibleValuesParser, TypedValueParser}; use clap::{Parser, ValueEnum}; use context_switch::{InputModality, OutputModality}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; @@ -34,6 +35,8 @@ struct Cli { protocol: Option, #[arg(long)] endpoint: Option, + #[arg(long, value_parser = realtime_voice_value_parser())] + voice: Option, } #[derive(Debug, Clone, Copy, ValueEnum)] @@ -43,6 +46,19 @@ enum CliProtocol { Azure, } +fn realtime_voice_value_parser() -> impl TypedValueParser { + PossibleValuesParser::new(::VARIANTS).try_map( + |value| { + parse_realtime_voice_value(&value) + .map_err(|e| format!("Invalid voice value `{value}`: {e}")) + }, + ) +} + +fn parse_realtime_voice_value(value: &str) -> Result { + types::RealtimeVoice::from_str(value) +} + impl From for Protocol { fn from(value: CliProtocol) -> Self { match value { @@ -107,6 +123,7 @@ async fn main() -> Result<()> { .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) .filter(|endpoint| !endpoint.trim().is_empty()); params.protocol = cli.protocol.map(Into::into); + params.voice = cli.voice; params.tools.push(get_time_function_definition()); let (output_sender, output_receiver) = unbounded_channel(); diff --git a/external/openai-api-rs b/external/openai-api-rs index 2b2792d..c37d725 160000 --- a/external/openai-api-rs +++ b/external/openai-api-rs @@ -1 +1 @@ -Subproject commit 2b2792d561a3dfe95c1ee9de2968c08727c481da +Subproject commit c37d72517f9645c6eed00a15d1b90cee0b14a7fe diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index cd0e1e1..109ed08 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -352,18 +352,21 @@ impl Client { }; if let Some(voice) = params.voice { - let mut audio = session.audio.unwrap_or(types::AudioConfig { - input: None, - output: None, - }); - let mut output = audio.output.unwrap_or(types::AudioOutput { - format: None, - speed: 1.0, - voice: None, - }); - output.voice = Some(voice); - audio.output = Some(output); - session.audio = Some(audio); + match self.protocol { + RealtimeProtocol::OpenAI => { + session.voice = Some(voice); + } + RealtimeProtocol::Azure => { + session.audio = Some(types::AudioConfig { + input: None, + output: Some(types::AudioOutput { + format: None, + speed: 1.0, + voice: Some(voice), + }), + }); + } + } send_update = true; } @@ -554,17 +557,24 @@ impl Client { tools, tool_choice, } => { - let audio = voice.map(|voice| types::AudioConfig { - input: None, - output: Some(types::AudioOutput { - format: None, - speed: 1.0, - voice: Some(voice), - }), - }); + let (voice, audio) = match self.protocol { + RealtimeProtocol::OpenAI => (voice, None), + RealtimeProtocol::Azure => { + let audio = voice.map(|voice| types::AudioConfig { + input: None, + output: Some(types::AudioOutput { + format: None, + speed: 1.0, + voice: Some(voice), + }), + }); + (None, audio) + } + }; let session = types::RealtimeSession { instructions, + voice, audio, tools, tool_choice, From 602b4350534d118c94ac87f08e793c30af9ac3d2 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 6 May 2026 16:12:14 +0200 Subject: [PATCH 9/9] Review code --- services/openai-dialog/src/lib.rs | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index 109ed08..01333e1 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -156,16 +156,9 @@ fn infer_protocol_from_host(host: Option<&str>) -> Result { let parsed = Url::parse(host).with_context(|| format!("Invalid host URL: {host}"))?; - let is_openai_realtime_endpoint = parsed.scheme() == "wss" - && parsed.host_str() == Some("api.openai.com") - && parsed.path() == "/v1/realtime"; - - if is_openai_realtime_endpoint { - return Ok(Protocol::OpenAI); - } - - match parsed.host_str() { - Some(host) if host.ends_with(".openai.azure.com") => Ok(Protocol::Azure), + match (parsed.scheme(), parsed.host_str(), parsed.path()) { + ("wss", Some("api.openai.com"), "/v1/realtime") => Ok(Protocol::OpenAI), + (_, Some(host), _) if host.ends_with(".openai.azure.com") => Ok(Protocol::Azure), _ => bail!( "Cannot infer protocol from host `{host}`. Set `protocol` explicitly to `openai` or `azure`, use `wss://api.openai.com/v1/realtime`, or use an Azure OpenAI host." ), @@ -458,16 +451,12 @@ impl Client { other => bail!("Unexpected non-realtime session: {other:?}"), }; - let input_format = session - .audio - .as_ref() + // OpenAI may omit audio here; treat missing as default behavior. + let audio = session.audio.as_ref(); + + let input_format = audio .and_then(|a| a.input.as_ref()) .and_then(|i| i.format.as_ref()); - let output_format = session - .audio - .as_ref() - .and_then(|a| a.output.as_ref()) - .and_then(|o| o.format.as_ref()); if let Some(format) = input_format && !matches!(format, types::AudioFormat::Pcm(_)) @@ -475,12 +464,17 @@ impl Client { bail!("Unexpected input audio format: {input_format:?}, expected PCM") } + let output_format = audio + .and_then(|a| a.output.as_ref()) + .and_then(|o| o.format.as_ref()); + if let Some(format) = output_format && !matches!(format, types::AudioFormat::Pcm(_)) { bail!("Unexpected output audio format: {output_format:?}, expected PCM") } + // OpenAI may omit output modalities here; treat missing as default behavior. let modalities = session.output_modalities.unwrap_or_default(); if !modalities.is_empty() && !modalities