diff --git a/parser.go b/parser.go index f0c25ba..2b31e1e 100644 --- a/parser.go +++ b/parser.go @@ -2373,6 +2373,16 @@ func (p *Parser) parsePrefix() (sqlast.Node, error) { func (p *Parser) parseFunction(name *sqlast.ObjectName) (sqlast.Node, error) { p.expectToken(sqltoken.LParen) + + var filter *sqlast.Ident + if ok, _, _ := p.parseKeyword("DISTINCT"); ok { + p.prevToken() + filter, _ = p.parseIdentifier() + } else if ok, _, _ := p.parseKeyword("ALL"); ok { + p.prevToken() + filter, _ = p.parseIdentifier() + } + args, err := p.parseOptionalArgs() if err != nil { return nil, errors.Errorf("parseOptionalArgs failed: %w", err) @@ -2434,6 +2444,7 @@ func (p *Parser) parseFunction(name *sqlast.ObjectName) (sqlast.Node, error) { Args: args, Over: over, ArgsRParen: r.To, + Filter: filter, }, nil } diff --git a/parser_test.go b/parser_test.go index e10e857..df7f1f1 100644 --- a/parser_test.go +++ b/parser_test.go @@ -908,6 +908,56 @@ FROM user WHERE id BETWEEN 1 AND 2`, }, }, }, + { + name: "count distinct", + in: "SELECT COUNT(DISTINCT email) FROM user", + out: &sqlast.QueryStmt{ + Body: &sqlast.SQLSelect{ + Select: sqltoken.NewPos(1, 1), + Projection: []sqlast.SQLSelectItem{ + &sqlast.UnnamedSelectItem{ + Node: &sqlast.Function{ + Name: &sqlast.ObjectName{ + Idents: []*sqlast.Ident{ + { + Value: "COUNT", + From: sqltoken.NewPos(1, 8), + To: sqltoken.NewPos(1, 13), + }, + }, + }, + Args: []sqlast.Node{ + &sqlast.Ident{ + Value: "email", + From: sqltoken.NewPos(1, 23), + To: sqltoken.NewPos(1, 28), + }, + }, + Filter: &sqlast.Ident{ + Value: "DISTINCT", + From: sqltoken.NewPos(1, 14), + To: sqltoken.NewPos(1, 22), + }, + ArgsRParen: sqltoken.NewPos(1, 29), + }, + }, + }, + FromClause: []sqlast.TableReference{ + &sqlast.Table{ + Name: &sqlast.ObjectName{ + Idents: []*sqlast.Ident{ + { + Value: "user", + From: sqltoken.NewPos(1, 35), + To: sqltoken.NewPos(1, 39), + }, + }, + }, + }, + }, + }, + }, + }, } for _, c := range cases { diff --git a/sqlast/ast.go b/sqlast/ast.go index 20144c4..a55876a 100644 --- a/sqlast/ast.go +++ b/sqlast/ast.go @@ -405,6 +405,7 @@ func (s *UnaryExpr) WriteTo(w io.Writer) (int64, error) { type Function struct { Name *ObjectName // Function Name Args []Node + Filter *Ident ArgsRParen sqltoken.Pos // function args RParen position Over *WindowSpec OverRparen sqltoken.Pos // Over RParen position (if Over is not nil)