Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# OpenAI Configuration
OPENAI_API_KEY=your_openai_key
OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview
OPENAI_REALTIME_ENDPOINT=

# Aristech
ARISTECH_ENDPOINT=
Expand Down
2 changes: 2 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ tokio = { workspace = true, features = ["rt-multi-thread"] }
openai-api-rs = { workspace = true }
serde_json = { workspace = true }
chrono-tz = { version = "0.10.3" }
url = { workspace = true }
strum = { version = "0.28" }


# For recognizing audio files in azure-transcribe.
Expand Down
8 changes: 8 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ Configure the services by setting the appropriate environment variables in your
# OpenAI Configuration
OPENAI_API_KEY=your_openai_key
OPENAI_REALTIME_API_MODEL=gpt-4o-mini-realtime-preview
OPENAI_REALTIME_ENDPOINT=

# Azure Configuration
AZURE_SUBSCRIPTION_KEY=your_azure_key
Expand All @@ -105,6 +106,13 @@ ELEVENLABS_API_KEY=your_elevenlabs_key
AUDIO_KNIFE_ADDRESS=127.0.0.1:8123
```

For Azure OpenAI realtime endpoints (`*.openai.azure.com`), the realtime client automatically appends
`api-key` as a query parameter to the websocket URL. For other hosts, it uses the standard
`Authorization: Bearer ...` header.

The websocket client does not follow redirects. If the endpoint responds with `3xx` (for example
`302 Found`), update the configured endpoint URL to the final websocket target.

## License

[MIT License](LICENSE)
50 changes: 49 additions & 1 deletion examples/openai-dialog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ use std::{

use anyhow::{Context, Result, bail};
use chrono::Utc;
use clap::builder::{PossibleValuesParser, TypedValueParser};
use clap::{Parser, ValueEnum};
use context_switch::{InputModality, OutputModality};
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use openai_api_rs::realtime::types;
use openai_dialog::{OpenAIDialog, ServiceInputEvent, ServiceOutputEvent};
use openai_dialog::{OpenAIDialog, Protocol, ServiceInputEvent, ServiceOutputEvent};
use rodio::{DeviceSinkBuilder, Player, Source};
use serde_json::json;
use tokio::{
Expand All @@ -27,8 +29,48 @@ use context_switch_core::{
conversation::{Conversation, Input, Output},
};

#[derive(Debug, Parser)]
struct Cli {
#[arg(long, value_enum)]
protocol: Option<CliProtocol>,
#[arg(long)]
endpoint: Option<String>,
#[arg(long, value_parser = realtime_voice_value_parser())]
voice: Option<types::RealtimeVoice>,
}

#[derive(Debug, Clone, Copy, ValueEnum)]
enum CliProtocol {
#[value(name = "openai")]
OpenAI,
Azure,
}

fn realtime_voice_value_parser() -> impl TypedValueParser<Value = types::RealtimeVoice> {
PossibleValuesParser::new(<types::RealtimeVoice as strum::VariantNames>::VARIANTS).try_map(
|value| {
parse_realtime_voice_value(&value)
.map_err(|e| format!("Invalid voice value `{value}`: {e}"))
},
)
}

fn parse_realtime_voice_value(value: &str) -> Result<types::RealtimeVoice, strum::ParseError> {
types::RealtimeVoice::from_str(value)
}

impl From<CliProtocol> for Protocol {
fn from(value: CliProtocol) -> Self {
match value {
CliProtocol::OpenAI => Protocol::OpenAI,
CliProtocol::Azure => Protocol::Azure,
}
}
}

#[tokio::main]
async fn main() -> Result<()> {
let cli = Cli::parse();
dotenvy::dotenv_override().context("Reading .env file")?;
tracing_subscriber::fmt::init();

Expand Down Expand Up @@ -76,6 +118,12 @@ async fn main() -> Result<()> {

let openai = OpenAIDialog;
let mut params = openai_dialog::Params::new(key, model);
params.host = cli
.endpoint
.or_else(|| env::var("OPENAI_REALTIME_ENDPOINT").ok())
.filter(|endpoint| !endpoint.trim().is_empty());
params.protocol = cli.protocol.map(Into::into);
params.voice = cli.voice;
params.tools.push(get_time_function_definition());

let (output_sender, output_receiver) = unbounded_channel();
Expand Down
1 change: 1 addition & 0 deletions services/openai-dialog/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ base64 = { workspace = true }
serde = { workspace = true }
async-trait = { workspace = true }
uuid = { workspace = true }
url = { workspace = true }
Loading
Loading