diff --git a/sqlitepool/queryglue.go b/sqlitepool/queryglue.go deleted file mode 100644 index 4a84e80..0000000 --- a/sqlitepool/queryglue.go +++ /dev/null @@ -1,392 +0,0 @@ -//go:build cgo - -package sqlitepool - -// This file contains bridging functions designed to let users of -// database/sql move to sqlitepool without changing the semantics -// of their code. -// -// Eventually users should piece-wise migrate to another interface. -// (Or we should invest in this interface? Seems suboptimal.) - -import ( - sqlpkg "database/sql" - "database/sql/driver" - "encoding" - "fmt" - "reflect" - "strings" - "time" - - "github.com/tailscale/sqlite/sqliteh" -) - -// Exec is like database/sql.Tx.Exec. -// Only use this for one-off/rare queries. -// For normal queries, see the Exec method on Tx. -func Exec(db sqliteh.DB, sql string, args ...any) error { - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - return err - } - if err := bindAll(db, stmt, args...); err != nil { - return fmt.Errorf("Exec: %w", err) - } - _, _, _, _, err = stmt.StepResult() - if err != nil { - err = fmt.Errorf("%w: %v", err, db.ErrMsg()) - } - stmt.Finalize() - return err -} - -// QueryRow is like database/sql.Tx.QueryRow. -// Only use this for one-off/rare queries. -// For normal queries, see the methods on Rx. -func QueryRow(db sqliteh.DB, sql string, args ...any) *Row { - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, db.ErrMsg())} - } - if err := bindAll(db, stmt, args...); err != nil { - return &Row{err: fmt.Errorf("QueryRow: %w", err)} - } - row, err := stmt.Step(nil) - if err != nil { - msg := db.ErrMsg() - stmt.Finalize() - return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)} - } - if !row { - stmt.Finalize() - return &Row{err: sqlpkg.ErrNoRows} - } - return &Row{stmt: stmt, oneOff: true} -} - -// Query is like database/sql.Tx.Query. -// Only use this for one-off/rare queries. -// For normal queries, see the methods on Rx. -func Query(db sqliteh.DB, sql string, args ...any) (*Rows, error) { - stmt, _, err := db.Prepare(sql, 0) - if err != nil { - return nil, fmt.Errorf("Query: %w: %v", err, db.ErrMsg()) - } - if err := bindAll(db, stmt, args...); err != nil { - return nil, err - } - return &Rows{stmt: stmt, oneOff: true}, nil -} - -// Exec is like database/sql.Tx.Exec. -func (tx *Tx) Exec(sql string, args ...any) error { - stmt := tx.Prepare(sql) - if err := bindAll(tx.conn.db, stmt, args...); err != nil { - return err - } - _, _, _, _, err := stmt.StepResult() - if err != nil { - return fmt.Errorf("%w: %v", err, tx.conn.db.ErrMsg()) - } - return nil -} - -func (tx *Tx) ExecRes(sql string, args ...any) (rowsAffected int64, err error) { - stmt := tx.Prepare(sql) - if err := bindAll(tx.conn.db, stmt, args...); err != nil { - return 0, err - } - _, _, rowsAffected, _, err = stmt.StepResult() - return rowsAffected, err -} - -// QueryRow is like database/sql.Tx.QueryRow. -func (rx *Rx) QueryRow(sql string, args ...any) *Row { - stmt := rx.Prepare(sql) - if err := bindAll(rx.conn.db, stmt, args...); err != nil { - return &Row{err: fmt.Errorf("QueryRow: %w", err)} - } - row, err := stmt.Step(nil) - if err != nil { - msg := rx.DB().ErrMsg() - stmt.ResetAndClear() - return &Row{err: fmt.Errorf("QueryRow: %w: %v", err, msg)} - } - if !row { - stmt.ResetAndClear() - return &Row{err: sqlpkg.ErrNoRows} - } - return &Row{stmt: stmt} -} - -// Query is like database/sql.Tx.Query. -func (rx *Rx) Query(sql string, args ...any) (*Rows, error) { - stmt := rx.Prepare(sql) - if err := bindAll(rx.conn.db, stmt, args...); err != nil { - return nil, fmt.Errorf("Query: %w", err) - } - return &Rows{stmt: stmt}, nil -} - -// Rows is like database/sql.Tx.Rows. -type Rows struct { - stmt sqliteh.Stmt - err error - oneOff bool -} - -func (rs *Rows) Next() bool { - if rs.err != nil { - return false - } - row, err := rs.stmt.Step(nil) - if err != nil { - rs.err = fmt.Errorf("QueryRow.Next: %w: %v", err, rs.stmt.DBHandle().ErrMsg()) - return false - } - if !row { - rs.stmt.ResetAndClear() - } - return row -} - -func (rs *Rows) Err() error { - return rs.err -} - -func (rs *Rows) Scan(dest ...any) error { - if rs.err != nil { - return rs.err - } - return scanAll(rs.stmt, dest...) -} - -func (rs *Rows) Close() error { - if rs.stmt == nil { - return nil - } - _, err := rs.stmt.ResetAndClear() - msg := rs.stmt.DBHandle().ErrMsg() - var err2 error - if rs.oneOff { - err2 = rs.stmt.Finalize() - } - rs.stmt = nil - if err != nil { - return fmt.Errorf("Rows.ResetAndClear: %w: %v", err, msg) - } - if err2 != nil { - return fmt.Errorf("Rows.ResetAndClear: %w: %v", err2, rs.stmt.DBHandle().ErrMsg()) - } - return nil -} - -// Row is like database/sql.Tx.Row. -type Row struct { - stmt sqliteh.Stmt - err error - oneOff bool -} - -func (r *Row) Err() error { - return r.err -} - -func (r *Row) Scan(dest ...any) error { - if r.err != nil { - return r.err - } - err := scanAll(r.stmt, dest...) - r.stmt.ResetAndClear() - if r.oneOff { - r.stmt.Finalize() - } - return err -} - -type scanner interface { - Scan(value any) error -} - -// scanAll mimics (some of) the sqlite driver's scanning logic, which is -// split across the driver and the database/sql package. -func scanAll(stmt sqliteh.Stmt, dest ...any) error { - for i := 0; i < len(dest); i++ { - if s, ok := dest[i].(scanner); ok { - // We have a handful of *sql.NullInt64 objects in - // our tree, so we implement minimal support for - // them here. TODO: remove some time. - var v any - switch stmt.ColumnType(i) { - case sqliteh.SQLITE_INTEGER: - v = stmt.ColumnInt64(i) - case sqliteh.SQLITE_FLOAT: - v = stmt.ColumnDouble(i) - case sqliteh.SQLITE_TEXT: - v = stmt.ColumnText(i) - case sqliteh.SQLITE_BLOB: - v = stmt.ColumnText(i) - case sqliteh.SQLITE_NULL: - v = nil - } - if err := s.Scan(v); err != nil { - return err - } - continue - } - v := reflect.ValueOf(dest[i]) - if v.Elem().Kind() == reflect.Slice && v.Elem().Type().Elem().Kind() == reflect.Uint8 { - b := append([]byte(nil), stmt.ColumnBlob(i)...) - v.Elem().SetBytes(b) - continue - } - switch v.Elem().Kind() { - case reflect.Bool: - v.Elem().SetBool(stmt.ColumnInt64(i) != 0) - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - v.Elem().SetInt(stmt.ColumnInt64(i)) - case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - v.Elem().SetUint(uint64(stmt.ColumnInt64(i))) - case reflect.Float32, reflect.Float64: - v.Elem().SetFloat(stmt.ColumnDouble(i)) - case reflect.String: - v.Elem().SetString(stmt.ColumnText(i)) - default: - return fmt.Errorf("sqlitepool.scan:%d: cannot handle destination kind %v (%T)", i, v.Kind(), dest[i]) - } - } - return nil -} - -func bindAll(db sqliteh.DB, stmt sqliteh.Stmt, args ...any) error { - for i, arg := range args { - if err := bind(db, stmt, i+1, arg); err != nil { - stmt.ResetAndClear() - return fmt.Errorf("bind: %d, %q: %w", i, arg, err) - } - } - return nil -} - -type driverValue interface { - Value() (driver.Value, error) -} - -// bind, from the driver in sqlite.go. -func bind(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) error { - // Start with obvious types, including time.Time before TextMarshaler. - found, err := bindBasic(db, s, ordinal, v) - if err != nil { - return err - } else if found { - return nil - } - - if m, _ := v.(driverValue); m != nil { - // We have a few NullInt64s we need to handle. - // TODO: remove or rethink in the future. - var err error - v, err = m.Value() - if err != nil { - return fmt.Errorf("sqlitepool.bind:%d: bad driver.Value: %w", ordinal, err) - } - if v == nil { - _, err := bindBasic(db, s, ordinal, nil) - return err - } - } - - if m, _ := v.(encoding.TextMarshaler); m != nil { - b, err := m.MarshalText() - if err != nil { - return fmt.Errorf("sqlitepool.bind:%d: cannot marshal %T: %w", ordinal, v, err) - } - _, err = bindBasic(db, s, ordinal, b) - return err - } - - // Look for named basic types or other convertible types. - val := reflect.ValueOf(v) - if val.Kind() == reflect.Pointer { - if val.IsNil() { - _, err := bindBasic(db, s, ordinal, nil) - return err - } - val = val.Elem() - } - typ := reflect.TypeOf(v) - if typ.Kind() == reflect.Pointer { - typ = typ.Elem() - } - switch typ.Kind() { - case reflect.Bool: - b := int64(0) - if val.Bool() { - b = 1 - } - _, err := bindBasic(db, s, ordinal, b) - return err - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: - var i int64 - if !val.IsZero() { - i = val.Int() - } - _, err := bindBasic(db, s, ordinal, i) - return err - case reflect.Uint, reflect.Uint64: - return fmt.Errorf("sqlitepool.bind:%d: sqlite does not support uint64 (try a string or TextMarshaler)", ordinal) - case reflect.Uint8, reflect.Uint16, reflect.Uint32: - _, err := bindBasic(db, s, ordinal, int64(val.Uint())) - return err - case reflect.Float32, reflect.Float64: - _, err := bindBasic(db, s, ordinal, val.Float()) - return err - case reflect.String: - _, err := bindBasic(db, s, ordinal, val.String()) - return err - } - - return fmt.Errorf("sqlitepool.bind:%d: unknown value type %T (try a string or TextMarshaler)", ordinal, v) -} - -// bindBasic, from the driver in sqlite.go. -func bindBasic(db sqliteh.DB, s sqliteh.Stmt, ordinal int, v any) (found bool, err error) { - defer func() { - if err != nil { - err = fmt.Errorf("sqlitepool.bind:%d:%T: %w: %v", ordinal, v, err, db.ErrMsg()) - } - }() - switch v := v.(type) { - case nil: - return true, s.BindNull(ordinal) - case string: - return true, s.BindText64(ordinal, v) - case int: - return true, s.BindInt64(ordinal, int64(v)) - case int64: - return true, s.BindInt64(ordinal, v) - case float64: - return true, s.BindDouble(ordinal, v) - case []byte: - if len(v) == 0 { - return true, s.BindZeroBlob64(ordinal, 0) - } else { - return true, s.BindBlob64(ordinal, v) - } - case time.Time: - // Shortest of: - // YYYY-MM-DD HH:MM - // YYYY-MM-DD HH:MM:SS - // YYYY-MM-DD HH:MM:SS.SSS - str := v.Format(timeFormat) - str = strings.TrimSuffix(str, "-0000") - str = strings.TrimSuffix(str, ".000") - str = strings.TrimSuffix(str, ":00") - return true, s.BindText64(ordinal, str) - default: - return false, nil - } -} - -// timeFormat from the driver in sqlite.go. -const timeFormat = "2006-01-02 15:04:05.000-0700" diff --git a/sqlitepool/queryglue_test.go b/sqlitepool/queryglue_test.go deleted file mode 100644 index e29f798..0000000 --- a/sqlitepool/queryglue_test.go +++ /dev/null @@ -1,101 +0,0 @@ -//go:build cgo - -package sqlitepool - -import ( - "context" - "database/sql" - "testing" - - "github.com/tailscale/sqlite/sqliteh" - "github.com/tailscale/sqlite/sqlstats" -) - -func TestQueryGlue(t *testing.T) { - ctx := context.Background() - initFn := func(db sqliteh.DB) error { return ExecScript(db, "PRAGMA synchronous=OFF;") } - tracer := &sqlstats.Tracer{} - tempDir := t.TempDir() - p, err := NewPool("file:"+tempDir+"/sqlitepool_queryglue_test", 2, initFn, tracer) - if err != nil { - t.Fatal(err) - } - - tx, err := p.BeginTx(ctx, "insert-1") - if err != nil { - t.Fatal(err) - } - if err := tx.Exec("CREATE TABLE t (id INTEGER PRIMARY KEY, val TEXT)"); err != nil { - t.Fatal(err) - } - if err := Exec(tx.DB(), "INSERT INTO t VALUES (?, ?)", 10, "skip"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 100, "a"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 200, "b"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 300, "c"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 400, "d"); err != nil { - t.Fatal(err) - } - if err := tx.Exec("INSERT INTO t VALUES (?, ?)", 401, "skip"); err != nil { - t.Fatal(err) - } - - var count int - if err := tx.QueryRow("SELECT count(*) FROM t WHERE id >= ? AND id <= ?", 100, 400).Scan(&count); err != nil { - t.Fatal(err) - } - if count != 4 { - t.Fatalf("count=%d, want 4", count) - } - if err := tx.QueryRow("SELECT id FROM t WHERE id >= ?", 900).Scan(&count); err != sql.ErrNoRows { - t.Fatalf("QueryRow err=%v, want ErrNoRows", err) - } - - rows, err := tx.Query("SELECT * FROM t WHERE id >= ? AND id <= ?", 100, 400) - if err != nil { - t.Fatal(err) - } - for i := 0; i < 4; i++ { - if !rows.Next() { - t.Fatalf("pass %d: Next=false", i) - } - var id int64 - var val string - if err := rows.Scan(&id, &val); err != nil { - t.Fatalf("pass %d: Scan: %v", i, err) - } - if want := int64(i+1) * 100; id != want { - t.Fatalf("pass %d: id=%d, want %d", i, id, want) - } - if want := string([]byte{'a' + byte(i)}); val != want { - t.Fatalf("pass %d: val=%q want %q", i, val, want) - } - } - if rows.Next() { - t.Fatal("too many rows") - } - if err := rows.Err(); err != nil { - t.Fatal(err) - } - if err := rows.Close(); err != nil { - t.Fatal(err) - } - - var concat sql.RawBytes - if err := tx.QueryRow("SELECT val FROM t WHERE id = 401").Scan(&concat); err != nil { - t.Fatal(err) - } - if got, want := string(concat), "skip"; got != want { - t.Fatalf("concat=%q, want %q", got, want) - } - - tx.Rollback() - p.Close() -} diff --git a/sqlitepool/sqlitepool.go b/sqlitepool/sqlitepool.go deleted file mode 100644 index d699be6..0000000 --- a/sqlitepool/sqlitepool.go +++ /dev/null @@ -1,323 +0,0 @@ -//go:build cgo - -// Package sqlitepool implements a pool of SQLite database connections. -package sqlitepool - -import ( - "context" - "errors" - "fmt" - "strings" - - "github.com/tailscale/sqlite/cgosqlite" - "github.com/tailscale/sqlite/sqliteh" -) - -// A Pool is a fixed-size pool of SQLite database connections. -// One is reserved for writable transactions, the others are -// used for read-only transactions. -type Pool struct { - poolSize int - rwConnFree chan *conn // cap == 1 - roConnsFree chan *conn // cap == poolSize-1 - tracer sqliteh.Tracer - closed chan struct{} -} - -type conn struct { - pool *Pool - db sqliteh.DB - stmts map[string]sqliteh.Stmt // persistent statements on db - id sqliteh.TraceConnID -} - -// NewPool creates a Pool of poolSize database connections. -// -// For each connection, initFn is called to initialize the connection. -// Tracer is used to report statistics about the use of the Pool. -func NewPool(filename string, poolSize int, initFn func(sqliteh.DB) error, tracer sqliteh.Tracer) (_ *Pool, err error) { - p := &Pool{ - poolSize: poolSize, - rwConnFree: make(chan *conn, 1), - roConnsFree: make(chan *conn, poolSize-1), - tracer: tracer, - closed: make(chan struct{}), - } - defer func() { - if err != nil { - err = fmt.Errorf("sqlitepool.NewPool: %w", err) - select { - case conn := <-p.rwConnFree: - conn.db.Close() - default: - } - close(p.roConnsFree) - for conn := range p.roConnsFree { - conn.db.Close() - } - } - }() - if poolSize < 2 { - return nil, fmt.Errorf("poolSize=%d is too small", poolSize) - } - for i := 0; i < poolSize; i++ { - db, err := cgosqlite.Open(filename, sqliteh.OpenFlagsDefault, "") - if err != nil { - return nil, err - } - if err := initFn(db); err != nil { - return nil, err - } - c := &conn{ - pool: p, - db: db, - stmts: make(map[string]sqliteh.Stmt), - id: sqliteh.TraceConnID(i), - } - if i == 0 { - p.rwConnFree <- c - } else { - if err := ExecScript(c.db, "PRAGMA query_only=true"); err != nil { - return nil, err - } - p.roConnsFree <- c - } - } - - return p, nil -} - -func (c *conn) close() error { - if c.db == nil { - return errors.New("sqlitepool conn already closed") - } - for _, stmt := range c.stmts { - stmt.Finalize() - } - c.stmts = nil - err := c.db.Close() - c.db = nil - return err -} - -func (p *Pool) Close() error { - select { - case <-p.closed: - return errors.New("pool already closed") - default: - } - close(p.closed) - - c := <-p.rwConnFree - err := c.close() - - for i := 0; i < p.poolSize-1; i++ { - c := <-p.roConnsFree - err2 := c.close() - if err == nil { - err = err2 - } - } - return err -} - -var errPoolClosed = fmt.Errorf("%w: sqlitepool closed", context.Canceled) - -// BeginTx creates a writable transaction using BEGIN IMMEDIATE. -// The parameter why is passed to the Tracer for debugging. -func (p *Pool) BeginTx(ctx context.Context, why string) (*Tx, error) { - select { - case <-p.closed: - return nil, errPoolClosed - case <-ctx.Done(): - return nil, ctx.Err() - case conn := <-p.rwConnFree: - tx := &Tx{Rx: &Rx{conn: conn, inTx: true}} - err := tx.Exec("BEGIN IMMEDIATE;") - if p.tracer != nil { - p.tracer.BeginTx(ctx, conn.id, why, false, err) - } - if err != nil { - p.rwConnFree <- conn // can't block, buffer is big enough - return nil, err - } - return tx, nil - } -} - -// BeginRx creates a read-only transaction. -// The parameter why is passed to the Tracer for debugging. -func (p *Pool) BeginRx(ctx context.Context, why string) (*Rx, error) { - select { - case <-p.closed: - return nil, errPoolClosed - case <-ctx.Done(): - return nil, ctx.Err() - case conn := <-p.roConnsFree: - rx := &Rx{conn: conn} - err := rx.Exec("BEGIN;") - if p.tracer != nil { - p.tracer.BeginTx(ctx, conn.id, why, true, err) - } - if err != nil { - p.roConnsFree <- conn // can't block, buffer is big enough - return nil, err - } - return &Rx{conn: conn}, nil - } -} - -// Rx is a read-only transaction. -// -// It is *not* safe for concurrent use. -type Rx struct { - conn *conn - inTx bool // true if this Rx is embedded in a writable Tx - - // OnRollback is an optional function called after rollback. - // If Rx is part of a Tx and it is committed, then OnRollback - // is not called. - OnRollback func() -} - -// Exec executes an SQL statement with no result. -func (rx *Rx) Exec(sql string) error { - _, _, _, _, err := rx.Prepare(sql).StepResult() - if err != nil { - return fmt.Errorf("%w: %v", err, rx.conn.db.ErrMsg()) - } - return nil -} - -// Prepare prepares an SQL statement. -// The Stmt is cached on the connection, so subsequent calls are fast. -func (rx *Rx) Prepare(sql string) sqliteh.Stmt { - stmt := rx.conn.stmts[sql] - if stmt != nil { - return stmt - } - stmt, _, err := rx.conn.db.Prepare(sql, sqliteh.SQLITE_PREPARE_PERSISTENT) - if err != nil { - // Persistent statements are constant strings hardcoded into - // programs. Failing to prepare one means the string is bad. - // Ideally we would detect this at compile time, but barring - // that, there is no point returning the error because this - // is not something the program can recover from or handle. - panic(fmt.Sprintf("%v: %v", err, rx.conn.db.ErrMsg())) - } - rx.conn.stmts[sql] = stmt - return stmt -} - -// DB returns the underlying database connection. -// -// Be careful: a transaction is in progress. Any use of BEGIN/COMMIT/ROLLBACK -// should be modelled as a nested transaction, and when done the original -// outer transaction should be left in-progress. -func (rx *Rx) DB() sqliteh.DB { - return rx.conn.db -} - -// ExecScript executes a series of SQL statements against a database connection. -// It is intended for one-off scripts, so the prepared Stmt objects are not -// cached for future calls. -func ExecScript(db sqliteh.DB, queries string) error { - for { - queries = strings.TrimSpace(queries) - if queries == "" { - return nil - } - stmt, rem, err := db.Prepare(queries, 0) - if err != nil { - return fmt.Errorf("ExecScript: %w: %v, in remaining script: %s", err, db.ErrMsg(), queries) - } - queries = rem - _, err = stmt.Step(nil) - if err != nil { - err = fmt.Errorf("ExecScript: %w: %s: %v", err, stmt.SQL(), db.ErrMsg()) - } - stmt.Finalize() - if err != nil { - return err - } - } -} - -// Rollback executes ROLLBACK and cleans up the Rx. -// It is a no-op if Rx is already rolled back. -func (rx *Rx) Rollback() { - if rx.conn == nil { - return - } - if rx.inTx { - panic("Tx.Rx.Rollback called, only call Rollback on the Tx object") - } - err := rx.Exec("ROLLBACK;") - if rx.conn.pool.tracer != nil { - rx.conn.pool.tracer.Rollback(rx.conn.id, err) - } - rx.conn.pool.roConnsFree <- rx.conn - rx.conn = nil - if rx.OnRollback != nil { - rx.OnRollback() - rx.OnRollback = nil - } - if err != nil { - panic(err) - } -} - -// Tx is a writable SQLite database transaction. -// -// It is *not* safe for concurrent use. -// -// A Tx contains an embedded Rx, which can be used to pass to functions -// that want to perform read-only queries on the writable Tx. -type Tx struct { - *Rx - - // OnCommit is an optional function called after successful commit. - OnCommit func() -} - -// Rollback executes ROLLBACK and cleans up the Tx. -// It is a no-op if the Tx is already rolled back or committed. -func (tx *Tx) Rollback() { - if tx.conn == nil { - return - } - err := tx.Exec("ROLLBACK;") - if tx.conn.pool.tracer != nil { - tx.conn.pool.tracer.Rollback(tx.conn.id, err) - } - tx.conn.pool.rwConnFree <- tx.conn - tx.conn = nil - if tx.OnRollback != nil { - tx.OnRollback() - tx.OnRollback = nil - tx.OnCommit = nil - } - if err != nil { - panic(err) - } -} - -// Commit executes COMMIT and cleans up the Tx. -// It is an error to call if the Tx is already rolled back or committed. -func (tx *Tx) Commit() error { - if tx.conn == nil { - return errors.New("tx already done") - } - err := tx.Exec("COMMIT;") - if tx.conn.pool.tracer != nil { - tx.conn.pool.tracer.Commit(tx.conn.id, err) - } - tx.conn.pool.rwConnFree <- tx.conn - tx.conn = nil - if tx.OnCommit != nil { - tx.OnCommit() - tx.OnCommit = nil - tx.OnRollback = nil - } - return err -} diff --git a/sqlitepool/sqlitepool_test.go b/sqlitepool/sqlitepool_test.go deleted file mode 100644 index f476d66..0000000 --- a/sqlitepool/sqlitepool_test.go +++ /dev/null @@ -1,178 +0,0 @@ -//go:build cgo - -package sqlitepool - -import ( - "context" - "errors" - "testing" - - "github.com/tailscale/sqlite/sqliteh" - "github.com/tailscale/sqlite/sqlstats" -) - -func TestPool(t *testing.T) { - ctx := context.Background() - initFn := func(db sqliteh.DB) error { - err := ExecScript(db, ` - PRAGMA synchronous=OFF; - PRAGMA journal_mode=WAL; - `) - return err - } - tracer := &sqlstats.Tracer{} - tempDir := t.TempDir() - p, err := NewPool("file:"+tempDir+"/sqlitepool_test", 3, initFn, tracer) - if err != nil { - t.Fatal(err) - } - - tx, err := p.BeginTx(ctx, "insert-1") - if err != nil { - t.Fatal(err) - } - if err := tx.Exec("CREATE TABLE t (c);"); err != nil { - t.Fatal(err) - } - stmt := tx.Prepare("INSERT INTO t (c) VALUES (?);") - stmt.BindInt64(1, 1) - if _, _, _, _, err := stmt.StepResult(); err != nil { - t.Fatal(err) - } - var onCommitCalled, onRollbackCalled bool - tx.OnCommit = func() { onCommitCalled = true } - tx.OnRollback = func() { onRollbackCalled = true } - if err := tx.Commit(); err != nil { - t.Fatal(err) - } - tx.Rollback() // no-op, does not call OnRollback - if !onCommitCalled { - t.Fatal("onCommit not called") - } - if onRollbackCalled { - t.Fatal("onRollback called") - } - if err := tx.Commit(); err == nil { - t.Fatalf("want error on second commit, got: %v", err) - } - - tx, err = p.BeginTx(ctx, "insert-2") - if err != nil { - t.Fatal(err) - } - stmt2 := tx.Prepare("INSERT INTO t (c) VALUES (?);") - if stmt != stmt2 { - t.Fatalf("second call to prepare returned a different stmt: %p vs. %p", stmt, stmt2) - } - stmt = stmt2 - stmt.BindInt64(1, 2) - if _, _, _, _, err := stmt.StepResult(); err != nil { - t.Fatal(err) - } - func() { - defer func() { - const want = `SQLITE_ERROR: near "INVALID": syntax error` - if r := recover(); r == nil { - t.Fatal("no panic from invalid prepare") - } else if r != want { - t.Fatalf("invalid sql recover: %q, want %q", r, want) - } - }() - tx.Prepare("INVALID SQL") - }() - onCommitCalled = false - onRollbackCalled = false - tx.OnCommit = func() { onCommitCalled = true } - tx.OnRollback = func() { onRollbackCalled = true } - tx.Rollback() - if onCommitCalled { - t.Fatal("onCommit called") - } - if !onRollbackCalled { - t.Fatal("onRollback not called") - } - if err := tx.Commit(); err == nil { - t.Fatalf("want error on commit after rollback, got: %v", err) - } - tx.Rollback() // no-op - - rx1, err := p.BeginRx(ctx, "read-1") - if err != nil { - t.Fatal(err) - } - defer rx1.Rollback() - rx2, err := p.BeginRx(ctx, "read-2") - if err != nil { - t.Fatal(err) - } - defer rx2.Rollback() - - ctxCancel, cancel := context.WithCancel(ctx) - rx3Err := make(chan error, 1) - go func() { - rx3, err := p.BeginRx(ctxCancel, "read-3") - if err != nil { - rx3Err <- err - return - } - rx3.Rollback() - rx3Err <- errors.New("BeginRx(read-3) did not fail") - }() - cancel() - if err := <-rx3Err; err != context.Canceled { - t.Fatalf("read-3, not context canceled: %v", err) - } - - stmt = rx1.Prepare("SELECT count(*) FROM t") - if row, err := stmt.Step(nil); err != nil { - t.Fatal(err) - } else if !row { - t.Fatal("no row from select count") - } - if got, want := int(stmt.ColumnInt64(0)), 1; got != want { - t.Fatalf("got=%d, want %d", got, want) - } - rx1.Rollback() - rx1.Rollback() // no-op - - rx1, err = p.BeginRx(ctx, "read-1") // now another rx is available - if err != nil { - t.Fatal(err) - } - rx1.Rollback() - rx2.Rollback() - - tx, err = p.BeginTx(ctx, "insert-3") - if err != nil { - t.Fatal(err) - } - if err := ExecScript(tx.DB(), "PRAGMA user_version=5"); err != nil { - t.Fatal(err) - } - func() { - defer func() { - if r := recover(); r != "Tx.Rx.Rollback called, only call Rollback on the Tx object" { - t.Fatalf("expected panic from Tx.Rx.Rollback, got: %q", r) - } - }() - tx.Rx.Rollback() - }() - if err := tx.Commit(); err != nil { - t.Fatal(err) - } - if err := tx.Commit(); err == nil { - t.Fatalf("second commit did not fail, want 'already done'") - } - - if err := p.Close(); err != nil { - t.Fatal(err) - } - p.Close() // no-op - - if _, err := p.BeginTx(ctx, "after-close"); err == nil { - t.Fatal("tx-after-close did not fail") - } - if _, err := p.BeginRx(ctx, "after-close"); err == nil { - t.Fatal("rx-after-close did not fail") - } -} diff --git a/sqlitepool/util.go b/sqlitepool/util.go deleted file mode 100644 index d17c370..0000000 --- a/sqlitepool/util.go +++ /dev/null @@ -1,88 +0,0 @@ -//go:build cgo - -package sqlitepool - -import ( - "fmt" - "strings" - - "github.com/tailscale/sqlite/sqliteh" -) - -// CopyAll copies the contents of one database to another. -// -// Traditionally this is done in sqlite by closing the database and copying -// the file. However it can be useful to do it online: a single exclusive -// transaction can cross multiple databases, and if multiple processes are -// using a file, this lets one replace the database without first -// communicating with the other processes, asking them to close the DB first. -// -// The dstSchemaName and srcSchemaName parameters follow the SQLite PRAMGA -// schema-name conventions: https://sqlite.org/pragma.html#syntax -func CopyAll(db sqliteh.DB, dstSchemaName, srcSchemaName string) (err error) { - defer func() { - if err != nil { - err = fmt.Errorf("sqlitepool.CopyAll: %w", err) - } - }() - if dstSchemaName == "" { - dstSchemaName = "main" - } - if srcSchemaName == "" { - srcSchemaName = "main" - } - if dstSchemaName == srcSchemaName { - return fmt.Errorf("source matches destination: %q", srcSchemaName) - } - // Filter on sql to avoid auto indexes. - // See https://www.sqlite.org/schematab.html for sqlite_schema docs. - rows, err := Query(db, fmt.Sprintf("SELECT name, type, sql FROM %q.sqlite_schema WHERE sql != ''", srcSchemaName)) - if err != nil { - return err - } - defer rows.Close() - for rows.Next() { - var name, sqlType, sqlText string - if err := rows.Scan(&name, &sqlType, &sqlText); err != nil { - return err - } - // Regardless of the case or whitespace used in the original - // create statement (or whether or not "if not exists" is used), - // the SQL text in the sqlite_schema table always reads: - // "CREATE (TABLE|VIEW|INDEX|TRIGGER) name". - // We take advantage of that here to rewrite the create - // statement for a different schema. - switch sqlType { - case "index": - sqlText = strings.TrimPrefix(sqlText, "CREATE INDEX ") - sqlText = fmt.Sprintf("CREATE INDEX %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - case "table": - sqlText = strings.TrimPrefix(sqlText, "CREATE TABLE ") - sqlText = fmt.Sprintf("CREATE TABLE %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - if err := ExecScript(db, fmt.Sprintf("INSERT INTO %q.%q SELECT * FROM %q.%q;", dstSchemaName, name, srcSchemaName, name)); err != nil { - return err - } - case "trigger": - sqlText = strings.TrimPrefix(sqlText, "CREATE TRIGGER ") - sqlText = fmt.Sprintf("CREATE TRIGGER %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - case "view": - sqlText = strings.TrimPrefix(sqlText, "CREATE VIEW ") - sqlText = fmt.Sprintf("CREATE VIEW %q.%s", dstSchemaName, sqlText) - if err := ExecScript(db, sqlText); err != nil { - return err - } - default: - return fmt.Errorf("unknown sqlite schema type %q for %q", sqlType, name) - } - } - return rows.Err() -}