Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/command.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,10 @@ pub enum Commands {
command: IndexesCommands,
},

/// Full-text search across a table column
/// Full-text or vector search across a table column
Search {
/// Search query text
query: String,
/// Search query text (omit to read a vector from stdin for vector search)
query: Option<String>,

/// Table to search (connection.schema.table)
#[arg(long)]
Expand All @@ -136,6 +136,10 @@ pub enum Commands {
#[arg(long, default_value = "10")]
limit: u32,

/// Embedding model to generate a vector from the query text (e.g. text-embedding-3-small)
#[arg(long, value_parser = ["text-embedding-3-small", "text-embedding-3-large"])]
model: Option<String>,

/// Workspace ID (defaults to first workspace from login)
#[arg(long, short = 'w')]
workspace_id: Option<String>,
Expand Down
118 changes: 118 additions & 0 deletions src/embedding.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
use serde_json::Value;

/// Try to parse a vector from stdin. Accepts either:
/// - A raw JSON array of numbers: [0.1, -0.2, ...]
/// - An OpenAI-compatible response: {"data": [{"embedding": [...]}]}
pub fn read_vector_from_stdin() -> Vec<f64> {
use std::io::Read;
let mut input = String::new();
std::io::stdin().read_to_string(&mut input).unwrap_or_else(|e| {
eprintln!("error reading stdin: {e}");
std::process::exit(1);
});

let input = input.trim();
if input.is_empty() {
eprintln!("error: no vector provided on stdin");
std::process::exit(1);
}

let parsed: Value = match serde_json::from_str(input) {
Ok(v) => v,
Err(e) => {
eprintln!("error parsing vector from stdin: {e}");
std::process::exit(1);
}
};

extract_vector(&parsed)
}

/// Extract a float vector from either a raw JSON array or an OpenAI embedding response.
fn extract_vector(value: &Value) -> Vec<f64> {
// Raw array: [0.1, -0.2, ...]
if let Some(arr) = value.as_array() {
return parse_float_array(arr);
}

// OpenAI response: {"data": [{"embedding": [...]}]}
if let Some(embedding) = value.get("data")
.and_then(|d| d.get(0))
.and_then(|d| d.get("embedding"))
.and_then(|e| e.as_array())
{
return parse_float_array(embedding);
}

eprintln!("error: stdin must be a JSON array of numbers or an OpenAI embedding response");
std::process::exit(1);
}

fn parse_float_array(arr: &[Value]) -> Vec<f64> {
arr.iter()
.enumerate()
.map(|(i, v)| {
v.as_f64().unwrap_or_else(|| {
eprintln!("error: vector element {i} is not a number: {v}");
std::process::exit(1);
})
})
.collect()
}

/// Call the OpenAI embeddings API to generate a vector from text.
pub fn openai_embed(text: &str, model: &str) -> Vec<f64> {
let api_key = match std::env::var("OPENAI_API_KEY") {
Ok(k) if !k.is_empty() => k,
_ => {
eprintln!("error: OPENAI_API_KEY environment variable is not set");
std::process::exit(1);
}
};

let body = serde_json::json!({
"input": text,
"model": model,
});

let client = reqwest::blocking::Client::new();
let resp = match client
.post("https://api.openai.com/v1/embeddings")
.header("Authorization", format!("Bearer {api_key}"))
.header("Content-Type", "application/json")
.json(&body)
.send()
{
Ok(r) => r,
Err(e) => {
eprintln!("error connecting to OpenAI API: {e}");
std::process::exit(1);
}
};

if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().unwrap_or_default();
let message = serde_json::from_str::<Value>(&body)
.ok()
.and_then(|v| v["error"]["message"].as_str().map(str::to_string))
.unwrap_or(body);
eprintln!("error from OpenAI API ({status}): {message}");
std::process::exit(1);
}

let parsed: Value = match resp.json() {
Ok(v) => v,
Err(e) => {
eprintln!("error parsing OpenAI response: {e}");
std::process::exit(1);
}
};

extract_vector(&parsed)
}

/// Format a vector as a SQL ARRAY literal: ARRAY[0.1,-0.2,...]
pub fn vector_to_sql(vec: &[f64]) -> String {
format!("ARRAY[{}]", vec.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(","))
}
59 changes: 47 additions & 12 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ mod config;
mod connections;
mod connections_new;
mod datasets;
mod embedding;
mod indexes;
mod jobs;
mod queries;
Expand Down Expand Up @@ -218,20 +219,54 @@ fn main() {
}
}
}
Commands::Search { query, table, column, select, limit, workspace_id, output } => {
Commands::Search { query, table, column, select, limit, model, workspace_id, output } => {
let workspace_id = resolve_workspace(workspace_id);
let columns = match select.as_deref() {
Some(cols) => format!("{}, score", cols),
None => "*".to_string(),
let select_cols = select.as_deref().unwrap_or("*");

// Determine search mode:
// 1. --model flag: embed the query text via the model provider
// 2. No query + piped stdin: read vector from stdin
// 3. Query text without --model: BM25 text search
let sql = if let Some(ref model_name) = model {
let query_text = match query {
Some(ref q) => q.as_str(),
None => {
eprintln!("error: --model requires a search query text");
std::process::exit(1);
}
};
let vec = embedding::openai_embed(query_text, model_name);
let vec_str = embedding::vector_to_sql(&vec);
format!(
"SELECT {}, l2_distance({}, {}) as dist FROM {} ORDER BY dist LIMIT {}",
Comment thread
pthurlow marked this conversation as resolved.
select_cols, column, vec_str, table, limit,
Comment thread
pthurlow marked this conversation as resolved.
)
} else if query.is_none() {
use std::io::IsTerminal;
if std::io::stdin().is_terminal() {
eprintln!("error: provide a search query or pipe a vector via stdin");
std::process::exit(1);
}
let vec = embedding::read_vector_from_stdin();
let vec_str = embedding::vector_to_sql(&vec);
format!(
"SELECT {}, l2_distance({}, {}) as dist FROM {} ORDER BY dist LIMIT {}",
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same unquoted-identifier issue as above — column and table need identifier quoting here too.

select_cols, column, vec_str, table, limit,
)
} else {
let bm25_columns = match select.as_deref() {
Some(cols) => format!("{}, score", cols),
None => "*".to_string(),
};
format!(
"SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}",
bm25_columns,
table.replace('\'', "''"),
column.replace('\'', "''"),
query.unwrap().replace('\'', "''"),
limit,
)
};
let sql = format!(
"SELECT {} FROM bm25_search('{}', '{}', '{}') ORDER BY score DESC LIMIT {}",
columns,
table.replace('\'', "''"),
column.replace('\'', "''"),
query.replace('\'', "''"),
limit,
);
query::execute(&sql, &workspace_id, None, &output)
}
Commands::Queries { id, output, command } => {
Expand Down
9 changes: 8 additions & 1 deletion src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ fn value_to_string(v: &Value) -> String {
Value::Bool(b) => b.to_string(),
Value::Number(n) => n.to_string(),
Value::String(s) => s.clone(),
Value::Array(_) | Value::Object(_) => v.to_string(),
Value::Array(arr) => {
let (formatted, count) = crate::table::truncate_array(arr);
match count {
Some(n) => format!("{formatted} ({n} items)"),
None => formatted,
}
}
Value::Object(_) => v.to_string(),
}
}

Expand Down
23 changes: 23 additions & 0 deletions src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,28 @@ use tabled::settings::{
width::Width,
};

/// Truncate arrays to first 3 + last 3 when over 6 elements.
/// Returns (formatted_values, total_count) where total_count is Some when truncated.
pub fn truncate_array(arr: &[serde_json::Value]) -> (String, Option<usize>) {
if arr.len() > 6 {
let head: Vec<String> = arr[..3].iter().map(|v| v.to_string()).collect();
let tail: Vec<String> = arr[arr.len()-3..].iter().map(|v| v.to_string()).collect();
(format!("[{}, ..., {}]", head.join(", "), tail.join(", ")), Some(arr.len()))
} else {
(format!("[{}]", arr.iter().map(|v| v.to_string()).collect::<Vec<_>>().join(", ")), None)
}
}

/// Format an array for styled table output.
fn format_array(arr: &[serde_json::Value]) -> String {
use crossterm::style::Stylize;
let (formatted, count) = truncate_array(arr);
match count {
Some(n) => format!("{formatted} {}", format!("({n} items)").dark_grey()),
None => formatted,
}
}

fn term_width() -> usize {
crossterm::terminal::size()
.map(|(w, _)| w as usize)
Expand Down Expand Up @@ -108,6 +130,7 @@ pub fn print_json(headers: &[String], rows: &[Vec<serde_json::Value>]) {
colored_cells.push((ri + 1, ci, Color::FG_YELLOW));
b.to_string()
}
serde_json::Value::Array(arr) => format_array(arr),
_ => v.as_str().map(str::to_string).unwrap_or_else(|| v.to_string()),
}
})
Expand Down
Loading