From 908681f887aca321cefee0ed3e8a235704110b23 Mon Sep 17 00:00:00 2001 From: Percy Wegmann Date: Tue, 17 Feb 2026 07:11:15 -0600 Subject: [PATCH] report error when logging statement to ConnLogger Updates tailscale/corp#37338 Signed-off-by: Percy Wegmann --- sqlite.go | 10 +++++--- sqlite_test.go | 62 ++++++++++++++++++++++++++++++++++---------------- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/sqlite.go b/sqlite.go index 7730d9a..a6b349a 100644 --- a/sqlite.go +++ b/sqlite.go @@ -554,7 +554,7 @@ func (s *stmt) Close() error { func (s *stmt) Exec(args []driver.Value) (driver.Result, error) { panic("deprecated, unused") } func (s *stmt) Query(args []driver.Value) (driver.Rows, error) { panic("deprecated, unused") } -func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { +func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (res driver.Result, errRet error) { if s.closed.Load() { UsesAfterClose.Add("stmt.ExecContext", 1) return nil, ErrClosed @@ -566,7 +566,10 @@ func (s *stmt) ExecContext(ctx context.Context, args []driver.NamedValue) (drive return nil, s.reserr("Stmt.Exec(Bind)", err) } if s.conn.logger != nil && !s.conn.readOnly { - s.conn.logger.Statement(s.stmt.ExpandedSQL()) + esql := s.stmt.ExpandedSQL() + defer func() { + s.conn.logger.Statement(esql, errRet) + }() } if ctx.Value(queryCancelKey{}) != nil { @@ -1196,7 +1199,8 @@ type ConnLogger interface { Begin() // Statement is called with evaluated SQL when a statement is executed. - Statement(sql string) + // err is the error (if any) resulting from executing the statement. + Statement(sql string, err error) // Commit is called after a commit statement, with the error resulting // from the attempted commit. diff --git a/sqlite_test.go b/sqlite_test.go index aa2b159..2a2bf45 100644 --- a/sqlite_test.go +++ b/sqlite_test.go @@ -1416,9 +1416,18 @@ func TestDisableFunction(t *testing.T) { } } +type statements struct { + succeeded []string + failed []string +} + +func (s statements) String() string { + return fmt.Sprintf("succeeded\n------------\n%s\n\nfailed\n------------\n%s", strings.Join(s.succeeded, "\n"), strings.Join(s.failed, "\n")) +} + type connLogger struct { - ch chan []string - statements []string + ch chan statements + statements statements panicOnUse bool } @@ -1426,14 +1435,18 @@ func (cl *connLogger) Begin() { if cl.panicOnUse { panic("unexpected connLogger.Begin()") } - cl.statements = nil + cl.statements = statements{} } -func (cl *connLogger) Statement(s string) { +func (cl *connLogger) Statement(s string, err error) { if cl.panicOnUse { panic("unexpected connLogger.Statement: " + s) } - cl.statements = append(cl.statements, s) + if err == nil { + cl.statements.succeeded = append(cl.statements.succeeded, s) + } else { + cl.statements.failed = append(cl.statements.failed, s) + } } func (cl *connLogger) Commit(err error) { @@ -1450,7 +1463,7 @@ func (cl *connLogger) Rollback() { if cl.panicOnUse { panic("unexpected connLogger.Rollback()") } - cl.statements = nil + cl.statements = statements{} } func TestConnLogger_writable(t *testing.T) { @@ -1461,7 +1474,7 @@ func TestConnLogger_writable(t *testing.T) { } t.Run(doneStatement, func(t *testing.T) { ctx := context.Background() - ch := make(chan []string, 1) + ch := make(chan statements, 1) txl := connLogger{ch: ch} makeLogger := func() ConnLogger { return &txl } db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger)) @@ -1471,7 +1484,7 @@ func TestConnLogger_writable(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := tx.Exec("CREATE TABLE T (x INTEGER)"); err != nil { + if _, err := tx.Exec("CREATE TABLE T (x INTEGER UNIQUE)"); err != nil { t.Fatal(err) } if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err != nil { @@ -1480,6 +1493,10 @@ func TestConnLogger_writable(t *testing.T) { if _, err := tx.Query("SELECT x FROM T"); err != nil { t.Fatal(err) } + // the below should fail because T already contains value 1 + if _, err := tx.Exec("INSERT INTO T VALUES (?)", 1); err == nil { + t.Fatal("unique constraint violation should have failed") + } done := tx.Rollback if commit { done = tx.Commit @@ -1490,22 +1507,27 @@ func TestConnLogger_writable(t *testing.T) { if !commit { select { case got := <-ch: - t.Errorf("unexpectedly logged statements for rollback:\n%s", strings.Join(got, "\n")) + t.Errorf("unexpectedly logged statements for rollback:\n%s", got) default: return } } - want := []string{ - "BEGIN IMMEDIATE", - "CREATE TABLE T (x INTEGER)", - "INSERT INTO T VALUES (1)", - doneStatement, + want := statements{ + succeeded: []string{ + "BEGIN IMMEDIATE", + "CREATE TABLE T (x INTEGER UNIQUE)", + "INSERT INTO T VALUES (1)", + doneStatement, + }, + failed: []string{ + "INSERT INTO T VALUES (1)", + }, } select { case got := <-ch: - if !slices.Equal(got, want) { - t.Errorf("unexpected log statements. got:\n%s\n\nwant:\n%s", strings.Join(got, "\n"), strings.Join(want, "\n")) + if !slices.Equal(got.succeeded, want.succeeded) || !slices.Equal(got.failed, want.failed) { + t.Errorf("unexpected log statements. got:\n%s\nwant:\n%s", got, want) } default: t.Fatal("no logged statements after commit") @@ -1516,7 +1538,7 @@ func TestConnLogger_writable(t *testing.T) { func TestConnLogger_commit_error(t *testing.T) { ctx := context.Background() - ch := make(chan []string, 1) + ch := make(chan statements, 1) txl := connLogger{ch: ch} makeLogger := func() ConnLogger { return &txl } db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger)) @@ -1544,7 +1566,7 @@ func TestConnLogger_commit_error(t *testing.T) { } select { case got := <-ch: - t.Errorf("unexpectedly logged statements for errored commit:\n%s", strings.Join(got, "\n")) + t.Errorf("unexpectedly logged statements for errored commit:\n%s", got) default: return } @@ -1552,7 +1574,7 @@ func TestConnLogger_commit_error(t *testing.T) { func TestConnLogger_read_tx(t *testing.T) { ctx := context.Background() - ch := make(chan []string, 1) + ch := make(chan statements, 1) txl := connLogger{ch: ch} makeLogger := func() ConnLogger { return &txl } db := sql.OpenDB(ConnectorWithLogger("file:"+t.TempDir()+"/test.db", nil, nil, makeLogger)) @@ -1573,7 +1595,7 @@ func TestConnLogger_read_tx(t *testing.T) { } select { case got := <-ch: - if len(got) == 0 { + if len(got.succeeded) == 0 { t.Errorf("expected logged statements for write tx") } default: