From f3ef3d1c461ac1d5f24e9148ab3d094347df4300 Mon Sep 17 00:00:00 2001 From: Brad Fitzpatrick Date: Mon, 16 Feb 2026 13:17:20 -1000 Subject: [PATCH] sqlitepool: delete never-used code This sqlitepool was a Jun 2022 effort to move away from Go's database/sql package and have an SQLite-specific layer instead. But we've never used it, or even finished it. Instead, it just distracts people reading this repo, making them think we use this code. So just delete it. We can always resurrect it from git if we decide to go in that direction later. Updates #cleanup Signed-off-by: Brad Fitzpatrick --- sqlitepool/queryglue.go | 392 ---------------------------------- sqlitepool/queryglue_test.go | 101 --------- sqlitepool/sqlitepool.go | 323 ---------------------------- sqlitepool/sqlitepool_test.go | 178 --------------- sqlitepool/util.go | 88 -------- 5 files changed, 1082 deletions(-) delete mode 100644 sqlitepool/queryglue.go delete mode 100644 sqlitepool/queryglue_test.go delete mode 100644 sqlitepool/sqlitepool.go delete mode 100644 sqlitepool/sqlitepool_test.go delete mode 100644 sqlitepool/util.go diff --git a/sqlitepool/queryglue.go b/sqlitepool/queryglue.go deleted file mode 100644 index 4a84e80..0000000 --- a/sqlitepool/queryglue.go +++ /dev/null @@ -1,392 +0,0 @@ -//go:build cgo - -package sqlitepool - -// This file contains bridging functions designed to let users of -// database/sql move to sqlitepool without changing the semantics -// of their code. -// -// Eventually users should piece-wise migrate to another interface. -// (Or we should invest in this interface? Seems suboptimal.) - -import ( - sqlpkg "database/sql" - "database/sql/driver" - "encoding" - "fmt" - "reflect" - "strings" - "time" - - "github.com/tailscale/sqlite/sqliteh" -) - -// Exec is like database/sql.Tx.Exec. -// Only use this for one-off/rare queries. -// For normal queries, see the Exec method on Tx. -func Exec(db sqliteh.DB, sql string, args ...any) error { - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - return err - } - if err := bindAll(db, stmt, args...); err != nil { - return fmt.Errorf("Exec: %w", err) - } - _, _, _, _, err = stmt.StepResult() - if err != nil { - err = fmt.Errorf("%w: %v", err, db.ErrMsg()) - } - stmt.Finalize() - return err -} - -// QueryRow is like database/sql.Tx.QueryRow. -// Only use this for one-off/rare queries. -// For normal queries, see the methods on Rx. -func QueryRow(db sqliteh.DB, sql string, args ...any) *Row { - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, db.ErrMsg())} - } - if err := bindAll(db, stmt, args...); err != nil { - return &Row{err: fmt.Errorf("QueryRow: %w", err)} - } - row, err := stmt.Step(nil) - if err != nil { - msg := db.ErrMsg() - stmt.Finalize() - return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)} - } - if !row { - stmt.Finalize() - return &Row{err: sqlpkg.ErrNoRows} - } - return &Row{stmt: stmt, oneOff: true} -} - -// Query is like database/sql.Tx.Query. -// Only use this for one-off/rare queries. -// For normal queries, see the methods on Rx. -func Query(db sqliteh.DB, sql string, args ...any) (*Rows, error) { - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - return nil, fmt.Errorf("Query: %w: %v", err, db.ErrMsg()) - } - if err := bindAll(db, stmt, args...); err != nil { - return nil, err - } - return &Rows{stmt: stmt, oneOff: true}, nil -} - -// Exec is like database/sql.Tx.Exec. -func (tx *Tx) Exec(sql string, args ...any) error { - stmt := tx.Prepare(sql) - if err := bindAll(tx.conn.db, stmt, args...); err != nil { - return err - } - _, _, _, _, err := stmt.StepResult() - if err != nil { - return fmt.Errorf("%w: %v", err, tx.conn.db.ErrMsg()) - } - return nil -} - -func (tx *Tx) ExecRes(sql string, args ...any) (rowsAffected int64, err error) { - stmt := tx.Prepare(sql) - if err := bindAll(tx.conn.db, stmt, args...); err != nil { - return 0, err - } - _, _, rowsAffected, _, err = stmt.StepResult() - return rowsAffected, err -} - -// QueryRow is like database/sql.Tx.QueryRow. -func (rx *Rx) QueryRow(sql string, args ...any) *Row { - stmt := rx.Prepare(sql) - if err := bindAll(rx.conn.db, stmt, args...); err != nil { - return &Row{err: fmt.Errorf("QueryRow: %w", err)} - } - row, err := stmt.Step(nil) - if err != nil { - msg := rx.DB().ErrMsg() - stmt.ResetAndClear() - return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)} - } - if !row { - stmt.ResetAndClear() - return &Row{err: sqlpkg.ErrNoRows} - } - return &Row{stmt: stmt} -} - -// Query is like database/sql.Tx.Query. -func (rx *Rx) Query(sql string, args ...any) (*Rows, error) { - stmt := rx.Prepare(sql) - if err := bindAll(rx.conn.db, stmt, args...); err != nil { - return nil, fmt.Errorf("Query: %w", err) - } - return &Rows{stmt: stmt}, nil -} - -// Rows is like database/sql.Tx.Rows. -type Rows struct { - stmt sqliteh.Stmt - err error - oneOff bool -} - -func (rs *Rows) Next() bool { - if rs.err != nil { - return false - } - row, err := rs.stmt.Step(nil) - if err != nil { - rs.err = fmt.Errorf("QueryRow.Next: %w: %v", err, rs.stmt.DBHandle().ErrMsg()) - return false - } - if !row { - rs.stmt.ResetAndClear() - } - return row -} - -func (rs *Rows) Err() error { - return rs.err -} - -func (rs *Rows) Scan(dest ...any) error { - if rs.err != nil { - return rs.err - } - return scanAll(rs.stmt, dest...) -} - -func (rs *Rows) Close() error { - if rs.stmt == nil { - return nil - } - _, err := rs.stmt.ResetAndClear() - msg := rs.stmt.DBHandle().ErrMsg() - var err2 error - if rs.oneOff { - err2 = rs.stmt.Finalize() - } - rs.stmt = nil - if err != nil { - return fmt.Errorf("Rows.ResetAndClear: %w: %v", err, msg) - } - if err2 != nil { - return fmt.Errorf("Rows.ResetAndClear: %w: %v", err2, rs.stmt.DBHandle().ErrMsg()) - } - return nil -} - -// Row is like database/sql.Tx.Row. -type Row struct { - stmt sqliteh.Stmt - err error - oneOff bool -} - -func (r *Row) Err() error { - return r.err -} - -func (r *Row) Scan(dest ...any) error { - if r.err != nil { - return r.err - } - err := scanAll(r.stmt, dest...) - r.stmt.ResetAndClear() - if r.oneOff { - r.stmt.Finalize() - } - return err -} - -type scanner interface { - Scan(value any) error -} - -// scanAll mimics (some of) the sqlite driver's scanning logic, which is -// split across the driver and the database/sql package. -func scanAll(stmt sqliteh.Stmt, dest ...any) error { - for i := 0; i < len(dest); i++ { - if s, ok := dest[i].(scanner); ok { - // We have a handful of *sql.NullInt64 objects in - // our tree, so we implement minimal support for - // them here. TODO: remove some time. - var v any - switch stmt.ColumnType(i) { - case sqliteh.SQLITE_INTEGER: - v = stmt.ColumnInt64(i) - case sqliteh.SQLITE_FLOAT: - v = stmt.ColumnDouble(i) - case sqliteh.SQLITE_TEXT: - v = stmt.ColumnText(i) - case sqliteh.SQLITE_BLOB: - v = stmt.ColumnText(i) - case sqliteh.SQLITE_NULL: - v = nil - } - if err := s.Scan(v); err != nil { - return err - } - continue - } - v := reflect.ValueOf(dest[i]) - if v.Elem().Kind() == reflect.Slice && v.Elem().Type().Elem().Kind() == reflect.Uint8 { - b := append([]byte(nil), stmt.ColumnBlob(i)...) - v.Elem().SetBytes(b) - continue - } - switch v.Elem().Kind() { - case reflect.Bool: - v.Elem().SetBool(stmt.ColumnInt64(i) != 0) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.Elem().SetInt(stmt.ColumnInt64(i)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v.Elem().SetUint(uint64(stmt.ColumnInt64(i))) - case reflect.Float32, reflect.Float64: - v.Elem().SetFloat(stmt.ColumnDouble(i)) - case reflect.String: - v.Elem().SetString(stmt.ColumnText(i)) - default: - return fmt.Errorf("sqlitepool.scan:%d: cannot handle destination kind %v (%T)", i, v.Kind(), dest[i]) - } - } - return nil -} - -func bindAll(db sqliteh.DB, stmt sqliteh.Stmt, args ...any) error { - for i, arg := range args { - if err := bind(db, stmt, i+1, arg); err != nil { - stmt.ResetAndClear() - return fmt.Errorf("bind: %d, %q: %w", i, arg, err) - } - } - return nil -} - -type driverValue interface { - Value() (driver.Value, error) -} - -// bind, from the driver in sqlite.go. -func bind(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) error { - // Start with obvious types, including time.Time before TextMarshaler. - found, err := bindBasic(db, s, ordinal, v) - if err != nil { - return err - } else if found { - return nil - } - - if m, _ := v.(driverValue); m != nil { - // We have a few NullInt64s we need to handle. - // TODO: remove or rethink in the future. - var err error - v, err = m.Value() - if err != nil { - return fmt.Errorf("sqlitepool.bind:%d: bad driver.Value: %w", ordinal, err) - } - if v == nil { - _, err := bindBasic(db, s, ordinal, nil) - return err - } - } - - if m, _ := v.(encoding.TextMarshaler); m != nil { - b, err := m.MarshalText() - if err != nil { - return fmt.Errorf("sqlitepool.bind:%d: cannot marshal %T: %w", ordinal, v, err) - } - _, err = bindBasic(db, s, ordinal, b) - return err - } - - // Look for named basic types or other convertible types. - val := reflect.ValueOf(v) - if val.Kind() == reflect.Pointer { - if val.IsNil() { - _, err := bindBasic(db, s, ordinal, nil) - return err - } - val = val.Elem() - } - typ := reflect.TypeOf(v) - if typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - switch typ.Kind() { - case reflect.Bool: - b := int64(0) - if val.Bool() { - b = 1 - } - _, err := bindBasic(db, s, ordinal, b) - return err - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - var i int64 - if !val.IsZero() { - i = val.Int() - } - _, err := bindBasic(db, s, ordinal, i) - return err - case reflect.Uint, reflect.Uint64: - return fmt.Errorf("sqlitepool.bind:%d: sqlite does not support uint64 (try a string or TextMarshaler)", ordinal) - case reflect.Uint8, reflect.Uint16, reflect.Uint32: - _, err := bindBasic(db, s, ordinal, int64(val.Uint())) - return err - case reflect.Float32, reflect.Float64: - _, err := bindBasic(db, s, ordinal, val.Float()) - return err - case reflect.String: - _, err := bindBasic(db, s, ordinal, val.String()) - return err - } - - return fmt.Errorf("sqlitepool.bind:%d: unknown value type %T (try a string or TextMarshaler)", ordinal, v) -} - -// bindBasic, from the driver in sqlite.go. -func bindBasic(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) (found bool, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("sqlitepool.bind:%d:%T: %w: %v", ordinal, v, err, db.ErrMsg()) - } - }() - switch v := v.(type) { - case nil: - return true, s.BindNull(ordinal) - case string: - return true, s.BindText64(ordinal, v) - case int: - return true, s.BindInt64(ordinal, int64(v)) - case int64: - return true, s.BindInt64(ordinal, v) - case float64: - return true, s.BindDouble(ordinal, v) - case []byte: - if len(v) == 0 { - return true, s.BindZeroBlob64(ordinal, 0) - } else { - return true, s.BindBlob64(ordinal, v) - } - case time.Time: - // Shortest of: - // YYYY-MM-DD HH:MM - // YYYY-MM-DD HH:MM:SS - // YYYY-MM-DD HH:MM:SS.SSS - str := v.Format(timeFormat) - str = strings.TrimSuffix(str, "-0000") - str = strings.TrimSuffix(str, ".000") - str = strings.TrimSuffix(str, ":00") - return true, s.BindText64(ordinal, str) - default: - return false, nil - } -} - -// timeFormat from the driver in sqlite.go. -const timeFormat = "2006-01-02 15:04:05.000-0700" diff --git a/sqlitepool/queryglue_test.go b/sqlitepool/queryglue_test.go deleted file mode 100644 index e29f798..0000000 --- a/sqlitepool/queryglue_test.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build cgo - -package sqlitepool - -import ( - "context" - "database/sql" - "testing" - - "github.com/tailscale/sqlite/sqliteh" - "github.com/tailscale/sqlite/sqlstats" -) - -func TestQueryGlue(t *testing.T) { - ctx := context.Background() - initFn := func(db sqliteh.DB) error { return ExecScript(db, "PRAGMA synchronous=OFF;") } - tracer := &sqlstats.Tracer{} - tempDir := t.TempDir() - p, err := NewPool("file:"+tempDir+"/sqlitepool_queryglue_test", 2, initFn, tracer) - if err != nil { - t.Fatal(err) - } - - tx, err := p.BeginTx(ctx, "insert-1") - if err != nil { - t.Fatal(err) - } - if err := tx.Exec("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)"); err != nil { - t.Fatal(err) - } - if err := Exec(tx.DB(), "INSERT INTO t VALUES (?, ?)", 10, "skip"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 100, "a"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 200, "b"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 300, "c"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 400, "d"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 401, "skip"); err != nil { - t.Fatal(err) - } - - var count int - if err := tx.QueryRow("SELECT count(*) FROM t WHERE id >= ? AND id <= ?", 100, 400).Scan(&count); err != nil { - t.Fatal(err) - } - if count != 4 { - t.Fatalf("count=%d, want 4", count) - } - if err := tx.QueryRow("SELECT id FROM t WHERE id >= ?", 900).Scan(&count); err != sql.ErrNoRows { - t.Fatalf("QueryRow err=%v, want ErrNoRows", err) - } - - rows, err := tx.Query("SELECT * FROM t WHERE id >= ? AND id <= ?", 100, 400) - if err != nil { - t.Fatal(err) - } - for i := 0; i < 4; i++ { - if !rows.Next() { - t.Fatalf("pass %d: Next=false", i) - } - var id int64 - var val string - if err := rows.Scan(&id, &val); err != nil { - t.Fatalf("pass %d: Scan: %v", i, err) - } - if want := int64(i+1) * 100; id != want { - t.Fatalf("pass %d: id=%d, want %d", i, id, want) - } - if want := string([]byte{'a' + byte(i)}); val != want { - t.Fatalf("pass %d: val=%q want %q", i, val, want) - } - } - if rows.Next() { - t.Fatal("too many rows") - } - if err := rows.Err(); err != nil { - t.Fatal(err) - } - if err := rows.Close(); err != nil { - t.Fatal(err) - } - - var concat sql.RawBytes - if err := tx.QueryRow("SELECT val FROM t WHERE id = 401").Scan(&concat); err != nil { - t.Fatal(err) - } - if got, want := string(concat), "skip"; got != want { - t.Fatalf("concat=%q, want %q", got, want) - } - - tx.Rollback() - p.Close() -} diff --git a/sqlitepool/sqlitepool.go b/sqlitepool/sqlitepool.go deleted file mode 100644 index d699be6..0000000 --- a/sqlitepool/sqlitepool.go +++ /dev/null @@ -1,323 +0,0 @@ -//go:build cgo - -// Package sqlitepool implements a pool of SQLite database connections. -package sqlitepool - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/tailscale/sqlite/cgosqlite" - "github.com/tailscale/sqlite/sqliteh" -) - -// A Pool is a fixed-size pool of SQLite database connections. -// One is reserved for writable transactions, the others are -// used for read-only transactions. -type Pool struct { - poolSize int - rwConnFree chan *conn // cap == 1 - roConnsFree chan *conn // cap == poolSize-1 - tracer sqliteh.Tracer - closed chan struct{} -} - -type conn struct { - pool *Pool - db sqliteh.DB - stmts map[string]sqliteh.Stmt // persistent statements on db - id sqliteh.TraceConnID -} - -// NewPool creates a Pool of poolSize database connections. -// -// For each connection, initFn is called to initialize the connection. -// Tracer is used to report statistics about the use of the Pool. -func NewPool(filename string, poolSize int, initFn func(sqliteh.DB) error, tracer sqliteh.Tracer) (_ *Pool, err error) { - p := &Pool{ - poolSize: poolSize, - rwConnFree: make(chan *conn, 1), - roConnsFree: make(chan *conn, poolSize-1), - tracer: tracer, - closed: make(chan struct{}), - } - defer func() { - if err != nil { - err = fmt.Errorf("sqlitepool.NewPool: %w", err) - select { - case conn := <-p.rwConnFree: - conn.db.Close() - default: - } - close(p.roConnsFree) - for conn := range p.roConnsFree { - conn.db.Close() - } - } - }() - if poolSize < 2 { - return nil, fmt.Errorf("poolSize=%d is too small", poolSize) - } - for i := 0; i < poolSize; i++ { - db, err := cgosqlite.Open(filename, sqliteh.OpenFlagsDefault, "") - if err != nil { - return nil, err - } - if err := initFn(db); err != nil { - return nil, err - } - c := &conn{ - pool: p, - db: db, - stmts: make(map[string]sqliteh.Stmt), - id: sqliteh.TraceConnID(i), - } - if i == 0 { - p.rwConnFree <- c - } else { - if err := ExecScript(c.db, "PRAGMA query_only=true"); err != nil { - return nil, err - } - p.roConnsFree <- c - } - } - - return p, nil -} - -func (c *conn) close() error { - if c.db == nil { - return errors.New("sqlitepool conn already closed") - } - for _, stmt := range c.stmts { - stmt.Finalize() - } - c.stmts = nil - err := c.db.Close() - c.db = nil - return err -} - -func (p *Pool) Close() error { - select { - case <-p.closed: - return errors.New("pool already closed") - default: - } - close(p.closed) - - c := <-p.rwConnFree - err := c.close() - - for i := 0; i < p.poolSize-1; i++ { - c := <-p.roConnsFree - err2 := c.close() - if err == nil { - err = err2 - } - } - return err -} - -var errPoolClosed = fmt.Errorf("%w: sqlitepool closed", context.Canceled) - -// BeginTx creates a writable transaction using BEGIN IMMEDIATE. -// The parameter why is passed to the Tracer for debugging. -func (p *Pool) BeginTx(ctx context.Context, why string) (*Tx, error) { - select { - case <-p.closed: - return nil, errPoolClosed - case <-ctx.Done(): - return nil, ctx.Err() - case conn := <-p.rwConnFree: - tx := &Tx{Rx: &Rx{conn: conn, inTx: true}} - err := tx.Exec("BEGIN IMMEDIATE;") - if p.tracer != nil { - p.tracer.BeginTx(ctx, conn.id, why, false, err) - } - if err != nil { - p.rwConnFree <- conn // can't block, buffer is big enough - return nil, err - } - return tx, nil - } -} - -// BeginRx creates a read-only transaction. -// The parameter why is passed to the Tracer for debugging. -func (p *Pool) BeginRx(ctx context.Context, why string) (*Rx, error) { - select { - case <-p.closed: - return nil, errPoolClosed - case <-ctx.Done(): - return nil, ctx.Err() - case conn := <-p.roConnsFree: - rx := &Rx{conn: conn} - err := rx.Exec("BEGIN;") - if p.tracer != nil { - p.tracer.BeginTx(ctx, conn.id, why, true, err) - } - if err != nil { - p.roConnsFree <- conn // can't block, buffer is big enough - return nil, err - } - return &Rx{conn: conn}, nil - } -} - -// Rx is a read-only transaction. -// -// It is *not* safe for concurrent use. -type Rx struct { - conn *conn - inTx bool // true if this Rx is embedded in a writable Tx - - // OnRollback is an optional function called after rollback. - // If Rx is part of a Tx and it is committed, then OnRollback - // is not called. - OnRollback func() -} - -// Exec executes an SQL statement with no result. -func (rx *Rx) Exec(sql string) error { - _, _, _, _, err := rx.Prepare(sql).StepResult() - if err != nil { - return fmt.Errorf("%w: %v", err, rx.conn.db.ErrMsg()) - } - return nil -} - -// Prepare prepares an SQL statement. -// The Stmt is cached on the connection, so subsequent calls are fast. -func (rx *Rx) Prepare(sql string) sqliteh.Stmt { - stmt := rx.conn.stmts[sql] - if stmt != nil { - return stmt - } - stmt, _, err := rx.conn.db.Prepare(sql, sqliteh.SQLITE_PREPARE_PERSISTENT) - if err != nil { - // Persistent statements are constant strings hardcoded into - // programs. Failing to prepare one means the string is bad. - // Ideally we would detect this at compile time, but barring - // that, there is no point returning the error because this - // is not something the program can recover from or handle. - panic(fmt.Sprintf("%v: %v", err, rx.conn.db.ErrMsg())) - } - rx.conn.stmts[sql] = stmt - return stmt -} - -// DB returns the underlying database connection. -// -// Be careful: a transaction is in progress. Any use of BEGIN/COMMIT/ROLLBACK -// should be modelled as a nested transaction, and when done the original -// outer transaction should be left in-progress. -func (rx *Rx) DB() sqliteh.DB { - return rx.conn.db -} - -// ExecScript executes a series of SQL statements against a database connection. -// It is intended for one-off scripts, so the prepared Stmt objects are not -// cached for future calls. -func ExecScript(db sqliteh.DB, queries string) error { - for { - queries = strings.TrimSpace(queries) - if queries == "" { - return nil - } - stmt, rem, err := db.Prepare(queries, 0) - if err != nil { - return fmt.Errorf("ExecScript: %w: %v, in remaining script: %s", err, db.ErrMsg(), queries) - } - queries = rem - _, err = stmt.Step(nil) - if err != nil { - err = fmt.Errorf("ExecScript: %w: %s: %v", err, stmt.SQL(), db.ErrMsg()) - } - stmt.Finalize() - if err != nil { - return err - } - } -} - -// Rollback executes ROLLBACK and cleans up the Rx. -// It is a no-op if Rx is already rolled back. -func (rx *Rx) Rollback() { - if rx.conn == nil { - return - } - if rx.inTx { - panic("Tx.Rx.Rollback called, only call Rollback on the Tx object") - } - err := rx.Exec("ROLLBACK;") - if rx.conn.pool.tracer != nil { - rx.conn.pool.tracer.Rollback(rx.conn.id, err) - } - rx.conn.pool.roConnsFree <- rx.conn - rx.conn = nil - if rx.OnRollback != nil { - rx.OnRollback() - rx.OnRollback = nil - } - if err != nil { - panic(err) - } -} - -// Tx is a writable SQLite database transaction. -// -// It is *not* safe for concurrent use. -// -// A Tx contains an embedded Rx, which can be used to pass to functions -// that want to perform read-only queries on the writable Tx. -type Tx struct { - *Rx - - // OnCommit is an optional function called after successful commit. - OnCommit func() -} - -// Rollback executes ROLLBACK and cleans up the Tx. -// It is a no-op if the Tx is already rolled back or committed. -func (tx *Tx) Rollback() { - if tx.conn == nil { - return - } - err := tx.Exec("ROLLBACK;") - if tx.conn.pool.tracer != nil { - tx.conn.pool.tracer.Rollback(tx.conn.id, err) - } - tx.conn.pool.rwConnFree <- tx.conn - tx.conn = nil - if tx.OnRollback != nil { - tx.OnRollback() - tx.OnRollback = nil - tx.OnCommit = nil - } - if err != nil { - panic(err) - } -} - -// Commit executes COMMIT and cleans up the Tx. -// It is an error to call if the Tx is already rolled back or committed. -func (tx *Tx) Commit() error { - if tx.conn == nil { - return errors.New("tx already done") - } - err := tx.Exec("COMMIT;") - if tx.conn.pool.tracer != nil { - tx.conn.pool.tracer.Commit(tx.conn.id, err) - } - tx.conn.pool.rwConnFree <- tx.conn - tx.conn = nil - if tx.OnCommit != nil { - tx.OnCommit() - tx.OnCommit = nil - tx.OnRollback = nil - } - return err -} diff --git a/sqlitepool/sqlitepool_test.go b/sqlitepool/sqlitepool_test.go deleted file mode 100644 index f476d66..0000000 --- a/sqlitepool/sqlitepool_test.go +++ /dev/null @@ -1,178 +0,0 @@ -//go:build cgo - -package sqlitepool - -import ( - "context" - "errors" - "testing" - - "github.com/tailscale/sqlite/sqliteh" - "github.com/tailscale/sqlite/sqlstats" -) - -func TestPool(t *testing.T) { - ctx := context.Background() - initFn := func(db sqliteh.DB) error { - err := ExecScript(db, ` - PRAGMA synchronous=OFF; - PRAGMA journal_mode=WAL; - `) - return err - } - tracer := &sqlstats.Tracer{} - tempDir := t.TempDir() - p, err := NewPool("file:"+tempDir+"/sqlitepool_test", 3, initFn, tracer) - if err != nil { - t.Fatal(err) - } - - tx, err := p.BeginTx(ctx, "insert-1") - if err != nil { - t.Fatal(err) - } - if err := tx.Exec("CREATE TABLE t (c);"); err != nil { - t.Fatal(err) - } - stmt := tx.Prepare("INSERT INTO t (c) VALUES (?);") - stmt.BindInt64(1, 1) - if _, _, _, _, err := stmt.StepResult(); err != nil { - t.Fatal(err) - } - var onCommitCalled, onRollbackCalled bool - tx.OnCommit = func() { onCommitCalled = true } - tx.OnRollback = func() { onRollbackCalled = true } - if err := tx.Commit(); err != nil { - t.Fatal(err) - } - tx.Rollback() // no-op, does not call OnRollback - if !onCommitCalled { - t.Fatal("onCommit not called") - } - if onRollbackCalled { - t.Fatal("onRollback called") - } - if err := tx.Commit(); err == nil { - t.Fatalf("want error on second commit, got: %v", err) - } - - tx, err = p.BeginTx(ctx, "insert-2") - if err != nil { - t.Fatal(err) - } - stmt2 := tx.Prepare("INSERT INTO t (c) VALUES (?);") - if stmt != stmt2 { - t.Fatalf("second call to prepare returned a different stmt: %p vs. %p", stmt, stmt2) - } - stmt = stmt2 - stmt.BindInt64(1, 2) - if _, _, _, _, err := stmt.StepResult(); err != nil { - t.Fatal(err) - } - func() { - defer func() { - const want = `SQLITE_ERROR: near "INVALID": syntax error` - if r := recover(); r == nil { - t.Fatal("no panic from invalid prepare") - } else if r != want { - t.Fatalf("invalid sql recover: %q, want %q", r, want) - } - }() - tx.Prepare("INVALID SQL") - }() - onCommitCalled = false - onRollbackCalled = false - tx.OnCommit = func() { onCommitCalled = true } - tx.OnRollback = func() { onRollbackCalled = true } - tx.Rollback() - if onCommitCalled { - t.Fatal("onCommit called") - } - if !onRollbackCalled { - t.Fatal("onRollback not called") - } - if err := tx.Commit(); err == nil { - t.Fatalf("want error on commit after rollback, got: %v", err) - } - tx.Rollback() // no-op - - rx1, err := p.BeginRx(ctx, "read-1") - if err != nil { - t.Fatal(err) - } - defer rx1.Rollback() - rx2, err := p.BeginRx(ctx, "read-2") - if err != nil { - t.Fatal(err) - } - defer rx2.Rollback() - - ctxCancel, cancel := context.WithCancel(ctx) - rx3Err := make(chan error, 1) - go func() { - rx3, err := p.BeginRx(ctxCancel, "read-3") - if err != nil { - rx3Err <- err - return - } - rx3.Rollback() - rx3Err <- errors.New("BeginRx(read-3) did not fail") - }() - cancel() - if err := <-rx3Err; err != context.Canceled { - t.Fatalf("read-3, not context canceled: %v", err) - } - - stmt = rx1.Prepare("SELECT count(*) FROM t") - if row, err := stmt.Step(nil); err != nil { - t.Fatal(err) - } else if !row { - t.Fatal("no row from select count") - } - if got, want := int(stmt.ColumnInt64(0)), 1; got != want { - t.Fatalf("got=%d, want %d", got, want) - } - rx1.Rollback() - rx1.Rollback() // no-op - - rx1, err = p.BeginRx(ctx, "read-1") // now another rx is available - if err != nil { - t.Fatal(err) - } - rx1.Rollback() - rx2.Rollback() - - tx, err = p.BeginTx(ctx, "insert-3") - if err != nil { - t.Fatal(err) - } - if err := ExecScript(tx.DB(), "PRAGMA user_version=5"); err != nil { - t.Fatal(err) - } - func() { - defer func() { - if r := recover(); r != "Tx.Rx.Rollback called, only call Rollback on the Tx object" { - t.Fatalf("expected panic from Tx.Rx.Rollback, got: %q", r) - } - }() - tx.Rx.Rollback() - }() - if err := tx.Commit(); err != nil { - t.Fatal(err) - } - if err := tx.Commit(); err == nil { - t.Fatalf("second commit did not fail, want 'already done'") - } - - if err := p.Close(); err != nil { - t.Fatal(err) - } - p.Close() // no-op - - if _, err := p.BeginTx(ctx, "after-close"); err == nil { - t.Fatal("tx-after-close did not fail") - } - if _, err := p.BeginRx(ctx, "after-close"); err == nil { - t.Fatal("rx-after-close did not fail") - } -} diff --git a/sqlitepool/util.go b/sqlitepool/util.go deleted file mode 100644 index d17c370..0000000 --- a/sqlitepool/util.go +++ /dev/null @@ -1,88 +0,0 @@ -//go:build cgo - -package sqlitepool - -import ( - "fmt" - "strings" - - "github.com/tailscale/sqlite/sqliteh" -) - -// CopyAll copies the contents of one database to another. -// -// Traditionally this is done in sqlite by closing the database and copying -// the file. However it can be useful to do it online: a single exclusive -// transaction can cross multiple databases, and if multiple processes are -// using a file, this lets one replace the database without first -// communicating with the other processes, asking them to close the DB first. -// -// The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA -// schema-name conventions: https://sqlite.org/pragma.html#syntax -func CopyAll(db sqliteh.DB, dstSchemaName, srcSchemaName string) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("sqlitepool.CopyAll: %w", err) - } - }() - if dstSchemaName == "" { - dstSchemaName = "main" - } - if srcSchemaName == "" { - srcSchemaName = "main" - } - if dstSchemaName == srcSchemaName { - return fmt.Errorf("source matches destination: %q", srcSchemaName) - } - // Filter on sql to avoid auto indexes. - // See https://www.sqlite.org/schematab.html for sqlite_schema docs. - rows, err := Query(db, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName)) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var name, sqlType, sqlText string - if err := rows.Scan(&name, &sqlType, &sqlText); err != nil { - return err - } - // Regardless of the case or whitespace used in the original - // create statement (or whether or not "if not exists" is used), - // the SQL text in the sqlite_schema table always reads: - // "CREATE (TABLE|VIEW|INDEX|TRIGGER) name". - // We take advantage of that here to rewrite the create - // statement for a different schema. - switch sqlType { - case "index": - sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ") - sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - case "table": - sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ") - sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - if err := ExecScript(db, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil { - return err - } - case "trigger": - sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ") - sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - case "view": - sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ") - sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - default: - return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name) - } - } - return rows.Err() -}