diff --git a/crates/lance-graph/src/logical_plan.rs b/crates/lance-graph/src/logical_plan.rs index 236706a1..9d6c32f8 100644 --- a/crates/lance-graph/src/logical_plan.rs +++ b/crates/lance-graph/src/logical_plan.rs @@ -8,9 +8,10 @@ //! //! Logical plans describe WHAT operations to perform, not HOW to perform them. -use crate::ast::*; use crate::error::{GraphError, Result}; +use crate::{ast::*, GraphConfig}; use serde::{Deserialize, Serialize}; +use snafu::Location; use std::collections::HashMap; /// A logical plan operator - describes what operation to perform @@ -149,15 +150,17 @@ pub struct SortItem { } /// Logical plan builder - converts AST to logical plan -pub struct LogicalPlanner { +pub struct LogicalPlanner<'a> { /// Track variables in scope variables: HashMap, // variable -> label + config: &'a GraphConfig, } -impl LogicalPlanner { - pub fn new() -> Self { +impl<'a> LogicalPlanner<'a> { + pub fn new(config: &'a GraphConfig) -> Self { Self { variables: HashMap::new(), + config, } } @@ -506,14 +509,58 @@ impl LogicalPlanner { return_clause: &ReturnClause, input: LogicalOperator, ) -> Result { - let projections = return_clause - .items - .iter() - .map(|item| ProjectionItem { - expression: item.expression.clone(), - alias: item.alias.clone(), - }) - .collect(); + let mut projections: Vec = Vec::new(); + + for item in &return_clause.items { + let alias = &item.alias; + match &item.expression { + ValueExpression::Variable(var) => { + match self.variables.get(var) { + // if it is a node variable, expand to all properties + Some(label) if label != "Unwound" => { + let mapping = self.config.get_node_mapping(label).ok_or_else(|| { + GraphError::PlanError { + message: format!("Node label '{}' doesn't exist", label), + location: Location::new(file!(), line!(), column!()), + } + })?; + + projections.push(ProjectionItem { + expression: ValueExpression::Property(PropertyRef { + variable: var.clone(), + property: mapping.id_field.clone(), + }), + alias: alias + .clone() + .map(|name| format!("{}.{}", name, mapping.id_field)), + }); + + for prop in &mapping.property_fields { + projections.push(ProjectionItem { + expression: ValueExpression::Property(PropertyRef { + variable: var.clone(), + property: prop.clone(), + }), + alias: alias.clone().map(|name| format!("{}.{}", name, prop)), + }); + } + } + _ => { + projections.push(ProjectionItem { + expression: item.expression.clone(), + alias: alias.clone(), + }); + } + } + } + _ => { + projections.push(ProjectionItem { + expression: item.expression.clone(), + alias: alias.clone(), + }); + } + } + } let mut plan = LogicalOperator::Project { input: Box::new(input), @@ -578,16 +625,10 @@ impl LogicalPlanner { } } -impl Default for LogicalPlanner { - fn default() -> Self { - Self::new() - } -} - #[cfg(test)] mod tests { use super::*; - use crate::parser::parse_cypher_query; + use crate::{parser::parse_cypher_query, NodeMapping}; #[test] fn test_relationship_query_logical_plan_structure() { @@ -597,7 +638,8 @@ mod tests { let ast = parse_cypher_query(query_text).unwrap(); // Plan to logical operators - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Verify the overall structure is a projection @@ -698,7 +740,8 @@ mod tests { let query_text = "MATCH (n:Person) RETURN n.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Should be: Project { input: ScanByLabel } @@ -724,7 +767,8 @@ mod tests { let query_text = "MATCH (n:Person {age: 25}) RETURN n.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Should be: Project { input: ScanByLabel with properties } @@ -759,7 +803,8 @@ mod tests { let query_text = "MATCH (a:Person)-[:KNOWS*1..2]->(b:Person) RETURN b.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Should be: Project { input: VariableLengthExpand { input: ScanByLabel } } @@ -802,7 +847,8 @@ mod tests { let query_text = r#"MATCH (n:Person) WHERE n.age > 25 RETURN n.name"#; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Should be: Project { input: Filter { input: ScanByLabel } } @@ -849,7 +895,8 @@ mod tests { let query_text = "MATCH (a:Person), (b:Company) RETURN a.name, b.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Expect: Project { input: Join { left: Scan(a:Person), right: Scan(b:Company) } } @@ -895,7 +942,8 @@ mod tests { "MATCH (a:Person)-[:KNOWS]->(b:Person), (b)-[:LIKES]->(c:Thing) RETURN c.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Expect: Project { input: Expand (b->c) { input: Expand (a->b) { input: Scan(a) } } } @@ -942,7 +990,8 @@ mod tests { let query_text = "MATCH (a:Person)-[:KNOWS*1..1]->(b:Person) RETURN b.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); match &logical_plan { @@ -966,7 +1015,8 @@ mod tests { // DISTINCT should wrap Project with Distinct let q1 = "MATCH (n:Person) RETURN DISTINCT n.name"; let ast1 = parse_cypher_query(q1).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical1 = planner.plan(&ast1).unwrap(); match logical1 { LogicalOperator::Distinct { input } => match *input { @@ -979,7 +1029,7 @@ mod tests { // ORDER BY + LIMIT should be Limit(Sort(Project(..))) let q2 = "MATCH (n:Person) RETURN n.name ORDER BY n.name LIMIT 10"; let ast2 = parse_cypher_query(q2).unwrap(); - let mut planner2 = LogicalPlanner::new(); + let mut planner2 = LogicalPlanner::new(&config); let logical2 = planner2.plan(&ast2).unwrap(); match logical2 { LogicalOperator::Limit { input, count } => { @@ -1001,7 +1051,8 @@ mod tests { // ORDER BY + SKIP + LIMIT should be Limit(Offset(Sort(Project(..)))) let q = "MATCH (n:Person) RETURN n.name ORDER BY n.name SKIP 5 LIMIT 10"; let ast = parse_cypher_query(q).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical = planner.plan(&ast).unwrap(); match logical { LogicalOperator::Limit { input, count } => { @@ -1032,7 +1083,8 @@ mod tests { // SKIP only should be Offset(Project(..)) let q = "MATCH (n:Person) RETURN n.name SKIP 3"; let ast = parse_cypher_query(q).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical = planner.plan(&ast).unwrap(); match logical { LogicalOperator::Offset { input, offset } => { @@ -1048,9 +1100,10 @@ mod tests { #[test] fn test_relationship_properties_pushed_into_expand() { - let q = "MATCH (a)-[:KNOWS {since: 2020}]->(b) RETURN b"; + let q = "MATCH (a)-[:KNOWS {since: 2020}]->(b) RETURN b.name"; let ast = parse_cypher_query(q).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical = planner.plan(&ast).unwrap(); match logical { LogicalOperator::Project { input, .. } => match *input { @@ -1068,7 +1121,8 @@ mod tests { fn test_multiple_match_clauses_cross_join() { let q = "MATCH (a:Person) MATCH (b:Company) RETURN a.name, b.name"; let ast = parse_cypher_query(q).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical = planner.plan(&ast).unwrap(); match logical { LogicalOperator::Project { input, .. } => match *input { @@ -1107,9 +1161,10 @@ mod tests { #[test] fn test_variable_only_node_default_label() { - let q = "MATCH (x) RETURN x"; + let q = "MATCH (x) RETURN x.name"; let ast = parse_cypher_query(q).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical = planner.plan(&ast).unwrap(); match logical { LogicalOperator::Project { input, .. } => match *input { @@ -1127,9 +1182,10 @@ mod tests { #[test] fn test_multi_label_node_uses_first_label() { - let q = "MATCH (n:Person:Employee) RETURN n"; + let q = "MATCH (n:Person:Employee) RETURN n.name"; let ast = parse_cypher_query(q).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical = planner.plan(&ast).unwrap(); match logical { LogicalOperator::Project { input, .. } => match *input { @@ -1145,9 +1201,10 @@ mod tests { #[test] fn test_open_ended_and_partial_var_length_ranges() { // * (unbounded) - let q1 = "MATCH (a)-[:R*]->(b) RETURN b"; + let q1 = "MATCH (a)-[:R*]->(b:Node) RETURN b.name"; let ast1 = parse_cypher_query(q1).unwrap(); - let mut planner1 = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner1 = LogicalPlanner::new(&config); let plan1 = planner1.plan(&ast1).unwrap(); match plan1 { LogicalOperator::Project { input, .. } => match *input { @@ -1165,9 +1222,9 @@ mod tests { } // *2.. (min only) - let q2 = "MATCH (a)-[:R*2..]->(b) RETURN b"; + let q2 = "MATCH (a)-[:R*2..]->(b) RETURN b.name"; let ast2 = parse_cypher_query(q2).unwrap(); - let mut planner2 = LogicalPlanner::new(); + let mut planner2 = LogicalPlanner::new(&config); let plan2 = planner2.plan(&ast2).unwrap(); match plan2 { LogicalOperator::Project { input, .. } => match *input { @@ -1185,9 +1242,9 @@ mod tests { } // *..3 (max only) - let q3 = "MATCH (a)-[:R*..3]->(b) RETURN b"; + let q3 = "MATCH (a)-[:R*..3]->(b) RETURN b.name"; let ast3 = parse_cypher_query(q3).unwrap(); - let mut planner3 = LogicalPlanner::new(); + let mut planner3 = LogicalPlanner::new(&config); let plan3 = planner3.plan(&ast3).unwrap(); match plan3 { LogicalOperator::Project { input, .. } => match *input { @@ -1211,7 +1268,8 @@ mod tests { "MATCH (a:Person)-[:KNOWS]->(shared:Person), (shared)-[:KNOWS]->(b:Person) RETURN b.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let logical_plan = planner.plan(&ast).unwrap(); // Expect: Project { Expand(shared->b) { Expand(a->shared) { Scan(a) } } } @@ -1250,7 +1308,8 @@ mod tests { "MATCH (a:Person)-[:KNOWS]->(shared:Person), (shared:Company)-[:EMPLOYS]->(b:Person) RETURN b.name"; let ast = parse_cypher_query(query_text).unwrap(); - let mut planner = LogicalPlanner::new(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); let err = planner.plan(&ast).unwrap_err(); let err_msg = err.to_string(); @@ -1261,4 +1320,98 @@ mod tests { err_msg ); } + + #[test] + fn test_return_node_variable() { + let query_text = "MATCH (a:Person) RETURN a"; + + let ast = parse_cypher_query(query_text).unwrap(); + let config = GraphConfig::builder() + .with_node_mapping(NodeMapping { + label: "Person".to_string(), + id_field: "id".to_string(), + property_fields: vec!["name".to_string(), "age".to_string()], + filter_conditions: None, + }) + .build() + .unwrap(); + let mut planner = LogicalPlanner::new(&config); + let logical_plan = planner.plan(&ast).unwrap(); + + match &logical_plan { + LogicalOperator::Project { projections, .. } => { + assert_eq!(projections.len(), 3); + match &projections[0].expression { + ValueExpression::Property(prop_ref) => { + assert_eq!(prop_ref.variable, "a"); + assert_eq!(prop_ref.property, "id"); + } + _ => panic!("Expected property reference for a.id"), + } + match &projections[1].expression { + ValueExpression::Property(prop_ref) => { + assert_eq!(prop_ref.variable, "a"); + assert_eq!(prop_ref.property, "name"); + } + _ => panic!("Expected property reference for a.name"), + } + match &projections[2].expression { + ValueExpression::Property(prop_ref) => { + assert_eq!(prop_ref.variable, "a"); + assert_eq!(prop_ref.property, "age"); + } + _ => panic!("Expected property reference for a.age"), + } + } + _ => panic!("Expected Project at the top level"), + } + } + + #[test] + fn test_return_node_variable_with_alias() { + let query_text = "MATCH (a:Person) RETURN a AS b"; + + let ast = parse_cypher_query(query_text).unwrap(); + let config = GraphConfig::builder() + .with_node_label("Person", "id") + .build() + .unwrap(); + let mut planner = LogicalPlanner::new(&config); + let logical_plan = planner.plan(&ast).unwrap(); + + match &logical_plan { + LogicalOperator::Project { projections, .. } => { + assert_eq!(projections.len(), 1); + match &projections[0].expression { + ValueExpression::Property(prop_ref) => { + assert_eq!(prop_ref.variable, "a"); + assert_eq!(prop_ref.property, "id"); + } + _ => panic!("Expected property reference for a.id"), + } + match &projections[0].alias { + Some(alias) => assert_eq!(alias, "b.id"), + None => panic!("Expected alias for a.id as b.id"), + } + } + _ => panic!("Expected Project at the top level"), + } + } + + #[test] + fn test_return_node_variable_no_label() { + let query_text = "MATCH (a:Person) RETURN a"; + + let ast = parse_cypher_query(query_text).unwrap(); + let config = GraphConfig::default(); + let mut planner = LogicalPlanner::new(&config); + let err = planner.plan(&ast).unwrap_err(); + let err_msg = err.to_string(); + + assert!( + err_msg.contains("Node label 'Person' doesn't exist"), + "Expected error about missing label 'Person', got: {}", + err_msg + ); + } } diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index 0a10e7f1..bdf33384 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -778,7 +778,7 @@ impl CypherQuery { } // Phase 2: Graph Logical Plan - let mut logical_planner = LogicalPlanner::new(); + let mut logical_planner = LogicalPlanner::new(config); let logical_plan = logical_planner.plan(&self.ast)?; // Phase 3: DataFusion Logical Plan diff --git a/crates/lance-graph/tests/test_complex_return_clauses.rs b/crates/lance-graph/tests/test_complex_return_clauses.rs new file mode 100644 index 00000000..24e53e32 --- /dev/null +++ b/crates/lance-graph/tests/test_complex_return_clauses.rs @@ -0,0 +1,142 @@ +use arrow_array::{Int64Array, RecordBatch, StringArray}; +use arrow_schema::{DataType, Field, Schema}; +use lance_arrow::SchemaExt; +use lance_graph::config::GraphConfig; +use lance_graph::{CypherQuery, ExecutionStrategy, NodeMapping}; +use std::collections::HashMap; +use std::sync::Arc; + +// This test suite validates complex RETURN clause scenarios +// +// Datasets used: +// +// Person Dataset: +// | id | name | +// |----|---------| +// | 1 | Bob | +// | 2 | Alice | +// | 3 | Charlie | +// +// Person Dataset with Duplicates: +// | id | name | +// |----|---------| +// | 1 | Bob | +// | 1 | Bob | +// | 3 | Charlie | +// +// Scenarios Tested: +// 1. RETURN node_variable; should expand to all properties in RETURN clause +// 2. RETURN DISTINCT node_variable; should expand to all properties and return unique rows +// 3. RETURN node_variable ORDER BY; should expand to all properties and sort accordingly + +/// Helper to create Person dataset +fn create_person_dataset() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 2, 3])), + Arc::new(StringArray::from(vec!["Bob", "Alice", "Charlie"])), + ], + ) + .unwrap() +} + +/// Helper to create Person dataset with duplicates +fn create_person_dataset_with_dups() -> RecordBatch { + let schema = Arc::new(Schema::new(vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8, false), + ])); + + RecordBatch::try_new( + schema, + vec![ + Arc::new(Int64Array::from(vec![1, 1, 3])), + Arc::new(StringArray::from(vec!["Bob", "Bob", "Charlie"])), + ], + ) + .unwrap() +} + +fn create_graph_config() -> GraphConfig { + GraphConfig::builder() + .with_node_mapping(NodeMapping { + label: "Person".to_string(), + id_field: "id".to_string(), + property_fields: vec!["name".to_string()], + filter_conditions: None, + }) + .build() + .unwrap() +} + +#[tokio::test] +async fn test_return_node_variable_expands_properties() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p") + .unwrap() + .with_config(config); + let datasets = HashMap::from([("Person".to_string(), person_batch)]); + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_columns(), 2); + assert_eq!(result.schema().field_names(), vec!["p.id", "p.name"]); + assert_eq!(result.num_rows(), 3); +} + +#[tokio::test] +async fn test_return_node_variable_expands_properties_with_distinct() { + let config = create_graph_config(); + let person_batch = create_person_dataset_with_dups(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN DISTINCT p") + .unwrap() + .with_config(config); + let datasets = HashMap::from([("Person".to_string(), person_batch)]); + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_columns(), 2); + assert_eq!(result.schema().field_names(), vec!["p.id", "p.name"]); + assert_eq!(result.num_rows(), 2); +} + +#[tokio::test] +async fn test_return_node_variable_expands_properties_with_order_by() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + let query = CypherQuery::new("MATCH (p:Person) RETURN p ORDER BY p.name") + .unwrap() + .with_config(config); + let datasets = HashMap::from([("Person".to_string(), person_batch)]); + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + assert_eq!(result.num_columns(), 2); + assert_eq!(result.schema().field_names(), vec!["p.id", "p.name"]); + + let names: &StringArray = result + .column_by_name("p.name") + .unwrap() + .as_any() + .downcast_ref::() + .unwrap(); + let expected: StringArray = StringArray::from(vec!["Alice", "Bob", "Charlie"]); + + assert_eq!(names, &expected); +}