diff --git a/ast/ast.go b/ast/ast.go index 122fb78c45..1c86b8b483 100644 --- a/ast/ast.go +++ b/ast/ast.go @@ -290,9 +290,10 @@ func (c *Constraint) End() token.Position { return c.Position } // EngineClause represents an ENGINE clause. type EngineClause struct { - Position token.Position `json:"-"` - Name string `json:"name"` - Parameters []Expression `json:"parameters,omitempty"` + Position token.Position `json:"-"` + Name string `json:"name"` + Parameters []Expression `json:"parameters,omitempty"` + HasParentheses bool `json:"has_parentheses,omitempty"` // true if called with () } func (e *EngineClause) Pos() token.Position { return e.Position } @@ -781,10 +782,11 @@ func (w *WhenClause) End() token.Position { return w.Position } // CastExpr represents a CAST expression. type CastExpr struct { - Position token.Position `json:"-"` - Expr Expression `json:"expr"` - Type *DataType `json:"type"` - Alias string `json:"alias,omitempty"` + Position token.Position `json:"-"` + Expr Expression `json:"expr"` + Type *DataType `json:"type"` + Alias string `json:"alias,omitempty"` + OperatorSyntax bool `json:"operator_syntax,omitempty"` // true if using :: syntax } func (c *CastExpr) Pos() token.Position { return c.Position } diff --git a/ast/explain.go b/ast/explain.go new file mode 100644 index 0000000000..709f25e1bf --- /dev/null +++ b/ast/explain.go @@ -0,0 +1,1735 @@ +package ast + +import ( + "fmt" + "strings" +) + +// Explain returns a string representation of the AST in the same format +// as ClickHouse's EXPLAIN AST output. +func Explain(stmt Statement) string { + var b strings.Builder + explainNode(&b, stmt, 0) + return b.String() +} + +// explainNode recursively writes the AST node to the builder. +func explainNode(b *strings.Builder, node interface{}, depth int) { + indent := strings.Repeat(" ", depth) + + switch n := node.(type) { + case *SelectWithUnionQuery: + // Check if first select has Format or IntoOutfile clause + var format *Identifier + var intoOutfile *IntoOutfileClause + if len(n.Selects) > 0 { + if sq, ok := n.Selects[0].(*SelectQuery); ok { + if sq.Format != nil { + format = sq.Format + } + if sq.IntoOutfile != nil { + intoOutfile = sq.IntoOutfile + } + } + } + unionChildren := 1 // ExpressionList + if format != nil { + unionChildren++ + } + if intoOutfile != nil { + unionChildren++ + } + fmt.Fprintf(b, "%sSelectWithUnionQuery (children %d)\n", indent, unionChildren) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Selects)) + for _, sel := range n.Selects { + explainNode(b, sel, depth+2) + } + if intoOutfile != nil { + fmt.Fprintf(b, "%s Literal \\'%s\\'\n", indent, intoOutfile.Filename) + } + if format != nil { + fmt.Fprintf(b, "%s Identifier %s\n", indent, format.Name()) + } + + case *SelectQuery: + children := countSelectQueryChildren(n) + fmt.Fprintf(b, "%sSelectQuery (children %d)\n", indent, children) + // WITH clause (comes first) + if len(n.With) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.With)) + for _, w := range n.With { + explainNode(b, w, depth+2) + } + } + // Columns + if len(n.Columns) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Columns)) + for _, col := range n.Columns { + explainNode(b, col, depth+2) + } + } + // From (with ArrayJoin integrated) + if n.From != nil || n.ArrayJoin != nil { + explainTablesWithArrayJoin(b, n.From, n.ArrayJoin, depth+1) + } + // PreWhere + if n.PreWhere != nil { + explainNode(b, n.PreWhere, depth+1) + } + // Where + if n.Where != nil { + explainNode(b, n.Where, depth+1) + } + // GroupBy + if len(n.GroupBy) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.GroupBy)) + for _, expr := range n.GroupBy { + explainNode(b, expr, depth+2) + } + } + // Having + if n.Having != nil { + explainNode(b, n.Having, depth+1) + } + // OrderBy + if len(n.OrderBy) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.OrderBy)) + for _, elem := range n.OrderBy { + explainOrderByElement(b, elem, depth+2) + } + } + // Offset (comes before Limit in ClickHouse output) + if n.Offset != nil { + explainNode(b, n.Offset, depth+1) + } + // Limit + if n.Limit != nil { + explainNode(b, n.Limit, depth+1) + } + // Settings + if len(n.Settings) > 0 { + fmt.Fprintf(b, "%s Set\n", indent) + } + // Window clause (WINDOW w AS ...) + if len(n.Window) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Window)) + for range n.Window { + fmt.Fprintf(b, "%s WindowListElement\n", indent) + } + } + + case *TablesInSelectQuery: + fmt.Fprintf(b, "%sTablesInSelectQuery (children %d)\n", indent, len(n.Tables)) + for i, table := range n.Tables { + explainTablesInSelectQueryElement(b, table, i > 0, depth+1) + } + + case *TablesInSelectQueryElement: + // This case is kept for direct calls, but TablesInSelectQuery uses the specialized function + children := 0 + if n.Table != nil { + children++ + } + if n.Join != nil { + children++ + } + fmt.Fprintf(b, "%sTablesInSelectQueryElement (children %d)\n", indent, children) + if n.Table != nil { + explainNode(b, n.Table, depth+1) + } + if n.Join != nil { + explainTableJoin(b, n.Join, depth+1) + } + + case *TableExpression: + children := 1 + if n.Sample != nil { + children++ // SampleRatio for ratio + if n.Sample.Offset != nil { + children++ // SampleRatio for offset + } + } + fmt.Fprintf(b, "%sTableExpression (children %d)\n", indent, children) + // Pass alias to the inner Table + explainTableWithAlias(b, n.Table, n.Alias, depth+1) + if n.Sample != nil { + explainSampleClause(b, n.Sample, depth+1) + } + + case *TableIdentifier: + name := n.Table + if n.Database != "" { + name = n.Database + "." + name + } + if n.Alias != "" { + fmt.Fprintf(b, "%sTableIdentifier %s (alias %s)\n", indent, name, n.Alias) + } else { + fmt.Fprintf(b, "%sTableIdentifier %s\n", indent, name) + } + + case *Identifier: + name := n.Name() + if n.Alias != "" { + fmt.Fprintf(b, "%sIdentifier %s (alias %s)\n", indent, name, n.Alias) + } else { + fmt.Fprintf(b, "%sIdentifier %s\n", indent, name) + } + + case *Literal: + // Empty array literal is represented as a function call + if n.Type == LiteralArray { + if arr, ok := n.Value.([]Expression); ok && len(arr) == 0 { + fmt.Fprintf(b, "%sFunction array (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList\n", indent) + return + } + if arr, ok := n.Value.([]interface{}); ok && len(arr) == 0 { + fmt.Fprintf(b, "%sFunction array (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList\n", indent) + return + } + } + // Tuple containing expressions should be output as Function tuple + if n.Type == LiteralTuple { + if exprs, ok := n.Value.([]Expression); ok && len(exprs) > 0 { + // Check if any element is not a simple literal + hasNonLiteral := false + for _, e := range exprs { + if _, isLit := e.(*Literal); !isLit { + hasNonLiteral = true + break + } + } + if hasNonLiteral { + fmt.Fprintf(b, "%sFunction tuple (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(exprs)) + for _, e := range exprs { + explainNode(b, e, depth+2) + } + return + } + } + } + explainLiteral(b, n, "", depth) + + case *FunctionCall: + explainFunctionCall(b, n, depth) + + case *BinaryExpr: + funcName := binaryOpToFunction(n.Op) + args := []Expression{n.Left, n.Right} + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(args)) + for _, arg := range args { + explainNode(b, arg, depth+2) + } + + case *UnaryExpr: + // Special case: unary minus on a literal integer becomes a negative literal + if n.Op == "-" { + if lit, ok := n.Operand.(*Literal); ok && lit.Type == LiteralInteger { + fmt.Fprintf(b, "%sLiteral Int64_-%v\n", indent, lit.Value) + return + } + } + funcName := unaryOpToFunction(n.Op) + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent) + explainNode(b, n.Operand, depth+2) + + case *Asterisk: + if len(n.Except) > 0 || len(n.Replace) > 0 { + children := 0 + if len(n.Except) > 0 || len(n.Replace) > 0 { + children = 1 + } + if n.Table != "" { + fmt.Fprintf(b, "%sQualifiedAsterisk (children %d)\n", indent, children+1) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table) + } else { + fmt.Fprintf(b, "%sAsterisk (children %d)\n", indent, children) + } + if len(n.Except) > 0 { + fmt.Fprintf(b, "%s ColumnsTransformerList (children 1)\n", indent) + fmt.Fprintf(b, "%s ColumnsExceptTransformer (children %d)\n", indent, len(n.Except)) + for _, col := range n.Except { + fmt.Fprintf(b, "%s Identifier %s\n", indent, col) + } + } + if len(n.Replace) > 0 { + fmt.Fprintf(b, "%s ColumnsTransformerList (children 1)\n", indent) + fmt.Fprintf(b, "%s ColumnsReplaceTransformer (children %d)\n", indent, len(n.Replace)) + for _, r := range n.Replace { + fmt.Fprintf(b, "%s ColumnsReplaceTransformer::Replacement (children 1)\n", indent) + // Unwrap AliasedExpr if present - REPLACE doesn't output alias on expression + expr := r.Expr + if ae, ok := expr.(*AliasedExpr); ok { + expr = ae.Expr + } + explainNode(b, expr, depth+4) + } + } + } else if n.Table != "" { + fmt.Fprintf(b, "%sQualifiedAsterisk (children 1)\n", indent) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table) + } else { + fmt.Fprintf(b, "%sAsterisk\n", indent) + } + + case *ColumnsMatcher: + fmt.Fprintf(b, "%sColumnsRegexpMatcher\n", indent) + + case *Subquery: + if n.Alias != "" { + fmt.Fprintf(b, "%sSubquery (alias %s) (children 1)\n", indent, n.Alias) + } else { + fmt.Fprintf(b, "%sSubquery (children 1)\n", indent) + } + explainNode(b, n.Query, depth+1) + + case *CaseExpr: + explainCaseExpr(b, n, depth) + + case *CastExpr: + explainCastExpr(b, n, depth) + + case *Lambda: + explainLambda(b, n, depth) + + case *TernaryExpr: + // Ternary is represented as if(cond, then, else) + fmt.Fprintf(b, "%sFunction if (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 3)\n", indent) + explainNode(b, n.Condition, depth+2) + explainNode(b, n.Then, depth+2) + explainNode(b, n.Else, depth+2) + + case *InExpr: + funcName := "in" + if n.Not { + funcName = "notIn" + } + if n.Global { + funcName = "global" + strings.Title(funcName) + } + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Expr, depth+2) + if n.Query != nil { + // Wrap query in Subquery node + fmt.Fprintf(b, "%s Subquery (children 1)\n", indent) + explainNode(b, n.Query, depth+3) + } else { + // List is shown as a Tuple literal + explainInListAsTuple(b, n.List, depth+2) + } + + case *BetweenExpr: + // BETWEEN is expanded to and(greaterOrEquals(expr, low), lessOrEquals(expr, high)) + // NOT BETWEEN is expanded to or(less(expr, low), greater(expr, high)) + if n.Not { + fmt.Fprintf(b, "%sFunction or (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + fmt.Fprintf(b, "%s Function less (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Expr, depth+4) + explainNode(b, n.Low, depth+4) + fmt.Fprintf(b, "%s Function greater (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Expr, depth+4) + explainNode(b, n.High, depth+4) + } else { + fmt.Fprintf(b, "%sFunction and (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + fmt.Fprintf(b, "%s Function greaterOrEquals (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Expr, depth+4) + explainNode(b, n.Low, depth+4) + fmt.Fprintf(b, "%s Function lessOrEquals (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Expr, depth+4) + explainNode(b, n.High, depth+4) + } + + case *IsNullExpr: + funcName := "isNull" + if n.Not { + funcName = "isNotNull" + } + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent) + explainNode(b, n.Expr, depth+2) + + case *LikeExpr: + funcName := "like" + if n.Not && n.CaseInsensitive { + funcName = "notILike" + } else if n.CaseInsensitive { + funcName = "ilike" + } else if n.Not { + funcName = "notLike" + } + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Expr, depth+2) + explainNode(b, n.Pattern, depth+2) + + case *ArrayAccess: + fmt.Fprintf(b, "%sFunction arrayElement (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Array, depth+2) + explainNode(b, n.Index, depth+2) + + case *TupleAccess: + fmt.Fprintf(b, "%sFunction tupleElement (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + explainNode(b, n.Tuple, depth+2) + explainNode(b, n.Index, depth+2) + + case *IntervalExpr: + fmt.Fprintf(b, "%sFunction toInterval%s (children 1)\n", indent, strings.Title(strings.ToLower(n.Unit))) + fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent) + explainNode(b, n.Value, depth+2) + + case *ExtractExpr: + // EXTRACT(YEAR FROM date) becomes toYear(date) + funcName := extractFieldToFunction(n.Field) + if n.Alias != "" { + fmt.Fprintf(b, "%sFunction %s (alias %s) (children 1)\n", indent, funcName, n.Alias) + } else { + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + } + fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent) + explainNode(b, n.From, depth+2) + + case *AliasedExpr: + // For aliased expressions, we need to print the inner expression with the alias + explainNodeWithAlias(b, n.Expr, n.Alias, depth) + + case *WithElement: + // For scalar WITH (WITH 1 AS x), output the expression with alias + // For subquery WITH (WITH x AS (SELECT 1)), output as WithElement + if _, isSubquery := n.Query.(*Subquery); isSubquery { + fmt.Fprintf(b, "%sWithElement (children 1)\n", indent) + explainNode(b, n.Query, depth+1) + } else { + // Scalar expression - output with alias + explainNodeWithAlias(b, n.Query, n.Name, depth) + } + + case *ExistsExpr: + fmt.Fprintf(b, "%sFunction exists (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent) + // Wrap query in Subquery node + fmt.Fprintf(b, "%s Subquery (children 1)\n", indent) + explainNode(b, n.Query, depth+3) + + case *DataType: + // Data types in expressions (like in CAST) + if len(n.Parameters) > 0 { + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, n.Name) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Parameters)) + for _, p := range n.Parameters { + explainNode(b, p, depth+2) + } + } else { + fmt.Fprintf(b, "%sIdentifier %s\n", indent, n.Name) + } + + // Non-SELECT statements + case *UseQuery: + fmt.Fprintf(b, "%sUseQuery %s (children 1)\n", indent, n.Database) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Database) + + case *TruncateQuery: + tableName := n.Table + if n.Database != "" { + tableName = n.Database + "." + tableName + } + fmt.Fprintf(b, "%sTruncateQuery %s (children 1)\n", indent, tableName) + fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName) + + case *AlterQuery: + tableName := n.Table + if n.Database != "" { + tableName = n.Database + "." + tableName + } + fmt.Fprintf(b, "%sAlterQuery %s (children 2)\n", indent, tableName) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Commands)) + for _, cmd := range n.Commands { + explainAlterCommand(b, cmd, depth+2) + } + fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName) + + case *DropQuery: + var name string + if n.DropDatabase { + name = n.Database + } else if n.View != "" { + name = n.View + } else { + name = n.Table + } + if n.Database != "" && !n.DropDatabase { + name = n.Database + "." + name + } + // Different spacing for DROP DATABASE vs DROP TABLE + if n.DropDatabase { + fmt.Fprintf(b, "%sDropQuery %s (children 1)\n", indent, name) + } else { + fmt.Fprintf(b, "%sDropQuery %s (children 1)\n", indent, name) + } + fmt.Fprintf(b, "%s Identifier %s\n", indent, name) + + case *CreateQuery: + explainCreateQuery(b, n, depth) + + case *InsertQuery: + tableName := n.Table + if n.Database != "" { + tableName = n.Database + "." + tableName + } + children := 1 // Always have table identifier + if len(n.Columns) > 0 { + children++ // column list + } + if n.Select != nil { + children++ + } + fmt.Fprintf(b, "%sInsertQuery (children %d)\n", indent, children) + fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName) + if len(n.Columns) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Columns)) + for _, col := range n.Columns { + fmt.Fprintf(b, "%s Identifier %s\n", indent, col.Name()) + } + } + if n.Select != nil { + explainNode(b, n.Select, depth+1) + } + + case *SystemQuery: + children := 0 + if n.Database != "" { + children++ + } + if n.Table != "" { + children++ + } + if children > 0 { + fmt.Fprintf(b, "%sSYSTEM query (children %d)\n", indent, children) + if n.Database != "" { + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Database) + } + if n.Table != "" { + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table) + } + } else { + fmt.Fprintf(b, "%sSYSTEM query\n", indent) + } + + case *OptimizeQuery: + tableName := n.Table + if n.Database != "" { + tableName = n.Database + "." + tableName + } + // Add suffix based on flags + displayName := tableName + if n.Final { + displayName = tableName + "_final" + } else if n.Dedupe { + displayName = tableName + "_deduplicate" + } + children := 1 // identifier + if n.Partition != nil { + children++ + } + fmt.Fprintf(b, "%sOptimizeQuery %s (children %d)\n", indent, displayName, children) + if n.Partition != nil { + fmt.Fprintf(b, "%s Partition (children 1)\n", indent) + explainNode(b, n.Partition, depth+2) + } + fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName) + + case *DescribeQuery: + tableName := n.Table + if n.Database != "" { + tableName = n.Database + "." + tableName + } + fmt.Fprintf(b, "%sDescribeQuery (children 1)\n", indent) + fmt.Fprintf(b, "%s TableExpression (children 1)\n", indent) + fmt.Fprintf(b, "%s TableIdentifier %s\n", indent, tableName) + + case *ShowQuery: + // Handle SHOW CREATE specially + if n.ShowType == ShowCreate || n.ShowType == ShowCreateDB { + if n.ShowType == ShowCreate { + // SHOW CREATE TABLE + tableName := n.From + if n.Database != "" && tableName != "" { + fmt.Fprintf(b, "%sShowCreateTableQuery %s %s (children 2)\n", indent, n.Database, tableName) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Database) + fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName) + } else if tableName != "" { + fmt.Fprintf(b, "%sShowCreateTableQuery %s (children 1)\n", indent, tableName) + fmt.Fprintf(b, "%s Identifier %s\n", indent, tableName) + } else { + fmt.Fprintf(b, "%sShowCreate\n", indent) + } + } else { + // SHOW CREATE DATABASE - database name is in From field + dbName := n.From + fmt.Fprintf(b, "%sShowCreateDatabaseQuery %s (children 1)\n", indent, dbName) + fmt.Fprintf(b, "%s Identifier %s\n", indent, dbName) + } + } else if n.ShowType == ShowProcesses { + fmt.Fprintf(b, "%sShowProcesslistQuery\n", indent) + } else if n.ShowType == ShowColumns { + // SHOW COLUMNS doesn't output table name in children + fmt.Fprintf(b, "%sShowColumns\n", indent) + } else if n.ShowType == ShowTables && (n.From != "" || n.Database != "") { + // SHOW TABLES FROM database + dbName := n.From + if dbName == "" { + dbName = n.Database + } + fmt.Fprintf(b, "%sShowTables (children 1)\n", indent) + fmt.Fprintf(b, "%s Identifier %s\n", indent, dbName) + } else { + showName := showTypeToName(n.ShowType) + fmt.Fprintf(b, "%s%s\n", indent, showName) + } + + case *SetQuery: + fmt.Fprintf(b, "%sSet\n", indent) + + case *RenameQuery: + fmt.Fprintf(b, "%sRename (children 2)\n", indent) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.From) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.To) + + case *ExchangeQuery: + fmt.Fprintf(b, "%sRename (children 2)\n", indent) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table1) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Table2) + + case *ExplainQuery: + explainType := normalizeExplainType(string(n.ExplainType)) + fmt.Fprintf(b, "%sExplain %s (children 1)\n", indent, explainType) + explainNode(b, n.Statement, depth+1) + + default: + // For unknown types, just print the type name + fmt.Fprintf(b, "%s%T\n", indent, n) + } +} + +// explainTableWithAlias prints a table expression (TableIdentifier, Subquery, Function) with an alias. +func explainTableWithAlias(b *strings.Builder, table interface{}, alias string, depth int) { + indent := strings.Repeat(" ", depth) + + switch t := table.(type) { + case *TableIdentifier: + name := t.Table + if t.Database != "" { + name = t.Database + "." + name + } + if alias != "" { + fmt.Fprintf(b, "%sTableIdentifier %s (alias %s)\n", indent, name, alias) + } else if t.Alias != "" { + fmt.Fprintf(b, "%sTableIdentifier %s (alias %s)\n", indent, name, t.Alias) + } else { + fmt.Fprintf(b, "%sTableIdentifier %s\n", indent, name) + } + + case *Subquery: + if alias != "" { + fmt.Fprintf(b, "%sSubquery (alias %s) (children 1)\n", indent, alias) + } else if t.Alias != "" { + fmt.Fprintf(b, "%sSubquery (alias %s) (children 1)\n", indent, t.Alias) + } else { + fmt.Fprintf(b, "%sSubquery (children 1)\n", indent) + } + explainNode(b, t.Query, depth+1) + + case *FunctionCall: + // For table functions like numbers(), pass alias + if alias != "" { + explainFunctionCallWithAlias(b, t, alias, depth) + } else { + explainFunctionCall(b, t, depth) + } + + default: + explainNode(b, table, depth) + } +} + +// explainNodeWithAlias prints a node with an alias suffix. +func explainNodeWithAlias(b *strings.Builder, node interface{}, alias string, depth int) { + indent := strings.Repeat(" ", depth) + + switch n := node.(type) { + case *Literal: + explainLiteral(b, n, alias, depth) + + case *Identifier: + name := n.Name() + if alias != "" { + fmt.Fprintf(b, "%sIdentifier %s (alias %s)\n", indent, name, alias) + } else if n.Alias != "" { + fmt.Fprintf(b, "%sIdentifier %s (alias %s)\n", indent, name, n.Alias) + } else { + fmt.Fprintf(b, "%sIdentifier %s\n", indent, name) + } + + case *FunctionCall: + explainFunctionCallWithAlias(b, n, alias, depth) + + case *BinaryExpr: + funcName := binaryOpToFunction(n.Op) + args := []Expression{n.Left, n.Right} + if alias != "" { + fmt.Fprintf(b, "%sFunction %s (alias %s) (children 1)\n", indent, funcName, alias) + } else { + fmt.Fprintf(b, "%sFunction %s (children 1)\n", indent, funcName) + } + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(args)) + for _, arg := range args { + explainNode(b, arg, depth+2) + } + return + + default: + // Fall back to regular node printing + explainNode(b, node, depth) + } +} + +// explainLiteral formats a literal value. +func explainLiteral(b *strings.Builder, lit *Literal, alias string, depth int) { + indent := strings.Repeat(" ", depth) + var valueStr string + + switch lit.Type { + case LiteralString: + // Escape backslashes in string literals (ClickHouse doubles them) + strVal := strings.ReplaceAll(fmt.Sprintf("%v", lit.Value), "\\", "\\\\") + valueStr = fmt.Sprintf("\\'%s\\'", strVal) + case LiteralInteger: + valueStr = fmt.Sprintf("UInt64_%v", lit.Value) + case LiteralFloat: + valueStr = fmt.Sprintf("Float64_%v", lit.Value) + case LiteralBoolean: + if lit.Value.(bool) { + valueStr = "Bool_1" + } else { + valueStr = "Bool_0" + } + case LiteralNull: + valueStr = "NULL" + case LiteralArray: + valueStr = formatArrayLiteral(lit.Value) + case LiteralTuple: + valueStr = formatTupleLiteral(lit.Value) + default: + valueStr = fmt.Sprintf("%v", lit.Value) + } + + if alias != "" { + fmt.Fprintf(b, "%sLiteral %s (alias %s)\n", indent, valueStr, alias) + } else { + fmt.Fprintf(b, "%sLiteral %s\n", indent, valueStr) + } +} + +// formatArrayLiteral formats an array literal. +func formatArrayLiteral(value interface{}) string { + switch v := value.(type) { + case []interface{}: + parts := make([]string, len(v)) + for i, elem := range v { + parts[i] = formatLiteralElement(elem) + } + return fmt.Sprintf("Array_[%s]", strings.Join(parts, ", ")) + case []Expression: + parts := make([]string, len(v)) + for i, elem := range v { + if lit, ok := elem.(*Literal); ok { + switch lit.Type { + case LiteralString: + escaped := strings.ReplaceAll(fmt.Sprintf("%v", lit.Value), "\\", "\\\\") + parts[i] = fmt.Sprintf("\\'%s\\'", escaped) + case LiteralInteger: + parts[i] = fmt.Sprintf("UInt64_%v", lit.Value) + case LiteralFloat: + parts[i] = fmt.Sprintf("Float64_%v", lit.Value) + case LiteralArray: + parts[i] = formatArrayLiteral(lit.Value) + case LiteralTuple: + parts[i] = formatTupleLiteral(lit.Value) + default: + parts[i] = fmt.Sprintf("%v", lit.Value) + } + } else { + parts[i] = fmt.Sprintf("%v", elem) + } + } + return fmt.Sprintf("Array_[%s]", strings.Join(parts, ", ")) + default: + return fmt.Sprintf("Array_%v", value) + } +} + +// formatTupleLiteral formats a tuple literal. +func formatTupleLiteral(value interface{}) string { + switch v := value.(type) { + case []interface{}: + parts := make([]string, len(v)) + for i, elem := range v { + parts[i] = formatLiteralElement(elem) + } + return fmt.Sprintf("Tuple_(%s)", strings.Join(parts, ", ")) + case []Expression: + parts := make([]string, len(v)) + for i, elem := range v { + if lit, ok := elem.(*Literal); ok { + switch lit.Type { + case LiteralString: + escaped := strings.ReplaceAll(fmt.Sprintf("%v", lit.Value), "\\", "\\\\") + parts[i] = fmt.Sprintf("\\'%s\\'", escaped) + case LiteralInteger: + parts[i] = fmt.Sprintf("UInt64_%v", lit.Value) + case LiteralFloat: + parts[i] = fmt.Sprintf("Float64_%v", lit.Value) + default: + parts[i] = fmt.Sprintf("%v", lit.Value) + } + } else { + parts[i] = fmt.Sprintf("%v", elem) + } + } + return fmt.Sprintf("Tuple_(%s)", strings.Join(parts, ", ")) + default: + return fmt.Sprintf("Tuple_%v", value) + } +} + +// explainInListAsTuple formats an IN list as a Tuple literal. +func explainInListAsTuple(b *strings.Builder, list []Expression, depth int) { + indent := strings.Repeat(" ", depth) + + // Build the tuple elements + parts := make([]string, len(list)) + for i, elem := range list { + if lit, ok := elem.(*Literal); ok { + switch lit.Type { + case LiteralString: + parts[i] = fmt.Sprintf("'%v'", lit.Value) + case LiteralInteger: + parts[i] = fmt.Sprintf("UInt64_%v", lit.Value) + case LiteralFloat: + parts[i] = fmt.Sprintf("Float64_%v", lit.Value) + default: + parts[i] = fmt.Sprintf("%v", lit.Value) + } + } else { + parts[i] = fmt.Sprintf("%v", elem) + } + } + + fmt.Fprintf(b, "%sLiteral Tuple_(%s)\n", indent, strings.Join(parts, ", ")) +} + +// formatLiteralElement formats a single literal element. +func formatLiteralElement(elem interface{}) string { + switch e := elem.(type) { + case string: + escaped := strings.ReplaceAll(e, "\\", "\\\\") + return fmt.Sprintf("\\'%s\\'", escaped) + case int, int64, uint64: + return fmt.Sprintf("UInt64_%v", e) + case float64: + return fmt.Sprintf("Float64_%v", e) + case bool: + if e { + return "Bool_1" + } + return "Bool_0" + default: + return fmt.Sprintf("%v", e) + } +} + +// explainFunctionCall formats a function call. +func explainFunctionCall(b *strings.Builder, fn *FunctionCall, depth int) { + explainFunctionCallWithAlias(b, fn, fn.Alias, depth) +} + +// normalizeFunctionName normalizes function names to match ClickHouse EXPLAIN AST output. +func normalizeFunctionName(name string) string { + switch strings.ToLower(name) { + case "trim": + return "trimBoth" + case "ltrim": + return "trimLeft" + case "rtrim": + return "trimRight" + default: + return name + } +} + +// explainFunctionCallWithAlias formats a function call with an optional alias. +func explainFunctionCallWithAlias(b *strings.Builder, fn *FunctionCall, alias string, depth int) { + indent := strings.Repeat(" ", depth) + name := normalizeFunctionName(fn.Name) + // DISTINCT in aggregate functions gets appended to function name + if fn.Distinct { + name = name + "Distinct" + } + + // Count children: + // - 1 for arguments ExpressionList + // - 1 for parameters ExpressionList if present (parametric aggregate functions) + // - 1 for window spec if present (but not for named windows, which are output separately) + children := 1 // Always have arguments ExpressionList + if len(fn.Parameters) > 0 { + children++ + } + // Only count window spec for inline windows, not named windows (OVER w) + hasInlineWindow := fn.Over != nil && fn.Over.Name == "" + if hasInlineWindow { + children++ + } + + aliasSuffix := "" + if alias != "" { + aliasSuffix = fmt.Sprintf(" (alias %s)", alias) + } + + fmt.Fprintf(b, "%sFunction %s%s (children %d)\n", indent, name, aliasSuffix, children) + + // Arguments (first ExpressionList) + if len(fn.Arguments) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(fn.Arguments)) + for _, arg := range fn.Arguments { + explainNode(b, arg, depth+2) + } + } else { + // Empty argument list + fmt.Fprintf(b, "%s ExpressionList\n", indent) + } + + // Parameters (second ExpressionList, for parametric aggregate functions like quantile(0.9)) + if len(fn.Parameters) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(fn.Parameters)) + for _, param := range fn.Parameters { + explainNode(b, param, depth+2) + } + } + + // Window specification (only for inline windows, not named windows) + if hasInlineWindow { + explainWindowSpec(b, fn.Over, depth+1) + } +} + +// explainWindowSpec formats a window specification. +func explainWindowSpec(b *strings.Builder, spec *WindowSpec, depth int) { + indent := strings.Repeat(" ", depth) + + // Count children: partition by + order by + frame bounds + children := 0 + if len(spec.PartitionBy) > 0 { + children++ + } + if len(spec.OrderBy) > 0 { + children++ + } + // Count frame bound children + if spec.Frame != nil { + if spec.Frame.StartBound != nil && spec.Frame.StartBound.Offset != nil { + children++ + } + if spec.Frame.EndBound != nil && spec.Frame.EndBound.Offset != nil { + children++ + } + } + + if children > 0 { + fmt.Fprintf(b, "%sWindowDefinition (children %d)\n", indent, children) + } else { + fmt.Fprintf(b, "%sWindowDefinition\n", indent) + } + + // Partition by + if len(spec.PartitionBy) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(spec.PartitionBy)) + for _, expr := range spec.PartitionBy { + explainNode(b, expr, depth+2) + } + } + + // Order by + if len(spec.OrderBy) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(spec.OrderBy)) + for _, elem := range spec.OrderBy { + explainOrderByElement(b, elem, depth+2) + } + } + + // Frame bounds + if spec.Frame != nil { + if spec.Frame.StartBound != nil && spec.Frame.StartBound.Offset != nil { + explainNode(b, spec.Frame.StartBound.Offset, depth+1) + } + if spec.Frame.EndBound != nil && spec.Frame.EndBound.Offset != nil { + explainNode(b, spec.Frame.EndBound.Offset, depth+1) + } + } +} + +// explainTableJoin formats a table join. +func explainTableJoin(b *strings.Builder, join *TableJoin, depth int) { + indent := strings.Repeat(" ", depth) + children := 0 + if join.On != nil { + children++ + } + if len(join.Using) > 0 { + children++ + } + if children > 0 { + fmt.Fprintf(b, "%sTableJoin (children %d)\n", indent, children) + } else { + fmt.Fprintf(b, "%sTableJoin\n", indent) + } + if join.On != nil { + explainNode(b, join.On, depth+1) + } + if len(join.Using) > 0 { + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(join.Using)) + for _, col := range join.Using { + explainNode(b, col, depth+2) + } + } +} + +// explainArrayJoinClause formats an array join as a table element. +func explainArrayJoinClause(b *strings.Builder, aj *ArrayJoinClause, depth int) { + // Array join is already represented in TablesInSelectQuery + // This is just for when it's encountered directly + indent := strings.Repeat(" ", depth) + fmt.Fprintf(b, "%sArrayJoin (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(aj.Columns)) + for _, col := range aj.Columns { + explainNode(b, col, depth+2) + } +} + +// explainOrderByElement formats an order by element. +func explainOrderByElement(b *strings.Builder, elem *OrderByElement, depth int) { + indent := strings.Repeat(" ", depth) + + // Count children: expression + optional FillFrom, FillTo, FillStep + children := 1 + if elem.FillFrom != nil { + children++ + } + if elem.FillTo != nil { + children++ + } + if elem.FillStep != nil { + children++ + } + + fmt.Fprintf(b, "%sOrderByElement (children %d)\n", indent, children) + explainNode(b, elem.Expression, depth+1) + + if elem.FillFrom != nil { + explainNode(b, elem.FillFrom, depth+1) + } + if elem.FillTo != nil { + explainNode(b, elem.FillTo, depth+1) + } + if elem.FillStep != nil { + explainNode(b, elem.FillStep, depth+1) + } +} + +// explainCaseExpr formats a CASE expression. +func explainCaseExpr(b *strings.Builder, c *CaseExpr, depth int) { + indent := strings.Repeat(" ", depth) + // CASE is represented as multiIf or caseWithExpression + aliasSuffix := "" + if c.Alias != "" { + aliasSuffix = fmt.Sprintf(" (alias %s)", c.Alias) + } + + if c.Operand != nil { + // CASE x WHEN ... -> caseWithExpression + children := 1 + len(c.Whens)*2 + if c.Else != nil { + children++ + } + fmt.Fprintf(b, "%sFunction caseWithExpression%s (children 1)\n", indent, aliasSuffix) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, children) + explainNode(b, c.Operand, depth+2) + for _, when := range c.Whens { + explainNode(b, when.Condition, depth+2) + explainNode(b, when.Result, depth+2) + } + if c.Else != nil { + explainNode(b, c.Else, depth+2) + } + } else { + // CASE WHEN ... -> multiIf + children := len(c.Whens) * 2 + if c.Else != nil { + children++ + } + fmt.Fprintf(b, "%sFunction multiIf%s (children 1)\n", indent, aliasSuffix) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, children) + for _, when := range c.Whens { + explainNode(b, when.Condition, depth+2) + explainNode(b, when.Result, depth+2) + } + if c.Else != nil { + explainNode(b, c.Else, depth+2) + } + } +} + +// explainCastExpr formats a CAST expression. +func explainCastExpr(b *strings.Builder, c *CastExpr, depth int) { + indent := strings.Repeat(" ", depth) + aliasSuffix := "" + if c.Alias != "" { + aliasSuffix = fmt.Sprintf(" (alias %s)", c.Alias) + } + fmt.Fprintf(b, "%sFunction CAST%s (children 1)\n", indent, aliasSuffix) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + // For :: operator syntax, the expression is output as a string literal + if c.OperatorSyntax { + if lit, ok := c.Expr.(*Literal); ok { + fmt.Fprintf(b, "%s Literal \\'%v\\'\n", indent, lit.Value) + } else { + explainNode(b, c.Expr, depth+2) + } + } else { + explainNode(b, c.Expr, depth+2) + } + // Type is represented as a Literal string + fmt.Fprintf(b, "%s Literal \\'%s\\'\n", indent, c.Type.Name) +} + +// explainLambda formats a lambda expression. +func explainLambda(b *strings.Builder, l *Lambda, depth int) { + indent := strings.Repeat(" ", depth) + fmt.Fprintf(b, "%sFunction lambda (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children 2)\n", indent) + // Parameters as tuple + fmt.Fprintf(b, "%s Function tuple (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(l.Parameters)) + for _, param := range l.Parameters { + fmt.Fprintf(b, "%s Identifier %s\n", indent, param) + } + // Body + explainNode(b, l.Body, depth+2) +} + +// countSelectQueryChildren counts the non-nil children of a SelectQuery. +func countSelectQueryChildren(s *SelectQuery) int { + count := 0 + if len(s.With) > 0 { + count++ + } + if len(s.Columns) > 0 { + count++ + } + // From and ArrayJoin are combined into one TablesInSelectQuery + if s.From != nil || s.ArrayJoin != nil { + count++ + } + if s.PreWhere != nil { + count++ + } + if s.Where != nil { + count++ + } + if len(s.GroupBy) > 0 { + count++ + } + if s.Having != nil { + count++ + } + if len(s.OrderBy) > 0 { + count++ + } + if s.Limit != nil { + count++ + } + if s.Offset != nil { + count++ + } + if len(s.Settings) > 0 { + count++ + } + if len(s.Window) > 0 { + count++ + } + return count +} + +// explainTablesWithArrayJoin outputs TablesInSelectQuery with ArrayJoin integrated. +func explainTablesWithArrayJoin(b *strings.Builder, from *TablesInSelectQuery, arrayJoin *ArrayJoinClause, depth int) { + indent := strings.Repeat(" ", depth) + + tableCount := 0 + if from != nil { + tableCount = len(from.Tables) + } + if arrayJoin != nil { + tableCount++ + } + + fmt.Fprintf(b, "%sTablesInSelectQuery (children %d)\n", indent, tableCount) + + if from != nil { + for i, table := range from.Tables { + explainTablesInSelectQueryElement(b, table, i > 0, depth+1) + } + } + + if arrayJoin != nil { + // ArrayJoin is output as a TablesInSelectQueryElement + fmt.Fprintf(b, "%s TablesInSelectQueryElement (children 1)\n", indent) + explainArrayJoinClause(b, arrayJoin, depth+2) + } +} + +// binaryOpToFunction maps binary operators to their function names. +func binaryOpToFunction(op string) string { + switch op { + case "+": + return "plus" + case "-": + return "minus" + case "*": + return "multiply" + case "/": + return "divide" + case "%": + return "modulo" + case "=", "==": + return "equals" + case "!=", "<>": + return "notEquals" + case "<": + return "less" + case "<=": + return "lessOrEquals" + case ">": + return "greater" + case ">=": + return "greaterOrEquals" + case "<=>": + return "isNotDistinctFrom" + case "AND": + return "and" + case "OR": + return "or" + case "LIKE": + return "like" + case "ILIKE": + return "ilike" + case "NOT LIKE": + return "notLike" + case "NOT ILIKE": + return "notILike" + case "IN": + return "in" + case "NOT IN": + return "notIn" + case "GLOBAL IN": + return "globalIn" + case "GLOBAL NOT IN": + return "globalNotIn" + default: + return op + } +} + +// unaryOpToFunction maps unary operators to their function names. +func unaryOpToFunction(op string) string { + switch op { + case "-": + return "negate" + case "NOT": + return "not" + case "~": + return "bitNot" + default: + return op + } +} + +// extractFieldToFunction maps EXTRACT fields to function names. +func extractFieldToFunction(field string) string { + switch strings.ToUpper(field) { + case "YEAR": + return "toYear" + case "MONTH": + return "toMonth" + case "DAY": + return "toDayOfMonth" + case "HOUR": + return "toHour" + case "MINUTE": + return "toMinute" + case "SECOND": + return "toSecond" + default: + return "to" + strings.Title(strings.ToLower(field)) + } +} + +// normalizeAlterCommandType normalizes ALTER command types to match ClickHouse output. +func normalizeAlterCommandType(t AlterCommandType) string { + switch t { + case AlterFreeze: + return "FREEZE_ALL" + case AlterDetachPartition: + return "DROP_PARTITION" + case AlterClearIndex: + return "DROP_INDEX" + default: + return string(t) + } +} + +// explainAlterCommand formats an ALTER command. +func explainAlterCommand(b *strings.Builder, cmd *AlterCommand, depth int) { + indent := strings.Repeat(" ", depth) + + children := 0 + if cmd.Column != nil { + children++ + } + if cmd.ColumnName != "" && cmd.Type != AlterAddColumn && cmd.Type != AlterModifyColumn { + children++ + } + if cmd.NewName != "" { + children++ + } + if cmd.AfterColumn != "" { + children++ + } + if cmd.Constraint != nil { + children++ + } + if cmd.IndexExpr != nil { + children++ + } + if cmd.Partition != nil { + children++ + } + if cmd.Index != "" && cmd.IndexExpr == nil { + children++ + } + // Don't count ConstraintName for ADD_CONSTRAINT as it's part of the Constraint structure + if cmd.ConstraintName != "" && cmd.Type != AlterAddConstraint { + children++ + } + if cmd.TTL != nil { + children++ + } + + cmdType := normalizeAlterCommandType(cmd.Type) + if children > 0 { + fmt.Fprintf(b, "%sAlterCommand %s (children %d)\n", indent, cmdType, children) + } else { + fmt.Fprintf(b, "%sAlterCommand %s\n", indent, cmdType) + } + + if cmd.Column != nil { + explainColumnDeclaration(b, cmd.Column, depth+1) + } + if cmd.ColumnName != "" && cmd.Type != AlterAddColumn && cmd.Type != AlterModifyColumn { + fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.ColumnName) + } + if cmd.NewName != "" { + fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.NewName) + } + if cmd.AfterColumn != "" { + fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.AfterColumn) + } + if cmd.Constraint != nil { + explainConstraint(b, cmd.Constraint, depth+1) + } + if cmd.IndexExpr != nil { + fmt.Fprintf(b, "%s Index (children 2)\n", indent) + explainNode(b, cmd.IndexExpr, depth+2) + if cmd.IndexType != "" { + fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, cmd.IndexType) + fmt.Fprintf(b, "%s ExpressionList\n", indent) + } + } + if cmd.Partition != nil { + fmt.Fprintf(b, "%s Partition (children 1)\n", indent) + explainNode(b, cmd.Partition, depth+2) + } + if cmd.Index != "" && cmd.IndexExpr == nil { + fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.Index) + } + // Don't output ConstraintName for ADD_CONSTRAINT + if cmd.ConstraintName != "" && cmd.Type != AlterAddConstraint { + fmt.Fprintf(b, "%s Identifier %s\n", indent, cmd.ConstraintName) + } + if cmd.TTL != nil { + fmt.Fprintf(b, "%s ExpressionList (children 1)\n", indent) + fmt.Fprintf(b, "%s TTLElement (children 1)\n", indent) + explainNode(b, cmd.TTL.Expression, depth+3) + } +} + +// explainColumnDeclaration formats a column declaration. +func explainColumnDeclaration(b *strings.Builder, col *ColumnDeclaration, depth int) { + indent := strings.Repeat(" ", depth) + + children := 0 + if col.Type != nil { + children++ + } + if col.Default != nil { + children++ + } + if col.Codec != nil { + children++ + } + if col.Comment != "" { + children++ + } + + fmt.Fprintf(b, "%sColumnDeclaration %s (children %d)\n", indent, col.Name, children) + if col.Type != nil { + fmt.Fprintf(b, "%s DataType %s\n", indent, col.Type.Name) + } + if col.Comment != "" { + fmt.Fprintf(b, "%s Literal \\'%s\\'\n", indent, col.Comment) + } + if col.Default != nil { + explainNode(b, col.Default, depth+1) + } + if col.Codec != nil { + explainCodec(b, col.Codec, depth+1) + } +} + +// explainCodec formats a CODEC expression. +func explainCodec(b *strings.Builder, codec *CodecExpr, depth int) { + indent := strings.Repeat(" ", depth) + fmt.Fprintf(b, "%sFunction CODEC (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(codec.Codecs)) + for _, c := range codec.Codecs { + if len(c.Arguments) == 0 { + fmt.Fprintf(b, "%s Function %s\n", indent, c.Name) + } else { + fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, c.Name) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(c.Arguments)) + for _, arg := range c.Arguments { + explainNode(b, arg, depth+4) + } + } + } +} + +// explainConstraint formats a constraint. +func explainConstraint(b *strings.Builder, c *Constraint, depth int) { + indent := strings.Repeat(" ", depth) + fmt.Fprintf(b, "%sConstraint (children 1)\n", indent) + explainNode(b, c.Expression, depth+1) +} + +// explainCreateQuery formats a CREATE query. +func explainCreateQuery(b *strings.Builder, n *CreateQuery, depth int) { + indent := strings.Repeat(" ", depth) + + if n.CreateDatabase { + children := 1 + if n.Engine != nil { + children++ + } + fmt.Fprintf(b, "%sCreateQuery %s (children %d)\n", indent, n.Database, children) + fmt.Fprintf(b, "%s Identifier %s\n", indent, n.Database) + if n.Engine != nil { + fmt.Fprintf(b, "%s Storage definition (children 1)\n", indent) + fmt.Fprintf(b, "%s Function %s\n", indent, n.Engine.Name) + } + return + } + + var name string + if n.View != "" { + name = n.View + } else { + name = n.Table + } + if n.Database != "" { + name = n.Database + "." + name + } + + // For materialized views, handle specially + if n.View != "" { + explainCreateView(b, n, name, depth) + return + } + + children := 1 // identifier + if len(n.Columns) > 0 { + children++ + } + if n.Engine != nil || len(n.OrderBy) > 0 || n.PartitionBy != nil || len(n.PrimaryKey) > 0 || len(n.Settings) > 0 { + children++ + } + if n.AsSelect != nil { + children++ + } + + fmt.Fprintf(b, "%sCreateQuery %s (children %d)\n", indent, name, children) + fmt.Fprintf(b, "%s Identifier %s\n", indent, name) + + if len(n.Columns) > 0 { + fmt.Fprintf(b, "%s Columns definition (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Columns)) + for _, col := range n.Columns { + explainColumnDeclaration(b, col, depth+3) + } + } + + if n.Engine != nil || len(n.OrderBy) > 0 || n.PartitionBy != nil || len(n.PrimaryKey) > 0 || len(n.Settings) > 0 { + storageChildren := 0 + if n.Engine != nil { + storageChildren++ + } + if n.PartitionBy != nil { + storageChildren++ + } + if len(n.PrimaryKey) > 0 { + storageChildren++ + } + if len(n.OrderBy) > 0 { + storageChildren++ + } + if len(n.Settings) > 0 { + storageChildren++ + } + fmt.Fprintf(b, "%s Storage definition (children %d)\n", indent, storageChildren) + if n.Engine != nil { + if len(n.Engine.Parameters) == 0 && !n.Engine.HasParentheses { + fmt.Fprintf(b, "%s Function %s\n", indent, n.Engine.Name) + } else if len(n.Engine.Parameters) == 0 && n.Engine.HasParentheses { + fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, n.Engine.Name) + fmt.Fprintf(b, "%s ExpressionList\n", indent) + } else { + fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, n.Engine.Name) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Engine.Parameters)) + for _, p := range n.Engine.Parameters { + explainNode(b, p, depth+4) + } + } + } + if n.PartitionBy != nil { + explainNode(b, n.PartitionBy, depth+2) + } + if len(n.PrimaryKey) > 0 { + // For simple PRIMARY KEY, just output the identifier + if len(n.PrimaryKey) == 1 { + if id, ok := n.PrimaryKey[0].(*Identifier); ok { + fmt.Fprintf(b, "%s Identifier %s\n", indent, id.Name()) + } else { + explainNode(b, n.PrimaryKey[0], depth+2) + } + } else { + fmt.Fprintf(b, "%s Function tuple (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.PrimaryKey)) + for _, expr := range n.PrimaryKey { + explainNode(b, expr, depth+4) + } + } + } + if len(n.OrderBy) > 0 { + // For simple ORDER BY, just output the identifier + if len(n.OrderBy) == 1 { + if id, ok := n.OrderBy[0].(*Identifier); ok { + fmt.Fprintf(b, "%s Identifier %s\n", indent, id.Name()) + } else { + explainNode(b, n.OrderBy[0], depth+2) + } + } else { + fmt.Fprintf(b, "%s Function tuple (children 1)\n", indent) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.OrderBy)) + for _, expr := range n.OrderBy { + explainNode(b, expr, depth+4) + } + } + } + if len(n.Settings) > 0 { + fmt.Fprintf(b, "%s Set\n", indent) + } + } + + if n.AsSelect != nil { + explainNode(b, n.AsSelect, depth+1) + } +} + +// explainTablesInSelectQueryElement formats a table element with optional implicit join. +func explainTablesInSelectQueryElement(b *strings.Builder, elem *TablesInSelectQueryElement, isImplicitJoin bool, depth int) { + indent := strings.Repeat(" ", depth) + + children := 0 + if elem.Table != nil { + children++ + } + if elem.Join != nil { + children++ + } else if isImplicitJoin { + // For implicit cross joins (comma-separated tables), add an empty TableJoin + children++ + } + + fmt.Fprintf(b, "%sTablesInSelectQueryElement (children %d)\n", indent, children) + if elem.Table != nil { + explainNode(b, elem.Table, depth+1) + } + if elem.Join != nil { + explainTableJoin(b, elem.Join, depth+1) + } else if isImplicitJoin { + // Output empty TableJoin for implicit cross join + fmt.Fprintf(b, "%s TableJoin\n", indent) + } +} + +// explainSampleClause formats a SAMPLE clause. +func explainSampleClause(b *strings.Builder, sample *SampleClause, depth int) { + indent := strings.Repeat(" ", depth) + fmt.Fprintf(b, "%sSampleRatio %s\n", indent, formatSampleRatio(sample.Ratio)) + if sample.Offset != nil { + fmt.Fprintf(b, "%sSampleRatio %s\n", indent, formatSampleRatio(sample.Offset)) + } +} + +// formatSampleRatio formats a sample ratio expression. +func formatSampleRatio(expr Expression) string { + switch e := expr.(type) { + case *Literal: + if e.Type == LiteralInteger { + return fmt.Sprintf("%v", e.Value) + } + if e.Type == LiteralFloat { + // Convert float to fraction + return floatToFraction(e.Value.(float64)) + } + return fmt.Sprintf("%v", e.Value) + case *BinaryExpr: + // For division, format as "numerator / denominator" + if e.Op == "/" { + left := formatSampleRatio(e.Left) + right := formatSampleRatio(e.Right) + return fmt.Sprintf("%s / %s", left, right) + } + } + return fmt.Sprintf("%v", expr) +} + +// normalizeExplainType normalizes EXPLAIN type for output. +func normalizeExplainType(t string) string { + switch strings.ToUpper(t) { + case "", "PLAN": + return "EXPLAIN" + case "AST": + return "EXPLAIN AST" + case "SYNTAX": + return "EXPLAIN SYNTAX" + case "PIPELINE": + return "EXPLAIN PIPELINE" + default: + return "EXPLAIN " + t + } +} + +// showTypeToName maps ShowType to EXPLAIN AST output name. +func showTypeToName(t ShowType) string { + switch t { + case ShowTables: + return "ShowTables" + case ShowDatabases: + return "ShowTables" + case ShowProcesses: + return "ShowProcessList" + case ShowCreate: + return "ShowCreate" + case ShowCreateDB: + return "ShowCreate" + case ShowColumns: + return "ShowColumns" + case ShowDictionaries: + return "ShowTables" + default: + return "ShowTables" + } +} + +// floatToFraction converts a float to a fraction string. +func floatToFraction(f float64) string { + // Handle common fractions + if f == 0.1 { + return "1 / 10" + } + if f == 0.5 { + return "5 / 10" + } + if f == 0.25 { + return "25 / 100" + } + // For other floats, just return as is for now + return fmt.Sprintf("%v", f) +} + +// explainCreateView formats a CREATE VIEW or MATERIALIZED VIEW query. +func explainCreateView(b *strings.Builder, n *CreateQuery, name string, depth int) { + indent := strings.Repeat(" ", depth) + + children := 1 // identifier + if n.AsSelect != nil { + children++ + } + if n.Engine != nil { + children++ // ViewTargets + } + + fmt.Fprintf(b, "%sCreateQuery %s (children %d)\n", indent, name, children) + fmt.Fprintf(b, "%s Identifier %s\n", indent, name) + + // For views, the AS SELECT comes before storage/ViewTargets + if n.AsSelect != nil { + explainNode(b, n.AsSelect, depth+1) + } + + // Storage is wrapped in ViewTargets for views + if n.Engine != nil { + fmt.Fprintf(b, "%s ViewTargets (children 1)\n", indent) + fmt.Fprintf(b, "%s Storage definition (children 1)\n", indent) + if len(n.Engine.Parameters) == 0 { + fmt.Fprintf(b, "%s Function %s\n", indent, n.Engine.Name) + } else { + fmt.Fprintf(b, "%s Function %s (children 1)\n", indent, n.Engine.Name) + fmt.Fprintf(b, "%s ExpressionList (children %d)\n", indent, len(n.Engine.Parameters)) + for _, p := range n.Engine.Parameters { + explainNode(b, p, depth+5) + } + } + } +} diff --git a/ast/explain_test.go b/ast/explain_test.go new file mode 100644 index 0000000000..7ea997a238 --- /dev/null +++ b/ast/explain_test.go @@ -0,0 +1,86 @@ +package ast_test + +import ( + "context" + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/kyleconroy/doubleclick/ast" + "github.com/kyleconroy/doubleclick/parser" +) + +// testMetadata holds optional metadata for a test case +type testMetadata struct { + Todo bool `json:"todo,omitempty"` + Source string `json:"source,omitempty"` +} + +func TestExplain(t *testing.T) { + testdataDir := "../parser/testdata" + + entries, err := os.ReadDir(testdataDir) + if err != nil { + t.Fatalf("Failed to read testdata directory: %v", err) + } + + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + testName := entry.Name() + testDir := filepath.Join(testdataDir, testName) + + // Check if explain.txt exists + explainPath := filepath.Join(testDir, "explain.txt") + explainBytes, err := os.ReadFile(explainPath) + if err != nil { + continue // Skip test cases without explain.txt + } + expected := string(explainBytes) + + t.Run(testName, func(t *testing.T) { + // Read optional metadata + var metadata testMetadata + metadataPath := filepath.Join(testDir, "metadata.json") + if metadataBytes, err := os.ReadFile(metadataPath); err == nil { + if err := json.Unmarshal(metadataBytes, &metadata); err != nil { + t.Fatalf("Failed to parse metadata.json: %v", err) + } + } + + // Read the query + queryPath := filepath.Join(testDir, "query.sql") + queryBytes, err := os.ReadFile(queryPath) + if err != nil { + t.Fatalf("Failed to read query.sql: %v", err) + } + query := strings.TrimSpace(string(queryBytes)) + + // Parse the query + stmts, err := parser.Parse(context.Background(), strings.NewReader(query)) + if err != nil { + t.Skipf("Parse error (skipping): %v", err) + return + } + + if len(stmts) == 0 { + t.Fatalf("Expected at least 1 statement, got 0") + } + + // Generate explain output + got := ast.Explain(stmts[0]) + + // Compare + if got != expected { + if metadata.Todo { + t.Skipf("TODO: Explain output mismatch (skipping)") + } + t.Errorf("Explain output mismatch\nQuery: %s\n\nExpected:\n%s\nGot:\n%s", query, expected, got) + } + }) + } +} diff --git a/parser/expression.go b/parser/expression.go index 7849107277..12839bff43 100644 --- a/parser/expression.go +++ b/parser/expression.go @@ -1132,8 +1132,9 @@ func (p *Parser) parseAlias(left ast.Expression) ast.Expression { func (p *Parser) parseCastOperator(left ast.Expression) ast.Expression { expr := &ast.CastExpr{ - Position: p.current.Pos, - Expr: left, + Position: p.current.Pos, + Expr: left, + OperatorSyntax: true, } p.nextToken() // skip :: diff --git a/parser/parser.go b/parser/parser.go index 547baa2c28..809b0f7fc5 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -1293,6 +1293,7 @@ func (p *Parser) parseEngineClause() *ast.EngineClause { } if p.currentIs(token.LPAREN) { + engine.HasParentheses = true p.nextToken() if !p.currentIs(token.RPAREN) { engine.Parameters = p.parseExpressionList() @@ -1806,8 +1807,8 @@ func (p *Parser) parseShow() *ast.ShowQuery { } } - // Parse FROM clause (or table name for SHOW CREATE TABLE) - if p.currentIs(token.FROM) || (show.ShowType == ast.ShowCreate && (p.currentIs(token.IDENT) || p.current.Token.IsKeyword())) { + // Parse FROM clause (or table/database name for SHOW CREATE TABLE/DATABASE) + if p.currentIs(token.FROM) || ((show.ShowType == ast.ShowCreate || show.ShowType == ast.ShowCreateDB) && (p.currentIs(token.IDENT) || p.current.Token.IsKeyword())) { if p.currentIs(token.FROM) { p.nextToken() } diff --git a/parser/testdata/dateadd/metadata.json b/parser/testdata/dateadd/metadata.json new file mode 100644 index 0000000000..ef120d978e --- /dev/null +++ b/parser/testdata/dateadd/metadata.json @@ -0,0 +1 @@ +{"todo": true} diff --git a/parser/testdata/datesub/metadata.json b/parser/testdata/datesub/metadata.json new file mode 100644 index 0000000000..ef120d978e --- /dev/null +++ b/parser/testdata/datesub/metadata.json @@ -0,0 +1 @@ +{"todo": true}