Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,32 @@ result = query.execute({"Person": people})
print(result.to_pydict()) # {'name': ['Bob', 'David'], 'age': [34, 42]}
```

## Python example: Direct SQL query

For data analytics workflows where you prefer standard SQL, use `SqlQuery` or `SqlEngine`. No `GraphConfig` is needed:

```python
import pyarrow as pa
from lance_graph import SqlQuery, SqlEngine

person = pa.table({
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Carol"],
"age": [28, 34, 29],
})

# One-off query
result = SqlQuery(
"SELECT name, age FROM person WHERE age > 30"
).execute({"person": person})
print(result.to_pydict()) # {'name': ['Bob'], 'age': [34]}

# Multi-query with cached context
engine = SqlEngine({"person": person})
r1 = engine.execute("SELECT COUNT(*) AS cnt FROM person")
r2 = engine.execute("SELECT name FROM person ORDER BY age DESC LIMIT 2")
```

## Knowledge Graph CLI & API

The `knowledge_graph` package layers a simple Lance-backed knowledge graph
Expand Down
255 changes: 238 additions & 17 deletions crates/lance-graph-python/src/graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ use lance_graph::{
ast::{DistanceMetric as RustDistanceMetric, GraphPattern, ReadingClause},
CypherQuery as RustCypherQuery, ExecutionStrategy as RustExecutionStrategy,
GraphConfig as RustGraphConfig, GraphError as RustGraphError, InMemoryCatalog,
VectorSearch as RustVectorSearch,
SqlQuery as RustSqlQuery, VectorSearch as RustVectorSearch,
};
use pyo3::{
exceptions::{PyNotImplementedError, PyRuntimeError, PyValueError},
Expand Down Expand Up @@ -710,7 +710,9 @@ impl CypherQuery {
vector_search: &VectorSearch,
) -> PyResult<PyObject> {
if vector_search.use_lance_index {
if let Some(result) = try_execute_with_lance_index(py, &self.inner, datasets, vector_search)? {
if let Some(result) =
try_execute_with_lance_index(py, &self.inner, datasets, vector_search)?
{
return record_batch_to_python_table(py, &result);
}
}
Expand Down Expand Up @@ -917,10 +919,7 @@ fn split_vector_column(column: &str) -> (Option<String>, &str) {
}
}

fn resolve_vector_label(
query: &RustCypherQuery,
alias: Option<&str>,
) -> PyResult<Option<String>> {
fn resolve_vector_label(query: &RustCypherQuery, alias: Option<&str>) -> PyResult<Option<String>> {
let alias_map = alias_map_from_query(query);
if let Some(alias) = alias {
return Ok(alias_map.get(alias).cloned());
Expand Down Expand Up @@ -956,11 +955,17 @@ fn collect_aliases_from_pattern(pattern: &GraphPattern, map: &mut HashMap<String
}
}
GraphPattern::Path(path) => {
if let (Some(var), Some(label)) = (path.start_node.variable.as_ref(), path.start_node.labels.first()) {
if let (Some(var), Some(label)) = (
path.start_node.variable.as_ref(),
path.start_node.labels.first(),
) {
map.entry(var.clone()).or_insert_with(|| label.clone());
}
for segment in &path.segments {
if let (Some(var), Some(label)) = (segment.end_node.variable.as_ref(), segment.end_node.labels.first()) {
if let (Some(var), Some(label)) = (
segment.end_node.variable.as_ref(),
segment.end_node.labels.first(),
) {
map.entry(var.clone()).or_insert_with(|| label.clone());
}
}
Expand Down Expand Up @@ -1168,14 +1173,20 @@ impl CypherEngine {
// Register all datasets as tables
for (name, batch) in &arrow_datasets {
let mem_table = Arc::new(
MemTable::try_new(batch.schema(), vec![vec![batch.clone()]])
.map_err(|e| PyRuntimeError::new_err(format!("Failed to create MemTable for {}: {}", name, e)))?,
MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]).map_err(|e| {
PyRuntimeError::new_err(format!(
"Failed to create MemTable for {}: {}",
name, e
))
})?,
);

// Register in session context for execution
let normalized_name = name.to_lowercase();
ctx.register_table(&normalized_name, mem_table.clone())
.map_err(|e| PyRuntimeError::new_err(format!("Failed to register table {}: {}", name, e)))?;
.map_err(|e| {
PyRuntimeError::new_err(format!("Failed to register table {}: {}", name, e))
})?;

let table_source = Arc::new(DefaultTableSource::new(mem_table));

Expand All @@ -1186,7 +1197,7 @@ impl CypherEngine {
// based on the Cypher query pattern (e.g., MATCH (p:Person) vs -[:KNOWS]->).
//
// By registering all datasets in both catalogs, we allow the planner to look up
// the correct source based on query context. This pattern matches the Rust
// the correct source based on query context. This pattern matches the Rust
// implementation in query.rs:build_catalog_and_context_from_datasets.
catalog = catalog
.with_node_source(name, table_source.clone())
Expand Down Expand Up @@ -1226,11 +1237,7 @@ impl CypherEngine {
/// --------
/// >>> result = engine.execute("MATCH (p:Person) WHERE p.age > 30 RETURN p.name")
/// >>> print(result.to_pandas())
fn execute(
&self,
py: Python,
query: &str,
) -> PyResult<PyObject> {
fn execute(&self, py: Python, query: &str) -> PyResult<PyObject> {
// Parse the query
let cypher_query = RustCypherQuery::new(query)
.map_err(graph_error_to_pyerr)?
Expand Down Expand Up @@ -1314,6 +1321,218 @@ impl CypherEngine {
}
}

/// Execute raw SQL queries against in-memory datasets
///
/// This class allows executing standard SQL directly against Arrow tables,
/// without requiring a GraphConfig or Cypher parsing. DataFusion handles
/// SQL parsing and execution.
///
/// Examples
/// --------
/// >>> import pyarrow as pa
/// >>> from lance_graph import SqlQuery
/// >>>
/// >>> person = pa.table({"id": [1, 2], "name": ["Alice", "Bob"], "age": [28, 34]})
/// >>> query = SqlQuery("SELECT name, age FROM person WHERE age > 30")
/// >>> result = query.execute({"person": person})
/// >>> print(result.to_pandas())
#[pyclass(name = "SqlQuery", module = "lance.graph")]
pub struct SqlQuery {
inner: RustSqlQuery,
}

#[pymethods]
impl SqlQuery {
/// Create a new SQL query
///
/// Parameters
/// ----------
/// sql : str
/// The SQL query string
///
/// Returns
/// -------
/// SqlQuery
/// A new SQL query instance
#[new]
fn new(sql: &str) -> Self {
Self {
inner: RustSqlQuery::new(sql),
}
}

/// Get the SQL query text
fn sql(&self) -> &str {
self.inner.sql()
}

/// Execute query against in-memory datasets
///
/// Parameters
/// ----------
/// datasets : dict
/// Dictionary mapping table names to PyArrow tables or Lance datasets
///
/// Returns
/// -------
/// pyarrow.Table
/// Query results as Arrow table
///
/// Raises
/// ------
/// RuntimeError
/// If query execution fails
fn execute(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<PyObject> {
let arrow_datasets = python_datasets_to_batches(datasets)?;
let inner = self.inner.clone();

let result_batch = RT
.block_on(Some(py), inner.execute(arrow_datasets))?
.map_err(graph_error_to_pyerr)?;

record_batch_to_python_table(py, &result_batch)
}

/// Return the execution plan as a string
///
/// Parameters
/// ----------
/// datasets : dict
/// Dictionary mapping table names to PyArrow tables or Lance datasets
///
/// Returns
/// -------
/// str
/// The DataFusion logical and physical execution plan
///
/// Raises
/// ------
/// RuntimeError
/// If planning fails
fn explain(&self, py: Python, datasets: &Bound<'_, PyDict>) -> PyResult<String> {
let arrow_datasets = python_datasets_to_batches(datasets)?;
let inner = self.inner.clone();

let plan = RT
.block_on(Some(py), inner.explain(arrow_datasets))?
.map_err(graph_error_to_pyerr)?;

Ok(plan)
}

fn __repr__(&self) -> String {
format!("SqlQuery(\"{}\")", self.inner.sql())
}
}

/// Cached SQL execution engine for running multiple queries against the same datasets
///
/// This class registers datasets once during initialization and reuses the
/// DataFusion SessionContext for subsequent queries, avoiding repeated setup.
///
/// Examples
/// --------
/// >>> from lance_graph import SqlEngine
/// >>> import pyarrow as pa
/// >>>
/// >>> datasets = {
/// ... "person": pa.table({"id": [1, 2], "name": ["Alice", "Bob"], "age": [28, 34]}),
/// ... "knows": pa.table({"src": [1], "dst": [2]}),
/// ... }
/// >>>
/// >>> engine = SqlEngine(datasets)
/// >>> result1 = engine.execute("SELECT * FROM person WHERE age > 30")
/// >>> result2 = engine.execute("SELECT p.name FROM person p JOIN knows k ON p.id = k.src")
#[pyclass(name = "SqlEngine", module = "lance.graph")]
pub struct SqlEngine {
context: Arc<SessionContext>,
}

#[pymethods]
impl SqlEngine {
/// Create a new SqlEngine with cached datasets
///
/// Parameters
/// ----------
/// datasets : dict
/// Dictionary mapping table names to PyArrow tables or Lance datasets.
/// Table names are lowercased for consistency.
///
/// Returns
/// -------
/// SqlEngine
/// A new engine instance ready to execute queries
///
/// Raises
/// ------
/// ValueError
/// If no datasets are provided
/// RuntimeError
/// If table registration fails
#[new]
fn new(datasets: &Bound<'_, PyDict>) -> PyResult<Self> {
let arrow_datasets = python_datasets_to_batches(datasets)?;

if arrow_datasets.is_empty() {
return Err(PyValueError::new_err("No input datasets provided"));
}

let ctx = SessionContext::new();

for (name, batch) in &arrow_datasets {
let mem_table = Arc::new(
MemTable::try_new(batch.schema(), vec![vec![batch.clone()]]).map_err(|e| {
PyRuntimeError::new_err(format!(
"Failed to create MemTable for {}: {}",
name, e
))
})?,
);

let normalized_name = name.to_lowercase();
ctx.register_table(&normalized_name, mem_table)
.map_err(|e| {
PyRuntimeError::new_err(format!("Failed to register table {}: {}", name, e))
})?;
}

Ok(Self {
context: Arc::new(ctx),
})
}

/// Execute a SQL query using the cached datasets
///
/// Parameters
/// ----------
/// sql : str
/// The SQL query string to execute
///
/// Returns
/// -------
/// pyarrow.Table
/// Query results as Arrow table
///
/// Raises
/// ------
/// RuntimeError
/// If query execution fails
fn execute(&self, py: Python, sql: &str) -> PyResult<PyObject> {
let query = RustSqlQuery::new(sql);
let context = self.context.as_ref().clone();

let result_batch = RT
.block_on(Some(py), query.execute_with_context(context))?
.map_err(graph_error_to_pyerr)?;

record_batch_to_python_table(py, &result_batch)
}

fn __repr__(&self) -> String {
"SqlEngine(...)".to_string()
}
}

/// Register graph functionality with the Python module
pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) -> PyResult<()> {
let graph_module = PyModule::new(py, "graph")?;
Expand All @@ -1324,6 +1543,8 @@ pub fn register_graph_module(py: Python, parent_module: &Bound<'_, PyModule>) ->
graph_module.add_class::<GraphConfigBuilder>()?;
graph_module.add_class::<CypherQuery>()?;
graph_module.add_class::<CypherEngine>()?;
graph_module.add_class::<SqlQuery>()?;
graph_module.add_class::<SqlEngine>()?;
graph_module.add_class::<VectorSearch>()?;
graph_module.add_class::<PyDirNamespace>()?;

Expand Down
2 changes: 2 additions & 0 deletions crates/lance-graph/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ pub mod parser;
pub mod query;
pub mod semantic;
pub mod simple_executor;
pub mod sql_query;

/// Maximum allowed hops for variable-length relationship expansion (e.g., *1..N)
pub const MAX_VARIABLE_LENGTH_HOPS: u32 = 20;
Expand All @@ -58,3 +59,4 @@ pub use lance_graph_catalog::{
};
pub use lance_vector_search::VectorSearch;
pub use query::{CypherQuery, ExecutionStrategy};
pub use sql_query::SqlQuery;
4 changes: 2 additions & 2 deletions crates/lance-graph/src/query.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ use std::sync::Arc;
///
/// This ensures that column names in the dataset match the normalized
/// qualified column names used internally (e.g., "fullName" becomes "fullname").
fn normalize_schema(schema: SchemaRef) -> Result<SchemaRef> {
pub(crate) fn normalize_schema(schema: SchemaRef) -> Result<SchemaRef> {
let fields: Vec<_> = schema
.fields()
.iter()
Expand All @@ -42,7 +42,7 @@ fn normalize_schema(schema: SchemaRef) -> Result<SchemaRef> {
///
/// This creates a new RecordBatch with a normalized schema while
/// preserving all the data arrays.
fn normalize_record_batch(batch: &RecordBatch) -> Result<RecordBatch> {
pub(crate) fn normalize_record_batch(batch: &RecordBatch) -> Result<RecordBatch> {
let normalized_schema = normalize_schema(batch.schema())?;
RecordBatch::try_new(normalized_schema, batch.columns().to_vec()).map_err(|e| {
GraphError::PlanError {
Expand Down
Loading
Loading