diff --git a/Makefile b/Makefile index af6e84220..a41df6e7b 100644 --- a/Makefile +++ b/Makefile @@ -47,7 +47,7 @@ cmd_%: go build $(BUILDFLAGS) -o $(OUTPUT) $(SOURCE) golangci-lint: - GOBIN=$(GOBIN) go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.61.0 + GOBIN=$(GOBIN) GOTOOLCHAIN=go1.25.8 go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.64.8 go-header: GOBIN=$(GOBIN) go install github.com/denis-tingaikin/go-header/cmd/go-header@latest diff --git a/docker/Dockerfile b/docker/Dockerfile index 395c052fa..f28ba35db 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -1,6 +1,6 @@ -FROM alpine:edge as builder +FROM golang:1.25.8-alpine AS builder -RUN apk add --no-cache --progress git make go +RUN apk add --no-cache --progress git make ARG VERSION ARG BRANCH ARG COMMIT diff --git a/go.mod b/go.mod index 0d2ed496d..51531e51f 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/pingcap/tiproxy -go 1.21 +go 1.25.8 require ( github.com/BurntSushi/toml v1.2.1 diff --git a/lib/cli/util.go b/lib/cli/util.go index cf04682a0..e3a1390a6 100644 --- a/lib/cli/util.go +++ b/lib/cli/util.go @@ -47,9 +47,10 @@ func doRequest(ctx context.Context, bctx *Context, method string, url string, rd res, err := bctx.Client.Do(req) if err != nil { if errors.Is(err, io.EOF) { - if req.URL.Scheme == "https" { + switch req.URL.Scheme { + case "https": req.URL.Scheme = "http" - } else if req.URL.Scheme == "http" { + case "http": req.URL.Scheme = "https" } // probably server did not enable TLS, try again with plain http diff --git a/lib/config/proxy.go b/lib/config/proxy.go index d45264a3e..4e49372e7 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -102,7 +102,7 @@ type TLSConfig struct { } func (c TLSConfig) HasCert() bool { - return !(c.Cert == "" && c.Key == "") + return c.Cert != "" || c.Key != "" } func (c TLSConfig) HasCA() bool { diff --git a/lib/go.mod b/lib/go.mod index 9f68fe9a4..ef9671af8 100644 --- a/lib/go.mod +++ b/lib/go.mod @@ -1,6 +1,6 @@ module github.com/pingcap/tiproxy/lib -go 1.21 +go 1.25.8 require ( github.com/cenkalti/backoff/v4 v4.2.1 diff --git a/lib/util/security/cert.go b/lib/util/security/cert.go index 30146fb5f..c14ca286c 100644 --- a/lib/util/security/cert.go +++ b/lib/util/security/cert.go @@ -41,10 +41,10 @@ func NewCert(server bool) *CertInfo { func (ci *CertInfo) Reload(lg *zap.Logger) (tlsConfig *tls.Config, err error) { // Some methods to rotate server config: // - For certs: customize GetCertificate / GetConfigForClient. - // - For CA: customize ClientAuth + VerifyPeerCertificate / GetConfigForClient + // - For CA: customize ClientAuth + VerifyConnection / GetConfigForClient // Some methods to rotate client config: // - For certs: customize GetClientCertificate - // - For CA: customize InsecureSkipVerify + VerifyPeerCertificate + // - For CA: customize InsecureSkipVerify + VerifyConnection prevExpireTime := ci.getExpireTime() if ci.server { lg = lg.With(zap.String("tls", "server"), zap.Any("cfg", ci.cfg.Load())) @@ -82,20 +82,11 @@ func (ci *CertInfo) getClientCert(*tls.CertificateRequestInfo) (*tls.Certificate return cert, nil } -func (ci *CertInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certificate) error { - if len(rawCerts) == 0 { +func (ci *CertInfo) verifyPeerCertificates(certs []*x509.Certificate) error { + if len(certs) == 0 { return nil } - certs := make([]*x509.Certificate, len(rawCerts)) - for i, asn1Data := range rawCerts { - cert, err := x509.ParseCertificate(asn1Data) - if err != nil { - return errors.New("tls: failed to parse certificate from server: " + err.Error()) - } - certs[i] = cert - } - cas := ci.ca.Load() if cas == nil { cas = x509.NewCertPool() @@ -120,6 +111,10 @@ func (ci *CertInfo) verifyPeerCertificate(rawCerts [][]byte, _ [][]*x509.Certifi return err } +func (ci *CertInfo) verifyConnection(cs tls.ConnectionState) error { + return ci.verifyPeerCertificates(cs.PeerCertificates) +} + func (ci *CertInfo) loadCA(pemCerts []byte) (*x509.CertPool, error) { pool := x509.NewCertPool() for len(pemCerts) > 0 { @@ -155,10 +150,10 @@ func (ci *CertInfo) buildServerConfig(lg *zap.Logger) (*tls.Config, error) { } tcfg := &tls.Config{ - MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg), - GetCertificate: ci.getCert, - GetClientCertificate: ci.getClientCert, - VerifyPeerCertificate: ci.verifyPeerCertificate, + MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg), + GetCertificate: ci.getCert, + GetClientCertificate: ci.getClientCert, + VerifyConnection: ci.verifyConnection, } var certPEM, keyPEM []byte @@ -239,11 +234,11 @@ func (ci *CertInfo) buildClientConfig(lg *zap.Logger) (*tls.Config, error) { } tcfg := &tls.Config{ - MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg), - GetCertificate: ci.getCert, - GetClientCertificate: ci.getClientCert, - InsecureSkipVerify: true, - VerifyPeerCertificate: ci.verifyPeerCertificate, + MinVersion: GetMinTLSVer(cfg.MinTLSVersion, lg), + GetCertificate: ci.getCert, + GetClientCertificate: ci.getClientCert, + InsecureSkipVerify: true, + VerifyConnection: ci.verifyConnection, } certBytes, err := os.ReadFile(cfg.CA) diff --git a/pkg/balance/metricsreader/backend_reader_test.go b/pkg/balance/metricsreader/backend_reader_test.go index 93f402fec..e4eed6b24 100644 --- a/pkg/balance/metricsreader/backend_reader_test.go +++ b/pkg/balance/metricsreader/backend_reader_test.go @@ -832,7 +832,7 @@ func TestQueryBackendConcurrently(t *testing.T) { const initialRules, initialBackends = 3, 3 var buf strings.Builder for i := 0; i < initialRules+1; i++ { - buf.WriteString(fmt.Sprintf("name%d 100\n", i)) + _, _ = fmt.Fprintf(&buf, "name%d 100\n", i) } resp := buf.String() diff --git a/pkg/balance/observer/backend_health.go b/pkg/balance/observer/backend_health.go index 6091f7195..8434a4fbd 100644 --- a/pkg/balance/observer/backend_health.go +++ b/pkg/balance/observer/backend_health.go @@ -53,14 +53,14 @@ func (bh *BackendHealth) String() string { _, _ = sb.WriteString("down") } if bh.PingErr != nil { - _, _ = sb.WriteString(fmt.Sprintf(", err: %s", bh.PingErr.Error())) + _, _ = fmt.Fprintf(&sb, ", err: %s", bh.PingErr) } if len(bh.ServerVersion) > 0 { _, _ = sb.WriteString(", version: ") _, _ = sb.WriteString(bh.ServerVersion) } if bh.Labels != nil { - _, _ = sb.WriteString(fmt.Sprintf(", labels: %v", bh.Labels)) + _, _ = fmt.Fprintf(&sb, ", labels: %v", bh.Labels) } return sb.String() } diff --git a/pkg/balance/observer/health_check_test.go b/pkg/balance/observer/health_check_test.go index 905eab5f4..6a4626eb2 100644 --- a/pkg/balance/observer/health_check_test.go +++ b/pkg/balance/observer/health_check_test.go @@ -35,7 +35,7 @@ func TestReadServerVersion(t *testing.T) { backend.stopSQLServer() //test for respBody not ok - backend.mockHttpHandler.setHTTPRespBody("") + backend.setHTTPRespBody("") backend.startSQLServer() health = hc.Check(context.Background(), backend.sqlAddr, info) require.False(t, health.Healthy) @@ -120,7 +120,7 @@ func (srv *backendServer) setServerVersion(version string) { GitHash: "", } body, _ := json.Marshal(resp) - srv.mockHttpHandler.setHTTPRespBody(string(body)) + srv.setHTTPRespBody(string(body)) } func (srv *backendServer) startHTTPServer() { diff --git a/pkg/balance/router/router.go b/pkg/balance/router/router.go index e93c0c7a9..408cd1cce 100644 --- a/pkg/balance/router/router.go +++ b/pkg/balance/router/router.go @@ -159,7 +159,7 @@ func (b *backendWrapper) GetBackendInfo() observer.BackendInfo { func (b *backendWrapper) Equals(health observer.BackendHealth) bool { b.mu.RLock() - equal := b.mu.BackendHealth.Equals(health) + equal := b.mu.Equals(health) b.mu.RUnlock() return equal } diff --git a/pkg/manager/config/config_test.go b/pkg/manager/config/config_test.go index a7d92e477..e2a5c98ff 100644 --- a/pkg/manager/config/config_test.go +++ b/pkg/manager/config/config_test.go @@ -192,7 +192,7 @@ func TestFilePath(t *testing.T) { // For linux, it creates another file. For macOS, it doesn't touch the file. f, err = os.Create(filepath.Join(tmpdir, "cfg")) require.NoError(t, err) - _, err = f.WriteString(fmt.Sprintf("proxy.pd-addrs = \"%s\"", pdAddr1)) + _, err = fmt.Fprintf(f, "proxy.pd-addrs = \"%s\"", pdAddr1) require.NoError(t, err) require.NoError(t, f.Close()) }, @@ -219,7 +219,7 @@ func TestFilePath(t *testing.T) { } f, err := os.Create("_tmp/cfg") require.NoError(t, err) - _, err = f.WriteString(fmt.Sprintf("proxy.pd-addrs = \"%s\"", pdAddr1)) + _, err = fmt.Fprintf(f, "proxy.pd-addrs = \"%s\"", pdAddr1) require.NoError(t, err) require.NoError(t, f.Close()) }, @@ -233,7 +233,7 @@ func TestFilePath(t *testing.T) { require.NoError(t, os.Mkdir("_tmp", 0755)) f, err := os.Create("_tmp/cfg") require.NoError(t, err) - _, err = f.WriteString(fmt.Sprintf("proxy.pd-addrs = \"%s\"", pdAddr3)) + _, err = fmt.Fprintf(f, "proxy.pd-addrs = \"%s\"", pdAddr3) require.NoError(t, err) require.NoError(t, f.Close()) t.Log("write _tmp") @@ -250,7 +250,7 @@ func TestFilePath(t *testing.T) { f, err := os.Create(filename) require.NoError(t, err) - _, err = f.WriteString(fmt.Sprintf("proxy.pd-addrs = \"%s\"", pdAddr3)) + _, err = fmt.Fprintf(f, "proxy.pd-addrs = \"%s\"", pdAddr3) require.NoError(t, err) require.NoError(t, f.Close()) require.Eventually(t, func() bool { @@ -267,7 +267,7 @@ func TestFilePath(t *testing.T) { } else { f, err := os.Create(test.filename) require.NoError(t, err) - _, err = f.WriteString(fmt.Sprintf("proxy.pd-addrs = \"%s\"", pdAddr1)) + _, err = fmt.Fprintf(f, "proxy.pd-addrs = \"%s\"", pdAddr1) require.NoError(t, err) require.NoError(t, f.Close()) } diff --git a/pkg/manager/elect/election.go b/pkg/manager/elect/election.go index 189a7086b..66d13cffd 100644 --- a/pkg/manager/elect/election.go +++ b/pkg/manager/elect/election.go @@ -199,7 +199,7 @@ func (m *election) onRetired() { m.member.OnRetired() m.isOwner = false // Delete the metric so that it doesn't show on Grafana. - metrics.OwnerGauge.MetricVec.DeletePartialMatch(map[string]string{metrics.LblType: m.trimedKey}) + metrics.OwnerGauge.DeletePartialMatch(map[string]string{metrics.LblType: m.trimedKey}) } // waitRetire retires after another member becomes the owner so that there will always be an owner. diff --git a/pkg/manager/vip/manager_test.go b/pkg/manager/vip/manager_test.go index 72e5f3555..0774e00f0 100644 --- a/pkg/manager/vip/manager_test.go +++ b/pkg/manager/vip/manager_test.go @@ -208,7 +208,8 @@ func TestMultiVIP(t *testing.T) { require.NoError(t, err) require.Eventually(t, func() bool { return strings.Count(text.String(), "adding VIP success") >= 2 || - strings.Count(text.String(), "ip: command not found") >= 2 + strings.Count(text.String(), "ip: command not found") >= 2 || + strings.Count(text.String(), "executable file not found") >= 2 }, 3*time.Second, 10*time.Millisecond) vm1.PreClose() vm2.PreClose() diff --git a/pkg/manager/vip/network_test.go b/pkg/manager/vip/network_test.go index cc6586b04..2b00420b4 100644 --- a/pkg/manager/vip/network_test.go +++ b/pkg/manager/vip/network_test.go @@ -45,7 +45,9 @@ func TestAddDelIP(t *testing.T) { } isOtherErr := func(err error) bool { - return strings.Contains(err.Error(), "command not found") || strings.Contains(err.Error(), "not in the sudoers file") + return strings.Contains(err.Error(), "command not found") || + strings.Contains(err.Error(), "not in the sudoers file") || + strings.Contains(err.Error(), "executable file not found") } for i, test := range tests { diff --git a/pkg/proxy/backend/authenticator_test.go b/pkg/proxy/backend/authenticator_test.go index 679dc3862..243589fab 100644 --- a/pkg/proxy/backend/authenticator_test.go +++ b/pkg/proxy/backend/authenticator_test.go @@ -68,11 +68,11 @@ func TestUnsupportedCapability(t *testing.T) { for _, cfgs := range cfgOverriders { ts, clean := newTestSuite(t, tc, cfgs...) ts.authenticateFirstTime(t, func(t *testing.T, _ *testSuite) { - if ts.mc.clientConfig.capability&requiredFrontendCaps != requiredFrontendCaps { + if ts.mc.capability&requiredFrontendCaps != requiredFrontendCaps { require.ErrorIs(t, ts.mp.err, ErrClientCap) require.Nil(t, ErrToClient(ts.mp.err)) require.Equal(t, SrcClientHandshake, Error2Source(ts.mp.err)) - } else if ts.mb.backendConfig.capability&defRequiredBackendCaps != defRequiredBackendCaps { + } else if ts.mb.capability&defRequiredBackendCaps != defRequiredBackendCaps { require.ErrorIs(t, ts.mp.err, ErrBackendCap) require.Equal(t, ErrBackendCap, ErrToClient(ts.mp.err)) require.Equal(t, SrcBackendHandshake, Error2Source(ts.mp.err)) @@ -583,8 +583,8 @@ func TestUpgradeBackendCap(t *testing.T) { require.Equal(t, pnet.Capability(0), ts.mb.capability&pnet.ClientCompress) }) // After upgrade, the backend also supports compression. - ts.mb.backendConfig.capability |= pnet.ClientCompress - ts.mb.backendConfig.capability |= pnet.ClientZstdCompressionAlgorithm + ts.mb.capability |= pnet.ClientCompress + ts.mb.capability |= pnet.ClientZstdCompressionAlgorithm ts.authenticateSecondTime(t, func(t *testing.T, ts *testSuite) { require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mc.capability&pnet.ClientCompress) require.Equal(t, referCfg.clientConfig.capability&pnet.ClientCompress, ts.mp.authenticator.capability&pnet.ClientCompress) diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index fd62199d6..2b4addbf3 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -227,7 +227,9 @@ func (mgr *BackendConnManager) Connect(ctx context.Context, clientIO pnet.Packet mgr.cmdProcessor.capability = mgr.authenticator.capability childCtx, cancelFunc := context.WithCancel(ctx) - mgr.cancelFunc = cancelFunc + mgr.cancelFunc = func() { + cancelFunc() + } mgr.lastActiveTime = endTime if mgr.cpt != nil && !reflect.ValueOf(mgr.cpt).IsNil() { mgr.cpt.InitConn(endTime, mgr.connectionID, mgr.authenticator.dbname) diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index f89eee344..004a3e359 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -1209,7 +1209,7 @@ func TestCloseWhileConnect(t *testing.T) { client: ts.mc.authenticate, proxy: func(clientIO, backendIO pnet.PacketIO) error { go func() { - require.NoError(ts.t, ts.mp.BackendConnManager.Close()) + require.NoError(ts.t, ts.mp.Close()) }() err := ts.mp.Connect(context.Background(), clientIO, ts.mp.frontendTLSConfig, ts.mp.backendTLSConfig, "", "") if err == nil { @@ -1244,7 +1244,7 @@ func TestCloseWhileExecute(t *testing.T) { return err } go func() { - require.NoError(ts.t, ts.mp.BackendConnManager.Close()) + require.NoError(ts.t, ts.mp.Close()) }() return ts.mp.ExecuteCmd(context.Background(), request) }, @@ -1268,7 +1268,7 @@ func TestCloseWhileGracefulClose(t *testing.T) { { proxy: func(clientIO, backendIO pnet.PacketIO) error { go func() { - require.NoError(ts.t, ts.mp.BackendConnManager.Close()) + require.NoError(ts.t, ts.mp.Close()) }() ts.mp.GracefulClose() return nil diff --git a/pkg/proxy/backend/testsuite_test.go b/pkg/proxy/backend/testsuite_test.go index d5d8e3888..87bb791e3 100644 --- a/pkg/proxy/backend/testsuite_test.go +++ b/pkg/proxy/backend/testsuite_test.go @@ -199,7 +199,7 @@ func (ts *testSuite) authenticateFirstTime(t *testing.T, c checker) { // The proxy reconnects to the proxy using preserved client data. // This must be called after authenticateFirstTime. func (ts *testSuite) authenticateSecondTime(t *testing.T, c checker) { - ts.mb.backendConfig.authSucceed = true + ts.mb.authSucceed = true ts.tc.reconnectBackend(t) ts.runAndCheck(t, c, nil, ts.mb.authenticate, ts.mp.authenticateSecondTime) if c == nil { diff --git a/pkg/proxy/net/compress.go b/pkg/proxy/net/compress.go index 63ef43173..5d46b1084 100644 --- a/pkg/proxy/net/compress.go +++ b/pkg/proxy/net/compress.go @@ -89,7 +89,7 @@ func (crw *compressedReadWriter) ResetSequence() { // the client/server begins reading or writing. func (crw *compressedReadWriter) BeginRW(status rwStatus) { if crw.rwStatus != status { - crw.packetReadWriter.SetSequence(crw.sequence) + crw.SetSequence(crw.sequence) crw.rwStatus = status } } diff --git a/pkg/proxy/net/packetio.go b/pkg/proxy/net/packetio.go index 8e27b2755..64104724d 100644 --- a/pkg/proxy/net/packetio.go +++ b/pkg/proxy/net/packetio.go @@ -162,14 +162,14 @@ func (brw *basicReadWriter) TLSConnectionState() tls.ConnectionState { // This function normally costs 1ms, so don't call it too frequently. // This function may incorrectly return true if the system is extremely slow. func (brw *basicReadWriter) IsPeerActive() bool { - if err := brw.Conn.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { + if err := brw.SetReadDeadline(time.Now().Add(time.Millisecond)); err != nil { return false } active := true - if _, err := brw.ReadWriter.Peek(1); err != nil { + if _, err := brw.Peek(1); err != nil { active = !errors.Is(err, io.EOF) } - if err := brw.Conn.SetReadDeadline(time.Time{}); err != nil { + if err := brw.SetReadDeadline(time.Time{}); err != nil { return false } return active diff --git a/pkg/proxy/net/proxy.go b/pkg/proxy/net/proxy.go index c7b45de9b..33808363f 100644 --- a/pkg/proxy/net/proxy.go +++ b/pkg/proxy/net/proxy.go @@ -109,7 +109,7 @@ func (prw *proxyReadWriter) writeProxy() error { return errors.Wrap(err, ErrWriteConn) } // according to the spec, we better flush to avoid server hanging - if err := prw.packetReadWriter.Flush(); err != nil { + if err := prw.Flush(); err != nil { return err } prw.proxyInited.Store(true) @@ -133,7 +133,7 @@ func (prw *proxyReadWriter) parseProxyV2() (*proxyprotocol.Proxy, error) { } // yes, it is proxyV2 - _, err = prw.packetReadWriter.Discard(len(proxyprotocol.MagicV2)) + _, err = prw.Discard(len(proxyprotocol.MagicV2)) if err != nil { return nil, errors.WithStack(errors.Wrap(err, ErrReadConn)) } diff --git a/pkg/proxy/net/tls.go b/pkg/proxy/net/tls.go index ff79d5706..b05da0e41 100644 --- a/pkg/proxy/net/tls.go +++ b/pkg/proxy/net/tls.go @@ -20,7 +20,7 @@ type tlsInternalConn struct { } func (br *tlsInternalConn) Write(p []byte) (n int, err error) { - return br.packetReadWriter.DirectWrite(p) + return br.DirectWrite(p) } func (p *packetIO) ServerTLSHandshake(tlsConfig *tls.Config) (tls.ConnectionState, error) { diff --git a/pkg/util/lex/lex.go b/pkg/util/lex/lex.go index c9a2b0acf..848d0824e 100644 --- a/pkg/util/lex/lex.go +++ b/pkg/util/lex/lex.go @@ -36,15 +36,17 @@ func (l *Lexer) NextToken() string { } } case inSingleQuote: - if char == '\\' { + switch char { + case '\\': l.curIdx++ - } else if char == '\'' { + case '\'': inSingleQuote = false } case inDoubleQuote: - if char == '\\' { + switch char { + case '\\': l.curIdx++ - } else if char == '"' { + case '"': inDoubleQuote = false } case char == '-' && l.curIdx+1 < len(l.sql) && l.sql[l.curIdx+1] == '-':