diff --git a/.env.example b/.env.example index 64f9267..3e62a67 100644 --- a/.env.example +++ b/.env.example @@ -1,3 +1,7 @@ +# Gemini + +GEMINI_API_KEY=your_gemini_api_key + # OpenAI Configuration OPENAI_API_KEY=your_openai_key OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md index f3a2f32..f6fa600 100644 --- a/.github/copilot-instructions.md +++ b/.github/copilot-instructions.md @@ -3,9 +3,13 @@ ## Rust Style - Prefer imports over deeply-qualified module paths. As a rule of thumb, avoid using more than one module prefix inline (for example, prefer importing a type and using `TypeName` instead of writing `foo::bar::TypeName` repeatedly). - Prefer high-level flow first: when practical, place local supporting definitions (for example helper structs, impls, functions, and type aliases) below their first use. +- In module files, keep definitions ordered top-down by call flow (entry points first, helpers after first use). - Keep imports grouped and sorted to match existing file style. - Avoid `maybe_` prefixes for optional variables; use neutral names and rely on type/context for optionality. - Avoid `_ref` suffixes for local variable names; use descriptive neutral names instead. +- Prefer explicit imports over repeated relative module qualification. +- Prefer private-by-default visibility; only widen visibility when required by a module boundary. +- For trait-based APIs, prefer focused request/context types over passing broad configuration structs. ## Change Communication - Include a short rationale for each non-trivial code change. diff --git a/.gitmodules b/.gitmodules index aeb60dc..632287d 100644 --- a/.gitmodules +++ b/.gitmodules @@ -6,3 +6,8 @@ path = external/openai-api-rs url = ../openai-api-rs.git branch = "context-switch-v9.0.0" +[submodule "external/gemini-live-rs"] + path = external/gemini-live-rs + url = ../gemini-live-rs.git + branch = context-switch + diff --git a/.harper-dictionary.txt b/.harper-dictionary.txt index dff3086..17c9691 100644 --- a/.harper-dictionary.txt +++ b/.harper-dictionary.txt @@ -2,6 +2,7 @@ AirPods BCP ContextSwitch FreeSWITCH -Inband +alsa +inband seekable subtag diff --git a/.vscode/settings.json b/.vscode/settings.json deleted file mode 100644 index 8ef4f25..0000000 --- a/.vscode/settings.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "cSpell.words": [ - "alsa", - "aristech", - "denoiser", - "subtag" - ] -} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index f7d43a1..889f0ad 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,6 +13,7 @@ members = [ "services/aristech", "services/azure", "services/elevenlabs", + "services/google-dialog", "services/google-transcribe", "services/openai-dialog", "services/playback", @@ -21,6 +22,8 @@ members = [ [workspace.package] version = "2.3.0" edition = "2024" +license = "MIT" +repository = "https://github.com/pragmatrix/context-switch" [dependencies] @@ -29,6 +32,7 @@ edition = "2024" context-switch-core = { workspace = true } openai-dialog = { path = "services/openai-dialog" } +google-dialog = { workspace = true } azure = { workspace = true } azure-speech = { workspace = true } aristech = { workspace = true } @@ -73,10 +77,13 @@ rodio = { workspace = true, features = ["playback"] } azure = { workspace = true } aristech = { workspace = true } google-transcribe = { workspace = true } +google-dialog = { workspace = true } +gemini-live = { workspace = true } tokio = { workspace = true, features = ["rt-multi-thread"] } +reqwest = { workspace = true } -# For advanced params in openai-dialog +# For advanced parameters in `openai-dialog` openai-api-rs = { workspace = true } serde_json = { workspace = true } chrono-tz = { version = "0.10.3" } @@ -96,12 +103,28 @@ playback = { path = "services/playback" } aristech = { path = "services/aristech" } elevenlabs = { path = "services/elevenlabs" } google-transcribe = { path = "services/google-transcribe" } +google-dialog = { path = "services/google-dialog" } +gemini-live = { path = "external/gemini-live-rs/crates/gemini-live" } + +# Dependencies required by `external/gemini-live-rs/crates/gemini-live`. +# The submodule crate inherits these via `workspace = true`, so we keep them +# centralized here and grouped to make future sync/review straightforward. +tokio-tungstenite = { version = "0.29", features = ["rustls-tls-webpki-roots"] } +futures-util = "0.3.32" +bytes = "1.11" +thiserror = "2.0.18" +rustls = { version = "0.23", features = ["ring"], default-features = false } +google-cloud-auth = { version = "1.10.0", default-features = false } anyhow = "1.0.102" derive_more = { version = "2.1.1", features = ["full"] } static_assertions = "1.1.0" async-stream = { version = "0.3.6" } -tokio = { version = "1.50.0", features = ["sync"] } +# Tokio features are intentionally explicit: +# - `sync`: channels/mutexes used throughout services +# - `rt` + `macros`: runtime and `tokio::select!`/task macros used by `gemini-live` +# - `time`: timeout/sleep used by `gemini-live` session/transport logic +tokio = { version = "1.52.3", features = ["sync", "rt", "macros", "time"] } futures = "0.3.31" serde = { version = "1.0.215", features = ["derive"] } serde_json = "1.0.149" @@ -111,7 +134,7 @@ async-trait = "0.1.83" tracing = "0.1.41" dotenvy = { version = "0.15.7" } url = { version = "2.5.8" } -reqwest = { version = "0.13.2" } +reqwest = { version = "0.13.3" } mime_guess2 = { version = "2.3.1" } hound = { version = "3.5.1" } chrono = { version = "0.4.44" } @@ -121,13 +144,12 @@ chrono = { version = "0.4.44" } # azure-speech = { path = "external/azure-speech-sdk-rs" } -# openai-api-rs = "5.2.3" openai-api-rs = { path = "external/openai-api-rs" } -# - `symphonia-wav` is mandatory: The default WAV decoder does not seem to support A-Law and also -# panics with a few of our testcases. +# - `symphonia-wav` is mandatory: The default WAV decoder does not seem to support A-Law and +# panics with a few of our test cases. # - No default features because we don't want to pull alsa on Linux by default for local playback. # - We have to define at least _one_ decoder, otherwise `cargo clippy --all-targets` fails, so we select `symphonia-mp3`. rodio = { version = "0.22.2", default-features = false, features = ["symphonia-mp3"] } rstest = { version = "0.26.1" } -uuid = { version = "1.17.0", features = ["v4"] } +uuid = { version = "1.23.1", features = ["v4"] } diff --git a/audio-knife/src/main.rs b/audio-knife/src/main.rs index e36d385..d584a15 100644 --- a/audio-knife/src/main.rs +++ b/audio-knife/src/main.rs @@ -37,8 +37,8 @@ use tokio::{ use tracing::{Instrument, Span, debug, error, info, info_span}; use context_switch::{ - AudioFormat, AudioFrame, ClientEvent, ContextSwitch, ConversationId, InputModality, - ServerEvent, audio, billing_collector::BillingCollector, conversation::BillingId, + AudioFormat, AudioFrame, BillingId, ClientEvent, ContextSwitch, ConversationId, InputModality, + ServerEvent, audio, billing_collector::BillingCollector, }; use tracing_subscriber::{EnvFilter, fmt::format::FmtSpan}; use uuid::Uuid; diff --git a/core/src/billing_context.rs b/core/src/billing_context.rs index 2269e69..004fa8a 100644 --- a/core/src/billing_context.rs +++ b/core/src/billing_context.rs @@ -2,7 +2,9 @@ use std::sync::{Arc, Mutex}; use anyhow::Result; -use crate::{BillingRecord, billing_collector::BillingCollector, conversation::BillingId}; +use crate::BillingRecord; +use crate::billing_collector::BillingCollector; +use crate::conversation::BillingId; #[derive(Debug, Clone)] pub struct BillingContext { diff --git a/core/src/conversation.rs b/core/src/conversation.rs index 038ca60..486e1a3 100644 --- a/core/src/conversation.rs +++ b/core/src/conversation.rs @@ -168,7 +168,7 @@ impl ConversationInput { self.input.recv().await } - /// Run a nested service conversation with one single input request and wait until its + /// Run a nested service conversation with one single input request and wait until it's /// completed. /// /// All output is sent to the conversation output. diff --git a/core/src/lib.rs b/core/src/lib.rs index 1e0b9c7..6e24baa 100644 --- a/core/src/lib.rs +++ b/core/src/lib.rs @@ -1,7 +1,7 @@ pub mod audio; pub mod billing_collector; mod billing_context; -pub mod conversation; +mod conversation; mod duration; pub mod language; mod protocol; @@ -15,6 +15,7 @@ use anyhow::{Context, Result, bail}; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender, unbounded_channel}; pub use billing_context::BillingContext; +pub use conversation::*; pub use duration::Duration; pub use protocol::*; pub use registry::*; diff --git a/examples/aristech-synthesize.rs b/examples/aristech-synthesize.rs index d57b7b1..206338b 100644 --- a/examples/aristech-synthesize.rs +++ b/examples/aristech-synthesize.rs @@ -1,23 +1,18 @@ -use std::{ - env, - num::{NonZeroU16, NonZeroU32}, - thread, - time::Duration, -}; +use std::env; +use std::num::{NonZeroU16, NonZeroU32}; +use std::thread; +use std::time::Duration; use anyhow::{Context as AnyhowContext, Result}; +use aristech::synthesize::{AristechSynthesize, Params as AristechParams}; use rodio::{DeviceSinkBuilder, Player, Source}; -use tokio::{ - select, - sync::mpsc::{channel, unbounded_channel}, -}; +use tokio::select; +use tokio::sync::mpsc::{channel, unbounded_channel}; -use aristech::synthesize::{AristechSynthesize, Params as AristechParams}; use context_switch::{InputModality, OutputModality}; +use context_switch_core::service::Service; use context_switch_core::{ - AudioFormat, AudioFrame, AudioProducer, audio, - conversation::{Conversation, Input, Output}, - service::Service, + AudioFormat, AudioFrame, AudioProducer, Conversation, Input, Output, audio, }; const SAMPLE_TEXT: &str = "Hallo! Dies ist eine Demonstration des Aristech Text-zu-Sprache-Dienstes. \ diff --git a/examples/aristech-transcribe.rs b/examples/aristech-transcribe.rs index af7a79e..c1144c1 100644 --- a/examples/aristech-transcribe.rs +++ b/examples/aristech-transcribe.rs @@ -9,12 +9,9 @@ use tokio::{ }; use aristech::transcribe::{ApiKeyAuth, AuthConfig, CredentialsAuth, Params as AristechParams}; + use context_switch::{InputModality, OutputModality, services::AristechTranscribe}; -use context_switch_core::{ - AudioFormat, AudioFrame, audio, - conversation::{Conversation, Input}, - service::Service, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, audio, service::Service}; #[tokio::main] async fn main() -> Result<()> { @@ -28,6 +25,7 @@ async fn main() -> Result<()> { let device = host .default_input_device() .expect("Failed to get default input device"); + // spellcheck: ignore // let config = device // .default_input_config() // .expect("Failed to get default input config"); diff --git a/examples/azure-translate.rs b/examples/azure-translate.rs index 1a7e6dd..962fff4 100644 --- a/examples/azure-translate.rs +++ b/examples/azure-translate.rs @@ -1,26 +1,19 @@ //! A context switch demo. Runs locally, gets voice data from your current microphone. -use std::{ - env, - num::{NonZeroU16, NonZeroU32}, - thread, - time::Duration, -}; +use std::env; +use std::num::{NonZeroU16, NonZeroU32}; +use std::thread; +use std::time::Duration; use anyhow::Result; use azure::AzureTranslate; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; use rodio::{DeviceSinkBuilder, Player, Source}; +use tokio::select; +use tokio::sync::mpsc::{UnboundedReceiver, channel, unbounded_channel}; use context_switch::{InputModality, OutputModality}; -use context_switch_core::{ - AudioFormat, AudioFrame, Service, audio, - conversation::{Conversation, Input, Output}, -}; -use tokio::{ - select, - sync::mpsc::{UnboundedReceiver, channel, unbounded_channel}, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, Output, Service, audio}; #[tokio::main] async fn main() -> Result<()> { diff --git a/examples/openai-dialog.rs b/examples/dialog.rs similarity index 60% rename from examples/openai-dialog.rs rename to examples/dialog.rs index 47186e7..2eb1e59 100644 --- a/examples/openai-dialog.rs +++ b/examples/dialog.rs @@ -1,81 +1,73 @@ //! A context switch demo. Runs locally, gets voice data from your current microphone. -use std::{ - env, - num::{NonZeroU16, NonZeroU32}, - str::FromStr, - thread, - time::Duration, -}; +use std::num::{NonZeroU16, NonZeroU32}; +use std::str::FromStr; +use std::thread; +use std::time::Duration; use anyhow::{Context, Result, bail}; use chrono::Utc; -use clap::builder::{PossibleValuesParser, TypedValueParser}; use clap::{Parser, ValueEnum}; use context_switch::{InputModality, OutputModality}; use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; -use openai_api_rs::realtime::types; -use openai_dialog::{OpenAIDialog, Protocol, ServiceInputEvent, ServiceOutputEvent}; use rodio::{DeviceSinkBuilder, Player, Source}; use serde_json::json; -use tokio::{ - select, - sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}, -}; +use tokio::select; +use tokio::sync::mpsc::{Sender, UnboundedReceiver, channel, unbounded_channel}; use tracing::info; -use context_switch_core::{ - AudioFormat, AudioFrame, Service, audio, - conversation::{Conversation, Input, Output}, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, Output, audio}; + +mod dialog_providers; #[derive(Debug, Parser)] struct Cli { - #[arg(long, value_enum)] - protocol: Option, + #[arg(value_enum)] + provider: Provider, + #[arg(long)] + list_models: bool, + #[arg(long)] + list_voices: bool, #[arg(long)] endpoint: Option, #[arg(long)] model: Option, - #[arg(long, value_parser = realtime_voice_value_parser())] - voice: Option, + #[arg(long)] + voice: Option, } #[derive(Debug, Clone, Copy, ValueEnum)] -enum CliProtocol { +enum Provider { #[value(name = "openai")] OpenAI, - Azure, -} - -fn realtime_voice_value_parser() -> impl TypedValueParser { - PossibleValuesParser::new(::VARIANTS).try_map( - |value| { - parse_realtime_voice_value(&value) - .map_err(|e| format!("Invalid voice value `{value}`: {e}")) - }, - ) + #[value(name = "azure-openai")] + AzureOpenAI, + Google, } -fn parse_realtime_voice_value(value: &str) -> Result { - types::RealtimeVoice::from_str(value) -} - -impl From for Protocol { - fn from(value: CliProtocol) -> Self { - match value { - CliProtocol::OpenAI => Protocol::OpenAI, - CliProtocol::Azure => Protocol::Azure, - } +impl Provider { + fn api(self) -> &'static dyn dialog_providers::ProviderApi { + dialog_providers::provider_api(self) } } #[tokio::main] async fn main() -> Result<()> { - let cli = Cli::parse(); dotenvy::dotenv_override().context("Reading .env file")?; tracing_subscriber::fmt::init(); + let cli = Cli::parse(); + + if cli.list_models { + list_available_models(&cli).await?; + return Ok(()); + } + + if cli.list_voices { + list_available_voices(cli.provider)?; + return Ok(()); + } + let host = cpal::default_host(); let device = host .default_input_device() @@ -88,10 +80,10 @@ async fn main() -> Result<()> { let channels = input_config.channels(); let sample_rate = input_config.sample_rate(); - let format = AudioFormat::new(channels, sample_rate); + let input_format = AudioFormat::new(channels, sample_rate); + let output_format = cli.provider.api().output_format(input_format); let (input_sender, input_receiver) = channel(256); - let input_sender2 = input_sender.clone(); // Create and run the input stream @@ -100,7 +92,10 @@ async fn main() -> Result<()> { &input_config.into(), move |data: &[f32], _: &cpal::InputCallbackInfo| { let samples = audio::into_i16(data); - let frame = AudioFrame { format, samples }; + let frame = AudioFrame { + format: input_format, + samples, + }; if input_sender2.try_send(Input::Audio { frame }).is_err() { println!("Failed to send audio data") } @@ -108,43 +103,34 @@ async fn main() -> Result<()> { move |err| { eprintln!("Error occurred on stream: {err}"); }, - // timeout + // Timeout Some(Duration::from_secs(1)), ) .expect("Failed to build input stream"); stream.play().expect("Failed to play stream"); - let key = env::var("OPENAI_API_KEY").unwrap(); - let model = cli - .model - .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) - .filter(|model| !model.trim().is_empty()) - .context("Provide --model or set OPENAI_REALTIME_API_MODEL")?; - - let openai = OpenAIDialog; - let mut params = openai_dialog::Params::new(key, model); - params.host = cli - .endpoint - .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) - .filter(|endpoint| !endpoint.trim().is_empty()); - params.protocol = cli.protocol.map(Into::into); - params.voice = cli.voice; - params.tools.push(get_time_function_definition()); - let (output_sender, output_receiver) = unbounded_channel(); - + // Keep text enabled at the context-switch layer for Google. let conversation = Conversation::new( - InputModality::Audio { format }, - [OutputModality::Audio { format }], + InputModality::Audio { + format: input_format, + }, + [ + OutputModality::Audio { + format: output_format, + }, + OutputModality::Text, + OutputModality::InterimText, + ], input_receiver, output_sender, ); - let mut conversation = openai.conversation(params, conversation); - - let playback_task = setup_audio_playback(format, input_sender, output_receiver).await; - + let conversation = start_conversation(&cli, conversation); + tokio::pin!(conversation); + let playback_task = + setup_audio_playback(cli.provider, output_format, input_sender, output_receiver).await; // Spawn audio playback task let mut playback_handle = tokio::spawn(playback_task); @@ -155,17 +141,43 @@ async fn main() -> Result<()> { let _ = playback_handle.await; r? } - // Drive playback r = &mut playback_handle => { r?? } - } Ok(()) } +fn list_available_voices(provider: Provider) -> Result<()> { + println!("Available voices for {:?}:", provider); + for voice in provider.api().voices() { + println!("- {voice}"); + } + Ok(()) +} + +async fn list_available_models(cli: &Cli) -> Result<()> { + let request = dialog_providers::ListModelsRequest { + endpoint: cli.endpoint.clone(), + model: cli.model.clone(), + }; + cli.provider.api().list_models(request).await +} + +async fn start_conversation(cli: &Cli, conversation: Conversation) -> Result<()> { + let request = dialog_providers::StartConversationRequest { + endpoint: cli.endpoint.clone(), + model: cli.model.clone(), + voice: cli.voice.clone(), + }; + cli.provider + .api() + .start_conversation(request, conversation) + .await +} + enum AudioCommand { PlayFrame(AudioFrame), Clear, @@ -173,6 +185,7 @@ enum AudioCommand { } async fn setup_audio_playback( + provider: Provider, format: AudioFormat, input: Sender, mut output: UnboundedReceiver, @@ -182,7 +195,6 @@ async fn setup_audio_playback( // Spawn a dedicated audio thread let playback_thread = thread::spawn(move || { // Create output stream in the audio thread - let sink_handle = DeviceSinkBuilder::open_default_sink().unwrap(); let player = Player::connect_new(sink_handle.mixer()); @@ -209,7 +221,6 @@ async fn setup_audio_playback( }); // Create async task to forward frames to the audio thread - async move { while let Some(output) = output.recv().await { match output { @@ -219,32 +230,18 @@ async fn setup_audio_playback( break; } } - Output::Text { .. } | Output::RequestCompleted { .. } => {} + output @ Output::Text { .. } => { + println!("{output:?}"); + } + Output::RequestCompleted { .. } => {} Output::ClearAudio => { if cmd_tx.send(AudioCommand::Clear).is_err() { break; } } - Output::ServiceEvent { value, .. } => match serde_json::from_value(value)? { - ServiceOutputEvent::FunctionCall { - name, - call_id, - arguments, - } => { - info!("Processing function `{name}` with arguments `{arguments:?}`"); - let result = call_function(&name, arguments)?; - info!("Function result: `{result}`"); - let value = ServiceInputEvent::FunctionCallResult { - call_id, - output: json! ({ "time": serde_json::Value::String(result) }), - }; - let value = serde_json::to_value(&value)?; - input.try_send(Input::ServiceEvent { value })?; - } - ServiceOutputEvent::SessionUpdated { tools } => { - info!("Session Updated: {tools:?}"); - } - }, + Output::ServiceEvent { value, .. } => { + handle_service_event(provider, &input, value)?; + } Output::BillingRecords { records, scope, .. } => { info!("Billing: scope: {scope:?}, records: {records:?}"); } @@ -257,21 +254,58 @@ async fn setup_audio_playback( } } -fn get_time_function_definition() -> types::ToolDefinition { - types::ToolDefinition::Function { - name: "get_time".into(), - description: "The current time to the exact second.".into(), - parameters: json!({ - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "IANA time zone identifier of the region and city." - } - }, - "required": ["location"] - }), +fn handle_service_event( + provider: Provider, + input: &Sender, + value: serde_json::Value, +) -> Result<()> { + let call = provider.api().parse_service_event(value)?; + + if let Some(call) = call { + info!( + "Processing function `{}` with arguments `{:?}`", + call.name, call.arguments + ); + let result = call_function(&call.name, call.arguments)?; + info!("Function result: `{result}`"); + send_function_result(provider, input, call.call_id, call.name, result)?; } + + Ok(()) +} + +fn send_function_result( + provider: Provider, + input: &Sender, + call_id: String, + name: String, + result: String, +) -> Result<()> { + let value = provider + .api() + .function_result_event(call_id, Some(name), result)?; + input.try_send(Input::ServiceEvent { value })?; + Ok(()) +} + +#[derive(Debug)] +struct FunctionCall { + call_id: String, + name: String, + arguments: Option, +} + +fn get_time_parameters_schema() -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "IANA time zone identifier of the region and city." + } + }, + "required": ["location"] + }) } fn call_function(name: &str, arguments: Option) -> Result { diff --git a/examples/dialog_providers/azure_openai.rs b/examples/dialog_providers/azure_openai.rs new file mode 100644 index 0000000..e0b2a61 --- /dev/null +++ b/examples/dialog_providers/azure_openai.rs @@ -0,0 +1,90 @@ +use std::env; + +use anyhow::{Context, Result}; +use async_trait::async_trait; +use openai_api_rs::realtime::types as openai_types; +use openai_dialog::{OpenAIDialog, Protocol}; +use strum::VariantNames; + +use context_switch_core::{AudioFormat, Conversation, Service}; + +use super::openai::{self, OpenAIProvider}; +use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; + +pub struct AzureOpenAIProvider; + +#[async_trait(?Send)] +impl ProviderApi for AzureOpenAIProvider { + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()> { + let key = env::var("AZURE_OPENAI_API_KEY").context("AZURE_OPENAI_API_KEY undefined")?; + let model = request + .model + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .context("Provide --model or set AZURE_OPENAI_REALTIME_API_MODEL")?; + + let mut params = openai_dialog::Params::new(key, model); + params.host = request + .endpoint + .or_else(|| env::var("AZURE_OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = Some(Protocol::Azure); + params.voice = request + .voice + .as_deref() + .map(openai::parse_realtime_voice_value) + .transpose()?; + params.tools.push(openai::get_time_function_definition()); + + OpenAIDialog.conversation(params, conversation).await + } + + fn parse_service_event(&self, value: serde_json::Value) -> Result> { + OpenAIProvider.parse_service_event(value) + } + + fn function_result_event( + &self, + call_id: String, + _name: Option, + result: String, + ) -> Result { + OpenAIProvider.function_result_event(call_id, None, result) + } + + fn output_format(&self, input_format: AudioFormat) -> AudioFormat { + input_format + } + + fn voices(&self) -> &'static [&'static str] { + ::VARIANTS + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + println!("Available models for Azure:"); + println!( + "- Azure Realtime uses deployment names configured in your Azure OpenAI resource." + ); + println!( + "- The realtime endpoint does not provide a provider-agnostic model listing API here." + ); + + if let Some(model) = request + .model + .or_else(|| env::var("AZURE_OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + { + println!("- Configured deployment/model: {model}"); + } else { + println!( + "- Set --model or AZURE_OPENAI_REALTIME_API_MODEL to your Azure deployment name." + ); + } + + Ok(()) + } +} diff --git a/examples/dialog_providers/google.rs b/examples/dialog_providers/google.rs new file mode 100644 index 0000000..a2b4ee1 --- /dev/null +++ b/examples/dialog_providers/google.rs @@ -0,0 +1,243 @@ +use std::env; + +use anyhow::{Context, Result, bail}; +use async_trait::async_trait; +use gemini_live::types as gemini_types; +use google_dialog::{ + GoogleDialog, ServiceInputEvent as GoogleServiceInputEvent, ServiceOutputEvent, +}; +use reqwest::Url; +use serde::Deserialize; +use serde_json::json; + +use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; +use crate::{FunctionCall, get_time_parameters_schema}; +use context_switch_core::{AudioFormat, Conversation, Service}; + +pub struct GoogleProvider; + +#[async_trait(?Send)] +impl ProviderApi for GoogleProvider { + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()> { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let model = request + .model + .or_else(|| env::var("GEMINI_LIVE_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .unwrap_or_else(|| "gemini-3.1-flash-live-preview".to_owned()); + + let mut params = google_dialog::Params::new(key, model); + params.host = request + .endpoint + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.voice = request + .voice + .as_deref() + .map(parse_voice_value) + .transpose()?; + params.input_audio_transcription = true; + params.output_audio_transcription = true; + params.tools.push(get_time_tool()); + + GoogleDialog.conversation(params, conversation).await + } + + fn parse_service_event(&self, value: serde_json::Value) -> Result> { + match serde_json::from_value(value)? { + ServiceOutputEvent::FunctionCall { + name, + call_id, + arguments, + } => Ok(Some(FunctionCall { + name, + call_id, + arguments: Some(arguments), + })), + ServiceOutputEvent::ToolCallCancellation { call_id } => { + tracing::info!("Tool call cancelled: {call_id}"); + Ok(None) + } + ServiceOutputEvent::SessionUpdated { tools } => { + tracing::info!("Session updated: {tools:?}"); + Ok(None) + } + } + } + + fn function_result_event( + &self, + call_id: String, + name: Option, + result: String, + ) -> Result { + let name = name.context("Function name is required for Google function result")?; + let output = json!({ "time": serde_json::Value::String(result) }); + serde_json::to_value(&GoogleServiceInputEvent::FunctionCallResult { + call_id, + name, + output, + }) + .map_err(Into::into) + } + + fn output_format(&self, _input_format: AudioFormat) -> AudioFormat { + AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE) + } + + fn voices(&self) -> &'static [&'static str] { + VOICES + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + let key = env::var("GEMINI_API_KEY").context("GEMINI_API_KEY undefined")?; + let endpoint = request + .endpoint + .or_else(|| env::var("GEMINI_LIVE_ENDPOINT").ok()); + let models_url = models_url(endpoint.as_deref())?; + + let response: GeminiModelsResponse = reqwest::Client::new() + .get(models_url) + .query(&[("key", key)]) + .send() + .await + .context("Requesting Gemini models")? + .error_for_status() + .context("Gemini models request failed")? + .json() + .await + .context("Decoding Gemini models response")?; + + let mut live_models: Vec<_> = response + .models + .into_iter() + .filter(|model| is_live_model(&model.name, &model.supported_generation_methods)) + .map(|model| model.name) + .collect(); + live_models.sort(); + + println!("Available models for Google (Live API capable):"); + if live_models.is_empty() { + println!("- No Live-capable models were detected from models.list."); + println!( + "- This can happen when model metadata does not include Live-specific methods." + ); + println!("- Try explicitly using a known Live model, for example:"); + println!(" - models/gemini-3.1-flash-live-preview"); + println!(" - models/gemini-2.5-flash-live-preview"); + } else { + for model in live_models { + println!("- {model}"); + } + } + Ok(()) + } +} + +fn models_url(endpoint: Option<&str>) -> Result { + const GOOGLE_MODELS_ENDPOINT: &str = "https://generativelanguage.googleapis.com/v1beta/models"; + + let Some(endpoint) = endpoint else { + return Ok(GOOGLE_MODELS_ENDPOINT.to_owned()); + }; + + let mut url = Url::parse(endpoint) + .or_else(|_| Url::parse(GOOGLE_MODELS_ENDPOINT)) + .context("Parsing Gemini model list URL")?; + + let normalized_scheme = match url.scheme() { + "wss" => "https".to_owned(), + "ws" => "http".to_owned(), + other => other.to_owned(), + }; + url.set_scheme(&normalized_scheme).ok(); + url.set_path("/v1beta/models"); + url.set_query(None); + Ok(url.to_string()) +} + +fn is_live_model(model_name: &str, methods: &[String]) -> bool { + if model_name.to_ascii_lowercase().contains("live") { + return true; + } + + methods.iter().any(|method| { + method.eq_ignore_ascii_case("bidiGenerateContent") + || method.eq_ignore_ascii_case("streamGenerateContent") + }) +} + +const VOICES: &[&str] = &[ + "Zephyr", + "Puck", + "Charon", + "Kore", + "Fenrir", + "Leda", + "Orus", + "Aoede", + "Callirrhoe", + "Autonoe", + "Enceladus", + "Iapetus", + "Umbriel", + "Algieba", + "Despina", + "Erinome", + "Algenib", + "Rasalgethi", + "Laomedeia", + "Achernar", + "Alnilam", + "Schedar", + "Gacrux", + "Pulcherrima", + "Achird", + "Zubenelgenubi", + "Vindemiatrix", + "Sadachbia", + "Sadaltager", + "Sulafat", +]; + +fn parse_voice_value(value: &str) -> Result { + if VOICES.iter().any(|voice| voice.eq_ignore_ascii_case(value)) { + let voice = VOICES + .iter() + .find(|voice| voice.eq_ignore_ascii_case(value)) + .copied() + .unwrap_or(value) + .to_owned(); + Ok(voice) + } else { + let available = VOICES.join(", "); + bail!("Invalid Gemini voice `{value}`. Available voices: {available}") + } +} + +fn get_time_tool() -> gemini_types::Tool { + gemini_types::Tool::FunctionDeclarations(vec![gemini_types::FunctionDeclaration { + name: "get_time".into(), + description: "The current time to the exact second.".into(), + parameters: get_time_parameters_schema(), + scheduling: None, + behavior: None, + }]) +} + +#[derive(Debug, Deserialize)] +struct GeminiModelsResponse { + #[serde(default)] + models: Vec, +} + +#[derive(Debug, Deserialize)] +struct GeminiModel { + name: String, + #[serde(default)] + supported_generation_methods: Vec, +} diff --git a/examples/dialog_providers/mod.rs b/examples/dialog_providers/mod.rs new file mode 100644 index 0000000..a62456f --- /dev/null +++ b/examples/dialog_providers/mod.rs @@ -0,0 +1,49 @@ +use anyhow::Result; +use async_trait::async_trait; + +use crate::{FunctionCall, Provider}; +use context_switch_core::{AudioFormat, Conversation}; + +#[derive(Debug, Clone)] +pub struct ListModelsRequest { + pub endpoint: Option, + pub model: Option, +} + +#[derive(Debug, Clone)] +pub struct StartConversationRequest { + pub endpoint: Option, + pub model: Option, + pub voice: Option, +} + +#[async_trait(?Send)] +pub trait ProviderApi { + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()>; + fn parse_service_event(&self, value: serde_json::Value) -> Result>; + fn function_result_event( + &self, + call_id: String, + name: Option, + result: String, + ) -> Result; + fn output_format(&self, input_format: AudioFormat) -> AudioFormat; + fn voices(&self) -> &'static [&'static str]; + async fn list_models(&self, request: ListModelsRequest) -> Result<()>; +} + +pub fn provider_api(provider: Provider) -> &'static dyn ProviderApi { + match provider { + Provider::OpenAI => &openai::OpenAIProvider, + Provider::AzureOpenAI => &azure_openai::AzureOpenAIProvider, + Provider::Google => &google::GoogleProvider, + } +} + +mod azure_openai; +mod google; +mod openai; diff --git a/examples/dialog_providers/openai.rs b/examples/dialog_providers/openai.rs new file mode 100644 index 0000000..4d543be --- /dev/null +++ b/examples/dialog_providers/openai.rs @@ -0,0 +1,175 @@ +use std::{env, str::FromStr}; + +use anyhow::{Context, Result, anyhow}; +use async_trait::async_trait; +use openai_api_rs::realtime::types as openai_types; +use openai_dialog::{ + OpenAIDialog, Protocol, ServiceInputEvent as OpenAIServiceInputEvent, + ServiceOutputEvent as OpenAIServiceOutputEvent, +}; +use reqwest::Url; +use serde::Deserialize; +use serde_json::json; +use strum::VariantNames; + +use super::{ListModelsRequest, ProviderApi, StartConversationRequest}; +use crate::{FunctionCall, get_time_parameters_schema}; +use context_switch_core::{AudioFormat, Conversation, Service}; + +pub struct OpenAIProvider; + +#[async_trait(?Send)] +impl ProviderApi for OpenAIProvider { + async fn start_conversation( + &self, + request: StartConversationRequest, + conversation: Conversation, + ) -> Result<()> { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let model = request + .model + .or_else(|| env::var("OPENAI_REALTIME_API_MODEL").ok()) + .filter(|model| !model.trim().is_empty()) + .context("Provide --model or set OPENAI_REALTIME_API_MODEL")?; + + let mut params = openai_dialog::Params::new(key, model); + params.host = request + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()) + .filter(|endpoint| !endpoint.trim().is_empty()); + params.protocol = Some(Protocol::OpenAI); + params.voice = request + .voice + .as_deref() + .map(parse_realtime_voice_value) + .transpose()?; + params.tools.push(get_time_function_definition()); + + OpenAIDialog.conversation(params, conversation).await + } + + fn parse_service_event(&self, value: serde_json::Value) -> Result> { + match serde_json::from_value(value)? { + OpenAIServiceOutputEvent::FunctionCall { + name, + call_id, + arguments, + } => Ok(Some(FunctionCall { + name, + call_id, + arguments, + })), + OpenAIServiceOutputEvent::SessionUpdated { tools } => { + tracing::info!("Session updated: {tools:?}"); + Ok(None) + } + } + } + + fn function_result_event( + &self, + call_id: String, + _name: Option, + result: String, + ) -> Result { + let output = json!({ "time": serde_json::Value::String(result) }); + serde_json::to_value(&OpenAIServiceInputEvent::FunctionCallResult { call_id, output }) + .map_err(Into::into) + } + + fn output_format(&self, input_format: AudioFormat) -> AudioFormat { + input_format + } + + fn voices(&self) -> &'static [&'static str] { + openai_types::RealtimeVoice::VARIANTS + } + + async fn list_models(&self, request: ListModelsRequest) -> Result<()> { + let key = env::var("OPENAI_API_KEY").context("OPENAI_API_KEY undefined")?; + let endpoint = request + .endpoint + .or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok()); + let models_url = models_url(endpoint.as_deref())?; + + let response: OpenAIModelsResponse = reqwest::Client::new() + .get(models_url) + .bearer_auth(key) + .send() + .await + .context("Requesting OpenAI models")? + .error_for_status() + .context("OpenAI models request failed")? + .json() + .await + .context("Decoding OpenAI models response")?; + + let mut models: Vec<_> = response + .data + .into_iter() + .map(|model| model.id) + .filter(|id| is_realtime_model(id)) + .collect(); + models.sort(); + + println!("Available models for OpenAI:"); + if models.is_empty() { + println!("- No realtime-capable models were returned by the models endpoint."); + println!("- Ensure your API key has access to OpenAI Realtime API models."); + } else { + for model in models { + println!("- {model}"); + } + } + Ok(()) + } +} + +pub fn parse_realtime_voice_value(value: &str) -> Result { + openai_types::RealtimeVoice::from_str(value) + .map_err(|error| anyhow!("Invalid voice value `{value}`: {error}")) +} + +pub fn get_time_function_definition() -> openai_types::ToolDefinition { + openai_types::ToolDefinition::Function { + name: "get_time".into(), + description: "The current time to the exact second.".into(), + parameters: get_time_parameters_schema(), + } +} + +fn is_realtime_model(model_id: &str) -> bool { + model_id.to_ascii_lowercase().contains("realtime") +} + +fn models_url(endpoint: Option<&str>) -> Result { + const OPENAI_MODELS_ENDPOINT: &str = "https://api.openai.com/v1/models"; + + let Some(endpoint) = endpoint else { + return Ok(OPENAI_MODELS_ENDPOINT.to_owned()); + }; + + let mut url = Url::parse(endpoint) + .or_else(|_| Url::parse(OPENAI_MODELS_ENDPOINT)) + .context("Parsing OpenAI model list URL")?; + + let normalized_scheme = match url.scheme() { + "wss" => "https".to_owned(), + "ws" => "http".to_owned(), + other => other.to_owned(), + }; + url.set_scheme(&normalized_scheme).ok(); + url.set_path("/v1/models"); + url.set_query(None); + Ok(url.to_string()) +} + +#[derive(Debug, Deserialize)] +struct OpenAIModelsResponse { + data: Vec, +} + +#[derive(Debug, Deserialize)] +struct OpenAIModel { + id: String, +} diff --git a/examples/transcribe.rs b/examples/transcribe.rs index d4cda99..480c849 100644 --- a/examples/transcribe.rs +++ b/examples/transcribe.rs @@ -13,10 +13,9 @@ use context_switch::services::{ AristechTranscribe, AzureTranscribe, ElevenLabsTranscribe, GoogleTranscribe, }; use context_switch::{AudioConsumer, InputModality, OutputModality}; -use context_switch_core::conversation::{Conversation, Input}; use context_switch_core::language::Languages; use context_switch_core::service::Service; -use context_switch_core::{AudioFormat, AudioFrame, audio}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, audio}; const DEFAULT_LANGUAGE: &str = "en-US"; diff --git a/external/gemini-live-rs b/external/gemini-live-rs new file mode 160000 index 0000000..4674ae2 --- /dev/null +++ b/external/gemini-live-rs @@ -0,0 +1 @@ +Subproject commit 4674ae228730ad75b24d088eee3228a4d5036ae1 diff --git a/services/aristech/src/synthesize.rs b/services/aristech/src/synthesize.rs index 2820c29..3a9a745 100644 --- a/services/aristech/src/synthesize.rs +++ b/services/aristech/src/synthesize.rs @@ -10,10 +10,7 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; use tracing::debug; -use context_switch_core::{ - AudioFormat, AudioFrame, Service, - conversation::{Conversation, Input}, -}; +use context_switch_core::{AudioFormat, AudioFrame, Conversation, Input, Service}; //TODO: Add `language` field as alternative to `voice_id` #[derive(Debug, Serialize, Deserialize)] diff --git a/services/aristech/src/transcribe.rs b/services/aristech/src/transcribe.rs index 1c3b078..0719084 100644 --- a/services/aristech/src/transcribe.rs +++ b/services/aristech/src/transcribe.rs @@ -12,10 +12,7 @@ use async_trait::async_trait; use serde::Deserialize; use tonic::codegen::CompressionEncoding; -use context_switch_core::{ - Service, - conversation::{Conversation, Input}, -}; +use context_switch_core::{Conversation, Input, Service}; /// Authentication configuration #[derive(Debug, Deserialize)] diff --git a/services/azure/src/synthesize.rs b/services/azure/src/synthesize.rs index 3d56163..fc47bcd 100644 --- a/services/azure/src/synthesize.rs +++ b/services/azure/src/synthesize.rs @@ -13,8 +13,7 @@ use tracing::debug; use crate::Host; use context_switch_core::{ - AudioFrame, BillingRecord, Service, - conversation::{BillingSchedule, Conversation, Input}, + AudioFrame, BillingRecord, BillingSchedule, Conversation, Input, Service, }; #[derive(Debug, Serialize, Deserialize)] diff --git a/services/azure/src/transcribe.rs b/services/azure/src/transcribe.rs index e4037a3..bdb38c5 100644 --- a/services/azure/src/transcribe.rs +++ b/services/azure/src/transcribe.rs @@ -7,10 +7,9 @@ use serde::Deserialize; use tracing::{error, info}; use crate::Host; +use context_switch_core::language::Languages; use context_switch_core::{ - BillingRecord, Service, - conversation::{BillingSchedule, Conversation, ConversationOutput, Input}, - language::Languages, + BillingRecord, BillingSchedule, Conversation, ConversationOutput, Input, Service, speech_gate::make_speech_gate_processor_soft_rms, }; diff --git a/services/azure/src/translate.rs b/services/azure/src/translate.rs index 7b6731d..129232d 100644 --- a/services/azure/src/translate.rs +++ b/services/azure/src/translate.rs @@ -8,8 +8,8 @@ use tracing::{debug, error}; use crate::Host; use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, OutputModality, OutputPath, Service, - conversation::{BillingSchedule, Conversation, Input}, + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, Conversation, Input, OutputModality, + OutputPath, Service, }; #[derive(Debug, Deserialize)] diff --git a/services/elevenlabs/src/transcribe.rs b/services/elevenlabs/src/transcribe.rs index d0f8a9c..81f86ee 100644 --- a/services/elevenlabs/src/transcribe.rs +++ b/services/elevenlabs/src/transcribe.rs @@ -7,21 +7,17 @@ use serde_json::Value; use tokio::select; use tokio::sync::mpsc; use tokio::time::{Duration, sleep}; -use tokio_tungstenite::{ - connect_async_with_config, - tungstenite::{ - Message, - client::IntoClientRequest, - http::{HeaderName, HeaderValue}, - }, -}; +use tokio_tungstenite::connect_async_with_config; +use tokio_tungstenite::tungstenite::Message; +use tokio_tungstenite::tungstenite::client::IntoClientRequest; +use tokio_tungstenite::tungstenite::http::{HeaderName, HeaderValue}; use tracing::{debug, error, warn}; use url::Url; +use context_switch_core::language::{bcp47_to_iso639_3, iso639_to_bcp47}; use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, Service, - conversation::{BillingSchedule, Conversation, ConversationInput, ConversationOutput, Input}, - language::{bcp47_to_iso639_3, iso639_to_bcp47}, + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, Conversation, ConversationInput, + ConversationOutput, Input, Service, }; // Observed Scribe v2 behavior as of 2026-04-02: diff --git a/services/google-dialog/Cargo.toml b/services/google-dialog/Cargo.toml new file mode 100644 index 0000000..3be04a9 --- /dev/null +++ b/services/google-dialog/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "google-dialog" +version = "0.1.0" +edition.workspace = true + +[dependencies] +context-switch-core = { workspace = true } + +gemini-live = { workspace = true } + +anyhow = { workspace = true } +async-trait = { workspace = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } diff --git a/services/google-dialog/src/client.rs b/services/google-dialog/src/client.rs new file mode 100644 index 0000000..62a8325 --- /dev/null +++ b/services/google-dialog/src/client.rs @@ -0,0 +1,390 @@ +use anyhow::{Context, Result, bail}; + +use gemini_live::transport::{Auth, Endpoint, TransportConfig}; +use gemini_live::types::{ + AudioTranscriptionConfig, Content, ContextWindowCompressionConfig, FunctionDeclaration, + FunctionResponse, GenerationConfig, Modality, ModalityTokenCount, Part, PrebuiltVoiceConfig, + ServerEvent, SessionResumptionConfig, SetupConfig, SlidingWindow, SpeechConfig, ThinkingConfig, + Tool, UsageMetadata, VoiceConfig, +}; +use gemini_live::{ReconnectPolicy, Session, SessionConfig}; +use tracing::{debug, info, trace}; + +use crate::{Params, ServiceInputEvent, ServiceOutputEvent, TextOutputs}; +use context_switch_core::{ + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, ConversationInput, ConversationOutput, + Input, OutputPath, +}; + +#[derive(Debug)] +pub struct Client { + params: Params, +} + +impl Client { + pub fn new(params: Params) -> Self { + Self { params } + } + + pub async fn dialog( + self, + output_format: AudioFormat, + text_outputs: TextOutputs, + mut input: ConversationInput, + output: ConversationOutput, + ) -> Result<()> { + let billing_scope = self.params.model.clone(); + let tools = function_declarations(&self.params.tools); + let mut output_transcription_buffer = String::new(); + let mut session = Session::connect(session_config(&self.params, text_outputs)?) + .await + .context("Connecting to Gemini Live")?; + output.service_event( + OutputPath::Control, + ServiceOutputEvent::SessionUpdated { tools }, + )?; + + loop { + tokio::select! { + input = input.recv() => { + if let Some(input) = input { + self.process_input(&session, input).await?; + } else { + debug!("Conversation input closed"); + session.audio_stream_end().await.context("Ending Gemini audio stream")?; + debug!("Audio stream end sent to Gemini"); + break; + } + } + event = session.next_event() => { + match event { + Some(event) => { + match self.process_event(event, output_format, text_outputs, &output, &billing_scope, &mut output_transcription_buffer)? { + FlowControl::Continue => {} + FlowControl::End => { + debug!("Received terminal server event"); + break; + } + } + } + None => { + debug!("Server event stream ended"); + break; + } + } + } + } + } + + debug!("Closing session"); + session.close().await.context("Closing session")?; + Ok(()) + } + + async fn process_input(&self, session: &Session, input: Input) -> Result<()> { + match input { + Input::Audio { frame } => { + let mono = frame.into_mono(); + let sample_rate = mono.format.sample_rate; + let audio = mono.to_le_bytes(); + session + .send_audio_at_rate(&audio, sample_rate) + .await + .context("Sending audio to Gemini Live")?; + } + Input::Text { text, .. } => { + session + .send_text(&text) + .await + .context("Sending text to Gemini Live")?; + } + Input::ServiceEvent { value } => match serde_json::from_value(value)? { + ServiceInputEvent::FunctionCallResult { + call_id, + name, + output, + } => { + let response = FunctionResponse { + id: call_id, + name, + response: output, + }; + session + .send_tool_response(vec![response]) + .await + .context("Sending tool response")?; + } + ServiceInputEvent::Prompt { text } => { + info!("Received prompt"); + session.send_text(&text).await.context("Sending prompt")?; + } + }, + } + Ok(()) + } + + fn process_event( + &self, + event: ServerEvent, + output_format: AudioFormat, + text_outputs: TextOutputs, + output: &ConversationOutput, + billing_scope: &str, + output_transcription_buffer: &mut String, + ) -> Result { + trace!(?event, "Gemini Live event"); + match event { + ServerEvent::SetupComplete => {} + ServerEvent::ModelText(text) => { + // This does not seem to work, even when we enable TEXT response modalities. + debug!(%text, "Gemini model text"); + } + ServerEvent::ModelAudio(audio) => { + let frame = AudioFrame::from_le_bytes(output_format, &audio); + output.audio_frame(frame)?; + } + ServerEvent::GenerationComplete => {} + ServerEvent::TurnComplete => { + if text_outputs.text && !output_transcription_buffer.is_empty() { + output.text( + true, + std::mem::take(output_transcription_buffer), + None, + Some(self.params.model.clone()), + )?; + } else { + output_transcription_buffer.clear(); + } + output.request_completed(None)?; + } + ServerEvent::Interrupted => { + output_transcription_buffer.clear(); + output.clear_audio()?; + } + ServerEvent::InputTranscription(text) => { + if text_outputs.text { + output.text(true, text, None, None)?; + } + } + ServerEvent::OutputTranscription(text) => { + output_transcription_buffer.push_str(&text); + if text_outputs.interim { + output.text( + false, + output_transcription_buffer.clone(), + None, + Some(self.params.model.clone()), + )?; + } + } + ServerEvent::ToolCall(calls) => { + for call in calls { + // Send the function call via the media path. + // + // This means that audio scheduled before will finish playing before + // the client receives the event to execute the function call. + // + // For example, if we use a prompt to initiate a function call, it + // might overtake currently pending audio output and an answer + // before the participant even heard the audio. + output.service_event( + OutputPath::Media, + ServiceOutputEvent::FunctionCall { + call_id: call.id, + name: call.name, + arguments: call.args, + }, + )?; + } + } + ServerEvent::ToolCallCancellation(ids) => { + // Since we are sending function calls through the media path, we need to send + // cancellations too, so that they don't overtake. + for id in ids { + output.service_event( + OutputPath::Media, + ServiceOutputEvent::ToolCallCancellation { call_id: id }, + )?; + } + } + ServerEvent::SessionResumption { .. } => {} + ServerEvent::GoAway { time_left } => { + debug!(?time_left, "GoAway received"); + } + ServerEvent::Usage(usage) => { + bill_usage(output, billing_scope, usage)?; + } + ServerEvent::Closed { reason } => { + if !reason.is_empty() { + debug!(%reason, "Endpoint signaled connection closure"); + } else { + debug!("Endpoint signaled connection closure without a reason"); + } + return Ok(FlowControl::End); + } + ServerEvent::Error(error) => { + bail!("Gemini Live error: {}", error.message); + } + } + Ok(FlowControl::Continue) + } +} + +fn function_declarations(tools: &[Tool]) -> Option> { + let declarations: Vec<_> = tools + .iter() + .filter_map(|tool| match tool { + Tool::FunctionDeclarations(declarations) => Some(declarations.as_slice()), + Tool::GoogleSearch(_) => None, + }) + .flatten() + .cloned() + .collect(); + + (!declarations.is_empty()).then_some(declarations) +} + +fn session_config(params: &Params, text_outputs: TextOutputs) -> Result { + let transport = TransportConfig { + endpoint: params + .host + .clone() + .map(Endpoint::Custom) + .unwrap_or_default(), + auth: Auth::ApiKey(params.api_key.clone()), + ..Default::default() + }; + + Ok(SessionConfig { + transport, + setup: setup_config(params, text_outputs)?, + reconnect: ReconnectPolicy::default(), + }) +} + +fn setup_config(params: &Params, text_outputs: TextOutputs) -> Result { + let input_audio_transcription = params + .input_audio_transcription + .then_some(AudioTranscriptionConfig {}); + let output_audio_transcription = params + .output_audio_transcription + .then_some(AudioTranscriptionConfig {}); + + if !(text_outputs.text || text_outputs.interim) + && (input_audio_transcription.is_some() || output_audio_transcription.is_some()) + { + bail!( + "Google dialog requires text output modality when transcription is enabled: if inputAudioTranscription or outputAudioTranscription is set, add OutputModality::Text or OutputModality::InterimText to the conversation output modalities, or set both transcription flags to false." + ); + } + + Ok(SetupConfig { + model: model_resource_name(¶ms.model), + generation_config: Some(GenerationConfig { + // NOTE: Enabling Modality::Text here currently causes a Gemini setup-time + // "Internal error encountered." in this service flow. + response_modalities: Some(vec![Modality::Audio]), + speech_config: params.voice.clone().map(|voice_name| SpeechConfig { + voice_config: VoiceConfig { + prebuilt_voice_config: PrebuiltVoiceConfig { voice_name }, + }, + }), + thinking_config: params.thinking_level.map(|thinking_level| ThinkingConfig { + thinking_level: Some(thinking_level), + ..Default::default() + }), + temperature: params.temperature, + ..Default::default() + }), + system_instruction: params.instructions.clone().map(system_instruction), + tools: (!params.tools.is_empty()).then(|| params.tools.clone()), + realtime_input_config: params.realtime_input_config.clone(), + // Opt in so Gemini sends resume handles. The session layer stores + // the latest handle and patches it into reconnect setup messages, + // keeping context across GoAway-triggered reconnects. + session_resumption: Some(SessionResumptionConfig::default()), + context_window_compression: params.context_window_compression.then_some( + ContextWindowCompressionConfig { + sliding_window: Some(SlidingWindow::default()), + ..Default::default() + }, + ), + input_audio_transcription, + output_audio_transcription, + ..Default::default() + }) +} + +fn model_resource_name(model: &str) -> String { + if model.starts_with("models/") { + model.to_owned() + } else { + format!("models/{model}") + } +} + +fn system_instruction(text: String) -> Content { + Content { + role: None, + parts: vec![Part { + text: Some(text), + inline_data: None, + }], + } +} + +fn bill_usage( + output: &ConversationOutput, + billing_scope: &str, + usage: UsageMetadata, +) -> Result<()> { + let prompt_audio_total = modality_count(&usage.prompt_tokens_details, "AUDIO"); + let prompt_text_total = modality_count(&usage.prompt_tokens_details, "TEXT"); + let cached_audio = modality_count(&usage.cache_tokens_details, "AUDIO"); + let cached_text = modality_count(&usage.cache_tokens_details, "TEXT"); + let tool_audio = modality_count(&usage.tool_use_prompt_tokens_details, "AUDIO"); + let tool_text = modality_count(&usage.tool_use_prompt_tokens_details, "TEXT"); + let response_audio = modality_count(&usage.response_tokens_details, "AUDIO"); + let response_text = modality_count(&usage.response_tokens_details, "TEXT"); + + let prompt_audio = prompt_audio_total + .checked_sub(cached_audio + tool_audio) + .context("Invalid Gemini usage: prompt audio tokens less than cached+tool audio tokens")?; + let prompt_text = prompt_text_total + .checked_sub(cached_text + tool_text) + .context("Invalid Gemini usage: prompt text tokens less than cached+tool text tokens")?; + + let records = [ + BillingRecord::count("tokens:input:audio", prompt_audio), + BillingRecord::count("tokens:input:text", prompt_text), + BillingRecord::count("tokens:input:audio:cached", cached_audio), + BillingRecord::count("tokens:input:text:cached", cached_text), + BillingRecord::count("tokens:input:audio:tool", tool_audio), + BillingRecord::count("tokens:input:text:tool", tool_text), + BillingRecord::count("tokens:output:audio", response_audio), + BillingRecord::count("tokens:output:text", response_text), + BillingRecord::count("tokens:thoughts", usage.thoughts_token_count as _), + ]; + + output.billing_records( + None, + Some(billing_scope.into()), + records, + BillingSchedule::Now, + )?; + Ok(()) +} + +fn modality_count(details: &Option>, modality: &str) -> usize { + details + .iter() + .flatten() + .filter(|detail| detail.modality.eq_ignore_ascii_case(modality)) + .map(|detail| detail.token_count as usize) + .sum() +} + +enum FlowControl { + Continue, + End, +} diff --git a/services/google-dialog/src/lib.rs b/services/google-dialog/src/lib.rs new file mode 100644 index 0000000..45292fc --- /dev/null +++ b/services/google-dialog/src/lib.rs @@ -0,0 +1,77 @@ +//! Gemini Live audio dialog service. + +use anyhow::{Result, bail}; +use async_trait::async_trait; +use tracing::info; + +use context_switch_core::{AudioFormat, Conversation, OutputModality, Service}; + +mod client; +mod types; + +use client::Client; +pub use types::{Params, ServiceInputEvent, ServiceOutputEvent}; + +#[derive(Debug)] +pub struct GoogleDialog; + +#[async_trait] +impl Service for GoogleDialog { + type Params = Params; + + async fn conversation(&self, params: Params, conversation: Conversation) -> Result<()> { + let _input_format = conversation.require_audio_input()?; + let output_format = conversation.require_one_audio_output()?; + let text_outputs = TextOutputs::from_modalities(&conversation.output_modalities)?; + + let expected_output = AudioFormat::new(1, gemini_live::audio::OUTPUT_SAMPLE_RATE); + if output_format != expected_output { + bail!( + "Audio output has the wrong format {:?}, expected: {:?}", + output_format, + expected_output + ); + } + + info!(model = %params.model, "Connecting to Gemini Live"); + let (input, output) = conversation.start()?; + Client::new(params) + .dialog(output_format, text_outputs, input, output) + .await + } +} + +#[derive(Debug, Clone, Copy)] +struct TextOutputs { + text: bool, + interim: bool, +} + +impl TextOutputs { + fn from_modalities(output_modalities: &[OutputModality]) -> Result { + let mut text_outputs = Self { + text: false, + interim: false, + }; + + for modality in output_modalities { + match modality { + OutputModality::Text => { + if text_outputs.text { + bail!("Expecting at most one text output") + } + text_outputs.text = true; + } + OutputModality::InterimText => { + if text_outputs.interim { + bail!("Expecting at most one interim text output") + } + text_outputs.interim = true; + } + OutputModality::Audio { .. } => {} + } + } + + Ok(text_outputs) + } +} diff --git a/services/google-dialog/src/types.rs b/services/google-dialog/src/types.rs new file mode 100644 index 0000000..6ee5e30 --- /dev/null +++ b/services/google-dialog/src/types.rs @@ -0,0 +1,87 @@ +use gemini_live::types::{FunctionDeclaration, RealtimeInputConfig, ThinkingLevel, Tool}; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Params { + pub api_key: String, + /// Gemini Live model name, with or without the `models/` prefix. + pub model: String, + pub host: Option, + pub instructions: Option, + pub voice: Option, + + /// Sampling temperature. Valid range: `0.0..=2.0`. + /// If omitted, Gemini uses the model-specific default temperature. + pub temperature: Option, + /// Gemini 3.1 thinking level (`minimal`, `low`, `medium`, or `high`). + /// In Live API, Gemini 3.1 defaults to `minimal` when omitted. + pub thinking_level: Option, + /// Enabled by default to avoid context-window exhaustion during long audio sessions. + #[serde(default = "default_context_window_compression")] + pub context_window_compression: bool, + #[serde(default)] + pub tools: Vec, + /// Gemini realtime input behavior, including VAD and turn coverage. + pub realtime_input_config: Option, + /// Enable server-side transcription of user input audio. + #[serde(default)] + pub input_audio_transcription: bool, + /// Enable server-side transcription of model output audio. + #[serde(default)] + pub output_audio_transcription: bool, +} + +impl Params { + pub fn new(api_key: impl Into, model: impl Into) -> Self { + Self { + api_key: api_key.into(), + model: model.into(), + host: None, + instructions: None, + voice: None, + temperature: None, + thinking_level: None, + context_window_compression: default_context_window_compression(), + tools: vec![], + realtime_input_config: None, + input_audio_transcription: false, + output_audio_transcription: false, + } + } +} + +fn default_context_window_compression() -> bool { + true +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ServiceInputEvent { + #[serde(rename_all = "camelCase")] + FunctionCallResult { + call_id: String, + name: String, + output: serde_json::Value, + }, + Prompt { + text: String, + }, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "camelCase")] +pub enum ServiceOutputEvent { + #[serde(rename_all = "camelCase")] + FunctionCall { + call_id: String, + name: String, + arguments: serde_json::Value, + }, + #[serde(rename_all = "camelCase")] + ToolCallCancellation { call_id: String }, + SessionUpdated { + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + }, +} diff --git a/services/google-transcribe/src/transcribe.rs b/services/google-transcribe/src/transcribe.rs index dc91808..397139b 100644 --- a/services/google-transcribe/src/transcribe.rs +++ b/services/google-transcribe/src/transcribe.rs @@ -10,9 +10,8 @@ use tokio::sync::mpsc::UnboundedReceiver; use tonic::Code; use context_switch_core::{ - AudioFormat, AudioFrame, AudioProducer, BillingRecord, OutputModality, Service, - conversation::{BillingSchedule, Conversation, ConversationOutput, Input}, - language::Languages, + AudioFormat, AudioFrame, AudioProducer, BillingRecord, BillingSchedule, Conversation, + ConversationOutput, Input, OutputModality, Service, language::Languages, }; use tracing::{info, warn}; @@ -176,7 +175,7 @@ async fn transcribe_and_process_stream_session( async fn process_stream_session( model: &str, include_detected_language: bool, - output: &context_switch_core::conversation::ConversationOutput, + output: &ConversationOutput, response_stream: S, ) -> Result where diff --git a/services/openai-dialog/src/client.rs b/services/openai-dialog/src/client.rs index c985f93..b304825 100644 --- a/services/openai-dialog/src/client.rs +++ b/services/openai-dialog/src/client.rs @@ -2,20 +2,12 @@ use std::collections::VecDeque; use anyhow::{Context, Result, bail}; use base64::prelude::*; -use context_switch_core::{ - AudioFormat, AudioFrame, BillingRecord, OutputPath, audio, - conversation::{BillingSchedule, ConversationInput, ConversationOutput, Input}, -}; -use futures::{ - SinkExt, StreamExt, - stream::{SplitSink, SplitStream}, -}; -use openai_api_rs::realtime::{ - client_event::{self, ClientEvent}, - server_event::{self, ServerEvent}, - types::{ - self, ItemContentType, ItemRole, ItemStatus, ItemType, OutputModality, ResponseStatus, - }, +use futures::stream::{SplitSink, SplitStream}; +use futures::{SinkExt, StreamExt}; +use openai_api_rs::realtime::client_event::{self, ClientEvent}; +use openai_api_rs::realtime::server_event::{self, ServerEvent}; +use openai_api_rs::realtime::types::{ + self, ItemContentType, ItemRole, ItemStatus, ItemType, OutputModality, ResponseStatus, }; use tokio::{net::TcpStream, select}; use tokio_tungstenite::{ @@ -26,6 +18,10 @@ use tracing::{debug, info, trace, warn}; use uuid::Uuid; use crate::{Params, ServiceInputEvent, ServiceOutputEvent}; +use context_switch_core::{ + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, ConversationInput, ConversationOutput, + Input, OutputPath, audio, +}; pub struct Client { read: SplitStream>>, diff --git a/services/openai-dialog/src/lib.rs b/services/openai-dialog/src/lib.rs index f3bd280..0cdbeee 100644 --- a/services/openai-dialog/src/lib.rs +++ b/services/openai-dialog/src/lib.rs @@ -6,7 +6,7 @@ use anyhow::{Result, bail}; use async_trait::async_trait; use tracing::info; -use context_switch_core::{Service, conversation::Conversation}; +use context_switch_core::{Conversation, Service}; mod client; mod host; diff --git a/services/playback/src/lib.rs b/services/playback/src/lib.rs index 6d8b0ad..74d149d 100644 --- a/services/playback/src/lib.rs +++ b/services/playback/src/lib.rs @@ -1,25 +1,19 @@ -use std::{ - fs::{self, File}, - io::{self, BufReader}, - num::{NonZeroU16, NonZeroU32}, - path::{Path, PathBuf}, -}; +use std::fs::{self, File}; +use std::io::{self, BufReader}; +use std::num::{NonZeroU16, NonZeroU32}; +use std::path::{Path, PathBuf}; -use anyhow::{Context, Result, bail}; +use anyhow::{Context, Result, anyhow, bail}; use async_trait::async_trait; -use context_switch_core::{BillingRecord, audio, conversation::BillingSchedule}; -use rodio::{ - Decoder, Source, - conversions::{ChannelCountConverter, SampleRateConverter}, -}; +use rodio::conversions::{ChannelCountConverter, SampleRateConverter}; +use rodio::{Decoder, Source}; use serde::{Deserialize, Serialize}; use tokio::task; use tracing::{debug, error}; use url::Url; use context_switch_core::{ - AudioFormat, AudioFrame, Service, - conversation::{Conversation, Input}, + AudioFormat, AudioFrame, BillingRecord, BillingSchedule, Conversation, Input, Service, audio, }; mod stream_reader; @@ -353,7 +347,7 @@ pub fn check_supported_audio_type( } else { let guessed_mime = mime_guess2::from_path(path) .first() - .ok_or_else(|| anyhow::anyhow!("Invalid audio url (should end in `.mp3` or `.wav`)"))?; + .ok_or_else(|| anyhow!("Invalid audio url (should end in `.mp3` or `.wav`)"))?; guessed_mime.essence_str().to_string() }; diff --git a/src/context_switch.rs b/src/context_switch.rs index d0d10e1..a2abd40 100644 --- a/src/context_switch.rs +++ b/src/context_switch.rs @@ -14,8 +14,7 @@ use tracing_futures::Instrument; use crate::{AudioTracer, ClientEvent, ConversationId, InputModality, ServerEvent}; use context_switch_core::billing_collector::BillingCollector; -use context_switch_core::conversation::{Conversation, Input, Output}; -use context_switch_core::{AudioFrame, BillingContext, Registry}; +use context_switch_core::{AudioFrame, BillingContext, Conversation, Input, Output, Registry}; #[derive(Debug)] pub struct ContextSwitch { @@ -44,6 +43,7 @@ pub fn registry() -> Registry { .add_service("elevenlabs-transcribe", elevenlabs::ElevenLabsTranscribe) .add_service("google-transcribe", google_transcribe::GoogleTranscribe) .add_service("openai-dialog", openai_dialog::OpenAIDialog) + .add_service("google-dialog", google_dialog::GoogleDialog) .add_service("aristech-transcribe", aristech::AristechTranscribe) .add_service("aristech-synthesize", aristech::AristechSynthesize) } @@ -225,7 +225,7 @@ async fn process_conversation_protected( let service = registry.service(&service_name)?; // Temporarily use an unbounded channel for output forwarding because we may process rather - // large audio files (local playback for example) in one go are are not yet able to block sends. + // large audio files (local playback for example) in one go and are not able to block sends. let (output_sender, mut output_receiver) = unbounded_channel(); // We might receive a large number of audio frames before the service can process them. let (input_sender, input_receiver) = channel(256); diff --git a/src/lib.rs b/src/lib.rs index 2cdc8e7..51e3862 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,5 +15,6 @@ pub mod services { pub use aristech::AristechTranscribe; pub use azure::AzureTranscribe; pub use elevenlabs::ElevenLabsTranscribe; + pub use google_dialog::GoogleDialog; pub use google_transcribe::GoogleTranscribe; } diff --git a/src/protocol.rs b/src/protocol.rs index 3dbe8fa..a3f555d 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -3,8 +3,7 @@ use derive_more::derive::{Deref, Display, From, Into}; use serde::{Deserialize, Deserializer, Serialize, Serializer}; use context_switch_core::{ - BillingRecord, InputModality, OutputModality, OutputPath, audio, - conversation::{BillingId, RequestId}, + BillingId, BillingRecord, InputModality, OutputModality, OutputPath, RequestId, audio, }; /// Conversation identifier. diff --git a/src/tests.rs b/src/tests.rs index 807921d..8302536 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -103,9 +103,10 @@ mod helper { use anyhow::Result; use async_trait::async_trait; - use tokio::{sync::mpsc::Sender, time}; + use tokio::sync::mpsc::Sender; + use tokio::time; - use context_switch_core::{Service, conversation::Conversation}; + use context_switch_core::{Conversation, Service}; #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Notification {