From b5350100748bca1b9da940309222914b112fe0e1 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Mon, 4 May 2026 17:26:29 +0800 Subject: [PATCH 1/3] server: add Prometheus metrics endpoint --- pkg/server/server.go | 92 +++++++++++++++++++++++++++++++++++++ pkg/server/server_test.go | 95 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 187 insertions(+) diff --git a/pkg/server/server.go b/pkg/server/server.go index a2302e6..18f827e 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -701,6 +701,88 @@ func (s *Server) ListConnectionInfo() (result []*ConnInfo) { return } +func prometheusEscapeLabelValue(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, "\n", `\n`) + s = strings.ReplaceAll(s, `"`, `\"`) + return s +} + +func prometheusLabelValueOrUnknown(s string) string { + if s == "" { + return "unknown" + } + return s +} + +func prometheusLabels(index uint32, module, upstream string) string { + return fmt.Sprintf( + `index="%d",module="%s",upstream="%s"`, + index, + prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(module)), + prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(upstream)), + ) +} + +func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { + connections := s.ListConnectionInfo() + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge") + _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount()) + + connectionCounts := make(map[string]int) + for _, conn := range connections { + module := prometheusLabelValueOrUnknown(conn.Module) + upstream := prometheusLabelValueOrUnknown(conn.UpstreamAddr) + key := module + "\xff" + upstream + connectionCounts[key]++ + } + + keys := make([]string, 0, len(connectionCounts)) + for key := range connectionCounts { + keys = append(keys, key) + } + sort.Strings(keys) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections_by_module Current active rsync proxy connections by module and upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections_by_module gauge") + for _, key := range keys { + parts := strings.SplitN(key, "\xff", 2) + module := prometheusEscapeLabelValue(parts[0]) + upstream := prometheusEscapeLabelValue(parts[1]) + _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections_by_module{module=\"%s\",upstream=\"%s\"} %d\n", module, upstream, connectionCounts[key]) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_sent_bytes Bytes sent to clients for active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_sent_bytes gauge") + for _, conn := range connections { + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_sent_bytes{%s} %d\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), conn.SentBytes.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_received_bytes Bytes received from clients for active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_received_bytes gauge") + for _, conn := range connections { + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_received_bytes{%s} %d\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), conn.ReceivedBytes.Load()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_connected_timestamp_seconds Unix timestamp when active connections were established.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_connected_timestamp_seconds gauge") + for _, conn := range connections { + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_connected_timestamp_seconds{%s} %d\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), conn.ConnectedAt.Unix()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_duration_seconds Current duration of active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_duration_seconds gauge") + for _, conn := range connections { + duration := now.Sub(conn.ConnectedAt).Seconds() + if duration < 0 { + duration = 0 + } + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.0f\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), duration) + } +} + func (s *Server) runHTTPServer() error { hostname, err := os.Hostname() if err != nil { @@ -804,6 +886,16 @@ func (s *Server) runHTTPServer() error { _, _ = fmt.Fprintf(w, "rsync-proxy,host=%s count=%d %d\n", hostname, count, timestamp) }) + mux.HandleFunc("/metrics", func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + w.Header().Set("Content-Type", "text/plain; version=0.0.4; charset=utf-8") + s.writePrometheusMetrics(w, time.Now()) + }) + return http.Serve(s.HTTPListener, &mux) } diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index 291c453..e740dec 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "net/http" "os" "path/filepath" "strings" @@ -340,6 +341,100 @@ func TestStatusIncludesSelectedUpstream(t *testing.T) { wg.Done() } +func TestMetricsEndpointNoConnections(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + resp, err := http.Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Equal(t, "text/plain; version=0.0.4; charset=utf-8", resp.Header.Get("Content-Type")) + assert.Contains(t, text, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") + assert.Contains(t, text, "# TYPE rsync_proxy_active_connections gauge") + assert.Contains(t, text, "rsync_proxy_active_connections 0\n") +} + +func TestMetricsEndpointRejectsNonGET(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + resp, err := http.Post("http://"+srv.HTTPListener.Addr().String()+"/metrics", "text/plain", nil) + require.NoError(t, err) + defer resp.Body.Close() + + assert.Equal(t, http.StatusMethodNotAllowed, resp.StatusCode) +} + +func TestMetricsIncludesActiveConnections(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + var wg sync.WaitGroup + wg.Add(1) + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + require.NoError(t, err) + wg.Wait() + }) + fakeRsync.Start() + defer fakeRsync.Close() + + upstreamAddr := fakeRsync.Listener.Addr().String() + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: upstreamAddr}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + conn := rsync.NewConn(rawConn) + defer conn.Close() + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + require.NoError(t, err) + + require.Eventually(t, func() bool { + infos := srv.ListConnectionInfo() + return len(infos) == 1 && infos[0].UpstreamAddr == upstreamAddr + }, time.Second, 10*time.Millisecond) + + resp, err := http.Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, text, "rsync_proxy_active_connections 1\n") + assert.Contains(t, text, fmt.Sprintf("rsync_proxy_active_connections_by_module{module=\"fake\",upstream=%q} 1\n", upstreamAddr)) + assert.Contains(t, text, "rsync_proxy_connection_sent_bytes{index=\"") + assert.Contains(t, text, "module=\"fake\"") + assert.Contains(t, text, fmt.Sprintf("upstream=%q", upstreamAddr)) + assert.Contains(t, text, "rsync_proxy_connection_received_bytes{index=\"") + assert.Contains(t, text, "rsync_proxy_connection_connected_timestamp_seconds{index=\"") + assert.Contains(t, text, "rsync_proxy_connection_duration_seconds{index=\"") + assert.NotContains(t, text, rawConn.LocalAddr().String()) + + wg.Done() +} + +func TestPrometheusLabelValueEscaping(t *testing.T) { + assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain")) + assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`)) + assert.Equal(t, `slash\\value`, prometheusEscapeLabelValue(`slash\value`)) + assert.Equal(t, `line\nbreak`, prometheusEscapeLabelValue("line\nbreak")) + assert.Equal(t, `unknown`, prometheusLabelValueOrUnknown("")) +} + func TestPerUpstreamQueueIsolation(t *testing.T) { srv := startServer(t) defer srv.Close() From e617a493de18ab2a3307202d894ec2f7a0600504 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Mon, 4 May 2026 17:57:08 +0800 Subject: [PATCH 2/3] server: address metrics review feedback --- pkg/server/metrics.go | 106 +++++++++++++++++++++++++++++++++ pkg/server/server.go | 120 +++++++++----------------------------- pkg/server/server_test.go | 58 ++++++++++++++++-- 3 files changed, 187 insertions(+), 97 deletions(-) create mode 100644 pkg/server/metrics.go diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go new file mode 100644 index 0000000..01bf7b3 --- /dev/null +++ b/pkg/server/metrics.go @@ -0,0 +1,106 @@ +package server + +import ( + "fmt" + "io" + "sort" + "strings" + "time" +) + +func prometheusEscapeLabelValue(s string) string { + s = strings.ReplaceAll(s, `\`, `\\`) + s = strings.ReplaceAll(s, "\n", `\n`) + s = strings.ReplaceAll(s, `"`, `\"`) + return s +} + +func prometheusLabelValueOrUnknown(s string) string { + if s == "" { + return "unknown" + } + return s +} + +func prometheusLabels(index uint32, module, upstream string) string { + return fmt.Sprintf( + `index="%d",module="%s",upstream="%s"`, + index, + prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(module)), + prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(upstream)), + ) +} + +type prometheusConnectionGroup struct { + module string + upstream string +} + +func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { + connections := s.ListConnectionInfo() + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge") + _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount()) + + connectionCounts := make(map[prometheusConnectionGroup]int) + for _, conn := range connections { + _, module, upstream, _, _, _ := conn.snapshot() + key := prometheusConnectionGroup{ + module: prometheusLabelValueOrUnknown(module), + upstream: prometheusLabelValueOrUnknown(upstream), + } + connectionCounts[key]++ + } + + keys := make([]prometheusConnectionGroup, 0, len(connectionCounts)) + for key := range connectionCounts { + keys = append(keys, key) + } + sort.Slice(keys, func(i, j int) bool { + if keys[i].module != keys[j].module { + return keys[i].module < keys[j].module + } + return keys[i].upstream < keys[j].upstream + }) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections_by_module Current active rsync proxy connections by module and upstream.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections_by_module gauge") + for _, key := range keys { + module := prometheusEscapeLabelValue(key.module) + upstream := prometheusEscapeLabelValue(key.upstream) + _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections_by_module{module=\"%s\",upstream=\"%s\"} %d\n", module, upstream, connectionCounts[key]) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_sent_bytes Bytes sent to clients for active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_sent_bytes gauge") + for _, conn := range connections { + index, module, upstream, _, sentBytes, _ := conn.snapshot() + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_sent_bytes{%s} %d\n", prometheusLabels(index, module, upstream), sentBytes) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_received_bytes Bytes received from clients for active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_received_bytes gauge") + for _, conn := range connections { + index, module, upstream, _, _, receivedBytes := conn.snapshot() + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_received_bytes{%s} %d\n", prometheusLabels(index, module, upstream), receivedBytes) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_connected_timestamp_seconds Unix timestamp when active connections were established.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_connected_timestamp_seconds gauge") + for _, conn := range connections { + index, module, upstream, connectedAt, _, _ := conn.snapshot() + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_connected_timestamp_seconds{%s} %d\n", prometheusLabels(index, module, upstream), connectedAt.Unix()) + } + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_duration_seconds Current duration of active connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_duration_seconds gauge") + for _, conn := range connections { + index, module, upstream, connectedAt, _, _ := conn.snapshot() + duration := now.Sub(connectedAt).Seconds() + if duration < 0 { + duration = 0 + } + _, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.3f\n", prometheusLabels(index, module, upstream), duration) + } +} diff --git a/pkg/server/server.go b/pkg/server/server.go index 18f827e..4bb11e1 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -50,6 +50,7 @@ var ( const lineFeed = '\n' type ConnInfo struct { + mu sync.RWMutex Index uint32 LocalAddr string RemoteAddr string @@ -60,7 +61,26 @@ type ConnInfo struct { ReceivedBytes atomic.Int64 } +func (c *ConnInfo) SetModule(module string) { + c.mu.Lock() + defer c.mu.Unlock() + c.Module = module +} + +func (c *ConnInfo) SetUpstreamAddr(upstreamAddr string) { + c.mu.Lock() + defer c.mu.Unlock() + c.UpstreamAddr = upstreamAddr +} + +func (c *ConnInfo) snapshot() (index uint32, module, upstreamAddr string, connectedAt time.Time, sentBytes, receivedBytes int64) { + c.mu.RLock() + defer c.mu.RUnlock() + return c.Index, c.Module, c.UpstreamAddr, c.ConnectedAt, c.SentBytes.Load(), c.ReceivedBytes.Load() +} + func (c *ConnInfo) MarshalJSON() ([]byte, error) { + index, module, upstreamAddr, connectedAt, sentBytes, receivedBytes := c.snapshot() // Handle atomic value (cannot marshal directly) return json.Marshal(struct { Index uint32 `json:"index"` @@ -72,14 +92,14 @@ func (c *ConnInfo) MarshalJSON() ([]byte, error) { SentBytes int64 `json:"sentBytes"` ReceivedBytes int64 `json:"receivedBytes"` }{ - Index: c.Index, + Index: index, LocalAddr: c.LocalAddr, RemoteAddr: c.RemoteAddr, - ConnectedAt: c.ConnectedAt, - Module: c.Module, - UpstreamAddr: c.UpstreamAddr, - SentBytes: c.SentBytes.Load(), - ReceivedBytes: c.ReceivedBytes.Load(), + ConnectedAt: connectedAt, + Module: module, + UpstreamAddr: upstreamAddr, + SentBytes: sentBytes, + ReceivedBytes: receivedBytes, }) } @@ -537,8 +557,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err } moduleName := string(buf[:n-1]) // trim trailing \n - info.Module = moduleName - s.connInfo.Store(index, &info) + info.SetModule(moduleName) targets, ok := s.getTargetsForModule(moduleName) if !ok { @@ -551,8 +570,7 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err target := targets[chooseTargetByClientIP(net.ParseIP(ip), len(targets))] upstreamAddr := target.Addr useProxyProtocol := target.UseProxyProtocol - info.UpstreamAddr = upstreamAddr - s.connInfo.Store(index, &info) + info.SetUpstreamAddr(upstreamAddr) upstreamQueue, ok := s.getQueueForUpstream(target.Upstream) if !ok { @@ -701,88 +719,6 @@ func (s *Server) ListConnectionInfo() (result []*ConnInfo) { return } -func prometheusEscapeLabelValue(s string) string { - s = strings.ReplaceAll(s, `\`, `\\`) - s = strings.ReplaceAll(s, "\n", `\n`) - s = strings.ReplaceAll(s, `"`, `\"`) - return s -} - -func prometheusLabelValueOrUnknown(s string) string { - if s == "" { - return "unknown" - } - return s -} - -func prometheusLabels(index uint32, module, upstream string) string { - return fmt.Sprintf( - `index="%d",module="%s",upstream="%s"`, - index, - prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(module)), - prometheusEscapeLabelValue(prometheusLabelValueOrUnknown(upstream)), - ) -} - -func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { - connections := s.ListConnectionInfo() - - _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") - _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge") - _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount()) - - connectionCounts := make(map[string]int) - for _, conn := range connections { - module := prometheusLabelValueOrUnknown(conn.Module) - upstream := prometheusLabelValueOrUnknown(conn.UpstreamAddr) - key := module + "\xff" + upstream - connectionCounts[key]++ - } - - keys := make([]string, 0, len(connectionCounts)) - for key := range connectionCounts { - keys = append(keys, key) - } - sort.Strings(keys) - - _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections_by_module Current active rsync proxy connections by module and upstream.") - _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections_by_module gauge") - for _, key := range keys { - parts := strings.SplitN(key, "\xff", 2) - module := prometheusEscapeLabelValue(parts[0]) - upstream := prometheusEscapeLabelValue(parts[1]) - _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections_by_module{module=\"%s\",upstream=\"%s\"} %d\n", module, upstream, connectionCounts[key]) - } - - _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_sent_bytes Bytes sent to clients for active connections.") - _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_sent_bytes gauge") - for _, conn := range connections { - _, _ = fmt.Fprintf(w, "rsync_proxy_connection_sent_bytes{%s} %d\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), conn.SentBytes.Load()) - } - - _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_received_bytes Bytes received from clients for active connections.") - _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_received_bytes gauge") - for _, conn := range connections { - _, _ = fmt.Fprintf(w, "rsync_proxy_connection_received_bytes{%s} %d\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), conn.ReceivedBytes.Load()) - } - - _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_connected_timestamp_seconds Unix timestamp when active connections were established.") - _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_connected_timestamp_seconds gauge") - for _, conn := range connections { - _, _ = fmt.Fprintf(w, "rsync_proxy_connection_connected_timestamp_seconds{%s} %d\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), conn.ConnectedAt.Unix()) - } - - _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_connection_duration_seconds Current duration of active connections.") - _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_connection_duration_seconds gauge") - for _, conn := range connections { - duration := now.Sub(conn.ConnectedAt).Seconds() - if duration < 0 { - duration = 0 - } - _, _ = fmt.Fprintf(w, "rsync_proxy_connection_duration_seconds{%s} %.0f\n", prometheusLabels(conn.Index, conn.Module, conn.UpstreamAddr), duration) - } -} - func (s *Server) runHTTPServer() error { hostname, err := os.Hostname() if err != nil { diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index e740dec..def69b4 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -1,6 +1,7 @@ package server import ( + "bytes" "crypto/tls" "crypto/x509" "fmt" @@ -58,6 +59,10 @@ func startServer(t *testing.T) *Server { return srv } +func testHTTPClient() *http.Client { + return &http.Client{Timeout: time.Second} +} + func doClientHandshake(conn *rsync.Conn, version []byte, module string) (svrVersion string, err error) { _, err = conn.Write(version) if err != nil { @@ -335,7 +340,11 @@ func TestStatusIncludesSelectedUpstream(t *testing.T) { require.Eventually(t, func() bool { infos := srv.ListConnectionInfo() - return len(infos) == 1 && infos[0].UpstreamAddr == upstreamAddr + if len(infos) != 1 { + return false + } + _, _, infoUpstreamAddr, _, _, _ := infos[0].snapshot() + return infoUpstreamAddr == upstreamAddr }, time.Second, 10*time.Millisecond) wg.Done() @@ -345,7 +354,7 @@ func TestMetricsEndpointNoConnections(t *testing.T) { srv := startServer(t) defer srv.Close() - resp, err := http.Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") require.NoError(t, err) defer resp.Body.Close() @@ -364,7 +373,7 @@ func TestMetricsEndpointRejectsNonGET(t *testing.T) { srv := startServer(t) defer srv.Close() - resp, err := http.Post("http://"+srv.HTTPListener.Addr().String()+"/metrics", "text/plain", nil) + resp, err := testHTTPClient().Post("http://"+srv.HTTPListener.Addr().String()+"/metrics", "text/plain", nil) require.NoError(t, err) defer resp.Body.Close() @@ -402,10 +411,14 @@ func TestMetricsIncludesActiveConnections(t *testing.T) { require.Eventually(t, func() bool { infos := srv.ListConnectionInfo() - return len(infos) == 1 && infos[0].UpstreamAddr == upstreamAddr + if len(infos) != 1 { + return false + } + _, _, infoUpstreamAddr, _, _, _ := infos[0].snapshot() + return infoUpstreamAddr == upstreamAddr }, time.Second, 10*time.Millisecond) - resp, err := http.Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") require.NoError(t, err) defer resp.Body.Close() @@ -427,6 +440,41 @@ func TestMetricsIncludesActiveConnections(t *testing.T) { wg.Done() } +func TestPrometheusConnectionGroupingUsesStructuredKey(t *testing.T) { + srv := New() + + first := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)} + first.Module = "a\xffb" + first.UpstreamAddr = "c" + srv.connInfo.Store(first.Index, first) + + second := &ConnInfo{Index: 2, ConnectedAt: time.Unix(100, 0)} + second.Module = "a" + second.UpstreamAddr = "b\xffc" + srv.connInfo.Store(second.Index, second) + + var buf bytes.Buffer + srv.writePrometheusMetrics(&buf, time.Unix(101, 0)) + text := buf.String() + + assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\xffb\",upstream=\"c\"} 1\n") + assert.Contains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 1\n") + assert.NotContains(t, text, "rsync_proxy_active_connections_by_module{module=\"a\",upstream=\"b\xffc\"} 2\n") +} + +func TestPrometheusDurationIncludesFractionalSeconds(t *testing.T) { + srv := New() + conn := &ConnInfo{Index: 1, ConnectedAt: time.Unix(100, 0)} + conn.Module = "fake" + conn.UpstreamAddr = "127.0.0.1:873" + srv.connInfo.Store(conn.Index, conn) + + var buf bytes.Buffer + srv.writePrometheusMetrics(&buf, time.Unix(100, 250_000_000)) + + assert.Contains(t, buf.String(), "rsync_proxy_connection_duration_seconds{index=\"1\",module=\"fake\",upstream=\"127.0.0.1:873\"} 0.250\n") +} + func TestPrometheusLabelValueEscaping(t *testing.T) { assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain")) assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`)) From 82cb742972558b55dcadac7712226dd254a4b1d1 Mon Sep 17 00:00:00 2001 From: yaoge123 Date: Mon, 4 May 2026 17:32:39 +0800 Subject: [PATCH 3/3] server: add lifetime Prometheus counters --- pkg/server/metrics.go | 17 ++++++++++ pkg/server/server.go | 18 +++++++++-- pkg/server/server_test.go | 66 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 3 deletions(-) diff --git a/pkg/server/metrics.go b/pkg/server/metrics.go index 01bf7b3..f5292e3 100644 --- a/pkg/server/metrics.go +++ b/pkg/server/metrics.go @@ -38,11 +38,28 @@ type prometheusConnectionGroup struct { func (s *Server) writePrometheusMetrics(w io.Writer, now time.Time) { connections := s.ListConnectionInfo() + acceptedConnections := s.acceptedConnTotal.Load() _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_active_connections Current active rsync proxy connections.") _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_active_connections gauge") _, _ = fmt.Fprintf(w, "rsync_proxy_active_connections %d\n", s.GetActiveConnectionCount()) + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_accepted_connections_total Total accepted rsync proxy connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_accepted_connections_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_accepted_connections_total %d\n", acceptedConnections) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_completed_connections_total Total completed rsync proxy connections that reached upstream relay.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_completed_connections_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_completed_connections_total %d\n", s.completedConns.Load()) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_sent_bytes_total Total bytes sent to clients for completed connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_sent_bytes_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_sent_bytes_total %d\n", s.sentBytesTotal.Load()) + + _, _ = fmt.Fprintln(w, "# HELP rsync_proxy_received_bytes_total Total bytes received from clients for completed connections.") + _, _ = fmt.Fprintln(w, "# TYPE rsync_proxy_received_bytes_total counter") + _, _ = fmt.Fprintf(w, "rsync_proxy_received_bytes_total %d\n", s.recvBytesTotal.Load()) + connectionCounts := make(map[prometheusConnectionGroup]int) for _, conn := range connections { _, module, upstream, _, _, _ := conn.snapshot() diff --git a/pkg/server/server.go b/pkg/server/server.go index 4bb11e1..4283644 100644 --- a/pkg/server/server.go +++ b/pkg/server/server.go @@ -143,9 +143,13 @@ type Server struct { upstreamQueues map[string]*queue.Queue - activeConnCount atomic.Int64 - connIndex atomic.Uint32 - connInfo sync.Map + activeConnCount atomic.Int64 + connIndex atomic.Uint32 + acceptedConnTotal atomic.Uint64 + connInfo sync.Map + completedConns atomic.Int64 + sentBytesTotal atomic.Int64 + recvBytesTotal atomic.Int64 TCPListener net.Listener TLSListener net.Listener @@ -694,9 +698,16 @@ func (s *Server) relay(ctx context.Context, index uint32, downConn net.Conn) err s.errorLog.F("close downstream read: %v", err) } } + _ = upConn.Close() + _ = downConn.Close() + <-sentClosed + <-receivedClosed sentBytes := info.SentBytes.Load() receivedBytes := info.ReceivedBytes.Load() + s.completedConns.Add(1) + s.sentBytesTotal.Add(sentBytes) + s.recvBytesTotal.Add(receivedBytes) duration := time.Since(info.ConnectedAt) s.accessLog.F("client %s finishes module %s (sent: %d, received: %d, duration: %s)", ip, moduleName, sentBytes, receivedBytes, duration) @@ -886,6 +897,7 @@ func (s *Server) Close() { func (s *Server) handleConn(ctx context.Context, conn net.Conn) { s.activeConnCount.Add(1) + s.acceptedConnTotal.Add(1) defer s.activeConnCount.Add(-1) connIndex := s.connIndex.Add(1) diff --git a/pkg/server/server_test.go b/pkg/server/server_test.go index def69b4..9780eb6 100644 --- a/pkg/server/server_test.go +++ b/pkg/server/server_test.go @@ -475,6 +475,72 @@ func TestPrometheusDurationIncludesFractionalSeconds(t *testing.T) { assert.Contains(t, buf.String(), "rsync_proxy_connection_duration_seconds{index=\"1\",module=\"fake\",upstream=\"127.0.0.1:873\"} 0.250\n") } +func TestPrometheusAcceptedConnectionsTotalUsesUint64(t *testing.T) { + srv := New() + srv.acceptedConnTotal.Store((uint64(1) << 32) + 1) + + var buf bytes.Buffer + srv.writePrometheusMetrics(&buf, time.Unix(100, 0)) + + assert.Contains(t, buf.String(), "rsync_proxy_accepted_connections_total 4294967297\n") +} + +func TestMetricsIncludesLifetimeCounters(t *testing.T) { + srv := startServer(t) + defer srv.Close() + + payload := []byte("payload from upstream\n") + fakeRsync := rsync.NewServer(func(conn *rsync.Conn) { + defer conn.Close() + _, _, err := doServerHandshake(conn, RsyncdServerVersion) + require.NoError(t, err) + _, err = conn.Write(payload) + require.NoError(t, err) + }) + fakeRsync.Start() + defer fakeRsync.Close() + + upstreamAddr := fakeRsync.Listener.Addr().String() + srv.modules = map[string][]Target{ + "fake": {{Upstream: "u1", Addr: upstreamAddr}}, + } + srv.upstreamQueues = map[string]*queue.Queue{"u1": queue.New(0, 0)} + + rawConn, err := net.Dial("tcp", srv.TCPListener.Addr().String()) + require.NoError(t, err) + conn := rsync.NewConn(rawConn) + + _, err = doClientHandshake(conn, RsyncdServerVersion, "fake") + require.NoError(t, err) + + data, err := io.ReadAll(conn) + require.NoError(t, err) + require.NoError(t, conn.Close()) + require.Equal(t, payload, data) + + require.Eventually(t, func() bool { + return srv.GetActiveConnectionCount() == 0 + }, time.Second, 10*time.Millisecond) + + resp, err := testHTTPClient().Get("http://" + srv.HTTPListener.Addr().String() + "/metrics") + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + text := string(body) + + assert.Equal(t, http.StatusOK, resp.StatusCode) + assert.Contains(t, text, "# TYPE rsync_proxy_accepted_connections_total counter\n") + assert.Contains(t, text, "rsync_proxy_accepted_connections_total 1\n") + assert.Contains(t, text, "# TYPE rsync_proxy_completed_connections_total counter\n") + assert.Contains(t, text, "rsync_proxy_completed_connections_total 1\n") + assert.Contains(t, text, "# TYPE rsync_proxy_sent_bytes_total counter\n") + assert.Contains(t, text, fmt.Sprintf("rsync_proxy_sent_bytes_total %d\n", len(payload))) + assert.Contains(t, text, "# TYPE rsync_proxy_received_bytes_total counter\n") + assert.Contains(t, text, "rsync_proxy_received_bytes_total 0\n") +} + func TestPrometheusLabelValueEscaping(t *testing.T) { assert.Equal(t, `plain`, prometheusEscapeLabelValue("plain")) assert.Equal(t, `quote\"value`, prometheusEscapeLabelValue(`quote"value`))