From 68c233547a9674dfecc810ac460a0b3a9ea0549d Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 07:06:43 +0200 Subject: [PATCH 01/20] Initial implementation of google-dialog --- Cargo.toml | 6 + examples/dialog.rs | 442 +++++++++++++++++++++++++++ examples/openai-dialog.rs | 330 -------------------- services/google-dialog/Cargo.toml | 16 + services/google-dialog/src/client.rs | 318 +++++++++++++++++++ services/google-dialog/src/lib.rs | 48 +++ services/google-dialog/src/types.rs | 66 ++++ src/context_switch.rs | 1 + src/lib.rs | 1 + 9 files changed, 898 insertions(+), 330 deletions(-) create mode 100644 examples/dialog.rs delete mode 100644 examples/openai-dialog.rs create mode 100644 services/google-dialog/Cargo.toml create mode 100644 services/google-dialog/src/client.rs create mode 100644 services/google-dialog/src/lib.rs create mode 100644 services/google-dialog/src/types.rs diff --git a/Cargo.toml b/Cargo.toml index f7d43a1..12ba84b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "services/aristech", "services/azure", "services/elevenlabs", + "services/google-dialog", "services/google-transcribe", "services/openai-dialog", "services/playback", @@ -29,6 +30,7 @@ edition = "2024" context-switch-core = { workspace = true } openai-dialog = { path = "services/openai-dialog" } +google-dialog = { workspace = true } azure = { workspace = true } azure-speech = { workspace = true } aristech = { workspace = true } @@ -73,6 +75,8 @@ rodio = { workspace = true, features = ["playback"] } azure = { workspace = true } aristech = { workspace = true } google-transcribe = { workspace = true } +google-dialog = { workspace = true } +gemini-live = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } @@ -96,6 +100,8 @@ playback = { path = "services/playback" } aristech = { path = "services/aristech" } elevenlabs = { path = "services/elevenlabs" } google-transcribe = { path = "services/google-transcribe" } +google-dialog = { path = "services/google-dialog" } +gemini-live = "0.1.8" anyhow = "1.0.102" derive_more = { version = "2.1.1", features = ["full"] } diff --git a/examples/dialog.rs b/examples/dialog.rs new file mode 100644 index 0000000..e7893a4 --- /dev/null +++ b/examples/dialog.rs @@ -0,0 +1,442 @@ +//! A context switch demo. Runs locally, gets voice data from your current microphone. + +use std::{ + env, + num::{NonZeroU16, NonZeroU32}, + str::FromStr, + thread, + time::Duration, +}; + +use anyhow::{Context, Result, bail}; +use chrono::Utc; +use clap::{Parser, ValueEnum}; +use context_switch::{InputModality, OutputModality}; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use gemini_live::types as gemini_types; +use google_dialog::{GoogleDialog, ServiceInputEvent as GoogleServiceInputEvent}; +use openai_api_rs::realtime::types as openai_types; +use openai_dialog::{ + OpenAIDialog, Protocol, ServiceInputEvent as OpenAIServiceInputEvent, + ServiceOutputEvent as OpenAIServiceOutputEvent, +}; +use rodio::{DeviceSinkBuilder, Player, Source}; +use serde_json::json; +use tokio::{ + select, + sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}, +}; +use tracing::info; + +use context_switch_core::{ + AudioFormat, AudioFrame, Service, audio, + conversation::{Conversation, Input, Output}, +}; + +#[derive(Debug, Parser)] +struct Cli { + #[arg(value_enum)] + provider: Provider, + #[arg(long)] + endpoint: Option, + #[arg(long)] + model: Option, + #[arg(long)] + voice: Option, +} + +#[derive(Debug, Clone, Copy, ValueEnum)] +enum Provider { + #[value(name = "openai")] + OpenAI, + Azure, + Google, +} + +impl Provider { + fn output_format(self, input_format: AudioFormat) -> AudioFormat { + match self { + Provider::OpenAI | Provider::Azure => input_format, + Provider::Google => AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE), + } + } +} + +#[tokio::main] +async fn main() -> Result<()> { + let cli = Cli::parse(); + dotenvy::dotenv_override().context("Reading .env file")?; + tracing_subscriber::fmt::init(); + + let host = cpal::default_host(); + let device = host + .default_input_device() + .expect("Failed to get default input device"); + let input_config = device + .default_input_config() + .expect("Failed to get default input config"); + + println!("Audio device input config: {input_config:?}"); + + let channels = input_config.channels(); + let sample_rate = input_config.sample_rate(); + let input_format = AudioFormat::new(channels, sample_rate); + let output_format = cli.provider.output_format(input_format); + + let (input_sender, input_receiver) = channel(256); + let input_sender2 = input_sender.clone(); + + let stream = device + .build_input_stream( + &input_config.into(), + move |data: &[f32], _: &cpal::InputCallbackInfo| { + let samples = audio::into_i16(data); + let frame = AudioFrame { + format: input_format, + samples, + }; + if input_sender2.try_send(Input::Audio { frame }).is_err() { + println!("Failed to send audio data") + } + }, + move |err| { + eprintln!("Error occurred on stream: {err}"); + }, + Some(Duration::from_secs(1)), + ) + .expect("Failed to build input stream"); + + stream.play().expect("Failed to play stream"); + + let (output_sender, output_receiver) = unbounded_channel(); + let conversation = Conversation::new( + InputModality::Audio { + format: input_format, + }, + [OutputModality::Audio { + format: output_format, + }], + input_receiver, + output_sender, + ); + + let conversation = start_conversation(&cli, conversation); + tokio::pin!(conversation); + let playback_task = + setup_audio_playback(cli.provider, output_format, input_sender, output_receiver).await; + let mut playback_handle = tokio::spawn(playback_task); + + select! { + r = &mut conversation => { + let _ = playback_handle.await; + r? + } + r = &mut playback_handle => { + r?? + } + } + + Ok(()) +} + +async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> { + match cli.provider { + Provider::OpenAI | Provider::Azure => { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let model = cli + .model + .clone() + .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .context("Provide --model or set OPENAI_REALTIME_API_MODEL")?; + + let mut params = openai_dialog::Params::new(key, model); + params.host = cli + .endpoint + .clone() + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = Some(match cli.provider { + Provider::OpenAI => Protocol::OpenAI, + Provider::Azure => Protocol::Azure, + Provider::Google => unreachable!(), + }); + params.voice = cli + .voice + .as_deref() + .map(parse_realtime_voice_value) + .transpose()?; + params.tools.push(openai_get_time_function_definition()); + + OpenAIDialog.conversation(params, conversation).await + } + Provider::Google => { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let model = cli + .model + .clone() + .or_else(|| env::var("GEMINI_LIVE_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .unwrap_or_else(|| "gemini-3.1-flash-live-preview".to_owned()); + + let mut params = google_dialog::Params::new(key, model); + params.host = cli + .endpoint + .clone() + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.voice = cli.voice.clone(); + params.output_audio_transcription = true; + params.tools.push(gemini_get_time_tool()); + + GoogleDialog.conversation(params, conversation).await + } + } +} + +fn parse_realtime_voice_value(value: &str) -> Result { + openai_types::RealtimeVoice::from_str(value) + .map_err(|e| anyhow::anyhow!("Invalid voice value `{value}`: {e}")) +} + +enum AudioCommand { + PlayFrame(AudioFrame), + Clear, + Stop, +} + +async fn setup_audio_playback( + provider: Provider, + format: AudioFormat, + input: Sender, + mut output: UnboundedReceiver, +) -> impl std::future::Future> { + let (cmd_tx, cmd_rx) = std::sync::mpsc::channel(); + + let playback_thread = thread::spawn(move || { + let sink_handle = DeviceSinkBuilder::open_default_sink().unwrap(); + let player = Player::connect_new(sink_handle.mixer()); + + while let Ok(cmd) = cmd_rx.recv() { + match cmd { + AudioCommand::PlayFrame(frame) => { + let source = FrameSource { + frames: audio::from_i16(frame.samples), + position: 0, + sample_rate: format.sample_rate, + channels: format.channels, + }; + player.append(source); + } + AudioCommand::Clear => { + player.clear(); + player.play(); + } + AudioCommand::Stop => break, + } + } + + player.sleep_until_end(); + }); + + async move { + while let Some(output) = output.recv().await { + match output { + Output::ServiceStarted { .. } => {} + Output::Audio { frame } => { + if cmd_tx.send(AudioCommand::PlayFrame(frame)).is_err() { + break; + } + } + Output::Text { text, .. } => { + info!("Output text: {text}"); + } + Output::RequestCompleted { .. } => {} + Output::ClearAudio => { + if cmd_tx.send(AudioCommand::Clear).is_err() { + break; + } + } + Output::ServiceEvent { value, .. } => { + handle_service_event(provider, &input, value)?; + } + Output::BillingRecords { records, scope, .. } => { + info!("Billing: scope: {scope:?}, records: {records:?}"); + } + } + } + let _ = cmd_tx.send(AudioCommand::Stop); + let _ = playback_thread.join(); + Ok(()) + } +} + +fn handle_service_event( + provider: Provider, + input: &Sender, + value: serde_json::Value, +) -> Result<()> { + let call = match provider { + Provider::OpenAI | Provider::Azure => match serde_json::from_value(value)? { + OpenAIServiceOutputEvent::FunctionCall { + name, + call_id, + arguments, + } => Some(FunctionCall { + name, + call_id, + arguments, + }), + OpenAIServiceOutputEvent::SessionUpdated { tools } => { + info!("Session updated: {tools:?}"); + None + } + }, + Provider::Google => match serde_json::from_value(value)? { + google_dialog::ServiceOutputEvent::FunctionCall { + name, + call_id, + arguments, + } => Some(FunctionCall { + name, + call_id, + arguments: Some(arguments), + }), + google_dialog::ServiceOutputEvent::ToolCallCancellation { call_ids } => { + info!("Tool calls cancelled: {call_ids:?}"); + None + } + google_dialog::ServiceOutputEvent::SessionUpdated { tools } => { + info!("Session updated: {tools:?}"); + None + } + }, + }; + + if let Some(call) = call { + info!( + "Processing function `{}` with arguments `{:?}`", + call.name, call.arguments + ); + let result = call_function(&call.name, call.arguments)?; + info!("Function result: `{result}`"); + send_function_result(provider, input, call.call_id, call.name, result)?; + } + + Ok(()) +} + +fn send_function_result( + provider: Provider, + input: &Sender, + call_id: String, + name: String, + result: String, +) -> Result<()> { + let output = json!({ "time": serde_json::Value::String(result) }); + let value = match provider { + Provider::OpenAI | Provider::Azure => { + serde_json::to_value(&OpenAIServiceInputEvent::FunctionCallResult { call_id, output })? + } + Provider::Google => serde_json::to_value(&GoogleServiceInputEvent::FunctionCallResult { + call_id, + name, + output, + })?, + }; + input.try_send(Input::ServiceEvent { value })?; + Ok(()) +} + +#[derive(Debug)] +struct FunctionCall { + call_id: String, + name: String, + arguments: Option, +} + +fn openai_get_time_function_definition() -> openai_types::ToolDefinition { + openai_types::ToolDefinition::Function { + name: "get_time".into(), + description: "The current time to the exact second.".into(), + parameters: get_time_parameters_schema(), + } +} + +fn gemini_get_time_tool() -> gemini_types::Tool { + gemini_types::Tool::FunctionDeclarations(vec![gemini_types::FunctionDeclaration { + name: "get_time".into(), + description: "The current time to the exact second.".into(), + parameters: get_time_parameters_schema(), + scheduling: None, + behavior: None, + }]) +} + +fn get_time_parameters_schema() -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "IANA time zone identifier of the region and city." + } + }, + "required": ["location"] + }) +} + +fn call_function(name: &str, arguments: Option) -> Result { + let arguments = arguments.context("No arguments provided for function call")?; + if name != "get_time" { + bail!("Unknown function: {name}"); + } + let location = arguments["location"] + .as_str() + .context("Invalid or missing 'location' field in arguments")?; + let tz = chrono_tz::Tz::from_str(location) + .with_context(|| format!("Unknown time zone: {location}"))?; + + let now = Utc::now().with_timezone(&tz); + Ok(now.format("%H:%M:%S").to_string()) +} + +struct FrameSource { + frames: Vec, + position: usize, + sample_rate: u32, + channels: u16, +} + +impl Iterator for FrameSource { + type Item = f32; + + fn next(&mut self) -> Option { + if self.position >= self.frames.len() { + None + } else { + let sample = self.frames[self.position]; + self.position += 1; + Some(sample) + } + } +} + +impl Source for FrameSource { + fn current_span_len(&self) -> Option { + Some(self.frames.len() - self.position) + } + + fn channels(&self) -> NonZeroU16 { + NonZeroU16::new(self.channels).expect("channels must be non-zero") + } + + fn sample_rate(&self) -> NonZeroU32 { + NonZeroU32::new(self.sample_rate).expect("sample rate must be non-zero") + } + + fn total_duration(&self) -> Option { + let seconds = self.frames.len() as f32 / (self.sample_rate as f32 * self.channels as f32); + Some(Duration::from_secs_f32(seconds)) + } +} diff --git a/examples/openai-dialog.rs b/examples/openai-dialog.rs deleted file mode 100644 index 47186e7..0000000 --- a/examples/openai-dialog.rs +++ /dev/null @@ -1,330 +0,0 @@ -//! A context switch demo. Runs locally, gets voice data from your current microphone. - -use std::{ - env, - num::{NonZeroU16, NonZeroU32}, - str::FromStr, - thread, - time::Duration, -}; - -use anyhow::{Context, Result, bail}; -use chrono::Utc; -use clap::builder::{PossibleValuesParser, TypedValueParser}; -use clap::{Parser, ValueEnum}; -use context_switch::{InputModality, OutputModality}; -use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; -use openai_api_rs::realtime::types; -use openai_dialog::{OpenAIDialog, Protocol, ServiceInputEvent, ServiceOutputEvent}; -use rodio::{DeviceSinkBuilder, Player, Source}; -use serde_json::json; -use tokio::{ - select, - sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}, -}; -use tracing::info; - -use context_switch_core::{ - AudioFormat, AudioFrame, Service, audio, - conversation::{Conversation, Input, Output}, -}; - -#[derive(Debug, Parser)] -struct Cli { - #[arg(long, value_enum)] - protocol: Option, - #[arg(long)] - endpoint: Option, - #[arg(long)] - model: Option, - #[arg(long, value_parser = realtime_voice_value_parser())] - voice: Option, -} - -#[derive(Debug, Clone, Copy, ValueEnum)] -enum CliProtocol { - #[value(name = "openai")] - OpenAI, - Azure, -} - -fn realtime_voice_value_parser() -> impl TypedValueParser { - PossibleValuesParser::new(::VARIANTS).try_map( - |value| { - parse_realtime_voice_value(&value) - .map_err(|e| format!("Invalid voice value `{value}`: {e}")) - }, - ) -} - -fn parse_realtime_voice_value(value: &str) -> Result { - types::RealtimeVoice::from_str(value) -} - -impl From for Protocol { - fn from(value: CliProtocol) -> Self { - match value { - CliProtocol::OpenAI => Protocol::OpenAI, - CliProtocol::Azure => Protocol::Azure, - } - } -} - -#[tokio::main] -async fn main() -> Result<()> { - let cli = Cli::parse(); - dotenvy::dotenv_override().context("Reading .env file")?; - tracing_subscriber::fmt::init(); - - let host = cpal::default_host(); - let device = host - .default_input_device() - .expect("Failed to get default input device"); - let input_config = device - .default_input_config() - .expect("Failed to get default input config"); - - println!("Audio device input config: {input_config:?}"); - - let channels = input_config.channels(); - let sample_rate = input_config.sample_rate(); - let format = AudioFormat::new(channels, sample_rate); - - let (input_sender, input_receiver) = channel(256); - - let input_sender2 = input_sender.clone(); - - // Create and run the input stream - let stream = device - .build_input_stream( - &input_config.into(), - move |data: &[f32], _: &cpal::InputCallbackInfo| { - let samples = audio::into_i16(data); - let frame = AudioFrame { format, samples }; - if input_sender2.try_send(Input::Audio { frame }).is_err() { - println!("Failed to send audio data") - } - }, - move |err| { - eprintln!("Error occurred on stream: {err}"); - }, - // timeout - Some(Duration::from_secs(1)), - ) - .expect("Failed to build input stream"); - - stream.play().expect("Failed to play stream"); - - let key = env::var("OPENAI_API_KEY").unwrap(); - let model = cli - .model - .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) - .filter(|model| !model.trim().is_empty()) - .context("Provide --model or set OPENAI_REALTIME_API_MODEL")?; - - let openai = OpenAIDialog; - let mut params = openai_dialog::Params::new(key, model); - params.host = cli - .endpoint - .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) - .filter(|endpoint| !endpoint.trim().is_empty()); - params.protocol = cli.protocol.map(Into::into); - params.voice = cli.voice; - params.tools.push(get_time_function_definition()); - - let (output_sender, output_receiver) = unbounded_channel(); - - let conversation = Conversation::new( - InputModality::Audio { format }, - [OutputModality::Audio { format }], - input_receiver, - output_sender, - ); - - let mut conversation = openai.conversation(params, conversation); - - let playback_task = setup_audio_playback(format, input_sender, output_receiver).await; - - // Spawn audio playback task - let mut playback_handle = tokio::spawn(playback_task); - - select! { - // Drive conversation - r = &mut conversation => { - // When conversation ends, wait for playback to complete before returning. - let _ = playback_handle.await; - r? - } - - // Drive playback - r = &mut playback_handle => { - r?? - } - - } - - Ok(()) -} - -enum AudioCommand { - PlayFrame(AudioFrame), - Clear, - Stop, -} - -async fn setup_audio_playback( - format: AudioFormat, - input: Sender, - mut output: UnboundedReceiver, -) -> impl std::future::Future> { - let (cmd_tx, cmd_rx) = std::sync::mpsc::channel(); - - // Spawn a dedicated audio thread - let playback_thread = thread::spawn(move || { - // Create output stream in the audio thread - - let sink_handle = DeviceSinkBuilder::open_default_sink().unwrap(); - let player = Player::connect_new(sink_handle.mixer()); - - while let Ok(cmd) = cmd_rx.recv() { - match cmd { - AudioCommand::PlayFrame(frame) => { - let source = FrameSource { - frames: audio::from_i16(frame.samples), - position: 0, - sample_rate: format.sample_rate, - channels: format.channels, - }; - player.append(source); - } - AudioCommand::Clear => { - player.clear(); - player.play(); - } - AudioCommand::Stop => break, - } - } - - player.sleep_until_end(); - }); - - // Create async task to forward frames to the audio thread - - async move { - while let Some(output) = output.recv().await { - match output { - Output::ServiceStarted { .. } => {} - Output::Audio { frame } => { - if cmd_tx.send(AudioCommand::PlayFrame(frame)).is_err() { - break; - } - } - Output::Text { .. } | Output::RequestCompleted { .. } => {} - Output::ClearAudio => { - if cmd_tx.send(AudioCommand::Clear).is_err() { - break; - } - } - Output::ServiceEvent { value, .. } => match serde_json::from_value(value)? { - ServiceOutputEvent::FunctionCall { - name, - call_id, - arguments, - } => { - info!("Processing function `{name}` with arguments `{arguments:?}`"); - let result = call_function(&name, arguments)?; - info!("Function result: `{result}`"); - let value = ServiceInputEvent::FunctionCallResult { - call_id, - output: json! ({ "time": serde_json::Value::String(result) }), - }; - let value = serde_json::to_value(&value)?; - input.try_send(Input::ServiceEvent { value })?; - } - ServiceOutputEvent::SessionUpdated { tools } => { - info!("Session Updated: {tools:?}"); - } - }, - Output::BillingRecords { records, scope, .. } => { - info!("Billing: scope: {scope:?}, records: {records:?}"); - } - } - } - let _ = cmd_tx.send(AudioCommand::Stop); - // TODO: this may block! - let _ = playback_thread.join(); - Ok(()) - } -} - -fn get_time_function_definition() -> types::ToolDefinition { - types::ToolDefinition::Function { - name: "get_time".into(), - description: "The current time to the exact second.".into(), - parameters: json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "IANA time zone identifier of the region and city." - } - }, - "required": ["location"] - }), - } -} - -fn call_function(name: &str, arguments: Option) -> Result { - let arguments = arguments.context("No arguments provided for function call")?; - if name != "get_time" { - bail!("Unknown function: {name}"); - } - let location = arguments["location"] - .as_str() - .context("Invalid or missing 'location' field in arguments")?; - let tz = chrono_tz::Tz::from_str(location) - .with_context(|| format!("Unknown time zone: {location}"))?; - - let now = Utc::now().with_timezone(&tz); - Ok(now.format("%H:%M:%S").to_string()) -} - -struct FrameSource { - frames: Vec, - position: usize, - sample_rate: u32, - channels: u16, -} - -impl Iterator for FrameSource { - type Item = f32; - - fn next(&mut self) -> Option { - if self.position >= self.frames.len() { - None - } else { - let sample = self.frames[self.position]; - self.position += 1; - Some(sample) - } - } -} - -impl Source for FrameSource { - fn current_span_len(&self) -> Option { - Some(self.frames.len() - self.position) - } - - fn channels(&self) -> NonZeroU16 { - NonZeroU16::new(self.channels).expect("channels must be non-zero") - } - - fn sample_rate(&self) -> NonZeroU32 { - NonZeroU32::new(self.sample_rate).expect("sample rate must be non-zero") - } - - fn total_duration(&self) -> Option { - let seconds = self.frames.len() as f32 / (self.sample_rate as f32 * self.channels as f32); - Some(Duration::from_secs_f32(seconds)) - } -} diff --git a/services/google-dialog/Cargo.toml b/services/google-dialog/Cargo.toml new file mode 100644 index 0000000..3be04a9 --- /dev/null +++ b/services/google-dialog/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "google-dialog" +version = "0.1.0" +edition.workspace = true + +[dependencies] +context-switch-core = { workspace = true } + +gemini-live = { workspace = true } + +anyhow = { workspace = true } +async-trait = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs new file mode 100644 index 0000000..4b8c519 --- /dev/null +++ b/services/google-dialog/src/client.rs @@ -0,0 +1,318 @@ +use anyhow::{Context, Result, bail}; +use context_switch_core::{ + AudioFormat, AudioFrame, BillingRecord, OutputPath, + conversation::{BillingSchedule, ConversationInput, ConversationOutput, Input}, +}; +use gemini_live::{ + ReconnectPolicy, Session, SessionConfig, + transport::{Auth, Endpoint, TransportConfig}, + types::{ + AudioTranscriptionConfig, Content, FunctionDeclaration, FunctionResponse, GenerationConfig, + Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, ServerEvent, SetupConfig, + SpeechConfig, Tool, UsageMetadata, VoiceConfig, + }, +}; +use tracing::{debug, info, trace}; + +use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; + +pub struct Client { + params: Params, +} + +impl Client { + pub fn new(params: Params) -> Self { + Self { params } + } + + pub async fn dialog( + self, + _input_format: AudioFormat, + output_format: AudioFormat, + output_transcription: bool, + mut input: ConversationInput, + output: ConversationOutput, + ) -> Result<()> { + let billing_scope = self.params.model.clone(); + let tools = function_declarations(&self.params.tools); + let mut session = Session::connect(self.session_config(output_transcription)) + .await + .context("Connecting to Gemini Live")?; + output.service_event( + OutputPath::Control, + ServiceOutputEvent::SessionUpdated { tools }, + )?; + + loop { + tokio::select! { + input = input.recv() => { + if let Some(input) = input { + self.process_input(&session, input).await?; + } else { + session.audio_stream_end().await.context("Ending Gemini audio stream")?; + break; + } + } + event = session.next_event() => { + match event { + Some(event) => { + match self.process_event(event, output_format, output_transcription, &output, &billing_scope).await? { + FlowControl::Continue => {} + FlowControl::End => break, + } + } + None => break, + } + } + } + } + + session + .close() + .await + .context("Closing Gemini Live session")?; + Ok(()) + } + + fn session_config(&self, output_transcription: bool) -> SessionConfig { + let transport = TransportConfig { + endpoint: self + .params + .host + .clone() + .map(Endpoint::Custom) + .unwrap_or_default(), + auth: Auth::ApiKey(self.params.api_key.clone()), + ..Default::default() + }; + + SessionConfig { + transport, + setup: self.setup_config(output_transcription), + reconnect: ReconnectPolicy::default(), + } + } + + fn setup_config(&self, output_transcription: bool) -> SetupConfig { + SetupConfig { + model: model_resource_name(&self.params.model), + generation_config: Some(GenerationConfig { + response_modalities: Some(vec![Modality::Audio]), + speech_config: self.params.voice.clone().map(|voice_name| SpeechConfig { + voice_config: VoiceConfig { + prebuilt_voice_config: PrebuiltVoiceConfig { voice_name }, + }, + }), + ..Default::default() + }), + system_instruction: self.params.instructions.clone().map(system_instruction), + tools: (!self.params.tools.is_empty()).then(|| self.params.tools.clone()), + realtime_input_config: self.params.realtime_input_config.clone(), + input_audio_transcription: self + .params + .input_audio_transcription + .then_some(AudioTranscriptionConfig {}), + output_audio_transcription: (output_transcription_enabled(&self.params)) + .then_some(AudioTranscriptionConfig {}) + .or_else(|| output_transcription.then_some(AudioTranscriptionConfig {})), + ..Default::default() + } + } + + async fn process_input(&self, session: &Session, input: Input) -> Result<()> { + match input { + Input::Audio { frame } => { + let mono = frame.into_mono(); + let sample_rate = mono.format.sample_rate; + let audio = mono.to_le_bytes(); + session + .send_audio_at_rate(&audio, sample_rate) + .await + .context("Sending audio to Gemini Live")?; + } + Input::Text { text, .. } => { + session + .send_text(&text) + .await + .context("Sending text to Gemini Live")?; + } + Input::ServiceEvent { value } => match serde_json::from_value(value)? { + ServiceInputEvent::FunctionCallResult { + call_id, + name, + output, + } => { + let response = FunctionResponse { + id: call_id, + name, + response: output, + }; + session + .send_tool_response(vec![response]) + .await + .context("Sending Gemini tool response")?; + } + ServiceInputEvent::Prompt { text } => { + info!("Received prompt"); + session + .send_text(&text) + .await + .context("Sending prompt to Gemini Live")?; + } + }, + } + Ok(()) + } + + async fn process_event( + &self, + event: ServerEvent, + output_format: AudioFormat, + output_transcription: bool, + output: &ConversationOutput, + billing_scope: &str, + ) -> Result { + trace!(?event, "Gemini Live event"); + match event { + ServerEvent::SetupComplete => {} + ServerEvent::ModelText(text) => { + if output_transcription { + output.text(true, text, None, None)?; + } + } + ServerEvent::ModelAudio(audio) => { + let frame = AudioFrame::from_le_bytes(output_format, &audio); + output.audio_frame(frame)?; + } + ServerEvent::GenerationComplete => {} + ServerEvent::TurnComplete => { + output.request_completed(None)?; + } + ServerEvent::Interrupted => { + output.clear_audio()?; + } + ServerEvent::InputTranscription(text) => { + debug!(%text, "Gemini input transcription"); + } + ServerEvent::OutputTranscription(text) => { + if output_transcription { + output.text(true, text, None, None)?; + } + } + ServerEvent::ToolCall(calls) => { + for call in calls { + output.service_event( + OutputPath::Media, + ServiceOutputEvent::FunctionCall { + call_id: call.id, + name: call.name, + arguments: call.args, + }, + )?; + } + } + ServerEvent::ToolCallCancellation(ids) => { + output.service_event( + OutputPath::Control, + ServiceOutputEvent::ToolCallCancellation { call_ids: ids }, + )?; + } + ServerEvent::SessionResumption { .. } => {} + ServerEvent::GoAway { time_left } => { + debug!(?time_left, "Gemini Live goAway received"); + } + ServerEvent::Usage(usage) => { + bill_usage(output, billing_scope, usage)?; + } + ServerEvent::Closed { reason } => { + if !reason.is_empty() { + debug!(%reason, "Gemini Live connection closed"); + } + return Ok(FlowControl::End); + } + ServerEvent::Error(error) => { + bail!("Gemini Live error: {}", error.message); + } + } + Ok(FlowControl::Continue) + } +} + +fn model_resource_name(model: &str) -> String { + if model.starts_with("models/") { + model.to_owned() + } else { + format!("models/{model}") + } +} + +fn system_instruction(text: String) -> Content { + Content { + role: None, + parts: vec![Part { + text: Some(text), + inline_data: None, + }], + } +} + +fn output_transcription_enabled(params: &Params) -> bool { + params.output_audio_transcription +} + +fn function_declarations(tools: &[Tool]) -> Option> { + let declarations: Vec<_> = tools + .iter() + .filter_map(|tool| match tool { + Tool::FunctionDeclarations(declarations) => Some(declarations.as_slice()), + Tool::GoogleSearch(_) => None, + }) + .flatten() + .cloned() + .collect(); + + (!declarations.is_empty()).then_some(declarations) +} + +fn bill_usage( + output: &ConversationOutput, + billing_scope: &str, + usage: UsageMetadata, +) -> Result<()> { + let prompt_audio = modality_count(&usage.prompt_tokens_details, "AUDIO"); + let prompt_text = modality_count(&usage.prompt_tokens_details, "TEXT"); + let response_audio = modality_count(&usage.response_tokens_details, "AUDIO"); + let response_text = modality_count(&usage.response_tokens_details, "TEXT"); + + let records = [ + BillingRecord::count("tokens:input:audio", prompt_audio), + BillingRecord::count("tokens:input:text", prompt_text), + BillingRecord::count("tokens:input:cached", usage.cached_content_token_count as _), + BillingRecord::count("tokens:input:tool", usage.tool_use_prompt_token_count as _), + BillingRecord::count("tokens:output:audio", response_audio), + BillingRecord::count("tokens:output:text", response_text), + BillingRecord::count("tokens:thoughts", usage.thoughts_token_count as _), + ]; + + output.billing_records( + None, + Some(billing_scope.into()), + records, + BillingSchedule::Now, + )?; + Ok(()) +} + +fn modality_count(details: &Option>, modality: &str) -> usize { + details + .iter() + .flatten() + .filter(|detail| detail.modality.eq_ignore_ascii_case(modality)) + .map(|detail| detail.token_count as usize) + .sum() +} + +enum FlowControl { + Continue, + End, +} diff --git a/services/google-dialog/src/lib.rs b/services/google-dialog/src/lib.rs new file mode 100644 index 0000000..3a2608c --- /dev/null +++ b/services/google-dialog/src/lib.rs @@ -0,0 +1,48 @@ +//! Gemini Live audio dialog service. + +use anyhow::{Result, bail}; +use async_trait::async_trait; +use tracing::info; + +use context_switch_core::{AudioFormat, Service, conversation::Conversation}; + +mod client; +mod types; + +pub use client::Client; +pub use types::{Params, ServiceInputEvent, ServiceOutputEvent}; + +#[derive(Debug)] +pub struct GoogleDialog; + +#[async_trait] +impl Service for GoogleDialog { + type Params = Params; + + async fn conversation(&self, params: Params, conversation: Conversation) -> Result<()> { + let input_format = conversation.require_audio_input()?; + let output_format = conversation.require_one_audio_output()?; + let output_transcription = conversation.has_one_text_output()?; + + let expected_output = AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE); + if output_format != expected_output { + bail!( + "Audio output has the wrong format {:?}, expected: {:?}", + output_format, + expected_output + ); + } + + info!(model = %params.model, "Connecting to Gemini Live"); + let (input, output) = conversation.start()?; + Client::new(params) + .dialog( + input_format, + output_format, + output_transcription, + input, + output, + ) + .await + } +} diff --git a/services/google-dialog/src/types.rs b/services/google-dialog/src/types.rs new file mode 100644 index 0000000..184c665 --- /dev/null +++ b/services/google-dialog/src/types.rs @@ -0,0 +1,66 @@ +use gemini_live::types::{FunctionDeclaration, RealtimeInputConfig, Tool}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Params { + pub api_key: String, + pub model: String, + pub host: Option, + pub instructions: Option, + pub voice: Option, + #[serde(default)] + pub tools: Vec, + pub realtime_input_config: Option, + #[serde(default)] + pub input_audio_transcription: bool, + #[serde(default)] + pub output_audio_transcription: bool, +} + +impl Params { + pub fn new(api_key: impl Into, model: impl Into) -> Self { + Self { + api_key: api_key.into(), + model: model.into(), + host: None, + instructions: None, + voice: None, + tools: vec![], + realtime_input_config: None, + input_audio_transcription: false, + output_audio_transcription: false, + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ServiceInputEvent { + #[serde(rename_all = "camelCase")] + FunctionCallResult { + call_id: String, + name: String, + output: serde_json::Value, + }, + Prompt { + text: String, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ServiceOutputEvent { + #[serde(rename_all = "camelCase")] + FunctionCall { + call_id: String, + name: String, + arguments: serde_json::Value, + }, + #[serde(rename_all = "camelCase")] + ToolCallCancellation { call_ids: Vec }, + SessionUpdated { + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + }, +} diff --git a/src/context_switch.rs b/src/context_switch.rs index d0d10e1..181d82d 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -42,6 +42,7 @@ pub fn registry() -> Registry { .add_service("azure-synthesize", azure::AzureSynthesize) .add_service("azure-translate", azure::AzureTranslate) .add_service("elevenlabs-transcribe", elevenlabs::ElevenLabsTranscribe) + .add_service("google-dialog", google_dialog::GoogleDialog) .add_service("google-transcribe", google_transcribe::GoogleTranscribe) .add_service("openai-dialog", openai_dialog::OpenAIDialog) .add_service("aristech-transcribe", aristech::AristechTranscribe) diff --git a/src/lib.rs b/src/lib.rs index 2cdc8e7..51e3862 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,5 +15,6 @@ pub mod services { pub use aristech::AristechTranscribe; pub use azure::AzureTranscribe; pub use elevenlabs::ElevenLabsTranscribe; + pub use google_dialog::GoogleDialog; pub use google_transcribe::GoogleTranscribe; } From 95d379805781748059a046ce5057e32ebedc7d98 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 08:50:34 +0200 Subject: [PATCH 02/20] google-dialog: Review billing --- services/google-dialog/src/client.rs | 39 +++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 4b8c519..fa35804 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -279,16 +279,43 @@ fn bill_usage( billing_scope: &str, usage: UsageMetadata, ) -> Result<()> { - let prompt_audio = modality_count(&usage.prompt_tokens_details, "AUDIO"); - let prompt_text = modality_count(&usage.prompt_tokens_details, "TEXT"); + // TODO(gemini-live): Move input billing back to disjoint modality buckets once + // `UsageMetadata` exposes `cache_tokens_details` and + // `tool_use_prompt_tokens_details` from the Live API protocol. + // + // Why this is needed: + // - Input audio and input text may have different prices. + // - `prompt_tokens_details` already includes cached/tool prompt usage. + // - Without cached/tool modality details, splitting prompt input into + // audio/text can overlap with cached/tool categories and double count. + // + // Current fallback keeps input categories disjoint, but loses input + // modality precision: + // - tokens:input + // - tokens:input:cached + // - tokens:input:tool + // + // Target mapping after gemini-live exposes the missing fields: + // - tokens:input:audio = prompt_audio - cached_audio - tool_audio + // - tokens:input:text = prompt_text - cached_text - tool_text + // - tokens:input:audio:cached = cached_audio + // - tokens:input:text:cached = cached_text + // - tokens:input:audio:tool = tool_audio + // - tokens:input:text:tool = tool_text + let prompt_total = usage.prompt_token_count as usize; + let cached_total = usage.cached_content_token_count as usize; + let tool_total = usage.tool_use_prompt_token_count as usize; let response_audio = modality_count(&usage.response_tokens_details, "AUDIO"); let response_text = modality_count(&usage.response_tokens_details, "TEXT"); + let prompt_uncached_untool = prompt_total + .checked_sub(cached_total + tool_total) + .context("Invalid Gemini usage: prompt tokens less than cached+tool prompt tokens")?; + let records = [ - BillingRecord::count("tokens:input:audio", prompt_audio), - BillingRecord::count("tokens:input:text", prompt_text), - BillingRecord::count("tokens:input:cached", usage.cached_content_token_count as _), - BillingRecord::count("tokens:input:tool", usage.tool_use_prompt_token_count as _), + BillingRecord::count("tokens:input", prompt_uncached_untool), + BillingRecord::count("tokens:input:cached", cached_total), + BillingRecord::count("tokens:input:tool", tool_total), BillingRecord::count("tokens:output:audio", response_audio), BillingRecord::count("tokens:output:text", response_text), BillingRecord::count("tokens:thoughts", usage.thoughts_token_count as _), From c4e68b76f85ec65ed61d4d093d4ce43a4ddee0ff Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 08:51:05 +0200 Subject: [PATCH 03/20] Add GEMINI_API_KEY to the example env --- .env.example | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.env.example b/.env.example index 64f9267..3e62a67 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,7 @@ +# Gemini + +GEMINI_API_KEY=your_gemini_api_key + # OpenAI Configuration OPENAI_API_KEY=your_openai_key OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview From c6d3c260a3d0af9b6959e385a70937d8004c86e0 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 09:08:51 +0200 Subject: [PATCH 04/20] Fix gemini billing records generation --- .gitmodules | 5 +++ Cargo.toml | 20 +++++++++-- external/gemini-live-rs | 1 + services/google-dialog/src/client.rs | 50 ++++++++++------------------ 4 files changed, 42 insertions(+), 34 deletions(-) create mode 160000 external/gemini-live-rs diff --git a/.gitmodules b/.gitmodules index aeb60dc..632287d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,3 +6,8 @@ path = external/openai-api-rs url = ../openai-api-rs.git branch = "context-switch-v9.0.0" +[submodule "external/gemini-live-rs"] + path = external/gemini-live-rs + url = ../gemini-live-rs.git + branch = context-switch + diff --git a/Cargo.toml b/Cargo.toml index 12ba84b..c1d1b44 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -22,6 +22,8 @@ members = [ [workspace.package] version = "2.3.0" edition = "2024" +license = "MIT" +repository = "https://github.com/pragmatrix/context-switch" [dependencies] @@ -101,13 +103,27 @@ aristech = { path = "services/aristech" } elevenlabs = { path = "services/elevenlabs" } google-transcribe = { path = "services/google-transcribe" } google-dialog = { path = "services/google-dialog" } -gemini-live = "0.1.8" +gemini-live = { path = "external/gemini-live-rs/crates/gemini-live" } + +# Dependencies required by `external/gemini-live-rs/crates/gemini-live`. +# The submodule crate inherits these via `workspace = true`, so we keep them +# centralized here and grouped to make future sync/review straightforward. +tokio-tungstenite = { version = "0.29", features = ["rustls-tls-webpki-roots"] } +futures-util = "0.3" +bytes = "1.11" +thiserror = "2.0" +rustls = { version = "0.23", features = ["ring"], default-features = false } +google-cloud-auth = { version = "1.9.0", default-features = false } anyhow = "1.0.102" derive_more = { version = "2.1.1", features = ["full"] } static_assertions = "1.1.0" async-stream = { version = "0.3.6" } -tokio = { version = "1.50.0", features = ["sync"] } +# Tokio features are intentionally explicit: +# - sync: channels/mutexes used throughout services +# - rt + macros: runtime and `tokio::select!`/task macros used by gemini-live +# - time: timeout/sleep used by gemini-live session/transport logic +tokio = { version = "1.50.0", features = ["sync", "rt", "macros", "time"] } futures = "0.3.31" serde = { version = "1.0.215", features = ["derive"] } serde_json = "1.0.149" diff --git a/external/gemini-live-rs b/external/gemini-live-rs new file mode 160000 index 0000000..4674ae2 --- /dev/null +++ b/external/gemini-live-rs @@ -0,0 +1 @@ +Subproject commit 4674ae228730ad75b24d088eee3228a4d5036ae1 diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index fa35804..fe6b5a0 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -279,43 +279,29 @@ fn bill_usage( billing_scope: &str, usage: UsageMetadata, ) -> Result<()> { - // TODO(gemini-live): Move input billing back to disjoint modality buckets once - // `UsageMetadata` exposes `cache_tokens_details` and - // `tool_use_prompt_tokens_details` from the Live API protocol. - // - // Why this is needed: - // - Input audio and input text may have different prices. - // - `prompt_tokens_details` already includes cached/tool prompt usage. - // - Without cached/tool modality details, splitting prompt input into - // audio/text can overlap with cached/tool categories and double count. - // - // Current fallback keeps input categories disjoint, but loses input - // modality precision: - // - tokens:input - // - tokens:input:cached - // - tokens:input:tool - // - // Target mapping after gemini-live exposes the missing fields: - // - tokens:input:audio = prompt_audio - cached_audio - tool_audio - // - tokens:input:text = prompt_text - cached_text - tool_text - // - tokens:input:audio:cached = cached_audio - // - tokens:input:text:cached = cached_text - // - tokens:input:audio:tool = tool_audio - // - tokens:input:text:tool = tool_text - let prompt_total = usage.prompt_token_count as usize; - let cached_total = usage.cached_content_token_count as usize; - let tool_total = usage.tool_use_prompt_token_count as usize; + let prompt_audio_total = modality_count(&usage.prompt_tokens_details, "AUDIO"); + let prompt_text_total = modality_count(&usage.prompt_tokens_details, "TEXT"); + let cached_audio = modality_count(&usage.cache_tokens_details, "AUDIO"); + let cached_text = modality_count(&usage.cache_tokens_details, "TEXT"); + let tool_audio = modality_count(&usage.tool_use_prompt_tokens_details, "AUDIO"); + let tool_text = modality_count(&usage.tool_use_prompt_tokens_details, "TEXT"); let response_audio = modality_count(&usage.response_tokens_details, "AUDIO"); let response_text = modality_count(&usage.response_tokens_details, "TEXT"); - let prompt_uncached_untool = prompt_total - .checked_sub(cached_total + tool_total) - .context("Invalid Gemini usage: prompt tokens less than cached+tool prompt tokens")?; + let prompt_audio = prompt_audio_total + .checked_sub(cached_audio + tool_audio) + .context("Invalid Gemini usage: prompt audio tokens less than cached+tool audio tokens")?; + let prompt_text = prompt_text_total + .checked_sub(cached_text + tool_text) + .context("Invalid Gemini usage: prompt text tokens less than cached+tool text tokens")?; let records = [ - BillingRecord::count("tokens:input", prompt_uncached_untool), - BillingRecord::count("tokens:input:cached", cached_total), - BillingRecord::count("tokens:input:tool", tool_total), + BillingRecord::count("tokens:input:audio", prompt_audio), + BillingRecord::count("tokens:input:text", prompt_text), + BillingRecord::count("tokens:input:audio:cached", cached_audio), + BillingRecord::count("tokens:input:text:cached", cached_text), + BillingRecord::count("tokens:input:audio:tool", tool_audio), + BillingRecord::count("tokens:input:text:tool", tool_text), BillingRecord::count("tokens:output:audio", response_audio), BillingRecord::count("tokens:output:text", response_text), BillingRecord::count("tokens:thoughts", usage.thoughts_token_count as _), From 4433ee7ca649a152e826b7049f8d6a05a39dbf67 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 09:27:39 +0200 Subject: [PATCH 05/20] Add --list-voices to the dialog example --- examples/dialog.rs | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/examples/dialog.rs b/examples/dialog.rs index e7893a4..dcd8038 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -22,6 +22,7 @@ use openai_dialog::{ }; use rodio::{DeviceSinkBuilder, Player, Source}; use serde_json::json; +use strum::VariantNames; use tokio::{ select, sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}, @@ -38,6 +39,8 @@ struct Cli { #[arg(value_enum)] provider: Provider, #[arg(long)] + list_voices: bool, + #[arg(long)] endpoint: Option, #[arg(long)] model: Option, @@ -65,6 +68,12 @@ impl Provider { #[tokio::main] async fn main() -> Result<()> { let cli = Cli::parse(); + + if cli.list_voices { + list_available_voices(cli.provider)?; + return Ok(()); + } + dotenvy::dotenv_override().context("Reading .env file")?; tracing_subscriber::fmt::init(); @@ -139,6 +148,21 @@ async fn main() -> Result<()> { Ok(()) } +fn list_available_voices(provider: Provider) -> Result<()> { + match provider { + Provider::OpenAI | Provider::Azure => { + println!("Available voices for {:?}:", provider); + for voice in ::VARIANTS { + println!("- {voice}"); + } + Ok(()) + } + Provider::Google => { + bail!("Voice listing is only available for openai and azure providers") + } + } +} + async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> { match cli.provider { Provider::OpenAI | Provider::Azure => { From f468b025204b682a6d86a871147e723b9de3fd5d Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 09:54:40 +0200 Subject: [PATCH 06/20] dialog: Improve support for --list-voices and --list-models --- Cargo.toml | 1 + examples/dialog.rs | 264 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 263 insertions(+), 2 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index c1d1b44..a714fe2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -81,6 +81,7 @@ google-dialog = { workspace = true } gemini-live = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } +reqwest = { workspace = true } # For advanced params in openai-dialog openai-api-rs = { workspace = true } diff --git a/examples/dialog.rs b/examples/dialog.rs index dcd8038..a85c7e3 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -20,7 +20,9 @@ use openai_dialog::{ OpenAIDialog, Protocol, ServiceInputEvent as OpenAIServiceInputEvent, ServiceOutputEvent as OpenAIServiceOutputEvent, }; +use reqwest::Url; use rodio::{DeviceSinkBuilder, Player, Source}; +use serde::Deserialize; use serde_json::json; use strum::VariantNames; use tokio::{ @@ -39,6 +41,8 @@ struct Cli { #[arg(value_enum)] provider: Provider, #[arg(long)] + list_models: bool, + #[arg(long)] list_voices: bool, #[arg(long)] endpoint: Option, @@ -69,6 +73,12 @@ impl Provider { async fn main() -> Result<()> { let cli = Cli::parse(); + if cli.list_models { + let _ = dotenvy::dotenv_override(); + list_available_models(&cli).await?; + return Ok(()); + } + if cli.list_voices { list_available_voices(cli.provider)?; return Ok(()); @@ -158,11 +168,240 @@ fn list_available_voices(provider: Provider) -> Result<()> { Ok(()) } Provider::Google => { - bail!("Voice listing is only available for openai and azure providers") + println!("Available voices for {:?}:", provider); + for voice in GEMINI_VOICES { + println!("- {voice}"); + } + Ok(()) } } } +async fn list_available_models(cli: &Cli) -> Result<()> { + match cli.provider { + Provider::OpenAI => list_openai_models(cli).await, + Provider::Azure => list_azure_models(cli), + Provider::Google => list_google_models(cli).await, + } +} + +async fn list_openai_models(cli: &Cli) -> Result<()> { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let endpoint = cli + .endpoint + .clone() + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()); + let models_url = openai_models_url(endpoint.as_deref())?; + + let response: OpenAIModelsResponse = reqwest::Client::new() + .get(models_url) + .bearer_auth(key) + .send() + .await + .context("Requesting OpenAI models")? + .error_for_status() + .context("OpenAI models request failed")? + .json() + .await + .context("Decoding OpenAI models response")?; + + let mut models: Vec<_> = response + .data + .into_iter() + .map(|m| m.id) + .filter(|id| is_openai_realtime_model(id)) + .collect(); + models.sort(); + + println!("Available models for OpenAI:"); + if models.is_empty() { + println!("- No realtime-capable models were returned by the models endpoint."); + println!("- Ensure your API key has access to OpenAI Realtime API models."); + } else { + for model in models { + println!("- {model}"); + } + } + Ok(()) +} + +fn is_openai_realtime_model(model_id: &str) -> bool { + model_id.to_ascii_lowercase().contains("realtime") +} + +fn list_azure_models(cli: &Cli) -> Result<()> { + println!("Available models for Azure:"); + println!("- Azure Realtime uses deployment names configured in your Azure OpenAI resource."); + println!("- The realtime endpoint does not provide a provider-agnostic model listing API here."); + + if let Some(model) = cli + .model + .clone() + .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) + .filter(|m| !m.trim().is_empty()) + { + println!("- Configured deployment/model: {model}"); + } else { + println!("- Set --model or OPENAI_REALTIME_API_MODEL to your Azure deployment name."); + } + + Ok(()) +} + +async fn list_google_models(cli: &Cli) -> Result<()> { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let endpoint = cli + .endpoint + .clone() + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()); + let models_url = google_models_url(endpoint.as_deref())?; + + let response: GeminiModelsResponse = reqwest::Client::new() + .get(models_url) + .query(&[("key", key)]) + .send() + .await + .context("Requesting Gemini models")? + .error_for_status() + .context("Gemini models request failed")? + .json() + .await + .context("Decoding Gemini models response")?; + + let mut live_models: Vec<_> = response + .models + .into_iter() + .filter(|m| is_gemini_live_model(&m.name, &m.supported_generation_methods)) + .map(|m| m.name) + .collect(); + live_models.sort(); + + println!("Available models for Google (Live API capable):"); + if live_models.is_empty() { + println!("- No Live-capable models were detected from models.list."); + println!("- This can happen when model metadata does not include Live-specific methods."); + println!("- Try explicitly using a known Live model, for example:"); + println!(" - models/gemini-3.1-flash-live-preview"); + println!(" - models/gemini-2.5-flash-live-preview"); + } else { + for model in live_models { + println!("- {model}"); + } + } + Ok(()) +} + +fn is_gemini_live_model(model_name: &str, methods: &[String]) -> bool { + if model_name.to_ascii_lowercase().contains("live") { + return true; + } + + methods.iter().any(|method| { + method.eq_ignore_ascii_case("bidiGenerateContent") + || method.eq_ignore_ascii_case("streamGenerateContent") + }) +} + +fn openai_models_url(endpoint: Option<&str>) -> Result { + const OPENAI_MODELS_ENDPOINT: &str = "https://api.openai.com/v1/models"; + + let Some(endpoint) = endpoint else { + return Ok(OPENAI_MODELS_ENDPOINT.to_owned()); + }; + + let mut url = Url::parse(endpoint) + .or_else(|_| Url::parse(OPENAI_MODELS_ENDPOINT)) + .context("Parsing OpenAI model list URL")?; + + let normalized_scheme = match url.scheme() { + "wss" => "https".to_owned(), + "ws" => "http".to_owned(), + other => other.to_owned(), + }; + url.set_scheme(&normalized_scheme).ok(); + url.set_path("/v1/models"); + url.set_query(None); + Ok(url.to_string()) +} + +fn google_models_url(endpoint: Option<&str>) -> Result { + const GOOGLE_MODELS_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta/models"; + + let Some(endpoint) = endpoint else { + return Ok(GOOGLE_MODELS_ENDPOINT.to_owned()); + }; + + let mut url = Url::parse(endpoint) + .or_else(|_| Url::parse(GOOGLE_MODELS_ENDPOINT)) + .context("Parsing Gemini model list URL")?; + + let normalized_scheme = match url.scheme() { + "wss" => "https".to_owned(), + "ws" => "http".to_owned(), + other => other.to_owned(), + }; + url.set_scheme(&normalized_scheme).ok(); + url.set_path("/v1beta/models"); + url.set_query(None); + Ok(url.to_string()) +} + +const GEMINI_VOICES: &[&str] = &[ + "Zephyr", + "Puck", + "Charon", + "Kore", + "Fenrir", + "Leda", + "Orus", + "Aoede", + "Callirrhoe", + "Autonoe", + "Enceladus", + "Iapetus", + "Umbriel", + "Algieba", + "Despina", + "Erinome", + "Algenib", + "Rasalgethi", + "Laomedeia", + "Achernar", + "Alnilam", + "Schedar", + "Gacrux", + "Pulcherrima", + "Achird", + "Zubenelgenubi", + "Vindemiatrix", + "Sadachbia", + "Sadaltager", + "Sulafat", +]; + +#[derive(Debug, Deserialize)] +struct OpenAIModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAIModel { + id: String, +} + +#[derive(Debug, Deserialize)] +struct GeminiModelsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiModel { + name: String, + #[serde(default)] + supported_generation_methods: Vec, +} + async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> { match cli.provider { Provider::OpenAI | Provider::Azure => { @@ -209,7 +448,11 @@ async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> .clone() .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()) .filter(|endpoint| !endpoint.trim().is_empty()); - params.voice = cli.voice.clone(); + params.voice = cli + .voice + .as_deref() + .map(parse_gemini_voice_value) + .transpose()?; params.output_audio_transcription = true; params.tools.push(gemini_get_time_tool()); @@ -223,6 +466,23 @@ fn parse_realtime_voice_value(value: &str) -> Result Result { + if GEMINI_VOICES.iter().any(|v| v.eq_ignore_ascii_case(value)) { + let voice = GEMINI_VOICES + .iter() + .find(|v| v.eq_ignore_ascii_case(value)) + .copied() + .unwrap_or(value) + .to_owned(); + Ok(voice) + } else { + let available = GEMINI_VOICES.join(", "); + bail!( + "Invalid Gemini voice `{value}`. Available voices: {available}" + ) + } +} + enum AudioCommand { PlayFrame(AudioFrame), Clear, From 71ca08d2be8d1710076c001b3b6736f66f1731b7 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 09:59:23 +0200 Subject: [PATCH 07/20] Rename provider azure to azure-openai and use appropriate keys --- examples/dialog.rs | 53 ++++++++++++++++++++++++++++++++++------------ 1 file changed, 39 insertions(+), 14 deletions(-) diff --git a/examples/dialog.rs b/examples/dialog.rs index a85c7e3..1fc6eed 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -56,14 +56,15 @@ struct Cli { enum Provider { #[value(name = "openai")] OpenAI, - Azure, + #[value(name = "azure-openai")] + AzureOpenAI, Google, } impl Provider { fn output_format(self, input_format: AudioFormat) -> AudioFormat { match self { - Provider::OpenAI | Provider::Azure => input_format, + Provider::OpenAI | Provider::AzureOpenAI => input_format, Provider::Google => AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE), } } @@ -160,7 +161,7 @@ async fn main() -> Result<()> { fn list_available_voices(provider: Provider) -> Result<()> { match provider { - Provider::OpenAI | Provider::Azure => { + Provider::OpenAI | Provider::AzureOpenAI => { println!("Available voices for {:?}:", provider); for voice in ::VARIANTS { println!("- {voice}"); @@ -180,7 +181,7 @@ fn list_available_voices(provider: Provider) -> Result<()> { async fn list_available_models(cli: &Cli) -> Result<()> { match cli.provider { Provider::OpenAI => list_openai_models(cli).await, - Provider::Azure => list_azure_models(cli), + Provider::AzureOpenAI => list_azure_models(cli), Provider::Google => list_google_models(cli).await, } } @@ -237,12 +238,14 @@ fn list_azure_models(cli: &Cli) -> Result<()> { if let Some(model) = cli .model .clone() - .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) .filter(|m| !m.trim().is_empty()) { println!("- Configured deployment/model: {model}"); } else { - println!("- Set --model or OPENAI_REALTIME_API_MODEL to your Azure deployment name."); + println!( + "- Set --model or AZURE_OPENAI_REALTIME_API_MODEL to your Azure deployment name." + ); } Ok(()) @@ -404,7 +407,7 @@ struct GeminiModel { async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> { match cli.provider { - Provider::OpenAI | Provider::Azure => { + Provider::OpenAI => { let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; let model = cli .model @@ -419,11 +422,33 @@ async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> .clone() .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) .filter(|endpoint| !endpoint.trim().is_empty()); - params.protocol = Some(match cli.provider { - Provider::OpenAI => Protocol::OpenAI, - Provider::Azure => Protocol::Azure, - Provider::Google => unreachable!(), - }); + params.protocol = Some(Protocol::OpenAI); + params.voice = cli + .voice + .as_deref() + .map(parse_realtime_voice_value) + .transpose()?; + params.tools.push(openai_get_time_function_definition()); + + OpenAIDialog.conversation(params, conversation).await + } + Provider::AzureOpenAI => { + let key = env::var("AZURE_OPENAI_API_KEY") + .context("AZURE_OPENAI_API_KEY undefined")?; + let model = cli + .model + .clone() + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .context("Provide --model or set AZURE_OPENAI_REALTIME_API_MODEL")?; + + let mut params = openai_dialog::Params::new(key, model); + params.host = cli + .endpoint + .clone() + .or_else(|| env::var("AZURE_OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = Some(Protocol::Azure); params.voice = cli .voice .as_deref() @@ -561,7 +586,7 @@ fn handle_service_event( value: serde_json::Value, ) -> Result<()> { let call = match provider { - Provider::OpenAI | Provider::Azure => match serde_json::from_value(value)? { + Provider::OpenAI | Provider::AzureOpenAI => match serde_json::from_value(value)? { OpenAIServiceOutputEvent::FunctionCall { name, call_id, @@ -619,7 +644,7 @@ fn send_function_result( ) -> Result<()> { let output = json!({ "time": serde_json::Value::String(result) }); let value = match provider { - Provider::OpenAI | Provider::Azure => { + Provider::OpenAI | Provider::AzureOpenAI => { serde_json::to_value(&OpenAIServiceInputEvent::FunctionCallResult { call_id, output })? } Provider::Google => serde_json::to_value(&GoogleServiceInputEvent::FunctionCallResult { From 3c90598e94e5af299f7755393aa60fde2ce17624 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 10:37:29 +0200 Subject: [PATCH 08/20] dialog example: Separate provider specific code into modules --- .github/copilot-instructions.md | 4 + examples/dialog.rs | 449 ++-------------------- examples/dialog_providers/azure_openai.rs | 90 +++++ examples/dialog_providers/google.rs | 241 ++++++++++++ examples/dialog_providers/mod.rs | 49 +++ examples/dialog_providers/openai.rs | 176 +++++++++ 6 files changed, 587 insertions(+), 422 deletions(-) create mode 100644 examples/dialog_providers/azure_openai.rs create mode 100644 examples/dialog_providers/google.rs create mode 100644 examples/dialog_providers/mod.rs create mode 100644 examples/dialog_providers/openai.rs diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index f3a2f32..f6fa600 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -3,9 +3,13 @@ ## Rust Style - Prefer imports over deeply-qualified module paths. As a rule of thumb, avoid using more than one module prefix inline (for example, prefer importing a type and using `TypeName` instead of writing `foo::bar::TypeName` repeatedly). - Prefer high-level flow first: when practical, place local supporting definitions (for example helper structs, impls, functions, and type aliases) below their first use. +- In module files, keep definitions ordered top-down by call flow (entry points first, helpers after first use). - Keep imports grouped and sorted to match existing file style. - Avoid `maybe_` prefixes for optional variables; use neutral names and rely on type/context for optionality. - Avoid `_ref` suffixes for local variable names; use descriptive neutral names instead. +- Prefer explicit imports over repeated relative module qualification. +- Prefer private-by-default visibility; only widen visibility when required by a module boundary. +- For trait-based APIs, prefer focused request/context types over passing broad configuration structs. ## Change Communication - Include a short rationale for each non-trivial code change. diff --git a/examples/dialog.rs b/examples/dialog.rs index 1fc6eed..1e6e802 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -1,7 +1,6 @@ //! A context switch demo. Runs locally, gets voice data from your current microphone. use std::{ - env, num::{NonZeroU16, NonZeroU32}, str::FromStr, thread, @@ -13,18 +12,8 @@ use chrono::Utc; use clap::{Parser, ValueEnum}; use context_switch::{InputModality, OutputModality}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; -use gemini_live::types as gemini_types; -use google_dialog::{GoogleDialog, ServiceInputEvent as GoogleServiceInputEvent}; -use openai_api_rs::realtime::types as openai_types; -use openai_dialog::{ - OpenAIDialog, Protocol, ServiceInputEvent as OpenAIServiceInputEvent, - ServiceOutputEvent as OpenAIServiceOutputEvent, -}; -use reqwest::Url; use rodio::{DeviceSinkBuilder, Player, Source}; -use serde::Deserialize; use serde_json::json; -use strum::VariantNames; use tokio::{ select, sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}, @@ -32,10 +21,12 @@ use tokio::{ use tracing::info; use context_switch_core::{ - AudioFormat, AudioFrame, Service, audio, + AudioFormat, AudioFrame, audio, conversation::{Conversation, Input, Output}, }; +mod dialog_providers; + #[derive(Debug, Parser)] struct Cli { #[arg(value_enum)] @@ -62,11 +53,8 @@ enum Provider { } impl Provider { - fn output_format(self, input_format: AudioFormat) -> AudioFormat { - match self { - Provider::OpenAI | Provider::AzureOpenAI => input_format, - Provider::Google => AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE), - } + fn api(self) -> &'static dyn dialog_providers::ProviderApi { + dialog_providers::provider_api(self) } } @@ -101,7 +89,7 @@ async fn main() -> Result<()> { let channels = input_config.channels(); let sample_rate = input_config.sample_rate(); let input_format = AudioFormat::new(channels, sample_rate); - let output_format = cli.provider.output_format(input_format); + let output_format = cli.provider.api().output_format(input_format); let (input_sender, input_receiver) = channel(256); let input_sender2 = input_sender.clone(); @@ -160,352 +148,31 @@ async fn main() -> Result<()> { } fn list_available_voices(provider: Provider) -> Result<()> { - match provider { - Provider::OpenAI | Provider::AzureOpenAI => { - println!("Available voices for {:?}:", provider); - for voice in ::VARIANTS { - println!("- {voice}"); - } - Ok(()) - } - Provider::Google => { - println!("Available voices for {:?}:", provider); - for voice in GEMINI_VOICES { - println!("- {voice}"); - } - Ok(()) - } - } -} - -async fn list_available_models(cli: &Cli) -> Result<()> { - match cli.provider { - Provider::OpenAI => list_openai_models(cli).await, - Provider::AzureOpenAI => list_azure_models(cli), - Provider::Google => list_google_models(cli).await, - } -} - -async fn list_openai_models(cli: &Cli) -> Result<()> { - let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; - let endpoint = cli - .endpoint - .clone() - .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()); - let models_url = openai_models_url(endpoint.as_deref())?; - - let response: OpenAIModelsResponse = reqwest::Client::new() - .get(models_url) - .bearer_auth(key) - .send() - .await - .context("Requesting OpenAI models")? - .error_for_status() - .context("OpenAI models request failed")? - .json() - .await - .context("Decoding OpenAI models response")?; - - let mut models: Vec<_> = response - .data - .into_iter() - .map(|m| m.id) - .filter(|id| is_openai_realtime_model(id)) - .collect(); - models.sort(); - - println!("Available models for OpenAI:"); - if models.is_empty() { - println!("- No realtime-capable models were returned by the models endpoint."); - println!("- Ensure your API key has access to OpenAI Realtime API models."); - } else { - for model in models { - println!("- {model}"); - } - } - Ok(()) -} - -fn is_openai_realtime_model(model_id: &str) -> bool { - model_id.to_ascii_lowercase().contains("realtime") -} - -fn list_azure_models(cli: &Cli) -> Result<()> { - println!("Available models for Azure:"); - println!("- Azure Realtime uses deployment names configured in your Azure OpenAI resource."); - println!("- The realtime endpoint does not provide a provider-agnostic model listing API here."); - - if let Some(model) = cli - .model - .clone() - .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) - .filter(|m| !m.trim().is_empty()) - { - println!("- Configured deployment/model: {model}"); - } else { - println!( - "- Set --model or AZURE_OPENAI_REALTIME_API_MODEL to your Azure deployment name." - ); - } - - Ok(()) -} - -async fn list_google_models(cli: &Cli) -> Result<()> { - let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; - let endpoint = cli - .endpoint - .clone() - .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()); - let models_url = google_models_url(endpoint.as_deref())?; - - let response: GeminiModelsResponse = reqwest::Client::new() - .get(models_url) - .query(&[("key", key)]) - .send() - .await - .context("Requesting Gemini models")? - .error_for_status() - .context("Gemini models request failed")? - .json() - .await - .context("Decoding Gemini models response")?; - - let mut live_models: Vec<_> = response - .models - .into_iter() - .filter(|m| is_gemini_live_model(&m.name, &m.supported_generation_methods)) - .map(|m| m.name) - .collect(); - live_models.sort(); - - println!("Available models for Google (Live API capable):"); - if live_models.is_empty() { - println!("- No Live-capable models were detected from models.list."); - println!("- This can happen when model metadata does not include Live-specific methods."); - println!("- Try explicitly using a known Live model, for example:"); - println!(" - models/gemini-3.1-flash-live-preview"); - println!(" - models/gemini-2.5-flash-live-preview"); - } else { - for model in live_models { - println!("- {model}"); - } + println!("Available voices for {:?}:", provider); + for voice in provider.api().voices() { + println!("- {voice}"); } Ok(()) } -fn is_gemini_live_model(model_name: &str, methods: &[String]) -> bool { - if model_name.to_ascii_lowercase().contains("live") { - return true; - } - - methods.iter().any(|method| { - method.eq_ignore_ascii_case("bidiGenerateContent") - || method.eq_ignore_ascii_case("streamGenerateContent") - }) -} - -fn openai_models_url(endpoint: Option<&str>) -> Result { - const OPENAI_MODELS_ENDPOINT: &str = "https://api.openai.com/v1/models"; - - let Some(endpoint) = endpoint else { - return Ok(OPENAI_MODELS_ENDPOINT.to_owned()); - }; - - let mut url = Url::parse(endpoint) - .or_else(|_| Url::parse(OPENAI_MODELS_ENDPOINT)) - .context("Parsing OpenAI model list URL")?; - - let normalized_scheme = match url.scheme() { - "wss" => "https".to_owned(), - "ws" => "http".to_owned(), - other => other.to_owned(), - }; - url.set_scheme(&normalized_scheme).ok(); - url.set_path("/v1/models"); - url.set_query(None); - Ok(url.to_string()) -} - -fn google_models_url(endpoint: Option<&str>) -> Result { - const GOOGLE_MODELS_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta/models"; - - let Some(endpoint) = endpoint else { - return Ok(GOOGLE_MODELS_ENDPOINT.to_owned()); - }; - - let mut url = Url::parse(endpoint) - .or_else(|_| Url::parse(GOOGLE_MODELS_ENDPOINT)) - .context("Parsing Gemini model list URL")?; - - let normalized_scheme = match url.scheme() { - "wss" => "https".to_owned(), - "ws" => "http".to_owned(), - other => other.to_owned(), +async fn list_available_models(cli: &Cli) -> Result<()> { + let request = dialog_providers::ListModelsRequest { + endpoint: cli.endpoint.clone(), + model: cli.model.clone(), }; - url.set_scheme(&normalized_scheme).ok(); - url.set_path("/v1beta/models"); - url.set_query(None); - Ok(url.to_string()) -} - -const GEMINI_VOICES: &[&str] = &[ - "Zephyr", - "Puck", - "Charon", - "Kore", - "Fenrir", - "Leda", - "Orus", - "Aoede", - "Callirrhoe", - "Autonoe", - "Enceladus", - "Iapetus", - "Umbriel", - "Algieba", - "Despina", - "Erinome", - "Algenib", - "Rasalgethi", - "Laomedeia", - "Achernar", - "Alnilam", - "Schedar", - "Gacrux", - "Pulcherrima", - "Achird", - "Zubenelgenubi", - "Vindemiatrix", - "Sadachbia", - "Sadaltager", - "Sulafat", -]; - -#[derive(Debug, Deserialize)] -struct OpenAIModelsResponse { - data: Vec, -} - -#[derive(Debug, Deserialize)] -struct OpenAIModel { - id: String, -} - -#[derive(Debug, Deserialize)] -struct GeminiModelsResponse { - #[serde(default)] - models: Vec, -} - -#[derive(Debug, Deserialize)] -struct GeminiModel { - name: String, - #[serde(default)] - supported_generation_methods: Vec, + cli.provider.api().list_models(request).await } async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> { - match cli.provider { - Provider::OpenAI => { - let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; - let model = cli - .model - .clone() - .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) - .filter(|model| !model.trim().is_empty()) - .context("Provide --model or set OPENAI_REALTIME_API_MODEL")?; - - let mut params = openai_dialog::Params::new(key, model); - params.host = cli - .endpoint - .clone() - .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) - .filter(|endpoint| !endpoint.trim().is_empty()); - params.protocol = Some(Protocol::OpenAI); - params.voice = cli - .voice - .as_deref() - .map(parse_realtime_voice_value) - .transpose()?; - params.tools.push(openai_get_time_function_definition()); - - OpenAIDialog.conversation(params, conversation).await - } - Provider::AzureOpenAI => { - let key = env::var("AZURE_OPENAI_API_KEY") - .context("AZURE_OPENAI_API_KEY undefined")?; - let model = cli - .model - .clone() - .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) - .filter(|model| !model.trim().is_empty()) - .context("Provide --model or set AZURE_OPENAI_REALTIME_API_MODEL")?; - - let mut params = openai_dialog::Params::new(key, model); - params.host = cli - .endpoint - .clone() - .or_else(|| env::var("AZURE_OPENAI_REALTIME_ENDPOINT").ok()) - .filter(|endpoint| !endpoint.trim().is_empty()); - params.protocol = Some(Protocol::Azure); - params.voice = cli - .voice - .as_deref() - .map(parse_realtime_voice_value) - .transpose()?; - params.tools.push(openai_get_time_function_definition()); - - OpenAIDialog.conversation(params, conversation).await - } - Provider::Google => { - let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; - let model = cli - .model - .clone() - .or_else(|| env::var("GEMINI_LIVE_API_MODEL").ok()) - .filter(|model| !model.trim().is_empty()) - .unwrap_or_else(|| "gemini-3.1-flash-live-preview".to_owned()); - - let mut params = google_dialog::Params::new(key, model); - params.host = cli - .endpoint - .clone() - .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()) - .filter(|endpoint| !endpoint.trim().is_empty()); - params.voice = cli - .voice - .as_deref() - .map(parse_gemini_voice_value) - .transpose()?; - params.output_audio_transcription = true; - params.tools.push(gemini_get_time_tool()); - - GoogleDialog.conversation(params, conversation).await - } - } -} - -fn parse_realtime_voice_value(value: &str) -> Result { - openai_types::RealtimeVoice::from_str(value) - .map_err(|e| anyhow::anyhow!("Invalid voice value `{value}`: {e}")) -} - -fn parse_gemini_voice_value(value: &str) -> Result { - if GEMINI_VOICES.iter().any(|v| v.eq_ignore_ascii_case(value)) { - let voice = GEMINI_VOICES - .iter() - .find(|v| v.eq_ignore_ascii_case(value)) - .copied() - .unwrap_or(value) - .to_owned(); - Ok(voice) - } else { - let available = GEMINI_VOICES.join(", "); - bail!( - "Invalid Gemini voice `{value}`. Available voices: {available}" - ) - } + let request = dialog_providers::StartConversationRequest { + endpoint: cli.endpoint.clone(), + model: cli.model.clone(), + voice: cli.voice.clone(), + }; + cli.provider + .api() + .start_conversation(request, conversation) + .await } enum AudioCommand { @@ -585,42 +252,7 @@ fn handle_service_event( input: &Sender, value: serde_json::Value, ) -> Result<()> { - let call = match provider { - Provider::OpenAI | Provider::AzureOpenAI => match serde_json::from_value(value)? { - OpenAIServiceOutputEvent::FunctionCall { - name, - call_id, - arguments, - } => Some(FunctionCall { - name, - call_id, - arguments, - }), - OpenAIServiceOutputEvent::SessionUpdated { tools } => { - info!("Session updated: {tools:?}"); - None - } - }, - Provider::Google => match serde_json::from_value(value)? { - google_dialog::ServiceOutputEvent::FunctionCall { - name, - call_id, - arguments, - } => Some(FunctionCall { - name, - call_id, - arguments: Some(arguments), - }), - google_dialog::ServiceOutputEvent::ToolCallCancellation { call_ids } => { - info!("Tool calls cancelled: {call_ids:?}"); - None - } - google_dialog::ServiceOutputEvent::SessionUpdated { tools } => { - info!("Session updated: {tools:?}"); - None - } - }, - }; + let call = provider.api().parse_service_event(value)?; if let Some(call) = call { info!( @@ -642,17 +274,9 @@ fn send_function_result( name: String, result: String, ) -> Result<()> { - let output = json!({ "time": serde_json::Value::String(result) }); - let value = match provider { - Provider::OpenAI | Provider::AzureOpenAI => { - serde_json::to_value(&OpenAIServiceInputEvent::FunctionCallResult { call_id, output })? - } - Provider::Google => serde_json::to_value(&GoogleServiceInputEvent::FunctionCallResult { - call_id, - name, - output, - })?, - }; + let value = provider + .api() + .function_result_event(call_id, Some(name), result)?; input.try_send(Input::ServiceEvent { value })?; Ok(()) } @@ -663,25 +287,6 @@ struct FunctionCall { name: String, arguments: Option, } - -fn openai_get_time_function_definition() -> openai_types::ToolDefinition { - openai_types::ToolDefinition::Function { - name: "get_time".into(), - description: "The current time to the exact second.".into(), - parameters: get_time_parameters_schema(), - } -} - -fn gemini_get_time_tool() -> gemini_types::Tool { - gemini_types::Tool::FunctionDeclarations(vec![gemini_types::FunctionDeclaration { - name: "get_time".into(), - description: "The current time to the exact second.".into(), - parameters: get_time_parameters_schema(), - scheduling: None, - behavior: None, - }]) -} - fn get_time_parameters_schema() -> serde_json::Value { json!({ "type": "object", diff --git a/examples/dialog_providers/azure_openai.rs b/examples/dialog_providers/azure_openai.rs new file mode 100644 index 0000000..fcde448 --- /dev/null +++ b/examples/dialog_providers/azure_openai.rs @@ -0,0 +1,90 @@ +use std::env; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use openai_api_rs::realtime::types as openai_types; +use openai_dialog::{OpenAIDialog, Protocol}; +use strum::VariantNames; + +use context_switch_core::{AudioFormat, Service, conversation::Conversation}; + +use super::openai::{self, OpenAIProvider}; +use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; + +pub struct AzureOpenAIProvider; + +#[async_trait(?Send)] +impl ProviderApi for AzureOpenAIProvider { + fn output_format(&self, input_format: AudioFormat) -> AudioFormat { + input_format + } + + fn voices(&self) -> &'static [&'static str] { + ::VARIANTS + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + println!("Available models for Azure:"); + println!( + "- Azure Realtime uses deployment names configured in your Azure OpenAI resource." + ); + println!( + "- The realtime endpoint does not provide a provider-agnostic model listing API here." + ); + + if let Some(model) = request + .model + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + { + println!("- Configured deployment/model: {model}"); + } else { + println!( + "- Set --model or AZURE_OPENAI_REALTIME_API_MODEL to your Azure deployment name." + ); + } + + Ok(()) + } + + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()> { + let key = env::var("AZURE_OPENAI_API_KEY").context("AZURE_OPENAI_API_KEY undefined")?; + let model = request + .model + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .context("Provide --model or set AZURE_OPENAI_REALTIME_API_MODEL")?; + + let mut params = openai_dialog::Params::new(key, model); + params.host = request + .endpoint + .or_else(|| env::var("AZURE_OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = Some(Protocol::Azure); + params.voice = request + .voice + .as_deref() + .map(openai::parse_realtime_voice_value) + .transpose()?; + params.tools.push(openai::get_time_function_definition()); + + OpenAIDialog.conversation(params, conversation).await + } + + fn parse_service_event(&self, value: serde_json::Value) -> Result> { + OpenAIProvider.parse_service_event(value) + } + + fn function_result_event( + &self, + call_id: String, + _name: Option, + result: String, + ) -> Result { + OpenAIProvider.function_result_event(call_id, None, result) + } +} diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs new file mode 100644 index 0000000..227e0e6 --- /dev/null +++ b/examples/dialog_providers/google.rs @@ -0,0 +1,241 @@ +use std::env; + +use anyhow::{Context, Result, bail}; +use async_trait::async_trait; +use gemini_live::types as gemini_types; +use google_dialog::{GoogleDialog, ServiceInputEvent as GoogleServiceInputEvent}; +use reqwest::Url; +use serde::Deserialize; +use serde_json::json; + +use crate::{FunctionCall, get_time_parameters_schema}; +use context_switch_core::{AudioFormat, Service, conversation::Conversation}; + +use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; + +pub struct GoogleProvider; + +#[async_trait(?Send)] +impl ProviderApi for GoogleProvider { + fn output_format(&self, _input_format: AudioFormat) -> AudioFormat { + AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE) + } + + fn voices(&self) -> &'static [&'static str] { + VOICES + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let endpoint = request + .endpoint + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()); + let models_url = models_url(endpoint.as_deref())?; + + let response: GeminiModelsResponse = reqwest::Client::new() + .get(models_url) + .query(&[("key", key)]) + .send() + .await + .context("Requesting Gemini models")? + .error_for_status() + .context("Gemini models request failed")? + .json() + .await + .context("Decoding Gemini models response")?; + + let mut live_models: Vec<_> = response + .models + .into_iter() + .filter(|model| is_live_model(&model.name, &model.supported_generation_methods)) + .map(|model| model.name) + .collect(); + live_models.sort(); + + println!("Available models for Google (Live API capable):"); + if live_models.is_empty() { + println!("- No Live-capable models were detected from models.list."); + println!( + "- This can happen when model metadata does not include Live-specific methods." + ); + println!("- Try explicitly using a known Live model, for example:"); + println!(" - models/gemini-3.1-flash-live-preview"); + println!(" - models/gemini-2.5-flash-live-preview"); + } else { + for model in live_models { + println!("- {model}"); + } + } + Ok(()) + } + + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()> { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let model = request + .model + .or_else(|| env::var("GEMINI_LIVE_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .unwrap_or_else(|| "gemini-3.1-flash-live-preview".to_owned()); + + let mut params = google_dialog::Params::new(key, model); + params.host = request + .endpoint + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.voice = request + .voice + .as_deref() + .map(parse_voice_value) + .transpose()?; + params.output_audio_transcription = true; + params.tools.push(get_time_tool()); + + GoogleDialog.conversation(params, conversation).await + } + + fn parse_service_event(&self, value: serde_json::Value) -> Result> { + match serde_json::from_value(value)? { + google_dialog::ServiceOutputEvent::FunctionCall { + name, + call_id, + arguments, + } => Ok(Some(FunctionCall { + name, + call_id, + arguments: Some(arguments), + })), + google_dialog::ServiceOutputEvent::ToolCallCancellation { call_ids } => { + tracing::info!("Tool calls cancelled: {call_ids:?}"); + Ok(None) + } + google_dialog::ServiceOutputEvent::SessionUpdated { tools } => { + tracing::info!("Session updated: {tools:?}"); + Ok(None) + } + } + } + + fn function_result_event( + &self, + call_id: String, + name: Option, + result: String, + ) -> Result { + let name = name.context("Function name is required for Google function result")?; + let output = json!({ "time": serde_json::Value::String(result) }); + serde_json::to_value(&GoogleServiceInputEvent::FunctionCallResult { + call_id, + name, + output, + }) + .map_err(Into::into) + } +} + +fn models_url(endpoint: Option<&str>) -> Result { + const GOOGLE_MODELS_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta/models"; + + let Some(endpoint) = endpoint else { + return Ok(GOOGLE_MODELS_ENDPOINT.to_owned()); + }; + + let mut url = Url::parse(endpoint) + .or_else(|_| Url::parse(GOOGLE_MODELS_ENDPOINT)) + .context("Parsing Gemini model list URL")?; + + let normalized_scheme = match url.scheme() { + "wss" => "https".to_owned(), + "ws" => "http".to_owned(), + other => other.to_owned(), + }; + url.set_scheme(&normalized_scheme).ok(); + url.set_path("/v1beta/models"); + url.set_query(None); + Ok(url.to_string()) +} + +fn is_live_model(model_name: &str, methods: &[String]) -> bool { + if model_name.to_ascii_lowercase().contains("live") { + return true; + } + + methods.iter().any(|method| { + method.eq_ignore_ascii_case("bidiGenerateContent") + || method.eq_ignore_ascii_case("streamGenerateContent") + }) +} + +const VOICES: &[&str] = &[ + "Zephyr", + "Puck", + "Charon", + "Kore", + "Fenrir", + "Leda", + "Orus", + "Aoede", + "Callirrhoe", + "Autonoe", + "Enceladus", + "Iapetus", + "Umbriel", + "Algieba", + "Despina", + "Erinome", + "Algenib", + "Rasalgethi", + "Laomedeia", + "Achernar", + "Alnilam", + "Schedar", + "Gacrux", + "Pulcherrima", + "Achird", + "Zubenelgenubi", + "Vindemiatrix", + "Sadachbia", + "Sadaltager", + "Sulafat", +]; + +fn parse_voice_value(value: &str) -> Result { + if VOICES.iter().any(|voice| voice.eq_ignore_ascii_case(value)) { + let voice = VOICES + .iter() + .find(|voice| voice.eq_ignore_ascii_case(value)) + .copied() + .unwrap_or(value) + .to_owned(); + Ok(voice) + } else { + let available = VOICES.join(", "); + bail!("Invalid Gemini voice `{value}`. Available voices: {available}") + } +} + +fn get_time_tool() -> gemini_types::Tool { + gemini_types::Tool::FunctionDeclarations(vec![gemini_types::FunctionDeclaration { + name: "get_time".into(), + description: "The current time to the exact second.".into(), + parameters: get_time_parameters_schema(), + scheduling: None, + behavior: None, + }]) +} + +#[derive(Debug, Deserialize)] +struct GeminiModelsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiModel { + name: String, + #[serde(default)] + supported_generation_methods: Vec, +} diff --git a/examples/dialog_providers/mod.rs b/examples/dialog_providers/mod.rs new file mode 100644 index 0000000..c8f3cb9 --- /dev/null +++ b/examples/dialog_providers/mod.rs @@ -0,0 +1,49 @@ +use anyhow::Result; +use async_trait::async_trait; + +use crate::{FunctionCall, Provider}; +use context_switch_core::{AudioFormat, conversation::Conversation}; + +#[derive(Debug, Clone)] +pub struct ListModelsRequest { + pub endpoint: Option, + pub model: Option, +} + +#[derive(Debug, Clone)] +pub struct StartConversationRequest { + pub endpoint: Option, + pub model: Option, + pub voice: Option, +} + +#[async_trait(?Send)] +pub trait ProviderApi { + fn output_format(&self, input_format: AudioFormat) -> AudioFormat; + fn voices(&self) -> &'static [&'static str]; + async fn list_models(&self, request: ListModelsRequest) -> Result<()>; + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()>; + fn parse_service_event(&self, value: serde_json::Value) -> Result>; + fn function_result_event( + &self, + call_id: String, + name: Option, + result: String, + ) -> Result; +} + +pub fn provider_api(provider: Provider) -> &'static dyn ProviderApi { + match provider { + Provider::OpenAI => &openai::OpenAIProvider, + Provider::AzureOpenAI => &azure_openai::AzureOpenAIProvider, + Provider::Google => &google::GoogleProvider, + } +} + +mod azure_openai; +mod google; +mod openai; diff --git a/examples/dialog_providers/openai.rs b/examples/dialog_providers/openai.rs new file mode 100644 index 0000000..89beaf2 --- /dev/null +++ b/examples/dialog_providers/openai.rs @@ -0,0 +1,176 @@ +use std::{env, str::FromStr}; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use openai_api_rs::realtime::types as openai_types; +use openai_dialog::{ + OpenAIDialog, Protocol, ServiceInputEvent as OpenAIServiceInputEvent, + ServiceOutputEvent as OpenAIServiceOutputEvent, +}; +use reqwest::Url; +use serde::Deserialize; +use serde_json::json; +use strum::VariantNames; + +use crate::{FunctionCall, get_time_parameters_schema}; +use context_switch_core::{AudioFormat, Service, conversation::Conversation}; + +use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; + +pub struct OpenAIProvider; + +#[async_trait(?Send)] +impl ProviderApi for OpenAIProvider { + fn output_format(&self, input_format: AudioFormat) -> AudioFormat { + input_format + } + + fn voices(&self) -> &'static [&'static str] { + ::VARIANTS + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let endpoint = request + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()); + let models_url = models_url(endpoint.as_deref())?; + + let response: OpenAIModelsResponse = reqwest::Client::new() + .get(models_url) + .bearer_auth(key) + .send() + .await + .context("Requesting OpenAI models")? + .error_for_status() + .context("OpenAI models request failed")? + .json() + .await + .context("Decoding OpenAI models response")?; + + let mut models: Vec<_> = response + .data + .into_iter() + .map(|model| model.id) + .filter(|id| is_realtime_model(id)) + .collect(); + models.sort(); + + println!("Available models for OpenAI:"); + if models.is_empty() { + println!("- No realtime-capable models were returned by the models endpoint."); + println!("- Ensure your API key has access to OpenAI Realtime API models."); + } else { + for model in models { + println!("- {model}"); + } + } + Ok(()) + } + + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()> { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let model = request + .model + .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .context("Provide --model or set OPENAI_REALTIME_API_MODEL")?; + + let mut params = openai_dialog::Params::new(key, model); + params.host = request + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = Some(Protocol::OpenAI); + params.voice = request + .voice + .as_deref() + .map(parse_realtime_voice_value) + .transpose()?; + params.tools.push(get_time_function_definition()); + + OpenAIDialog.conversation(params, conversation).await + } + + fn parse_service_event(&self, value: serde_json::Value) -> Result> { + match serde_json::from_value(value)? { + OpenAIServiceOutputEvent::FunctionCall { + name, + call_id, + arguments, + } => Ok(Some(FunctionCall { + name, + call_id, + arguments, + })), + OpenAIServiceOutputEvent::SessionUpdated { tools } => { + tracing::info!("Session updated: {tools:?}"); + Ok(None) + } + } + } + + fn function_result_event( + &self, + call_id: String, + _name: Option, + result: String, + ) -> Result { + let output = json!({ "time": serde_json::Value::String(result) }); + serde_json::to_value(&OpenAIServiceInputEvent::FunctionCallResult { call_id, output }) + .map_err(Into::into) + } +} + +pub fn parse_realtime_voice_value(value: &str) -> Result { + openai_types::RealtimeVoice::from_str(value) + .map_err(|error| anyhow::anyhow!("Invalid voice value `{value}`: {error}")) +} + +pub fn get_time_function_definition() -> openai_types::ToolDefinition { + openai_types::ToolDefinition::Function { + name: "get_time".into(), + description: "The current time to the exact second.".into(), + parameters: get_time_parameters_schema(), + } +} + +fn is_realtime_model(model_id: &str) -> bool { + model_id.to_ascii_lowercase().contains("realtime") +} + +fn models_url(endpoint: Option<&str>) -> Result { + const OPENAI_MODELS_ENDPOINT: &str = "https://api.openai.com/v1/models"; + + let Some(endpoint) = endpoint else { + return Ok(OPENAI_MODELS_ENDPOINT.to_owned()); + }; + + let mut url = Url::parse(endpoint) + .or_else(|_| Url::parse(OPENAI_MODELS_ENDPOINT)) + .context("Parsing OpenAI model list URL")?; + + let normalized_scheme = match url.scheme() { + "wss" => "https".to_owned(), + "ws" => "http".to_owned(), + other => other.to_owned(), + }; + url.set_scheme(&normalized_scheme).ok(); + url.set_path("/v1/models"); + url.set_query(None); + Ok(url.to_string()) +} + +#[derive(Debug, Deserialize)] +struct OpenAIModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAIModel { + id: String, +} From 1b951e2aa1b97736a790780be924a4e287c194ce Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 11:12:04 +0200 Subject: [PATCH 09/20] google-deialog: Enable session resumption by default --- services/google-dialog/src/client.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index fe6b5a0..bf02e32 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -8,8 +8,8 @@ use gemini_live::{ transport::{Auth, Endpoint, TransportConfig}, types::{ AudioTranscriptionConfig, Content, FunctionDeclaration, FunctionResponse, GenerationConfig, - Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, ServerEvent, SetupConfig, - SpeechConfig, Tool, UsageMetadata, VoiceConfig, + Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, ServerEvent, + SessionResumptionConfig, SetupConfig, SpeechConfig, Tool, UsageMetadata, VoiceConfig, }, }; use tracing::{debug, info, trace}; @@ -108,6 +108,10 @@ impl Client { system_instruction: self.params.instructions.clone().map(system_instruction), tools: (!self.params.tools.is_empty()).then(|| self.params.tools.clone()), realtime_input_config: self.params.realtime_input_config.clone(), + // Opt in so Gemini sends resume handles. The session layer stores + // the latest handle and patches it into reconnect setup messages, + // keeping context across GoAway-triggered reconnects. + session_resumption: Some(SessionResumptionConfig::default()), input_audio_transcription: self .params .input_audio_transcription From 3ffc6321bfa954f4b4d0c005b75fb18073ce449c Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 11:25:45 +0200 Subject: [PATCH 10/20] google-dialog: Add a few more features --- services/google-dialog/src/client.rs | 32 ++++++++++++++++++++++++---- services/google-dialog/src/types.rs | 21 +++++++++++++++++- 2 files changed, 48 insertions(+), 5 deletions(-) diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index bf02e32..e9bbe17 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -7,9 +7,10 @@ use gemini_live::{ ReconnectPolicy, Session, SessionConfig, transport::{Auth, Endpoint, TransportConfig}, types::{ - AudioTranscriptionConfig, Content, FunctionDeclaration, FunctionResponse, GenerationConfig, - Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, ServerEvent, - SessionResumptionConfig, SetupConfig, SpeechConfig, Tool, UsageMetadata, VoiceConfig, + AudioTranscriptionConfig, Content, ContextWindowCompressionConfig, FunctionDeclaration, + FunctionResponse, GenerationConfig, GoogleSearchTool, Modality, ModalityTokenCount, Part, + PrebuiltVoiceConfig, ServerEvent, SessionResumptionConfig, SetupConfig, SlidingWindow, + SpeechConfig, ThinkingConfig, Tool, UsageMetadata, VoiceConfig, }, }; use tracing::{debug, info, trace}; @@ -103,15 +104,26 @@ impl Client { prebuilt_voice_config: PrebuiltVoiceConfig { voice_name }, }, }), + thinking_config: self.params.thinking_level.map(|thinking_level| ThinkingConfig { + thinking_level: Some(thinking_level), + ..Default::default() + }), + temperature: self.params.temperature, ..Default::default() }), system_instruction: self.params.instructions.clone().map(system_instruction), - tools: (!self.params.tools.is_empty()).then(|| self.params.tools.clone()), + tools: setup_tools(&self.params), realtime_input_config: self.params.realtime_input_config.clone(), // Opt in so Gemini sends resume handles. The session layer stores // the latest handle and patches it into reconnect setup messages, // keeping context across GoAway-triggered reconnects. session_resumption: Some(SessionResumptionConfig::default()), + context_window_compression: self.params.context_window_compression.then_some( + ContextWindowCompressionConfig { + sliding_window: Some(SlidingWindow::default()), + ..Default::default() + }, + ), input_audio_transcription: self .params .input_audio_transcription @@ -264,6 +276,18 @@ fn output_transcription_enabled(params: &Params) -> bool { params.output_audio_transcription } +fn setup_tools(params: &Params) -> Option> { + let mut tools = params.tools.clone(); + if params.enable_search + && !tools + .iter() + .any(|tool| matches!(tool, Tool::GoogleSearch(_))) + { + tools.push(Tool::GoogleSearch(GoogleSearchTool::default())); + } + (!tools.is_empty()).then_some(tools) +} + fn function_declarations(tools: &[Tool]) -> Option> { let declarations: Vec<_> = tools .iter() diff --git a/services/google-dialog/src/types.rs b/services/google-dialog/src/types.rs index 184c665..45c145f 100644 --- a/services/google-dialog/src/types.rs +++ b/services/google-dialog/src/types.rs @@ -1,16 +1,27 @@ -use gemini_live::types::{FunctionDeclaration, RealtimeInputConfig, Tool}; +use gemini_live::types::{FunctionDeclaration, RealtimeInputConfig, ThinkingLevel, Tool}; use serde::{Deserialize, Serialize}; #[derive(Debug, Serialize, Deserialize)] #[serde(rename_all = "camelCase")] pub struct Params { pub api_key: String, + /// Gemini Live model name, with or without the `models/` prefix. pub model: String, pub host: Option, pub instructions: Option, pub voice: Option, + pub temperature: Option, + /// Gemini 3.1 thinking level (`minimal`, `low`, `medium`, or `high`). + pub thinking_level: Option, + /// Enabled by default to avoid context-window exhaustion during long audio sessions. + #[serde(default = "default_true")] + pub context_window_compression: bool, + /// Add Gemini's built-in Google Search tool unless it is already present. + #[serde(default)] + pub enable_search: bool, #[serde(default)] pub tools: Vec, + /// Gemini realtime input behavior, including VAD and turn coverage. pub realtime_input_config: Option, #[serde(default)] pub input_audio_transcription: bool, @@ -26,6 +37,10 @@ impl Params { host: None, instructions: None, voice: None, + temperature: None, + thinking_level: None, + context_window_compression: true, + enable_search: false, tools: vec![], realtime_input_config: None, input_audio_transcription: false, @@ -34,6 +49,10 @@ impl Params { } } +fn default_true() -> bool { + true +} + #[derive(Debug, Serialize, Deserialize)] #[serde(tag = "type", rename_all = "camelCase")] pub enum ServiceInputEvent { From 140a8230dc73f6813994db5fe0385d3c9ab3f21b Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Tue, 19 May 2026 12:45:13 +0200 Subject: [PATCH 11/20] google-dialog: Support interim and regular text events --- examples/dialog.rs | 15 ++-- examples/dialog_providers/google.rs | 1 + services/google-dialog/src/client.rs | 124 +++++++++++++++++---------- services/google-dialog/src/lib.rs | 42 ++++++++- services/google-dialog/src/types.rs | 7 +- 5 files changed, 131 insertions(+), 58 deletions(-) diff --git a/examples/dialog.rs b/examples/dialog.rs index 1e6e802..2377b14 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -117,13 +117,18 @@ async fn main() -> Result<()> { stream.play().expect("Failed to play stream"); let (output_sender, output_receiver) = unbounded_channel(); + // Keep text enabled at the context-switch layer for Google. let conversation = Conversation::new( InputModality::Audio { format: input_format, }, - [OutputModality::Audio { - format: output_format, - }], + [ + OutputModality::Audio { + format: output_format, + }, + OutputModality::Text, + OutputModality::InterimText, + ], input_receiver, output_sender, ); @@ -224,8 +229,8 @@ async fn setup_audio_playback( break; } } - Output::Text { text, .. } => { - info!("Output text: {text}"); + output @ Output::Text { .. } => { + println!("{output:?}"); } Output::RequestCompleted { .. } => {} Output::ClearAudio => { diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index 227e0e6..90c66f2 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -91,6 +91,7 @@ impl ProviderApi for GoogleProvider { .as_deref() .map(parse_voice_value) .transpose()?; + params.input_audio_transcription = true; params.output_audio_transcription = true; params.tools.push(get_time_tool()); diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index e9bbe17..1515297 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -8,7 +8,7 @@ use gemini_live::{ transport::{Auth, Endpoint, TransportConfig}, types::{ AudioTranscriptionConfig, Content, ContextWindowCompressionConfig, FunctionDeclaration, - FunctionResponse, GenerationConfig, GoogleSearchTool, Modality, ModalityTokenCount, Part, + FunctionResponse, GenerationConfig, Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, ServerEvent, SessionResumptionConfig, SetupConfig, SlidingWindow, SpeechConfig, ThinkingConfig, Tool, UsageMetadata, VoiceConfig, }, @@ -21,6 +21,12 @@ pub struct Client { params: Params, } +#[derive(Debug, Clone, Copy)] +struct TextOutputConfig { + text: bool, + interim: bool, +} + impl Client { pub fn new(params: Params) -> Self { Self { params } @@ -30,13 +36,19 @@ impl Client { self, _input_format: AudioFormat, output_format: AudioFormat, - output_transcription: bool, + text_output_enabled: bool, + interim_text_output_enabled: bool, mut input: ConversationInput, output: ConversationOutput, ) -> Result<()> { + let text_output = TextOutputConfig { + text: text_output_enabled, + interim: interim_text_output_enabled, + }; let billing_scope = self.params.model.clone(); let tools = function_declarations(&self.params.tools); - let mut session = Session::connect(self.session_config(output_transcription)) + let mut output_transcription_buffer = String::new(); + let mut session = Session::connect(self.session_config(text_output)?) .await .context("Connecting to Gemini Live")?; output.service_event( @@ -57,7 +69,7 @@ impl Client { event = session.next_event() => { match event { Some(event) => { - match self.process_event(event, output_format, output_transcription, &output, &billing_scope).await? { + match self.process_event(event, output_format, text_output, &output, &billing_scope, &mut output_transcription_buffer).await? { FlowControl::Continue => {} FlowControl::End => break, } @@ -75,7 +87,7 @@ impl Client { Ok(()) } - fn session_config(&self, output_transcription: bool) -> SessionConfig { + fn session_config(&self, text_output: TextOutputConfig) -> Result { let transport = TransportConfig { endpoint: self .params @@ -87,32 +99,54 @@ impl Client { ..Default::default() }; - SessionConfig { + Ok(SessionConfig { transport, - setup: self.setup_config(output_transcription), + setup: self.setup_config(text_output)?, reconnect: ReconnectPolicy::default(), - } + }) } - fn setup_config(&self, output_transcription: bool) -> SetupConfig { - SetupConfig { + fn setup_config(&self, text_output: TextOutputConfig) -> Result { + let input_audio_transcription = self + .params + .input_audio_transcription + .then_some(AudioTranscriptionConfig {}); + let output_audio_transcription = self + .params + .output_audio_transcription + .then_some(AudioTranscriptionConfig {}); + + if !(text_output.text || text_output.interim) + && (input_audio_transcription.is_some() || output_audio_transcription.is_some()) + { + bail!( + "Google dialog requires text output modality when transcription is enabled: if inputAudioTranscription or outputAudioTranscription is set, add OutputModality::Text or OutputModality::InterimText to the conversation output modalities, or set both transcription flags to false." + ); + } + + Ok(SetupConfig { model: model_resource_name(&self.params.model), generation_config: Some(GenerationConfig { + // NOTE: Enabling Modality::Text here currently causes a Gemini setup-time + // "Internal error encountered." in this service flow. response_modalities: Some(vec![Modality::Audio]), speech_config: self.params.voice.clone().map(|voice_name| SpeechConfig { voice_config: VoiceConfig { prebuilt_voice_config: PrebuiltVoiceConfig { voice_name }, }, }), - thinking_config: self.params.thinking_level.map(|thinking_level| ThinkingConfig { - thinking_level: Some(thinking_level), - ..Default::default() - }), + thinking_config: self + .params + .thinking_level + .map(|thinking_level| ThinkingConfig { + thinking_level: Some(thinking_level), + ..Default::default() + }), temperature: self.params.temperature, ..Default::default() }), system_instruction: self.params.instructions.clone().map(system_instruction), - tools: setup_tools(&self.params), + tools: (!self.params.tools.is_empty()).then(|| self.params.tools.clone()), realtime_input_config: self.params.realtime_input_config.clone(), // Opt in so Gemini sends resume handles. The session layer stores // the latest handle and patches it into reconnect setup messages, @@ -124,15 +158,10 @@ impl Client { ..Default::default() }, ), - input_audio_transcription: self - .params - .input_audio_transcription - .then_some(AudioTranscriptionConfig {}), - output_audio_transcription: (output_transcription_enabled(&self.params)) - .then_some(AudioTranscriptionConfig {}) - .or_else(|| output_transcription.then_some(AudioTranscriptionConfig {})), + input_audio_transcription, + output_audio_transcription, ..Default::default() - } + }) } async fn process_input(&self, session: &Session, input: Input) -> Result<()> { @@ -184,17 +213,17 @@ impl Client { &self, event: ServerEvent, output_format: AudioFormat, - output_transcription: bool, + text_output: TextOutputConfig, output: &ConversationOutput, billing_scope: &str, + output_transcription_buffer: &mut String, ) -> Result { trace!(?event, "Gemini Live event"); match event { ServerEvent::SetupComplete => {} ServerEvent::ModelText(text) => { - if output_transcription { - output.text(true, text, None, None)?; - } + // This does not seem to work, even when we enable TEXT response modalities. + debug!(%text, "Gemini model text"); } ServerEvent::ModelAudio(audio) => { let frame = AudioFrame::from_le_bytes(output_format, &audio); @@ -202,17 +231,36 @@ impl Client { } ServerEvent::GenerationComplete => {} ServerEvent::TurnComplete => { + if text_output.text && !output_transcription_buffer.is_empty() { + output.text( + true, + std::mem::take(output_transcription_buffer), + None, + Some(self.params.model.clone()), + )?; + } else { + output_transcription_buffer.clear(); + } output.request_completed(None)?; } ServerEvent::Interrupted => { + output_transcription_buffer.clear(); output.clear_audio()?; } ServerEvent::InputTranscription(text) => { - debug!(%text, "Gemini input transcription"); + if text_output.text { + output.text(true, text, None, None)?; + } } ServerEvent::OutputTranscription(text) => { - if output_transcription { - output.text(true, text, None, None)?; + output_transcription_buffer.push_str(&text); + if text_output.interim { + output.text( + false, + output_transcription_buffer.clone(), + None, + Some(self.params.model.clone()), + )?; } } ServerEvent::ToolCall(calls) => { @@ -272,22 +320,6 @@ fn system_instruction(text: String) -> Content { } } -fn output_transcription_enabled(params: &Params) -> bool { - params.output_audio_transcription -} - -fn setup_tools(params: &Params) -> Option> { - let mut tools = params.tools.clone(); - if params.enable_search - && !tools - .iter() - .any(|tool| matches!(tool, Tool::GoogleSearch(_))) - { - tools.push(Tool::GoogleSearch(GoogleSearchTool::default())); - } - (!tools.is_empty()).then_some(tools) -} - fn function_declarations(tools: &[Tool]) -> Option> { let declarations: Vec<_> = tools .iter() diff --git a/services/google-dialog/src/lib.rs b/services/google-dialog/src/lib.rs index 3a2608c..4691507 100644 --- a/services/google-dialog/src/lib.rs +++ b/services/google-dialog/src/lib.rs @@ -4,7 +4,7 @@ use anyhow::{Result, bail}; use async_trait::async_trait; use tracing::info; -use context_switch_core::{AudioFormat, Service, conversation::Conversation}; +use context_switch_core::{AudioFormat, OutputModality, Service, conversation::Conversation}; mod client; mod types; @@ -22,7 +22,7 @@ impl Service for GoogleDialog { async fn conversation(&self, params: Params, conversation: Conversation) -> Result<()> { let input_format = conversation.require_audio_input()?; let output_format = conversation.require_one_audio_output()?; - let output_transcription = conversation.has_one_text_output()?; + let text_outputs = TextOutputs::from_modalities(&conversation.output_modalities)?; let expected_output = AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE); if output_format != expected_output { @@ -39,10 +39,46 @@ impl Service for GoogleDialog { .dialog( input_format, output_format, - output_transcription, + text_outputs.text, + text_outputs.interim, input, output, ) .await } } + +#[derive(Debug, Clone, Copy)] +struct TextOutputs { + text: bool, + interim: bool, +} + +impl TextOutputs { + fn from_modalities(output_modalities: &[OutputModality]) -> Result { + let mut text_outputs = Self { + text: false, + interim: false, + }; + + for modality in output_modalities { + match modality { + OutputModality::Text => { + if text_outputs.text { + bail!("Expecting at most one text output") + } + text_outputs.text = true; + } + OutputModality::InterimText => { + if text_outputs.interim { + bail!("Expecting at most one interim text output") + } + text_outputs.interim = true; + } + OutputModality::Audio { .. } => {} + } + } + + Ok(text_outputs) + } +} diff --git a/services/google-dialog/src/types.rs b/services/google-dialog/src/types.rs index 45c145f..fa4494f 100644 --- a/services/google-dialog/src/types.rs +++ b/services/google-dialog/src/types.rs @@ -10,21 +10,21 @@ pub struct Params { pub host: Option, pub instructions: Option, pub voice: Option, + pub temperature: Option, /// Gemini 3.1 thinking level (`minimal`, `low`, `medium`, or `high`). pub thinking_level: Option, /// Enabled by default to avoid context-window exhaustion during long audio sessions. #[serde(default = "default_true")] pub context_window_compression: bool, - /// Add Gemini's built-in Google Search tool unless it is already present. - #[serde(default)] - pub enable_search: bool, #[serde(default)] pub tools: Vec, /// Gemini realtime input behavior, including VAD and turn coverage. pub realtime_input_config: Option, + /// Enable server-side transcription of user input audio. #[serde(default)] pub input_audio_transcription: bool, + /// Enable server-side transcription of model output audio. #[serde(default)] pub output_audio_transcription: bool, } @@ -40,7 +40,6 @@ impl Params { temperature: None, thinking_level: None, context_window_compression: true, - enable_search: false, tools: vec![], realtime_input_config: None, input_audio_transcription: false, From 94b93487c313d068f967098ff05b80d06cbf7c4d Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 08:24:05 +0200 Subject: [PATCH 12/20] Reorder --- examples/dialog_providers/azure_openai.rs | 64 ++++++------- examples/dialog_providers/google.rs | 104 +++++++++++----------- examples/dialog_providers/mod.rs | 6 +- examples/dialog_providers/openai.rs | 94 +++++++++---------- 4 files changed, 134 insertions(+), 134 deletions(-) diff --git a/examples/dialog_providers/azure_openai.rs b/examples/dialog_providers/azure_openai.rs index fcde448..588a7b8 100644 --- a/examples/dialog_providers/azure_openai.rs +++ b/examples/dialog_providers/azure_openai.rs @@ -15,38 +15,6 @@ pub struct AzureOpenAIProvider; #[async_trait(?Send)] impl ProviderApi for AzureOpenAIProvider { - fn output_format(&self, input_format: AudioFormat) -> AudioFormat { - input_format - } - - fn voices(&self) -> &'static [&'static str] { - ::VARIANTS - } - - async fn list_models(&self, request: ListModelsRequest) -> Result<()> { - println!("Available models for Azure:"); - println!( - "- Azure Realtime uses deployment names configured in your Azure OpenAI resource." - ); - println!( - "- The realtime endpoint does not provide a provider-agnostic model listing API here." - ); - - if let Some(model) = request - .model - .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) - .filter(|model| !model.trim().is_empty()) - { - println!("- Configured deployment/model: {model}"); - } else { - println!( - "- Set --model or AZURE_OPENAI_REALTIME_API_MODEL to your Azure deployment name." - ); - } - - Ok(()) - } - async fn start_conversation( &self, request: StartConversationRequest, @@ -87,4 +55,36 @@ impl ProviderApi for AzureOpenAIProvider { ) -> Result { OpenAIProvider.function_result_event(call_id, None, result) } + + fn output_format(&self, input_format: AudioFormat) -> AudioFormat { + input_format + } + + fn voices(&self) -> &'static [&'static str] { + ::VARIANTS + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + println!("Available models for Azure:"); + println!( + "- Azure Realtime uses deployment names configured in your Azure OpenAI resource." + ); + println!( + "- The realtime endpoint does not provide a provider-agnostic model listing API here." + ); + + if let Some(model) = request + .model + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + { + println!("- Configured deployment/model: {model}"); + } else { + println!( + "- Set --model or AZURE_OPENAI_REALTIME_API_MODEL to your Azure deployment name." + ); + } + + Ok(()) + } } diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index 90c66f2..a9b6188 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -17,58 +17,6 @@ pub struct GoogleProvider; #[async_trait(?Send)] impl ProviderApi for GoogleProvider { - fn output_format(&self, _input_format: AudioFormat) -> AudioFormat { - AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE) - } - - fn voices(&self) -> &'static [&'static str] { - VOICES - } - - async fn list_models(&self, request: ListModelsRequest) -> Result<()> { - let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; - let endpoint = request - .endpoint - .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()); - let models_url = models_url(endpoint.as_deref())?; - - let response: GeminiModelsResponse = reqwest::Client::new() - .get(models_url) - .query(&[("key", key)]) - .send() - .await - .context("Requesting Gemini models")? - .error_for_status() - .context("Gemini models request failed")? - .json() - .await - .context("Decoding Gemini models response")?; - - let mut live_models: Vec<_> = response - .models - .into_iter() - .filter(|model| is_live_model(&model.name, &model.supported_generation_methods)) - .map(|model| model.name) - .collect(); - live_models.sort(); - - println!("Available models for Google (Live API capable):"); - if live_models.is_empty() { - println!("- No Live-capable models were detected from models.list."); - println!( - "- This can happen when model metadata does not include Live-specific methods." - ); - println!("- Try explicitly using a known Live model, for example:"); - println!(" - models/gemini-3.1-flash-live-preview"); - println!(" - models/gemini-2.5-flash-live-preview"); - } else { - for model in live_models { - println!("- {model}"); - } - } - Ok(()) - } - async fn start_conversation( &self, request: StartConversationRequest, @@ -135,6 +83,58 @@ impl ProviderApi for GoogleProvider { }) .map_err(Into::into) } + + fn output_format(&self, _input_format: AudioFormat) -> AudioFormat { + AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE) + } + + fn voices(&self) -> &'static [&'static str] { + VOICES + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let endpoint = request + .endpoint + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()); + let models_url = models_url(endpoint.as_deref())?; + + let response: GeminiModelsResponse = reqwest::Client::new() + .get(models_url) + .query(&[("key", key)]) + .send() + .await + .context("Requesting Gemini models")? + .error_for_status() + .context("Gemini models request failed")? + .json() + .await + .context("Decoding Gemini models response")?; + + let mut live_models: Vec<_> = response + .models + .into_iter() + .filter(|model| is_live_model(&model.name, &model.supported_generation_methods)) + .map(|model| model.name) + .collect(); + live_models.sort(); + + println!("Available models for Google (Live API capable):"); + if live_models.is_empty() { + println!("- No Live-capable models were detected from models.list."); + println!( + "- This can happen when model metadata does not include Live-specific methods." + ); + println!("- Try explicitly using a known Live model, for example:"); + println!(" - models/gemini-3.1-flash-live-preview"); + println!(" - models/gemini-2.5-flash-live-preview"); + } else { + for model in live_models { + println!("- {model}"); + } + } + Ok(()) + } } fn models_url(endpoint: Option<&str>) -> Result { diff --git a/examples/dialog_providers/mod.rs b/examples/dialog_providers/mod.rs index c8f3cb9..84a48f1 100644 --- a/examples/dialog_providers/mod.rs +++ b/examples/dialog_providers/mod.rs @@ -19,9 +19,6 @@ pub struct StartConversationRequest { #[async_trait(?Send)] pub trait ProviderApi { - fn output_format(&self, input_format: AudioFormat) -> AudioFormat; - fn voices(&self) -> &'static [&'static str]; - async fn list_models(&self, request: ListModelsRequest) -> Result<()>; async fn start_conversation( &self, request: StartConversationRequest, @@ -34,6 +31,9 @@ pub trait ProviderApi { name: Option, result: String, ) -> Result; + fn output_format(&self, input_format: AudioFormat) -> AudioFormat; + fn voices(&self) -> &'static [&'static str]; + async fn list_models(&self, request: ListModelsRequest) -> Result<()>; } pub fn provider_api(provider: Provider) -> &'static dyn ProviderApi { diff --git a/examples/dialog_providers/openai.rs b/examples/dialog_providers/openai.rs index 89beaf2..2bb4f0a 100644 --- a/examples/dialog_providers/openai.rs +++ b/examples/dialog_providers/openai.rs @@ -21,53 +21,6 @@ pub struct OpenAIProvider; #[async_trait(?Send)] impl ProviderApi for OpenAIProvider { - fn output_format(&self, input_format: AudioFormat) -> AudioFormat { - input_format - } - - fn voices(&self) -> &'static [&'static str] { - ::VARIANTS - } - - async fn list_models(&self, request: ListModelsRequest) -> Result<()> { - let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; - let endpoint = request - .endpoint - .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()); - let models_url = models_url(endpoint.as_deref())?; - - let response: OpenAIModelsResponse = reqwest::Client::new() - .get(models_url) - .bearer_auth(key) - .send() - .await - .context("Requesting OpenAI models")? - .error_for_status() - .context("OpenAI models request failed")? - .json() - .await - .context("Decoding OpenAI models response")?; - - let mut models: Vec<_> = response - .data - .into_iter() - .map(|model| model.id) - .filter(|id| is_realtime_model(id)) - .collect(); - models.sort(); - - println!("Available models for OpenAI:"); - if models.is_empty() { - println!("- No realtime-capable models were returned by the models endpoint."); - println!("- Ensure your API key has access to OpenAI Realtime API models."); - } else { - for model in models { - println!("- {model}"); - } - } - Ok(()) - } - async fn start_conversation( &self, request: StartConversationRequest, @@ -124,6 +77,53 @@ impl ProviderApi for OpenAIProvider { serde_json::to_value(&OpenAIServiceInputEvent::FunctionCallResult { call_id, output }) .map_err(Into::into) } + + fn output_format(&self, input_format: AudioFormat) -> AudioFormat { + input_format + } + + fn voices(&self) -> &'static [&'static str] { + ::VARIANTS + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let endpoint = request + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()); + let models_url = models_url(endpoint.as_deref())?; + + let response: OpenAIModelsResponse = reqwest::Client::new() + .get(models_url) + .bearer_auth(key) + .send() + .await + .context("Requesting OpenAI models")? + .error_for_status() + .context("OpenAI models request failed")? + .json() + .await + .context("Decoding OpenAI models response")?; + + let mut models: Vec<_> = response + .data + .into_iter() + .map(|model| model.id) + .filter(|id| is_realtime_model(id)) + .collect(); + models.sort(); + + println!("Available models for OpenAI:"); + if models.is_empty() { + println!("- No realtime-capable models were returned by the models endpoint."); + println!("- Ensure your API key has access to OpenAI Realtime API models."); + } else { + for model in models { + println!("- {model}"); + } + } + Ok(()) + } } pub fn parse_realtime_voice_value(value: &str) -> Result { From e2287142ed1936d8f83ddbf0c66d60849281c201 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 08:47:52 +0200 Subject: [PATCH 13/20] Review some uses and import structures --- .harper-dictionary.txt | 2 +- .vscode/settings.json | 8 ------- core/src/billing_context.rs | 4 +++- core/src/conversation.rs | 2 +- core/src/lib.rs | 3 ++- examples/dialog.rs | 7 +++--- examples/dialog_providers/google.rs | 3 ++- examples/dialog_providers/mod.rs | 3 ++- examples/dialog_providers/openai.rs | 9 ++++---- services/aristech/src/synthesize.rs | 5 +--- services/aristech/src/transcribe.rs | 5 +--- services/azure/src/synthesize.rs | 3 +-- services/azure/src/transcribe.rs | 5 ++-- services/azure/src/translate.rs | 4 ++-- services/elevenlabs/src/transcribe.rs | 18 ++++++--------- services/google-dialog/src/client.rs | 9 ++++---- services/google-dialog/src/lib.rs | 2 +- services/google-transcribe/src/transcribe.rs | 7 +++--- services/openai-dialog/src/client.rs | 24 ++++++++------------ services/openai-dialog/src/lib.rs | 2 +- services/playback/src/lib.rs | 24 ++++++++------------ 21 files changed, 62 insertions(+), 87 deletions(-) delete mode 100644 .vscode/settings.json diff --git a/.harper-dictionary.txt b/.harper-dictionary.txt index dff3086..fd2abbd 100644 --- a/.harper-dictionary.txt +++ b/.harper-dictionary.txt @@ -2,6 +2,6 @@ AirPods BCP ContextSwitch FreeSWITCH -Inband +inband seekable subtag diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 8ef4f25..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "cSpell.words": [ - "alsa", - "aristech", - "denoiser", - "subtag" - ] -} \ No newline at end of file diff --git a/core/src/billing_context.rs b/core/src/billing_context.rs index 2269e69..004fa8a 100644 --- a/core/src/billing_context.rs +++ b/core/src/billing_context.rs @@ -2,7 +2,9 @@ use std::sync::{Arc, Mutex}; use anyhow::Result; -use crate::{BillingRecord, billing_collector::BillingCollector, conversation::BillingId}; +use crate::BillingRecord; +use crate::billing_collector::BillingCollector; +use crate::conversation::BillingId; #[derive(Debug, Clone)] pub struct BillingContext { diff --git a/core/src/conversation.rs b/core/src/conversation.rs index 038ca60..486e1a3 100644 --- a/core/src/conversation.rs +++ b/core/src/conversation.rs @@ -168,7 +168,7 @@ impl ConversationInput { self.input.recv().await } - /// Run a nested service conversation with one single input request and wait until its + /// Run a nested service conversation with one single input request and wait until it's /// completed. /// /// All output is sent to the conversation output. diff --git a/core/src/lib.rs b/core/src/lib.rs index 1e0b9c7..6e24baa 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,7 +1,7 @@ pub mod audio; pub mod billing_collector; mod billing_context; -pub mod conversation; +mod conversation; mod duration; pub mod language; mod protocol; @@ -15,6 +15,7 @@ use anyhow::{Context, Result, bail}; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender, unbounded_channel}; pub use billing_context::BillingContext; +pub use conversation::*; pub use duration::Duration; pub use protocol::*; pub use registry::*; diff --git a/examples/dialog.rs b/examples/dialog.rs index 2377b14..b343a71 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -60,10 +60,12 @@ impl Provider { #[tokio::main] async fn main() -> Result<()> { + dotenvy::dotenv_override().context("Reading .env file")?; + tracing_subscriber::fmt::init(); + let cli = Cli::parse(); if cli.list_models { - let _ = dotenvy::dotenv_override(); list_available_models(&cli).await?; return Ok(()); } @@ -73,9 +75,6 @@ async fn main() -> Result<()> { return Ok(()); } - dotenvy::dotenv_override().context("Reading .env file")?; - tracing_subscriber::fmt::init(); - let host = cpal::default_host(); let device = host .default_input_device() diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index a9b6188..4a2e824 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -9,7 +9,8 @@ use serde::Deserialize; use serde_json::json; use crate::{FunctionCall, get_time_parameters_schema}; -use context_switch_core::{AudioFormat, Service, conversation::Conversation}; +use context_switch_core::conversation::Conversation; +use context_switch_core::{AudioFormat, Service}; use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; diff --git a/examples/dialog_providers/mod.rs b/examples/dialog_providers/mod.rs index 84a48f1..42b1c74 100644 --- a/examples/dialog_providers/mod.rs +++ b/examples/dialog_providers/mod.rs @@ -2,7 +2,8 @@ use anyhow::Result; use async_trait::async_trait; use crate::{FunctionCall, Provider}; -use context_switch_core::{AudioFormat, conversation::Conversation}; +use context_switch_core::AudioFormat; +use context_switch_core::conversation::Conversation; #[derive(Debug, Clone)] pub struct ListModelsRequest { diff --git a/examples/dialog_providers/openai.rs b/examples/dialog_providers/openai.rs index 2bb4f0a..04119f2 100644 --- a/examples/dialog_providers/openai.rs +++ b/examples/dialog_providers/openai.rs @@ -1,6 +1,6 @@ use std::{env, str::FromStr}; -use anyhow::{Context, Result}; +use anyhow::{Context, Result, anyhow}; use async_trait::async_trait; use openai_api_rs::realtime::types as openai_types; use openai_dialog::{ @@ -13,7 +13,8 @@ use serde_json::json; use strum::VariantNames; use crate::{FunctionCall, get_time_parameters_schema}; -use context_switch_core::{AudioFormat, Service, conversation::Conversation}; +use context_switch_core::conversation::Conversation; +use context_switch_core::{AudioFormat, Service}; use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; @@ -83,7 +84,7 @@ impl ProviderApi for OpenAIProvider { } fn voices(&self) -> &'static [&'static str] { - ::VARIANTS + openai_types::RealtimeVoice::VARIANTS } async fn list_models(&self, request: ListModelsRequest) -> Result<()> { @@ -128,7 +129,7 @@ impl ProviderApi for OpenAIProvider { pub fn parse_realtime_voice_value(value: &str) -> Result { openai_types::RealtimeVoice::from_str(value) - .map_err(|error| anyhow::anyhow!("Invalid voice value `{value}`: {error}")) + .map_err(|error| anyhow!("Invalid voice value `{value}`: {error}")) } pub fn get_time_function_definition() -> openai_types::ToolDefinition { diff --git a/services/aristech/src/synthesize.rs b/services/aristech/src/synthesize.rs index 2820c29..3a9a745 100644 --- a/services/aristech/src/synthesize.rs +++ b/services/aristech/src/synthesize.rs @@ -10,10 +10,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::debug; -use context_switch_core::{ - AudioFormat, AudioFrame, Service, - conversation::{Conversation, Input}, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, Service}; //TODO: Add `language` field as alternative to `voice_id` #[derive(Debug, Serialize, Deserialize)] diff --git a/services/aristech/src/transcribe.rs b/services/aristech/src/transcribe.rs index 1c3b078..0719084 100644 --- a/services/aristech/src/transcribe.rs +++ b/services/aristech/src/transcribe.rs @@ -12,10 +12,7 @@ use async_trait::async_trait; use serde::Deserialize; use tonic::codegen::CompressionEncoding; -use context_switch_core::{ - Service, - conversation::{Conversation, Input}, -}; +use context_switch_core::{Conversation, Input, Service}; /// Authentication configuration #[derive(Debug, Deserialize)] diff --git a/services/azure/src/synthesize.rs b/services/azure/src/synthesize.rs index 3d56163..fc47bcd 100644 --- a/services/azure/src/synthesize.rs +++ b/services/azure/src/synthesize.rs @@ -13,8 +13,7 @@ use tracing::debug; use crate::Host; use context_switch_core::{ - AudioFrame, BillingRecord, Service, - conversation::{BillingSchedule, Conversation, Input}, + AudioFrame, BillingRecord, BillingSchedule, Conversation, Input, Service, }; #[derive(Debug, Serialize, Deserialize)] diff --git a/services/azure/src/transcribe.rs b/services/azure/src/transcribe.rs index e4037a3..bdb38c5 100644 --- a/services/azure/src/transcribe.rs +++ b/services/azure/src/transcribe.rs @@ -7,10 +7,9 @@ use serde::Deserialize; use tracing::{error, info}; use crate::Host; +use context_switch_core::language::Languages; use context_switch_core::{ - BillingRecord, Service, - conversation::{BillingSchedule, Conversation, ConversationOutput, Input}, - language::Languages, + BillingRecord, BillingSchedule, Conversation, ConversationOutput, Input, Service, speech_gate::make_speech_gate_processor_soft_rms, }; diff --git a/services/azure/src/translate.rs b/services/azure/src/translate.rs index 7b6731d..129232d 100644 --- a/services/azure/src/translate.rs +++ b/services/azure/src/translate.rs @@ -8,8 +8,8 @@ use tracing::{debug, error}; use crate::Host; use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, OutputModality, OutputPath, Service, - conversation::{BillingSchedule, Conversation, Input}, + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, Conversation, Input, OutputModality, + OutputPath, Service, }; #[derive(Debug, Deserialize)] diff --git a/services/elevenlabs/src/transcribe.rs b/services/elevenlabs/src/transcribe.rs index d0f8a9c..81f86ee 100644 --- a/services/elevenlabs/src/transcribe.rs +++ b/services/elevenlabs/src/transcribe.rs @@ -7,21 +7,17 @@ use serde_json::Value; use tokio::select; use tokio::sync::mpsc; use tokio::time::{Duration, sleep}; -use tokio_tungstenite::{ - connect_async_with_config, - tungstenite::{ - Message, - client::IntoClientRequest, - http::{HeaderName, HeaderValue}, - }, -}; +use tokio_tungstenite::connect_async_with_config; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::{HeaderName, HeaderValue}; use tracing::{debug, error, warn}; use url::Url; +use context_switch_core::language::{bcp47_to_iso639_3, iso639_to_bcp47}; use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, Service, - conversation::{BillingSchedule, Conversation, ConversationInput, ConversationOutput, Input}, - language::{bcp47_to_iso639_3, iso639_to_bcp47}, + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, Conversation, ConversationInput, + ConversationOutput, Input, Service, }; // Observed Scribe v2 behavior as of 2026-04-02: diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 1515297..331731d 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -1,8 +1,5 @@ use anyhow::{Context, Result, bail}; -use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, OutputPath, - conversation::{BillingSchedule, ConversationInput, ConversationOutput, Input}, -}; + use gemini_live::{ ReconnectPolicy, Session, SessionConfig, transport::{Auth, Endpoint, TransportConfig}, @@ -16,6 +13,10 @@ use gemini_live::{ use tracing::{debug, info, trace}; use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; +use context_switch_core::{ + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, ConversationInput, ConversationOutput, + Input, OutputPath, +}; pub struct Client { params: Params, diff --git a/services/google-dialog/src/lib.rs b/services/google-dialog/src/lib.rs index 4691507..333a8ef 100644 --- a/services/google-dialog/src/lib.rs +++ b/services/google-dialog/src/lib.rs @@ -4,7 +4,7 @@ use anyhow::{Result, bail}; use async_trait::async_trait; use tracing::info; -use context_switch_core::{AudioFormat, OutputModality, Service, conversation::Conversation}; +use context_switch_core::{AudioFormat, Conversation, OutputModality, Service}; mod client; mod types; diff --git a/services/google-transcribe/src/transcribe.rs b/services/google-transcribe/src/transcribe.rs index dc91808..397139b 100644 --- a/services/google-transcribe/src/transcribe.rs +++ b/services/google-transcribe/src/transcribe.rs @@ -10,9 +10,8 @@ use tokio::sync::mpsc::UnboundedReceiver; use tonic::Code; use context_switch_core::{ - AudioFormat, AudioFrame, AudioProducer, BillingRecord, OutputModality, Service, - conversation::{BillingSchedule, Conversation, ConversationOutput, Input}, - language::Languages, + AudioFormat, AudioFrame, AudioProducer, BillingRecord, BillingSchedule, Conversation, + ConversationOutput, Input, OutputModality, Service, language::Languages, }; use tracing::{info, warn}; @@ -176,7 +175,7 @@ async fn transcribe_and_process_stream_session( async fn process_stream_session( model: &str, include_detected_language: bool, - output: &context_switch_core::conversation::ConversationOutput, + output: &ConversationOutput, response_stream: S, ) -> Result where diff --git a/services/openai-dialog/src/client.rs b/services/openai-dialog/src/client.rs index c985f93..b304825 100644 --- a/services/openai-dialog/src/client.rs +++ b/services/openai-dialog/src/client.rs @@ -2,20 +2,12 @@ use std::collections::VecDeque; use anyhow::{Context, Result, bail}; use base64::prelude::*; -use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, OutputPath, audio, - conversation::{BillingSchedule, ConversationInput, ConversationOutput, Input}, -}; -use futures::{ - SinkExt, StreamExt, - stream::{SplitSink, SplitStream}, -}; -use openai_api_rs::realtime::{ - client_event::{self, ClientEvent}, - server_event::{self, ServerEvent}, - types::{ - self, ItemContentType, ItemRole, ItemStatus, ItemType, OutputModality, ResponseStatus, - }, +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use openai_api_rs::realtime::client_event::{self, ClientEvent}; +use openai_api_rs::realtime::server_event::{self, ServerEvent}; +use openai_api_rs::realtime::types::{ + self, ItemContentType, ItemRole, ItemStatus, ItemType, OutputModality, ResponseStatus, }; use tokio::{net::TcpStream, select}; use tokio_tungstenite::{ @@ -26,6 +18,10 @@ use tracing::{debug, info, trace, warn}; use uuid::Uuid; use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; +use context_switch_core::{ + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, ConversationInput, ConversationOutput, + Input, OutputPath, audio, +}; pub struct Client { read: SplitStream>>, diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index f3bd280..0cdbeee 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -6,7 +6,7 @@ use anyhow::{Result, bail}; use async_trait::async_trait; use tracing::info; -use context_switch_core::{Service, conversation::Conversation}; +use context_switch_core::{Conversation, Service}; mod client; mod host; diff --git a/services/playback/src/lib.rs b/services/playback/src/lib.rs index 6d8b0ad..74d149d 100644 --- a/services/playback/src/lib.rs +++ b/services/playback/src/lib.rs @@ -1,25 +1,19 @@ -use std::{ - fs::{self, File}, - io::{self, BufReader}, - num::{NonZeroU16, NonZeroU32}, - path::{Path, PathBuf}, -}; +use std::fs::{self, File}; +use std::io::{self, BufReader}; +use std::num::{NonZeroU16, NonZeroU32}; +use std::path::{Path, PathBuf}; -use anyhow::{Context, Result, bail}; +use anyhow::{Context, Result, anyhow, bail}; use async_trait::async_trait; -use context_switch_core::{BillingRecord, audio, conversation::BillingSchedule}; -use rodio::{ - Decoder, Source, - conversions::{ChannelCountConverter, SampleRateConverter}, -}; +use rodio::conversions::{ChannelCountConverter, SampleRateConverter}; +use rodio::{Decoder, Source}; use serde::{Deserialize, Serialize}; use tokio::task; use tracing::{debug, error}; use url::Url; use context_switch_core::{ - AudioFormat, AudioFrame, Service, - conversation::{Conversation, Input}, + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, Conversation, Input, Service, audio, }; mod stream_reader; @@ -353,7 +347,7 @@ pub fn check_supported_audio_type( } else { let guessed_mime = mime_guess2::from_path(path) .first() - .ok_or_else(|| anyhow::anyhow!("Invalid audio url (should end in `.mp3` or `.wav`)"))?; + .ok_or_else(|| anyhow!("Invalid audio url (should end in `.mp3` or `.wav`)"))?; guessed_mime.essence_str().to_string() }; From 015ada29a710a06cbdd38520c1d01c3dd1d4da3c Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 08:52:52 +0200 Subject: [PATCH 14/20] Readd lost comments --- examples/dialog.rs | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/examples/dialog.rs b/examples/dialog.rs index b343a71..6c69dad 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -93,6 +93,7 @@ async fn main() -> Result<()> { let (input_sender, input_receiver) = channel(256); let input_sender2 = input_sender.clone(); + // Create and run the input stream let stream = device .build_input_stream( &input_config.into(), @@ -109,6 +110,7 @@ async fn main() -> Result<()> { move |err| { eprintln!("Error occurred on stream: {err}"); }, + // Timeout Some(Duration::from_secs(1)), ) .expect("Failed to build input stream"); @@ -136,13 +138,17 @@ async fn main() -> Result<()> { tokio::pin!(conversation); let playback_task = setup_audio_playback(cli.provider, output_format, input_sender, output_receiver).await; + // Spawn audio playback task let mut playback_handle = tokio::spawn(playback_task); select! { + // Drive conversation r = &mut conversation => { + // When conversation ends, wait for playback to complete before returning. let _ = playback_handle.await; r? } + // Drive playback r = &mut playback_handle => { r?? } @@ -193,7 +199,9 @@ async fn setup_audio_playback( ) -> impl std::future::Future> { let (cmd_tx, cmd_rx) = std::sync::mpsc::channel(); + // Spawn a dedicated audio thread let playback_thread = thread::spawn(move || { + // Create output stream in the audio thread let sink_handle = DeviceSinkBuilder::open_default_sink().unwrap(); let player = Player::connect_new(sink_handle.mixer()); @@ -219,6 +227,7 @@ async fn setup_audio_playback( player.sleep_until_end(); }); + // Create async task to forward frames to the audio thread async move { while let Some(output) = output.recv().await { match output { @@ -246,6 +255,7 @@ async fn setup_audio_playback( } } let _ = cmd_tx.send(AudioCommand::Stop); + // TODO: this may block! let _ = playback_thread.join(); Ok(()) } From 830399cc20bdc1c2d37072baf57750e2c23ff1a4 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 09:00:11 +0200 Subject: [PATCH 15/20] Review --- examples/dialog.rs | 1 + services/google-dialog/src/client.rs | 18 ++++++++---------- services/google-dialog/src/lib.rs | 3 +-- 3 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/dialog.rs b/examples/dialog.rs index 6c69dad..f1b07c1 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -301,6 +301,7 @@ struct FunctionCall { name: String, arguments: Option, } + fn get_time_parameters_schema() -> serde_json::Value { json!({ "type": "object", diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 331731d..b3c18c8 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -1,15 +1,13 @@ use anyhow::{Context, Result, bail}; -use gemini_live::{ - ReconnectPolicy, Session, SessionConfig, - transport::{Auth, Endpoint, TransportConfig}, - types::{ - AudioTranscriptionConfig, Content, ContextWindowCompressionConfig, FunctionDeclaration, - FunctionResponse, GenerationConfig, Modality, ModalityTokenCount, Part, - PrebuiltVoiceConfig, ServerEvent, SessionResumptionConfig, SetupConfig, SlidingWindow, - SpeechConfig, ThinkingConfig, Tool, UsageMetadata, VoiceConfig, - }, +use gemini_live::transport::{Auth, Endpoint, TransportConfig}; +use gemini_live::types::{ + AudioTranscriptionConfig, Content, ContextWindowCompressionConfig, FunctionDeclaration, + FunctionResponse, GenerationConfig, Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, + ServerEvent, SessionResumptionConfig, SetupConfig, SlidingWindow, SpeechConfig, ThinkingConfig, + Tool, UsageMetadata, VoiceConfig, }; +use gemini_live::{ReconnectPolicy, Session, SessionConfig}; use tracing::{debug, info, trace}; use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; @@ -18,6 +16,7 @@ use context_switch_core::{ Input, OutputPath, }; +#[derive(Debug)] pub struct Client { params: Params, } @@ -35,7 +34,6 @@ impl Client { pub async fn dialog( self, - _input_format: AudioFormat, output_format: AudioFormat, text_output_enabled: bool, interim_text_output_enabled: bool, diff --git a/services/google-dialog/src/lib.rs b/services/google-dialog/src/lib.rs index 333a8ef..3346eaf 100644 --- a/services/google-dialog/src/lib.rs +++ b/services/google-dialog/src/lib.rs @@ -20,7 +20,7 @@ impl Service for GoogleDialog { type Params = Params; async fn conversation(&self, params: Params, conversation: Conversation) -> Result<()> { - let input_format = conversation.require_audio_input()?; + let _input_format = conversation.require_audio_input()?; let output_format = conversation.require_one_audio_output()?; let text_outputs = TextOutputs::from_modalities(&conversation.output_modalities)?; @@ -37,7 +37,6 @@ impl Service for GoogleDialog { let (input, output) = conversation.start()?; Client::new(params) .dialog( - input_format, output_format, text_outputs.text, text_outputs.interim, From ca88628a3e75427d001d824e6234674d606c1159 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 09:11:36 +0200 Subject: [PATCH 16/20] Fix compilation errors --- audio-knife/src/main.rs | 4 +-- examples/aristech-synthesize.rs | 23 ++++++--------- examples/aristech-transcribe.rs | 8 ++---- examples/azure-translate.rs | 21 +++++--------- examples/dialog.rs | 21 +++++--------- examples/dialog_providers/azure_openai.rs | 2 +- examples/dialog_providers/google.rs | 6 ++-- examples/dialog_providers/mod.rs | 3 +- examples/dialog_providers/openai.rs | 6 ++-- examples/transcribe.rs | 3 +- services/google-dialog/src/client.rs | 35 ++++++++--------------- services/google-dialog/src/lib.rs | 10 ++----- src/context_switch.rs | 3 +- src/protocol.rs | 3 +- src/tests.rs | 5 ++-- 15 files changed, 54 insertions(+), 99 deletions(-) diff --git a/audio-knife/src/main.rs b/audio-knife/src/main.rs index e36d385..d584a15 100644 --- a/audio-knife/src/main.rs +++ b/audio-knife/src/main.rs @@ -37,8 +37,8 @@ use tokio::{ use tracing::{Instrument, Span, debug, error, info, info_span}; use context_switch::{ - AudioFormat, AudioFrame, ClientEvent, ContextSwitch, ConversationId, InputModality, - ServerEvent, audio, billing_collector::BillingCollector, conversation::BillingId, + AudioFormat, AudioFrame, BillingId, ClientEvent, ContextSwitch, ConversationId, InputModality, + ServerEvent, audio, billing_collector::BillingCollector, }; use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan}; use uuid::Uuid; diff --git a/examples/aristech-synthesize.rs b/examples/aristech-synthesize.rs index d57b7b1..206338b 100644 --- a/examples/aristech-synthesize.rs +++ b/examples/aristech-synthesize.rs @@ -1,23 +1,18 @@ -use std::{ - env, - num::{NonZeroU16, NonZeroU32}, - thread, - time::Duration, -}; +use std::env; +use std::num::{NonZeroU16, NonZeroU32}; +use std::thread; +use std::time::Duration; use anyhow::{Context as AnyhowContext, Result}; +use aristech::synthesize::{AristechSynthesize, Params as AristechParams}; use rodio::{DeviceSinkBuilder, Player, Source}; -use tokio::{ - select, - sync::mpsc::{channel, unbounded_channel}, -}; +use tokio::select; +use tokio::sync::mpsc::{channel, unbounded_channel}; -use aristech::synthesize::{AristechSynthesize, Params as AristechParams}; use context_switch::{InputModality, OutputModality}; +use context_switch_core::service::Service; use context_switch_core::{ - AudioFormat, AudioFrame, AudioProducer, audio, - conversation::{Conversation, Input, Output}, - service::Service, + AudioFormat, AudioFrame, AudioProducer, Conversation, Input, Output, audio, }; const SAMPLE_TEXT: &str = "Hallo! Dies ist eine Demonstration des Aristech Text-zu-Sprache-Dienstes. \ diff --git a/examples/aristech-transcribe.rs b/examples/aristech-transcribe.rs index af7a79e..c1144c1 100644 --- a/examples/aristech-transcribe.rs +++ b/examples/aristech-transcribe.rs @@ -9,12 +9,9 @@ use tokio::{ }; use aristech::transcribe::{ApiKeyAuth, AuthConfig, CredentialsAuth, Params as AristechParams}; + use context_switch::{InputModality, OutputModality, services::AristechTranscribe}; -use context_switch_core::{ - AudioFormat, AudioFrame, audio, - conversation::{Conversation, Input}, - service::Service, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, audio, service::Service}; #[tokio::main] async fn main() -> Result<()> { @@ -28,6 +25,7 @@ async fn main() -> Result<()> { let device = host .default_input_device() .expect("Failed to get default input device"); + // spellcheck: ignore // let config = device // .default_input_config() // .expect("Failed to get default input config"); diff --git a/examples/azure-translate.rs b/examples/azure-translate.rs index 1a7e6dd..962fff4 100644 --- a/examples/azure-translate.rs +++ b/examples/azure-translate.rs @@ -1,26 +1,19 @@ //! A context switch demo. Runs locally, gets voice data from your current microphone. -use std::{ - env, - num::{NonZeroU16, NonZeroU32}, - thread, - time::Duration, -}; +use std::env; +use std::num::{NonZeroU16, NonZeroU32}; +use std::thread; +use std::time::Duration; use anyhow::Result; use azure::AzureTranslate; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use rodio::{DeviceSinkBuilder, Player, Source}; +use tokio::select; +use tokio::sync::mpsc::{UnboundedReceiver, channel, unbounded_channel}; use context_switch::{InputModality, OutputModality}; -use context_switch_core::{ - AudioFormat, AudioFrame, Service, audio, - conversation::{Conversation, Input, Output}, -}; -use tokio::{ - select, - sync::mpsc::{UnboundedReceiver, channel, unbounded_channel}, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, Output, Service, audio}; #[tokio::main] async fn main() -> Result<()> { diff --git a/examples/dialog.rs b/examples/dialog.rs index f1b07c1..2eb1e59 100644 --- a/examples/dialog.rs +++ b/examples/dialog.rs @@ -1,11 +1,9 @@ //! A context switch demo. Runs locally, gets voice data from your current microphone. -use std::{ - num::{NonZeroU16, NonZeroU32}, - str::FromStr, - thread, - time::Duration, -}; +use std::num::{NonZeroU16, NonZeroU32}; +use std::str::FromStr; +use std::thread; +use std::time::Duration; use anyhow::{Context, Result, bail}; use chrono::Utc; @@ -14,16 +12,11 @@ use context_switch::{InputModality, OutputModality}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use rodio::{DeviceSinkBuilder, Player, Source}; use serde_json::json; -use tokio::{ - select, - sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}, -}; +use tokio::select; +use tokio::sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}; use tracing::info; -use context_switch_core::{ - AudioFormat, AudioFrame, audio, - conversation::{Conversation, Input, Output}, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, Output, audio}; mod dialog_providers; diff --git a/examples/dialog_providers/azure_openai.rs b/examples/dialog_providers/azure_openai.rs index 588a7b8..e0b2a61 100644 --- a/examples/dialog_providers/azure_openai.rs +++ b/examples/dialog_providers/azure_openai.rs @@ -6,7 +6,7 @@ use openai_api_rs::realtime::types as openai_types; use openai_dialog::{OpenAIDialog, Protocol}; use strum::VariantNames; -use context_switch_core::{AudioFormat, Service, conversation::Conversation}; +use context_switch_core::{AudioFormat, Conversation, Service}; use super::openai::{self, OpenAIProvider}; use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index 4a2e824..b1af97e 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -8,11 +8,9 @@ use reqwest::Url; use serde::Deserialize; use serde_json::json; -use crate::{FunctionCall, get_time_parameters_schema}; -use context_switch_core::conversation::Conversation; -use context_switch_core::{AudioFormat, Service}; - use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; +use crate::{FunctionCall, get_time_parameters_schema}; +use context_switch_core::{AudioFormat, Conversation, Service}; pub struct GoogleProvider; diff --git a/examples/dialog_providers/mod.rs b/examples/dialog_providers/mod.rs index 42b1c74..a62456f 100644 --- a/examples/dialog_providers/mod.rs +++ b/examples/dialog_providers/mod.rs @@ -2,8 +2,7 @@ use anyhow::Result; use async_trait::async_trait; use crate::{FunctionCall, Provider}; -use context_switch_core::AudioFormat; -use context_switch_core::conversation::Conversation; +use context_switch_core::{AudioFormat, Conversation}; #[derive(Debug, Clone)] pub struct ListModelsRequest { diff --git a/examples/dialog_providers/openai.rs b/examples/dialog_providers/openai.rs index 04119f2..4d543be 100644 --- a/examples/dialog_providers/openai.rs +++ b/examples/dialog_providers/openai.rs @@ -12,11 +12,9 @@ use serde::Deserialize; use serde_json::json; use strum::VariantNames; -use crate::{FunctionCall, get_time_parameters_schema}; -use context_switch_core::conversation::Conversation; -use context_switch_core::{AudioFormat, Service}; - use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; +use crate::{FunctionCall, get_time_parameters_schema}; +use context_switch_core::{AudioFormat, Conversation, Service}; pub struct OpenAIProvider; diff --git a/examples/transcribe.rs b/examples/transcribe.rs index d4cda99..480c849 100644 --- a/examples/transcribe.rs +++ b/examples/transcribe.rs @@ -13,10 +13,9 @@ use context_switch::services::{ AristechTranscribe, AzureTranscribe, ElevenLabsTranscribe, GoogleTranscribe, }; use context_switch::{AudioConsumer, InputModality, OutputModality}; -use context_switch_core::conversation::{Conversation, Input}; use context_switch_core::language::Languages; use context_switch_core::service::Service; -use context_switch_core::{AudioFormat, AudioFrame, audio}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, audio}; const DEFAULT_LANGUAGE: &str = "en-US"; diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index b3c18c8..316bc69 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -10,7 +10,7 @@ use gemini_live::types::{ use gemini_live::{ReconnectPolicy, Session, SessionConfig}; use tracing::{debug, info, trace}; -use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; +use crate::{Params, ServiceInputEvent, ServiceOutputEvent, TextOutputs}; use context_switch_core::{ AudioFormat, AudioFrame, BillingRecord, BillingSchedule, ConversationInput, ConversationOutput, Input, OutputPath, @@ -21,12 +21,6 @@ pub struct Client { params: Params, } -#[derive(Debug, Clone, Copy)] -struct TextOutputConfig { - text: bool, - interim: bool, -} - impl Client { pub fn new(params: Params) -> Self { Self { params } @@ -35,19 +29,14 @@ impl Client { pub async fn dialog( self, output_format: AudioFormat, - text_output_enabled: bool, - interim_text_output_enabled: bool, + text_outputs: TextOutputs, mut input: ConversationInput, output: ConversationOutput, ) -> Result<()> { - let text_output = TextOutputConfig { - text: text_output_enabled, - interim: interim_text_output_enabled, - }; let billing_scope = self.params.model.clone(); let tools = function_declarations(&self.params.tools); let mut output_transcription_buffer = String::new(); - let mut session = Session::connect(self.session_config(text_output)?) + let mut session = Session::connect(self.session_config(text_outputs)?) .await .context("Connecting to Gemini Live")?; output.service_event( @@ -68,7 +57,7 @@ impl Client { event = session.next_event() => { match event { Some(event) => { - match self.process_event(event, output_format, text_output, &output, &billing_scope, &mut output_transcription_buffer).await? { + match self.process_event(event, output_format, text_outputs, &output, &billing_scope, &mut output_transcription_buffer).await? { FlowControl::Continue => {} FlowControl::End => break, } @@ -86,7 +75,7 @@ impl Client { Ok(()) } - fn session_config(&self, text_output: TextOutputConfig) -> Result { + fn session_config(&self, text_outputs: TextOutputs) -> Result { let transport = TransportConfig { endpoint: self .params @@ -100,12 +89,12 @@ impl Client { Ok(SessionConfig { transport, - setup: self.setup_config(text_output)?, + setup: self.setup_config(text_outputs)?, reconnect: ReconnectPolicy::default(), }) } - fn setup_config(&self, text_output: TextOutputConfig) -> Result { + fn setup_config(&self, text_outputs: TextOutputs) -> Result { let input_audio_transcription = self .params .input_audio_transcription @@ -115,7 +104,7 @@ impl Client { .output_audio_transcription .then_some(AudioTranscriptionConfig {}); - if !(text_output.text || text_output.interim) + if !(text_outputs.text || text_outputs.interim) && (input_audio_transcription.is_some() || output_audio_transcription.is_some()) { bail!( @@ -212,7 +201,7 @@ impl Client { &self, event: ServerEvent, output_format: AudioFormat, - text_output: TextOutputConfig, + text_outputs: TextOutputs, output: &ConversationOutput, billing_scope: &str, output_transcription_buffer: &mut String, @@ -230,7 +219,7 @@ impl Client { } ServerEvent::GenerationComplete => {} ServerEvent::TurnComplete => { - if text_output.text && !output_transcription_buffer.is_empty() { + if text_outputs.text && !output_transcription_buffer.is_empty() { output.text( true, std::mem::take(output_transcription_buffer), @@ -247,13 +236,13 @@ impl Client { output.clear_audio()?; } ServerEvent::InputTranscription(text) => { - if text_output.text { + if text_outputs.text { output.text(true, text, None, None)?; } } ServerEvent::OutputTranscription(text) => { output_transcription_buffer.push_str(&text); - if text_output.interim { + if text_outputs.interim { output.text( false, output_transcription_buffer.clone(), diff --git a/services/google-dialog/src/lib.rs b/services/google-dialog/src/lib.rs index 3346eaf..45292fc 100644 --- a/services/google-dialog/src/lib.rs +++ b/services/google-dialog/src/lib.rs @@ -9,7 +9,7 @@ use context_switch_core::{AudioFormat, Conversation, OutputModality, Service}; mod client; mod types; -pub use client::Client; +use client::Client; pub use types::{Params, ServiceInputEvent, ServiceOutputEvent}; #[derive(Debug)] @@ -36,13 +36,7 @@ impl Service for GoogleDialog { info!(model = %params.model, "Connecting to Gemini Live"); let (input, output) = conversation.start()?; Client::new(params) - .dialog( - output_format, - text_outputs.text, - text_outputs.interim, - input, - output, - ) + .dialog(output_format, text_outputs, input, output) .await } } diff --git a/src/context_switch.rs b/src/context_switch.rs index 181d82d..f234b05 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -14,8 +14,7 @@ use tracing_futures::Instrument; use crate::{AudioTracer, ClientEvent, ConversationId, InputModality, ServerEvent}; use context_switch_core::billing_collector::BillingCollector; -use context_switch_core::conversation::{Conversation, Input, Output}; -use context_switch_core::{AudioFrame, BillingContext, Registry}; +use context_switch_core::{AudioFrame, BillingContext, Conversation, Input, Output, Registry}; #[derive(Debug)] pub struct ContextSwitch { diff --git a/src/protocol.rs b/src/protocol.rs index 3dbe8fa..a3f555d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -3,8 +3,7 @@ use derive_more::derive::{Deref, Display, From, Into}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use context_switch_core::{ - BillingRecord, InputModality, OutputModality, OutputPath, audio, - conversation::{BillingId, RequestId}, + BillingId, BillingRecord, InputModality, OutputModality, OutputPath, RequestId, audio, }; /// Conversation identifier. diff --git a/src/tests.rs b/src/tests.rs index 807921d..8302536 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -103,9 +103,10 @@ mod helper { use anyhow::Result; use async_trait::async_trait; - use tokio::{sync::mpsc::Sender, time}; + use tokio::sync::mpsc::Sender; + use tokio::time; - use context_switch_core::{Service, conversation::Conversation}; + use context_switch_core::{Conversation, Service}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Notification { From 0c5e4a6d7e446b21463e3e591674ae4b934b09c0 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 09:25:48 +0200 Subject: [PATCH 17/20] google-dialog: Review log output --- services/google-dialog/src/client.rs | 33 ++++++++++++++++------------ src/context_switch.rs | 2 +- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 316bc69..1084777 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -50,28 +50,34 @@ impl Client { if let Some(input) = input { self.process_input(&session, input).await?; } else { + debug!("Conversation input closed"); session.audio_stream_end().await.context("Ending Gemini audio stream")?; + debug!("Audio stream end sent to Gemini"); break; } } event = session.next_event() => { match event { Some(event) => { - match self.process_event(event, output_format, text_outputs, &output, &billing_scope, &mut output_transcription_buffer).await? { + match self.process_event(event, output_format, text_outputs, &output, &billing_scope, &mut output_transcription_buffer)? { FlowControl::Continue => {} - FlowControl::End => break, + FlowControl::End => { + debug!("Received terminal server event"); + break; + } } } - None => break, + None => { + debug!("Server event stream ended"); + break; + } } } } } - session - .close() - .await - .context("Closing Gemini Live session")?; + debug!("Closing session"); + session.close().await.context("Closing session")?; Ok(()) } @@ -183,21 +189,18 @@ impl Client { session .send_tool_response(vec![response]) .await - .context("Sending Gemini tool response")?; + .context("Sending tool response")?; } ServiceInputEvent::Prompt { text } => { info!("Received prompt"); - session - .send_text(&text) - .await - .context("Sending prompt to Gemini Live")?; + session.send_text(&text).await.context("Sending prompt")?; } }, } Ok(()) } - async fn process_event( + fn process_event( &self, event: ServerEvent, output_format: AudioFormat, @@ -278,7 +281,9 @@ impl Client { } ServerEvent::Closed { reason } => { if !reason.is_empty() { - debug!(%reason, "Gemini Live connection closed"); + debug!(%reason, "Endpoint signaled connection closure"); + } else { + debug!("Endpoint signaled connection closure without a reason"); } return Ok(FlowControl::End); } diff --git a/src/context_switch.rs b/src/context_switch.rs index f234b05..611dde3 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -225,7 +225,7 @@ async fn process_conversation_protected( let service = registry.service(&service_name)?; // Temporarily use an unbounded channel for output forwarding because we may process rather - // large audio files (local playback for example) in one go are are not yet able to block sends. + // large audio files (local playback for example) in one go are not yet able to block sends. let (output_sender, mut output_receiver) = unbounded_channel(); // We might receive a large number of audio frames before the service can process them. let (input_sender, input_receiver) = channel(256); From 76dbe374fd8b30ec543043513d92dcd97d722a57 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 10:01:21 +0200 Subject: [PATCH 18/20] Clarify why functions calls are sent through the media path and make cancellation symmetric to the function call invocation --- examples/dialog_providers/google.rs | 8 ++++---- services/google-dialog/src/client.rs | 22 +++++++++++++++++----- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index b1af97e..47248f6 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -47,7 +47,7 @@ impl ProviderApi for GoogleProvider { fn parse_service_event(&self, value: serde_json::Value) -> Result> { match serde_json::from_value(value)? { - google_dialog::ServiceOutputEvent::FunctionCall { + ServiceOutputEvent::FunctionCall { name, call_id, arguments, @@ -56,11 +56,11 @@ impl ProviderApi for GoogleProvider { call_id, arguments: Some(arguments), })), - google_dialog::ServiceOutputEvent::ToolCallCancellation { call_ids } => { - tracing::info!("Tool calls cancelled: {call_ids:?}"); + ServiceOutputEvent::ToolCallCancellation { call_id } => { + tracing::info!("Tool call cancelled: {call_id}"); Ok(None) } - google_dialog::ServiceOutputEvent::SessionUpdated { tools } => { + ServiceOutputEvent::SessionUpdated { tools } => { tracing::info!("Session updated: {tools:?}"); Ok(None) } diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 1084777..4ec1527 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -256,6 +256,14 @@ impl Client { } ServerEvent::ToolCall(calls) => { for call in calls { + // Send the function call via the media path. + // + // This means that audio scheduled before will finish playing before + // the client receives the event to execute the function call. + // + // For example, if we use a prompt to initiate a function call, it + // might overtake currently pending audio output and an answer + // before the participant even heard the audio. output.service_event( OutputPath::Media, ServiceOutputEvent::FunctionCall { @@ -267,14 +275,18 @@ impl Client { } } ServerEvent::ToolCallCancellation(ids) => { - output.service_event( - OutputPath::Control, - ServiceOutputEvent::ToolCallCancellation { call_ids: ids }, - )?; + // Since we are sending function calls through the media path, we need to send + // cancellations too, so that they don't overtake. + for id in ids { + output.service_event( + OutputPath::Media, + ServiceOutputEvent::ToolCallCancellation { call_id: id }, + )?; + } } ServerEvent::SessionResumption { .. } => {} ServerEvent::GoAway { time_left } => { - debug!(?time_left, "Gemini Live goAway received"); + debug!(?time_left, "GoAway received"); } ServerEvent::Usage(usage) => { bill_usage(output, billing_scope, usage)?; From ce614fa7390d04fac9a6417778847a0cc9c7b03e Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 10:02:09 +0200 Subject: [PATCH 19/20] Reorder functions, document Gemini Params --- .harper-dictionary.txt | 1 + Cargo.toml | 25 ++-- examples/dialog_providers/google.rs | 4 +- services/google-dialog/src/client.rs | 178 +++++++++++++-------------- services/google-dialog/src/types.rs | 11 +- src/context_switch.rs | 2 +- 6 files changed, 110 insertions(+), 111 deletions(-) diff --git a/.harper-dictionary.txt b/.harper-dictionary.txt index fd2abbd..17c9691 100644 --- a/.harper-dictionary.txt +++ b/.harper-dictionary.txt @@ -2,6 +2,7 @@ AirPods BCP ContextSwitch FreeSWITCH +alsa inband seekable subtag diff --git a/Cargo.toml b/Cargo.toml index a714fe2..889f0ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -83,7 +83,7 @@ gemini-live = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } reqwest = { workspace = true } -# For advanced params in openai-dialog +# For advanced parameters in `openai-dialog` openai-api-rs = { workspace = true } serde_json = { workspace = true } chrono-tz = { version = "0.10.3" } @@ -110,21 +110,21 @@ gemini-live = { path = "external/gemini-live-rs/crates/gemini-live" } # The submodule crate inherits these via `workspace = true`, so we keep them # centralized here and grouped to make future sync/review straightforward. tokio-tungstenite = { version = "0.29", features = ["rustls-tls-webpki-roots"] } -futures-util = "0.3" +futures-util = "0.3.32" bytes = "1.11" -thiserror = "2.0" +thiserror = "2.0.18" rustls = { version = "0.23", features = ["ring"], default-features = false } -google-cloud-auth = { version = "1.9.0", default-features = false } +google-cloud-auth = { version = "1.10.0", default-features = false } anyhow = "1.0.102" derive_more = { version = "2.1.1", features = ["full"] } static_assertions = "1.1.0" async-stream = { version = "0.3.6" } # Tokio features are intentionally explicit: -# - sync: channels/mutexes used throughout services -# - rt + macros: runtime and `tokio::select!`/task macros used by gemini-live -# - time: timeout/sleep used by gemini-live session/transport logic -tokio = { version = "1.50.0", features = ["sync", "rt", "macros", "time"] } +# - `sync`: channels/mutexes used throughout services +# - `rt` + `macros`: runtime and `tokio::select!`/task macros used by `gemini-live` +# - `time`: timeout/sleep used by `gemini-live` session/transport logic +tokio = { version = "1.52.3", features = ["sync", "rt", "macros", "time"] } futures = "0.3.31" serde = { version = "1.0.215", features = ["derive"] } serde_json = "1.0.149" @@ -134,7 +134,7 @@ async-trait = "0.1.83" tracing = "0.1.41" dotenvy = { version = "0.15.7" } url = { version = "2.5.8" } -reqwest = { version = "0.13.2" } +reqwest = { version = "0.13.3" } mime_guess2 = { version = "2.3.1" } hound = { version = "3.5.1" } chrono = { version = "0.4.44" } @@ -144,13 +144,12 @@ chrono = { version = "0.4.44" } # azure-speech = { path = "external/azure-speech-sdk-rs" } -# openai-api-rs = "5.2.3" openai-api-rs = { path = "external/openai-api-rs" } -# - `symphonia-wav` is mandatory: The default WAV decoder does not seem to support A-Law and also -# panics with a few of our testcases. +# - `symphonia-wav` is mandatory: The default WAV decoder does not seem to support A-Law and +# panics with a few of our test cases. # - No default features because we don't want to pull alsa on Linux by default for local playback. # - We have to define at least _one_ decoder, otherwise `cargo clippy --all-targets` fails, so we select `symphonia-mp3`. rodio = { version = "0.22.2", default-features = false, features = ["symphonia-mp3"] } rstest = { version = "0.26.1" } -uuid = { version = "1.17.0", features = ["v4"] } +uuid = { version = "1.23.1", features = ["v4"] } diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs index 47248f6..a2b4ee1 100644 --- a/examples/dialog_providers/google.rs +++ b/examples/dialog_providers/google.rs @@ -3,7 +3,9 @@ use std::env; use anyhow::{Context, Result, bail}; use async_trait::async_trait; use gemini_live::types as gemini_types; -use google_dialog::{GoogleDialog, ServiceInputEvent as GoogleServiceInputEvent}; +use google_dialog::{ + GoogleDialog, ServiceInputEvent as GoogleServiceInputEvent, ServiceOutputEvent, +}; use reqwest::Url; use serde::Deserialize; use serde_json::json; diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs index 4ec1527..62a8325 100644 --- a/services/google-dialog/src/client.rs +++ b/services/google-dialog/src/client.rs @@ -36,7 +36,7 @@ impl Client { let billing_scope = self.params.model.clone(); let tools = function_declarations(&self.params.tools); let mut output_transcription_buffer = String::new(); - let mut session = Session::connect(self.session_config(text_outputs)?) + let mut session = Session::connect(session_config(&self.params, text_outputs)?) .await .context("Connecting to Gemini Live")?; output.service_event( @@ -81,83 +81,6 @@ impl Client { Ok(()) } - fn session_config(&self, text_outputs: TextOutputs) -> Result { - let transport = TransportConfig { - endpoint: self - .params - .host - .clone() - .map(Endpoint::Custom) - .unwrap_or_default(), - auth: Auth::ApiKey(self.params.api_key.clone()), - ..Default::default() - }; - - Ok(SessionConfig { - transport, - setup: self.setup_config(text_outputs)?, - reconnect: ReconnectPolicy::default(), - }) - } - - fn setup_config(&self, text_outputs: TextOutputs) -> Result { - let input_audio_transcription = self - .params - .input_audio_transcription - .then_some(AudioTranscriptionConfig {}); - let output_audio_transcription = self - .params - .output_audio_transcription - .then_some(AudioTranscriptionConfig {}); - - if !(text_outputs.text || text_outputs.interim) - && (input_audio_transcription.is_some() || output_audio_transcription.is_some()) - { - bail!( - "Google dialog requires text output modality when transcription is enabled: if inputAudioTranscription or outputAudioTranscription is set, add OutputModality::Text or OutputModality::InterimText to the conversation output modalities, or set both transcription flags to false." - ); - } - - Ok(SetupConfig { - model: model_resource_name(&self.params.model), - generation_config: Some(GenerationConfig { - // NOTE: Enabling Modality::Text here currently causes a Gemini setup-time - // "Internal error encountered." in this service flow. - response_modalities: Some(vec![Modality::Audio]), - speech_config: self.params.voice.clone().map(|voice_name| SpeechConfig { - voice_config: VoiceConfig { - prebuilt_voice_config: PrebuiltVoiceConfig { voice_name }, - }, - }), - thinking_config: self - .params - .thinking_level - .map(|thinking_level| ThinkingConfig { - thinking_level: Some(thinking_level), - ..Default::default() - }), - temperature: self.params.temperature, - ..Default::default() - }), - system_instruction: self.params.instructions.clone().map(system_instruction), - tools: (!self.params.tools.is_empty()).then(|| self.params.tools.clone()), - realtime_input_config: self.params.realtime_input_config.clone(), - // Opt in so Gemini sends resume handles. The session layer stores - // the latest handle and patches it into reconnect setup messages, - // keeping context across GoAway-triggered reconnects. - session_resumption: Some(SessionResumptionConfig::default()), - context_window_compression: self.params.context_window_compression.then_some( - ContextWindowCompressionConfig { - sliding_window: Some(SlidingWindow::default()), - ..Default::default() - }, - ), - input_audio_transcription, - output_audio_transcription, - ..Default::default() - }) - } - async fn process_input(&self, session: &Session, input: Input) -> Result<()> { match input { Input::Audio { frame } => { @@ -307,6 +230,91 @@ impl Client { } } +fn function_declarations(tools: &[Tool]) -> Option> { + let declarations: Vec<_> = tools + .iter() + .filter_map(|tool| match tool { + Tool::FunctionDeclarations(declarations) => Some(declarations.as_slice()), + Tool::GoogleSearch(_) => None, + }) + .flatten() + .cloned() + .collect(); + + (!declarations.is_empty()).then_some(declarations) +} + +fn session_config(params: &Params, text_outputs: TextOutputs) -> Result { + let transport = TransportConfig { + endpoint: params + .host + .clone() + .map(Endpoint::Custom) + .unwrap_or_default(), + auth: Auth::ApiKey(params.api_key.clone()), + ..Default::default() + }; + + Ok(SessionConfig { + transport, + setup: setup_config(params, text_outputs)?, + reconnect: ReconnectPolicy::default(), + }) +} + +fn setup_config(params: &Params, text_outputs: TextOutputs) -> Result { + let input_audio_transcription = params + .input_audio_transcription + .then_some(AudioTranscriptionConfig {}); + let output_audio_transcription = params + .output_audio_transcription + .then_some(AudioTranscriptionConfig {}); + + if !(text_outputs.text || text_outputs.interim) + && (input_audio_transcription.is_some() || output_audio_transcription.is_some()) + { + bail!( + "Google dialog requires text output modality when transcription is enabled: if inputAudioTranscription or outputAudioTranscription is set, add OutputModality::Text or OutputModality::InterimText to the conversation output modalities, or set both transcription flags to false." + ); + } + + Ok(SetupConfig { + model: model_resource_name(¶ms.model), + generation_config: Some(GenerationConfig { + // NOTE: Enabling Modality::Text here currently causes a Gemini setup-time + // "Internal error encountered." in this service flow. + response_modalities: Some(vec![Modality::Audio]), + speech_config: params.voice.clone().map(|voice_name| SpeechConfig { + voice_config: VoiceConfig { + prebuilt_voice_config: PrebuiltVoiceConfig { voice_name }, + }, + }), + thinking_config: params.thinking_level.map(|thinking_level| ThinkingConfig { + thinking_level: Some(thinking_level), + ..Default::default() + }), + temperature: params.temperature, + ..Default::default() + }), + system_instruction: params.instructions.clone().map(system_instruction), + tools: (!params.tools.is_empty()).then(|| params.tools.clone()), + realtime_input_config: params.realtime_input_config.clone(), + // Opt in so Gemini sends resume handles. The session layer stores + // the latest handle and patches it into reconnect setup messages, + // keeping context across GoAway-triggered reconnects. + session_resumption: Some(SessionResumptionConfig::default()), + context_window_compression: params.context_window_compression.then_some( + ContextWindowCompressionConfig { + sliding_window: Some(SlidingWindow::default()), + ..Default::default() + }, + ), + input_audio_transcription, + output_audio_transcription, + ..Default::default() + }) +} + fn model_resource_name(model: &str) -> String { if model.starts_with("models/") { model.to_owned() @@ -325,20 +333,6 @@ fn system_instruction(text: String) -> Content { } } -fn function_declarations(tools: &[Tool]) -> Option> { - let declarations: Vec<_> = tools - .iter() - .filter_map(|tool| match tool { - Tool::FunctionDeclarations(declarations) => Some(declarations.as_slice()), - Tool::GoogleSearch(_) => None, - }) - .flatten() - .cloned() - .collect(); - - (!declarations.is_empty()).then_some(declarations) -} - fn bill_usage( output: &ConversationOutput, billing_scope: &str, diff --git a/services/google-dialog/src/types.rs b/services/google-dialog/src/types.rs index fa4494f..6ee5e30 100644 --- a/services/google-dialog/src/types.rs +++ b/services/google-dialog/src/types.rs @@ -11,11 +11,14 @@ pub struct Params { pub instructions: Option, pub voice: Option, + /// Sampling temperature. Valid range: `0.0..=2.0`. + /// If omitted, Gemini uses the model-specific default temperature. pub temperature: Option, /// Gemini 3.1 thinking level (`minimal`, `low`, `medium`, or `high`). + /// In Live API, Gemini 3.1 defaults to `minimal` when omitted. pub thinking_level: Option, /// Enabled by default to avoid context-window exhaustion during long audio sessions. - #[serde(default = "default_true")] + #[serde(default = "default_context_window_compression")] pub context_window_compression: bool, #[serde(default)] pub tools: Vec, @@ -39,7 +42,7 @@ impl Params { voice: None, temperature: None, thinking_level: None, - context_window_compression: true, + context_window_compression: default_context_window_compression(), tools: vec![], realtime_input_config: None, input_audio_transcription: false, @@ -48,7 +51,7 @@ impl Params { } } -fn default_true() -> bool { +fn default_context_window_compression() -> bool { true } @@ -76,7 +79,7 @@ pub enum ServiceOutputEvent { arguments: serde_json::Value, }, #[serde(rename_all = "camelCase")] - ToolCallCancellation { call_ids: Vec }, + ToolCallCancellation { call_id: String }, SessionUpdated { #[serde(skip_serializing_if = "Option::is_none")] tools: Option>, diff --git a/src/context_switch.rs b/src/context_switch.rs index 611dde3..44ee848 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -41,9 +41,9 @@ pub fn registry() -> Registry { .add_service("azure-synthesize", azure::AzureSynthesize) .add_service("azure-translate", azure::AzureTranslate) .add_service("elevenlabs-transcribe", elevenlabs::ElevenLabsTranscribe) - .add_service("google-dialog", google_dialog::GoogleDialog) .add_service("google-transcribe", google_transcribe::GoogleTranscribe) .add_service("openai-dialog", openai_dialog::OpenAIDialog) + .add_service("google-dialog", google_dialog::GoogleDialog) .add_service("aristech-transcribe", aristech::AristechTranscribe) .add_service("aristech-synthesize", aristech::AristechSynthesize) } From 80c37fd12fcaf453cf2c6c74153dc7e9c3e11298 Mon Sep 17 00:00:00 2001 From: Armin Sander Date: Wed, 20 May 2026 10:15:46 +0200 Subject: [PATCH 20/20] Fix typo --- src/context_switch.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/context_switch.rs b/src/context_switch.rs index 44ee848..a2abd40 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -225,7 +225,7 @@ async fn process_conversation_protected( let service = registry.service(&service_name)?; // Temporarily use an unbounded channel for output forwarding because we may process rather - // large audio files (local playback for example) in one go are not yet able to block sends. + // large audio files (local playback for example) in one go and are not able to block sends. let (output_sender, mut output_receiver) = unbounded_channel(); // We might receive a large number of audio frames before the service can process them. let (input_sender, input_receiver) = channel(256);