diff --git a/.env.example b/.env.example index bc0d668..64f9267 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,7 @@ # OpenAI Configuration OPENAI_API_KEY=your_openai_key OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview +OPENAI_REALTIME_ENDPOINT= # Aristech ARISTECH_ENDPOINT= diff --git a/Cargo.toml b/Cargo.toml index 1bca914..54ed8cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -80,6 +80,8 @@ tokio = { workspace = true, features = ["rt-multi-thread"] } openai-api-rs = { workspace = true } serde_json = { workspace = true } chrono-tz = { version = "0.10.3" } +url = { workspace = true } +strum = { version = "0.28" } # For recognizing audio files in azure-transcribe. diff --git a/README.md b/README.md index 26a00f9..ab5e8e3 100644 --- a/README.md +++ b/README.md @@ -93,6 +93,7 @@ Configure the services by setting the appropriate environment variables in your # OpenAI Configuration OPENAI_API_KEY=your_openai_key OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview +OPENAI_REALTIME_ENDPOINT= # Azure Configuration AZURE_SUBSCRIPTION_KEY=your_azure_key @@ -105,6 +106,13 @@ ELEVENLABS_API_KEY=your_elevenlabs_key AUDIO_KNIFE_ADDRESS=127.0.0.1:8123 ``` +For Azure OpenAI realtime endpoints (`*.openai.azure.com`), the realtime client automatically appends +`api-key` as a query parameter to the websocket URL. For other hosts, it uses the standard +`Authorization: Bearer ...` header. + +The websocket client does not follow redirects. If the endpoint responds with `3xx` (for example +`302 Found`), update the configured endpoint URL to the final websocket target. + ## License [MIT License](LICENSE) \ No newline at end of file diff --git a/examples/openai-dialog.rs b/examples/openai-dialog.rs index e070a49..a54d875 100644 --- a/examples/openai-dialog.rs +++ b/examples/openai-dialog.rs @@ -10,10 +10,12 @@ use std::{ use anyhow::{Context, Result, bail}; use chrono::Utc; +use clap::builder::{PossibleValuesParser, TypedValueParser}; +use clap::{Parser, ValueEnum}; use context_switch::{InputModality, OutputModality}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use openai_api_rs::realtime::types; -use openai_dialog::{OpenAIDialog, ServiceInputEvent, ServiceOutputEvent}; +use openai_dialog::{OpenAIDialog, Protocol, ServiceInputEvent, ServiceOutputEvent}; use rodio::{DeviceSinkBuilder, Player, Source}; use serde_json::json; use tokio::{ @@ -27,8 +29,48 @@ use context_switch_core::{ conversation::{Conversation, Input, Output}, }; +#[derive(Debug, Parser)] +struct Cli { + #[arg(long, value_enum)] + protocol: Option, + #[arg(long)] + endpoint: Option, + #[arg(long, value_parser = realtime_voice_value_parser())] + voice: Option, +} + +#[derive(Debug, Clone, Copy, ValueEnum)] +enum CliProtocol { + #[value(name = "openai")] + OpenAI, + Azure, +} + +fn realtime_voice_value_parser() -> impl TypedValueParser { + PossibleValuesParser::new(::VARIANTS).try_map( + |value| { + parse_realtime_voice_value(&value) + .map_err(|e| format!("Invalid voice value `{value}`: {e}")) + }, + ) +} + +fn parse_realtime_voice_value(value: &str) -> Result { + types::RealtimeVoice::from_str(value) +} + +impl From for Protocol { + fn from(value: CliProtocol) -> Self { + match value { + CliProtocol::OpenAI => Protocol::OpenAI, + CliProtocol::Azure => Protocol::Azure, + } + } +} + #[tokio::main] async fn main() -> Result<()> { + let cli = Cli::parse(); dotenvy::dotenv_override().context("Reading .env file")?; tracing_subscriber::fmt::init(); @@ -76,6 +118,12 @@ async fn main() -> Result<()> { let openai = OpenAIDialog; let mut params = openai_dialog::Params::new(key, model); + params.host = cli + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = cli.protocol.map(Into::into); + params.voice = cli.voice; params.tools.push(get_time_function_definition()); let (output_sender, output_receiver) = unbounded_channel(); diff --git a/external/openai-api-rs b/external/openai-api-rs index f7d659e..c37d725 160000 --- a/external/openai-api-rs +++ b/external/openai-api-rs @@ -1 +1 @@ -Subproject commit f7d659e293be65f2682ec232d3ff31dda2d139c7 +Subproject commit c37d72517f9645c6eed00a15d1b90cee0b14a7fe diff --git a/services/openai-dialog/Cargo.toml b/services/openai-dialog/Cargo.toml index 3654c0f..436f653 100644 --- a/services/openai-dialog/Cargo.toml +++ b/services/openai-dialog/Cargo.toml @@ -22,3 +22,4 @@ base64 = { workspace = true } serde = { workspace = true } async-trait = { workspace = true } uuid = { workspace = true } +url = { workspace = true } diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index 99f53ea..01333e1 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -12,12 +12,12 @@ use futures::{ stream::{SplitSink, SplitStream}, }; use openai_api_rs::realtime::{ - api::RealtimeClient, + api::{RealtimeClient, RealtimeProtocol}, client_event::{self, ClientEvent}, server_event::{self, ServerEvent}, types::{ - self, ItemContentType, ItemRole, ItemStatus, ItemType, RealtimeVoice, ResponseStatus, - ToolChoice, + self, ItemContentType, ItemRole, ItemStatus, ItemType, OutputModality, RealtimeVoice, + ResponseStatus, ToolChoice, }, }; use serde::{Deserialize, Serialize}; @@ -27,6 +27,7 @@ use tokio_tungstenite::{ tungstenite::{Bytes, protocol::Message}, }; use tracing::{debug, info, trace, warn}; +use url::Url; use uuid::Uuid; use context_switch_core::{ @@ -39,10 +40,10 @@ use context_switch_core::{ pub struct Params { pub api_key: String, pub model: String, + pub protocol: Option, pub host: Option, pub instructions: Option, pub voice: Option, - pub temperature: Option, #[serde(default)] pub tools: Vec, tool_choice: Option, @@ -53,16 +54,32 @@ impl Params { Self { api_key: api_key.into(), model: model.into(), + protocol: None, host: None, instructions: None, voice: None, - temperature: None, tools: vec![], tool_choice: None, } } } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Protocol { + OpenAI, + Azure, +} + +impl Protocol { + fn to_realtime_protocol(self) -> RealtimeProtocol { + match self { + Protocol::OpenAI => RealtimeProtocol::OpenAI, + Protocol::Azure => RealtimeProtocol::Azure, + } + } +} + #[derive(Debug)] pub struct OpenAIDialog; @@ -80,10 +97,12 @@ impl Service for OpenAIDialog { bail!("Input and output audio formats must match for OpenAI dialog service"); } + let protocol = resolve_protocol(params.protocol, params.host.as_deref())?; + let host = if let Some(host) = ¶ms.host { - Host::new_with_host(host, ¶ms.api_key, ¶ms.model) + Host::new_with_host(host, ¶ms.api_key, ¶ms.model, protocol) } else { - Host::new(¶ms.api_key, ¶ms.model) + Host::new(¶ms.api_key, ¶ms.model, protocol) }; info!("Connecting to {host:?}"); let mut client = host.connect().await?; @@ -107,6 +126,45 @@ impl Service for OpenAIDialog { } } +fn resolve_protocol(protocol: Option, host: Option<&str>) -> Result { + let protocol = match protocol { + Some(protocol) => Ok(protocol), + None => infer_protocol_from_host(host), + }?; + + validate_protocol_host(protocol, host)?; + Ok(protocol) +} + +fn validate_protocol_host(protocol: Protocol, host: Option<&str>) -> Result<()> { + match (protocol, host) { + (Protocol::Azure, None) => { + bail!( + "Protocol `azure` requires an Azure OpenAI `host` endpoint. Set `host` to your Azure OpenAI realtime websocket URL." + ) + } + (Protocol::Azure, Some(_)) => Ok(()), + (Protocol::OpenAI, _) => Ok(()), + } +} + +fn infer_protocol_from_host(host: Option<&str>) -> Result { + let host = match host { + Some(host) => host, + None => return Ok(Protocol::OpenAI), + }; + + let parsed = Url::parse(host).with_context(|| format!("Invalid host URL: {host}"))?; + + match (parsed.scheme(), parsed.host_str(), parsed.path()) { + ("wss", Some("api.openai.com"), "/v1/realtime") => Ok(Protocol::OpenAI), + (_, Some(host), _) if host.ends_with(".openai.azure.com") => Ok(Protocol::Azure), + _ => bail!( + "Cannot infer protocol from host `{host}`. Set `protocol` explicitly to `openai` or `azure`, use `wss://api.openai.com/v1/realtime`, or use an Azure OpenAI host." + ), + } +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum ServiceInputEvent { @@ -125,8 +183,6 @@ pub enum ServiceInputEvent { #[serde(skip_serializing_if = "Option::is_none")] voice: Option, #[serde(skip_serializing_if = "Option::is_none")] - temperature: Option, - #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_choice: Option, @@ -165,18 +221,24 @@ impl fmt::Debug for Host { } impl Host { - pub fn new_with_host(host: &str, api_key: &str, model: &str) -> Self { + pub fn new_with_host(host: &str, api_key: &str, model: &str, protocol: Protocol) -> Self { Host { - client: RealtimeClient::new_with_endpoint(host.into(), api_key.into(), model.into()), + client: RealtimeClient::new_with_endpoint_and_protocol( + host.into(), + api_key.into(), + model.into(), + protocol.to_realtime_protocol(), + ), } } - pub fn new(api_key: &str, model: &str) -> Self { + pub fn new(api_key: &str, model: &str, protocol: Protocol) -> Self { Host { - client: RealtimeClient::new_with_endpoint( + client: RealtimeClient::new_with_endpoint_and_protocol( "wss://api.openai.com/v1/realtime".into(), api_key.into(), model.into(), + protocol.to_realtime_protocol(), ), } } @@ -188,13 +250,14 @@ impl Host { .await .map_err(|e| anyhow!(e.to_string()))?; - Ok(Client::new(read, write)) + Ok(Client::new(read, write, self.client.protocol)) } } pub struct Client { read: SplitStream>>, write: SplitSink>, Message>, + protocol: RealtimeProtocol, response_state: ResponseState, inflight_prompt: Option<(String, PromptRequest)>, @@ -212,16 +275,32 @@ impl Client { fn new( read: SplitStream>>, write: SplitSink>, Message>, + protocol: RealtimeProtocol, ) -> Self { Self { read, write, + protocol, response_state: ResponseState::Idle, inflight_prompt: None, pending_prompts: Default::default(), } } + fn session_update_payload( + &self, + session: types::RealtimeSession, + ) -> client_event::SessionUpdatePayload { + match self.protocol { + RealtimeProtocol::OpenAI => client_event::SessionUpdatePayload::Untagged( + types::UntaggedSession::Realtime(session), + ), + RealtimeProtocol::Azure => { + client_event::SessionUpdatePayload::Tagged(types::Session::Realtime(session)) + } + } + } + /// Run an audio dialog. pub async fn dialog( &mut self, @@ -258,7 +337,7 @@ impl Client { { let mut send_update = false; - let mut session = types::Session::default(); + let mut session = types::RealtimeSession::default(); if let Some(instructions) = params.instructions { session.instructions = Some(instructions); @@ -266,12 +345,21 @@ impl Client { }; if let Some(voice) = params.voice { - session.voice = Some(voice); - send_update = true; - } - - if let Some(temperature) = params.temperature { - session.temperature = Some(temperature); + match self.protocol { + RealtimeProtocol::OpenAI => { + session.voice = Some(voice); + } + RealtimeProtocol::Azure => { + session.audio = Some(types::AudioConfig { + input: None, + output: Some(types::AudioOutput { + format: None, + speed: 1.0, + voice: Some(voice), + }), + }); + } + } send_update = true; } @@ -286,9 +374,11 @@ impl Client { } if send_update { + let payload = self.session_update_payload(session); + self.send_client_event(ClientEvent::SessionUpdate(client_event::SessionUpdate { event_id: None, - session, + session: payload, })) .await?; debug!("Session updated"); @@ -356,28 +446,42 @@ impl Client { } }; - let session = session_created.session; - - // PartialEq is not implemented for AudioFormat. - let Some(types::AudioFormat::PCM16) = session.input_audio_format else { - bail!( - "Unexpected input audio format: {:?}, expected {:?}", - session.input_audio_format, - types::AudioFormat::PCM16 - ) + let session = match session_created.session { + types::UntaggedSession::Realtime(session) => session, + other => bail!("Unexpected non-realtime session: {other:?}"), }; - let Some(types::AudioFormat::PCM16) = session.output_audio_format else { - bail!( - "Unexpected output audio format: {:?}, expected {:?}", - session.output_audio_format, - types::AudioFormat::PCM16 - ) - }; + // OpenAI may omit audio here; treat missing as default behavior. + let audio = session.audio.as_ref(); + + let input_format = audio + .and_then(|a| a.input.as_ref()) + .and_then(|i| i.format.as_ref()); - let modalities = session.modalities.unwrap_or_default(); - if !modalities.iter().any(|m| m == "audio") { - bail!("Expect `audio` modality: {:?}", modalities); + if let Some(format) = input_format + && !matches!(format, types::AudioFormat::Pcm(_)) + { + bail!("Unexpected input audio format: {input_format:?}, expected PCM") + } + + let output_format = audio + .and_then(|a| a.output.as_ref()) + .and_then(|o| o.format.as_ref()); + + if let Some(format) = output_format + && !matches!(format, types::AudioFormat::Pcm(_)) + { + bail!("Unexpected output audio format: {output_format:?}, expected PCM") + } + + // OpenAI may omit output modalities here; treat missing as default behavior. + let modalities = session.output_modalities.unwrap_or_default(); + if !modalities.is_empty() + && !modalities + .iter() + .any(|m| matches!(m, OutputModality::Audio)) + { + bail!("Expect audio output modality: {:?}", modalities); } Ok(()) @@ -444,19 +548,37 @@ impl Client { ServiceInputEvent::SessionUpdate { instructions, voice, - temperature, tools, tool_choice, } => { + let (voice, audio) = match self.protocol { + RealtimeProtocol::OpenAI => (voice, None), + RealtimeProtocol::Azure => { + let audio = voice.map(|voice| types::AudioConfig { + input: None, + output: Some(types::AudioOutput { + format: None, + speed: 1.0, + voice: Some(voice), + }), + }); + (None, audio) + } + }; + + let session = types::RealtimeSession { + instructions, + voice, + audio, + tools, + tool_choice, + ..Default::default() + }; + + let payload = self.session_update_payload(session); + let event = ClientEvent::SessionUpdate(client_event::SessionUpdate { - session: types::Session { - instructions, - voice, - temperature, - tools, - tool_choice, - ..Default::default() - }, + session: payload, ..Default::default() }); self.send_client_event(event).await?; @@ -519,6 +641,12 @@ impl Client { ..delta.clone() }) } + ServerEvent::ResponseOutputAudioDelta(delta) => { + ServerEvent::ResponseOutputAudioDelta(server_event::ResponseOutputAudioDelta { + delta: "[REMOVED]".to_string(), + ..delta.clone() + }) + } event => event.clone(), }; @@ -553,6 +681,16 @@ impl Client { }; output.audio_frame(frame)?; } + ServerEvent::ResponseOutputAudioDelta(audio_delta) => { + let decoded = BASE64_STANDARD.decode(audio_delta.delta)?; + let samples = audio::from_le_bytes(&decoded); + trace!("Sending {} samples", samples.len()); + let frame = AudioFrame { + format: output_format, + samples, + }; + output.audio_frame(frame)?; + } ServerEvent::InputAudioBufferSpeechStarted(_) => output.clear_audio()?, ServerEvent::ResponseCreated(server_event::ResponseCreated { response: types::Response { object, .. }, @@ -684,13 +822,16 @@ impl Client { }) .await?; } - ServerEvent::SessionUpdated(server_event::SessionUpdated { - session: types::Session { tools, .. }, - .. - }) => output.service_event( - OutputPath::Control, - ServiceOutputEvent::SessionUpdated { tools }, - )?, + ServerEvent::SessionUpdated(server_event::SessionUpdated { session, .. }) => { + let tools = match session { + types::UntaggedSession::Realtime(session) => session.tools, + types::UntaggedSession::Transcription(_) => None, + }; + output.service_event( + OutputPath::Control, + ServiceOutputEvent::SessionUpdated { tools }, + )? + } response => { trace!("Unhandled response: {:?}", response) @@ -741,7 +882,7 @@ impl Client { let response = ClientEvent::ResponseCreate(client_event::ResponseCreate { event_id: Some(event_id.clone()), - response: Some(types::Session { + response: Some(types::RealtimeSession { instructions: Some(prompt_request.0.clone()), ..Default::default() }),