diff --git a/src/execute/cte.rs b/src/execute/cte.rs index b83abc33..5a6b665f 100644 --- a/src/execute/cte.rs +++ b/src/execute/cte.rs @@ -16,6 +16,8 @@ pub struct CteDefinition { pub name: String, /// Full SQL text of the CTE body (including the SELECT statement inside) pub body: String, + /// Optional column aliases: WITH t(value, label) AS (...) → ["value", "label"] + pub column_aliases: Vec, } /// Extract CTE definitions from the source tree @@ -36,6 +38,7 @@ pub fn extract_ctes(source_tree: &SourceTree) -> Vec { /// Parse a single CTE definition node into a CteDefinition fn parse_cte_definition(node: &Node, source: &str) -> Option { let mut name: Option = None; + let mut column_aliases: Vec = Vec::new(); let mut body_start: Option = None; let mut body_end: Option = None; @@ -43,10 +46,14 @@ fn parse_cte_definition(node: &Node, source: &str) -> Option { for child in node.children(&mut cursor) { match child.kind() { "identifier" => { - name = Some(get_node_text(&child, source).to_string()); + // First identifier is the CTE name, subsequent ones are column aliases + if name.is_none() { + name = Some(get_node_text(&child, source).to_string()); + } else { + column_aliases.push(get_node_text(&child, source).to_string()); + } } - "select_statement" => { - // The SELECT inside the CTE + "select_statement" | "subquery_body" | "with_statement" => { body_start = Some(child.start_byte()); body_end = Some(child.end_byte()); } @@ -57,7 +64,11 @@ fn parse_cte_definition(node: &Node, source: &str) -> Option { match (name, body_start, body_end) { (Some(n), Some(start), Some(end)) => { let body = source[start..end].to_string(); - Some(CteDefinition { name: n, body }) + Some(CteDefinition { + name: n, + body, + column_aliases, + }) } _ => None, } @@ -136,9 +147,27 @@ pub fn materialize_ctes(ctes: &[CteDefinition], reader: &dyn Reader) -> Result = df + .get_column_names() + .iter() + .map(|s| s.to_string()) + .collect(); + for (old, new) in current_names.iter().zip(cte.column_aliases.iter()) { + df.rename(old, new.into()).map_err(|e| { + GgsqlError::ReaderError(format!( + "Failed to apply column alias '{}' for CTE '{}': {}", + new, cte.name, e + )) + })?; + } + } + reader.register(&temp_table_name, df, true).map_err(|e| { GgsqlError::ReaderError(format!("Failed to register CTE '{}': {}", cte.name, e)) })?; @@ -287,6 +316,28 @@ mod tests { assert_eq!(ctes[1].name, "targets"); } + #[test] + fn test_extract_ctes_with_column_aliases() { + let sql = "WITH t(value, label) AS (SELECT * FROM (VALUES (70, 'Target'))) SELECT * FROM t"; + let source_tree = SourceTree::new(sql).unwrap(); + let ctes = extract_ctes(&source_tree); + + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].name, "t"); + assert_eq!(ctes[0].column_aliases, vec!["value", "label"]); + } + + #[test] + fn test_extract_ctes_without_column_aliases() { + let sql = "WITH sales AS (SELECT * FROM raw_sales) SELECT * FROM sales"; + let source_tree = SourceTree::new(sql).unwrap(); + let ctes = extract_ctes(&source_tree); + + assert_eq!(ctes.len(), 1); + assert_eq!(ctes[0].name, "sales"); + assert!(ctes[0].column_aliases.is_empty()); + } + #[test] fn test_extract_ctes_none() { let sql = "SELECT * FROM sales WHERE year = 2024"; diff --git a/src/execute/mod.rs b/src/execute/mod.rs index 470a55cd..7cdfdd91 100644 --- a/src/execute/mod.rs +++ b/src/execute/mod.rs @@ -1366,6 +1366,36 @@ mod tests { assert_eq!(layer1_df.height(), 2); } + #[cfg(feature = "duckdb")] + #[test] + fn test_layer_references_cte_with_column_aliases() { + let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap(); + + let query = r#" + WITH t(value, label) AS ( + SELECT * FROM (VALUES + (70, 'Target'), + (80, 'Warning'), + (90, 'Critical') + ) + ) + SELECT 1 AS date, 75 AS temperature + VISUALISE + DRAW line MAPPING date AS x, temperature AS y + DRAW rule MAPPING value AS y, label AS colour FROM t + "#; + + let result = prepare_data_with_reader(query, &reader).unwrap(); + + // Layer 0: line from global data + let layer0_df = result.data.get(&naming::layer_key(0)).unwrap(); + assert_eq!(layer0_df.height(), 1); + + // Layer 1: rule from CTE with column aliases + let layer1_df = result.data.get(&naming::layer_key(1)).unwrap(); + assert_eq!(layer1_df.height(), 3); + } + #[cfg(feature = "duckdb")] #[test] fn test_histogram_stat_transform() { diff --git a/tree-sitter-ggsql/grammar.js b/tree-sitter-ggsql/grammar.js index e6583468..7a01d95f 100644 --- a/tree-sitter-ggsql/grammar.js +++ b/tree-sitter-ggsql/grammar.js @@ -81,11 +81,18 @@ module.exports = grammar({ cte_definition: $ => seq( $.identifier, + optional(seq( // Optional column list: df(x, y, id) + '(', + $.identifier, + repeat(seq(',', $.identifier)), + ')' + )), caseInsensitive('AS'), '(', choice( $.with_statement, // Allow nested CTEs - $.select_statement + $.select_statement, + $.subquery_body // VALUES (...) and other non-SELECT bodies ), ')' ),