diff --git a/crates/lance-graph/src/datafusion_planner/expression.rs b/crates/lance-graph/src/datafusion_planner/expression.rs index 2f4dd7e4..8688ba58 100644 --- a/crates/lance-graph/src/datafusion_planner/expression.rs +++ b/crates/lance-graph/src/datafusion_planner/expression.rs @@ -122,7 +122,12 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { VE::Literal(PV::Null) => { datafusion::logical_expr::Expr::Literal(datafusion::scalar::ScalarValue::Null, None) } - VE::Literal(PV::Parameter(_)) => lit(0), + VE::Literal(PV::Parameter(name)) => { + panic!( + "Parameter ${} should have been substituted during semantic analysis", + name + ); + } VE::Literal(PV::Property(prop)) => { // Create qualified column name: variable__property (lowercase for case-insensitivity) col(qualify_column(&prop.variable, &prop.property)) @@ -316,18 +321,10 @@ pub(crate) fn to_df_value_expr(expr: &ValueExpression) -> Expr { lit(scalar) } VE::Parameter(name) => { - // TODO: Implement proper parameter resolution - // Parameters ($param) should be resolved to literal values from the query's - // parameter map (CypherQuery::parameters()) before or during planning. - // - // Current limitation: This creates a column reference as a placeholder, - // which will fail at execution if the column doesn't exist. - // - // Proper fix requires one of: - // 1. Resolve parameters during semantic analysis (substitute before planning) - // 2. Pass parameter map to to_df_value_expr and resolve here - // 3. Use DataFusion's parameter binding mechanism - col(format!("${}", name)) + panic!( + "Parameter ${} should have been substituted during semantic analysis", + name + ); } } } diff --git a/crates/lance-graph/src/lib.rs b/crates/lance-graph/src/lib.rs index 692773ad..d36043cd 100644 --- a/crates/lance-graph/src/lib.rs +++ b/crates/lance-graph/src/lib.rs @@ -43,6 +43,7 @@ pub mod error; pub mod lance_native_planner; pub mod lance_vector_search; pub mod logical_plan; +pub mod parameter_substitution; pub mod parser; pub mod query; pub mod semantic; diff --git a/crates/lance-graph/src/parameter_substitution.rs b/crates/lance-graph/src/parameter_substitution.rs new file mode 100644 index 00000000..5f641c29 --- /dev/null +++ b/crates/lance-graph/src/parameter_substitution.rs @@ -0,0 +1,280 @@ +use crate::ast::*; +use crate::error::{GraphError, Result}; +use std::collections::HashMap; + +/// Substitute parameters with literal values in the AST +pub fn substitute_parameters( + query: &mut CypherQuery, + parameters: &HashMap, +) -> Result<()> { + // Substitute in READING clauses + for reading_clause in &mut query.reading_clauses { + substitute_in_reading_clause(reading_clause, parameters)?; + } + + // Substitute in WHERE clause + if let Some(where_clause) = &mut query.where_clause { + substitute_in_where_clause(where_clause, parameters)?; + } + + // Substitute in WITH clause + if let Some(with_clause) = &mut query.with_clause { + substitute_in_with_clause(with_clause, parameters)?; + } + + // Substitute in post-WITH READING clauses + for reading_clause in &mut query.post_with_reading_clauses { + substitute_in_reading_clause(reading_clause, parameters)?; + } + + // Substitute in post-WITH WHERE clause + if let Some(post_where) = &mut query.post_with_where_clause { + substitute_in_where_clause(post_where, parameters)?; + } + + // Substitute in RETURN clause + substitute_in_return_clause(&mut query.return_clause, parameters)?; + + // Substitute in ORDER BY clause + if let Some(order_by) = &mut query.order_by { + substitute_in_order_by_clause(order_by, parameters)?; + } + + Ok(()) +} + +fn substitute_in_reading_clause( + clause: &mut ReadingClause, + parameters: &HashMap, +) -> Result<()> { + match clause { + ReadingClause::Match(match_clause) => { + for pattern in &mut match_clause.patterns { + substitute_in_graph_pattern(pattern, parameters)?; + } + } + ReadingClause::Unwind(unwind_clause) => { + substitute_in_value_expression(&mut unwind_clause.expression, parameters)?; + } + } + Ok(()) +} + +fn substitute_in_graph_pattern( + pattern: &mut GraphPattern, + parameters: &HashMap, +) -> Result<()> { + match pattern { + GraphPattern::Node(node) => { + for value in node.properties.values_mut() { + substitute_in_property_value(value, parameters)?; + } + } + GraphPattern::Path(path) => { + substitute_in_node_pattern(&mut path.start_node, parameters)?; + for segment in &mut path.segments { + substitute_in_relationship_pattern(&mut segment.relationship, parameters)?; + substitute_in_node_pattern(&mut segment.end_node, parameters)?; + } + } + } + Ok(()) +} + +fn substitute_in_node_pattern( + node: &mut NodePattern, + parameters: &HashMap, +) -> Result<()> { + for value in node.properties.values_mut() { + substitute_in_property_value(value, parameters)?; + } + Ok(()) +} + +fn substitute_in_relationship_pattern( + rel: &mut RelationshipPattern, + parameters: &HashMap, +) -> Result<()> { + for value in rel.properties.values_mut() { + substitute_in_property_value(value, parameters)?; + } + Ok(()) +} + +fn substitute_in_property_value( + value: &mut PropertyValue, + parameters: &HashMap, +) -> Result<()> { + if let PropertyValue::Parameter(name) = value { + let param_value = + parameters + .get(&name.to_lowercase()) + .ok_or_else(|| GraphError::PlanError { + message: format!("Missing parameter: ${}", name), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + *value = json_to_property_value(param_value)?; + } + Ok(()) +} + +fn substitute_in_where_clause( + where_clause: &mut WhereClause, + parameters: &HashMap, +) -> Result<()> { + substitute_in_boolean_expression(&mut where_clause.expression, parameters) +} + +fn substitute_in_with_clause( + with_clause: &mut WithClause, + parameters: &HashMap, +) -> Result<()> { + for item in &mut with_clause.items { + substitute_in_value_expression(&mut item.expression, parameters)?; + } + if let Some(order_by) = &mut with_clause.order_by { + substitute_in_order_by_clause(order_by, parameters)?; + } + Ok(()) +} + +fn substitute_in_return_clause( + return_clause: &mut ReturnClause, + parameters: &HashMap, +) -> Result<()> { + for item in &mut return_clause.items { + substitute_in_value_expression(&mut item.expression, parameters)?; + } + Ok(()) +} + +fn substitute_in_order_by_clause( + order_by: &mut OrderByClause, + parameters: &HashMap, +) -> Result<()> { + for item in &mut order_by.items { + substitute_in_value_expression(&mut item.expression, parameters)?; + } + Ok(()) +} + +fn substitute_in_boolean_expression( + expr: &mut BooleanExpression, + parameters: &HashMap, +) -> Result<()> { + match expr { + BooleanExpression::Comparison { left, right, .. } => { + substitute_in_value_expression(left, parameters)?; + substitute_in_value_expression(right, parameters)?; + } + BooleanExpression::And(left, right) | BooleanExpression::Or(left, right) => { + substitute_in_boolean_expression(left, parameters)?; + substitute_in_boolean_expression(right, parameters)?; + } + BooleanExpression::Not(inner) => { + substitute_in_boolean_expression(inner, parameters)?; + } + BooleanExpression::Exists(_) => {} + BooleanExpression::In { expression, list } => { + substitute_in_value_expression(expression, parameters)?; + for item in list { + substitute_in_value_expression(item, parameters)?; + } + } + BooleanExpression::Like { expression, .. } + | BooleanExpression::ILike { expression, .. } + | BooleanExpression::Contains { expression, .. } + | BooleanExpression::StartsWith { expression, .. } + | BooleanExpression::EndsWith { expression, .. } + | BooleanExpression::IsNull(expression) + | BooleanExpression::IsNotNull(expression) => { + substitute_in_value_expression(expression, parameters)?; + } + } + Ok(()) +} + +fn substitute_in_value_expression( + expr: &mut ValueExpression, + parameters: &HashMap, +) -> Result<()> { + match expr { + ValueExpression::Parameter(name) => { + let param_value = + parameters + .get(&name.to_lowercase()) + .ok_or_else(|| GraphError::PlanError { + message: format!("Missing parameter: ${}", name), + location: snafu::Location::new(file!(), line!(), column!()), + })?; + + // Check for array to VectorLiteral conversion + if let serde_json::Value::Array(arr) = param_value { + let mut floats = Vec::new(); + for v in arr { + if let Some(f) = v.as_f64() { + floats.push(f as f32); + } else { + return Err(GraphError::PlanError { + message: format!( + "Parameter ${} is a list but contains non-numeric values. Only float vectors are supported as list parameters currently.", + name + ), + location: snafu::Location::new(file!(), line!(), column!()), + }); + } + } + *expr = ValueExpression::VectorLiteral(floats); + return Ok(()); + } + + // Scalar conversion + let prop_val = json_to_property_value(param_value)?; + *expr = ValueExpression::Literal(prop_val); + } + ValueExpression::ScalarFunction { args, .. } + | ValueExpression::AggregateFunction { args, .. } => { + for arg in args { + substitute_in_value_expression(arg, parameters)?; + } + } + ValueExpression::Arithmetic { left, right, .. } => { + substitute_in_value_expression(left, parameters)?; + substitute_in_value_expression(right, parameters)?; + } + ValueExpression::VectorDistance { left, right, .. } + | ValueExpression::VectorSimilarity { left, right, .. } => { + substitute_in_value_expression(left, parameters)?; + substitute_in_value_expression(right, parameters)?; + } + _ => {} + } + Ok(()) +} + +fn json_to_property_value(value: &serde_json::Value) -> Result { + match value { + serde_json::Value::Null => Ok(PropertyValue::Null), + serde_json::Value::Bool(b) => Ok(PropertyValue::Boolean(*b)), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + Ok(PropertyValue::Integer(i)) + } else if let Some(f) = n.as_f64() { + Ok(PropertyValue::Float(f)) + } else { + Err(GraphError::PlanError { + message: format!("Number parameter could not be converted to i64 or f64: {}", n), + location: snafu::Location::new(file!(), line!(), column!()), + }) + } + } + serde_json::Value::String(s) => Ok(PropertyValue::String(s.clone())), + serde_json::Value::Array(_) | serde_json::Value::Object(_) => { + Err(GraphError::PlanError { + message: "Complex types (List, Map) are not fully supported as parameters yet (except float vectors).".to_string(), + location: snafu::Location::new(file!(), line!(), column!()), + }) + } + } +} diff --git a/crates/lance-graph/src/parser.rs b/crates/lance-graph/src/parser.rs index 8887ce23..03a79e58 100644 --- a/crates/lance-graph/src/parser.rs +++ b/crates/lance-graph/src/parser.rs @@ -589,9 +589,8 @@ fn parse_vector_similarity(input: &str) -> IResult<&str, ValueExpression> { // Parse parameter reference: $name fn parse_parameter(input: &str) -> IResult<&str, ValueExpression> { - let (input, _) = char('$')(input)?; - let (input, name) = identifier(input)?; - Ok((input, ValueExpression::Parameter(name.to_string()))) + let (input, name) = parameter(input)?; + Ok((input, ValueExpression::Parameter(name))) } // Parse a function call: function_name(args) @@ -973,9 +972,8 @@ fn boolean_literal(input: &str) -> IResult<&str, bool> { // Parse a parameter reference fn parameter(input: &str) -> IResult<&str, String> { - let (input, _) = char('$')(input)?; - let (input, name) = identifier(input)?; - Ok((input, name.to_string())) + // Only support $param syntax + map(preceded(char('$'), identifier), |s| s.to_string())(input) } // Parse comma with optional whitespace @@ -1699,6 +1697,58 @@ mod tests { } } + #[test] + fn test_parse_multiple_parameters() { + let query = "MATCH (p:Person) WHERE p.age > $min_age AND p.age < $max_age RETURN p"; + let result = parse_cypher_query(query); + assert!( + result.is_ok(), + "Multiple parameters should parse successfully" + ); + + let ast = result.unwrap(); + let where_clause = ast.where_clause.expect("Expected WHERE clause"); + + match where_clause.expression { + BooleanExpression::And(left, right) => { + // Check left: p.age > $min_age + match *left { + BooleanExpression::Comparison { + right: val_right, .. + } => match val_right { + ValueExpression::Parameter(name) => { + assert_eq!(name, "min_age"); + } + _ => panic!("Expected Parameter min_age"), + }, + _ => panic!("Expected comparison on left"), + } + + // Check right: p.age < $max_age + match *right { + BooleanExpression::Comparison { + right: val_right, .. + } => match val_right { + ValueExpression::Parameter(name) => { + assert_eq!(name, "max_age"); + } + _ => panic!("Expected Parameter max_age"), + }, + _ => panic!("Expected comparison on right"), + } + } + _ => panic!("Expected AND expression"), + } + } + + #[test] + fn test_parse_parameter_formats() { + // Test $param (should succeed) + let query = "MATCH (p:Person) WHERE p.age > $min_age RETURN p"; + let result = parse_cypher_query(query); + assert!(result.is_ok(), "$param should parse successfully"); + } + #[test] fn test_vector_distance_metrics() { for metric in &["cosine", "l2", "dot"] { diff --git a/crates/lance-graph/src/query.rs b/crates/lance-graph/src/query.rs index bdf33384..b3a2dcec 100644 --- a/crates/lance-graph/src/query.rs +++ b/crates/lance-graph/src/query.rs @@ -101,13 +101,16 @@ impl CypherQuery { K: Into, V: Into, { - self.parameters.insert(key.into(), value.into()); + self.parameters + .insert(key.into().to_lowercase(), value.into()); self } /// Add multiple parameters to the query pub fn with_parameters(mut self, params: HashMap) -> Self { - self.parameters.extend(params); + for (k, v) in params { + self.parameters.insert(k.to_lowercase(), v); + } self } @@ -769,7 +772,7 @@ impl CypherQuery { // Phase 1: Semantic Analysis let mut analyzer = SemanticAnalyzer::new(config.clone()); - let semantic = analyzer.analyze(&self.ast)?; + let semantic = analyzer.analyze(&self.ast, &self.parameters)?; if !semantic.errors.is_empty() { return Err(GraphError::PlanError { message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")), @@ -779,7 +782,8 @@ impl CypherQuery { // Phase 2: Graph Logical Plan let mut logical_planner = LogicalPlanner::new(config); - let logical_plan = logical_planner.plan(&self.ast)?; + // Use the transformed AST (with parameters substituted) + let logical_plan = logical_planner.plan(&semantic.ast)?; // Phase 3: DataFusion Logical Plan let df_planner = DataFusionPlanner::with_catalog(config.clone(), catalog); @@ -943,7 +947,7 @@ impl CypherQuery { // Ensure we don't silently ignore unsupported features (e.g. scalar functions). let mut analyzer = SemanticAnalyzer::new(config); - let semantic = analyzer.analyze(&self.ast)?; + let semantic = analyzer.analyze(&self.ast, &self.parameters)?; if !semantic.errors.is_empty() { return Err(GraphError::PlanError { message: format!("Semantic analysis failed:\n{}", semantic.errors.join("\n")), @@ -1489,12 +1493,16 @@ mod tests { fn test_query_with_parameters() { let mut params = HashMap::new(); params.insert("minAge".to_string(), serde_json::Value::Number(30.into())); + params.insert("maxAge".to_string(), serde_json::Value::Number(50.into())); - let query = CypherQuery::new("MATCH (n:Person) WHERE n.age > $minAge RETURN n.name") - .unwrap() - .with_parameters(params); + let query = CypherQuery::new( + "MATCH (n:Person) WHERE n.age > $minAge AND n.age < $maxAge RETURN n.name", + ) + .unwrap() + .with_parameters(params); - assert!(query.parameters().contains_key("minAge")); + assert!(query.parameters().contains_key("minage")); + assert!(query.parameters().contains_key("maxage")); } #[test] diff --git a/crates/lance-graph/src/semantic.rs b/crates/lance-graph/src/semantic.rs index 065e36ae..c1b64028 100644 --- a/crates/lance-graph/src/semantic.rs +++ b/crates/lance-graph/src/semantic.rs @@ -54,6 +54,8 @@ pub enum ScopeType { /// Semantic analysis result with validated and enriched AST #[derive(Debug, Clone)] pub struct SemanticResult { + /// The AST with parameters substituted and validated + pub ast: CypherQuery, pub variables: HashMap, pub errors: Vec, pub warnings: Vec, @@ -69,13 +71,23 @@ impl SemanticAnalyzer { } /// Analyze a Cypher query AST - pub fn analyze(&mut self, query: &CypherQuery) -> Result { + pub fn analyze( + &mut self, + query: &CypherQuery, + parameters: &HashMap, + ) -> Result { + // Clone the query to perform parameter substitution + let mut analyzed_query = query.clone(); + + // Perform parameter substitution + self.substitute_parameters(&mut analyzed_query, parameters)?; + let mut errors = Vec::new(); let mut warnings = Vec::new(); // Phase 1: Variable discovery in READING clauses (MATCH/UNWIND) self.current_scope = ScopeType::Match; - for clause in &query.reading_clauses { + for clause in &analyzed_query.reading_clauses { match clause { ReadingClause::Match(match_clause) => { if let Err(e) = self.analyze_match_clause(match_clause) { @@ -91,7 +103,7 @@ impl SemanticAnalyzer { } // Phase 2: Validate WHERE clause (before WITH) - if let Some(where_clause) = &query.where_clause { + if let Some(where_clause) = &analyzed_query.where_clause { self.current_scope = ScopeType::Where; if let Err(e) = self.analyze_where_clause(where_clause) { errors.push(format!("WHERE clause error: {}", e)); @@ -99,7 +111,7 @@ impl SemanticAnalyzer { } // Phase 3: Validate WITH clause if present - if let Some(with_clause) = &query.with_clause { + if let Some(with_clause) = &analyzed_query.with_clause { self.current_scope = ScopeType::With; if let Err(e) = self.analyze_with_clause(with_clause) { errors.push(format!("WITH clause error: {}", e)); @@ -108,7 +120,7 @@ impl SemanticAnalyzer { // Phase 4: Variable discovery in post-WITH READING clauses (query chaining) self.current_scope = ScopeType::Match; - for clause in &query.post_with_reading_clauses { + for clause in &analyzed_query.post_with_reading_clauses { match clause { ReadingClause::Match(match_clause) => { if let Err(e) = self.analyze_match_clause(match_clause) { @@ -124,7 +136,7 @@ impl SemanticAnalyzer { } // Phase 4: Validate post-WITH WHERE clause if present - if let Some(post_where) = &query.post_with_where_clause { + if let Some(post_where) = &analyzed_query.post_with_where_clause { self.current_scope = ScopeType::PostWithWhere; if let Err(e) = self.analyze_where_clause(post_where) { errors.push(format!("Post-WITH WHERE clause error: {}", e)); @@ -133,12 +145,12 @@ impl SemanticAnalyzer { // Phase 5: Validate RETURN clause self.current_scope = ScopeType::Return; - if let Err(e) = self.analyze_return_clause(&query.return_clause) { + if let Err(e) = self.analyze_return_clause(&analyzed_query.return_clause) { errors.push(format!("RETURN clause error: {}", e)); } // Phase 6: Validate ORDER BY clause - if let Some(order_by) = &query.order_by { + if let Some(order_by) = &analyzed_query.order_by { self.current_scope = ScopeType::OrderBy; if let Err(e) = self.analyze_order_by_clause(order_by) { errors.push(format!("ORDER BY clause error: {}", e)); @@ -152,6 +164,7 @@ impl SemanticAnalyzer { self.validate_types(&mut errors); Ok(SemanticResult { + ast: analyzed_query, variables: self.variables.clone(), errors, warnings, @@ -730,6 +743,14 @@ impl SemanticAnalyzer { } Ok(()) } + /// Substitute parameters with literal values in the AST + fn substitute_parameters( + &self, + query: &mut CypherQuery, + parameters: &HashMap, + ) -> Result<()> { + crate::parameter_substitution::substitute_parameters(query, parameters) + } } #[cfg(test)] @@ -772,7 +793,7 @@ mod tests { skip: None, }; let mut analyzer = SemanticAnalyzer::new(test_config()); - analyzer.analyze(&query) + analyzer.analyze(&query, &HashMap::new()) } // Helper: analyze a query with a single MATCH (var:label) and a RETURN expression @@ -802,7 +823,7 @@ mod tests { skip: None, }; let mut analyzer = SemanticAnalyzer::new(test_config()); - analyzer.analyze(&query) + analyzer.analyze(&query, &HashMap::new()) } #[test] @@ -833,7 +854,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(test_config()); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result.errors.is_empty()); let n = result.variables.get("n").expect("variable n present"); // Labels merged @@ -882,7 +903,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(test_config()); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result .errors .iter() @@ -914,7 +935,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(test_config()); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result .errors .iter() @@ -956,7 +977,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(test_config()); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result .errors .iter() @@ -985,7 +1006,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(test_config()); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result .warnings .iter() @@ -1025,7 +1046,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(custom_config); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result .errors .iter() @@ -1070,7 +1091,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(test_config()); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); assert!(result .errors .iter() @@ -1137,7 +1158,7 @@ mod tests { }; let mut analyzer = SemanticAnalyzer::new(custom_config); - let result = analyzer.analyze(&query).unwrap(); + let result = analyzer.analyze(&query, &HashMap::new()).unwrap(); let r = result.variables.get("r").expect("variable r present"); // Types merged assert!(r.labels.contains(&"KNOWS".to_string())); @@ -1147,6 +1168,58 @@ mod tests { assert!(r.properties.contains("level")); } + #[test] + fn test_parameter_substitution() { + // MATCH (n:Person) WHERE n.age > $min_age RETURN n + let node = NodePattern::new(Some("n".to_string())).with_label("Person"); + let where_clause = WhereClause { + expression: BooleanExpression::Comparison { + left: ValueExpression::Property(PropertyRef::new("n", "age")), + operator: crate::ast::ComparisonOperator::GreaterThan, + right: ValueExpression::Parameter("min_age".to_string()), + }, + }; + let query = CypherQuery { + reading_clauses: vec![ReadingClause::Match(MatchClause { + patterns: vec![GraphPattern::Node(node)], + })], + where_clause: Some(where_clause), + with_clause: None, + post_with_reading_clauses: vec![], + post_with_where_clause: None, + return_clause: ReturnClause { + distinct: false, + items: vec![ReturnItem { + expression: ValueExpression::Variable("n".to_string()), + alias: None, + }], + }, + limit: None, + order_by: None, + skip: None, + }; + + let mut parameters = HashMap::new(); + parameters.insert("min_age".to_string(), serde_json::json!(18)); + + let mut analyzer = SemanticAnalyzer::new(test_config()); + let result = analyzer + .analyze(&query, ¶meters) + .expect("Analysis failed"); + + // Verify substitution in AST + let where_clause = result.ast.where_clause.as_ref().unwrap(); + match &where_clause.expression { + BooleanExpression::Comparison { right, .. } => match right { + ValueExpression::Literal(PropertyValue::Integer(val)) => { + assert_eq!(*val, 18); + } + _ => panic!("Expected Integer literal, got {:?}", right), + }, + _ => panic!("Expected Comparison expression"), + } + } + #[test] fn test_function_argument_undefined_variable_in_return() { // RETURN toUpper(m.name) diff --git a/crates/lance-graph/tests/test_datafusion_pipeline.rs b/crates/lance-graph/tests/test_datafusion_pipeline.rs index b1b203bb..906f27e6 100644 --- a/crates/lance-graph/tests/test_datafusion_pipeline.rs +++ b/crates/lance-graph/tests/test_datafusion_pipeline.rs @@ -5101,3 +5101,52 @@ async fn test_datafusion_variable_reuse_multi_pattern_optimization() { person_scan_count ); } + +#[tokio::test] +async fn test_datafusion_parameter_filtering_age() { + let config = create_graph_config(); + let person_batch = create_person_dataset(); + + let mut params = HashMap::new(); + // Filter for people older than 30 (Bob:35, David:40) + params.insert("min_age".to_string(), serde_json::json!(30)); + + let query = CypherQuery::new("MATCH (p:Person) WHERE p.age > $min_age RETURN p.name, p.age") + .unwrap() + .with_config(config) + .with_parameters(params); + + let mut datasets = HashMap::new(); + datasets.insert("Person".to_string(), person_batch); + + let result = query + .execute(datasets, Some(ExecutionStrategy::DataFusion)) + .await + .unwrap(); + + // Should return 2 people (Bob:35, David:40) + assert_eq!(result.num_rows(), 2); + assert_eq!(result.num_columns(), 2); + + let names = result + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + let ages = result + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + + let mut results = Vec::new(); + for i in 0..result.num_rows() { + results.push((names.value(i).to_string(), ages.value(i))); + } + + results.sort(); + assert_eq!( + results, + vec![("Bob".to_string(), 35), ("David".to_string(), 40)] + ); +} diff --git a/python/python/tests/test_graph.py b/python/python/tests/test_graph.py index 4fca84fa..bb1143fa 100644 --- a/python/python/tests/test_graph.py +++ b/python/python/tests/test_graph.py @@ -215,3 +215,30 @@ def test_execute_with_directory_namespace(graph_env, tmp_path): data = result.to_pydict() assert set(data["p.name"]) == {"Bob", "David"} + + +def test_cypher_parameter_syntax(graph_env): + """Test Cypher parameter syntax ($).""" + config, datasets, _ = graph_env + + # 1. Test $param + query_dollar = CypherQuery( + "MATCH (p:Person) WHERE p.age > $age RETURN p.name" + ).with_config(config) + result = query_dollar.with_parameter("age", 30).execute(datasets) + data = result.to_pydict() + assert set(data["p.name"]) == {"Bob", "David"} + + # 2. Test multiple parameters + query_multi = CypherQuery( + "MATCH (p:Person) WHERE p.age > $min_age AND p.age < $max_age RETURN p.name" + ).with_config(config) + result = ( + query_multi.with_parameter("min_age", 25) + .with_parameter("max_age", 35) + .execute(datasets) + ) + data = result.to_pydict() + # Should get Alice (28), Carol (29), Bob (34) + # David is 42 (excluded) + assert set(data["p.name"]) == {"Alice", "Carol", "Bob"}