From 8f5282100a28100f1e4164e66399ea87f8538b11 Mon Sep 17 00:00:00 2001 From: Yang Keao Date: Tue, 17 Mar 2026 23:11:39 +0800 Subject: [PATCH] config: add backend cluster schema and multi-port listener config Introduce backend cluster config, compatibility helpers, validation, and SQL listener support for proxy.port-range. --- lib/config/balance.go | 4 +- lib/config/balance_test.go | 7 ++ lib/config/proxy.go | 164 +++++++++++++++++++++++++++++++++- lib/config/proxy_test.go | 153 +++++++++++++++++++++++++++++++ pkg/proxy/proxy.go | 7 +- pkg/proxy/proxy_test.go | 115 +++++++++++++++++++++++- pkg/server/api/config_test.go | 83 +++++++++++++++++ 7 files changed, 526 insertions(+), 7 deletions(-) diff --git a/lib/config/balance.go b/lib/config/balance.go index f4b03aeb9..fd0424856 100644 --- a/lib/config/balance.go +++ b/lib/config/balance.go @@ -18,6 +18,8 @@ const ( MatchClientCIDRStr = "client_cidr" // MatchProxyCIDRStr is used for MatchProxyCIDR. MatchProxyCIDRStr = "proxy_cidr" + // MatchPortStr is used for port-based routing. + MatchPortStr = "port" ) type Balance struct { @@ -52,7 +54,7 @@ func (b *Balance) Check() error { } switch b.RoutingRule { - case MatchClientCIDRStr, MatchProxyCIDRStr, "": + case MatchClientCIDRStr, MatchProxyCIDRStr, MatchPortStr, "": default: return errors.Wrapf(ErrInvalidConfigValue, "invalid balance.routing-rule") } diff --git a/lib/config/balance_test.go b/lib/config/balance_test.go index 7940ad413..d246ff764 100644 --- a/lib/config/balance_test.go +++ b/lib/config/balance_test.go @@ -48,4 +48,11 @@ func TestCheckBalance(t *testing.T) { require.NoError(t, (&balance).Check()) balance = DefaultBalance() require.NoError(t, (&balance).Check()) + + balance = DefaultBalance() + balance.RoutingRule = MatchPortStr + require.NoError(t, balance.Check()) + + balance.RoutingRule = "unknown" + require.ErrorIs(t, balance.Check(), ErrInvalidConfigValue) } diff --git a/lib/config/proxy.go b/lib/config/proxy.go index d35932afb..43fc0573e 100644 --- a/lib/config/proxy.go +++ b/lib/config/proxy.go @@ -9,6 +9,8 @@ import ( "net" "os" "path/filepath" + "slices" + "strconv" "strings" "time" @@ -63,15 +65,25 @@ type ProxyServerOnline struct { GracefulCloseConnTimeout int `yaml:"graceful-close-conn-timeout,omitempty" toml:"graceful-close-conn-timeout,omitempty" json:"graceful-close-conn-timeout,omitempty" reloadable:"true"` // Public and private traffic are metered separately. PublicEndpoints []string `yaml:"public-endpoints,omitempty" toml:"public-endpoints,omitempty" json:"public-endpoints,omitempty" reloadable:"true"` + // BackendClusters represents multiple backend clusters that the proxy can route to. It can be reloaded + // online. + BackendClusters []BackendCluster `yaml:"backend-clusters,omitempty" toml:"backend-clusters,omitempty" json:"backend-clusters,omitempty" reloadable:"true"` } type ProxyServer struct { Addr string `yaml:"addr,omitempty" toml:"addr,omitempty" json:"addr,omitempty" reloadable:"false"` AdvertiseAddr string `yaml:"advertise-addr,omitempty" toml:"advertise-addr,omitempty" json:"advertise-addr,omitempty" reloadable:"false"` PDAddrs string `yaml:"pd-addrs,omitempty" toml:"pd-addrs,omitempty" json:"pd-addrs,omitempty" reloadable:"false"` + PortRange []int `yaml:"port-range,omitempty" toml:"port-range,omitempty" json:"port-range,omitempty" reloadable:"false"` ProxyServerOnline `yaml:",inline" toml:",inline" json:",inline"` } +type BackendCluster struct { + Name string `yaml:"name,omitempty" toml:"name,omitempty" json:"name,omitempty" reloadable:"true"` + PDAddrs string `yaml:"pd-addrs,omitempty" toml:"pd-addrs,omitempty" json:"pd-addrs,omitempty" reloadable:"true"` + NSServers []string `yaml:"ns-servers,omitempty" toml:"ns-servers,omitempty" json:"ns-servers,omitempty" reloadable:"true"` +} + type API struct { Addr string `yaml:"addr,omitempty" toml:"addr,omitempty" json:"addr,omitempty" reloadable:"false"` ProxyProtocol string `yaml:"proxy-protocol,omitempty" toml:"proxy-protocol,omitempty" json:"proxy-protocol,omitempty" reloadable:"false"` @@ -146,6 +158,11 @@ func NewConfig() *Config { func (cfg *Config) Clone() *Config { newCfg := *cfg newCfg.Labels = maps.Clone(cfg.Labels) + newCfg.Proxy.PublicEndpoints = slices.Clone(cfg.Proxy.PublicEndpoints) + newCfg.Proxy.BackendClusters = slices.Clone(cfg.Proxy.BackendClusters) + for i := range newCfg.Proxy.BackendClusters { + newCfg.Proxy.BackendClusters[i].NSServers = slices.Clone(newCfg.Proxy.BackendClusters[i].NSServers) + } return &newCfg } @@ -168,6 +185,9 @@ func (cfg *Config) Check() error { if cfg.Proxy.ConnBufferSize > 0 && (cfg.Proxy.ConnBufferSize > 16*1024*1024 || cfg.Proxy.ConnBufferSize < 1024) { return errors.Wrapf(ErrInvalidConfigValue, "conn-buffer-size must be between 1K and 16M") } + if err := cfg.Proxy.Check(); err != nil { + return err + } if err := cfg.Balance.Check(); err != nil { return err @@ -183,15 +203,16 @@ func (cfg *Config) ToBytes() ([]byte, error) { } func (cfg *Config) GetIPPort() (ip, port, statusPort string, err error) { - addrs := strings.Split(cfg.Proxy.Addr, ",") + addrs, err := cfg.Proxy.GetSQLAddrs() + if err != nil { + return + } ip, port, err = net.SplitHostPort(addrs[0]) if err != nil { - err = errors.WithStack(err) return } _, statusPort, err = net.SplitHostPort(cfg.API.Addr) if err != nil { - err = errors.WithStack(err) return } // AdvertiseAddr may be a DNS in k8s and certificate SAN typically contains DNS but not IP. @@ -217,3 +238,140 @@ func (cfg *Config) GetIPPort() (ip, port, statusPort string, err error) { } return } + +// GetBackendClusters returns configured backend clusters. +// It keeps backward compatibility for the legacy `proxy.pd-addrs` setting. +func (cfg *Config) GetBackendClusters() []BackendCluster { + if len(cfg.Proxy.BackendClusters) > 0 { + return cfg.Proxy.BackendClusters + } + if strings.TrimSpace(cfg.Proxy.PDAddrs) == "" { + return nil + } + return []BackendCluster{{ + Name: "default", + PDAddrs: cfg.Proxy.PDAddrs, + }} +} + +func (ps *ProxyServer) Check() error { + if _, err := ps.GetSQLAddrs(); err != nil { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr or proxy.port-range: %s", err.Error()) + } + if len(ps.BackendClusters) == 0 { + return nil + } + + clusterNames := make(map[string]struct{}, len(ps.BackendClusters)) + for i, cluster := range ps.BackendClusters { + name := strings.TrimSpace(cluster.Name) + if name == "" { + return errors.Wrapf(ErrInvalidConfigValue, "proxy.backend-clusters[%d].name is empty", i) + } + if _, ok := clusterNames[name]; ok { + return errors.Wrapf(ErrInvalidConfigValue, "duplicate proxy.backend-clusters name %s", name) + } + clusterNames[name] = struct{}{} + if err := validateAddrList(cluster.PDAddrs, "proxy.backend-clusters.pd-addrs"); err != nil { + return err + } + if _, err := ParseNSServers(cluster.NSServers); err != nil { + return errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.backend-clusters.ns-servers: %s", err.Error()) + } + } + return nil +} + +func splitAddrList(addrs string) []string { + parts := strings.Split(addrs, ",") + trimmed := make([]string, 0, len(parts)) + for _, part := range parts { + addr := strings.TrimSpace(part) + if addr != "" { + trimmed = append(trimmed, addr) + } + } + return trimmed +} + +func validateAddrList(addrs, field string) error { + parts := splitAddrList(addrs) + if len(parts) == 0 { + return errors.Wrapf(ErrInvalidConfigValue, "%s is empty", field) + } + for _, addr := range parts { + if _, _, err := net.SplitHostPort(addr); err != nil { + return errors.Wrapf(ErrInvalidConfigValue, "invalid %s address %s", field, addr) + } + } + return nil +} + +func ParseNSServers(nsServers []string) ([]string, error) { + if len(nsServers) == 0 { + return nil, nil + } + normalized := make([]string, 0, len(nsServers)) + for _, server := range nsServers { + addr, err := normalizeNSServer(server) + if err != nil { + return nil, err + } + normalized = append(normalized, addr) + } + return normalized, nil +} + +func normalizeNSServer(server string) (string, error) { + host, port, err := net.SplitHostPort(server) + if err == nil { + if host == "" { + return "", errors.Wrapf(ErrInvalidConfigValue, "host is empty") + } + portNum, err := strconv.Atoi(port) + if err != nil || portNum < 1 || portNum > 65535 { + return "", errors.Wrapf(ErrInvalidConfigValue, "port is invalid") + } + return net.JoinHostPort(host, strconv.Itoa(portNum)), nil + } + + if server == "" { + return "", errors.Wrapf(ErrInvalidConfigValue, "host is empty") + } + if strings.ContainsAny(server, "[]") { + return "", errors.Wrapf(ErrInvalidConfigValue, "host is invalid") + } + return net.JoinHostPort(server, "53"), nil +} + +func (ps *ProxyServer) GetSQLAddrs() ([]string, error) { + addrs := splitAddrList(ps.Addr) + if len(addrs) == 0 { + if len(ps.PortRange) == 0 { + return []string{""}, nil + } + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.addr is empty") + } + if len(ps.PortRange) == 0 { + return addrs, nil + } + if len(ps.PortRange) != 2 { + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.port-range must contain exactly two ports") + } + start, end := ps.PortRange[0], ps.PortRange[1] + if start < 1 || start > 65535 || end < 1 || end > 65535 || start > end { + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.port-range is invalid") + } + if len(addrs) != 1 { + return nil, errors.Wrapf(ErrInvalidConfigValue, "proxy.addr must contain exactly one host when proxy.port-range is set") + } + host, _, err := net.SplitHostPort(addrs[0]) + if err != nil { + return nil, errors.Wrapf(ErrInvalidConfigValue, "invalid proxy.addr: %s", err.Error()) + } + sqlAddrs := make([]string, 0, end-start+1) + for port := start; port <= end; port++ { + sqlAddrs = append(sqlAddrs, net.JoinHostPort(host, strconv.Itoa(port))) + } + return sqlAddrs, nil +} diff --git a/lib/config/proxy_test.go b/lib/config/proxy_test.go index eea4b7f24..a8824643a 100644 --- a/lib/config/proxy_test.go +++ b/lib/config/proxy_test.go @@ -25,6 +25,18 @@ var testProxyConfig = Config{ ProxyProtocol: "v2", GracefulWaitBeforeShutdown: 10, ConnBufferSize: 32 * 1024, + BackendClusters: []BackendCluster{ + { + Name: "cluster-a", + PDAddrs: "127.0.0.1:12379,127.0.0.1:22379", + NSServers: []string{"10.0.0.2", "10.0.0.3"}, + }, + { + Name: "cluster-b", + PDAddrs: "127.0.0.1:32379", + NSServers: []string{"10.0.0.4"}, + }, + }, }, }, API: API{ @@ -112,6 +124,58 @@ func TestProxyCheck(t *testing.T) { }, err: ErrInvalidConfigValue, }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.PortRange = []int{10000} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.PortRange = []int{10000, 9999} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.Addr = "0.0.0.0:6000,0.0.0.0:6001" + c.Proxy.PortRange = []int{10000, 10001} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.BackendClusters = append(c.Proxy.BackendClusters, BackendCluster{}) + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.BackendClusters = []BackendCluster{{Name: "c1", PDAddrs: ""}} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.BackendClusters = []BackendCluster{{Name: "c1", PDAddrs: "127.0.0.1"}} + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.BackendClusters = []BackendCluster{ + {Name: "c1", PDAddrs: "127.0.0.1:2379"}, + {Name: "c1", PDAddrs: "127.0.0.1:2380"}, + } + }, + err: ErrInvalidConfigValue, + }, + { + pre: func(t *testing.T, c *Config) { + c.Proxy.BackendClusters = []BackendCluster{{Name: "c1", PDAddrs: "127.0.0.1:2379", NSServers: []string{"10.0.0.1:abc"}}} + }, + err: ErrInvalidConfigValue, + }, } for _, tc := range testcases { cfg := testProxyConfig @@ -168,11 +232,100 @@ func TestGetIPPort(t *testing.T) { } } +func TestGetSQLAddrs(t *testing.T) { + cfg := NewConfig() + cfg.Proxy.Addr = "0.0.0.0:6000" + cfg.Proxy.PortRange = nil + addrs, err := cfg.Proxy.GetSQLAddrs() + require.NoError(t, err) + require.Equal(t, []string{"0.0.0.0:6000"}, addrs) + + cfg.Proxy.PortRange = []int{10000, 10002} + addrs, err = cfg.Proxy.GetSQLAddrs() + require.NoError(t, err) + require.Equal(t, []string{"0.0.0.0:10000", "0.0.0.0:10001", "0.0.0.0:10002"}, addrs) +} + +func TestParseNSServers(t *testing.T) { + for _, tc := range []struct { + name string + input []string + expected []string + expectErr bool + }{ + { + name: "ipv4 default port", + input: []string{"10.0.0.1"}, + expected: []string{"10.0.0.1:53"}, + }, + { + name: "hostname explicit port", + input: []string{"dns.example.com:5353"}, + expected: []string{"dns.example.com:5353"}, + }, + { + name: "ipv6 default port", + input: []string{"2001:db8::1"}, + expected: []string{"[2001:db8::1]:53"}, + }, + { + name: "bracketed ipv6 without port is invalid", + input: []string{"[2001:db8::1]"}, + expectErr: true, + }, + { + name: "invalid port", + input: []string{"10.0.0.1:abc"}, + expectErr: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + addrs, err := ParseNSServers(tc.input) + if tc.expectErr { + require.Error(t, err) + return + } + require.NoError(t, err) + require.Equal(t, tc.expected, addrs) + }) + } +} + func TestCloneConfig(t *testing.T) { cfg := testProxyConfig cfg.Labels = map[string]string{"a": "b"} + cfg.Proxy.PublicEndpoints = []string{"1.1.1.0/24"} clone := cfg.Clone() require.Equal(t, cfg, *clone) cfg.Labels["c"] = "d" + cfg.Proxy.PublicEndpoints[0] = "2.2.2.0/24" + cfg.Proxy.BackendClusters[0].Name = "cluster-updated" + cfg.Proxy.BackendClusters[0].NSServers[0] = "10.0.0.9" require.NotContains(t, clone.Labels, "c") + require.Equal(t, []string{"1.1.1.0/24"}, clone.Proxy.PublicEndpoints) + require.Equal(t, "cluster-a", clone.Proxy.BackendClusters[0].Name) + require.Equal(t, []string{"10.0.0.2", "10.0.0.3"}, clone.Proxy.BackendClusters[0].NSServers) +} + +func TestGetBackendClusters(t *testing.T) { + cfg := NewConfig() + cfg.Proxy.PDAddrs = "127.0.0.1:2379,127.0.0.2:2379" + cfg.Proxy.BackendClusters = nil + + clusters := cfg.GetBackendClusters() + require.Len(t, clusters, 1) + require.Equal(t, "default", clusters[0].Name) + require.Equal(t, cfg.Proxy.PDAddrs, clusters[0].PDAddrs) + + cfg.Proxy.BackendClusters = []BackendCluster{ + {Name: "cluster-a", PDAddrs: "127.0.0.3:2379"}, + } + clusters = cfg.GetBackendClusters() + require.Len(t, clusters, 1) + require.Equal(t, "cluster-a", clusters[0].Name) + require.Equal(t, "127.0.0.3:2379", clusters[0].PDAddrs) + + cfg.Proxy.BackendClusters = nil + cfg.Proxy.PDAddrs = "" + require.Nil(t, cfg.GetBackendClusters()) } diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go index e06e33187..59c89fcde 100644 --- a/pkg/proxy/proxy.go +++ b/pkg/proxy/proxy.go @@ -7,7 +7,6 @@ import ( "context" "net" "reflect" - "strings" "sync" "time" @@ -74,13 +73,17 @@ func NewSQLServer(logger *zap.Logger, cfg *config.Config, certMgr *cert.CertMana s.reset(cfg) - s.addrs = strings.Split(cfg.Proxy.Addr, ",") + s.addrs, err = cfg.Proxy.GetSQLAddrs() + if err != nil { + return nil, err + } s.listeners = make([]net.Listener, len(s.addrs)) for i, addr := range s.addrs { s.listeners[i], err = net.Listen("tcp", addr) if err != nil { return nil, err } + s.addrs[i] = s.listeners[i].Addr().String() } return s, nil diff --git a/pkg/proxy/proxy_test.go b/pkg/proxy/proxy_test.go index f7780c197..cd75c4c1d 100644 --- a/pkg/proxy/proxy_test.go +++ b/pkg/proxy/proxy_test.go @@ -8,7 +8,9 @@ import ( "database/sql" "fmt" "net" + "slices" "strings" + "sync" "testing" "time" @@ -219,6 +221,84 @@ func TestMultiAddr(t *testing.T) { certManager.Close() } +func TestPortRange(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + err := certManager.Init(&config.Config{}, lg, nil) + require.NoError(t, err) + start, end := findFreePortRange(t, 3) + server, err := NewSQLServer(lg, &config.Config{ + Proxy: config.ProxyServer{ + Addr: fmt.Sprintf("127.0.0.1:%d", start), + PortRange: []int{start, end}, + }, + }, certManager, id.NewIDManager(), nil, nil, &mockHsHandler{}) + require.NoError(t, err) + server.Run(context.Background(), nil) + + require.Len(t, server.listeners, 3) + ports := make([]int, 0, len(server.listeners)) + for _, listener := range server.listeners { + tcpAddr, ok := listener.Addr().(*net.TCPAddr) + require.True(t, ok) + ports = append(ports, tcpAddr.Port) + + conn, err := net.Dial("tcp", listener.Addr().String()) + require.NoError(t, err) + require.NoError(t, conn.Close()) + } + slices.Sort(ports) + require.Equal(t, []int{start, start + 1, end}, ports) + + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() +} + +func TestConnAddrUsesActualListenerAddr(t *testing.T) { + lg, _ := logger.CreateLoggerForTest(t) + certManager := cert.NewCertManager() + require.NoError(t, certManager.Init(&config.Config{}, lg, nil)) + + var ( + addrMu sync.Mutex + connAddr string + ) + handler := &mockHsHandler{ + getRouter: func(ctx backend.ConnContext, _ *pnet.HandshakeResp) (router.Router, error) { + addrMu.Lock() + connAddr, _ = ctx.Value(backend.ConnContextKeyConnAddr).(string) + addrMu.Unlock() + return nil, errors.New("no router") + }, + } + server, err := NewSQLServer(lg, &config.Config{ + Proxy: config.ProxyServer{ + Addr: "127.0.0.1:0", + }, + }, certManager, id.NewIDManager(), nil, nil, handler) + require.NoError(t, err) + server.Run(context.Background(), nil) + defer func() { + server.PreClose() + require.NoError(t, server.Close()) + certManager.Close() + }() + + _, port, err := net.SplitHostPort(server.listeners[0].Addr().String()) + require.NoError(t, err) + mdb, err := sql.Open("mysql", fmt.Sprintf("root@tcp(127.0.0.1:%s)/test", port)) + require.NoError(t, err) + defer func() { require.NoError(t, mdb.Close()) }() + + require.ErrorContains(t, mdb.Ping(), "no router") + require.Eventually(t, func() bool { + addrMu.Lock() + defer addrMu.Unlock() + return connAddr == server.listeners[0].Addr().String() + }, 3*time.Second, 10*time.Millisecond) +} + func TestWatchCfg(t *testing.T) { lg, _ := logger.CreateLoggerForTest(t) hsHandler := backend.NewDefaultHandshakeHandler(nil) @@ -253,6 +333,35 @@ func TestWatchCfg(t *testing.T) { require.NoError(t, server.Close()) } +func findFreePortRange(t *testing.T, size int) (start, end int) { + t.Helper() + for range 128 { + probe, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + start = probe.Addr().(*net.TCPAddr).Port + require.NoError(t, probe.Close()) + + listeners := make([]net.Listener, 0, size) + ok := true + for port := start; port < start+size; port++ { + listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", port)) + if err != nil { + ok = false + break + } + listeners = append(listeners, listener) + } + for _, listener := range listeners { + require.NoError(t, listener.Close()) + } + if ok { + return start, start + size - 1 + } + } + t.Fatal("failed to find free contiguous ports") + return 0, 0 +} + func TestRecoverPanic(t *testing.T) { lg, text := logger.CreateLoggerForTest(t) certManager := cert.NewCertManager() @@ -327,6 +436,7 @@ func TestPublicEndpoint(t *testing.T) { type mockHsHandler struct { backend.DefaultHandshakeHandler handshakeResp func(ctx backend.ConnContext, _ *pnet.HandshakeResp) error + getRouter func(ctx backend.ConnContext, _ *pnet.HandshakeResp) (router.Router, error) } // HandleHandshakeResp only panics for the first connections. @@ -342,6 +452,9 @@ func (handler *mockHsHandler) GetServerVersion() string { } // GetRouter returns an error for the second connection. -func (handler *mockHsHandler) GetRouter(backend.ConnContext, *pnet.HandshakeResp) (router.Router, error) { +func (handler *mockHsHandler) GetRouter(ctx backend.ConnContext, resp *pnet.HandshakeResp) (router.Router, error) { + if handler.getRouter != nil { + return handler.getRouter(ctx, resp) + } return nil, errors.New("no router") } diff --git a/pkg/server/api/config_test.go b/pkg/server/api/config_test.go index 712ee4c05..105c2eb33 100644 --- a/pkg/server/api/config_test.go +++ b/pkg/server/api/config_test.go @@ -62,6 +62,89 @@ func TestConfig(t *testing.T) { require.NotEqual(t, sum, string(sumreg.Find(all))) require.Equal(t, http.StatusOK, r.StatusCode) }) + doHTTP(t, http.MethodPut, "/api/admin/config", httpOpts{reader: strings.NewReader(` +[[proxy.backend-clusters]] +name = "cluster-a" +pd-addrs = "127.0.0.1:2379" +ns-servers = ["10.0.0.1"] + +[[proxy.backend-clusters]] +name = "cluster-b" +pd-addrs = "127.0.0.2:2379" +ns-servers = ["10.0.0.2", "10.0.0.3"] +`)}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTP(t, http.MethodGet, "/api/admin/config?format=json", httpOpts{}, func(t *testing.T, r *http.Response) { + var cfg config.Config + all, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(all, &cfg)) + require.Len(t, cfg.Proxy.BackendClusters, 2) + require.Equal(t, "cluster-a", cfg.Proxy.BackendClusters[0].Name) + require.Equal(t, "127.0.0.1:2379", cfg.Proxy.BackendClusters[0].PDAddrs) + require.Equal(t, []string{"10.0.0.1"}, cfg.Proxy.BackendClusters[0].NSServers) + require.Equal(t, "cluster-b", cfg.Proxy.BackendClusters[1].Name) + require.Equal(t, "127.0.0.2:2379", cfg.Proxy.BackendClusters[1].PDAddrs) + require.Equal(t, []string{"10.0.0.2", "10.0.0.3"}, cfg.Proxy.BackendClusters[1].NSServers) + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTP(t, http.MethodPut, "/api/admin/config", httpOpts{reader: strings.NewReader(` +[[proxy.backend-clusters]] +name = "cluster-d" +pd-addrs = "127.0.0.4:2379" +ns-servers = ["10.0.0.4:abc"] +`)}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusInternalServerError, r.StatusCode) + }) + + doHTTP(t, http.MethodGet, "/api/admin/config?format=json", httpOpts{}, func(t *testing.T, r *http.Response) { + var cfg config.Config + all, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(all, &cfg)) + require.Len(t, cfg.Proxy.BackendClusters, 2) + require.Equal(t, "cluster-a", cfg.Proxy.BackendClusters[0].Name) + require.Equal(t, "cluster-b", cfg.Proxy.BackendClusters[1].Name) + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTP(t, http.MethodPut, "/api/admin/config", httpOpts{reader: strings.NewReader(` +[[proxy.backend-clusters]] +name = "cluster-c" +pd-addrs = "127.0.0.3:2379" +`)}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTP(t, http.MethodGet, "/api/admin/config?format=json", httpOpts{}, func(t *testing.T, r *http.Response) { + var cfg config.Config + all, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(all, &cfg)) + require.Len(t, cfg.Proxy.BackendClusters, 1) + require.Equal(t, "cluster-c", cfg.Proxy.BackendClusters[0].Name) + require.Equal(t, "127.0.0.3:2379", cfg.Proxy.BackendClusters[0].PDAddrs) + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTP(t, http.MethodPut, "/api/admin/config", httpOpts{reader: strings.NewReader(` +[proxy] +backend-clusters = [] +`)}, func(t *testing.T, r *http.Response) { + require.Equal(t, http.StatusOK, r.StatusCode) + }) + + doHTTP(t, http.MethodGet, "/api/admin/config?format=json", httpOpts{}, func(t *testing.T, r *http.Response) { + var cfg config.Config + all, err := io.ReadAll(r.Body) + require.NoError(t, err) + require.NoError(t, json.Unmarshal(all, &cfg)) + require.Empty(t, cfg.Proxy.BackendClusters) + require.Equal(t, http.StatusOK, r.StatusCode) + }) } func TestAcceptType(t *testing.T) {