@@ -2704,11 +2704,7 @@ func TestEatIdentifier(t *testing.T) {
27042704func TestExtractSetStatementsFromHints (t * testing.T ) {
27052705 t .Parallel ()
27062706
2707- parser , err := NewStatementParser (databasepb .DatabaseDialect_GOOGLE_STANDARD_SQL , 1000 )
2708- if err != nil {
2709- t .Fatal (err )
2710- }
2711- tests := []struct {
2707+ googleSQLTests := []struct {
27122708 input string
27132709 want * ParsedSetStatement
27142710 wantErr bool
@@ -2766,6 +2762,10 @@ func TestExtractSetStatementsFromHints(t *testing.T) {
27662762 input : "select * from my_table" ,
27672763 want : nil ,
27682764 },
2765+ {
2766+ input : "@{} select * from my_table" ,
2767+ want : nil ,
2768+ },
27692769 {
27702770 input : "@{ foo is 'bar' } select * from my_table" ,
27712771 wantErr : true ,
@@ -2809,34 +2809,7 @@ func TestExtractSetStatementsFromHints(t *testing.T) {
28092809 wantErr : true ,
28102810 },
28112811 }
2812- for _ , test := range tests {
2813- t .Run (test .input , func (t * testing.T ) {
2814- statement , err := parser .extractSetStatementsFromHints (test .input )
2815- if test .wantErr {
2816- if err == nil {
2817- t .Fatal ("missing expected error" )
2818- }
2819- } else {
2820- if err != nil {
2821- t .Fatal (err )
2822- }
2823- opts := cmpopts .IgnoreUnexported (ParsedSetStatement {})
2824- if ! cmp .Equal (statement , test .want , opts ) {
2825- t .Fatalf ("mismatch (-want +got):\n %s" , cmp .Diff (test .want , statement , opts ))
2826- }
2827- }
2828- })
2829- }
2830- }
2831-
2832- func TestExtractSetStatementsFromHintsPostgreSQL (t * testing.T ) {
2833- t .Parallel ()
2834-
2835- parser , err := NewStatementParser (databasepb .DatabaseDialect_POSTGRESQL , 1000 )
2836- if err != nil {
2837- t .Fatal (err )
2838- }
2839- tests := []struct {
2812+ pgTests := []struct {
28402813 input string
28412814 want * ParsedSetStatement
28422815 wantErr bool
@@ -2894,6 +2867,10 @@ func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) {
28942867 input : "select * from my_table" ,
28952868 want : nil ,
28962869 },
2870+ {
2871+ input : "/*@*/select * from my_table" ,
2872+ want : nil ,
2873+ },
28972874 {
28982875 input : "/*@ foo is 'bar' */ select * from my_table" ,
28992876 wantErr : true ,
@@ -2937,24 +2914,44 @@ func TestExtractSetStatementsFromHintsPostgreSQL(t *testing.T) {
29372914 wantErr : true ,
29382915 },
29392916 }
2940- for _ , test := range tests {
2941- t .Run (test .input , func (t * testing.T ) {
2942- statement , err := parser .extractSetStatementsFromHints (test .input )
2943- if test .wantErr {
2944- if err == nil {
2945- t .Fatal ("missing expected error" )
2946- }
2947- } else {
2948- if err != nil {
2949- t .Fatal (err )
2950- }
2951- opts := cmpopts .IgnoreUnexported (ParsedSetStatement {})
2952- if ! cmp .Equal (statement , test .want , opts ) {
2953- t .Fatalf ("mismatch (-want +got):\n %s" , cmp .Diff (test .want , statement , opts ))
2917+
2918+ runHintTests := func (t * testing.T , dialect databasepb.DatabaseDialect , tests []struct {
2919+ input string
2920+ want * ParsedSetStatement
2921+ wantErr bool
2922+ }) {
2923+ parser , err := NewStatementParser (dialect , 1000 )
2924+ if err != nil {
2925+ t .Fatal (err )
2926+ }
2927+ for _ , test := range tests {
2928+ t .Run (test .input , func (t * testing.T ) {
2929+ statement , err := parser .extractSetStatementsFromHints (test .input )
2930+ if test .wantErr {
2931+ if err == nil {
2932+ t .Fatal ("missing expected error" )
2933+ }
2934+ } else {
2935+ if err != nil {
2936+ t .Fatal (err )
2937+ }
2938+ opts := cmpopts .IgnoreUnexported (ParsedSetStatement {})
2939+ if ! cmp .Equal (statement , test .want , opts ) {
2940+ t .Fatalf ("mismatch (-want +got):\n %s" , cmp .Diff (test .want , statement , opts ))
2941+ }
29542942 }
2955- }
2956- })
2943+ })
2944+ }
29572945 }
2946+
2947+ t .Run ("GoogleSQL" , func (t * testing.T ) {
2948+ t .Parallel ()
2949+ runHintTests (t , databasepb .DatabaseDialect_GOOGLE_STANDARD_SQL , googleSQLTests )
2950+ })
2951+ t .Run ("PostgreSQL" , func (t * testing.T ) {
2952+ t .Parallel ()
2953+ runHintTests (t , databasepb .DatabaseDialect_POSTGRESQL , pgTests )
2954+ })
29582955}
29592956
29602957func TestSplit (t * testing.T ) {
0 commit comments