diff --git a/AGENTS.md b/AGENTS.md index eb83b0350..dd553c8dc 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -47,6 +47,7 @@ When adding or modifying features, prefer extending existing packages before cre - `pkg/manager/config/` - Auto-reloads configuration files and provides interfaces to query them. - `pkg/manager/elect/` - Manages TiProxy owner elections (for example, metrics reader and VIP modules need an owner). - `pkg/manager/id/` - Generates global IDs. +- `pkg/manager/backendcluster/` - Manages cluster-scoped backend runtimes and shared resources such as PD or etcd clients. - `pkg/manager/infosync/` - Queries the topology of TiDB and Prometheus from PD and updates TiProxy information to PD. - `pkg/manager/logger/` - Manages the logger service. - `pkg/manager/memory/` - Records heap and goroutine profiles when memory usage is high. diff --git a/go.mod b/go.mod index 421407d70..b257836fd 100644 --- a/go.mod +++ b/go.mod @@ -38,6 +38,7 @@ require ( go.uber.org/mock v0.5.2 go.uber.org/ratelimit v0.2.0 go.uber.org/zap v1.27.0 + golang.org/x/net v0.48.0 google.golang.org/grpc v1.63.2 ) @@ -272,7 +273,6 @@ require ( golang.org/x/crypto v0.47.0 // indirect golang.org/x/exp v0.0.0-20240909161429-701f63a606c0 // indirect golang.org/x/mod v0.31.0 // indirect - golang.org/x/net v0.48.0 // indirect golang.org/x/oauth2 v0.30.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect diff --git a/lib/config/label.go b/lib/config/label.go index 2a7a858f0..b34615c2c 100644 --- a/lib/config/label.go +++ b/lib/config/label.go @@ -6,9 +6,10 @@ package config const ( // LocationLabelName indicates the label name that decides the location of TiProxy and backends. // We use `zone` because the follower read in TiDB also uses `zone` to decide location. - LocationLabelName = "zone" - KeyspaceLabelName = "keyspace" - CidrLabelName = "cidr" + LocationLabelName = "zone" + KeyspaceLabelName = "keyspace" + CidrLabelName = "cidr" + TiProxyPortLabelName = "tiproxy-port" ) func (cfg *Config) GetLocation() string { diff --git a/lib/config/proxy.go b/lib/config/proxy.go index 43fc0573e..c6c5987af 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -24,6 +24,8 @@ var ( ErrInvalidConfigValue = errors.New("invalid config value") ) +const DefaultBackendClusterName = "default" + type Config struct { Proxy ProxyServer `yaml:"proxy,omitempty" toml:"proxy,omitempty" json:"proxy,omitempty"` API API `yaml:"api,omitempty" toml:"api,omitempty" json:"api,omitempty"` @@ -249,7 +251,7 @@ func (cfg *Config) GetBackendClusters() []BackendCluster { return nil } return []BackendCluster{{ - Name: "default", + Name: DefaultBackendClusterName, PDAddrs: cfg.Proxy.PDAddrs, }} } diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index a8824643a..fdd41cc98 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -314,7 +314,7 @@ func TestGetBackendClusters(t *testing.T) { clusters := cfg.GetBackendClusters() require.Len(t, clusters, 1) - require.Equal(t, "default", clusters[0].Name) + require.Equal(t, DefaultBackendClusterName, clusters[0].Name) require.Equal(t, cfg.Proxy.PDAddrs, clusters[0].PDAddrs) cfg.Proxy.BackendClusters = []BackendCluster{ diff --git a/pkg/balance/factor/factor_balance.go b/pkg/balance/factor/factor_balance.go index a1ac47e87..f50d83c98 100644 --- a/pkg/balance/factor/factor_balance.go +++ b/pkg/balance/factor/factor_balance.go @@ -30,7 +30,7 @@ type FactorBasedBalance struct { factors []Factor // to reduce memory allocation cachedList []scoredBackend - mr metricsreader.MetricsReader + mr metricsreader.MetricsQuerier lg *zap.Logger factorStatus *FactorStatus factorLabel *FactorLabel @@ -44,7 +44,7 @@ type FactorBasedBalance struct { routePolicy string } -func NewFactorBasedBalance(lg *zap.Logger, mr metricsreader.MetricsReader) *FactorBasedBalance { +func NewFactorBasedBalance(lg *zap.Logger, mr metricsreader.MetricsQuerier) *FactorBasedBalance { return &FactorBasedBalance{ lg: lg, mr: mr, diff --git a/pkg/balance/factor/factor_cpu.go b/pkg/balance/factor/factor_cpu.go index 0c319e459..cf56f7232 100644 --- a/pkg/balance/factor/factor_cpu.go +++ b/pkg/balance/factor/factor_cpu.go @@ -95,13 +95,13 @@ type FactorCPU struct { lastMetricTime time.Time // The estimated average CPU usage used by one connection. usagePerConn float64 - mr metricsreader.MetricsReader + mr metricsreader.MetricsQuerier bitNum int migrationsPerSecond float64 lg *zap.Logger } -func NewFactorCPU(mr metricsreader.MetricsReader, lg *zap.Logger) *FactorCPU { +func NewFactorCPU(mr metricsreader.MetricsQuerier, lg *zap.Logger) *FactorCPU { fc := &FactorCPU{ mr: mr, bitNum: 5, diff --git a/pkg/balance/factor/factor_health.go b/pkg/balance/factor/factor_health.go index 868145c94..36ae846b5 100644 --- a/pkg/balance/factor/factor_health.go +++ b/pkg/balance/factor/factor_health.go @@ -187,13 +187,13 @@ type errIndicator struct { type FactorHealth struct { snapshot map[string]healthBackendSnapshot indicators []errIndicator - mr metricsreader.MetricsReader + mr metricsreader.MetricsQuerier bitNum int migrationsPerSecond float64 lg *zap.Logger } -func NewFactorHealth(mr metricsreader.MetricsReader, lg *zap.Logger) *FactorHealth { +func NewFactorHealth(mr metricsreader.MetricsQuerier, lg *zap.Logger) *FactorHealth { return &FactorHealth{ mr: mr, snapshot: make(map[string]healthBackendSnapshot), @@ -203,7 +203,7 @@ func NewFactorHealth(mr metricsreader.MetricsReader, lg *zap.Logger) *FactorHeal } } -func initErrIndicator(mr metricsreader.MetricsReader) []errIndicator { +func initErrIndicator(mr metricsreader.MetricsQuerier) []errIndicator { indicators := make([]errIndicator, 0, len(errDefinitions)) for _, def := range errDefinitions { indicator := errIndicator{ diff --git a/pkg/balance/factor/factor_memory.go b/pkg/balance/factor/factor_memory.go index f577b7d27..cdf14a00a 100644 --- a/pkg/balance/factor/factor_memory.go +++ b/pkg/balance/factor/factor_memory.go @@ -107,13 +107,13 @@ type FactorMemory struct { snapshot map[string]memBackendSnapshot // The updated time of the metric that we've read last time. lastMetricTime time.Time - mr metricsreader.MetricsReader + mr metricsreader.MetricsQuerier bitNum int migrationsPerSecond float64 lg *zap.Logger } -func NewFactorMemory(mr metricsreader.MetricsReader, lg *zap.Logger) *FactorMemory { +func NewFactorMemory(mr metricsreader.MetricsQuerier, lg *zap.Logger) *FactorMemory { bitNum := 0 for levels := len(oomRiskLevels); ; bitNum++ { if levels == 0 { diff --git a/pkg/balance/metricsreader/backend_reader.go b/pkg/balance/metricsreader/backend_reader.go index 49f3130f0..505aaeafd 100644 --- a/pkg/balance/metricsreader/backend_reader.go +++ b/pkg/balance/metricsreader/backend_reader.go @@ -9,6 +9,7 @@ import ( "fmt" "math" "net" + "net/url" "slices" "strconv" "strings" @@ -35,18 +36,15 @@ import ( const ( // readerOwnerKeyPrefix is the key prefix in etcd for backend reader owner election. - // For global owner, the key is "/tiproxy/metric_reader/owner". - // For zonal owner, the key is "/tiproxy/metric_reader/{zone}/owner". - readerOwnerKeyPrefix = "/tiproxy/metric_reader" + // For the default cluster, the key is "/tiproxy/metric_reader/owner". + // For a named cluster, the key is "/tiproxy/metric_reader/{cluster}/owner". readerOwnerKeySuffix = "owner" // sessionTTL is the session's TTL in seconds for backend reader owner election. sessionTTL = 15 // backendMetricPath is the path of backend HTTP API to read metrics. backendMetricPath = "/metrics" - // ownerMetricPath is the path of reading backend metrics from the backend reader owner. - ownerMetricPath = "/api/backend/metrics" - goPoolSize = 100 - goMaxIdle = time.Minute + goPoolSize = 100 + goMaxIdle = time.Minute ) var ( @@ -72,6 +70,7 @@ type BackendReader struct { marshalledHistory []byte cfgGetter config.ConfigGetter backendFetcher TopologyFetcher + clusterName string lastZone string electionCfg elect.ElectionConfig election elect.Election @@ -84,6 +83,11 @@ type BackendReader struct { } func NewBackendReader(lg *zap.Logger, cfgGetter config.ConfigGetter, httpCli *http.Client, etcdCli *clientv3.Client, + backendFetcher TopologyFetcher, cfg *config.HealthCheck) *BackendReader { + return NewClusterBackendReader(lg, "", cfgGetter, httpCli, etcdCli, backendFetcher, cfg) +} + +func NewClusterBackendReader(lg *zap.Logger, clusterName string, cfgGetter config.ConfigGetter, httpCli *http.Client, etcdCli *clientv3.Client, backendFetcher TopologyFetcher, cfg *config.HealthCheck) *BackendReader { return &BackendReader{ queryRules: make(map[string]QueryRule), @@ -92,6 +96,7 @@ func NewBackendReader(lg *zap.Logger, cfgGetter config.ConfigGetter, httpCli *ht lg: lg, cfgGetter: cfgGetter, backendFetcher: backendFetcher, + clusterName: strings.TrimSpace(clusterName), cfg: cfg, wgp: waitgroup.NewWaitGroupPool(goPoolSize, goMaxIdle), electionCfg: elect.DefaultElectionConfig(sessionTTL), @@ -118,9 +123,9 @@ func (br *BackendReader) initElection(ctx context.Context, cfg *config.Config) e br.lastZone = cfg.GetLocation() if len(br.lastZone) > 0 { // Zonal owners are responsible for the backends in the same zone or not in any TiProxy zone. - key = fmt.Sprintf("%s/%s/%s", readerOwnerKeyPrefix, br.lastZone, readerOwnerKeySuffix) + key = fmt.Sprintf("%s/%s/%s", readerOwnerKeyPrefix(br.clusterName), br.lastZone, readerOwnerKeySuffix) } else { - key = fmt.Sprintf("%s/%s", readerOwnerKeyPrefix, readerOwnerKeySuffix) + key = fmt.Sprintf("%s/%s", readerOwnerKeyPrefix(br.clusterName), readerOwnerKeySuffix) } br.election = elect.NewElection(br.lg.Named("elect"), br.etcdCli, br.electionCfg, id, key, br) br.election.Start(ctx) @@ -213,7 +218,8 @@ func (br *BackendReader) queryAllOwners(ctx context.Context) (zones, owners []st // Get all owner keys. opts := []clientv3.OpOption{clientv3.WithPrefix()} var kvs []*mvccpb.KeyValue - kvs, err = etcd.GetKVs(ctx, br.etcdCli, readerOwnerKeyPrefix, opts, br.electionCfg.Timeout, br.electionCfg.RetryIntvl, br.electionCfg.RetryCnt) + keyPrefix := readerOwnerKeyPrefix(br.clusterName) + kvs, err = etcd.GetKVs(ctx, br.etcdCli, keyPrefix, opts, br.electionCfg.Timeout, br.electionCfg.RetryIntvl, br.electionCfg.RetryCnt) if err != nil { return } @@ -227,7 +233,7 @@ func (br *BackendReader) queryAllOwners(ctx context.Context) (zones, owners []st ownerMap := make(map[string]ownerInfo) for _, kv := range kvs { key := hack.String(kv.Key) - key = key[len(readerOwnerKeyPrefix):] + key = key[len(keyPrefix):] if len(key) == 0 || key[0] != '/' { continue } @@ -466,7 +472,7 @@ func (br *BackendReader) GetBackendMetrics() []byte { // If every member queries directly from backends, the backends may suffer from too much pressure. func (br *BackendReader) readFromOwner(ctx context.Context, ownerAddr string) error { b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(br.cfg.RetryInterval), uint64(br.cfg.MaxRetries)), ctx) - resp, err := br.httpCli.Get(ownerAddr, ownerMetricPath, b, br.cfg.DialTimeout) + resp, err := br.httpCli.Get(ownerAddr, backendMetricOwnerPath(br.clusterName), b, br.cfg.DialTimeout) if err != nil { return err } @@ -570,6 +576,22 @@ func (br *BackendReader) Close() { } } +func readerOwnerKeyPrefix(clusterName string) string { + clusterName = strings.TrimSpace(clusterName) + if clusterName == "" || clusterName == config.DefaultBackendClusterName { + return "/tiproxy/metric_reader" + } + return fmt.Sprintf("/tiproxy/metric_reader/%s", clusterName) +} + +func backendMetricOwnerPath(clusterName string) string { + clusterName = strings.TrimSpace(clusterName) + if clusterName == "" { + return "/api/backend/metrics" + } + return fmt.Sprintf("/api/backend/metrics?cluster=%s", url.QueryEscape(clusterName)) +} + func purgeHistory(history []model.SamplePair, retention time.Duration, now time.Time) []model.SamplePair { idx := -1 for i := range history { diff --git a/pkg/balance/metricsreader/backend_reader_test.go b/pkg/balance/metricsreader/backend_reader_test.go index 0245d72d5..87cd2cc83 100644 --- a/pkg/balance/metricsreader/backend_reader_test.go +++ b/pkg/balance/metricsreader/backend_reader_test.go @@ -1149,7 +1149,7 @@ func TestQueryAllOwners(t *testing.T) { br := NewBackendReader(lg, nil, nil, suite.client, nil, nil) for i, test := range tests { for i, key := range test.keys { - key = fmt.Sprintf("%s%s", readerOwnerKeyPrefix, key) + key = fmt.Sprintf("%s%s", readerOwnerKeyPrefix(""), key) suite.putKV(key, test.values[i]) } zones, owners, err := br.queryAllOwners(context.Background()) @@ -1166,7 +1166,7 @@ func TestQueryAllOwners(t *testing.T) { slices.Sort(zones) require.Equal(t, test.zones, zones, "case %d", i) } - suite.delKV(readerOwnerKeyPrefix) + suite.delKV(readerOwnerKeyPrefix("")) } } @@ -1184,7 +1184,7 @@ func TestUpdateLabel(t *testing.T) { defer br.Close() checkKeyPrefix := func(prefix string) bool { - kvs := suite.getKV(readerOwnerKeyPrefix) + kvs := suite.getKV(readerOwnerKeyPrefix("")) if len(kvs) != 1 { return false } @@ -1192,7 +1192,7 @@ func TestUpdateLabel(t *testing.T) { } // campaign for the global owner - prefix := fmt.Sprintf("%s/%s", readerOwnerKeyPrefix, readerOwnerKeySuffix) + prefix := fmt.Sprintf("%s/%s", readerOwnerKeyPrefix(""), readerOwnerKeySuffix) require.Eventually(t, func() bool { return checkKeyPrefix(prefix) }, 3*time.Second, 10*time.Millisecond) @@ -1201,7 +1201,7 @@ func TestUpdateLabel(t *testing.T) { cfg.Labels = map[string]string{config.LocationLabelName: "east"} err = br.ReadMetrics(context.Background()) require.NoError(t, err) - prefix = fmt.Sprintf("%s/east/%s", readerOwnerKeyPrefix, readerOwnerKeySuffix) + prefix = fmt.Sprintf("%s/east/%s", readerOwnerKeyPrefix(""), readerOwnerKeySuffix) require.Eventually(t, func() bool { return checkKeyPrefix(prefix) }, 3*time.Second, 10*time.Millisecond) @@ -1255,7 +1255,7 @@ func TestElection(t *testing.T) { // setup etcd suite := newEtcdTestSuite(t) t.Cleanup(suite.close) - ownerKey := fmt.Sprintf("%s/%s", readerOwnerKeyPrefix, readerOwnerKeySuffix) + ownerKey := fmt.Sprintf("%s/%s", readerOwnerKeyPrefix(""), readerOwnerKeySuffix) suite.putKV(ownerKey, addr) require.Eventually(t, func() bool { kvs := suite.getKV(ownerKey) @@ -1330,3 +1330,9 @@ func setupTypicalBackendListener(t *testing.T, respBody string) (backendPort int t.Cleanup(backendHttpHandler.Close) return } + +func TestBackendMetricOwnerPath(t *testing.T) { + require.Equal(t, "/api/backend/metrics", backendMetricOwnerPath("")) + require.Equal(t, "/api/backend/metrics?cluster=cluster-a", backendMetricOwnerPath("cluster-a")) + require.Equal(t, "/api/backend/metrics?cluster=cluster+a%2Fb", backendMetricOwnerPath("cluster a/b")) +} diff --git a/pkg/balance/metricsreader/metrics_reader.go b/pkg/balance/metricsreader/metrics_reader.go index 0f1507f05..24c070e8d 100644 --- a/pkg/balance/metricsreader/metrics_reader.go +++ b/pkg/balance/metricsreader/metrics_reader.go @@ -32,19 +32,25 @@ type TopologyFetcher interface { GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) } -type MetricsReader interface { - Start(ctx context.Context) error +type MetricsQuerier interface { AddQueryExpr(key string, queryExpr QueryExpr, queryRule QueryRule) RemoveQueryExpr(key string) GetQueryResult(key string) QueryResult GetBackendMetrics() []byte +} + +type MetricsReader interface { + MetricsQuerier + + Start(ctx context.Context) error PreClose() Close() } -var _ MetricsReader = (*DefaultMetricsReader)(nil) +var _ MetricsReader = (*ClusterReader)(nil) -type DefaultMetricsReader struct { +// ClusterReader is the metrics reader owned by one backend cluster. +type ClusterReader struct { source atomic.Int32 backendReader *BackendReader promReader *PromReader @@ -54,17 +60,22 @@ type DefaultMetricsReader struct { cfg *config.HealthCheck } -func NewDefaultMetricsReader(lg *zap.Logger, promFetcher PromInfoFetcher, backendFetcher TopologyFetcher, httpCli *http.Client, - etcdCli *clientv3.Client, cfg *config.HealthCheck, cfgGetter config.ConfigGetter) *DefaultMetricsReader { - return &DefaultMetricsReader{ +func NewClusterReader(lg *zap.Logger, clusterName string, promFetcher PromInfoFetcher, backendFetcher TopologyFetcher, httpCli *http.Client, + etcdCli *clientv3.Client, cfg *config.HealthCheck, cfgGetter config.ConfigGetter) *ClusterReader { + return &ClusterReader{ lg: lg, cfg: cfg, promReader: NewPromReader(lg.Named("prom_reader"), promFetcher, cfg), - backendReader: NewBackendReader(lg.Named("backend_reader"), cfgGetter, httpCli, etcdCli, backendFetcher, cfg), + backendReader: NewClusterBackendReader(lg.Named("backend_reader"), clusterName, cfgGetter, httpCli, etcdCli, backendFetcher, cfg), } } -func (dmr *DefaultMetricsReader) Start(ctx context.Context) error { +func NewDefaultMetricsReader(lg *zap.Logger, promFetcher PromInfoFetcher, backendFetcher TopologyFetcher, httpCli *http.Client, + etcdCli *clientv3.Client, cfg *config.HealthCheck, cfgGetter config.ConfigGetter) *ClusterReader { + return NewClusterReader(lg, "", promFetcher, backendFetcher, httpCli, etcdCli, cfg, cfgGetter) +} + +func (dmr *ClusterReader) Start(ctx context.Context) error { if err := dmr.backendReader.Start(ctx); err != nil { return err } @@ -86,7 +97,7 @@ func (dmr *DefaultMetricsReader) Start(ctx context.Context) error { } // readMetrics reads from Prometheus first. If it fails, fall back to read backends. -func (dmr *DefaultMetricsReader) readMetrics(ctx context.Context) { +func (dmr *ClusterReader) readMetrics(ctx context.Context) { if ctx.Err() != nil { return } @@ -107,7 +118,7 @@ func (dmr *DefaultMetricsReader) readMetrics(ctx context.Context) { dmr.lg.Warn("read metrics failed", zap.NamedError("prometheus", promErr), zap.NamedError("backends", backendErr)) } -func (dmr *DefaultMetricsReader) setSource(source int32, err error) { +func (dmr *ClusterReader) setSource(source int32, err error) { old := dmr.source.Load() if old != source { dmr.source.Store(source) @@ -120,18 +131,18 @@ func (dmr *DefaultMetricsReader) setSource(source int32, err error) { } } -func (dmr *DefaultMetricsReader) AddQueryExpr(key string, queryExpr QueryExpr, queryRule QueryRule) { +func (dmr *ClusterReader) AddQueryExpr(key string, queryExpr QueryExpr, queryRule QueryRule) { dmr.promReader.AddQueryExpr(key, queryExpr) dmr.backendReader.AddQueryRule(key, queryRule) } -func (dmr *DefaultMetricsReader) RemoveQueryExpr(key string) { +func (dmr *ClusterReader) RemoveQueryExpr(key string) { dmr.promReader.RemoveQueryExpr(key) dmr.backendReader.RemoveQueryRule(key) } // GetQueryResult returns an empty result if the key or the result is not found. -func (dmr *DefaultMetricsReader) GetQueryResult(key string) QueryResult { +func (dmr *ClusterReader) GetQueryResult(key string) QueryResult { switch dmr.source.Load() { case sourceProm: return dmr.promReader.GetQueryResult(key) @@ -142,11 +153,11 @@ func (dmr *DefaultMetricsReader) GetQueryResult(key string) QueryResult { } } -func (dmr *DefaultMetricsReader) GetBackendMetrics() []byte { +func (dmr *ClusterReader) GetBackendMetrics() []byte { return dmr.backendReader.GetBackendMetrics() } -func (dmr *DefaultMetricsReader) PreClose() { +func (dmr *ClusterReader) PreClose() { // No need to update results in the graceful shutdown. // Stop the loop before pre-closing the backend reader to avoid data race. if dmr.cancel != nil { @@ -157,7 +168,7 @@ func (dmr *DefaultMetricsReader) PreClose() { dmr.backendReader.PreClose() } -func (dmr *DefaultMetricsReader) Close() { +func (dmr *ClusterReader) Close() { if dmr.cancel != nil { dmr.cancel() dmr.cancel = nil diff --git a/pkg/balance/observer/backend_fetcher.go b/pkg/balance/observer/backend_fetcher.go index c29f7ab5e..543fa02af 100644 --- a/pkg/balance/observer/backend_fetcher.go +++ b/pkg/balance/observer/backend_fetcher.go @@ -25,6 +25,10 @@ type BackendFetcher interface { // TopologyFetcher is an interface to fetch the tidb topology from ETCD. type TopologyFetcher interface { GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) + // HasBackendClusters reports whether dynamic PD-backed clusters are configured at all. + // PDFetcher uses it to preserve the legacy behavior that static backend.instances still work + // when TiProxy starts without any PD cluster and clusters are added later through the API. + HasBackendClusters() bool } // PDFetcher fetches backend list from PD. @@ -32,25 +36,35 @@ type PDFetcher struct { tpFetcher TopologyFetcher logger *zap.Logger config *config.HealthCheck + static *StaticFetcher } -func NewPDFetcher(tpFetcher TopologyFetcher, logger *zap.Logger, config *config.HealthCheck) *PDFetcher { +func NewPDFetcher(tpFetcher TopologyFetcher, staticAddrs []string, logger *zap.Logger, config *config.HealthCheck) *PDFetcher { config.Check() return &PDFetcher{ tpFetcher: tpFetcher, logger: logger, config: config, + static: NewStaticFetcher(staticAddrs), } } func (pf *PDFetcher) GetBackendList(ctx context.Context) (map[string]*BackendInfo, error) { + // Keep backward compatibility with the legacy static-namespace flow: before any backend cluster + // is configured, backend.instances must still be routable even though namespace now always sees + // a non-nil topology fetcher from the cluster manager. + if !pf.tpFetcher.HasBackendClusters() { + return pf.static.GetBackendList(ctx) + } backends := pf.fetchBackendList(ctx) infos := make(map[string]*BackendInfo, len(backends)) - for addr, backend := range backends { - infos[addr] = &BackendInfo{ - Labels: backend.Labels, - IP: backend.IP, - StatusPort: backend.StatusPort, + for key, backend := range backends { + infos[key] = &BackendInfo{ + Addr: backend.Addr, + ClusterName: backend.ClusterName, + Labels: backend.Labels, + IP: backend.IP, + StatusPort: backend.StatusPort, } } return infos, nil @@ -98,7 +112,7 @@ func (sf *StaticFetcher) GetBackendList(context.Context) (map[string]*BackendInf func backendListToMap(addrs []string) map[string]*BackendInfo { backends := make(map[string]*BackendInfo, len(addrs)) for _, addr := range addrs { - backends[addr] = &BackendInfo{} + backends[addr] = &BackendInfo{Addr: addr} } return backends } diff --git a/pkg/balance/observer/backend_fetcher_test.go b/pkg/balance/observer/backend_fetcher_test.go index 5ce882d12..c56cb9ffe 100644 --- a/pkg/balance/observer/backend_fetcher_test.go +++ b/pkg/balance/observer/backend_fetcher_test.go @@ -26,6 +26,7 @@ func TestPDFetcher(t *testing.T) { { infos: map[string]*infosync.TiDBTopologyInfo{ "1.1.1.1:4000": { + Addr: "1.1.1.1:4000", Labels: map[string]string{"k1": "v1"}, IP: "1.1.1.1", StatusPort: 10080, @@ -34,6 +35,7 @@ func TestPDFetcher(t *testing.T) { check: func(m map[string]*BackendInfo) { require.Len(t, m, 1) require.NotNil(t, m["1.1.1.1:4000"]) + require.Equal(t, "1.1.1.1:4000", m["1.1.1.1:4000"].Addr) require.Equal(t, "1.1.1.1", m["1.1.1.1:4000"].IP) require.Equal(t, uint(10080), m["1.1.1.1:4000"].StatusPort) require.Equal(t, map[string]string{"k1": "v1"}, m["1.1.1.1:4000"].Labels) @@ -42,10 +44,12 @@ func TestPDFetcher(t *testing.T) { { infos: map[string]*infosync.TiDBTopologyInfo{ "1.1.1.1:4000": { + Addr: "1.1.1.1:4000", IP: "1.1.1.1", StatusPort: 10080, }, "2.2.2.2:4000": { + Addr: "2.2.2.2:4000", IP: "2.2.2.2", StatusPort: 10080, }, @@ -53,13 +57,30 @@ func TestPDFetcher(t *testing.T) { check: func(m map[string]*BackendInfo) { require.Len(t, m, 2) require.NotNil(t, m["1.1.1.1:4000"]) + require.Equal(t, "1.1.1.1:4000", m["1.1.1.1:4000"].Addr) require.Equal(t, "1.1.1.1", m["1.1.1.1:4000"].IP) require.Equal(t, uint(10080), m["1.1.1.1:4000"].StatusPort) require.NotNil(t, m["2.2.2.2:4000"]) + require.Equal(t, "2.2.2.2:4000", m["2.2.2.2:4000"].Addr) require.Equal(t, "2.2.2.2", m["2.2.2.2:4000"].IP) require.Equal(t, uint(10080), m["2.2.2.2:4000"].StatusPort) }, }, + { + infos: map[string]*infosync.TiDBTopologyInfo{ + "cluster-a/shared.tidb:4000": { + Addr: "shared.tidb:4000", + IP: "10.0.0.1", + StatusPort: 10080, + }, + }, + check: func(m map[string]*BackendInfo) { + require.Len(t, m, 1) + require.NotNil(t, m["cluster-a/shared.tidb:4000"]) + require.Equal(t, "shared.tidb:4000", m["cluster-a/shared.tidb:4000"].Addr) + require.Equal(t, "10.0.0.1", m["cluster-a/shared.tidb:4000"].IP) + }, + }, { ctx: func() context.Context { ctx, cancel := context.WithCancel(context.Background()) @@ -74,9 +95,10 @@ func TestPDFetcher(t *testing.T) { tpFetcher := newMockTpFetcher(t) lg, _ := logger.CreateLoggerForTest(t) - pf := NewPDFetcher(tpFetcher, lg, newHealthCheckConfigForTest()) + pf := NewPDFetcher(tpFetcher, nil, lg, newHealthCheckConfigForTest()) for _, test := range tests { tpFetcher.infos = test.infos + tpFetcher.hasClusters = true if test.ctx == nil { test.ctx = context.Background() } @@ -85,3 +107,27 @@ func TestPDFetcher(t *testing.T) { require.NoError(t, err) } } + +func TestPDFetcherFallbackToStaticWithoutBackendClusters(t *testing.T) { + tpFetcher := newMockTpFetcher(t) + lg, _ := logger.CreateLoggerForTest(t) + fetcher := NewPDFetcher(tpFetcher, []string{"127.0.0.1:4000"}, lg, newHealthCheckConfigForTest()) + + backends, err := fetcher.GetBackendList(context.Background()) + require.NoError(t, err) + require.Len(t, backends, 1) + require.Contains(t, backends, "127.0.0.1:4000") + + tpFetcher.hasClusters = true + tpFetcher.infos = map[string]*infosync.TiDBTopologyInfo{ + "cluster-a/10.0.0.1:4000": { + Addr: "10.0.0.1:4000", + ClusterName: "cluster-a", + }, + } + backends, err = fetcher.GetBackendList(context.Background()) + require.NoError(t, err) + require.Len(t, backends, 1) + require.Equal(t, "10.0.0.1:4000", backends["cluster-a/10.0.0.1:4000"].Addr) + require.Equal(t, "cluster-a", backends["cluster-a/10.0.0.1:4000"].ClusterName) +} diff --git a/pkg/balance/observer/backend_health.go b/pkg/balance/observer/backend_health.go index 2fec40bbb..45e1755d7 100644 --- a/pkg/balance/observer/backend_health.go +++ b/pkg/balance/observer/backend_health.go @@ -76,13 +76,17 @@ func (bh *BackendHealth) String() string { // BackendInfo stores the status info of each backend. type BackendInfo struct { - Labels map[string]string - IP string - StatusPort uint + Addr string + ClusterName string + Labels map[string]string + IP string + StatusPort uint } func (bi BackendInfo) Equals(other BackendInfo) bool { - return bi.IP == other.IP && + return bi.Addr == other.Addr && + bi.ClusterName == other.ClusterName && + bi.IP == other.IP && bi.StatusPort == other.StatusPort && maps.Equal(bi.Labels, other.Labels) } diff --git a/pkg/balance/observer/backend_health_test.go b/pkg/balance/observer/backend_health_test.go index a6fa0ae05..1b90b1391 100644 --- a/pkg/balance/observer/backend_health_test.go +++ b/pkg/balance/observer/backend_health_test.go @@ -15,6 +15,7 @@ func TestBackendHealthToString(t *testing.T) { {}, { BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -45,6 +46,7 @@ func TestBackendHealthEquals(t *testing.T) { { a: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -52,6 +54,7 @@ func TestBackendHealthEquals(t *testing.T) { }, b: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, }, @@ -61,6 +64,7 @@ func TestBackendHealthEquals(t *testing.T) { { a: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -68,6 +72,7 @@ func TestBackendHealthEquals(t *testing.T) { }, b: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, @@ -78,6 +83,7 @@ func TestBackendHealthEquals(t *testing.T) { { a: BackendHealth{ BackendInfo: BackendInfo{ + Addr: "127.0.0.1:4000", IP: "127.0.0.1", StatusPort: 1, Labels: map[string]string{"k1": "v1", "k2": "v2"}, diff --git a/pkg/balance/observer/backend_observer_test.go b/pkg/balance/observer/backend_observer_test.go index 8c3b9b221..ebb7fda22 100644 --- a/pkg/balance/observer/backend_observer_test.go +++ b/pkg/balance/observer/backend_observer_test.go @@ -279,6 +279,7 @@ func (ts *observerTestSuite) addBackend() (string, BackendInfo) { ts.backendIdx++ addr := fmt.Sprintf("%d", ts.backendIdx) info := &BackendInfo{ + Addr: addr, IP: "127.0.0.1", StatusPort: uint(ts.backendIdx), } diff --git a/pkg/balance/observer/health_check.go b/pkg/balance/observer/health_check.go index 5e7578fa0..7c81627cf 100644 --- a/pkg/balance/observer/health_check.go +++ b/pkg/balance/observer/health_check.go @@ -19,9 +19,14 @@ import ( "go.uber.org/zap" ) +type BackendNetwork interface { + HTTPClient(clusterName string) *http.Client + DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) +} + // HealthCheck is used to check the backends of one backend. One can pass a customized health check function to the observer. type HealthCheck interface { - Check(ctx context.Context, addr string, info *BackendInfo, lastHealth *BackendHealth) *BackendHealth + Check(ctx context.Context, backendID string, info *BackendInfo, lastHealth *BackendHealth) *BackendHealth } const ( @@ -48,21 +53,45 @@ type security struct { type DefaultHealthCheck struct { cfg *config.HealthCheck logger *zap.Logger - httpCli *http.Client + network BackendNetwork } func NewDefaultHealthCheck(httpCli *http.Client, cfg *config.HealthCheck, logger *zap.Logger) *DefaultHealthCheck { - if httpCli == nil { - httpCli = http.NewHTTPClient(func() *tls.Config { return nil }) + return NewDefaultHealthCheckWithNetwork(newDefaultBackendNetwork(httpCli), cfg, logger) +} + +func NewDefaultHealthCheckWithNetwork(network BackendNetwork, cfg *config.HealthCheck, logger *zap.Logger) *DefaultHealthCheck { + if network == nil { + network = newDefaultBackendNetwork(nil) } return &DefaultHealthCheck{ - httpCli: httpCli, + network: network, cfg: cfg, logger: logger, } } -func (dhc *DefaultHealthCheck) Check(ctx context.Context, addr string, info *BackendInfo, lastBh *BackendHealth) *BackendHealth { +type defaultBackendNetwork struct { + httpCli *http.Client +} + +func newDefaultBackendNetwork(httpCli *http.Client) *defaultBackendNetwork { + if httpCli == nil { + httpCli = http.NewHTTPClient(func() *tls.Config { return nil }) + } + return &defaultBackendNetwork{httpCli: httpCli} +} + +func (n *defaultBackendNetwork) HTTPClient(string) *http.Client { + return n.httpCli +} + +func (n *defaultBackendNetwork) DialContext(ctx context.Context, network, addr, _ string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) +} + +func (dhc *DefaultHealthCheck) Check(ctx context.Context, _ string, info *BackendInfo, lastBh *BackendHealth) *BackendHealth { bh := &BackendHealth{ BackendInfo: *info, Healthy: true, @@ -80,7 +109,7 @@ func (dhc *DefaultHealthCheck) Check(ctx context.Context, addr string, info *Bac if !bh.Healthy { return bh } - dhc.checkSqlPort(ctx, addr, bh) + dhc.checkSqlPort(ctx, info, bh) if !bh.Healthy { return bh } @@ -88,12 +117,21 @@ func (dhc *DefaultHealthCheck) Check(ctx context.Context, addr string, info *Bac return bh } -func (dhc *DefaultHealthCheck) checkSqlPort(ctx context.Context, addr string, bh *BackendHealth) { +func (dhc *DefaultHealthCheck) checkSqlPort(ctx context.Context, info *BackendInfo, bh *BackendHealth) { // Also dial the SQL port just in case that the SQL port hangs. + if info == nil || info.Addr == "" { + bh.Healthy = false + bh.PingErr = errors.New("backend address is empty") + return + } + addr := info.Addr + clusterName := info.ClusterName b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) err := http.ConnectWithRetry(func() error { startTime := time.Now() - conn, err := net.DialTimeout("tcp", addr, dhc.cfg.DialTimeout) + dialCtx, cancel := context.WithTimeout(ctx, dhc.cfg.DialTimeout) + conn, err := dhc.network.DialContext(dialCtx, "tcp", addr, clusterName) + cancel() setPingBackendMetrics(addr, startTime) if err != nil { return err @@ -128,7 +166,8 @@ func (dhc *DefaultHealthCheck) checkStatusPort(ctx context.Context, info *Backen addr := net.JoinHostPort(info.IP, strconv.Itoa(int(info.StatusPort))) b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) - resp, err := dhc.httpCli.Get(addr, statusPathSuffix, b, dhc.cfg.DialTimeout) + clusterName := info.ClusterName + resp, err := dhc.network.HTTPClient(clusterName).Get(addr, statusPathSuffix, b, dhc.cfg.DialTimeout) if err == nil { var respBody backendHttpStatusRespBody err = json.Unmarshal(resp, &respBody) @@ -170,7 +209,8 @@ func (dhc *DefaultHealthCheck) queryConfig(ctx context.Context, info *BackendInf b := backoff.WithContext(backoff.WithMaxRetries(backoff.NewConstantBackOff(dhc.cfg.RetryInterval), uint64(dhc.cfg.MaxRetries)), ctx) var resp []byte - if resp, err = dhc.httpCli.Get(addr, configPathSuffix, b, dhc.cfg.DialTimeout); err != nil { + clusterName := info.ClusterName + if resp, err = dhc.network.HTTPClient(clusterName).Get(addr, configPathSuffix, b, dhc.cfg.DialTimeout); err != nil { return } var respBody backendHttpConfigRespBody diff --git a/pkg/balance/observer/health_check_test.go b/pkg/balance/observer/health_check_test.go index 98b2be67c..4a7b9ad4b 100644 --- a/pkg/balance/observer/health_check_test.go +++ b/pkg/balance/observer/health_check_test.go @@ -5,10 +5,12 @@ package observer import ( "context" + "crypto/tls" "encoding/json" "net" "net/http" "strings" + "sync" "sync/atomic" "testing" "time" @@ -17,6 +19,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/lib/util/waitgroup" "github.com/pingcap/tiproxy/pkg/testkit" + httputil "github.com/pingcap/tiproxy/pkg/util/http" "github.com/stretchr/testify/require" ) @@ -120,6 +123,59 @@ func TestSupportRedirection(t *testing.T) { require.False(t, health.SupportRedirection) } +func TestHealthCheckUsesClusterNetwork(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + cfg := newHealthCheckConfigForTest() + backend, info := newBackendServer(t) + defer backend.close() + backend.setServerVersion("1.0") + backend.setHasSigningCert(true) + info.ClusterName = "cluster-a" + + network := &mockBackendNetwork{ + httpCli: httputil.NewHTTPClient(func() *tls.Config { return nil }), + } + hc := NewDefaultHealthCheckWithNetwork(network, cfg, lg) + health := hc.Check(context.Background(), backend.sqlAddr, info, nil) + require.True(t, health.Healthy) + require.Contains(t, network.httpClusters(), "cluster-a") + require.Contains(t, network.dialClusters(), "cluster-a") +} + +type mockBackendNetwork struct { + httpCli *httputil.Client + mu sync.Mutex + https []string + dials []string +} + +func (n *mockBackendNetwork) HTTPClient(clusterName string) *httputil.Client { + n.mu.Lock() + n.https = append(n.https, clusterName) + n.mu.Unlock() + return n.httpCli +} + +func (n *mockBackendNetwork) DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) { + n.mu.Lock() + n.dials = append(n.dials, clusterName) + n.mu.Unlock() + var dialer net.Dialer + return dialer.DialContext(ctx, network, addr) +} + +func (n *mockBackendNetwork) httpClusters() []string { + n.mu.Lock() + defer n.mu.Unlock() + return append([]string(nil), n.https...) +} + +func (n *mockBackendNetwork) dialClusters() []string { + n.mu.Lock() + defer n.mu.Unlock() + return append([]string(nil), n.dials...) +} + type backendServer struct { t *testing.T sqlListener net.Listener @@ -143,6 +199,7 @@ func newBackendServer(t *testing.T) (*backendServer, *BackendInfo) { backend.setSqlResp(true) backend.startSQLServer() return backend, &BackendInfo{ + Addr: backend.sqlAddr, IP: backend.ip, StatusPort: backend.statusPort, } diff --git a/pkg/balance/observer/mock_test.go b/pkg/balance/observer/mock_test.go index c8dbfd7d2..19d3e0edb 100644 --- a/pkg/balance/observer/mock_test.go +++ b/pkg/balance/observer/mock_test.go @@ -19,9 +19,10 @@ import ( ) type mockTpFetcher struct { - t *testing.T - infos map[string]*infosync.TiDBTopologyInfo - err error + t *testing.T + infos map[string]*infosync.TiDBTopologyInfo + err error + hasClusters bool } func newMockTpFetcher(t *testing.T) *mockTpFetcher { @@ -34,6 +35,10 @@ func (ft *mockTpFetcher) GetTiDBTopology(ctx context.Context) (map[string]*infos return ft.infos, ft.err } +func (ft *mockTpFetcher) HasBackendClusters() bool { + return ft.hasClusters +} + type mockBackendFetcher struct { sync.Mutex backends map[string]*BackendInfo @@ -82,11 +87,11 @@ func newMockHealthCheck() *mockHealthCheck { } } -func (mhc *mockHealthCheck) Check(_ context.Context, addr string, info *BackendInfo, _ *BackendHealth) *BackendHealth { +func (mhc *mockHealthCheck) Check(_ context.Context, backendID string, info *BackendInfo, _ *BackendHealth) *BackendHealth { mhc.Lock() defer mhc.Unlock() - mhc.backends[addr].BackendInfo = *info - return mhc.backends[addr] + mhc.backends[backendID].BackendInfo = *info + return mhc.backends[backendID] } func (mhc *mockHealthCheck) setBackend(addr string, health *BackendHealth) { diff --git a/pkg/balance/router/backend_selector.go b/pkg/balance/router/backend_selector.go index 25225cfbf..c9e979c05 100644 --- a/pkg/balance/router/backend_selector.go +++ b/pkg/balance/router/backend_selector.go @@ -8,6 +8,8 @@ import "net" type ClientInfo struct { ClientAddr net.Addr ProxyAddr net.Addr + // ListenerAddr is the SQL listener address that accepted the connection. + ListenerAddr string // TODO: username, database, etc. } diff --git a/pkg/balance/router/group.go b/pkg/balance/router/group.go index 43341054b..90586e42f 100644 --- a/pkg/balance/router/group.go +++ b/pkg/balance/router/group.go @@ -30,6 +30,8 @@ const ( MatchClientCIDR // Match connections based on proxy CIDR. If proxy-protocol is disabled, route by the client CIDR. MatchProxyCIDR + // Match connections based on the local SQL listener port. + MatchPort ) var _ ConnEventReceiver = (*Group)(nil) @@ -104,7 +106,7 @@ func (g *Group) Match(clientInfo ClientInfo) bool { func (g *Group) EqualValues(values []string) bool { switch g.matchType { - case MatchClientCIDR, MatchProxyCIDR: + case MatchClientCIDR, MatchProxyCIDR, MatchPort: if len(g.values) != len(values) { return false } @@ -123,7 +125,7 @@ func (g *Group) EqualValues(values []string) bool { // E.g. enable public endpoint (3 cidrs) -> enable private endpoint (6 cidrs) -> disable public endpoint (3 cidrs). func (g *Group) Intersect(values []string) bool { switch g.matchType { - case MatchClientCIDR, MatchProxyCIDR: + case MatchClientCIDR, MatchProxyCIDR, MatchPort: for _, v := range g.values { if slices.Contains(values, v) { return true @@ -158,17 +160,17 @@ func (g *Group) RefreshCidr() { } } -func (g *Group) AddBackend(addr string, backend *backendWrapper) { +func (g *Group) AddBackend(backendID string, backend *backendWrapper) { g.Lock() defer g.Unlock() - g.backends[addr] = backend + g.backends[backendID] = backend backend.group = g } -func (g *Group) RemoveBackend(addr string) { +func (g *Group) RemoveBackend(backendID string) { g.Lock() defer g.Unlock() - delete(g.backends, addr) + delete(g.backends, backendID) } func (g *Group) Empty() bool { @@ -200,7 +202,7 @@ func (g *Group) Route(excluded []BackendInst) (policy.BackendCtx, error) { // Exclude the backends that are already tried. found := false for _, e := range excluded { - if backend.Addr() == e.Addr() { + if backend.ID() == e.ID() { found = true break } @@ -273,7 +275,7 @@ func (g *Group) Balance(ctx context.Context) { func (g *Group) onCreateConn(backendInst BackendInst, conn RedirectableConn, succeed bool) { g.Lock() defer g.Unlock() - backend := g.ensureBackend(backendInst.Addr()) + backend := g.ensureBackend(backendInst.ID()) if succeed { connWrapper := &connWrapper{ RedirectableConn: conn, @@ -319,23 +321,18 @@ func (g *Group) RedirectConnections() error { return nil } -func (g *Group) ensureBackend(addr string) *backendWrapper { - backend, ok := g.backends[addr] +func (g *Group) ensureBackend(backendID string) *backendWrapper { + backend, ok := g.backends[backendID] if ok { return backend } // The backend should always exist if it will be needed. Add a warning and add it back. - g.lg.Warn("backend is not found in the router", zap.String("backend_addr", addr), zap.Stack("stack")) - ip, _, _ := net.SplitHostPort(addr) - backend = newBackendWrapper(addr, observer.BackendHealth{ - BackendInfo: observer.BackendInfo{ - IP: ip, - StatusPort: 10080, // impossible anyway - }, + g.lg.Warn("backend is not found in the router", zap.String("backend_id", backendID), zap.Stack("stack")) + backend = newBackendWrapper(backendID, observer.BackendHealth{ SupportRedirection: true, Healthy: false, }) - g.backends[addr] = backend + g.backends[backendID] = backend return backend } @@ -375,16 +372,16 @@ func (g *Group) onRedirectFinished(from, to string, conn RedirectableConn, succe } // OnConnClosed implements ConnEventReceiver.OnConnClosed interface. -func (g *Group) OnConnClosed(addr, redirectingAddr string, conn RedirectableConn) error { +func (g *Group) OnConnClosed(backendID, redirectingBackendID string, conn RedirectableConn) error { g.Lock() defer g.Unlock() - backend := g.ensureBackend(addr) + backend := g.ensureBackend(backendID) connWrapper := getConnWrapper(conn) // If this connection has not redirected yet, decrease the score of the target backend. - if redirectingAddr != "" { - redirectingBackend := g.ensureBackend(redirectingAddr) + if redirectingBackendID != "" { + redirectingBackend := g.ensureBackend(redirectingBackendID) redirectingBackend.connScore-- - metrics.PendingMigrateGuage.WithLabelValues(addr, redirectingAddr, connWrapper.Value.redirectReason).Dec() + metrics.PendingMigrateGuage.WithLabelValues(backendID, redirectingBackendID, connWrapper.Value.redirectReason).Dec() } else { backend.connScore-- } diff --git a/pkg/balance/router/mock_test.go b/pkg/balance/router/mock_test.go index d8eb98950..b701657ee 100644 --- a/pkg/balance/router/mock_test.go +++ b/pkg/balance/router/mock_test.go @@ -69,13 +69,13 @@ func (conn *mockRedirectableConn) Redirect(inst BackendInst) bool { return true } -func (conn *mockRedirectableConn) GetRedirectingAddr() string { +func (conn *mockRedirectableConn) GetRedirectingBackendID() string { conn.Lock() defer conn.Unlock() if conn.to == nil { return "" } - return conn.to.Addr() + return conn.to.ID() } func (conn *mockRedirectableConn) ConnectionID() uint64 { @@ -86,14 +86,14 @@ func (conn *mockRedirectableConn) ConnInfo() []zap.Field { return nil } -func (conn *mockRedirectableConn) getAddr() (string, string) { +func (conn *mockRedirectableConn) getBackendIDs() (string, string) { conn.Lock() defer conn.Unlock() var to string if conn.to != nil && !reflect.ValueOf(conn.to).IsNil() { - to = conn.to.Addr() + to = conn.to.ID() } - return conn.from.Addr(), to + return conn.from.ID(), to } func (conn *mockRedirectableConn) redirectSucceed() { @@ -133,16 +133,28 @@ func (mbo *mockBackendObserver) toggleBackendHealth(addr string) { } func (mbo *mockBackendObserver) addBackend(addr string, labels map[string]string) { + mbo.addBackendWithCluster(addr, "", labels) +} + +func (mbo *mockBackendObserver) addBackendWithCluster(addr, clusterName string, labels map[string]string) { mbo.healthLock.Lock() defer mbo.healthLock.Unlock() mbo.healths[addr] = &observer.BackendHealth{ Healthy: true, BackendInfo: observer.BackendInfo{ - Labels: labels, + Addr: addr, + ClusterName: clusterName, + Labels: labels, }, } } +func (mbo *mockBackendObserver) setLabels(addr string, labels map[string]string) { + mbo.healthLock.Lock() + defer mbo.healthLock.Unlock() + mbo.healths[addr].Labels = labels +} + func (mbo *mockBackendObserver) Start(ctx context.Context) { } @@ -182,8 +194,9 @@ func (mbo *mockBackendObserver) notify(err error) { func (mbo *mockBackendObserver) Close() { mbo.subscriberLock.Lock() defer mbo.subscriberLock.Unlock() - for _, subscriber := range mbo.subscribers { + for name, subscriber := range mbo.subscribers { close(subscriber) + delete(mbo.subscribers, name) } } diff --git a/pkg/balance/router/port_conflict_detector.go b/pkg/balance/router/port_conflict_detector.go new file mode 100644 index 000000000..f625d6446 --- /dev/null +++ b/pkg/balance/router/port_conflict_detector.go @@ -0,0 +1,49 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package router + +import "github.com/pingcap/tiproxy/lib/util/errors" + +type portConflictDetector struct { + routes map[string]*Group + blocked map[string]error + owners map[string]string +} + +func newPortConflictDetector() *portConflictDetector { + return &portConflictDetector{ + routes: make(map[string]*Group), + blocked: make(map[string]error), + owners: make(map[string]string), + } +} + +func (v *portConflictDetector) bind(port, clusterName string, group *Group) { + if v == nil || port == "" || group == nil { + return + } + if _, blocked := v.blocked[port]; blocked { + return + } + if owner, ok := v.owners[port]; !ok { + v.owners[port] = clusterName + v.routes[port] = group + return + } else if owner != clusterName { + v.blocked[port] = errors.Wrapf(ErrPortConflict, "listener port %s is claimed by multiple backend clusters", port) + delete(v.routes, port) + return + } + v.routes[port] = group +} + +func (v *portConflictDetector) groupFor(port string) (*Group, error) { + if v == nil || port == "" { + return nil, nil + } + if err, ok := v.blocked[port]; ok { + return nil, err + } + return v.routes[port], nil +} diff --git a/pkg/balance/router/router.go b/pkg/balance/router/router.go index ee844bf9e..9f62833ad 100644 --- a/pkg/balance/router/router.go +++ b/pkg/balance/router/router.go @@ -16,14 +16,15 @@ import ( ) var ( - ErrNoBackend = errors.New("no available backend") + ErrNoBackend = errors.New("no available backend") + ErrPortConflict = errors.New("port routing conflict") ) // ConnEventReceiver receives connection events. type ConnEventReceiver interface { OnRedirectSucceed(from, to string, conn RedirectableConn) error OnRedirectFail(from, to string, conn RedirectableConn) error - OnConnClosed(addr, redirectingAddr string, conn RedirectableConn) error + OnConnClosed(backendID, redirectingBackendID string, conn RedirectableConn) error } // Router routes client connections to backends. @@ -74,10 +75,12 @@ type RedirectableConn interface { // BackendInst defines a backend that a connection is redirecting to. type BackendInst interface { + ID() string Addr() string Healthy() bool Local() bool Keyspace() string + ClusterName() string } // backendWrapper contains the connections on the backend. @@ -86,6 +89,7 @@ type backendWrapper struct { sync.RWMutex observer.BackendHealth } + id string addr string // connScore is used for calculating backend scores and check if the backend can be removed from the list. // connScore = connList.Len() + incoming connections - outgoing connections. @@ -97,9 +101,10 @@ type backendWrapper struct { group *Group } -func newBackendWrapper(addr string, health observer.BackendHealth) *backendWrapper { +func newBackendWrapper(id string, health observer.BackendHealth) *backendWrapper { wrapper := &backendWrapper{ - addr: addr, + id: id, + addr: health.Addr, connList: glist.New[*connWrapper](), } wrapper.setHealth(health) @@ -123,6 +128,10 @@ func (b *backendWrapper) ConnScore() int { return b.connScore } +func (b *backendWrapper) ID() string { + return b.id +} + func (b *backendWrapper) Addr() string { return b.addr } @@ -176,6 +185,12 @@ func (b *backendWrapper) Keyspace() string { return labels[config.KeyspaceLabelName] } +func (b *backendWrapper) ClusterName() string { + b.mu.RLock() + defer b.mu.RUnlock() + return b.mu.BackendHealth.ClusterName +} + func (b *backendWrapper) Cidr() []string { labels := b.getHealth().Labels if len(labels) == 0 { @@ -197,6 +212,14 @@ func (b *backendWrapper) Cidr() []string { return cidrs } +func (b *backendWrapper) TiProxyPort() string { + labels := b.getHealth().Labels + if len(labels) == 0 { + return "" + } + return strings.TrimSpace(labels[config.TiProxyPortLabelName]) +} + func (b *backendWrapper) String() string { b.mu.RLock() str := b.mu.String() diff --git a/pkg/balance/router/router_score.go b/pkg/balance/router/router_score.go index d70afa4a9..efc7e1121 100644 --- a/pkg/balance/router/router_score.go +++ b/pkg/balance/router/router_score.go @@ -5,6 +5,8 @@ package router import ( "context" + "fmt" + "net" "slices" "strings" "sync" @@ -41,6 +43,8 @@ type ScoreBasedRouter struct { backends map[string]*backendWrapper // TODO: sort the groups to leverage binary search. groups []*Group + // portConflictDetector dispatches listener ports to cluster-scoped backend groups. + portConflictDetector *portConflictDetector // The routing rule for categorizing backends to groups. matchType MatchType observeError error @@ -74,6 +78,8 @@ func (r *ScoreBasedRouter) Init(ctx context.Context, ob observer.BackendObserver r.matchType = MatchClientCIDR case config.MatchProxyCIDRStr: r.matchType = MatchProxyCIDR + case config.MatchPortStr: + r.matchType = MatchPort case "": default: r.logger.Error("unsupported routing rule, use the default rule", zap.String("rule", cfg.Balance.RoutingRule)) @@ -110,13 +116,20 @@ func (router *ScoreBasedRouter) GetBackendSelector(clientInfo ClientInfo) Backen return } // The group may change from round to round because the backends are updated. - group = router.routeToGroup(clientInfo) + group, err = router.routeToGroup(clientInfo) + if err != nil { + return + } if group == nil { err = ErrNoBackend return } // The router may remove this group concurrently, make sure the group can be accessed after it's removed. - backend, err = group.Route(excluded) + var backendCtx policy.BackendCtx + backendCtx, err = group.Route(excluded) + if err == nil && backendCtx != nil { + backend = backendCtx.(BackendInst) + } return }, onCreate: func(backend BackendInst, conn RedirectableConn, succeed bool) { @@ -142,14 +155,22 @@ func (router *ScoreBasedRouter) HealthyBackendCount() int { } // called in the lock -func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) *Group { +func (router *ScoreBasedRouter) routeToGroup(clientInfo ClientInfo) (*Group, error) { + if router.matchType == MatchPort { + _, port, err := net.SplitHostPort(clientInfo.ListenerAddr) + if err != nil { + router.logger.Error("checking port failed", zap.String("listener_addr", clientInfo.ListenerAddr), zap.Error(err)) + return nil, nil + } + return router.portConflictDetector.groupFor(port) + } // TODO: binary search for _, group := range router.groups { if group.Match(clientInfo) { - return group + return group, nil } } - return nil + return nil, nil } // RefreshBackend implements Router.GetBackendSelector interface. @@ -181,11 +202,12 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt // `backends` contain all the backends, not only the updated ones. backends := healthResults.Backends() // If some backends are removed from the list, add them to `backends`. - for addr, backend := range router.backends { - if _, ok := backends[addr]; !ok { + for backendID, backend := range router.backends { + if _, ok := backends[backendID]; !ok { health := backend.getHealth() - router.logger.Debug("backend is removed from the list, add it back to router", zap.String("addr", addr), zap.Stringer("health", &health)) - backends[addr] = &observer.BackendHealth{ + router.logger.Debug("backend is removed from the list, add it back to router", + zap.String("backend_id", backendID), zap.String("addr", backend.Addr()), zap.Stringer("health", &health)) + backends[backendID] = &observer.BackendHealth{ BackendInfo: backend.GetBackendInfo(), SupportRedirection: backend.SupportRedirection(), Healthy: false, @@ -195,22 +217,25 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt } var serverVersion string supportRedirection := true - for addr, health := range backends { - backend, ok := router.backends[addr] + for backendID, health := range backends { + backend, ok := router.backends[backendID] if !ok && health.Healthy { - router.logger.Debug("add new backend to router", zap.String("addr", addr), zap.Stringer("health", health)) - router.backends[addr] = newBackendWrapper(addr, *health) + router.logger.Debug("add new backend to router", + zap.String("backend_id", backendID), zap.String("addr", health.Addr), zap.Stringer("health", health)) + router.backends[backendID] = newBackendWrapper(backendID, *health) serverVersion = health.ServerVersion } else if ok { if !health.Equals(backend.getHealth()) { - router.logger.Debug("update backend in router", zap.String("addr", addr), zap.Stringer("health", health)) + router.logger.Debug("update backend in router", + zap.String("backend_id", backendID), zap.String("addr", health.Addr), zap.Stringer("health", health)) } backend.setHealth(*health) if health.Healthy { serverVersion = health.ServerVersion } } else { - router.logger.Debug("unhealthy backend is not in router", zap.String("addr", addr), zap.Stringer("health", health)) + router.logger.Debug("unhealthy backend is not in router", + zap.String("backend_id", backendID), zap.String("addr", health.Addr), zap.Stringer("health", health)) } supportRedirection = health.SupportRedirection && supportRedirection } @@ -225,32 +250,74 @@ func (router *ScoreBasedRouter) updateBackendHealth(healthResults observer.Healt } } -// Update the groups after the backend list is updated. +func matchPortValue(clusterName, port string) string { + if clusterName == "" { + return port + } + return fmt.Sprintf("%s:%s", clusterName, port) +} + +func (router *ScoreBasedRouter) rebuildPortConflictDetector() { + if router.matchType != MatchPort { + router.portConflictDetector = nil + return + } + detector := newPortConflictDetector() + for _, group := range router.groups { + for _, value := range group.values { + clusterName, port, ok := strings.Cut(value, ":") + if !ok { + port = value + clusterName = "" + } + detector.bind(port, clusterName, group) + } + } + router.portConflictDetector = detector +} + // called in the lock. func (router *ScoreBasedRouter) updateGroups() { for _, backend := range router.backends { // If connList.Len() == 0, there won't be any outgoing connections. // And if also connScore == 0, there won't be any incoming connections. if !backend.Healthy() && backend.connList.Len() == 0 && backend.connScore <= 0 { - delete(router.backends, backend.addr) + delete(router.backends, backend.id) if backend.group != nil { - backend.group.RemoveBackend(backend.addr) - // remove empty groups + backend.group.RemoveBackend(backend.id) if backend.group.Empty() { router.groups = slices.DeleteFunc(router.groups, func(g *Group) bool { return g == backend.group }) } + backend.group = nil } continue } - // If the labels were correctly set, we won't update its group even if the labels change. if backend.group != nil { + switch router.matchType { + case MatchClientCIDR, MatchProxyCIDR, MatchPort: + var values []string + switch router.matchType { + case MatchClientCIDR, MatchProxyCIDR: + values = backend.Cidr() + case MatchPort: + port := backend.TiProxyPort() + if port != "" { + values = []string{matchPortValue(backend.ClusterName(), port)} + } + } + if !backend.group.EqualValues(values) { + router.logger.Warn("backend routing values changed, keep the existing group until it is removed", + zap.String("backend_id", backend.id), + zap.String("addr", backend.Addr()), + zap.Strings("current_values", values), + zap.Strings("group_values", backend.group.values)) + } + } continue } - // If the backend is not in any group, add it to a new group if its label is set. - // In operator deployment, the labels are set dynamically. var group *Group switch router.matchType { case MatchAll: @@ -259,33 +326,43 @@ func (router *ScoreBasedRouter) updateGroups() { router.groups = append(router.groups, group) } group = router.groups[0] - case MatchClientCIDR, MatchProxyCIDR: - cidrs := backend.Cidr() - if len(cidrs) == 0 { + case MatchClientCIDR, MatchProxyCIDR, MatchPort: + var values []string + switch router.matchType { + case MatchClientCIDR, MatchProxyCIDR: + values = backend.Cidr() + case MatchPort: + port := backend.TiProxyPort() + if port != "" { + values = []string{matchPortValue(backend.ClusterName(), port)} + } + } + if len(values) == 0 { break } for _, g := range router.groups { - if g.Intersect(cidrs) { + if g.Intersect(values) { group = g break } } if group == nil { - g, err := NewGroup(cidrs, router.bpCreator, router.matchType, router.logger) + g, err := NewGroup(values, router.bpCreator, router.matchType, router.logger) if err == nil { group = g router.groups = append(router.groups, group) } - // maybe too many logs, ignore the error now } } - if group != nil { - group.AddBackend(backend.addr, backend) + if group == nil { + continue } + group.AddBackend(backend.id, backend) } for _, group := range router.groups { group.RefreshCidr() } + router.rebuildPortConflictDetector() } func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { @@ -295,9 +372,19 @@ func (router *ScoreBasedRouter) rebalanceLoop(ctx context.Context) { case <-ctx.Done(): ticker.Stop() return - case healthResults := <-router.healthCh: + case healthResults, ok := <-router.healthCh: + if !ok { + router.logger.Warn("health channel is closed, stop watching channel") + router.healthCh = nil + continue + } router.updateBackendHealth(healthResults) - case cfg := <-router.cfgCh: + case cfg, ok := <-router.cfgCh: + if !ok { + router.logger.Warn("config channel is closed, stop watching channel") + router.cfgCh = nil + continue + } router.setConfig(cfg) case <-ticker.C: router.rebalance(ctx) diff --git a/pkg/balance/router/router_score_test.go b/pkg/balance/router/router_score_test.go index 1891f5eab..10136e750 100644 --- a/pkg/balance/router/router_score_test.go +++ b/pkg/balance/router/router_score_test.go @@ -8,6 +8,7 @@ import ( "math" "math/rand" "reflect" + "slices" "strconv" "testing" "time" @@ -71,6 +72,9 @@ func (tester *routerTester) addBackends(num int) { tester.backendID++ addr := strconv.Itoa(tester.backendID) tester.backends[addr] = &observer.BackendHealth{ + BackendInfo: observer.BackendInfo{ + Addr: addr, + }, Healthy: true, SupportRedirection: true, } @@ -114,6 +118,9 @@ func (tester *routerTester) updateBackendStatusByAddr(addr string, healthy bool) health.Healthy = healthy } else { tester.backends[addr] = &observer.BackendHealth{ + BackendInfo: observer.BackendInfo{ + Addr: addr, + }, SupportRedirection: true, Healthy: healthy, } @@ -136,7 +143,11 @@ func (tester *routerTester) getBackendByIndex(index int) *backendWrapper { } func (tester *routerTester) simpleRoute(conn RedirectableConn) BackendInst { - selector := tester.router.GetBackendSelector(ClientInfo{}) + return tester.route(conn, ClientInfo{}) +} + +func (tester *routerTester) route(conn RedirectableConn, ci ClientInfo) BackendInst { + selector := tester.router.GetBackendSelector(ci) backend, err := selector.Next() if err != ErrNoBackend { require.NoError(tester.t, err) @@ -159,11 +170,11 @@ func (tester *routerTester) closeConnections(num int, redirecting bool) { conns := make(map[uint64]*mockRedirectableConn, num) for id, conn := range tester.conns { if redirecting { - if len(conn.GetRedirectingAddr()) == 0 { + if len(conn.GetRedirectingBackendID()) == 0 { continue } } else { - if len(conn.GetRedirectingAddr()) > 0 { + if len(conn.GetRedirectingBackendID()) > 0 { continue } } @@ -173,7 +184,7 @@ func (tester *routerTester) closeConnections(num int, redirecting bool) { } } for _, conn := range conns { - err := tester.router.groups[0].OnConnClosed(conn.from.Addr(), conn.GetRedirectingAddr(), conn) + err := tester.router.groups[0].OnConnClosed(conn.from.ID(), conn.GetRedirectingBackendID(), conn) require.NoError(tester.t, err) delete(tester.conns, conn.connID) } @@ -191,7 +202,7 @@ func (tester *routerTester) rebalance(num int) { func (tester *routerTester) redirectFinish(num int, succeed bool) { i := 0 for _, conn := range tester.conns { - if len(conn.GetRedirectingAddr()) == 0 { + if len(conn.GetRedirectingBackendID()) == 0 { continue } @@ -199,11 +210,11 @@ func (tester *routerTester) redirectFinish(num int, succeed bool) { prevCount, err := readMigrateCounter(from.Addr(), to.Addr(), succeed) require.NoError(tester.t, err) if succeed { - err = tester.router.groups[0].OnRedirectSucceed(from.Addr(), to.Addr(), conn) + err = tester.router.groups[0].OnRedirectSucceed(from.ID(), to.ID(), conn) require.NoError(tester.t, err) conn.redirectSucceed() } else { - err = tester.router.groups[0].OnRedirectFail(from.Addr(), to.Addr(), conn) + err = tester.router.groups[0].OnRedirectFail(from.ID(), to.ID(), conn) require.NoError(tester.t, err) conn.redirectFail() } @@ -239,7 +250,7 @@ func (tester *routerTester) checkBalanced() { func (tester *routerTester) checkRedirectingNum(num int) { redirectingNum := 0 for _, conn := range tester.conns { - if len(conn.GetRedirectingAddr()) > 0 { + if len(conn.GetRedirectingBackendID()) > 0 { redirectingNum++ } } @@ -626,13 +637,13 @@ func TestConcurrency(t *testing.T) { require.NoError(t, err) selector.Finish(conn, true) conn.from = backend - } else if len(conn.GetRedirectingAddr()) > 0 { + } else if len(conn.GetRedirectingBackendID()) > 0 { // redirecting, 70% success, 20% fail, 10% close i := rand.Intn(10) - from, to := conn.getAddr() + from, to := conn.getBackendIDs() var err error if i < 1 { - err = router.groups[0].OnConnClosed(from, conn.GetRedirectingAddr(), conn) + err = router.groups[0].OnConnClosed(from, conn.GetRedirectingBackendID(), conn) conn = nil } else if i < 3 { conn.redirectFail() @@ -647,8 +658,8 @@ func TestConcurrency(t *testing.T) { i := rand.Intn(10) if i < 2 { // The balancer may happen to redirect it concurrently - that's exactly what may happen. - from, _ := conn.getAddr() - err := router.groups[0].OnConnClosed(from, conn.GetRedirectingAddr(), conn) + from, _ := conn.getBackendIDs() + err := router.groups[0].OnConnClosed(from, conn.GetRedirectingBackendID(), conn) require.NoError(t, err) conn = nil } @@ -730,10 +741,16 @@ func TestGetServerVersion(t *testing.T) { t.Cleanup(rt.Close) backends := map[string]*observer.BackendHealth{ "0": { + BackendInfo: observer.BackendInfo{ + Addr: "0", + }, Healthy: true, ServerVersion: "1.0", }, "1": { + BackendInfo: observer.BackendInfo{ + Addr: "1", + }, Healthy: true, ServerVersion: "2.0", }, @@ -794,8 +811,8 @@ func TestUpdateBackendHealth(t *testing.T) { // Test locality of some backends are changed. tester.updateBackendLocalityByAddr(tester.getBackendByIndex(0).Addr(), false) tester.updateBackendLocalityByAddr(tester.getBackendByIndex(1).Addr(), true) - require.Equal(t, false, tester.router.backends[tester.getBackendByIndex(0).Addr()].Local()) - require.Equal(t, true, tester.router.backends[tester.getBackendByIndex(1).Addr()].Local()) + require.Equal(t, false, tester.router.backends[tester.getBackendByIndex(0).ID()].Local()) + require.Equal(t, true, tester.router.backends[tester.getBackendByIndex(1).ID()].Local()) // Test some backends are not in the list anymore. tester.removeBackends(1) tester.checkBackendNum(2) @@ -840,6 +857,55 @@ func TestWatchConfig(t *testing.T) { }, 3*time.Second, 10*time.Millisecond) } +func TestChannelClosed(t *testing.T) { + tests := []struct { + name string + closeChannel func(cfgCh chan *config.Config, bo *mockBackendObserver) + }{ + { + name: "config", + closeChannel: func(cfgCh chan *config.Config, _ *mockBackendObserver) { + close(cfgCh) + }, + }, + { + name: "health", + closeChannel: func(_ chan *config.Config, bo *mockBackendObserver) { + bo.Close() + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + cfgCh := make(chan *config.Config) + cfg := &config.Config{} + cfgGetter := newMockConfigGetter(cfg) + p := &mockBalancePolicy{} + bpCreator := func(lg *zap.Logger) policy.BalancePolicy { + p.Init(cfg) + return p + } + bo := newMockBackendObserver() + router.Init(context.Background(), bo, bpCreator, cfgGetter, cfgCh) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackend("0", nil) + bo.notify(nil) + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + return len(router.groups) == 1 + }, 3*time.Second, 10*time.Millisecond) + + tt.closeChannel(cfgCh, bo) + time.Sleep(100 * time.Millisecond) + }) + } +} + func TestControlSpeed(t *testing.T) { tests := []struct { balanceCount float64 @@ -978,10 +1044,16 @@ func TestSkipRedirection(t *testing.T) { tester := newRouterTester(t, nil) backends := map[string]*observer.BackendHealth{ "0": { + BackendInfo: observer.BackendInfo{ + Addr: "0", + }, Healthy: true, SupportRedirection: false, }, "1": { + BackendInfo: observer.BackendInfo{ + Addr: "1", + }, Healthy: true, SupportRedirection: true, }, @@ -1109,3 +1181,492 @@ func TestGroupBackends(t *testing.T) { }, 3*time.Second, 10*time.Millisecond, "test %d", i) } } + +func TestGroupBackendsByPort(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + tests := []struct { + addr string + labels map[string]string + groupCount int + backendCount int + port string + }{ + { + addr: "0", + labels: nil, + groupCount: 0, + backendCount: 1, + }, + { + addr: "1", + labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + groupCount: 1, + backendCount: 2, + port: "10080", + }, + { + addr: "2", + labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + groupCount: 1, + backendCount: 3, + port: "10080", + }, + { + addr: "3", + labels: map[string]string{config.TiProxyPortLabelName: "10081"}, + groupCount: 2, + backendCount: 4, + port: "10081", + }, + } + + for i, test := range tests { + bo.addBackend(test.addr, test.labels) + bo.notify(nil) + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + if len(router.groups) != test.groupCount { + return false + } + if len(router.backends) != test.backendCount { + return false + } + group := router.backends[test.addr].group + if test.port == "" { + return group == nil + } + return group != nil && slices.Equal(group.values, []string{test.port}) + }, 3*time.Second, 10*time.Millisecond, "test %d", i) + } +} + +func TestRouteAndRebalanceByPort(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + bp := &mockBalancePolicy{} + tester := newRouterTester(t, bp) + tester.router.matchType = MatchPort + bp.backendToRoute = func(backends []policy.BackendCtx) policy.BackendCtx { + if len(backends) == 0 { + return nil + } + return backends[0] + } + bp.backendsToBalance = func(backends []policy.BackendCtx) (from policy.BackendCtx, to policy.BackendCtx, balanceCount float64, reason string, logFields []zapcore.Field) { + if len(backends) < 2 { + return nil, nil, 0, "", nil + } + var busiest, idlest policy.BackendCtx + for _, backend := range backends { + if busiest == nil || backend.ConnCount() > busiest.ConnCount() { + busiest = backend + } + if idlest == nil || backend.ConnCount() < idlest.ConnCount() { + idlest = backend + } + } + if busiest == nil || idlest == nil || busiest == idlest { + return nil, nil, 0, "", nil + } + return busiest, idlest, 100, "conn", nil + } + tester.router.cfgGetter = newMockConfigGetter(cfg) + + tester.backends["1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "1", + Labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + }, + } + tester.backends["2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "2", + Labels: map[string]string{config.TiProxyPortLabelName: "10080"}, + }, + } + tester.backends["3"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "3", + Labels: map[string]string{config.TiProxyPortLabelName: "10081"}, + }, + } + tester.notifyHealth() + + for range 10 { + conn := tester.createConn() + backend := tester.route(conn, ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + require.NotNil(t, backend) + conn.from = backend + tester.conns[conn.connID] = conn + } + for _, conn := range tester.conns { + require.Equal(t, "10080", tester.router.backends[conn.from.ID()].TiProxyPort()) + } + + tester.rebalance(10) + redirecting := 0 + for _, conn := range tester.conns { + if conn.to == nil || reflect.ValueOf(conn.to).IsNil() { + continue + } + redirecting++ + require.Equal(t, "10080", tester.router.backends[conn.to.ID()].TiProxyPort()) + require.NotEqual(t, "3", conn.to.Addr()) + } + require.Greater(t, redirecting, 0) +} + +func TestRouteByPortBlocksConflictingClusters(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + router := NewScoreBasedRouter(zap.NewNop()) + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackendWithCluster("a1", "cluster-a", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.addBackendWithCluster("b1", "cluster-b", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.notify(nil) + + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + return len(router.groups) == 2 && router.portConflictDetector != nil + }, 3*time.Second, 10*time.Millisecond) + + selector := router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + _, err := selector.Next() + require.Error(t, err) + require.True(t, errors.Is(err, ErrPortConflict)) +} + +func TestRouteByPortRecoversAfterConflictIsRemoved(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + router := NewScoreBasedRouter(zap.NewNop()) + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackendWithCluster("a1", "cluster-a", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.addBackendWithCluster("b1", "cluster-b", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.notify(nil) + + require.Eventually(t, func() bool { + selector := router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + _, err := selector.Next() + return errors.Is(err, ErrPortConflict) + }, 3*time.Second, 10*time.Millisecond) + + bo.healthLock.Lock() + delete(bo.healths, "b1") + bo.healthLock.Unlock() + bo.notify(nil) + + require.Eventually(t, func() bool { + selector := router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + backend, err := selector.Next() + return err == nil && backend != nil && backend.ID() == "a1" + }, 3*time.Second, 10*time.Millisecond) +} + +func TestKeepExistingPortGroupWhenPortLabelChanges(t *testing.T) { + cfg := &config.Config{ + Balance: config.Balance{ + RoutingRule: config.MatchPortStr, + }, + } + cfgGetter := newMockConfigGetter(cfg) + bo := newMockBackendObserver() + lg, text := logger.CreateLoggerForTest(t) + router := NewScoreBasedRouter(lg) + router.Init(context.Background(), bo, simpleBpCreator, cfgGetter, make(<-chan *config.Config)) + t.Cleanup(bo.Close) + t.Cleanup(router.Close) + + bo.addBackendWithCluster("backend-1", "cluster-a", map[string]string{ + config.TiProxyPortLabelName: "10080", + }) + bo.notify(nil) + + var oldGroup *Group + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + oldGroup = router.backends["backend-1"].group + return oldGroup != nil && slices.Equal(oldGroup.values, []string{"cluster-a:10080"}) + }, 3*time.Second, 10*time.Millisecond) + + conn := newMockRedirectableConn(t, 1) + selector := router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + backend, err := selector.Next() + require.NoError(t, err) + selector.Finish(conn, true) + conn.from = backend + + bo.healthLock.Lock() + bo.healths["backend-1"].ClusterName = "cluster-a" + bo.healthLock.Unlock() + bo.setLabels("backend-1", map[string]string{ + config.TiProxyPortLabelName: "10081", + }) + bo.notify(nil) + + require.Eventually(t, func() bool { + router.Lock() + defer router.Unlock() + return router.backends["backend-1"].group == oldGroup + }, 3*time.Second, 10*time.Millisecond) + require.Contains(t, text.String(), "backend routing values changed, keep the existing group until it is removed") + + conn.Lock() + require.Equal(t, oldGroup, conn.receiver) + conn.Unlock() + + oldSelector := router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + backend, err = oldSelector.Next() + require.NoError(t, err) + require.Equal(t, "backend-1", backend.ID()) + + newSelector := router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10081"}) + _, err = newSelector.Next() + require.ErrorIs(t, err, ErrNoBackend) +} + +func TestPortConflictGroupsStayClusterScoped(t *testing.T) { + tester := newRouterTester(t, nil) + tester.router.matchType = MatchPort + tester.backends["a1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-1:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["a2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-2:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["b1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-b-1:4000", + ClusterName: "cluster-b", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["b2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-b-2:4000", + ClusterName: "cluster-b", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.notifyHealth() + + groupA := findGroupByValues(t, tester.router, []string{"cluster-a:10080"}) + groupB := findGroupByValues(t, tester.router, []string{"cluster-b:10080"}) + require.NotSame(t, groupA, groupB) + for _, backend := range groupA.backends { + require.Equal(t, "cluster-a", backend.ClusterName()) + } + for _, backend := range groupB.backends { + require.Equal(t, "cluster-b", backend.ClusterName()) + } +} + +func TestPortConflictBlocksRoutingButAllowsIntraClusterRebalance(t *testing.T) { + bp := &mockBalancePolicy{} + tester := newRouterTester(t, bp) + tester.router.matchType = MatchPort + bp.backendsToBalance = func(backends []policy.BackendCtx) (from policy.BackendCtx, to policy.BackendCtx, balanceCount float64, reason string, logFields []zapcore.Field) { + if len(backends) < 2 { + return nil, nil, 0, "", nil + } + var busiest, idlest policy.BackendCtx + for _, backend := range backends { + if busiest == nil || backend.ConnCount() > busiest.ConnCount() { + busiest = backend + } + if idlest == nil || backend.ConnCount() < idlest.ConnCount() { + idlest = backend + } + } + if busiest == nil || idlest == nil || busiest == idlest { + return nil, nil, 0, "", nil + } + return busiest, idlest, 100, "conn", nil + } + + tester.backends["a1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-1:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["a2"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-a-2:4000", + ClusterName: "cluster-a", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.backends["b1"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared-b-1:4000", + ClusterName: "cluster-b", + Labels: map[string]string{ + config.TiProxyPortLabelName: "10080", + }, + }, + } + tester.notifyHealth() + + selector := tester.router.GetBackendSelector(ClientInfo{ListenerAddr: "127.0.0.1:10080"}) + _, err := selector.Next() + require.Error(t, err) + require.True(t, errors.Is(err, ErrPortConflict)) + + groupA := findGroupByValues(t, tester.router, []string{"cluster-a:10080"}) + backendA1 := tester.router.backends["a1"] + for range 6 { + conn := tester.createConn() + groupA.onCreateConn(backendA1, conn, true) + conn.from = backendA1 + tester.conns[conn.connID] = conn + } + + groupA.lastRedirectTime = time.Time{} + groupA.Balance(context.Background()) + + redirecting := 0 + for _, conn := range tester.conns { + if conn.to == nil || reflect.ValueOf(conn.to).IsNil() { + continue + } + redirecting++ + require.Equal(t, "cluster-a", tester.router.backends[conn.to.ID()].ClusterName()) + require.Equal(t, "a2", conn.to.ID()) + } + require.Greater(t, redirecting, 0) +} + +func TestRouteBackendsWithSameAddrDifferentIDs(t *testing.T) { + tester := newRouterTester(t, nil) + tester.router.matchType = MatchAll + tester.backends["cluster-a/shared:4000"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared:4000", + ClusterName: "cluster-a", + }, + } + tester.backends["cluster-b/shared:4000"] = &observer.BackendHealth{ + Healthy: true, + SupportRedirection: true, + BackendInfo: observer.BackendInfo{ + Addr: "shared:4000", + ClusterName: "cluster-b", + }, + } + tester.notifyHealth() + + selector := tester.router.GetBackendSelector(ClientInfo{}) + first, err := selector.Next() + require.NoError(t, err) + second, err := selector.Next() + require.NoError(t, err) + + require.Equal(t, "shared:4000", first.Addr()) + require.Equal(t, "shared:4000", second.Addr()) + require.NotEqual(t, first.ID(), second.ID()) +} + +func findGroupByValues(t *testing.T, router *ScoreBasedRouter, values []string) *Group { + t.Helper() + router.Lock() + defer router.Unlock() + for _, group := range router.groups { + if group.matchType == MatchPort { + if slices.Equal(group.values, values) { + return group + } + continue + } + if group.EqualValues(values) { + return group + } + } + require.FailNow(t, "group not found", "values=%v", values) + return nil +} diff --git a/pkg/balance/router/router_static.go b/pkg/balance/router/router_static.go index b4230bd2f..00385fdd7 100644 --- a/pkg/balance/router/router_static.go +++ b/pkg/balance/router/router_static.go @@ -26,7 +26,7 @@ func (r *StaticRouter) GetBackendSelector(_ ClientInfo) BackendSelector { for _, backend := range r.backends { found := false for _, e := range excluded { - if e.Addr() == backend.Addr() { + if e.ID() == backend.ID() { found = true break } @@ -74,7 +74,7 @@ func (r *StaticRouter) OnRedirectFail(from, to string, conn RedirectableConn) er return nil } -func (r *StaticRouter) OnConnClosed(addr, redirectingAddr string, conn RedirectableConn) error { +func (r *StaticRouter) OnConnClosed(backendID, redirectingBackendID string, conn RedirectableConn) error { r.cnt-- return nil } @@ -82,6 +82,7 @@ func (r *StaticRouter) OnConnClosed(addr, redirectingAddr string, conn Redirecta type StaticBackend struct { addr string keyspace string + cluster string healthy atomic.Bool } @@ -97,6 +98,10 @@ func (b *StaticBackend) Addr() string { return b.addr } +func (b *StaticBackend) ID() string { + return b.addr +} + func (b *StaticBackend) Healthy() bool { return b.healthy.Load() } @@ -116,3 +121,7 @@ func (b *StaticBackend) Keyspace() string { func (b *StaticBackend) SetKeyspace(k string) { b.keyspace = k } + +func (b *StaticBackend) ClusterName() string { + return b.cluster +} diff --git a/pkg/manager/backendcluster/backend_id.go b/pkg/manager/backendcluster/backend_id.go new file mode 100644 index 000000000..b5177d59e --- /dev/null +++ b/pkg/manager/backendcluster/backend_id.go @@ -0,0 +1,12 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import "fmt" + +// backendID returns the opaque identity key for one backend in one backend cluster. +// It is only used as an in-memory map key and must not be parsed or used as a network address. +func backendID(clusterName, addr string) string { + return fmt.Sprintf("%s/%s", clusterName, addr) +} diff --git a/pkg/manager/backendcluster/manager.go b/pkg/manager/backendcluster/manager.go new file mode 100644 index 000000000..a5ee14b4f --- /dev/null +++ b/pkg/manager/backendcluster/manager.go @@ -0,0 +1,385 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "crypto/tls" + "maps" + "net" + "slices" + "strings" + "sync" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/pingcap/tiproxy/pkg/balance/metricsreader" + "github.com/pingcap/tiproxy/pkg/manager/infosync" + "github.com/pingcap/tiproxy/pkg/util/etcd" + httputil "github.com/pingcap/tiproxy/pkg/util/http" + "github.com/pingcap/tiproxy/pkg/util/netutil" + "github.com/pingcap/tiproxy/pkg/util/waitgroup" + clientv3 "go.etcd.io/etcd/client/v3" + "go.uber.org/zap" +) + +// Cluster is the cluster-scoped container for one backend PD cluster. +type Cluster struct { + cfg config.BackendCluster + etcdCli *clientv3.Client + infoSyncer *infosync.InfoSyncer + metrics *metricsreader.ClusterReader + httpCli *httputil.Client + dialer *netutil.DNSDialer +} + +func (c *Cluster) Config() config.BackendCluster { + return c.cfg +} + +func (c *Cluster) EtcdClient() *clientv3.Client { + return c.etcdCli +} + +func (c *Cluster) GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) { + return c.infoSyncer.GetTiDBTopology(ctx) +} + +func (c *Cluster) GetPromInfo(ctx context.Context) (*infosync.PrometheusInfo, error) { + return c.infoSyncer.GetPromInfo(ctx) +} + +func (c *Cluster) HTTPClient() *httputil.Client { + return c.httpCli +} + +func (c *Cluster) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + return c.dialer.DialContext(ctx, network, addr) +} + +type Manager struct { + lg *zap.Logger + clusterTLS func() *tls.Config + cfgGetter config.ConfigGetter + + wg waitgroup.WaitGroup + cancel context.CancelFunc + metrics *MetricsQuerier + network *NetworkRouter + + mu struct { + sync.RWMutex + clusters map[string]*Cluster + } +} + +func NewManager(lg *zap.Logger, clusterTLS func() *tls.Config) *Manager { + mgr := &Manager{ + lg: lg, + clusterTLS: clusterTLS, + } + mgr.mu.clusters = make(map[string]*Cluster) + mgr.metrics = NewMetricsQuerier(mgr) + mgr.network = NewNetworkRouter(mgr, clusterTLS) + return mgr +} + +func (m *Manager) Start(ctx context.Context, cfgGetter config.ConfigGetter, cfgCh <-chan *config.Config) error { + m.cfgGetter = cfgGetter + if err := m.syncClusters(ctx, cfgGetter.GetConfig()); err != nil { + return err + } + childCtx, cancel := context.WithCancel(ctx) + m.cancel = cancel + m.wg.Run(func() { + m.watchConfig(childCtx, cfgCh) + }, m.lg) + return nil +} + +func (m *Manager) watchConfig(ctx context.Context, cfgCh <-chan *config.Config) { + if cfgCh == nil { + return + } + for { + select { + case <-ctx.Done(): + return + case cfg, ok := <-cfgCh: + if !ok { + m.lg.Warn("config channel is closed, stop watching backend clusters") + return + } + if cfg == nil { + continue + } + if err := m.syncClusters(ctx, cfg); err != nil { + m.lg.Error("sync backend clusters failed", zap.Error(err)) + } + } + } +} + +func (m *Manager) syncClusters(ctx context.Context, cfg *config.Config) error { + if cfg == nil { + return nil + } + desiredClusters := cfg.GetBackendClusters() + desiredMap := make(map[string]config.BackendCluster, len(desiredClusters)) + for _, cluster := range desiredClusters { + desiredMap[cluster.Name] = cluster + } + + m.mu.Lock() + oldClusters := m.mu.clusters + newClusters := make(map[string]*Cluster, len(desiredClusters)) + closeList := make([]*Cluster, 0, len(oldClusters)) + + for _, clusterCfg := range desiredClusters { + oldCluster, ok := oldClusters[clusterCfg.Name] + if ok && clusterReusable(oldCluster, clusterCfg) { + newClusters[clusterCfg.Name] = oldCluster + delete(oldClusters, clusterCfg.Name) + continue + } + + cluster, err := m.buildCluster(ctx, cfg, clusterCfg) + if err != nil { + if ok { + m.lg.Warn("failed to update backend cluster, keep the old one", + zap.String("cluster", clusterCfg.Name), zap.Error(err)) + newClusters[clusterCfg.Name] = oldCluster + delete(oldClusters, clusterCfg.Name) + continue + } + m.lg.Error("failed to add backend cluster", + zap.String("cluster", clusterCfg.Name), zap.Error(err)) + continue + } + newClusters[clusterCfg.Name] = cluster + if ok { + closeList = append(closeList, oldCluster) + delete(oldClusters, clusterCfg.Name) + m.lg.Info("updated backend cluster", + zap.String("cluster", clusterCfg.Name), zap.String("pd_addrs", clusterCfg.PDAddrs)) + } else { + m.lg.Info("added backend cluster", + zap.String("cluster", clusterCfg.Name), zap.String("pd_addrs", clusterCfg.PDAddrs)) + } + } + + for name, cluster := range oldClusters { + if _, ok := desiredMap[name]; ok { + continue + } + closeList = append(closeList, cluster) + m.lg.Info("removed backend cluster", + zap.String("cluster", name), zap.String("pd_addrs", cluster.cfg.PDAddrs)) + } + + m.mu.clusters = newClusters + m.mu.Unlock() + + for _, cluster := range closeList { + if err := m.closeCluster(cluster); err != nil { + m.lg.Warn("close backend cluster failed", + zap.String("cluster", cluster.cfg.Name), zap.Error(err)) + } + } + return nil +} + +func normalizeCluster(cluster config.BackendCluster) config.BackendCluster { + cluster.Name = strings.TrimSpace(cluster.Name) + cluster.PDAddrs = strings.TrimSpace(cluster.PDAddrs) + return cluster +} + +func clusterReusable(cluster *Cluster, cfg config.BackendCluster) bool { + if cluster == nil { + return false + } + left := normalizeCluster(cluster.cfg) + right := normalizeCluster(cfg) + return left.Name == right.Name && + left.PDAddrs == right.PDAddrs && + slices.Equal(left.NSServers, right.NSServers) +} + +func (m *Manager) buildCluster(ctx context.Context, cfg *config.Config, clusterCfg config.BackendCluster) (*Cluster, error) { + clusterCfg = normalizeCluster(clusterCfg) + nameServers, err := config.ParseNSServers(clusterCfg.NSServers) + if err != nil { + return nil, err + } + dialer := netutil.NewDNSDialer(nameServers) + httpCli := httputil.NewHTTPClientWithDialContext(m.clusterTLS, dialer.DialContext) + + etcdCli, err := etcd.InitEtcdClientWithAddrsAndDialer( + m.lg.With(zap.String("cluster", clusterCfg.Name)).Named("etcd"), + clusterCfg.PDAddrs, + m.clusterTLS(), + dialer, + ) + if err != nil { + return nil, err + } + + infoSyncer := infosync.NewInfoSyncer(m.lg.With(zap.String("cluster", clusterCfg.Name)).Named("infosync"), etcdCli) + if err := infoSyncer.Init(ctx, cfg); err != nil { + if closeErr := etcdCli.Close(); closeErr != nil { + m.lg.Warn("close cluster etcd client failed after infosync init error", + zap.String("cluster", clusterCfg.Name), zap.Error(closeErr)) + } + return nil, err + } + + cluster := &Cluster{ + cfg: clusterCfg, + etcdCli: etcdCli, + infoSyncer: infoSyncer, + httpCli: httpCli, + dialer: dialer, + } + cluster.metrics = metricsreader.NewClusterReader( + m.lg.With(zap.String("cluster", clusterCfg.Name)).Named("metrics"), + clusterCfg.Name, + cluster, + cluster, + httpCli, + etcdCli, + config.NewDefaultHealthCheckConfig(), + m.cfgGetter, + ) + for key, query := range m.metrics.snapshot() { + cluster.metrics.AddQueryExpr(key, query.expr, query.rule) + } + if err := cluster.metrics.Start(ctx); err != nil { + _ = infoSyncer.Close() + if closeErr := etcdCli.Close(); closeErr != nil { + m.lg.Warn("close cluster etcd client failed after metrics init error", + zap.String("cluster", clusterCfg.Name), zap.Error(closeErr)) + } + return nil, err + } + + return cluster, nil +} + +func (m *Manager) closeCluster(cluster *Cluster) error { + if cluster == nil { + return nil + } + errs := make([]error, 0, 2) + if cluster.metrics != nil { + cluster.metrics.Close() + } + if cluster.infoSyncer != nil { + errs = append(errs, cluster.infoSyncer.Close()) + } + if cluster.etcdCli != nil { + errs = append(errs, cluster.etcdCli.Close()) + } + return errors.Collect(errors.New("close backend cluster"), errs...) +} + +func (m *Manager) Snapshot() map[string]*Cluster { + m.mu.RLock() + snapshot := make(map[string]*Cluster, len(m.mu.clusters)) + maps.Copy(snapshot, m.mu.clusters) + m.mu.RUnlock() + return snapshot +} + +func (m *Manager) HasBackendClusters() bool { + m.mu.RLock() + defer m.mu.RUnlock() + return len(m.mu.clusters) > 0 +} + +func (m *Manager) MetricsQuerier() *MetricsQuerier { + return m.metrics +} + +func (m *Manager) NetworkRouter() *NetworkRouter { + return m.network +} + +// PrimaryCluster returns the only configured cluster when the cluster count is exactly one. +// It exists for features that are only well-defined in the single-cluster case, such as VIP. +func (m *Manager) PrimaryCluster() *Cluster { + m.mu.RLock() + defer m.mu.RUnlock() + if len(m.mu.clusters) != 1 { + return nil + } + for _, cluster := range m.mu.clusters { + return cluster + } + return nil +} + +func (m *Manager) PreClose() { + for _, cluster := range m.Snapshot() { + if cluster == nil || cluster.metrics == nil { + continue + } + cluster.metrics.PreClose() + } +} + +func (m *Manager) GetTiDBTopology(ctx context.Context) (map[string]*infosync.TiDBTopologyInfo, error) { + clusters := m.Snapshot() + merged := make(map[string]*infosync.TiDBTopologyInfo, 128) + errs := make([]error, 0, len(clusters)) + for clusterName, cluster := range clusters { + infos, err := cluster.GetTiDBTopology(ctx) + if err != nil { + errs = append(errs, err) + continue + } + for _, info := range infos { + cloned := *info + backendID := backendID(clusterName, cloned.Addr) + if oldInfo, ok := merged[backendID]; ok { + m.lg.Warn("duplicate backend in cluster, keep the first one", + zap.String("backend_id", backendID), + zap.String("addr", cloned.Addr), + zap.String("cluster", clusterName), + zap.String("first_cluster", oldInfo.ClusterName)) + continue + } + cloned.Labels = info.Labels + cloned.ClusterName = clusterName + merged[backendID] = &cloned + } + } + if len(merged) == 0 && len(errs) > 0 { + return nil, errors.Collect(errors.New("fetch from backend clusters"), errs...) + } + return merged, nil +} + +func (m *Manager) Close() error { + if m.cancel != nil { + m.cancel() + } + m.wg.Wait() + + m.mu.Lock() + clusters := m.mu.clusters + m.mu.clusters = make(map[string]*Cluster) + m.mu.Unlock() + + errs := make([]error, 0, len(clusters)) + for _, cluster := range clusters { + if err := m.closeCluster(cluster); err != nil { + errs = append(errs, err) + } + } + if len(errs) == 0 { + return nil + } + return errors.Collect(errors.New("close backend cluster manager"), errs...) +} diff --git a/pkg/manager/backendcluster/manager_test.go b/pkg/manager/backendcluster/manager_test.go new file mode 100644 index 000000000..2b4d21b41 --- /dev/null +++ b/pkg/manager/backendcluster/manager_test.go @@ -0,0 +1,361 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "crypto/tls" + "encoding/json" + "net" + "path" + "sync" + "testing" + "time" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/lib/util/logger" + "github.com/pingcap/tiproxy/pkg/manager/infosync" + "github.com/pingcap/tiproxy/pkg/testkit" + "github.com/pingcap/tiproxy/pkg/util/etcd" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/server/v3/embed" + "go.uber.org/zap" +) + +const ( + testTiDBTopologyPath = "/topology/tidb" + testInfoSuffix = "info" + testTTLSuffix = "ttl" +) + +func nilClusterTLS() *tls.Config { + return nil +} + +func TestManagerFetchesAllClusters(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 2 { + return false + } + return topology[backendID("cluster-a", "10.0.0.1:4000")].ClusterName == "cluster-a" && + topology[backendID("cluster-b", "10.0.0.2:4000")].ClusterName == "cluster-b" + }, 5*time.Second, 100*time.Millisecond) +} + +func TestManagerDynamicClusterUpdate(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.PDAddrs = "" + cfg.Proxy.BackendClusters = nil + cfgGetter := newManagerTestConfigGetter(cfg) + cfgCh := make(chan *config.Config, 4) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, cfgCh)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + topology, err := mgr.GetTiDBTopology(context.Background()) + require.NoError(t, err) + require.Empty(t, topology) + + nextCfg := cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 1 { + return false + } + info, ok := topology[backendID("cluster-a", "10.0.0.1:4000")] + return ok && info.ClusterName == "cluster-a" + }, 5*time.Second, 100*time.Millisecond) + + nextCfg = cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 1 { + return false + } + info, ok := topology[backendID("cluster-b", "10.0.0.2:4000")] + return ok && info.ClusterName == "cluster-b" + }, 5*time.Second, 100*time.Millisecond) +} + +func TestManagerUsesClusterNameServersForPD(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + dnsA := testkit.StartDNSServer(t, map[string][]string{"pd-a.test": {"127.0.0.1"}}) + dnsB := testkit.StartDNSServer(t, map[string][]string{"pd-b.test": {"127.0.0.1"}}) + _, portA, err := net.SplitHostPort(clusterA.addr) + require.NoError(t, err) + _, portB, err := net.SplitHostPort(clusterB.addr) + require.NoError(t, err) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: net.JoinHostPort("pd-a.test", portA), NSServers: []string{dnsA.Addr()}}, + {Name: "cluster-b", PDAddrs: net.JoinHostPort("pd-b.test", portB), NSServers: []string{dnsB.Addr()}}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 2 { + return false + } + return topology[backendID("cluster-a", "10.0.0.1:4000")].ClusterName == "cluster-a" && + topology[backendID("cluster-b", "10.0.0.2:4000")].ClusterName == "cluster-b" + }, 5*time.Second, 100*time.Millisecond) + require.Greater(t, dnsA.QueryCount("pd-a.test"), 0) + require.Greater(t, dnsB.QueryCount("pd-b.test"), 0) +} + +func TestManagerKeepsOldClusterWhenUpdateFails(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "10.0.0.2:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 1 { + return false + } + _, ok := topology[backendID("cluster-a", "10.0.0.1:4000")] + return ok + }, 5*time.Second, 100*time.Millisecond) + + originalCluster := mgr.Snapshot()["cluster-a"] + require.NotNil(t, originalCluster) + + nextCfg := cfg.Clone() + nextCfg.Proxy.Addr = "invalid" + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterB.addr}, + } + require.NoError(t, mgr.syncClusters(context.Background(), nextCfg)) + + currentCluster := mgr.Snapshot()["cluster-a"] + require.Same(t, originalCluster, currentCluster) + + topology, err := mgr.GetTiDBTopology(context.Background()) + require.NoError(t, err) + require.Contains(t, topology, backendID("cluster-a", "10.0.0.1:4000")) + require.NotContains(t, topology, backendID("cluster-a", "10.0.0.2:4000")) +} +func TestManagerUpdatesClusterNameServersForPD(t *testing.T) { + cluster := newManagerTestEtcdCluster(t) + t.Cleanup(func() { cluster.close(t) }) + + cluster.putTopology(t, "10.0.0.1:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + + dnsA := testkit.StartDNSServer(t, map[string][]string{"pd.test": {"127.0.0.1"}}) + dnsB := testkit.StartDNSServer(t, map[string][]string{"pd.test": {"127.0.0.1"}}) + _, port, err := net.SplitHostPort(cluster.addr) + require.NoError(t, err) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: net.JoinHostPort("pd.test", port), NSServers: []string{dnsA.Addr()}}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + cfgCh := make(chan *config.Config, 1) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, cfgCh)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + return dnsA.QueryCount("pd.test") > 0 + }, 5*time.Second, 100*time.Millisecond) + + originalCluster := mgr.Snapshot()["cluster-a"] + require.NotNil(t, originalCluster) + + nextCfg := cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: net.JoinHostPort("pd.test", port), NSServers: []string{dnsB.Addr()}}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + + require.Eventually(t, func() bool { + currentCluster := mgr.Snapshot()["cluster-a"] + return currentCluster != nil && + currentCluster != originalCluster && + dnsB.QueryCount("pd.test") > 0 + }, 5*time.Second, 100*time.Millisecond) +} +func TestManagerKeepsDuplicateBackendAddrsAcrossClusters(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + clusterA.putTopology(t, "shared.tidb:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.1", StatusPort: 10080}) + clusterB.putTopology(t, "shared.tidb:4000", &infosync.TiDBTopologyInfo{IP: "10.0.0.2", StatusPort: 10080}) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + + mgr := NewManager(zapLoggerForTest(t), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, nil)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + require.Eventually(t, func() bool { + topology, err := mgr.GetTiDBTopology(context.Background()) + if err != nil || len(topology) != 2 { + return false + } + infoA, okA := topology[backendID("cluster-a", "shared.tidb:4000")] + infoB, okB := topology[backendID("cluster-b", "shared.tidb:4000")] + return okA && okB && infoA.Addr == "shared.tidb:4000" && infoB.Addr == "shared.tidb:4000" + }, 5*time.Second, 100*time.Millisecond) +} + +type managerTestConfigGetter struct { + mu sync.RWMutex + cfg *config.Config +} + +func newManagerTestConfigGetter(cfg *config.Config) *managerTestConfigGetter { + return &managerTestConfigGetter{cfg: cfg} +} + +func (g *managerTestConfigGetter) GetConfig() *config.Config { + g.mu.RLock() + defer g.mu.RUnlock() + return g.cfg +} + +func (g *managerTestConfigGetter) setConfig(cfg *config.Config) { + g.mu.Lock() + g.cfg = cfg + g.mu.Unlock() +} + +type managerTestEtcdCluster struct { + etcd *embed.Etcd + client *clientv3.Client + kv clientv3.KV + addr string +} + +func newManagerTestEtcdCluster(t *testing.T) *managerTestEtcdCluster { + lg, _ := logger.CreateLoggerForTest(t) + etcdSrv, err := etcd.CreateEtcdServer("127.0.0.1:0", t.TempDir(), lg) + require.NoError(t, err) + addr := etcdSrv.Clients[0].Addr().String() + cli, err := etcd.InitEtcdClientWithAddrs(lg, addr, nil) + require.NoError(t, err) + return &managerTestEtcdCluster{ + etcd: etcdSrv, + client: cli, + kv: clientv3.NewKV(cli), + addr: addr, + } +} + +func (tec *managerTestEtcdCluster) close(t *testing.T) { + require.NoError(t, tec.client.Close()) + tec.etcd.Close() +} + +func (tec *managerTestEtcdCluster) putTopology(t *testing.T, sqlAddr string, info *infosync.TiDBTopologyInfo) { + data, err := json.Marshal(info) + require.NoError(t, err) + _, err = tec.kv.Put(context.Background(), path.Join(testTiDBTopologyPath, sqlAddr, testInfoSuffix), string(data)) + require.NoError(t, err) + _, err = tec.kv.Put(context.Background(), path.Join(testTiDBTopologyPath, sqlAddr, testTTLSuffix), "1") + require.NoError(t, err) +} + +func newManagerTestConfig() *config.Config { + cfg := config.NewConfig() + cfg.Proxy.Addr = "127.0.0.1:6000" + cfg.API.Addr = "127.0.0.1:3080" + cfg.Proxy.PDAddrs = "" + cfg.Proxy.BackendClusters = nil + return cfg +} + +func zapLoggerForTest(t *testing.T) *zap.Logger { + lg, _ := logger.CreateLoggerForTest(t) + return lg +} diff --git a/pkg/manager/backendcluster/metrics_querier.go b/pkg/manager/backendcluster/metrics_querier.go new file mode 100644 index 000000000..0f364b0d8 --- /dev/null +++ b/pkg/manager/backendcluster/metrics_querier.go @@ -0,0 +1,135 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "reflect" + "sync" + + "github.com/pingcap/tiproxy/pkg/balance/metricsreader" + "github.com/prometheus/common/model" +) + +var _ metricsreader.MetricsQuerier = (*MetricsQuerier)(nil) + +// MetricsQuerier is a thin fan-out and merge view over cluster-scoped metrics readers. +// It does not own any metrics collection lifecycle by itself. +type MetricsQuerier struct { + manager *Manager + mu sync.RWMutex + exprs map[string]metricsreader.QueryExpr + rules map[string]metricsreader.QueryRule +} + +func NewMetricsQuerier(manager *Manager) *MetricsQuerier { + return &MetricsQuerier{ + manager: manager, + exprs: make(map[string]metricsreader.QueryExpr), + rules: make(map[string]metricsreader.QueryRule), + } +} + +func (mq *MetricsQuerier) AddQueryExpr(key string, queryExpr metricsreader.QueryExpr, queryRule metricsreader.QueryRule) { + mq.mu.Lock() + mq.exprs[key] = queryExpr + mq.rules[key] = queryRule + mq.mu.Unlock() + + for _, cluster := range mq.manager.Snapshot() { + cluster.metrics.AddQueryExpr(key, queryExpr, queryRule) + } +} + +func (mq *MetricsQuerier) RemoveQueryExpr(key string) { + mq.mu.Lock() + delete(mq.exprs, key) + delete(mq.rules, key) + mq.mu.Unlock() + + for _, cluster := range mq.manager.Snapshot() { + cluster.metrics.RemoveQueryExpr(key) + } +} + +func (mq *MetricsQuerier) GetQueryResult(key string) metricsreader.QueryResult { + results := make([]metricsreader.QueryResult, 0, len(mq.manager.Snapshot())) + for _, cluster := range mq.manager.Snapshot() { + result := cluster.metrics.GetQueryResult(key) + if result.Empty() { + continue + } + results = append(results, result) + } + return mergeQueryResults(results) +} + +func (mq *MetricsQuerier) GetBackendMetrics() []byte { + return mq.GetBackendMetricsByCluster("") +} + +func (mq *MetricsQuerier) GetBackendMetricsByCluster(clusterName string) []byte { + if clusterName != "" { + snapshot := mq.manager.Snapshot() + cluster := snapshot[clusterName] + if cluster == nil { + return nil + } + return cluster.metrics.GetBackendMetrics() + } + if cluster := mq.manager.PrimaryCluster(); cluster != nil { + return cluster.metrics.GetBackendMetrics() + } + return nil +} + +func (mq *MetricsQuerier) snapshot() map[string]struct { + expr metricsreader.QueryExpr + rule metricsreader.QueryRule +} { + mq.mu.RLock() + snapshot := make(map[string]struct { + expr metricsreader.QueryExpr + rule metricsreader.QueryRule + }, len(mq.exprs)) + for key, expr := range mq.exprs { + snapshot[key] = struct { + expr metricsreader.QueryExpr + rule metricsreader.QueryRule + }{ + expr: expr, + rule: mq.rules[key], + } + } + mq.mu.RUnlock() + return snapshot +} + +func mergeQueryResults(results []metricsreader.QueryResult) metricsreader.QueryResult { + if len(results) == 0 { + return metricsreader.QueryResult{} + } + merged := metricsreader.QueryResult{} + for _, result := range results { + if merged.UpdateTime.Before(result.UpdateTime) { + merged.UpdateTime = result.UpdateTime + } + if result.Value == nil || reflect.ValueOf(result.Value).IsNil() { + continue + } + switch value := result.Value.(type) { + case model.Vector: + vector, _ := merged.Value.(model.Vector) + vector = append(vector, value...) + merged.Value = vector + case model.Matrix: + matrix, _ := merged.Value.(model.Matrix) + matrix = append(matrix, value...) + merged.Value = matrix + } + } + if merged.Value == nil { + return metricsreader.QueryResult{} + } + return merged +} diff --git a/pkg/manager/backendcluster/metrics_querier_test.go b/pkg/manager/backendcluster/metrics_querier_test.go new file mode 100644 index 000000000..8375ab70f --- /dev/null +++ b/pkg/manager/backendcluster/metrics_querier_test.go @@ -0,0 +1,165 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "reflect" + "testing" + "time" + + "github.com/pingcap/tiproxy/lib/config" + "github.com/pingcap/tiproxy/pkg/balance/metricsreader" + "github.com/prometheus/common/model" + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestMetricsQuerierQueryRegistry(t *testing.T) { + mgr := NewManager(zap.NewNop(), nilClusterTLS) + mq := NewMetricsQuerier(mgr) + + expr := metricsreader.QueryExpr{PromQL: "avg(up)"} + rule := metricsreader.QueryRule{ + Names: []string{"up"}, + ResultType: model.ValVector, + } + + mq.AddQueryExpr("up", expr, rule) + + snapshot := mq.snapshot() + require.Len(t, snapshot, 1) + require.Equal(t, expr, snapshot["up"].expr) + require.Equal(t, rule.Names, snapshot["up"].rule.Names) + require.Equal(t, rule.ResultType, snapshot["up"].rule.ResultType) + + mq.RemoveQueryExpr("up") + require.Empty(t, mq.snapshot()) +} + +func TestMergeQueryResults(t *testing.T) { + ts1 := time.Unix(10, 0) + ts2 := time.Unix(20, 0) + vector1 := model.Vector{ + &model.Sample{ + Metric: model.Metric{model.LabelName("instance"): "tidb-1"}, + Value: model.SampleValue(1), + Timestamp: model.Time(1000), + }, + } + vector2 := model.Vector{ + &model.Sample{ + Metric: model.Metric{model.LabelName("instance"): "tidb-2"}, + Value: model.SampleValue(2), + Timestamp: model.Time(2000), + }, + } + + merged := mergeQueryResults([]metricsreader.QueryResult{ + {UpdateTime: ts1, Value: vector1}, + {UpdateTime: ts2, Value: vector2}, + }) + require.Equal(t, ts2, merged.UpdateTime) + require.Len(t, merged.Value.(model.Vector), 2) + require.ElementsMatch(t, []string{"tidb-1", "tidb-2"}, []string{ + string(merged.Value.(model.Vector)[0].Metric[model.LabelName("instance")]), + string(merged.Value.(model.Vector)[1].Metric[model.LabelName("instance")]), + }) +} + +func TestMergeQueryResultsMatrix(t *testing.T) { + ts1 := time.Unix(10, 0) + ts2 := time.Unix(20, 0) + matrix1 := model.Matrix{ + &model.SampleStream{ + Metric: model.Metric{model.LabelName("instance"): "tidb-1"}, + Values: []model.SamplePair{{Timestamp: model.Time(1000), Value: model.SampleValue(1)}}, + }, + } + matrix2 := model.Matrix{ + &model.SampleStream{ + Metric: model.Metric{model.LabelName("instance"): "tidb-2"}, + Values: []model.SamplePair{{Timestamp: model.Time(2000), Value: model.SampleValue(2)}}, + }, + } + + merged := mergeQueryResults([]metricsreader.QueryResult{ + {UpdateTime: ts1, Value: matrix1}, + {UpdateTime: ts2, Value: matrix2}, + }) + require.Equal(t, ts2, merged.UpdateTime) + require.Len(t, merged.Value.(model.Matrix), 2) + require.ElementsMatch(t, []string{"tidb-1", "tidb-2"}, []string{ + string(merged.Value.(model.Matrix)[0].Metric[model.LabelName("instance")]), + string(merged.Value.(model.Matrix)[1].Metric[model.LabelName("instance")]), + }) +} + +func TestMetricsQuerierPropagatesQueriesToExistingAndNewClusters(t *testing.T) { + clusterA := newManagerTestEtcdCluster(t) + clusterB := newManagerTestEtcdCluster(t) + t.Cleanup(func() { clusterA.close(t) }) + t.Cleanup(func() { clusterB.close(t) }) + + cfg := newManagerTestConfig() + cfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + } + cfgGetter := newManagerTestConfigGetter(cfg) + cfgCh := make(chan *config.Config, 1) + + mgr := NewManager(zap.NewNop(), nilClusterTLS) + require.NoError(t, mgr.Start(context.Background(), cfgGetter, cfgCh)) + t.Cleanup(func() { + require.NoError(t, mgr.Close()) + }) + + expr := metricsreader.QueryExpr{PromQL: "sum(up)"} + rule := metricsreader.QueryRule{ + Names: []string{"up"}, + ResultType: model.ValVector, + } + mgr.MetricsQuerier().AddQueryExpr("up", expr, rule) + + require.Eventually(t, func() bool { + return clusterHasBackendQueryRule(mgr.Snapshot()["cluster-a"], "up") + }, 5*time.Second, 100*time.Millisecond) + + nextCfg := cfg.Clone() + nextCfg.Proxy.BackendClusters = []config.BackendCluster{ + {Name: "cluster-a", PDAddrs: clusterA.addr}, + {Name: "cluster-b", PDAddrs: clusterB.addr}, + } + cfgGetter.setConfig(nextCfg) + cfgCh <- nextCfg.Clone() + + require.Eventually(t, func() bool { + snapshot := mgr.Snapshot() + return clusterHasBackendQueryRule(snapshot["cluster-a"], "up") && + clusterHasBackendQueryRule(snapshot["cluster-b"], "up") + }, 5*time.Second, 100*time.Millisecond) + + mgr.MetricsQuerier().RemoveQueryExpr("up") + require.Eventually(t, func() bool { + snapshot := mgr.Snapshot() + return !clusterHasBackendQueryRule(snapshot["cluster-a"], "up") && + !clusterHasBackendQueryRule(snapshot["cluster-b"], "up") + }, 5*time.Second, 100*time.Millisecond) +} + +func clusterHasBackendQueryRule(cluster *Cluster, key string) bool { + if cluster == nil || cluster.metrics == nil { + return false + } + metricsValue := reflect.ValueOf(cluster.metrics).Elem() + backendReaderValue := metricsValue.FieldByName("backendReader") + if !backendReaderValue.IsValid() || backendReaderValue.IsNil() { + return false + } + queryRulesValue := backendReaderValue.Elem().FieldByName("queryRules") + if !queryRulesValue.IsValid() || queryRulesValue.Len() == 0 { + return false + } + return queryRulesValue.MapIndex(reflect.ValueOf(key)).IsValid() +} diff --git a/pkg/manager/backendcluster/network_router.go b/pkg/manager/backendcluster/network_router.go new file mode 100644 index 000000000..4dc173f4a --- /dev/null +++ b/pkg/manager/backendcluster/network_router.go @@ -0,0 +1,60 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "crypto/tls" + "net" + + "github.com/pingcap/tiproxy/lib/util/errors" + httputil "github.com/pingcap/tiproxy/pkg/util/http" + "github.com/pingcap/tiproxy/pkg/util/netutil" +) + +var ErrBackendClusterNotFound = errors.New("backend cluster not found") + +// NetworkRouter is a thin dispatch view over cluster-scoped dialers and HTTP clients. +// It does not own any cluster lifecycle by itself. +type NetworkRouter struct { + manager *Manager + clusterTLS func() *tls.Config + defaultDial *netutil.DNSDialer + defaultHTTP *httputil.Client +} + +func NewNetworkRouter(manager *Manager, clusterTLS func() *tls.Config) *NetworkRouter { + return &NetworkRouter{ + manager: manager, + clusterTLS: clusterTLS, + defaultDial: netutil.NewDNSDialer(nil), + defaultHTTP: httputil.NewHTTPClientWithDialContext(clusterTLS, nil), + } +} + +func (nr *NetworkRouter) missingClusterHTTPClient(clusterName string) *httputil.Client { + return httputil.NewHTTPClientWithDialContext(nr.clusterTLS, func(context.Context, string, string) (net.Conn, error) { + return nil, errors.Wrapf(ErrBackendClusterNotFound, "cluster %s", clusterName) + }) +} + +func (nr *NetworkRouter) HTTPClient(clusterName string) *httputil.Client { + if clusterName != "" { + if cluster := nr.manager.Snapshot()[clusterName]; cluster != nil { + return cluster.HTTPClient() + } + return nr.missingClusterHTTPClient(clusterName) + } + return nr.defaultHTTP +} + +func (nr *NetworkRouter) DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) { + if clusterName != "" { + if cluster := nr.manager.Snapshot()[clusterName]; cluster != nil { + return cluster.DialContext(ctx, network, addr) + } + return nil, errors.Wrapf(ErrBackendClusterNotFound, "cluster %s", clusterName) + } + return nr.defaultDial.DialContext(ctx, network, addr) +} diff --git a/pkg/manager/backendcluster/network_router_test.go b/pkg/manager/backendcluster/network_router_test.go new file mode 100644 index 000000000..593ef9bbb --- /dev/null +++ b/pkg/manager/backendcluster/network_router_test.go @@ -0,0 +1,58 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package backendcluster + +import ( + "context" + "net" + "testing" + "time" + + "github.com/cenkalti/backoff/v4" + "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/stretchr/testify/require" +) + +func TestNetworkRouterDialContextRejectsMissingCluster(t *testing.T) { + router := NewNetworkRouter(&Manager{}, nilClusterTLS) + _, err := router.DialContext(context.Background(), "tcp", "127.0.0.1:80", "missing") + require.Error(t, err) + require.True(t, errors.Is(err, ErrBackendClusterNotFound)) +} + +func TestNetworkRouterHTTPClientRejectsMissingCluster(t *testing.T) { + router := NewNetworkRouter(&Manager{}, nilClusterTLS) + b := backoff.WithMaxRetries(backoff.NewConstantBackOff(time.Millisecond), 0) + _, err := router.HTTPClient("missing").Get("127.0.0.1:80", "/status", b, time.Second) + require.Error(t, err) + require.True(t, errors.Is(err, ErrBackendClusterNotFound)) +} + +func TestNetworkRouterDialContextFallsBackWithoutClusterName(t *testing.T) { + ln, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, ln.Close()) + }) + + accepted := make(chan struct{}, 1) + go func() { + conn, err := ln.Accept() + if err == nil { + accepted <- struct{}{} + _ = conn.Close() + } + }() + + router := NewNetworkRouter(&Manager{}, nilClusterTLS) + conn, err := router.DialContext(context.Background(), "tcp", ln.Addr().String(), "") + require.NoError(t, err) + require.NoError(t, conn.Close()) + + select { + case <-accepted: + case <-time.After(time.Second): + t.Fatal("listener was not reached through default dialer") + } +} diff --git a/pkg/manager/infosync/info.go b/pkg/manager/infosync/info.go index 4d4e9f3d4..9944abbed 100644 --- a/pkg/manager/infosync/info.go +++ b/pkg/manager/infosync/info.go @@ -93,6 +93,8 @@ type TopologyInfo struct { type TiDBTopologyInfo struct { Version string `json:"version"` GitHash string `json:"git_hash"` + Addr string `json:"-"` + ClusterName string `json:"-"` IP string `json:"ip"` StatusPort uint `json:"status_port"` DeployPath string `json:"deploy_path"` @@ -295,6 +297,7 @@ func (is *InfoSyncer) GetTiDBTopology(ctx context.Context) (map[string]*TiDBTopo zap.String("value", hack.String(kv.Value)), zap.Error(err)) } else { infos[addr] = topology + topology.Addr = addr topology.Keyspace = keyspace } } diff --git a/pkg/manager/infosync/info_test.go b/pkg/manager/infosync/info_test.go index e6ac7d311..966e8fad7 100644 --- a/pkg/manager/infosync/info_test.go +++ b/pkg/manager/infosync/info_test.go @@ -128,6 +128,7 @@ func TestFetchTiDBTopology(t *testing.T) { check: func(info map[string]*TiDBTopologyInfo) { require.Len(ts.t, info, 1) require.NotNil(ts.t, info["1.1.1.1:4000"]) + require.Equal(ts.t, "1.1.1.1:4000", info["1.1.1.1:4000"].Addr) require.Equal(ts.t, "1.1.1.1", info["1.1.1.1:4000"].IP) require.Equal(ts.t, uint(10080), info["1.1.1.1:4000"].StatusPort) }, @@ -144,6 +145,7 @@ func TestFetchTiDBTopology(t *testing.T) { check: func(info map[string]*TiDBTopologyInfo) { require.Len(ts.t, info, 2) require.NotNil(ts.t, info["2.2.2.2:4000"]) + require.Equal(ts.t, "2.2.2.2:4000", info["2.2.2.2:4000"].Addr) require.Equal(ts.t, "2.2.2.2", info["2.2.2.2:4000"].IP) require.Equal(ts.t, uint(10080), info["2.2.2.2:4000"].StatusPort) }, @@ -170,6 +172,7 @@ func TestFetchTiDBTopology(t *testing.T) { check: func(info map[string]*TiDBTopologyInfo) { require.Len(ts.t, info, 2) require.NotNil(ts.t, info["3.3.3.3:4000"]) + require.Equal(ts.t, "3.3.3.3:4000", info["3.3.3.3:4000"].Addr) require.Equal(ts.t, "3.3.3.3", info["3.3.3.3:4000"].IP) require.Equal(ts.t, uint(10080), info["3.3.3.3:4000"].StatusPort) require.Equal(ts.t, "test", info["3.3.3.3:4000"].Keyspace) diff --git a/pkg/manager/namespace/manager.go b/pkg/manager/namespace/manager.go index f66c2301f..06eb8c5b2 100644 --- a/pkg/manager/namespace/manager.go +++ b/pkg/manager/namespace/manager.go @@ -10,7 +10,6 @@ import ( "context" "fmt" "maps" - "reflect" "sync" "github.com/pingcap/tiproxy/lib/config" @@ -25,9 +24,10 @@ import ( ) type NamespaceManager interface { + SetBackendNetwork(backendNetwork observer.BackendNetwork) Init(logger *zap.Logger, nscs []*config.Namespace, tpFetcher observer.TopologyFetcher, promFetcher metricsreader.PromInfoFetcher, httpCli *http.Client, cfgMgr *mconfig.ConfigManager, - metricsReader metricsreader.MetricsReader) error + metricsReader metricsreader.MetricsQuerier) error CommitNamespaces(nss []*config.Namespace, nssDelete []bool) error GetNamespace(nm string) (*Namespace, bool) GetNamespaceByUser(user string) (*Namespace, bool) @@ -38,13 +38,14 @@ type NamespaceManager interface { type namespaceManager struct { sync.RWMutex - nsm map[string]*Namespace - tpFetcher observer.TopologyFetcher - promFetcher metricsreader.PromInfoFetcher - metricsReader metricsreader.MetricsReader - httpCli *http.Client - logger *zap.Logger - cfgMgr *mconfig.ConfigManager + nsm map[string]*Namespace + tpFetcher observer.TopologyFetcher + promFetcher metricsreader.PromInfoFetcher + metricsReader metricsreader.MetricsQuerier + httpCli *http.Client + backendNetwork observer.BackendNetwork + logger *zap.Logger + cfgMgr *mconfig.ConfigManager } func NewNamespaceManager() *namespaceManager { @@ -54,18 +55,15 @@ func NewNamespaceManager() *namespaceManager { func (mgr *namespaceManager) buildNamespace(cfg *config.Namespace) (*Namespace, error) { logger := mgr.logger.With(zap.String("namespace", cfg.Namespace)) - // init BackendFetcher - var fetcher observer.BackendFetcher healthCheckCfg := config.NewDefaultHealthCheckConfig() - if mgr.tpFetcher != nil && !reflect.ValueOf(mgr.tpFetcher).IsNil() { - fetcher = observer.NewPDFetcher(mgr.tpFetcher, logger.Named("be_fetcher"), healthCheckCfg) - } else { - fetcher = observer.NewStaticFetcher(cfg.Backend.Instances) - } + // Namespace always receives a topology fetcher from the cluster manager. PDFetcher preserves + // legacy static backend.instances compatibility by falling back internally before any backend + // cluster is configured. + fetcher := observer.NewPDFetcher(mgr.tpFetcher, cfg.Backend.Instances, logger.Named("be_fetcher"), healthCheckCfg) // init Router rt := router.NewScoreBasedRouter(logger.Named("router")) - hc := observer.NewDefaultHealthCheck(mgr.httpCli, healthCheckCfg, logger.Named("hc")) + hc := observer.NewDefaultHealthCheckWithNetwork(mgr.backendNetwork, healthCheckCfg, logger.Named("hc")) bo := observer.NewDefaultBackendObserver(logger.Named("observer"), healthCheckCfg, fetcher, hc, mgr.cfgMgr) bo.Start(context.Background()) bpCreator := func(lg *zap.Logger) policy.BalancePolicy { @@ -110,7 +108,7 @@ func (mgr *namespaceManager) CommitNamespaces(nss []*config.Namespace, nssDelete func (mgr *namespaceManager) Init(logger *zap.Logger, nscs []*config.Namespace, tpFetcher observer.TopologyFetcher, promFetcher metricsreader.PromInfoFetcher, httpCli *http.Client, cfgMgr *mconfig.ConfigManager, - metricsReader metricsreader.MetricsReader) error { + metricsReader metricsreader.MetricsQuerier) error { mgr.Lock() mgr.tpFetcher = tpFetcher mgr.promFetcher = promFetcher @@ -122,6 +120,12 @@ func (mgr *namespaceManager) Init(logger *zap.Logger, nscs []*config.Namespace, return mgr.CommitNamespaces(nscs, nil) } +func (mgr *namespaceManager) SetBackendNetwork(backendNetwork observer.BackendNetwork) { + mgr.Lock() + mgr.backendNetwork = backendNetwork + mgr.Unlock() +} + func (mgr *namespaceManager) GetNamespace(nm string) (*Namespace, bool) { mgr.RLock() defer mgr.RUnlock() diff --git a/pkg/manager/namespace/manager_test.go b/pkg/manager/namespace/manager_test.go index fb2609975..fbe6f15fd 100644 --- a/pkg/manager/namespace/manager_test.go +++ b/pkg/manager/namespace/manager_test.go @@ -4,16 +4,28 @@ package namespace import ( + "context" "testing" "github.com/pingcap/tiproxy/pkg/balance/router" + "github.com/pingcap/tiproxy/pkg/manager/infosync" "github.com/stretchr/testify/require" "go.uber.org/zap" ) +type mockTopologyFetcher struct{} + +func (*mockTopologyFetcher) GetTiDBTopology(context.Context) (map[string]*infosync.TiDBTopologyInfo, error) { + return nil, nil +} + +func (*mockTopologyFetcher) HasBackendClusters() bool { + return false +} + func TestReady(t *testing.T) { nsMgr := NewNamespaceManager() - require.NoError(t, nsMgr.Init(zap.NewNop(), nil, nil, nil, nil, nil, nil)) + require.NoError(t, nsMgr.Init(zap.NewNop(), nil, &mockTopologyFetcher{}, nil, nil, nil, nil)) require.False(t, nsMgr.Ready()) rt := router.NewStaticRouter([]string{}) diff --git a/pkg/proxy/backend/backend_conn_mgr.go b/pkg/proxy/backend/backend_conn_mgr.go index 9619d60cd..e2f48d298 100644 --- a/pkg/proxy/backend/backend_conn_mgr.go +++ b/pkg/proxy/backend/backend_conn_mgr.go @@ -91,6 +91,7 @@ type BCConfig struct { HealthyKeepAlive config.KeepAlive UnhealthyKeepAlive config.KeepAlive FromPublicEndpoints func(addr net.Addr) bool + DialContext func(ctx context.Context, backend router.BackendInst, addr string) (net.Conn, error) TickerInterval time.Duration CheckBackendInterval time.Duration DialTimeout time.Duration @@ -287,6 +288,9 @@ func (mgr *BackendConnManager) getBackendIO(ctx context.Context, cctx ConnContex ci.ClientAddr = mgr.clientIO.RemoteAddr() ci.ProxyAddr = mgr.clientIO.ProxyAddr() } + if addr, ok := cctx.Value(ConnContextKeyConnAddr).(string); ok { + ci.ListenerAddr = addr + } selector := r.GetBackendSelector(ci) startTime := time.Now() var addr string @@ -306,7 +310,9 @@ func (mgr *BackendConnManager) getBackendIO(ctx context.Context, cctx ConnContex var cn net.Conn addr = backend.Addr() - cn, err = net.DialTimeout("tcp", addr, mgr.config.DialTimeout) + dialCtx, cancel := context.WithTimeout(bctx, mgr.config.DialTimeout) + cn, err = mgr.dialBackend(dialCtx, backend, addr) + cancel() selector.Finish(mgr, err == nil) if err != nil { metrics.DialBackendFailCounter.WithLabelValues(addr).Inc() @@ -591,8 +597,8 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } rs := &redirectResult{ - from: mgr.ServerAddr(), - to: (*backendInst).Addr(), + from: mgr.curBackend.ID(), + to: (*backendInst).ID(), } defer func() { // The `mgr` won't be notified again before it calls `OnRedirectSucceed`, so simply `StorePointer` is also fine. @@ -639,12 +645,14 @@ func (mgr *BackendConnManager) tryRedirect(ctx context.Context) { } var cn net.Conn - cn, rs.err = net.DialTimeout("tcp", rs.to, mgr.config.DialTimeout) + dialCtx, cancel := context.WithTimeout(ctx, mgr.config.DialTimeout) + cn, rs.err = mgr.dialBackend(dialCtx, *backendInst, (*backendInst).Addr()) + cancel() if rs.err != nil { - mgr.handshakeHandler.OnHandshake(mgr, rs.to, rs.err, SrcBackendNetwork) + mgr.handshakeHandler.OnHandshake(mgr, (*backendInst).Addr(), rs.err, SrcBackendNetwork) return } - newBackendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr(rs.to, cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))) + newBackendIO := pnet.PacketIO(pnet.NewPacketIO(cn, mgr.logger, mgr.config.ConnBufferSize, pnet.WithRemoteAddr((*backendInst).Addr(), cn.RemoteAddr()), pnet.WithWrapError(ErrBackendConn))) if rs.err = mgr.authenticator.handshakeSecondTime(mgr.logger, mgr.clientIO, newBackendIO, mgr.backendTLS, sessionToken); rs.err == nil { rs.err = mgr.initSessionStates(newBackendIO, sessionStates) @@ -810,6 +818,14 @@ func (mgr *BackendConnManager) Value(key any) any { return v } +func (mgr *BackendConnManager) dialBackend(ctx context.Context, backend router.BackendInst, addr string) (net.Conn, error) { + if mgr.config.DialContext != nil { + return mgr.config.DialContext(ctx, backend, addr) + } + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", addr) +} + // Close releases all resources. func (mgr *BackendConnManager) Close() error { // BackendConnMgr may close even before connecting, so protect the members with a lock. @@ -833,9 +849,7 @@ func (mgr *BackendConnManager) Close() error { handErr := mgr.handshakeHandler.OnConnClose(mgr, mgr.quitSource) var connErr error - var addr string if backendIO := mgr.backendIO.Swap(nil); backendIO != nil { - addr = (*backendIO).RemoteAddr().String() connErr = (*backendIO).Close() } @@ -846,13 +860,14 @@ func (mgr *BackendConnManager) Close() error { mgr.notifyRedirectResult(context.Background(), <-mgr.redirectResCh) } // The connection may have just received the redirecting signal. - if len(addr) > 0 { - var redirectingAddr string + if mgr.curBackend != nil { + var redirectingBackendID string if redirectingBackend := mgr.redirectInfo.Load(); redirectingBackend != nil { - redirectingAddr = (*redirectingBackend).Addr() + redirectingBackendID = (*redirectingBackend).ID() } - if err := eventReceiver.OnConnClosed(addr, redirectingAddr, mgr); err != nil { - mgr.logger.Error("close connection error", zap.String("backend_addr", addr), zap.NamedError("notify_err", err)) + if err := eventReceiver.OnConnClosed(mgr.curBackend.ID(), redirectingBackendID, mgr); err != nil { + mgr.logger.Error("close connection error", + zap.String("backend_id", mgr.curBackend.ID()), zap.String("backend_addr", mgr.curBackend.Addr()), zap.NamedError("notify_err", err)) } } } diff --git a/pkg/proxy/backend/backend_conn_mgr_test.go b/pkg/proxy/backend/backend_conn_mgr_test.go index 05b95b46e..1d1bfe27e 100644 --- a/pkg/proxy/backend/backend_conn_mgr_test.go +++ b/pkg/proxy/backend/backend_conn_mgr_test.go @@ -62,10 +62,10 @@ func (mer *mockEventReceiver) OnRedirectFail(from, to string, conn router.Redire return nil } -func (mer *mockEventReceiver) OnConnClosed(from, to string, conn router.RedirectableConn) error { +func (mer *mockEventReceiver) OnConnClosed(backendID, redirectingBackendID string, conn router.RedirectableConn) error { mer.eventCh <- event{ - from: from, - to: to, + from: backendID, + to: redirectingBackendID, eventName: eventClose, } return nil @@ -80,6 +80,7 @@ func (mer *mockEventReceiver) checkEvent(t *testing.T, eventName int) event { type mockBackendInst struct { addr string keyspace string + cluster string healthy atomic.Bool local atomic.Bool } @@ -97,6 +98,10 @@ func (mbi *mockBackendInst) Addr() string { return mbi.addr } +func (mbi *mockBackendInst) ID() string { + return mbi.addr +} + func (mbi *mockBackendInst) Healthy() bool { return mbi.healthy.Load() } @@ -121,6 +126,10 @@ func (mbi *mockBackendInst) setKeyspace(k string) { mbi.keyspace = k } +func (mbi *mockBackendInst) ClusterName() string { + return mbi.cluster +} + type runner struct { client func(packetIO pnet.PacketIO) error proxy func(clientIO, backendIO pnet.PacketIO) error @@ -1043,6 +1052,41 @@ func TestGetBackendIO(t *testing.T) { } } +func TestGetBackendIOUsesBackendDialContext(t *testing.T) { + listener, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer func() { require.NoError(t, listener.Close()) }() + + rt := router.NewStaticRouter([]string{"tidb-a.test:4000"}) + handler := &CustomHandshakeHandler{ + getRouter: func(ctx ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + return rt, nil + }, + } + lg, _ := logger.CreateLoggerForTest(t) + var gotCluster, gotAddr string + mgr := NewBackendConnManager(lg, handler, &mockCapture{}, 0, &BCConfig{ + ConnectTimeout: time.Second, + DialContext: func(ctx context.Context, backendInst router.BackendInst, addr string) (net.Conn, error) { + gotCluster = backendInst.ClusterName() + gotAddr = addr + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", listener.Addr().String()) + }, + }, nil) + + go func() { + conn, err := listener.Accept() + require.NoError(t, err) + require.NoError(t, conn.Close()) + }() + io, err := mgr.getBackendIO(context.Background(), mgr, nil) + require.NoError(t, err) + require.NoError(t, io.Close()) + require.Empty(t, gotCluster) + require.Equal(t, "tidb-a.test:4000", gotAddr) +} + func TestBackendInactive(t *testing.T) { ts := newBackendMgrTester(t, func(config *testConfig) { config.proxyConfig.bcConfig.TickerInterval = time.Millisecond diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index 59c89fcde..bafc3c3a4 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -12,6 +12,7 @@ import ( "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" + "github.com/pingcap/tiproxy/pkg/balance/router" "github.com/pingcap/tiproxy/pkg/manager/cert" "github.com/pingcap/tiproxy/pkg/manager/id" "github.com/pingcap/tiproxy/pkg/metrics" @@ -40,6 +41,10 @@ type serverState struct { gracefulClose int // graceful-close-conn-timeout } +type BackendDialer interface { + DialContext(ctx context.Context, network, addr, clusterName string) (net.Conn, error) +} + type SQLServer struct { listeners []net.Listener addrs []string @@ -49,6 +54,7 @@ type SQLServer struct { hsHandler backend.HandshakeHandler cpt capture.Capture meter backend.Meter + dialer BackendDialer wg waitgroup.WaitGroup cancelFunc context.CancelFunc @@ -108,6 +114,10 @@ func (s *SQLServer) reset(cfg *config.Config) { s.mu.Unlock() } +func (s *SQLServer) SetBackendDialer(dialer BackendDialer) { + s.dialer = dialer +} + func (s *SQLServer) Run(ctx context.Context, cfgch <-chan *config.Config) { // Create another context because it still needs to run after graceful shutdown. ctx, s.cancelFunc = context.WithCancel(context.Background()) @@ -176,6 +186,13 @@ func (s *SQLServer) onConn(ctx context.Context, conn net.Conn, addr string) { UnhealthyKeepAlive: s.mu.unhealthyKeepAlive, ConnBufferSize: s.mu.connBufferSize, FromPublicEndpoints: s.fromPublicEndpoint, + DialContext: func(ctx context.Context, backendInst router.BackendInst, addr string) (net.Conn, error) { + if s.dialer != nil { + return s.dialer.DialContext(ctx, "tcp", addr, backendInst.ClusterName()) + } + var dialer net.Dialer + return dialer.DialContext(ctx, "tcp", addr) + }, }, s.meter) s.mu.clients[connID] = clientConn logger.Debug("new connection", zap.Bool("proxy-protocol", s.mu.proxyProtocol), zap.Bool("require_backend_tls", s.mu.requireBackendTLS)) diff --git a/pkg/server/api/backend.go b/pkg/server/api/backend.go index e7121e5bf..fdc426a36 100644 --- a/pkg/server/api/backend.go +++ b/pkg/server/api/backend.go @@ -11,11 +11,11 @@ import ( ) type BackendReader interface { - GetBackendMetrics() []byte + GetBackendMetricsByCluster(cluster string) []byte } func (h *Server) BackendMetrics(c *gin.Context) { - metrics := h.mgr.BackendReader.GetBackendMetrics() + metrics := h.mgr.BackendReader.GetBackendMetricsByCluster(c.Query("cluster")) c.Writer.Header().Set("Content-Type", "application/json") c.Writer.WriteHeader(http.StatusOK) if _, err := c.Writer.Write(metrics); err != nil { diff --git a/pkg/server/api/backend_test.go b/pkg/server/api/backend_test.go index 06e8c5be7..2a4dc53d7 100644 --- a/pkg/server/api/backend_test.go +++ b/pkg/server/api/backend_test.go @@ -35,19 +35,37 @@ func TestBackendMetrics(t *testing.T) { mbr := server.mgr.BackendReader.(*mockBackendReader) for _, tt := range tests { mbr.data.Store(string(tt.data)) + mbr.cluster.Store("") doHTTP(t, http.MethodGet, "/api/backend/metrics", httpOpts{}, func(t *testing.T, r *http.Response) { all, err := io.ReadAll(r.Body) require.NoError(t, err) require.Equal(t, tt.expect, all) require.Equal(t, http.StatusOK, r.StatusCode) + require.Empty(t, mbr.cluster.Load()) }) } } +func TestBackendMetricsClusterQuery(t *testing.T) { + server, doHTTP := createServer(t) + mbr := server.mgr.BackendReader.(*mockBackendReader) + mbr.data.Store(`{"key":"value"}`) + + doHTTP(t, http.MethodGet, "/api/backend/metrics?cluster=cluster-a", httpOpts{}, func(t *testing.T, r *http.Response) { + all, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.Equal(t, []byte(`{"key":"value"}`), all) + require.Equal(t, http.StatusOK, r.StatusCode) + require.Equal(t, "cluster-a", mbr.cluster.Load()) + }) +} + type mockBackendReader struct { - data atomic.String + data atomic.String + cluster atomic.String } -func (mbr *mockBackendReader) GetBackendMetrics() []byte { +func (mbr *mockBackendReader) GetBackendMetricsByCluster(cluster string) []byte { + mbr.cluster.Store(cluster) return []byte(mbr.data.Load()) } diff --git a/pkg/server/api/mock_test.go b/pkg/server/api/mock_test.go index 758a285a7..9009c9a9d 100644 --- a/pkg/server/api/mock_test.go +++ b/pkg/server/api/mock_test.go @@ -29,8 +29,11 @@ func newMockNamespaceManager() *mockNamespaceManager { return mgr } +func (m *mockNamespaceManager) SetBackendNetwork(_ observer.BackendNetwork) { +} + func (m *mockNamespaceManager) Init(_ *zap.Logger, _ []*config.Namespace, _ observer.TopologyFetcher, - _ metricsreader.PromInfoFetcher, _ *http.Client, _ *mconfig.ConfigManager, _ metricsreader.MetricsReader) error { + _ metricsreader.PromInfoFetcher, _ *http.Client, _ *mconfig.ConfigManager, _ metricsreader.MetricsQuerier) error { return nil } diff --git a/pkg/server/server.go b/pkg/server/server.go index ab80c8912..0a783e086 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -10,11 +10,10 @@ import ( "github.com/pingcap/tiproxy/lib/config" "github.com/pingcap/tiproxy/lib/util/errors" - "github.com/pingcap/tiproxy/pkg/balance/metricsreader" + "github.com/pingcap/tiproxy/pkg/manager/backendcluster" "github.com/pingcap/tiproxy/pkg/manager/cert" mgrcfg "github.com/pingcap/tiproxy/pkg/manager/config" "github.com/pingcap/tiproxy/pkg/manager/id" - "github.com/pingcap/tiproxy/pkg/manager/infosync" "github.com/pingcap/tiproxy/pkg/manager/logger" "github.com/pingcap/tiproxy/pkg/manager/memory" "github.com/pingcap/tiproxy/pkg/manager/meter" @@ -26,7 +25,6 @@ import ( "github.com/pingcap/tiproxy/pkg/sctx" "github.com/pingcap/tiproxy/pkg/server/api" mgrrp "github.com/pingcap/tiproxy/pkg/sqlreplay/manager" - "github.com/pingcap/tiproxy/pkg/util/etcd" "github.com/pingcap/tiproxy/pkg/util/http" "github.com/pingcap/tiproxy/pkg/util/versioninfo" "github.com/pingcap/tiproxy/pkg/util/waitgroup" @@ -43,14 +41,11 @@ type Server struct { metricsManager *metrics.MetricsManager loggerManager *logger.LoggerManager certManager *cert.CertManager + clusterManager *backendcluster.Manager vipManager vip.VIPManager - infoSyncer *infosync.InfoSyncer - metricsReader metricsreader.MetricsReader replay mgrrp.JobManager meter *meter.Meter memManager *memory.MemManager - // etcd client - etcdCli *clientv3.Client // HTTP client httpCli *http.Client // HTTP server @@ -107,36 +102,24 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) return } - // setup etcd client - srv.etcdCli, err = etcd.InitEtcdClient(lg.Named("etcd"), cfg, srv.certManager) - if err != nil { + // setup backend cluster manager + srv.clusterManager = backendcluster.NewManager(lg.Named("backendcluster"), srv.certManager.ClusterTLS) + if err = srv.clusterManager.Start(ctx, srv.configManager, srv.configManager.WatchConfig()); err != nil { return } + var vipEtcdCli *clientv3.Client + if cluster := srv.clusterManager.PrimaryCluster(); cluster != nil { + vipEtcdCli = cluster.EtcdClient() + } // general cluster HTTP client { srv.httpCli = http.NewHTTPClient(srv.certManager.ClusterTLS) } - // setup info syncer - if cfg.Proxy.PDAddrs != "" { - srv.infoSyncer = infosync.NewInfoSyncer(lg.Named("infosync"), srv.etcdCli) - if err = srv.infoSyncer.Init(ctx, cfg); err != nil { - return - } - } - - // setup metrics reader - { - healthCheckCfg := config.NewDefaultHealthCheckConfig() - srv.metricsReader = metricsreader.NewDefaultMetricsReader(lg.Named("mr"), srv.infoSyncer, srv.infoSyncer, srv.httpCli, srv.etcdCli, healthCheckCfg, srv.configManager) - if err = srv.metricsReader.Start(ctx); err != nil { - return - } - } - // setup namespace manager { + srv.namespaceManager.SetBackendNetwork(srv.clusterManager.NetworkRouter()) nscs, nerr := srv.configManager.ListAllNamespace(ctx) if nerr != nil { err = nerr @@ -157,7 +140,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) nscs = append(nscs, nsc) } - err = srv.namespaceManager.Init(lg.Named("nsmgr"), nscs, srv.infoSyncer, srv.infoSyncer, srv.httpCli, srv.configManager, srv.metricsReader) + err = srv.namespaceManager.Init(lg.Named("nsmgr"), nscs, srv.clusterManager, nil, srv.httpCli, srv.configManager, srv.clusterManager.MetricsQuerier()) if err != nil { return } @@ -192,6 +175,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) if err != nil { return } + srv.proxy.SetBackendDialer(srv.clusterManager.NetworkRouter()) srv.proxy.Run(ctx, srv.configManager.WatchConfig()) } @@ -200,7 +184,7 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) CfgMgr: srv.configManager, NsMgr: srv.namespaceManager, CertMgr: srv.certManager, - BackendReader: srv.metricsReader, + BackendReader: srv.clusterManager.MetricsQuerier(), ReplayJobMgr: srv.replay, } if srv.apiServer, err = api.NewServer(cfg.API, lg.Named("api"), mgrs, handler, ready); err != nil { @@ -214,8 +198,12 @@ func NewServer(ctx context.Context, sctx *sctx.Context) (srv *Server, err error) return } if srv.vipManager != nil && !reflect.ValueOf(srv.vipManager).IsNil() { - if err = srv.vipManager.Start(ctx, srv.etcdCli); err != nil { - return + if vipEtcdCli != nil { + if err = srv.vipManager.Start(ctx, vipEtcdCli); err != nil { + return + } + } else { + lg.Info("VIP is disabled because backend cluster count is not 1") } } } @@ -247,9 +235,8 @@ func (s *Server) preClose() { if s.apiServer != nil { s.apiServer.PreClose() } - // Resign the metric reader owner to make other members campaign ASAP. - if s.metricsReader != nil && !reflect.ValueOf(s.metricsReader).IsNil() { - s.metricsReader.PreClose() + if s.clusterManager != nil { + s.clusterManager.PreClose() } // Gracefully drain clients. if s.proxy != nil { @@ -277,15 +264,9 @@ func (s *Server) Close() error { if s.namespaceManager != nil { errs = append(errs, s.namespaceManager.Close()) } - if s.metricsReader != nil && !reflect.ValueOf(s.metricsReader).IsNil() { - s.metricsReader.Close() - } if s.memManager != nil { s.memManager.Close() } - if s.infoSyncer != nil { - errs = append(errs, s.infoSyncer.Close()) - } if s.configManager != nil { errs = append(errs, s.configManager.Close()) } @@ -295,8 +276,8 @@ func (s *Server) Close() error { if s.loggerManager != nil { errs = append(errs, s.loggerManager.Close()) } - if s.etcdCli != nil { - errs = append(errs, s.etcdCli.Close()) + if s.clusterManager != nil { + errs = append(errs, s.clusterManager.Close()) } s.wg.Wait() return errors.Collect(ErrCloseServer, errs...) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 98f6bcd6a..e45a3f840 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -12,10 +12,14 @@ import ( "github.com/pingcap/tiproxy/lib/util/logger" "github.com/pingcap/tiproxy/pkg/sctx" "github.com/pingcap/tiproxy/pkg/util/etcd" + "github.com/prometheus/client_golang/prometheus" "github.com/stretchr/testify/require" ) func TestServer(t *testing.T) { + restore := resetPromRegistry() + defer restore() + dir := t.TempDir() lg, _ := logger.CreateLoggerForTest(t) etcdServer, err := etcd.CreateEtcdServer("0.0.0.0:0", dir, lg) @@ -34,3 +38,30 @@ func TestServer(t *testing.T) { require.NoError(t, server.Close()) etcdServer.Close() } + +func TestServerWithoutBackendCluster(t *testing.T) { + restore := resetPromRegistry() + defer restore() + + dir := t.TempDir() + configFile := dir + "/config.toml" + require.NoError(t, os.WriteFile(configFile, []byte("[proxy]\npd-addrs = \"\"\n"), 0o644)) + + server, err := NewServer(context.Background(), &sctx.Context{ + ConfigFile: configFile, + }) + require.NoError(t, err) + require.NoError(t, server.Close()) +} + +func resetPromRegistry() func() { + registry := prometheus.NewRegistry() + oldRegisterer := prometheus.DefaultRegisterer + oldGatherer := prometheus.DefaultGatherer + prometheus.DefaultRegisterer = registry + prometheus.DefaultGatherer = registry + return func() { + prometheus.DefaultRegisterer = oldRegisterer + prometheus.DefaultGatherer = oldGatherer + } +} diff --git a/pkg/testkit/dns_server.go b/pkg/testkit/dns_server.go new file mode 100644 index 000000000..686dd15db --- /dev/null +++ b/pkg/testkit/dns_server.go @@ -0,0 +1,137 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package testkit + +import ( + "net" + "strings" + "sync" + "testing" + + "github.com/pingcap/tiproxy/lib/util/waitgroup" + "github.com/stretchr/testify/require" + "golang.org/x/net/dns/dnsmessage" +) + +type DNSServer struct { + conn *net.UDPConn + records map[string][]net.IP + mu sync.Mutex + queries map[string]int + wg waitgroup.WaitGroup +} + +func StartDNSServer(t *testing.T, records map[string][]string) *DNSServer { + t.Helper() + conn, err := net.ListenUDP("udp", &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 0}) + require.NoError(t, err) + + server := &DNSServer{ + conn: conn, + records: make(map[string][]net.IP, len(records)), + queries: make(map[string]int), + } + for name, ips := range records { + key := normalizeDNSName(name) + server.records[key] = make([]net.IP, 0, len(ips)) + for _, ip := range ips { + server.records[key] = append(server.records[key], net.ParseIP(ip)) + } + } + server.wg.Run(func() { + server.serve() + }) + t.Cleanup(func() { + require.NoError(t, server.Close()) + }) + return server +} + +func (s *DNSServer) Addr() string { + return s.conn.LocalAddr().String() +} + +func (s *DNSServer) QueryCount(name string) int { + s.mu.Lock() + defer s.mu.Unlock() + return s.queries[normalizeDNSName(name)] +} + +func (s *DNSServer) Close() error { + if s.conn != nil { + err := s.conn.Close() + s.wg.Wait() + return err + } + return nil +} + +func (s *DNSServer) serve() { + buf := make([]byte, 1500) + for { + n, addr, err := s.conn.ReadFromUDP(buf) + if err != nil { + return + } + resp, err := s.handleQuery(buf[:n]) + if err != nil { + continue + } + _, _ = s.conn.WriteToUDP(resp, addr) + } +} + +func (s *DNSServer) handleQuery(pkt []byte) ([]byte, error) { + var parser dnsmessage.Parser + header, err := parser.Start(pkt) + if err != nil { + return nil, err + } + question, err := parser.Question() + if err != nil { + return nil, err + } + name := normalizeDNSName(question.Name.String()) + s.mu.Lock() + s.queries[name]++ + s.mu.Unlock() + + respHeader := dnsmessage.Header{ + ID: header.ID, + Response: true, + RecursionAvailable: true, + } + builder := dnsmessage.NewBuilder(nil, respHeader) + builder.EnableCompression() + if err := builder.StartQuestions(); err != nil { + return nil, err + } + if err := builder.Question(question); err != nil { + return nil, err + } + if err := builder.StartAnswers(); err != nil { + return nil, err + } + for _, ip := range s.records[name] { + if ipv4 := ip.To4(); ipv4 != nil && question.Type == dnsmessage.TypeA { + resource := dnsmessage.Resource{ + Header: dnsmessage.ResourceHeader{ + Name: question.Name, + Type: dnsmessage.TypeA, + Class: dnsmessage.ClassINET, + TTL: 60, + }, + Body: &dnsmessage.AResource{A: [4]byte(ipv4)}, + } + if err := builder.AResource(resource.Header, *resource.Body.(*dnsmessage.AResource)); err != nil { + return nil, err + } + } + } + return builder.Finish() +} + +func normalizeDNSName(name string) string { + return strings.TrimSuffix(strings.ToLower(name), ".") +} diff --git a/pkg/util/etcd/etcd.go b/pkg/util/etcd/etcd.go index 1508e9141..de1e51430 100644 --- a/pkg/util/etcd/etcd.go +++ b/pkg/util/etcd/etcd.go @@ -5,7 +5,9 @@ package etcd import ( "context" + "crypto/tls" "fmt" + "net" "net/url" "strings" "time" @@ -14,6 +16,7 @@ import ( "github.com/pingcap/tiproxy/lib/util/errors" "github.com/pingcap/tiproxy/lib/util/retry" "github.com/pingcap/tiproxy/pkg/manager/cert" + "github.com/pingcap/tiproxy/pkg/util/netutil" "go.etcd.io/etcd/api/v3/mvccpb" "go.etcd.io/etcd/client/pkg/v3/transport" clientv3 "go.etcd.io/etcd/client/v3" @@ -31,31 +34,53 @@ func InitEtcdClient(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertMa // use tidb server addresses directly return nil, nil } - pdEndpoints := strings.Split(pdAddr, ",") + return InitEtcdClientWithAddrs(logger, pdAddr, certMgr.ClusterTLS()) +} + +// InitEtcdClientWithAddrs initializes an etcd client that connects to PD ETCD servers. +func InitEtcdClientWithAddrs(logger *zap.Logger, pdAddrs string, tlsConfig *tls.Config) (*clientv3.Client, error) { + return InitEtcdClientWithAddrsAndDialer(logger, pdAddrs, tlsConfig, nil) +} + +func InitEtcdClientWithAddrsAndDialer(logger *zap.Logger, pdAddrs string, tlsConfig *tls.Config, + dnsDialer *netutil.DNSDialer) (*clientv3.Client, error) { + pdEndpoints := strings.Split(pdAddrs, ",") logger.Info("connect ETCD servers", zap.Strings("addrs", pdEndpoints)) + dialOptions := []grpc.DialOption{ + grpc.WithKeepaliveParams(keepalive.ClientParameters{ + Time: 10 * time.Second, + Timeout: 3 * time.Second, + }), + grpc.WithConnectParams(grpc.ConnectParams{ + Backoff: backoff.Config{ + BaseDelay: time.Second, + Multiplier: 1.1, + Jitter: 0.1, + MaxDelay: 3 * time.Second, + }, + MinConnectTimeout: 3 * time.Second, + }), + } + if dnsDialer != nil { + dialOptions = append(dialOptions, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) { + return dnsDialer.DialContext(ctx, "tcp", addr) + })) + } etcdClient, err := clientv3.New(clientv3.Config{ Endpoints: pdEndpoints, - TLS: certMgr.ClusterTLS(), + TLS: tlsConfig, Logger: logger.Named("etcdcli"), AutoSyncInterval: 30 * time.Second, DialTimeout: 5 * time.Second, - DialOptions: []grpc.DialOption{ - grpc.WithKeepaliveParams(keepalive.ClientParameters{ - Time: 10 * time.Second, - Timeout: 3 * time.Second, - }), - grpc.WithConnectParams(grpc.ConnectParams{ - Backoff: backoff.Config{ - BaseDelay: time.Second, - Multiplier: 1.1, - Jitter: 0.1, - MaxDelay: 3 * time.Second, - }, - MinConnectTimeout: 3 * time.Second, - }), - }, + DialOptions: dialOptions, }) - return etcdClient, errors.Wrapf(err, "init etcd client failed") + if err != nil { + return nil, errors.Wrapf(err, "init etcd client failed") + } + if err := syncEtcdClient(context.Background(), etcdClient); err != nil { + logger.Warn("sync ETCD member endpoints after init failed", zap.Error(err)) + } + return etcdClient, nil } func GetKVs(ctx context.Context, etcdCli *clientv3.Client, key string, opts []clientv3.OpOption, timeout, retryIntvl time.Duration, retryCnt uint64) ([]*mvccpb.KeyValue, error) { @@ -75,7 +100,15 @@ func GetKVs(ctx context.Context, etcdCli *clientv3.Client, key string, opts []cl // CreateEtcdServer creates an etcd server and is only used for testing. func CreateEtcdServer(addr, dir string, lg *zap.Logger) (*embed.Etcd, error) { - serverURL, err := url.Parse(fmt.Sprintf("http://%s", addr)) + listenAddr, advertiseAddr, err := allocEtcdServerAddr(addr) + if err != nil { + return nil, err + } + serverURL, err := url.Parse(fmt.Sprintf("http://%s", listenAddr)) + if err != nil { + return nil, err + } + advertiseURL, err := url.Parse(fmt.Sprintf("http://%s", advertiseAddr)) if err != nil { return nil, err } @@ -83,6 +116,9 @@ func CreateEtcdServer(addr, dir string, lg *zap.Logger) (*embed.Etcd, error) { cfg.Dir = dir cfg.ListenClientUrls = []url.URL{*serverURL} cfg.ListenPeerUrls = []url.URL{*serverURL} + cfg.AdvertiseClientUrls = []url.URL{*advertiseURL} + cfg.AdvertisePeerUrls = []url.URL{*advertiseURL} + cfg.InitialCluster = fmt.Sprintf("%s=%s", cfg.Name, advertiseURL.String()) cfg.ZapLoggerBuilder = embed.NewZapLoggerBuilder(lg) cfg.LogLevel = "fatal" // Reuse port so that it can reboot with the same port immediately. @@ -98,6 +134,30 @@ func CreateEtcdServer(addr, dir string, lg *zap.Logger) (*embed.Etcd, error) { return etcd, err } +func allocEtcdServerAddr(addr string) (listenAddr, advertiseAddr string, err error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return "", "", err + } + if host == "" || host == "0.0.0.0" || host == "::" { + host = "127.0.0.1" + } + if port != "0" { + return net.JoinHostPort(host, port), net.JoinHostPort(host, port), nil + } + ln, err := net.Listen("tcp", net.JoinHostPort(host, "0")) + if err != nil { + return "", "", err + } + defer func() { + closeErr := ln.Close() + if err == nil && closeErr != nil { + err = closeErr + } + }() + return ln.Addr().String(), ln.Addr().String(), nil +} + func ConfigForEtcdTest(endpoint string) *config.Config { return &config.Config{ Proxy: config.ProxyServer{ @@ -109,3 +169,13 @@ func ConfigForEtcdTest(endpoint string) *config.Config { }, } } + +type etcdSyncer interface { + Sync(ctx context.Context) error +} + +func syncEtcdClient(ctx context.Context, cli etcdSyncer) error { + syncCtx, cancel := context.WithTimeout(ctx, 5*time.Second) + defer cancel() + return errors.WithStack(cli.Sync(syncCtx)) +} diff --git a/pkg/util/etcd/etcd_test.go b/pkg/util/etcd/etcd_test.go index a80b7eba6..98e58bb12 100644 --- a/pkg/util/etcd/etcd_test.go +++ b/pkg/util/etcd/etcd_test.go @@ -35,3 +35,26 @@ func TestEtcdClient(t *testing.T) { require.NoError(t, client.Close()) server.Close() } + +func TestSyncEtcdClient(t *testing.T) { + err := syncEtcdClient(context.Background(), &mockEtcdSyncer{}) + require.NoError(t, err) +} + +func TestSyncEtcdClientTimeout(t *testing.T) { + err := syncEtcdClient(context.Background(), &mockEtcdSyncer{block: true}) + require.Error(t, err) + require.ErrorIs(t, err, context.DeadlineExceeded) +} + +type mockEtcdSyncer struct { + block bool +} + +func (m *mockEtcdSyncer) Sync(ctx context.Context) error { + if !m.block { + return nil + } + <-ctx.Done() + return ctx.Err() +} diff --git a/pkg/util/http/http.go b/pkg/util/http/http.go index c4f5890f7..d9a2e150b 100644 --- a/pkg/util/http/http.go +++ b/pkg/util/http/http.go @@ -4,9 +4,11 @@ package http import ( + "context" "crypto/tls" "fmt" "io" + "net" "net/http" "time" @@ -21,11 +23,18 @@ type Client struct { } func NewHTTPClient(getTLSConfig func() *tls.Config) *Client { + return NewHTTPClientWithDialContext(getTLSConfig, nil) +} + +func NewHTTPClientWithDialContext(getTLSConfig func() *tls.Config, dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) *Client { // Since TLS config will hot reload, `TLSClientConfig` need update by `getTLSConfig()` // to obtain the latest TLS config. return &Client{ cli: &http.Client{ - Transport: &http.Transport{TLSClientConfig: getTLSConfig()}, + Transport: &http.Transport{ + TLSClientConfig: getTLSConfig(), + DialContext: dialContext, + }, }, getTLSConfig: getTLSConfig, } diff --git a/pkg/util/netutil/dns.go b/pkg/util/netutil/dns.go new file mode 100644 index 000000000..a26bc5b61 --- /dev/null +++ b/pkg/util/netutil/dns.go @@ -0,0 +1,109 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package netutil + +import ( + "context" + "net" + "strings" + "sync" + "sync/atomic" + "time" +) + +const defaultDNSCacheTTL = 5 * time.Second + +type dnsCacheEntry struct { + ips []net.IP + deadline time.Time +} + +// DNSDialer routes DNS lookups to configured name servers and caches lookup results briefly. +// If no name servers are configured, it falls back to the system resolver and dialer. +type DNSDialer struct { + cacheTTL time.Duration + nameServer []string + resolver *net.Resolver + dialer net.Dialer + nextServer atomic.Uint64 + mu struct { + sync.Mutex + cacheMap map[string]dnsCacheEntry + } +} + +func NewDNSDialer(nameServers []string) *DNSDialer { + d := &DNSDialer{ + cacheTTL: defaultDNSCacheTTL, + nameServer: append([]string(nil), nameServers...), + mu: struct { + sync.Mutex + cacheMap map[string]dnsCacheEntry + }{ + cacheMap: make(map[string]dnsCacheEntry), + }, + } + if len(nameServers) == 0 { + return d + } + d.resolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, _ string) (net.Conn, error) { + server := d.nameServer[int(d.nextServer.Add(1)-1)%len(d.nameServer)] + return d.dialer.DialContext(ctx, network, server) + }, + } + return d +} + +func (d *DNSDialer) DialContext(ctx context.Context, network, addr string) (net.Conn, error) { + host, port, err := net.SplitHostPort(addr) + if err != nil { + return nil, err + } + if ip := net.ParseIP(host); ip != nil || d.resolver == nil { + return d.dialer.DialContext(ctx, network, addr) + } + ips, err := d.lookupNetIP(ctx, host) + if err != nil { + return nil, err + } + var dialErr error + for _, ip := range ips { + conn, err := d.dialer.DialContext(ctx, network, net.JoinHostPort(ip.String(), port)) + if err == nil { + return conn, nil + } + dialErr = err + } + return nil, dialErr +} + +func (d *DNSDialer) lookupNetIP(ctx context.Context, host string) ([]net.IP, error) { + key := strings.TrimSuffix(strings.ToLower(host), ".") + now := time.Now() + d.mu.Lock() + if entry, ok := d.mu.cacheMap[key]; ok && now.Before(entry.deadline) { + ips := entry.ips + d.mu.Unlock() + return ips, nil + } + d.mu.Unlock() + + ips, err := d.resolver.LookupNetIP(ctx, "ip", host) + if err != nil { + return nil, err + } + ipList := make([]net.IP, 0, len(ips)) + for _, ip := range ips { + ipList = append(ipList, append(net.IP(nil), ip.AsSlice()...)) + } + d.mu.Lock() + d.mu.cacheMap[key] = dnsCacheEntry{ + ips: ipList, + deadline: now.Add(d.cacheTTL), + } + d.mu.Unlock() + return ipList, nil +} diff --git a/pkg/util/netutil/dns_test.go b/pkg/util/netutil/dns_test.go new file mode 100644 index 000000000..2870b1537 --- /dev/null +++ b/pkg/util/netutil/dns_test.go @@ -0,0 +1,119 @@ +// Copyright 2026 PingCAP, Inc. +// SPDX-License-Identifier: Apache-2.0 + +package netutil + +import ( + "context" + "net" + "strconv" + "testing" + "time" + + "github.com/pingcap/tiproxy/pkg/testkit" + "github.com/stretchr/testify/require" +) + +func TestDNSDialerUsesConfiguredNameServerAndCache(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + dns := testkit.StartDNSServer(t, map[string][]string{ + "tidb.test": {"127.0.0.1"}, + }) + + accepted := make(chan error, 2) + for range 2 { + go func() { + conn, err := listener.Accept() + if err != nil { + accepted <- err + return + } + accepted <- conn.Close() + }() + } + + dialer := NewDNSDialer([]string{dns.Addr()}) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("tidb.test", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + queryCount := dns.QueryCount("tidb.test") + require.Greater(t, queryCount, 0) + + conn, err = dialer.DialContext(ctx, "tcp", net.JoinHostPort("tidb.test", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.Equal(t, queryCount, dns.QueryCount("tidb.test")) + require.NoError(t, <-accepted) + require.NoError(t, <-accepted) +} + +func TestDNSDialerFallbackToSystemResolver(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + accepted := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + accepted <- err + return + } + accepted <- conn.Close() + }() + + dialer := NewDNSDialer(nil) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("localhost", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.NoError(t, <-accepted) +} + +func TestDNSDialerTriesAllResolvedIPs(t *testing.T) { + listener, addr := testkit.StartListener(t, "127.0.0.1:0") + t.Cleanup(func() { require.NoError(t, listener.Close()) }) + _, port := testkit.ParseHostPort(t, addr) + dns := testkit.StartDNSServer(t, map[string][]string{ + "tidb.test": {"127.0.0.2", "127.0.0.1"}, + }) + + accepted := make(chan struct{}, 1) + acceptErr := make(chan error, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + acceptErr <- err + return + } + if err := conn.Close(); err != nil { + acceptErr <- err + return + } + accepted <- struct{}{} + }() + + dialer := NewDNSDialer([]string{dns.Addr()}) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + conn, err := dialer.DialContext(ctx, "tcp", net.JoinHostPort("tidb.test", strconv.Itoa(int(port)))) + require.NoError(t, err) + require.NoError(t, conn.Close()) + + select { + case <-accepted: + case <-time.After(time.Second): + t.Fatal("listener was not reached through resolved fallback IP") + } + select { + case err := <-acceptErr: + require.NoError(t, err) + default: + } +}