From d24d23c17f0e37de37857a5abf0cc7b6b69179b7 Mon Sep 17 00:00:00 2001 From: "Hareesh.Veligeti" Date: Fri, 23 Jan 2026 12:30:00 +0530 Subject: [PATCH] ba-proxy-agent: close idle connections to mitigate memory leaks The ba-proxy-agent currently experiences increasing memory consumption, leading to daily or weekly restarts. Previous mitigation attempts were insufficient, and the issue has been isolated to high memory usage on the agent connection side. This CL mitigates the issue by enforcing a timeout on idle connections. Since the root cause remains elusive, forcefully closing unused connections prevents memory accumulation from persistent links. Implementation details: * Introduced `lastActivityTime` property to agent connections, which updates upon usage. * Added a background routine to monitor connection activity. * Configured the routine to explicitly close connections that remain idle for more than 30 seconds. --- agent/agent.go | 4 ++- agent/websockets/connection.go | 45 ++++++++++++++++++++--------- agent/websockets/shim.go | 39 ++++++++++++++++++++----- agent/websockets/websockets_test.go | 3 +- 4 files changed, 69 insertions(+), 22 deletions(-) diff --git a/agent/agent.go b/agent/agent.go index 08f62af..b29b268 100644 --- a/agent/agent.go +++ b/agent/agent.go @@ -73,6 +73,7 @@ var ( injectBanner = flag.String("inject-banner", "", "HTML snippet to inject in served webpages") bannerHeight = flag.String("banner-height", "40px", "Height of the injected banner. This is ignored if no banner is set.") shimWebsockets = flag.Bool("shim-websockets", false, "Whether or not to replace websockets with a shim") + websocketShimTimeout = flag.Duration("websocket-shim-timeout", 60*time.Minute, "Timeout for websocket shim connections to expire due to inactivity.") shimPath = flag.String("shim-path", "", "Path under which to handle websocket shim requests") healthCheckPath = flag.String("health-check-path", "/", "Path on backend host to issue health checks against. Defaults to the root.") healthCheckFreq = flag.Int("health-check-interval-seconds", 0, "Wait time in seconds between health checks. Set to zero to disable health checks. Checks disabled by default.") @@ -126,7 +127,8 @@ func hostProxy(ctx context.Context, host, shimPath string, injectShimCode, force // restricted to a path prefix not equal to "/" will fail for websocket open requests. Passing in the // sessionHandler twice allows the websocket handler to ensure that cookies are applied based on the // correct, restored path. - h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler, metricHandler) + h, err = websockets.Proxy(ctx, h, host, shimPath, *rewriteWebsocketHost, *enableWebsocketsInjection, sessionLRU.SessionHandler, + metricHandler, *websocketShimTimeout) if injectShimCode { shimFunc, err := websockets.ShimBody(shimPath) if err != nil { diff --git a/agent/websockets/connection.go b/agent/websockets/connection.go index c14d445..a674bae 100644 --- a/agent/websockets/connection.go +++ b/agent/websockets/connection.go @@ -17,16 +17,16 @@ limitations under the License. package websockets import ( + "context" "encoding/base64" "encoding/json" "errors" "fmt" "log" "net/http" + "sync" "time" - "context" - "github.com/gorilla/websocket" ) @@ -57,12 +57,14 @@ func (m *message) Serialize(version int) interface{} { // and encapsulates it in an API that is a little more amenable to how the server side // of our websocket shim is implemented. type Connection struct { - done func() <-chan struct{} - cancel context.CancelFunc - clientMessages chan *message - serverMessages chan *message - protocolVersion int - subprotocol string + done func() <-chan struct{} + cancel context.CancelFunc + clientMessages chan *message + serverMessages chan *message + protocolVersion int + subprotocol string + mu sync.Mutex + lastActivityTime time.Time } // This map defines the set of headers that should be stripped from the WS request, as they @@ -87,6 +89,20 @@ func stripWSHeader(header http.Header) http.Header { return result } +// updateActivity updates the last activity timestamp. +func (conn *Connection) updateActivity() { + conn.mu.Lock() + defer conn.mu.Unlock() + conn.lastActivityTime = time.Now() +} + +// lastActivity returns the last activity timestamp. +func (conn *Connection) lastActivity() time.Time { + conn.mu.Lock() + defer conn.mu.Unlock() + return conn.lastActivityTime +} + // NewConnection creates and returns a new Connection. func NewConnection(ctx context.Context, targetURL string, header http.Header, errCallback func(err error)) (*Connection, error) { ctx, cancel := context.WithCancel(ctx) @@ -162,11 +178,12 @@ func NewConnection(ctx context.Context, targetURL string, header http.Header, er } }() return &Connection{ - done: ctx.Done, - cancel: cancel, - clientMessages: clientMessages, - serverMessages: serverMessages, - subprotocol: serverConn.Subprotocol(), + done: ctx.Done, + cancel: cancel, + clientMessages: clientMessages, + serverMessages: serverMessages, + subprotocol: serverConn.Subprotocol(), + lastActivityTime: time.Now(), }, nil } @@ -184,6 +201,7 @@ func (conn *Connection) Close() { // // The returned error value is non-nill if the connection has been closed. func (conn *Connection) SendClientMessage(msg interface{}, injectionEnabled bool, injectedHeaders map[string]string) error { + conn.updateActivity() var clientMessage *message if textMsg, ok := msg.(string); ok { clientMessage = &message{ @@ -244,6 +262,7 @@ func (conn *Connection) ReadServerMessages() ([]interface{}, error) { // The server messages channel has been closed. return nil, fmt.Errorf("attempt to read a server message from a closed websocket connection") } + conn.updateActivity() msgs = append(msgs, serverMsg.Serialize(conn.protocolVersion)) for { select { diff --git a/agent/websockets/shim.go b/agent/websockets/shim.go index 135a7a2..229ed29 100644 --- a/agent/websockets/shim.go +++ b/agent/websockets/shim.go @@ -18,6 +18,7 @@ package websockets import ( "bytes" + "context" "encoding/json" "fmt" "io" @@ -31,8 +32,8 @@ import ( "sync" "sync/atomic" "text/template" + "time" - "context" "github.com/google/inverting-proxy/agent/metrics" ) @@ -320,9 +321,33 @@ func (c *connectionErrorHandler) ReportError(err error) { } } -func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler) http.Handler { +func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost bool, openWebsocketWrapper func(http.Handler, *metrics.MetricHandler) http.Handler, enableWebsocketInjection bool, metricHandler *metrics.MetricHandler, timeout time.Duration) http.Handler { var connections sync.Map var sessionCount uint64 + + // Background goroutine to clean up inactive websocket shim connections. + go func() { + ticker := time.NewTicker(min(timeout, 30*time.Second)) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + connections.Range(func(key, value any) bool { + sessionID := key.(string) + conn := value.(*Connection) + if time.Since(conn.lastActivity()) > timeout { + log.Printf("Closing inactive websocket shim session %q after timeout", sessionID) + conn.Close() + connections.Delete(sessionID) + } + return true // Continue iteration + }) + } + } + }() + mux := http.NewServeMux() errorHandler := &connectionErrorHandler{} openWebsocketHandler := openWebsocketWrapper(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -351,9 +376,9 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b } } resp := &sessionMessage{ - ID: sessionID, - Message: targetURL.String(), - Version: conn.protocolVersion, + ID: sessionID, + Message: targetURL.String(), + Version: conn.protocolVersion, Subprotocol: conn.Subprotocol(), } respBytes, err := json.Marshal(resp) @@ -548,11 +573,11 @@ func createShimChannel(ctx context.Context, host, shimPath string, rewriteHost b // openWebsocketWrapper is a http.Handler wrapper function that is invoked on websocket open requests after the original // targetURL of the request is restored. It must call the wrapped http.Handler with which it is created after it // is finished processing the request. -func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler) (http.Handler, error) { +func Proxy(ctx context.Context, wrapped http.Handler, host, shimPath string, rewriteHost, enableWebsocketInjection bool, openWebsocketWrapper func(wrapped http.Handler, metricHandler *metrics.MetricHandler) http.Handler, metricHandler *metrics.MetricHandler, timeout time.Duration) (http.Handler, error) { mux := http.NewServeMux() if shimPath != "" { shimPath = path.Clean("/"+shimPath) + "/" - shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler) + shimServer := createShimChannel(ctx, host, shimPath, rewriteHost, openWebsocketWrapper, enableWebsocketInjection, metricHandler, timeout) mux.Handle(shimPath, shimServer) } mux.Handle("/", wrapped) diff --git a/agent/websockets/websockets_test.go b/agent/websockets/websockets_test.go index 0d674a3..ca08e56 100644 --- a/agent/websockets/websockets_test.go +++ b/agent/websockets/websockets_test.go @@ -30,6 +30,7 @@ import ( "strings" "sync" "testing" + "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" @@ -239,7 +240,7 @@ func TestShimHandlers(t *testing.T) { openWrapper := func(h http.Handler, metricHandler *metrics.MetricHandler) http.Handler { return h } - p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil) + p, err := Proxy(context.Background(), h, serverURL.Host, testShimPath, false, false, openWrapper, nil, 60*time.Second) if err != nil { t.Fatalf("Failure creating the websocket shim proxy: %+v", err) }