From a37ea698d8884a28fb8848700e9ce395f93b967b Mon Sep 17 00:00:00 2001 From: Omid Astaraki Date: Mon, 16 Mar 2026 12:16:29 +0000 Subject: [PATCH] feat(http): add default timeouts and graceful shutdown for HTTP client and MCP server Set a 30 s timeout on every outbound API HTTP client so requests cannot hang indefinitely. Add ReadHeader/Read/Write/IdleTimeout to the MCP HTTP server to prevent slow-client resource exhaustion. Listen for SIGINT and SIGTERM on the HTTP transport and perform a graceful shutdown with a 30 s deadline. Closes #66 --- cmd/apiutil/client.go | 9 +++++- cmd/mcp/serve.go | 60 +++++++++++++++++++++++++++++++---- tests/apiutil_client_test.go | 21 ++++++++++++ tests/mcp_http_server_test.go | 38 ++++++++++++++++++++++ 4 files changed, 121 insertions(+), 7 deletions(-) create mode 100644 tests/mcp_http_server_test.go diff --git a/cmd/apiutil/client.go b/cmd/apiutil/client.go index 1541fb9..a49a1cb 100644 --- a/cmd/apiutil/client.go +++ b/cmd/apiutil/client.go @@ -7,12 +7,16 @@ import ( "net/http" "net/url" "strings" + "time" api "seerr-cli/pkg/api" "github.com/spf13/viper" ) +// DefaultHTTPClientTimeout is the timeout applied to all outbound API requests. +const DefaultHTTPClientTimeout = 30 * time.Second + // OverrideServerURL is used by tests to redirect API calls to a mock server. var OverrideServerURL string @@ -101,9 +105,12 @@ func NewAPIClientWithKeyAndTransport(apiKey string, transport http.RoundTripper) if key != "" { configuration.AddDefaultHeader("X-Api-Key", key) } + // Always set an explicit timeout so outbound requests cannot hang indefinitely. + httpClient := &http.Client{Timeout: DefaultHTTPClientTimeout} if transport != nil { - configuration.HTTPClient = &http.Client{Transport: transport} + httpClient.Transport = transport } + configuration.HTTPClient = httpClient if OverrideServerURL != "" { configuration.Servers = api.ServerConfigurations{{URL: OverrideServerURL, Description: "Mock Server"}} } diff --git a/cmd/mcp/serve.go b/cmd/mcp/serve.go index c83c18c..3b2f170 100644 --- a/cmd/mcp/serve.go +++ b/cmd/mcp/serve.go @@ -5,7 +5,11 @@ import ( "crypto/subtle" "fmt" "net/http" + "os" + "os/signal" "strings" + "syscall" + "time" "seerr-cli/cmd/apiutil" @@ -14,6 +18,28 @@ import ( "github.com/spf13/viper" ) +// Default timeout values for the MCP HTTP server. +const ( + httpReadHeaderTimeout = 5 * time.Second + httpReadTimeout = 15 * time.Second + httpWriteTimeout = 30 * time.Second + httpIdleTimeout = 60 * time.Second + httpShutdownTimeout = 30 * time.Second +) + +// NewHTTPServer creates an http.Server bound to addr with safe default timeouts. +// It is exported so tests can assert that the server is properly configured. +func NewHTTPServer(addr string, handler http.Handler) *http.Server { + return &http.Server{ + Addr: addr, + Handler: handler, + ReadHeaderTimeout: httpReadHeaderTimeout, + ReadTimeout: httpReadTimeout, + WriteTimeout: httpWriteTimeout, + IdleTimeout: httpIdleTimeout, + } +} + var buildVersion = "dev" // SetVersionInfo injects the linker-set build version so the MCP server can @@ -178,14 +204,36 @@ func runServe(_ *cobra.Command, args []string) error { if cors { handler = corsMiddleware(handler) } - srv := &http.Server{ - Addr: addr, - Handler: handler, - } + srv := NewHTTPServer(addr, handler) + + // Catch SIGINT and SIGTERM so the server shuts down gracefully. + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt, syscall.SIGTERM) + + serveErrCh := make(chan error, 1) if tlsCert != "" && tlsKey != "" { - return srv.ListenAndServeTLS(tlsCert, tlsKey) + go func() { serveErrCh <- srv.ListenAndServeTLS(tlsCert, tlsKey) }() + } else { + go func() { serveErrCh <- srv.ListenAndServe() }() + } + + select { + case err := <-serveErrCh: + // Server exited on its own (e.g. port already in use). + if err != nil && err != http.ErrServerClosed { + return err + } + return nil + case <-sigCh: + mcpLog.Info("shutting down MCP HTTP server") + } + + shutdownCtx, cancel := context.WithTimeout(context.Background(), httpShutdownTimeout) + defer cancel() + if err := srv.Shutdown(shutdownCtx); err != nil { + return fmt.Errorf("graceful shutdown: %w", err) } - return srv.ListenAndServe() + return nil default: return fmt.Errorf("unknown transport %q: must be stdio or http", transport) } diff --git a/tests/apiutil_client_test.go b/tests/apiutil_client_test.go index ba7d68b..80c13ed 100644 --- a/tests/apiutil_client_test.go +++ b/tests/apiutil_client_test.go @@ -1,11 +1,14 @@ package tests import ( + "net/http" "testing" + "time" "seerr-cli/cmd/apiutil" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNormalizeServerURL(t *testing.T) { @@ -56,3 +59,21 @@ func TestNormalizeServerURL(t *testing.T) { }) } } + +func TestDefaultHTTPClientHasTimeout(t *testing.T) { + // Verify the default HTTP client carries a 30 s timeout so requests cannot hang indefinitely. + client := apiutil.NewAPIClientWithKeyAndTransport("", nil) + cfg := client.GetConfig() + require.NotNil(t, cfg.HTTPClient) + assert.Equal(t, 30*time.Second, cfg.HTTPClient.Timeout) +} + +func TestCustomTransportAlsoGetsTimeout(t *testing.T) { + // Verify a custom transport still gets wrapped in a client with the default timeout. + transport := &http.Transport{} + client := apiutil.NewAPIClientWithKeyAndTransport("", transport) + cfg := client.GetConfig() + require.NotNil(t, cfg.HTTPClient) + assert.Equal(t, 30*time.Second, cfg.HTTPClient.Timeout) + assert.Equal(t, transport, cfg.HTTPClient.Transport) +} diff --git a/tests/mcp_http_server_test.go b/tests/mcp_http_server_test.go new file mode 100644 index 0000000..51e2373 --- /dev/null +++ b/tests/mcp_http_server_test.go @@ -0,0 +1,38 @@ +package tests + +import ( + "net/http" + "testing" + "time" + + cmdmcp "seerr-cli/cmd/mcp" + + "github.com/stretchr/testify/assert" +) + +func TestHTTPServerHasNonZeroTimeouts(t *testing.T) { + // Verify the HTTP server is created with non-zero timeout values to prevent + // resource exhaustion from slow or stuck connections. + srv := cmdmcp.NewHTTPServer(":8811", http.NewServeMux()) + assert.NotZero(t, srv.ReadHeaderTimeout) + assert.NotZero(t, srv.ReadTimeout) + assert.NotZero(t, srv.WriteTimeout) + assert.NotZero(t, srv.IdleTimeout) +} + +func TestHTTPServerTimeoutValues(t *testing.T) { + // Verify the HTTP server timeout values match the documented safe defaults. + srv := cmdmcp.NewHTTPServer(":8811", http.NewServeMux()) + assert.Equal(t, 5*time.Second, srv.ReadHeaderTimeout) + assert.Equal(t, 15*time.Second, srv.ReadTimeout) + assert.Equal(t, 30*time.Second, srv.WriteTimeout) + assert.Equal(t, 60*time.Second, srv.IdleTimeout) +} + +func TestHTTPServerAddrAndHandler(t *testing.T) { + // Verify the Addr and Handler are set correctly. + mux := http.NewServeMux() + srv := cmdmcp.NewHTTPServer(":9999", mux) + assert.Equal(t, ":9999", srv.Addr) + assert.NotNil(t, srv.Handler) +}