From caa792d4ed3b49dd6932a512fcc31d5d8a39c69e Mon Sep 17 00:00:00 2001 From: "jianjian.xie" Date: Wed, 25 Feb 2026 23:19:12 -0800 Subject: [PATCH 1/2] feat(sql): add direct SQL query support for data analytics Add SqlQuery and SqlEngine that let users run standard SQL directly against their datasets without requiring a GraphConfig. This is useful for data analytics workflows where users want explicit JOINs and aggregations against node/relationship tables. DataFusion handles SQL parsing and execution. --- crates/lance-graph-python/src/graph.rs | 255 ++++++++++++++- crates/lance-graph/src/lib.rs | 2 + crates/lance-graph/src/query.rs | 4 +- crates/lance-graph/src/sql_query.rs | 356 +++++++++++++++++++++ crates/lance-graph/tests/test_sql_query.rs | 338 +++++++++++++++++++ python/python/lance_graph/__init__.py | 4 + python/python/tests/test_sql.py | 175 ++++++++++ 7 files changed, 1115 insertions(+), 19 deletions(-) create mode 100644 crates/lance-graph/src/sql_query.rs create mode 100644 crates/lance-graph/tests/test_sql_query.rs create mode 100644 python/python/tests/test_sql.py diff --git a/crates/lance-graph-python/src/graph.rs b/crates/lance-graph-python/src/graph.rs index b7a4337c..4a9c3c30 100644 --- a/crates/lance-graph-python/src/graph.rs +++ b/crates/lance-graph-python/src/graph.rs @@ -27,7 +27,7 @@ use lance_graph::{ ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause}, CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy, GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog, - VectorSearch as RustVectorSearch, + SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch, }; use pyo3::{ exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError}, @@ -710,7 +710,9 @@ impl CypherQuery { vector_search: &VectorSearch, ) -> PyResult { if vector_search.use_lance_index { - if let Some(result) = try_execute_with_lance_index(py, &self.inner, datasets, vector_search)? { + if let Some(result) = + try_execute_with_lance_index(py, &self.inner, datasets, vector_search)? + { return record_batch_to_python_table(py, &result); } } @@ -917,10 +919,7 @@ fn split_vector_column(column: &str) -> (Option, &str) { } } -fn resolve_vector_label( - query: &RustCypherQuery, - alias: Option<&str>, -) -> PyResult> { +fn resolve_vector_label(query: &RustCypherQuery, alias: Option<&str>) -> PyResult> { let alias_map = alias_map_from_query(query); if let Some(alias) = alias { return Ok(alias_map.get(alias).cloned()); @@ -956,11 +955,17 @@ fn collect_aliases_from_pattern(pattern: &GraphPattern, map: &mut HashMap { - if let (Some(var), Some(label)) = (path.start_node.variable.as_ref(), path.start_node.labels.first()) { + if let (Some(var), Some(label)) = ( + path.start_node.variable.as_ref(), + path.start_node.labels.first(), + ) { map.entry(var.clone()).or_insert_with(|| label.clone()); } for segment in &path.segments { - if let (Some(var), Some(label)) = (segment.end_node.variable.as_ref(), segment.end_node.labels.first()) { + if let (Some(var), Some(label)) = ( + segment.end_node.variable.as_ref(), + segment.end_node.labels.first(), + ) { map.entry(var.clone()).or_insert_with(|| label.clone()); } } @@ -1168,14 +1173,20 @@ impl CypherEngine { // Register all datasets as tables for (name, batch) in &arrow_datasets { let mem_table = Arc::new( - MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]) - .map_err(|e| PyRuntimeError::new_err(format!("Failed to create MemTable for {}: {}", name, e)))?, + MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]).map_err(|e| { + PyRuntimeError::new_err(format!( + "Failed to create MemTable for {}: {}", + name, e + )) + })?, ); // Register in session context for execution let normalized_name = name.to_lowercase(); ctx.register_table(&normalized_name, mem_table.clone()) - .map_err(|e| PyRuntimeError::new_err(format!("Failed to register table {}: {}", name, e)))?; + .map_err(|e| { + PyRuntimeError::new_err(format!("Failed to register table {}: {}", name, e)) + })?; let table_source = Arc::new(DefaultTableSource::new(mem_table)); @@ -1186,7 +1197,7 @@ impl CypherEngine { // based on the Cypher query pattern (e.g., MATCH (p:Person) vs -[:KNOWS]->). // // By registering all datasets in both catalogs, we allow the planner to look up - // the correct source based on query context. This pattern matches the Rust + // the correct source based on query context. This pattern matches the Rust // implementation in query.rs:build_catalog_and_context_from_datasets. catalog = catalog .with_node_source(name, table_source.clone()) @@ -1226,11 +1237,7 @@ impl CypherEngine { /// -------- /// >>> result = engine.execute("MATCH (p:Person) WHERE p.age > 30 RETURN p.name") /// >>> print(result.to_pandas()) - fn execute( - &self, - py: Python, - query: &str, - ) -> PyResult { + fn execute(&self, py: Python, query: &str) -> PyResult { // Parse the query let cypher_query = RustCypherQuery::new(query) .map_err(graph_error_to_pyerr)? @@ -1314,6 +1321,218 @@ impl CypherEngine { } } +/// Execute raw SQL queries against in-memory datasets +/// +/// This class allows executing standard SQL directly against Arrow tables, +/// without requiring a GraphConfig or Cypher parsing. DataFusion handles +/// SQL parsing and execution. +/// +/// Examples +/// -------- +/// >>> import pyarrow as pa +/// >>> from lance_graph import SqlQuery +/// >>> +/// >>> person = pa.table({"id": [1, 2], "name": ["Alice", "Bob"], "age": [28, 34]}) +/// >>> query = SqlQuery("SELECT name, age FROM person WHERE age > 30") +/// >>> result = query.execute({"person": person}) +/// >>> print(result.to_pandas()) +#[pyclass(name = "SqlQuery", module = "lance.graph")] +pub struct SqlQuery { + inner: RustSqlQuery, +} + +#[pymethods] +impl SqlQuery { + /// Create a new SQL query + /// + /// Parameters + /// ---------- + /// sql : str + /// The SQL query string + /// + /// Returns + /// ------- + /// SqlQuery + /// A new SQL query instance + #[new] + fn new(sql: &str) -> Self { + Self { + inner: RustSqlQuery::new(sql), + } + } + + /// Get the SQL query text + fn sql(&self) -> &str { + self.inner.sql() + } + + /// Execute query against in-memory datasets + /// + /// Parameters + /// ---------- + /// datasets : dict + /// Dictionary mapping table names to PyArrow tables or Lance datasets + /// + /// Returns + /// ------- + /// pyarrow.Table + /// Query results as Arrow table + /// + /// Raises + /// ------ + /// RuntimeError + /// If query execution fails + fn execute(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult { + let arrow_datasets = python_datasets_to_batches(datasets)?; + let inner = self.inner.clone(); + + let result_batch = RT + .block_on(Some(py), inner.execute(arrow_datasets))? + .map_err(graph_error_to_pyerr)?; + + record_batch_to_python_table(py, &result_batch) + } + + /// Return the execution plan as a string + /// + /// Parameters + /// ---------- + /// datasets : dict + /// Dictionary mapping table names to PyArrow tables or Lance datasets + /// + /// Returns + /// ------- + /// str + /// The DataFusion logical and physical execution plan + /// + /// Raises + /// ------ + /// RuntimeError + /// If planning fails + fn explain(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult { + let arrow_datasets = python_datasets_to_batches(datasets)?; + let inner = self.inner.clone(); + + let plan = RT + .block_on(Some(py), inner.explain(arrow_datasets))? + .map_err(graph_error_to_pyerr)?; + + Ok(plan) + } + + fn __repr__(&self) -> String { + format!("SqlQuery(\"{}\")", self.inner.sql()) + } +} + +/// Cached SQL execution engine for running multiple queries against the same datasets +/// +/// This class registers datasets once during initialization and reuses the +/// DataFusion SessionContext for subsequent queries, avoiding repeated setup. +/// +/// Examples +/// -------- +/// >>> from lance_graph import SqlEngine +/// >>> import pyarrow as pa +/// >>> +/// >>> datasets = { +/// ... "person": pa.table({"id": [1, 2], "name": ["Alice", "Bob"], "age": [28, 34]}), +/// ... "knows": pa.table({"src": [1], "dst": [2]}), +/// ... } +/// >>> +/// >>> engine = SqlEngine(datasets) +/// >>> result1 = engine.execute("SELECT * FROM person WHERE age > 30") +/// >>> result2 = engine.execute("SELECT p.name FROM person p JOIN knows k ON p.id = k.src") +#[pyclass(name = "SqlEngine", module = "lance.graph")] +pub struct SqlEngine { + context: Arc, +} + +#[pymethods] +impl SqlEngine { + /// Create a new SqlEngine with cached datasets + /// + /// Parameters + /// ---------- + /// datasets : dict + /// Dictionary mapping table names to PyArrow tables or Lance datasets. + /// Table names are lowercased for consistency. + /// + /// Returns + /// ------- + /// SqlEngine + /// A new engine instance ready to execute queries + /// + /// Raises + /// ------ + /// ValueError + /// If no datasets are provided + /// RuntimeError + /// If table registration fails + #[new] + fn new(datasets: &Bound<'_, PyDict>) -> PyResult { + let arrow_datasets = python_datasets_to_batches(datasets)?; + + if arrow_datasets.is_empty() { + return Err(PyValueError::new_err("No input datasets provided")); + } + + let ctx = SessionContext::new(); + + for (name, batch) in &arrow_datasets { + let mem_table = Arc::new( + MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]).map_err(|e| { + PyRuntimeError::new_err(format!( + "Failed to create MemTable for {}: {}", + name, e + )) + })?, + ); + + let normalized_name = name.to_lowercase(); + ctx.register_table(&normalized_name, mem_table) + .map_err(|e| { + PyRuntimeError::new_err(format!("Failed to register table {}: {}", name, e)) + })?; + } + + Ok(Self { + context: Arc::new(ctx), + }) + } + + /// Execute a SQL query using the cached datasets + /// + /// Parameters + /// ---------- + /// sql : str + /// The SQL query string to execute + /// + /// Returns + /// ------- + /// pyarrow.Table + /// Query results as Arrow table + /// + /// Raises + /// ------ + /// RuntimeError + /// If query execution fails + fn execute(&self, py: Python, sql: &str) -> PyResult { + let query = RustSqlQuery::new(sql); + let context = self.context.as_ref().clone(); + + let result_batch = RT + .block_on(Some(py), query.execute_with_context(context))? + .map_err(graph_error_to_pyerr)?; + + record_batch_to_python_table(py, &result_batch) + } + + fn __repr__(&self) -> String { + "SqlEngine(...)".to_string() + } +} + /// Register graph functionality with the Python module pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> PyResult<()> { let graph_module = PyModule::new(py, "graph")?; @@ -1324,6 +1543,8 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> graph_module.add_class::()?; graph_module.add_class::()?; graph_module.add_class::()?; + graph_module.add_class::()?; + graph_module.add_class::()?; graph_module.add_class::()?; graph_module.add_class::()?; diff --git a/crates/lance-graph/src/lib.rs b/crates/lance-graph/src/lib.rs index 692773ad..25d8ebec 100644 --- a/crates/lance-graph/src/lib.rs +++ b/crates/lance-graph/src/lib.rs @@ -47,6 +47,7 @@ pub mod parser; pub mod query; pub mod semantic; pub mod simple_executor; +pub mod sql_query; /// Maximum allowed hops for variable-length relationship expansion (e.g., *1..N) pub const MAX_VARIABLE_LENGTH_HOPS: u32 = 20; @@ -58,3 +59,4 @@ pub use lance_graph_catalog::{ }; pub use lance_vector_search::VectorSearch; pub use query::{CypherQuery, ExecutionStrategy}; +pub use sql_query::SqlQuery; diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index bdf33384..cc4f96c9 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -23,7 +23,7 @@ use std::sync::Arc; /// /// This ensures that column names in the dataset match the normalized /// qualified column names used internally (e.g., "fullName" becomes "fullname"). -fn normalize_schema(schema: SchemaRef) -> Result { +pub(crate) fn normalize_schema(schema: SchemaRef) -> Result { let fields: Vec<_> = schema .fields() .iter() @@ -42,7 +42,7 @@ fn normalize_schema(schema: SchemaRef) -> Result { /// /// This creates a new RecordBatch with a normalized schema while /// preserving all the data arrays. -fn normalize_record_batch(batch: &RecordBatch) -> Result { +pub(crate) fn normalize_record_batch(batch: &RecordBatch) -> Result { let normalized_schema = normalize_schema(batch.schema())?; RecordBatch::try_new(normalized_schema, batch.columns().to_vec()).map_err(|e| { GraphError::PlanError { diff --git a/crates/lance-graph/src/sql_query.rs b/crates/lance-graph/src/sql_query.rs new file mode 100644 index 00000000..97705d1b --- /dev/null +++ b/crates/lance-graph/src/sql_query.rs @@ -0,0 +1,356 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Direct SQL query interface for Lance datasets +//! +//! This module provides a way to execute standard SQL queries directly against +//! in-memory datasets (as RecordBatches) or a pre-configured DataFusion SessionContext, +//! without requiring a GraphConfig or Cypher parsing. + +use crate::error::{GraphError, Result}; +use crate::query::{normalize_record_batch, normalize_schema}; +use arrow_array::RecordBatch; +use datafusion::datasource::MemTable; +use datafusion::execution::context::SessionContext; +use std::collections::HashMap; +use std::sync::Arc; + +/// A SQL query that can be executed against in-memory datasets or a DataFusion SessionContext. +/// +/// Unlike `CypherQuery`, this does not require a `GraphConfig` — users write standard SQL +/// with explicit JOINs against their node/relationship tables. +/// +/// # Example +/// +/// ```no_run +/// use lance_graph::SqlQuery; +/// use arrow_array::RecordBatch; +/// use std::collections::HashMap; +/// +/// # async fn example() -> lance_graph::Result<()> { +/// let mut datasets: HashMap = HashMap::new(); +/// // datasets.insert("person".to_string(), person_batch); +/// +/// let query = SqlQuery::new("SELECT name, age FROM person WHERE age > 30"); +/// // let result = query.execute(datasets).await?; +/// # Ok(()) +/// # } +/// ``` +#[derive(Debug, Clone)] +pub struct SqlQuery { + sql: String, +} + +impl SqlQuery { + /// Create a new SQL query from a SQL string. + /// + /// No parsing is done at construction time — the SQL is validated when executed. + pub fn new(sql: &str) -> Self { + Self { + sql: sql.to_string(), + } + } + + /// Get the SQL query text. + pub fn sql(&self) -> &str { + &self.sql + } + + /// Execute the SQL query against in-memory datasets. + /// + /// Each entry in `datasets` is registered as a table in a fresh DataFusion + /// SessionContext. Table names are lowercased for consistency. + /// + /// # Arguments + /// * `datasets` - HashMap of table name to RecordBatch + /// + /// # Returns + /// A single `RecordBatch` containing all result rows. + pub async fn execute(&self, datasets: HashMap) -> Result { + let ctx = self.build_context(datasets)?; + self.execute_with_context(ctx).await + } + + /// Execute the SQL query against a pre-configured DataFusion SessionContext. + /// + /// Use this when tables are already registered (e.g., CSV/Parquet files, + /// external data sources, or a context shared across queries). + pub async fn execute_with_context(&self, ctx: SessionContext) -> Result { + let df = ctx + .sql(&self.sql) + .await + .map_err(|e| GraphError::PlanError { + message: format!("SQL execution error: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + let batches = df.collect().await.map_err(|e| GraphError::PlanError { + message: format!("Failed to collect SQL results: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + if batches.is_empty() { + // Return an empty batch with the schema from the logical plan + let schema = df_schema_from_ctx(&ctx, &self.sql).await?; + return Ok(RecordBatch::new_empty(schema)); + } + + let schema = batches[0].schema(); + arrow::compute::concat_batches(&schema, &batches).map_err(|e| GraphError::PlanError { + message: format!("Failed to concatenate result batches: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + }) + } + + /// Return the DataFusion execution plan as a formatted string. + /// + /// Useful for debugging and understanding how the query will be executed. + pub async fn explain(&self, datasets: HashMap) -> Result { + let ctx = self.build_context(datasets)?; + + let df = ctx + .sql(&self.sql) + .await + .map_err(|e| GraphError::PlanError { + message: format!("SQL explain error: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + let logical_plan = df.logical_plan(); + + let physical_plan = ctx + .state() + .create_physical_plan(logical_plan) + .await + .map_err(|e| GraphError::PlanError { + message: format!("Failed to create physical plan: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + let physical_plan_str = datafusion::physical_plan::displayable(physical_plan.as_ref()) + .indent(true) + .to_string(); + + Ok(format!( + "== Logical Plan ==\n{}\n\n== Physical Plan ==\n{}", + logical_plan.display_indent(), + physical_plan_str, + )) + } + + /// Build a DataFusion SessionContext from in-memory datasets. + fn build_context(&self, datasets: HashMap) -> Result { + let ctx = SessionContext::new(); + + for (name, batch) in datasets { + let normalized_batch = normalize_record_batch(&batch)?; + let schema = normalized_batch.schema(); + let mem_table = Arc::new( + MemTable::try_new(schema, vec![vec![normalized_batch]]).map_err(|e| { + GraphError::PlanError { + message: format!("Failed to create MemTable for {}: {}", name, e), + location: snafu::Location::new(file!(), line!(), column!()), + } + })?, + ); + + let normalized_name = name.to_lowercase(); + ctx.register_table(&normalized_name, mem_table) + .map_err(|e| GraphError::PlanError { + message: format!("Failed to register table {}: {}", name, e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + } + + Ok(ctx) + } +} + +/// Helper to get the output schema from a SQL query without executing it. +async fn df_schema_from_ctx(ctx: &SessionContext, sql: &str) -> Result> { + let df = ctx.sql(sql).await.map_err(|e| GraphError::PlanError { + message: format!("Failed to plan SQL for schema: {}", e), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + let arrow_schema = Arc::new(arrow_schema::Schema::from(df.schema())); + normalize_schema(arrow_schema) +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::{Float64Array, Int64Array, StringArray}; + use arrow_schema::{DataType, Field, Schema}; + + fn person_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int64, false), + Field::new("city", DataType::Utf8, false), + ])); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])), + Arc::new(Int64Array::from(vec![28, 34, 29, 42])), + Arc::new(StringArray::from(vec![ + "New York", + "San Francisco", + "New York", + "Chicago", + ])), + ], + ) + .unwrap() + } + + fn datasets_with(name: &str, batch: RecordBatch) -> HashMap { + let mut datasets = HashMap::new(); + datasets.insert(name.to_string(), batch); + datasets + } + + #[tokio::test] + async fn test_basic_select() { + let query = SqlQuery::new("SELECT name, age FROM person WHERE age > 30 ORDER BY age"); + let result = query + .execute(datasets_with("person", person_batch())) + .await + .unwrap(); + + let names: Vec<&str> = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + assert_eq!(names, vec!["Bob", "David"]); + } + + #[tokio::test] + async fn test_select_star() { + let query = SqlQuery::new("SELECT * FROM person"); + let result = query + .execute(datasets_with("person", person_batch())) + .await + .unwrap(); + assert_eq!(result.num_rows(), 4); + assert_eq!(result.num_columns(), 4); + } + + #[tokio::test] + async fn test_limit() { + let query = SqlQuery::new("SELECT name FROM person ORDER BY name LIMIT 2"); + let result = query + .execute(datasets_with("person", person_batch())) + .await + .unwrap(); + assert_eq!(result.num_rows(), 2); + } + + #[tokio::test] + async fn test_aggregation() { + let query = SqlQuery::new( + "SELECT COUNT(*) as cnt, AVG(age) as avg_age, SUM(age) as total_age FROM person", + ); + let result = query + .execute(datasets_with("person", person_batch())) + .await + .unwrap(); + assert_eq!(result.num_rows(), 1); + + let cnt = result + .column_by_name("cnt") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(cnt, 4); + + let avg_age = result + .column_by_name("avg_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert!((avg_age - 33.25).abs() < 0.01); + } + + #[tokio::test] + async fn test_group_by() { + let query = SqlQuery::new( + "SELECT city, COUNT(*) as cnt FROM person GROUP BY city ORDER BY cnt DESC", + ); + let result = query + .execute(datasets_with("person", person_batch())) + .await + .unwrap(); + + let cities: Vec<&str> = result + .column_by_name("city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + // New York has 2, others have 1 + assert_eq!(cities[0], "New York"); + } + + #[tokio::test] + async fn test_invalid_sql() { + let query = SqlQuery::new("INVALID SQL STATEMENT"); + let result = query.execute(datasets_with("person", person_batch())).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn test_explain() { + let query = SqlQuery::new("SELECT name FROM person WHERE age > 30"); + let plan = query + .explain(datasets_with("person", person_batch())) + .await + .unwrap(); + assert!(plan.contains("Logical Plan")); + assert!(plan.contains("Physical Plan")); + } + + #[tokio::test] + async fn test_execute_with_context() { + // Build context manually and execute against it + let ctx = SessionContext::new(); + let batch = person_batch(); + let schema = batch.schema(); + let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap()); + ctx.register_table("people", mem_table).unwrap(); + + let query = SqlQuery::new("SELECT name FROM people ORDER BY name LIMIT 1"); + let result = query.execute_with_context(ctx).await.unwrap(); + + let names: Vec<&str> = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + assert_eq!(names, vec!["Alice"]); + } + + #[tokio::test] + async fn test_sql_text_accessor() { + let query = SqlQuery::new("SELECT 1"); + assert_eq!(query.sql(), "SELECT 1"); + } +} diff --git a/crates/lance-graph/tests/test_sql_query.rs b/crates/lance-graph/tests/test_sql_query.rs new file mode 100644 index 00000000..09baad9f --- /dev/null +++ b/crates/lance-graph/tests/test_sql_query.rs @@ -0,0 +1,338 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Integration tests for SqlQuery + +use arrow_array::{Float64Array, Int64Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use lance_graph::SqlQuery; +use std::collections::HashMap; +use std::sync::Arc; + +fn person_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + Field::new("age", DataType::Int64, false), + Field::new("city", DataType::Utf8, false), + ])); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3, 4])), + Arc::new(StringArray::from(vec!["Alice", "Bob", "Carol", "David"])), + Arc::new(Int64Array::from(vec![28, 34, 29, 42])), + Arc::new(StringArray::from(vec![ + "New York", + "San Francisco", + "New York", + "Chicago", + ])), + ], + ) + .unwrap() +} + +fn knows_batch() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("src_id", DataType::Int64, false), + Field::new("dst_id", DataType::Int64, false), + Field::new("since_year", DataType::Int64, false), + ])); + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 1, 2, 3])), + Arc::new(Int64Array::from(vec![2, 3, 4, 4])), + Arc::new(Int64Array::from(vec![2015, 2018, 2020, 2021])), + ], + ) + .unwrap() +} + +fn make_datasets() -> HashMap { + let mut datasets = HashMap::new(); + datasets.insert("person".to_string(), person_batch()); + datasets.insert("knows".to_string(), knows_batch()); + datasets +} + +// ============================================================================ +// Basic SELECT with WHERE, ORDER BY, LIMIT +// ============================================================================ + +#[tokio::test] +async fn test_select_with_where_order_by_limit() { + let query = SqlQuery::new("SELECT name, age FROM person WHERE age > 30 ORDER BY age LIMIT 10"); + let result = query.execute(make_datasets()).await.unwrap(); + + let names: Vec<&str> = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + + assert_eq!(names, vec!["Bob", "David"]); + + let ages: Vec = result + .column_by_name("age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + + assert_eq!(ages, vec![34, 42]); +} + +#[tokio::test] +async fn test_select_star() { + let query = SqlQuery::new("SELECT * FROM person ORDER BY id"); + let result = query.execute(make_datasets()).await.unwrap(); + assert_eq!(result.num_rows(), 4); + assert_eq!(result.num_columns(), 4); +} + +#[tokio::test] +async fn test_select_limit() { + let query = SqlQuery::new("SELECT name FROM person ORDER BY name LIMIT 2"); + let result = query.execute(make_datasets()).await.unwrap(); + assert_eq!(result.num_rows(), 2); +} + +// ============================================================================ +// JOINs between node and relationship tables +// ============================================================================ + +#[tokio::test] +async fn test_inner_join() { + let query = SqlQuery::new( + "SELECT p.name, k.dst_id, k.since_year \ + FROM person p \ + JOIN knows k ON p.id = k.src_id \ + ORDER BY p.name, k.dst_id", + ); + let result = query.execute(make_datasets()).await.unwrap(); + + let names: Vec<&str> = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + + // Alice->2, Alice->3, Bob->4, Carol->4 + assert_eq!(names, vec!["Alice", "Alice", "Bob", "Carol"]); +} + +#[tokio::test] +async fn test_self_join_friends() { + let query = SqlQuery::new( + "SELECT p1.name AS person, p2.name AS friend \ + FROM person p1 \ + JOIN knows k ON p1.id = k.src_id \ + JOIN person p2 ON p2.id = k.dst_id \ + ORDER BY p1.name, p2.name", + ); + let result = query.execute(make_datasets()).await.unwrap(); + + let persons: Vec<&str> = result + .column_by_name("person") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + let friends: Vec<&str> = result + .column_by_name("friend") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + + assert_eq!(persons, vec!["Alice", "Alice", "Bob", "Carol"]); + assert_eq!(friends, vec!["Bob", "Carol", "David", "David"]); +} + +// ============================================================================ +// Aggregations (COUNT, SUM, AVG) +// ============================================================================ + +#[tokio::test] +async fn test_count() { + let query = SqlQuery::new("SELECT COUNT(*) AS cnt FROM person"); + let result = query.execute(make_datasets()).await.unwrap(); + + let cnt = result + .column_by_name("cnt") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(cnt, 4); +} + +#[tokio::test] +async fn test_sum() { + let query = SqlQuery::new("SELECT SUM(age) AS total_age FROM person"); + let result = query.execute(make_datasets()).await.unwrap(); + + let total = result + .column_by_name("total_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert_eq!(total, 28 + 34 + 29 + 42); +} + +#[tokio::test] +async fn test_avg() { + let query = SqlQuery::new("SELECT AVG(age) AS avg_age FROM person"); + let result = query.execute(make_datasets()).await.unwrap(); + + let avg = result + .column_by_name("avg_age") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .value(0); + assert!((avg - 33.25).abs() < 0.01); +} + +#[tokio::test] +async fn test_group_by_with_count() { + let query = SqlQuery::new( + "SELECT city, COUNT(*) AS cnt FROM person GROUP BY city ORDER BY cnt DESC, city", + ); + let result = query.execute(make_datasets()).await.unwrap(); + + let cities: Vec<&str> = result + .column_by_name("city") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + let counts: Vec = result + .column_by_name("cnt") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + + assert_eq!(cities[0], "New York"); + assert_eq!(counts[0], 2); +} + +// ============================================================================ +// Execute with SessionContext (pre-registered tables) +// ============================================================================ + +#[tokio::test] +async fn test_execute_with_session_context() { + use datafusion::datasource::MemTable; + use datafusion::execution::context::SessionContext; + + let ctx = SessionContext::new(); + + // Register person table + let batch = person_batch(); + let schema = batch.schema(); + let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap()); + ctx.register_table("people", mem_table).unwrap(); + + // Register knows table + let batch = knows_batch(); + let schema = batch.schema(); + let mem_table = Arc::new(MemTable::try_new(schema, vec![vec![batch]]).unwrap()); + ctx.register_table("relationships", mem_table).unwrap(); + + let query = SqlQuery::new( + "SELECT p.name, r.since_year \ + FROM people p \ + JOIN relationships r ON p.id = r.src_id \ + ORDER BY p.name, r.since_year", + ); + let result = query.execute_with_context(ctx).await.unwrap(); + + assert_eq!(result.num_rows(), 4); + + let names: Vec<&str> = result + .column_by_name("name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap() + .iter() + .map(|v| v.unwrap()) + .collect(); + assert_eq!(names[0], "Alice"); +} + +// ============================================================================ +// Explain +// ============================================================================ + +#[tokio::test] +async fn test_explain_output() { + let query = SqlQuery::new("SELECT p.name FROM person p JOIN knows k ON p.id = k.src_id"); + let plan = query.explain(make_datasets()).await.unwrap(); + assert!(plan.contains("Logical Plan")); + assert!(plan.contains("Physical Plan")); +} + +// ============================================================================ +// Error handling +// ============================================================================ + +#[tokio::test] +async fn test_invalid_sql() { + let query = SqlQuery::new("NOT VALID SQL"); + let result = query.execute(make_datasets()).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_missing_table() { + let query = SqlQuery::new("SELECT * FROM nonexistent_table"); + let result = query.execute(make_datasets()).await; + assert!(result.is_err()); +} + +// ============================================================================ +// Case insensitivity (table names are lowercased) +// ============================================================================ + +#[tokio::test] +async fn test_case_insensitive_table_names() { + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch()); + + // Table registered as lowercase "person", so SQL should use lowercase + let query = SqlQuery::new("SELECT name FROM person ORDER BY name LIMIT 1"); + let result = query.execute(datasets).await.unwrap(); + assert_eq!(result.num_rows(), 1); +} diff --git a/python/python/lance_graph/__init__.py b/python/python/lance_graph/__init__.py index 8ea65ee1..3a7418e4 100644 --- a/python/python/lance_graph/__init__.py +++ b/python/python/lance_graph/__init__.py @@ -72,6 +72,8 @@ def _load_dev_build() -> ModuleType: GraphConfigBuilder = _bindings.graph.GraphConfigBuilder CypherQuery = _bindings.graph.CypherQuery CypherEngine = _bindings.graph.CypherEngine +SqlQuery = _bindings.graph.SqlQuery +SqlEngine = _bindings.graph.SqlEngine ExecutionStrategy = _bindings.graph.ExecutionStrategy VectorSearch = _bindings.graph.VectorSearch DistanceMetric = _bindings.graph.DistanceMetric @@ -83,6 +85,8 @@ def _load_dev_build() -> ModuleType: "GraphConfigBuilder", "CypherQuery", "CypherEngine", + "SqlQuery", + "SqlEngine", "ExecutionStrategy", "VectorSearch", "DistanceMetric", diff --git a/python/python/tests/test_sql.py b/python/python/tests/test_sql.py new file mode 100644 index 00000000..e9d1b533 --- /dev/null +++ b/python/python/tests/test_sql.py @@ -0,0 +1,175 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright The Lance Authors + +import pyarrow as pa +import pytest +from lance_graph import SqlEngine, SqlQuery + + +@pytest.fixture +def person_table(): + return pa.table( + { + "id": [1, 2, 3, 4], + "name": ["Alice", "Bob", "Carol", "David"], + "age": [28, 34, 29, 42], + "city": ["New York", "San Francisco", "New York", "Chicago"], + } + ) + + +@pytest.fixture +def knows_table(): + return pa.table( + { + "src_id": [1, 1, 2, 3], + "dst_id": [2, 3, 4, 4], + "since_year": [2015, 2018, 2020, 2021], + } + ) + + +@pytest.fixture +def datasets(person_table, knows_table): + return {"person": person_table, "knows": knows_table} + + +# ========================================================================== +# SqlQuery tests +# ========================================================================== + + +class TestSqlQuery: + def test_basic_select(self, datasets): + query = SqlQuery("SELECT name, age FROM person WHERE age > 30 ORDER BY age") + result = query.execute(datasets) + data = result.to_pydict() + + assert data["name"] == ["Bob", "David"] + assert data["age"] == [34, 42] + + def test_select_star(self, datasets): + query = SqlQuery("SELECT * FROM person") + result = query.execute(datasets) + assert result.num_rows == 4 + assert result.num_columns == 4 + + def test_limit(self, datasets): + query = SqlQuery("SELECT name FROM person ORDER BY name LIMIT 2") + result = query.execute(datasets) + assert result.num_rows == 2 + + def test_join(self, datasets): + query = SqlQuery( + "SELECT p.name, k.dst_id " + "FROM person p " + "JOIN knows k ON p.id = k.src_id " + "ORDER BY p.name, k.dst_id" + ) + result = query.execute(datasets) + data = result.to_pydict() + + assert data["name"] == ["Alice", "Alice", "Bob", "Carol"] + + def test_self_join(self, datasets): + query = SqlQuery( + "SELECT p1.name AS person, p2.name AS friend " + "FROM person p1 " + "JOIN knows k ON p1.id = k.src_id " + "JOIN person p2 ON p2.id = k.dst_id " + "ORDER BY p1.name, p2.name" + ) + result = query.execute(datasets) + data = result.to_pydict() + + assert data["person"] == ["Alice", "Alice", "Bob", "Carol"] + assert data["friend"] == ["Bob", "Carol", "David", "David"] + + def test_count(self, datasets): + query = SqlQuery("SELECT COUNT(*) AS cnt FROM person") + result = query.execute(datasets) + assert result.to_pydict()["cnt"] == [4] + + def test_sum(self, datasets): + query = SqlQuery("SELECT SUM(age) AS total FROM person") + result = query.execute(datasets) + assert result.to_pydict()["total"] == [28 + 34 + 29 + 42] + + def test_avg(self, datasets): + query = SqlQuery("SELECT AVG(age) AS avg_age FROM person") + result = query.execute(datasets) + avg = result.to_pydict()["avg_age"][0] + assert abs(avg - 33.25) < 0.01 + + def test_group_by(self, datasets): + query = SqlQuery( + "SELECT city, COUNT(*) AS cnt " + "FROM person GROUP BY city ORDER BY cnt DESC, city" + ) + result = query.execute(datasets) + data = result.to_pydict() + assert data["city"][0] == "New York" + assert data["cnt"][0] == 2 + + def test_explain(self, datasets): + query = SqlQuery("SELECT name FROM person WHERE age > 30") + plan = query.explain(datasets) + assert "Logical Plan" in plan + assert "Physical Plan" in plan + + def test_sql_accessor(self): + query = SqlQuery("SELECT 1") + assert query.sql() == "SELECT 1" + + def test_repr(self): + query = SqlQuery("SELECT 1") + assert "SqlQuery" in repr(query) + + def test_invalid_sql(self, datasets): + query = SqlQuery("INVALID SQL") + with pytest.raises((RuntimeError, ValueError)): + query.execute(datasets) + + def test_case_insensitive_table_names(self, person_table): + """Table name 'Person' should be lowercased to 'person'.""" + query = SqlQuery("SELECT name FROM person LIMIT 1") + result = query.execute({"Person": person_table}) + assert result.num_rows == 1 + + +# ========================================================================== +# SqlEngine tests +# ========================================================================== + + +class TestSqlEngine: + def test_basic_query(self, datasets): + engine = SqlEngine(datasets) + result = engine.execute( + "SELECT name, age FROM person WHERE age > 30 ORDER BY age" + ) + data = result.to_pydict() + assert data["name"] == ["Bob", "David"] + + def test_multiple_queries(self, datasets): + engine = SqlEngine(datasets) + + r1 = engine.execute("SELECT COUNT(*) AS cnt FROM person") + r2 = engine.execute("SELECT name FROM person WHERE age > 30 ORDER BY name") + r3 = engine.execute( + "SELECT p.name, k.dst_id " + "FROM person p JOIN knows k ON p.id = k.src_id " + "ORDER BY p.name LIMIT 2" + ) + + assert r1.to_pydict()["cnt"] == [4] + assert r2.to_pydict()["name"] == ["Bob", "David"] + assert r3.num_rows == 2 + + def test_repr(self, datasets): + engine = SqlEngine(datasets) + assert "SqlEngine" in repr(engine) + + def test_empty_datasets_raises(self): + with pytest.raises(ValueError, match="No input datasets"): + SqlEngine({}) From 49e490d33f3bcfd8874b1dda75c08bf03e108188 Mon Sep 17 00:00:00 2001 From: "jianjian.xie" Date: Thu, 26 Feb 2026 13:27:49 -0800 Subject: [PATCH 2/2] docs: add SqlQuery/SqlEngine to Python and project READMEs --- README.md | 26 ++++++++++++++++++++ python/README.md | 34 +++++++++++++++++++++++++++ python/python/lance_graph/__init__.py | 10 +++++++- 3 files changed, 69 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index f1ed4070..8e61c4b6 100644 --- a/README.md +++ b/README.md @@ -67,6 +67,32 @@ result = query.execute({"Person": people}) print(result.to_pydict()) # {'name': ['Bob', 'David'], 'age': [34, 42]} ``` +## Python example: Direct SQL query + +For data analytics workflows where you prefer standard SQL, use `SqlQuery` or `SqlEngine`. No `GraphConfig` is needed: + +```python +import pyarrow as pa +from lance_graph import SqlQuery, SqlEngine + +person = pa.table({ + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Carol"], + "age": [28, 34, 29], +}) + +# One-off query +result = SqlQuery( + "SELECT name, age FROM person WHERE age > 30" +).execute({"person": person}) +print(result.to_pydict()) # {'name': ['Bob'], 'age': [34]} + +# Multi-query with cached context +engine = SqlEngine({"person": person}) +r1 = engine.execute("SELECT COUNT(*) AS cnt FROM person") +r2 = engine.execute("SELECT name FROM person ORDER BY age DESC LIMIT 2") +``` + ## Knowledge Graph CLI & API The `knowledge_graph` package layers a simple Lance-backed knowledge graph diff --git a/python/README.md b/python/README.md index 2a44a911..21a64a8f 100644 --- a/python/README.md +++ b/python/README.md @@ -82,6 +82,40 @@ print(result1.to_pylist()) # [{'p.name': 'Alice'}] ``` +### 3. Direct SQL Queries + +For data analytics workflows where you prefer standard SQL over Cypher, use `SqlQuery` or `SqlEngine`. No `GraphConfig` is needed — write explicit JOINs against your tables directly: + +```python +import pyarrow as pa +from lance_graph import SqlQuery, SqlEngine + +person = pa.table({ + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Carol"], + "age": [28, 34, 29], +}) +knows = pa.table({"src_id": [1, 1, 2], "dst_id": [2, 3, 3]}) +datasets = {"person": person, "knows": knows} + +# One-off query +result = SqlQuery( + "SELECT p.name, p.age FROM person p WHERE p.age > 30" +).execute(datasets) +print(result.to_pylist()) +# [{'name': 'Bob', 'age': 34}] + +# Multi-query with cached context +engine = SqlEngine(datasets) +r1 = engine.execute("SELECT COUNT(*) AS cnt FROM person") +r2 = engine.execute( + "SELECT p1.name AS person, p2.name AS friend " + "FROM person p1 " + "JOIN knows k ON p1.id = k.src_id " + "JOIN person p2 ON p2.id = k.dst_id" +) +``` + ### 3. Build a Knowledge Graph from Text ```python diff --git a/python/python/lance_graph/__init__.py b/python/python/lance_graph/__init__.py index 3a7418e4..4a348fd8 100644 --- a/python/python/lance_graph/__init__.py +++ b/python/python/lance_graph/__init__.py @@ -1,4 +1,12 @@ -"""Python bindings for the ``lance-graph`` crate.""" +"""Python bindings for the ``lance-graph`` crate. + +Provides two query interfaces: + +- **Cypher**: ``CypherQuery`` and ``CypherEngine`` for graph-pattern queries + (requires a ``GraphConfig`` with node/relationship mappings). +- **SQL**: ``SqlQuery`` and ``SqlEngine`` for standard SQL queries executed + directly against datasets via DataFusion (no ``GraphConfig`` needed). +""" from __future__ import annotations