diff --git a/internal/compiler/expand.go b/internal/compiler/expand.go index c60b7618b2..9649c65f93 100644 --- a/internal/compiler/expand.go +++ b/internal/compiler/expand.go @@ -132,7 +132,8 @@ func (c *Compiler) expandStmt(qc *QueryCatalog, raw *ast.RawStmt, node ast.Node) } } for _, t := range tables { - if scope != "" && scope != t.Rel.Name { + isOldNew := strings.EqualFold(scope, "old") || strings.EqualFold(scope, "new") + if scope != "" && !isOldNew && scope != t.Rel.Name { continue } tableName := c.quoteIdent(t.Rel.Name) diff --git a/internal/compiler/output_columns.go b/internal/compiler/output_columns.go index dbd486359a..07fbe67516 100644 --- a/internal/compiler/output_columns.go +++ b/internal/compiler/output_columns.go @@ -3,6 +3,7 @@ package compiler import ( "errors" "fmt" + "strings" "github.com/sqlc-dev/sqlc/internal/sql/ast" "github.com/sqlc-dev/sqlc/internal/sql/astutils" @@ -269,7 +270,8 @@ func (c *Compiler) outputColumns(qc *QueryCatalog, node ast.Node) ([]*Column, er // TODO: This code is copied in func expand() for _, t := range tables { scope := astutils.Join(n.Fields, ".") - if scope != "" && scope != t.Rel.Name { + isOldNew := strings.EqualFold(scope, "old") || strings.EqualFold(scope, "new") + if scope != "" && !isOldNew && scope != t.Rel.Name { continue } for _, c := range t.Columns { @@ -669,7 +671,8 @@ func outputColumnRefs(res *ast.ResTarget, tables []*Table, node *ast.ColumnRef) if schema != "" && t.Rel.Schema != schema { continue } - if alias != "" && t.Rel.Name != alias { + isOldNew := strings.EqualFold(alias, "old") || strings.EqualFold(alias, "new") + if alias != "" && !isOldNew && t.Rel.Name != alias { continue } for _, c := range t.Columns { diff --git a/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/db.go b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/db.go new file mode 100644 index 0000000000..3b320aa168 --- /dev/null +++ b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/db.go @@ -0,0 +1,31 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +import ( + "context" + "database/sql" +) + +type DBTX interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func New(db DBTX) *Queries { + return &Queries{db: db} +} + +type Queries struct { + db DBTX +} + +func (q *Queries) WithTx(tx *sql.Tx) *Queries { + return &Queries{ + db: tx, + } +} diff --git a/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/models.go b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/models.go new file mode 100644 index 0000000000..937e7bcdef --- /dev/null +++ b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/models.go @@ -0,0 +1,11 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 + +package querytest + +type Foo struct { + ID int32 + Bar string + Baz int32 +} diff --git a/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/query.sql.go b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/query.sql.go new file mode 100644 index 0000000000..e8bf1b3e3a --- /dev/null +++ b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/go/query.sql.go @@ -0,0 +1,74 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.30.0 +// source: query.sql + +package querytest + +import ( + "context" +) + +const deleteReturningOldStar = `-- name: DeleteReturningOldStar :one +DELETE FROM foo WHERE id = $1 RETURNING old.id, old.bar, old.baz +` + +func (q *Queries) DeleteReturningOldStar(ctx context.Context, id int32) (Foo, error) { + row := q.db.QueryRowContext(ctx, deleteReturningOldStar, id) + var i Foo + err := row.Scan(&i.ID, &i.Bar, &i.Baz) + return i, err +} + +const updateReturningNewStar = `-- name: UpdateReturningNewStar :one +UPDATE foo SET bar = $1 WHERE id = $2 RETURNING new.id, new.bar, new.baz +` + +type UpdateReturningNewStarParams struct { + Bar string + ID int32 +} + +func (q *Queries) UpdateReturningNewStar(ctx context.Context, arg UpdateReturningNewStarParams) (Foo, error) { + row := q.db.QueryRowContext(ctx, updateReturningNewStar, arg.Bar, arg.ID) + var i Foo + err := row.Scan(&i.ID, &i.Bar, &i.Baz) + return i, err +} + +const updateReturningOldNewCols = `-- name: UpdateReturningOldNewCols :one +UPDATE foo SET bar = $1 WHERE id = $2 RETURNING OLD.bar, NEW.bar +` + +type UpdateReturningOldNewColsParams struct { + Bar string + ID int32 +} + +type UpdateReturningOldNewColsRow struct { + Bar string + Bar_2 string +} + +func (q *Queries) UpdateReturningOldNewCols(ctx context.Context, arg UpdateReturningOldNewColsParams) (UpdateReturningOldNewColsRow, error) { + row := q.db.QueryRowContext(ctx, updateReturningOldNewCols, arg.Bar, arg.ID) + var i UpdateReturningOldNewColsRow + err := row.Scan(&i.Bar, &i.Bar_2) + return i, err +} + +const updateReturningOldStar = `-- name: UpdateReturningOldStar :one +UPDATE foo SET bar = $1 WHERE id = $2 RETURNING old.id, old.bar, old.baz +` + +type UpdateReturningOldStarParams struct { + Bar string + ID int32 +} + +func (q *Queries) UpdateReturningOldStar(ctx context.Context, arg UpdateReturningOldStarParams) (Foo, error) { + row := q.db.QueryRowContext(ctx, updateReturningOldStar, arg.Bar, arg.ID) + var i Foo + err := row.Scan(&i.ID, &i.Bar, &i.Baz) + return i, err +} diff --git a/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/query.sql b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/query.sql new file mode 100644 index 0000000000..6418ebea22 --- /dev/null +++ b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/query.sql @@ -0,0 +1,11 @@ +-- name: UpdateReturningOldStar :one +UPDATE foo SET bar = $1 WHERE id = $2 RETURNING OLD.*; + +-- name: UpdateReturningNewStar :one +UPDATE foo SET bar = $1 WHERE id = $2 RETURNING NEW.*; + +-- name: UpdateReturningOldNewCols :one +UPDATE foo SET bar = $1 WHERE id = $2 RETURNING OLD.bar, NEW.bar; + +-- name: DeleteReturningOldStar :one +DELETE FROM foo WHERE id = $1 RETURNING OLD.*; diff --git a/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/schema.sql b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/schema.sql new file mode 100644 index 0000000000..441a8b87af --- /dev/null +++ b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/schema.sql @@ -0,0 +1,5 @@ +CREATE TABLE foo ( + id serial primary key, + bar text not null, + baz int not null default 0 +); diff --git a/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/sqlc.json b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/sqlc.json new file mode 100644 index 0000000000..f717ca2e66 --- /dev/null +++ b/internal/endtoend/testdata/returning_old_new/postgresql/stdlib/sqlc.json @@ -0,0 +1,12 @@ +{ + "version": "1", + "packages": [ + { + "path": "go", + "engine": "postgresql", + "name": "querytest", + "schema": "schema.sql", + "queries": "query.sql" + } + ] +}