Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 56 additions & 5 deletions src/execute/cte.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub struct CteDefinition {
pub name: String,
/// Full SQL text of the CTE body (including the SELECT statement inside)
pub body: String,
/// Optional column aliases: WITH t(value, label) AS (...) → ["value", "label"]
pub column_aliases: Vec<String>,
}

/// Extract CTE definitions from the source tree
Expand All @@ -36,17 +38,22 @@ pub fn extract_ctes(source_tree: &SourceTree) -> Vec<CteDefinition> {
/// Parse a single CTE definition node into a CteDefinition
fn parse_cte_definition(node: &Node, source: &str) -> Option<CteDefinition> {
let mut name: Option<String> = None;
let mut column_aliases: Vec<String> = Vec::new();
let mut body_start: Option<usize> = None;
let mut body_end: Option<usize> = None;

let mut cursor = node.walk();
for child in node.children(&mut cursor) {
match child.kind() {
"identifier" => {
name = Some(get_node_text(&child, source).to_string());
// First identifier is the CTE name, subsequent ones are column aliases
if name.is_none() {
name = Some(get_node_text(&child, source).to_string());
} else {
column_aliases.push(get_node_text(&child, source).to_string());
}
}
"select_statement" => {
// The SELECT inside the CTE
"select_statement" | "subquery_body" | "with_statement" => {
body_start = Some(child.start_byte());
body_end = Some(child.end_byte());
}
Expand All @@ -57,7 +64,11 @@ fn parse_cte_definition(node: &Node, source: &str) -> Option<CteDefinition> {
match (name, body_start, body_end) {
(Some(n), Some(start), Some(end)) => {
let body = source[start..end].to_string();
Some(CteDefinition { name: n, body })
Some(CteDefinition {
name: n,
body,
column_aliases,
})
}
_ => None,
}
Expand Down Expand Up @@ -136,9 +147,27 @@ pub fn materialize_ctes(ctes: &[CteDefinition], reader: &dyn Reader) -> Result<H
let temp_table_name = naming::cte_table(&cte.name);

// Execute the CTE body SQL to get a DataFrame, then register it
let df = reader.execute_sql(&transformed_body).map_err(|e| {
let mut df = reader.execute_sql(&transformed_body).map_err(|e| {
GgsqlError::ReaderError(format!("Failed to materialize CTE '{}': {}", cte.name, e))
})?;

// Apply column aliases if present: WITH t(value, label) AS (...) renames columns
if !cte.column_aliases.is_empty() && cte.column_aliases.len() == df.width() {
let current_names: Vec<String> = df
.get_column_names()
.iter()
.map(|s| s.to_string())
.collect();
for (old, new) in current_names.iter().zip(cte.column_aliases.iter()) {
df.rename(old, new.into()).map_err(|e| {
GgsqlError::ReaderError(format!(
"Failed to apply column alias '{}' for CTE '{}': {}",
new, cte.name, e
))
})?;
}
}

reader.register(&temp_table_name, df, true).map_err(|e| {
GgsqlError::ReaderError(format!("Failed to register CTE '{}': {}", cte.name, e))
})?;
Expand Down Expand Up @@ -287,6 +316,28 @@ mod tests {
assert_eq!(ctes[1].name, "targets");
}

#[test]
fn test_extract_ctes_with_column_aliases() {
let sql = "WITH t(value, label) AS (SELECT * FROM (VALUES (70, 'Target'))) SELECT * FROM t";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);

assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].name, "t");
assert_eq!(ctes[0].column_aliases, vec!["value", "label"]);
}

#[test]
fn test_extract_ctes_without_column_aliases() {
let sql = "WITH sales AS (SELECT * FROM raw_sales) SELECT * FROM sales";
let source_tree = SourceTree::new(sql).unwrap();
let ctes = extract_ctes(&source_tree);

assert_eq!(ctes.len(), 1);
assert_eq!(ctes[0].name, "sales");
assert!(ctes[0].column_aliases.is_empty());
}

#[test]
fn test_extract_ctes_none() {
let sql = "SELECT * FROM sales WHERE year = 2024";
Expand Down
30 changes: 30 additions & 0 deletions src/execute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,36 @@ mod tests {
assert_eq!(layer1_df.height(), 2);
}

#[cfg(feature = "duckdb")]
#[test]
fn test_layer_references_cte_with_column_aliases() {
let reader = DuckDBReader::from_connection_string("duckdb://memory").unwrap();

let query = r#"
WITH t(value, label) AS (
SELECT * FROM (VALUES
(70, 'Target'),
(80, 'Warning'),
(90, 'Critical')
)
)
SELECT 1 AS date, 75 AS temperature
VISUALISE
DRAW line MAPPING date AS x, temperature AS y
DRAW rule MAPPING value AS y, label AS colour FROM t
"#;

let result = prepare_data_with_reader(query, &reader).unwrap();

// Layer 0: line from global data
let layer0_df = result.data.get(&naming::layer_key(0)).unwrap();
assert_eq!(layer0_df.height(), 1);

// Layer 1: rule from CTE with column aliases
let layer1_df = result.data.get(&naming::layer_key(1)).unwrap();
assert_eq!(layer1_df.height(), 3);
}

#[cfg(feature = "duckdb")]
#[test]
fn test_histogram_stat_transform() {
Expand Down
9 changes: 8 additions & 1 deletion tree-sitter-ggsql/grammar.js
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,18 @@ module.exports = grammar({

cte_definition: $ => seq(
$.identifier,
optional(seq( // Optional column list: df(x, y, id)
'(',
$.identifier,
repeat(seq(',', $.identifier)),
')'
)),
caseInsensitive('AS'),
'(',
choice(
$.with_statement, // Allow nested CTEs
$.select_statement
$.select_statement,
$.subquery_body // VALUES (...) and other non-SELECT bodies
),
')'
),
Expand Down
Loading