From 203ba6e073989f50211c935ab472da3c65b878c4 Mon Sep 17 00:00:00 2001 From: Paul Thurlow Date: Mon, 30 Mar 2026 11:23:12 -0700 Subject: [PATCH 1/3] feat(search): add basic vector search (l2_distance) --- src/command.rs | 10 ++-- src/embedding.rs | 118 +++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 59 +++++++++++++++++++----- src/query.rs | 14 +++++- src/table.rs | 16 +++++++ 5 files changed, 201 insertions(+), 16 deletions(-) create mode 100644 src/embedding.rs diff --git a/src/command.rs b/src/command.rs index 01d10b6..b6e6a41 100644 --- a/src/command.rs +++ b/src/command.rs @@ -115,10 +115,10 @@ pub enum Commands { command: IndexesCommands, }, - /// Full-text search across a table column + /// Full-text or vector search across a table column Search { - /// Search query text - query: String, + /// Search query text (omit to read a vector from stdin for vector search) + query: Option, /// Table to search (connection.schema.table) #[arg(long)] @@ -136,6 +136,10 @@ pub enum Commands { #[arg(long, default_value = "10")] limit: u32, + /// Embedding model to generate a vector from the query text (e.g. text-embedding-3-small) + #[arg(long, value_parser = ["text-embedding-3-small", "text-embedding-3-large"])] + model: Option, + /// Workspace ID (defaults to first workspace from login) #[arg(long, short = 'w')] workspace_id: Option, diff --git a/src/embedding.rs b/src/embedding.rs new file mode 100644 index 0000000..11ae4d4 --- /dev/null +++ b/src/embedding.rs @@ -0,0 +1,118 @@ +use serde_json::Value; + +/// Try to parse a vector from stdin. Accepts either: +/// - A raw JSON array of numbers: [0.1, -0.2, ...] +/// - An OpenAI-compatible response: {"data": [{"embedding": [...]}]} +pub fn read_vector_from_stdin() -> Vec { + use std::io::Read; + let mut input = String::new(); + std::io::stdin().read_to_string(&mut input).unwrap_or_else(|e| { + eprintln!("error reading stdin: {e}"); + std::process::exit(1); + }); + + let input = input.trim(); + if input.is_empty() { + eprintln!("error: no vector provided on stdin"); + std::process::exit(1); + } + + let parsed: Value = match serde_json::from_str(input) { + Ok(v) => v, + Err(e) => { + eprintln!("error parsing vector from stdin: {e}"); + std::process::exit(1); + } + }; + + extract_vector(&parsed) +} + +/// Extract a float vector from either a raw JSON array or an OpenAI embedding response. +fn extract_vector(value: &Value) -> Vec { + // Raw array: [0.1, -0.2, ...] + if let Some(arr) = value.as_array() { + return parse_float_array(arr); + } + + // OpenAI response: {"data": [{"embedding": [...]}]} + if let Some(embedding) = value.get("data") + .and_then(|d| d.get(0)) + .and_then(|d| d.get("embedding")) + .and_then(|e| e.as_array()) + { + return parse_float_array(embedding); + } + + eprintln!("error: stdin must be a JSON array of numbers or an OpenAI embedding response"); + std::process::exit(1); +} + +fn parse_float_array(arr: &[Value]) -> Vec { + arr.iter() + .enumerate() + .map(|(i, v)| { + v.as_f64().unwrap_or_else(|| { + eprintln!("error: vector element {i} is not a number: {v}"); + std::process::exit(1); + }) + }) + .collect() +} + +/// Call the OpenAI embeddings API to generate a vector from text. +pub fn openai_embed(text: &str, model: &str) -> Vec { + let api_key = match std::env::var("OPENAI_API_KEY") { + Ok(k) if !k.is_empty() => k, + _ => { + eprintln!("error: OPENAI_API_KEY environment variable is not set"); + std::process::exit(1); + } + }; + + let body = serde_json::json!({ + "input": text, + "model": model, + }); + + let client = reqwest::blocking::Client::new(); + let resp = match client + .post("https://api.openai.com/v1/embeddings") + .header("Authorization", format!("Bearer {api_key}")) + .header("Content-Type", "application/json") + .json(&body) + .send() + { + Ok(r) => r, + Err(e) => { + eprintln!("error connecting to OpenAI API: {e}"); + std::process::exit(1); + } + }; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().unwrap_or_default(); + let message = serde_json::from_str::(&body) + .ok() + .and_then(|v| v["error"]["message"].as_str().map(str::to_string)) + .unwrap_or(body); + eprintln!("error from OpenAI API ({status}): {message}"); + std::process::exit(1); + } + + let parsed: Value = match resp.json() { + Ok(v) => v, + Err(e) => { + eprintln!("error parsing OpenAI response: {e}"); + std::process::exit(1); + } + }; + + extract_vector(&parsed) +} + +/// Format a vector as a SQL ARRAY literal: ARRAY[0.1,-0.2,...] +pub fn vector_to_sql(vec: &[f64]) -> String { + format!("ARRAY[{}]", vec.iter().map(|v| v.to_string()).collect::>().join(",")) +} diff --git a/src/main.rs b/src/main.rs index 60a4e31..0756348 100644 --- a/src/main.rs +++ b/src/main.rs @@ -5,6 +5,7 @@ mod config; mod connections; mod connections_new; mod datasets; +mod embedding; mod indexes; mod jobs; mod queries; @@ -218,20 +219,54 @@ fn main() { } } } - Commands::Search { query, table, column, select, limit, workspace_id, output } => { + Commands::Search { query, table, column, select, limit, model, workspace_id, output } => { let workspace_id = resolve_workspace(workspace_id); - let columns = match select.as_deref() { - Some(cols) => format!("{}, score", cols), - None => "*".to_string(), + let select_cols = select.as_deref().unwrap_or("*"); + + // Determine search mode: + // 1. --model flag: embed the query text via the model provider + // 2. No query + piped stdin: read vector from stdin + // 3. Query text without --model: BM25 text search + let sql = if let Some(ref model_name) = model { + let query_text = match query { + Some(ref q) => q.as_str(), + None => { + eprintln!("error: --model requires a search query text"); + std::process::exit(1); + } + }; + let vec = embedding::openai_embed(query_text, model_name); + let vec_str = embedding::vector_to_sql(&vec); + format!( + "SELECT {}, l2_distance({}, {}) as dist FROM {} ORDER BY dist LIMIT {}", + select_cols, column, vec_str, table, limit, + ) + } else if query.is_none() { + use std::io::IsTerminal; + if std::io::stdin().is_terminal() { + eprintln!("error: provide a search query or pipe a vector via stdin"); + std::process::exit(1); + } + let vec = embedding::read_vector_from_stdin(); + let vec_str = embedding::vector_to_sql(&vec); + format!( + "SELECT {}, l2_distance({}, {}) as dist FROM {} ORDER BY dist LIMIT {}", + select_cols, column, vec_str, table, limit, + ) + } else { + let bm25_columns = match select.as_deref() { + Some(cols) => format!("{}, score", cols), + None => "*".to_string(), + }; + format!( + "SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}", + bm25_columns, + table.replace('\'', "''"), + column.replace('\'', "''"), + query.unwrap().replace('\'', "''"), + limit, + ) }; - let sql = format!( - "SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}", - columns, - table.replace('\'', "''"), - column.replace('\'', "''"), - query.replace('\'', "''"), - limit, - ); query::execute(&sql, &workspace_id, None, &output) } Commands::Queries { id, output, command } => { diff --git a/src/query.rs b/src/query.rs index b1a8e55..ed96119 100644 --- a/src/query.rs +++ b/src/query.rs @@ -13,13 +13,25 @@ pub struct QueryResponse { pub warning: Option, } +fn format_array(arr: &[Value]) -> String { + let is_numeric = arr.iter().all(|v| v.is_number()); + if is_numeric && arr.len() > 6 { + let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); + let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); + format!("[{}, ... {}, {}]", head.join(", "), tail.join(", "), format!("({} total)", arr.len())) + } else { + format!("[{}]", arr.iter().map(|v| value_to_string(v)).collect::>().join(", ")) + } +} + fn value_to_string(v: &Value) -> String { match v { Value::Null => "NULL".to_string(), Value::Bool(b) => b.to_string(), Value::Number(n) => n.to_string(), Value::String(s) => s.clone(), - Value::Array(_) | Value::Object(_) => v.to_string(), + Value::Array(arr) => format_array(arr), + Value::Object(_) => v.to_string(), } } diff --git a/src/table.rs b/src/table.rs index 7dc4306..d7a8973 100644 --- a/src/table.rs +++ b/src/table.rs @@ -5,6 +5,18 @@ use tabled::settings::{ width::Width, }; +/// Truncate numeric arrays to first 3 + last 3 when over 6 elements. +fn format_array(arr: &[serde_json::Value]) -> String { + let is_numeric = arr.iter().all(|v| v.is_number()); + if is_numeric && arr.len() > 6 { + let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); + let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); + format!("[{}, ..., {}] ({} items)", head.join(", "), tail.join(", "), arr.len()) + } else { + format!("[{}]", arr.iter().map(|v| v.to_string()).collect::>().join(", ")) + } +} + fn term_width() -> usize { crossterm::terminal::size() .map(|(w, _)| w as usize) @@ -108,6 +120,10 @@ pub fn print_json(headers: &[String], rows: &[Vec]) { colored_cells.push((ri + 1, ci, Color::FG_YELLOW)); b.to_string() } + serde_json::Value::Array(arr) => { + colored_cells.push((ri + 1, ci, Color::FG_BRIGHT_BLACK)); + format_array(arr) + } _ => v.as_str().map(str::to_string).unwrap_or_else(|| v.to_string()), } }) From c6c2dfaabe8ee989672e8468522d91ef227077a7 Mon Sep 17 00:00:00 2001 From: Paul Thurlow Date: Mon, 30 Mar 2026 11:50:11 -0700 Subject: [PATCH 2/3] tweak table formatting --- src/query.rs | 5 ++--- src/table.rs | 13 +++++-------- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/query.rs b/src/query.rs index ed96119..5f0cf4e 100644 --- a/src/query.rs +++ b/src/query.rs @@ -14,11 +14,10 @@ pub struct QueryResponse { } fn format_array(arr: &[Value]) -> String { - let is_numeric = arr.iter().all(|v| v.is_number()); - if is_numeric && arr.len() > 6 { + if arr.len() > 6 { let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); - format!("[{}, ... {}, {}]", head.join(", "), tail.join(", "), format!("({} total)", arr.len())) + format!("[{}, ..., {}] ({} items)", head.join(", "), tail.join(", "), arr.len()) } else { format!("[{}]", arr.iter().map(|v| value_to_string(v)).collect::>().join(", ")) } diff --git a/src/table.rs b/src/table.rs index d7a8973..4753cbd 100644 --- a/src/table.rs +++ b/src/table.rs @@ -5,13 +5,13 @@ use tabled::settings::{ width::Width, }; -/// Truncate numeric arrays to first 3 + last 3 when over 6 elements. +/// Truncate arrays to first 3 + last 3 when over 6 elements. fn format_array(arr: &[serde_json::Value]) -> String { - let is_numeric = arr.iter().all(|v| v.is_number()); - if is_numeric && arr.len() > 6 { + use crossterm::style::Stylize; + if arr.len() > 6 { let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); - format!("[{}, ..., {}] ({} items)", head.join(", "), tail.join(", "), arr.len()) + format!("[{}, ..., {}] {}", head.join(", "), tail.join(", "), format!("({} items)", arr.len()).dark_grey()) } else { format!("[{}]", arr.iter().map(|v| v.to_string()).collect::>().join(", ")) } @@ -120,10 +120,7 @@ pub fn print_json(headers: &[String], rows: &[Vec]) { colored_cells.push((ri + 1, ci, Color::FG_YELLOW)); b.to_string() } - serde_json::Value::Array(arr) => { - colored_cells.push((ri + 1, ci, Color::FG_BRIGHT_BLACK)); - format_array(arr) - } + serde_json::Value::Array(arr) => format_array(arr), _ => v.as_str().map(str::to_string).unwrap_or_else(|| v.to_string()), } }) From 8bb5379263b172cd2f892ab763ebc58740ddca37 Mon Sep 17 00:00:00 2001 From: Paul Thurlow Date: Mon, 30 Mar 2026 12:04:39 -0700 Subject: [PATCH 3/3] simplify table array formatting logic --- src/query.rs | 18 +++++++----------- src/table.rs | 18 ++++++++++++++---- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/src/query.rs b/src/query.rs index 5f0cf4e..4cbb386 100644 --- a/src/query.rs +++ b/src/query.rs @@ -13,23 +13,19 @@ pub struct QueryResponse { pub warning: Option, } -fn format_array(arr: &[Value]) -> String { - if arr.len() > 6 { - let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); - let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); - format!("[{}, ..., {}] ({} items)", head.join(", "), tail.join(", "), arr.len()) - } else { - format!("[{}]", arr.iter().map(|v| value_to_string(v)).collect::>().join(", ")) - } -} - fn value_to_string(v: &Value) -> String { match v { Value::Null => "NULL".to_string(), Value::Bool(b) => b.to_string(), Value::Number(n) => n.to_string(), Value::String(s) => s.clone(), - Value::Array(arr) => format_array(arr), + Value::Array(arr) => { + let (formatted, count) = crate::table::truncate_array(arr); + match count { + Some(n) => format!("{formatted} ({n} items)"), + None => formatted, + } + } Value::Object(_) => v.to_string(), } } diff --git a/src/table.rs b/src/table.rs index 4753cbd..f2935ad 100644 --- a/src/table.rs +++ b/src/table.rs @@ -6,14 +6,24 @@ use tabled::settings::{ }; /// Truncate arrays to first 3 + last 3 when over 6 elements. -fn format_array(arr: &[serde_json::Value]) -> String { - use crossterm::style::Stylize; +/// Returns (formatted_values, total_count) where total_count is Some when truncated. +pub fn truncate_array(arr: &[serde_json::Value]) -> (String, Option) { if arr.len() > 6 { let head: Vec = arr[..3].iter().map(|v| v.to_string()).collect(); let tail: Vec = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect(); - format!("[{}, ..., {}] {}", head.join(", "), tail.join(", "), format!("({} items)", arr.len()).dark_grey()) + (format!("[{}, ..., {}]", head.join(", "), tail.join(", ")), Some(arr.len())) } else { - format!("[{}]", arr.iter().map(|v| v.to_string()).collect::>().join(", ")) + (format!("[{}]", arr.iter().map(|v| v.to_string()).collect::>().join(", ")), None) + } +} + +/// Format an array for styled table output. +fn format_array(arr: &[serde_json::Value]) -> String { + use crossterm::style::Stylize; + let (formatted, count) = truncate_array(arr); + match count { + Some(n) => format!("{formatted} {}", format!("({n} items)").dark_grey()), + None => formatted, } }