diff --git a/src/command.rs b/src/command.rs index 01d10b6..b6e6a41 100644 --- a/src/command.rs +++ b/src/command.rs @@ -115,10 +115,10 @@ pub enum Commands { command: IndexesCommands, }, - /// Full-text search across a table column + /// Full-text or vector search across a table column Search { - /// Search query text - query: String, + /// Search query text (omit to read a vector from stdin for vector search) + query: Option, /// Table to search (connection.schema.table) #[arg(long)] @@ -136,6 +136,10 @@ pub enum Commands { #[arg(long, default_value = "10")] limit: u32, + /// Embedding model to generate a vector from the query text (e.g. text-embedding-3-small) + #[arg(long, value_parser = ["text-embedding-3-small", "text-embedding-3-large"])] + model: Option, + /// Workspace ID (defaults to first workspace from login) #[arg(long, short = 'w')] workspace_id: Option, diff --git a/src/embedding.rs b/src/embedding.rs new file mode 100644 index 0000000..11ae4d4 --- /dev/null +++ b/src/embedding.rs @@ -0,0 +1,118 @@ +use serde_json::Value; + +/// Try to parse a vector from stdin. Accepts either: +/// - A raw JSON array of numbers: [0.1, -0.2, ...] +/// - An OpenAI-compatible response: {"data": [{"embedding": [...]}]} +pub fn read_vector_from_stdin() -> Vec { + use std::io::Read; + let mut input = String::new(); + std::io::stdin().read_to_string(&mut input).unwrap_or_else(|e| { + eprintln!("error reading stdin: {e}"); + std::process::exit(1); + }); + + let input = input.trim(); + if input.is_empty() { + eprintln!("error: no vector provided on stdin"); + std::process::exit(1); + } + + let parsed: Value = match serde_json::from_str(input) { + Ok(v) => v, + Err(e) => { + eprintln!("error parsing vector from stdin: {e}"); + std::process::exit(1); + } + }; + + extract_vector(&parsed) +} + +/// Extract a float vector from either a raw JSON array or an OpenAI embedding response. +fn extract_vector(value: &Value) -> Vec { + // Raw array: [0.1, -0.2, ...] + if let Some(arr) = value.as_array() { + return parse_float_array(arr); + } + + // OpenAI response: {"data": [{"embedding": [...]}]} + if let Some(embedding) = value.get("data") + .and_then(|d| d.get(0)) + .and_then(|d| d.get("embedding")) + .and_then(|e| e.as_array()) + { + return parse_float_array(embedding); + } + + eprintln!("error: stdin must be a JSON array of numbers or an OpenAI embedding response"); + std::process::exit(1); +} + +fn parse_float_array(arr: &[Value]) -> Vec { + arr.iter() + .enumerate() + .map(|(i, v)| { + v.as_f64().unwrap_or_else(|| { + eprintln!("error: vector element {i} is not a number: {v}"); + std::process::exit(1); + }) + }) + .collect() +} + +/// Call the OpenAI embeddings API to generate a vector from text. +pub fn openai_embed(text: &str, model: &str) -> Vec { + let api_key = match std::env::var("OPENAI_API_KEY") { + Ok(k) if !k.is_empty() => k, + _ => { + eprintln!("error: OPENAI_API_KEY environment variable is not set"); + std::process::exit(1); + } + }; + + let body = serde_json::json!({ + "input": text, + "model": model, + }); + + let client = reqwest::blocking::Client::new(); + let resp = match client + .post("https://api.openai.com/v1/embeddings") + .header("Authorization", format!("Bearer {api_key}")) + .header("Content-Type", "application/json") + .json(&body) + .send() + { + Ok(r) => r, + Err(e) => { + eprintln!("error connecting to OpenAI API: {e}"); + std::process::exit(1); + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().unwrap_or_default(); + let message = serde_json::from_str::(&body) + .ok() + .and_then(|v| v["error"]["message"].as_str().map(str::to_string)) + .unwrap_or(body); + eprintln!("error from OpenAI API ({status}): {message}"); + std::process::exit(1); + } + + let parsed: Value = match resp.json() { + Ok(v) => v, + Err(e) => { + eprintln!("error parsing OpenAI response: {e}"); + std::process::exit(1); + } + }; + + extract_vector(&parsed) +} + +/// Format a vector as a SQL ARRAY literal: ARRAY[0.1,-0.2,...] +pub fn vector_to_sql(vec: &[f64]) -> String { + format!("ARRAY[{}]", vec.iter().map(|v| v.to_string()).collect::>().join(",")) +} diff --git a/src/main.rs b/src/main.rs index 60a4e31..0756348 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ mod config; mod connections; mod connections_new; mod datasets; +mod embedding; mod indexes; mod jobs; mod queries; @@ -218,20 +219,54 @@ fn main() { } } } - Commands::Search { query, table, column, select, limit, workspace_id, output } => { + Commands::Search { query, table, column, select, limit, model, workspace_id, output } => { let workspace_id = resolve_workspace(workspace_id); - let columns = match select.as_deref() { - Some(cols) => format!("{}, score", cols), - None => "*".to_string(), + let select_cols = select.as_deref().unwrap_or("*"); + + // Determine search mode: + // 1. --model flag: embed the query text via the model provider + // 2. No query + piped stdin: read vector from stdin + // 3. Query text without --model: BM25 text search + let sql = if let Some(ref model_name) = model { + let query_text = match query { + Some(ref q) => q.as_str(), + None => { + eprintln!("error: --model requires a search query text"); + std::process::exit(1); + } + }; + let vec = embedding::openai_embed(query_text, model_name); + let vec_str = embedding::vector_to_sql(&vec); + format!( + "SELECT {}, l2_distance({}, {}) as dist FROM {} ORDER BY dist LIMIT {}", + select_cols, column, vec_str, table, limit, + ) + } else if query.is_none() { + use std::io::IsTerminal; + if std::io::stdin().is_terminal() { + eprintln!("error: provide a search query or pipe a vector via stdin"); + std::process::exit(1); + } + let vec = embedding::read_vector_from_stdin(); + let vec_str = embedding::vector_to_sql(&vec); + format!( + "SELECT {}, l2_distance({}, {}) as dist FROM {} ORDER BY dist LIMIT {}", + select_cols, column, vec_str, table, limit, + ) + } else { + let bm25_columns = match select.as_deref() { + Some(cols) => format!("{}, score", cols), + None => "*".to_string(), + }; + format!( + "SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}", + bm25_columns, + table.replace('\'', "''"), + column.replace('\'', "''"), + query.unwrap().replace('\'', "''"), + limit, + ) }; - let sql = format!( - "SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}", - columns, - table.replace('\'', "''"), - column.replace('\'', "''"), - query.replace('\'', "''"), - limit, - ); query::execute(&sql, &workspace_id, None, &output) } Commands::Queries { id, output, command } => { diff --git a/src/query.rs b/src/query.rs index b1a8e55..4cbb386 100644 --- a/src/query.rs +++ b/src/query.rs @@ -19,7 +19,14 @@ fn value_to_string(v: &Value) -> String { Value::Bool(b) => b.to_string(), Value::Number(n) => n.to_string(), Value::String(s) => s.clone(), - Value::Array(_) | Value::Object(_) => v.to_string(), + Value::Array(arr) => { + let (formatted, count) = crate::table::truncate_array(arr); + match count { + Some(n) => format!("{formatted} ({n} items)"), + None => formatted, + } + } + Value::Object(_) => v.to_string(), } } diff --git a/src/table.rs b/src/table.rs index 7dc4306..f2935ad 100644 --- a/src/table.rs +++ b/src/table.rs @@ -5,6 +5,28 @@ use tabled::settings::{ width::Width, }; +/// Truncate arrays to first 3 + last 3 when over 6 elements. +/// Returns (formatted_values, total_count) where total_count is Some when truncated. +pub fn truncate_array(arr: &[serde_json::Value]) -> (String, Option) { + if arr.len() > 6 { + let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); + let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); + (format!("[{}, ..., {}]", head.join(", "), tail.join(", ")), Some(arr.len())) + } else { + (format!("[{}]", arr.iter().map(|v| v.to_string()).collect::>().join(", ")), None) + } +} + +/// Format an array for styled table output. +fn format_array(arr: &[serde_json::Value]) -> String { + use crossterm::style::Stylize; + let (formatted, count) = truncate_array(arr); + match count { + Some(n) => format!("{formatted} {}", format!("({n} items)").dark_grey()), + None => formatted, + } +} + fn term_width() -> usize { crossterm::terminal::size() .map(|(w, _)| w as usize) @@ -108,6 +130,7 @@ pub fn print_json(headers: &[String], rows: &[Vec]) { colored_cells.push((ri + 1, ci, Color::FG_YELLOW)); b.to_string() } + serde_json::Value::Array(arr) => format_array(arr), _ => v.as_str().map(str::to_string).unwrap_or_else(|| v.to_string()), } })