Skip to content

Commit d311b82

Browse files
committed
chore: address review comments
1 parent 9ac2928 commit d311b82

File tree

2 files changed

+48
-50
lines changed

2 files changed

+48
-50
lines changed

parser/simple_parser.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package parser
1616

1717
import (
18+
"bytes"
1819
"strings"
1920
"unicode"
2021
"unicode/utf8"
@@ -402,7 +403,7 @@ func (p *simpleParser) skipStatementHint() (bool, int) {
402403
// Skip all other whitespaces and comments, but not comments that contain a PG hint.
403404
p.skipWhitespacesAndCommentsWithPgHintOption( /*skipPgHints=*/ false)
404405
// Check if the next tokens are a PG hint.
405-
if len(p.sql) > p.pos+2 && p.sql[p.pos] == '/' && p.sql[p.pos+1] == '*' && p.sql[p.pos+2] == '@' {
406+
if bytes.HasPrefix(p.sql[p.pos:], postgreSqlStatementHintPrefix) {
406407
startPos := p.pos
407408
// Move to the end of this comment.
408409
p.pos = p.statementParser.skipMultiLineComment(p.sql, p.pos)

parser/statement_parser_test.go

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -2704,11 +2704,7 @@ func TestEatIdentifier(t *testing.T) {
27042704
func TestExtractSetStatementsFromHints(t *testing.T) {
27052705
t.Parallel()
27062706

2707-
parser, err := NewStatementParser(databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, 1000)
2708-
if err != nil {
2709-
t.Fatal(err)
2710-
}
2711-
tests := []struct {
2707+
googleSQLTests := []struct {
27122708
input string
27132709
want *ParsedSetStatement
27142710
wantErr bool
@@ -2766,6 +2762,10 @@ func TestExtractSetStatementsFromHints(t *testing.T) {
27662762
input: "select * from my_table",
27672763
want: nil,
27682764
},
2765+
{
2766+
input: "@{} select * from my_table",
2767+
want: nil,
2768+
},
27692769
{
27702770
input: "@{ foo is 'bar' } select * from my_table",
27712771
wantErr: true,
@@ -2809,34 +2809,7 @@ func TestExtractSetStatementsFromHints(t *testing.T) {
28092809
wantErr: true,
28102810
},
28112811
}
2812-
for _, test := range tests {
2813-
t.Run(test.input, func(t *testing.T) {
2814-
statement, err := parser.extractSetStatementsFromHints(test.input)
2815-
if test.wantErr {
2816-
if err == nil {
2817-
t.Fatal("missing expected error")
2818-
}
2819-
} else {
2820-
if err != nil {
2821-
t.Fatal(err)
2822-
}
2823-
opts := cmpopts.IgnoreUnexported(ParsedSetStatement{})
2824-
if !cmp.Equal(statement, test.want, opts) {
2825-
t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts))
2826-
}
2827-
}
2828-
})
2829-
}
2830-
}
2831-
2832-
func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) {
2833-
t.Parallel()
2834-
2835-
parser, err := NewStatementParser(databasepb.DatabaseDialect_POSTGRESQL, 1000)
2836-
if err != nil {
2837-
t.Fatal(err)
2838-
}
2839-
tests := []struct {
2812+
pgTests := []struct {
28402813
input string
28412814
want *ParsedSetStatement
28422815
wantErr bool
@@ -2894,6 +2867,10 @@ func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) {
28942867
input: "select * from my_table",
28952868
want: nil,
28962869
},
2870+
{
2871+
input: "/*@*/select * from my_table",
2872+
want: nil,
2873+
},
28972874
{
28982875
input: "/*@ foo is 'bar' */ select * from my_table",
28992876
wantErr: true,
@@ -2937,24 +2914,44 @@ func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) {
29372914
wantErr: true,
29382915
},
29392916
}
2940-
for _, test := range tests {
2941-
t.Run(test.input, func(t *testing.T) {
2942-
statement, err := parser.extractSetStatementsFromHints(test.input)
2943-
if test.wantErr {
2944-
if err == nil {
2945-
t.Fatal("missing expected error")
2946-
}
2947-
} else {
2948-
if err != nil {
2949-
t.Fatal(err)
2950-
}
2951-
opts := cmpopts.IgnoreUnexported(ParsedSetStatement{})
2952-
if !cmp.Equal(statement, test.want, opts) {
2953-
t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts))
2917+
2918+
runHintTests := func(t *testing.T, dialect databasepb.DatabaseDialect, tests []struct {
2919+
input string
2920+
want *ParsedSetStatement
2921+
wantErr bool
2922+
}) {
2923+
parser, err := NewStatementParser(dialect, 1000)
2924+
if err != nil {
2925+
t.Fatal(err)
2926+
}
2927+
for _, test := range tests {
2928+
t.Run(test.input, func(t *testing.T) {
2929+
statement, err := parser.extractSetStatementsFromHints(test.input)
2930+
if test.wantErr {
2931+
if err == nil {
2932+
t.Fatal("missing expected error")
2933+
}
2934+
} else {
2935+
if err != nil {
2936+
t.Fatal(err)
2937+
}
2938+
opts := cmpopts.IgnoreUnexported(ParsedSetStatement{})
2939+
if !cmp.Equal(statement, test.want, opts) {
2940+
t.Fatalf("mismatch (-want +got):\n%s", cmp.Diff(test.want, statement, opts))
2941+
}
29542942
}
2955-
}
2956-
})
2943+
})
2944+
}
29572945
}
2946+
2947+
t.Run("GoogleSQL", func(t *testing.T) {
2948+
t.Parallel()
2949+
runHintTests(t, databasepb.DatabaseDialect_GOOGLE_STANDARD_SQL, googleSQLTests)
2950+
})
2951+
t.Run("PostgreSQL", func(t *testing.T) {
2952+
t.Parallel()
2953+
runHintTests(t, databasepb.DatabaseDialect_POSTGRESQL, pgTests)
2954+
})
29582955
}
29592956

29602957
func TestSplit(t *testing.T) {

0 commit comments

Comments
 (0)