diff --git a/cmd/apiutil/client.go b/cmd/apiutil/client.go index 3ebf8f8..1541fb9 100644 --- a/cmd/apiutil/client.go +++ b/cmd/apiutil/client.go @@ -76,13 +76,23 @@ func RawGet(ctx context.Context, client *api.APIClient, path string, params url. return body, nil } +// NormalizeServerURL trims trailing slashes from raw and appends /api/v1 exactly +// once. Returns an empty string when raw is blank or all slashes. +func NormalizeServerURL(raw string) string { + s := strings.TrimRight(raw, "/") + if s == "" { + return "" + } + if !strings.HasSuffix(s, "/api/v1") { + s += "/api/v1" + } + return s +} + // NewAPIClientWithKeyAndTransport is the base constructor used by all other helpers. func NewAPIClientWithKeyAndTransport(apiKey string, transport http.RoundTripper) *api.APIClient { configuration := api.NewConfiguration() - serverURL := viper.GetString("seerr.server") - if !strings.HasSuffix(serverURL, "/api/v1") { - serverURL = strings.TrimSuffix(serverURL, "/") + "/api/v1" - } + serverURL := NormalizeServerURL(viper.GetString("seerr.server")) configuration.Servers = api.ServerConfigurations{{URL: serverURL, Description: "Configured Server"}} key := apiKey if key == "" { diff --git a/cmd/mcp/serve.go b/cmd/mcp/serve.go index a5c79d1..266e54a 100644 --- a/cmd/mcp/serve.go +++ b/cmd/mcp/serve.go @@ -7,6 +7,8 @@ import ( "net/http" "strings" + "seerr-cli/cmd/apiutil" + "github.com/mark3labs/mcp-go/server" "github.com/spf13/cobra" "github.com/spf13/viper" @@ -75,6 +77,10 @@ func runServe(_ *cobra.Command, args []string) error { return err } + if err := ValidateServeConfig(); err != nil { + return err + } + if transport == "http" && authToken == "" && routeToken == "" && !noAuth { return fmt.Errorf("HTTP transport requires --auth-token, --route-token, or --no-auth (insecure) to be set explicitly") } @@ -185,6 +191,16 @@ func runServe(_ *cobra.Command, args []string) error { } } +// ValidateServeConfig checks that the Seerr server URL is configured. It is +// exported so that tests can verify the fail-fast behaviour without starting +// the server. +func ValidateServeConfig() error { + if apiutil.NormalizeServerURL(viper.GetString("seerr.server")) == "" { + return fmt.Errorf("seerr.server is not configured; set it with --server or add seerr.server to ~/.seerr-cli.yaml") + } + return nil +} + // HealthCheckHandler responds to GET /health with a JSON status payload. // It is exported so that it can be tested directly from the tests package. func HealthCheckHandler(w http.ResponseWriter, r *http.Request) { diff --git a/tests/apiutil_client_test.go b/tests/apiutil_client_test.go new file mode 100644 index 0000000..ba7d68b --- /dev/null +++ b/tests/apiutil_client_test.go @@ -0,0 +1,58 @@ +package tests + +import ( + "testing" + + "seerr-cli/cmd/apiutil" + + "github.com/stretchr/testify/assert" +) + +func TestNormalizeServerURL(t *testing.T) { + tests := []struct { + name string + input string + want string + }{ + { + name: "bare host", + input: "https://host", + want: "https://host/api/v1", + }, + { + name: "trailing slash", + input: "https://host/", + want: "https://host/api/v1", + }, + { + name: "multiple trailing slashes", + input: "https://host///", + want: "https://host/api/v1", + }, + { + name: "already has api/v1", + input: "https://host/api/v1", + want: "https://host/api/v1", + }, + { + name: "api/v1 with trailing slash", + input: "https://host/api/v1/", + want: "https://host/api/v1", + }, + { + name: "empty string", + input: "", + want: "", + }, + { + name: "only slashes", + input: "///", + want: "", + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, apiutil.NormalizeServerURL(tc.input)) + }) + } +} diff --git a/tests/mcp_serve_validation_test.go b/tests/mcp_serve_validation_test.go new file mode 100644 index 0000000..fd41a47 --- /dev/null +++ b/tests/mcp_serve_validation_test.go @@ -0,0 +1,58 @@ +package tests + +import ( + "testing" + + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + cmdmcp "seerr-cli/cmd/mcp" +) + +func TestMCPServeFailsFastWithoutSeerrServer(t *testing.T) { + original := viper.GetString("seerr.server") + t.Cleanup(func() { viper.Set("seerr.server", original) }) + + tests := []struct { + name string + seerrServer string + wantErr bool + errContains string + }{ + { + name: "missing server returns error", + seerrServer: "", + wantErr: true, + errContains: "seerr.server", + }, + { + name: "only slashes returns error", + seerrServer: "///", + wantErr: true, + errContains: "seerr.server", + }, + { + name: "valid server passes validation", + seerrServer: "http://localhost:5055", + wantErr: false, + }, + { + name: "valid server with trailing slash passes validation", + seerrServer: "http://localhost:5055/", + wantErr: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + viper.Set("seerr.server", tc.seerrServer) + err := cmdmcp.ValidateServeConfig() + if tc.wantErr { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.errContains) + } else { + require.NoError(t, err) + } + }) + } +}