diff --git a/.github/workflows/CI.yaml b/.github/workflows/CI.yaml index 6cd622e..1e7e362 100644 --- a/.github/workflows/CI.yaml +++ b/.github/workflows/CI.yaml @@ -67,6 +67,9 @@ jobs: - name: Lint run: make lint + - name: Test + run: make test + - name: Build Docker images run: make docker diff --git a/CLAUDE.md b/CLAUDE.md index 5de6db1..dfaf53c 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -32,7 +32,7 @@ make build-image # Build Docker image skywalking-mcp:latest make clean # Remove build artifacts ``` -No unit tests exist yet. CI runs license checks, lint, and docker build. +Unit tests exist for selected transport/context behavior. CI runs license checks, lint, and docker build. ## Architecture @@ -41,12 +41,11 @@ No unit tests exist yet. CI runs license checks, lint, and docker build. Three MCP transport modes as cobra subcommands: `stdio`, `sse`, `streamable`. The SkyWalking OAP URL is resolved in priority order: -- **stdio**: `set_skywalking_url` session tool > `--sw-url` flag > `http://localhost:12800/graphql` -- **SSE/HTTP**: `SW-URL` HTTP header > `--sw-url` flag > `http://localhost:12800/graphql` +- **All transports**: `--sw-url` flag > `http://localhost:12800/graphql` -The `set_skywalking_url` tool is only available in stdio mode (single client, well-defined session). SSE and HTTP transports use per-request headers instead. +SSE and HTTP transports always use the configured server URL. -Basic auth is configured via `--sw-username` / `--sw-password` flags. Both flags (and the `set_skywalking_url` tool) support `${ENV_VAR}` syntax to resolve credentials from environment variables (e.g. `--sw-password ${MY_SECRET}`). +Basic auth is configured via `--sw-username` / `--sw-password` flags. The startup flags support `${ENV_VAR}` syntax to resolve credentials from environment variables (e.g. `--sw-password ${MY_SECRET}`). Each transport injects the OAP URL and auth into the request context via `WithSkyWalkingURLAndInsecure()` and `WithSkyWalkingAuth()`. Tools extract them downstream using `skywalking-cli`'s `contextkey.BaseURL{}`, `contextkey.Username{}`, and `contextkey.Password{}`. @@ -99,4 +98,4 @@ Tool handlers should return `(mcp.NewToolResultError(...), nil)` for expected qu ## CI & Merge Policy -Squash-merge only. PRs to `main` require 1 approval and passing `Required` status check (license + lint + docker build). Go 1.25. \ No newline at end of file +Squash-merge only. PRs to `main` require 1 approval and passing `Required` status check (license + lint + docker build). Go 1.25. diff --git a/Makefile b/Makefile index 982d91a..a28039a 100644 --- a/Makefile +++ b/Makefile @@ -31,6 +31,8 @@ PLATFORMS ?= linux/amd64 MULTI_PLATFORMS ?= linux/amd64,linux/arm64 OUTPUT ?= --load IMAGE_TAGS ?= -t $(IMAGE):$(VERSION) -t $(IMAGE):latest +GO_TEST_FLAGS ?= +GO_TEST_PKGS ?= ./... .PHONY: all all: build ; @@ -48,6 +50,14 @@ build: ## Build the binary. -X ${VERSION_PATH}.date=${BUILD_DATE}" \ -o bin/swmcp cmd/skywalking-mcp/main.go +.PHONY: test +test: ## Run unit tests. + go test $(GO_TEST_FLAGS) $(GO_TEST_PKGS) + +.PHONY: test-cover +test-cover: ## Run unit tests with coverage output in coverage.txt. + go test $(GO_TEST_FLAGS) -coverprofile=coverage.txt $(GO_TEST_PKGS) + $(GO_LINT): @$(GO_LINT) version > /dev/null 2>&1 || go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.64.0 $(LICENSE_EYE): @@ -139,7 +149,7 @@ PUSH_RELEASE_SCRIPTS := ./scripts/push-release.sh release-push-candidate: ${PUSH_RELEASE_SCRIPTS} -.PHONY: lint fix-lint +.PHONY: lint fix-lint test test-cover .PHONY: license-header fix-license-header dependency-license fix-dependency-license .PHONY: release-binary release-source release-sign release-assembly .PHONY: release-push-candidate docker-build-multi diff --git a/README.md b/README.md index a269b79..af91904 100644 --- a/README.md +++ b/README.md @@ -65,6 +65,11 @@ bin/swmcp stdio --sw-url http://localhost:12800 --sw-username admin --sw-passwor bin/swmcp sse --sse-address localhost:8000 --base-path /mcp --sw-url http://localhost:12800 ``` +Transport URL behavior: + +- `stdio`, `sse`, and `streamable` all use the configured `--sw-url` value (or the default `http://localhost:12800/graphql`). +- `sse` and `streamable` ignore request-level URL override headers. + ### Usage with Cursor, Copilot, Claude Code ```json @@ -128,7 +133,6 @@ SkyWalking MCP provides the following tools to query and analyze SkyWalking OAP | Category | Tool Name | Description | |--------------|--------------------------------|---------------------------------------------------------------------------------------------------| -| **Session** | `set_skywalking_url` | Set the SkyWalking OAP server URL and optional basic auth credentials for the current session (stdio mode only). Supports `${ENV_VAR}` syntax for credentials. | | **Trace** | `query_traces` | Query traces with multi-condition filtering (service, endpoint, state, tags, and time range via start/end/step). Supports `full`, `summary`, and `errors_only` views with performance insights. | | **Log** | `query_logs` | Query logs with filters for service, instance, endpoint, trace ID, tags, and time range. Supports cold storage and pagination. | | **MQE** | `execute_mqe_expression` | Execute MQE (Metrics Query Expression) to query and calculate metrics data. Supports calculations, aggregations, TopN, trend analysis, and multiple result types. | @@ -176,4 +180,4 @@ SkyWalking MCP provides the following prompts for guided analysis workflows: [Apache 2.0 License.](/LICENSE) -[mcp]: https://modelcontextprotocol.io/ \ No newline at end of file +[mcp]: https://modelcontextprotocol.io/ diff --git a/internal/swmcp/server.go b/internal/swmcp/server.go index 8b4a740..0cdba97 100644 --- a/internal/swmcp/server.go +++ b/internal/swmcp/server.go @@ -37,18 +37,13 @@ import ( ) // newMCPServer creates a new MCP server with all tools, resources, and prompts registered. -// When stdio is true, session management tools (set_skywalking_url) are also registered, -// since stdio has a single client and session semantics are well-defined. -func newMCPServer(stdio bool) *server.MCPServer { +func newMCPServer() *server.MCPServer { s := server.NewMCPServer( "skywalking-mcp", "0.1.0", server.WithResourceCapabilities(true, true), server.WithPromptCapabilities(true), server.WithLogging(), ) - if stdio { - AddSessionTools(s) - } tools.AddTraceTools(s) tools.AddLogTools(s) tools.AddMQETools(s) @@ -131,63 +126,31 @@ func withConfiguredAuth(ctx context.Context) context.Context { return ctx } -// urlFromHeaders extracts URL for a request. -// URL is sourced from Header > configured value > Default. -func urlFromHeaders(req *http.Request) string { - urlStr := req.Header.Get("SW-URL") - if urlStr == "" { - return configuredSkyWalkingURL() - } - - return tools.FinalizeURL(urlStr) -} - -// applySessionOverrides checks for a session in the context and applies any -// URL or auth overrides that were set via the set_skywalking_url tool. -func applySessionOverrides(ctx context.Context) context.Context { - session := SessionFromContext(ctx) - if session == nil { - return ctx - } - if url := session.URL(); url != "" { - ctx = context.WithValue(ctx, contextkey.BaseURL{}, url) - } - if username := session.Username(); username != "" { - ctx = WithSkyWalkingAuth(ctx, username, session.Password()) - } - return ctx -} - // EnhanceStdioContextFunc returns a StdioContextFunc that enriches the context -// with SkyWalking settings from the global configuration and a per-session store. +// with SkyWalking settings from the global configuration. func EnhanceStdioContextFunc() server.StdioContextFunc { - session := &Session{} return func(ctx context.Context) context.Context { - ctx = WithSession(ctx, session) ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) ctx = withConfiguredAuth(ctx) - ctx = applySessionOverrides(ctx) return ctx } } // EnhanceSSEContextFunc returns a SSEContextFunc that enriches the context -// with SkyWalking settings from SSE request headers and CLI-configured auth. +// with SkyWalking settings from the CLI configuration and configured auth. func EnhanceSSEContextFunc() server.SSEContextFunc { - return func(ctx context.Context, req *http.Request) context.Context { - urlStr := urlFromHeaders(req) - ctx = WithSkyWalkingURLAndInsecure(ctx, urlStr, false) + return func(ctx context.Context, _ *http.Request) context.Context { + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) ctx = withConfiguredAuth(ctx) return ctx } } // EnhanceHTTPContextFunc returns a HTTPContextFunc that enriches the context -// with SkyWalking settings from HTTP request headers and CLI-configured auth. +// with SkyWalking settings from the CLI configuration and configured auth. func EnhanceHTTPContextFunc() server.HTTPContextFunc { - return func(ctx context.Context, req *http.Request) context.Context { - urlStr := urlFromHeaders(req) - ctx = WithSkyWalkingURLAndInsecure(ctx, urlStr, false) + return func(ctx context.Context, _ *http.Request) context.Context { + ctx = WithSkyWalkingURLAndInsecure(ctx, configuredSkyWalkingURL(), false) ctx = withConfiguredAuth(ctx) return ctx } diff --git a/internal/swmcp/server_registry_test.go b/internal/swmcp/server_registry_test.go new file mode 100644 index 0000000..55d1384 --- /dev/null +++ b/internal/swmcp/server_registry_test.go @@ -0,0 +1,280 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package swmcp + +import ( + "reflect" + "sort" + "testing" + "unsafe" + + "github.com/mark3labs/mcp-go/mcp" + "github.com/mark3labs/mcp-go/server" +) + +// These registry tests verify that newMCPServer wires up the expected tools, +// prompts, and resources. mcp-go v0.45.0 does not expose a public inventory API +// for MCPServer, so the tests read server internals through a single helper +// layer below. If mcp-go changes its internal field layout, update only the +// helpers in this file rather than spreading reflect/unsafe access across tests. + +func TestNewMCPServerRegistersExpectedTools(t *testing.T) { + srv := newMCPServer() + + got := sortedToolNames(srv) + want := []string{ + "execute_mqe_expression", + "get_mqe_metric_type", + "list_endpoints", + "list_instances", + "list_layers", + "list_mqe_metrics", + "list_processes", + "list_services", + "query_alarms", + "query_endpoints_topology", + "query_events", + "query_instances_topology", + "query_logs", + "query_processes_topology", + "query_services_topology", + "query_traces", + } + + assertStringSlicesEqual(t, got, want) +} + +func TestNewMCPServerRegistersExpectedPrompts(t *testing.T) { + srv := newMCPServer() + + got := sortedPromptNames(srv) + want := []string{ + "analyze-logs", + "analyze-performance", + "build-mqe-query", + "compare-services", + "explore-metrics", + "explore-service-topology", + "generate_duration", + "investigate-traces", + "top-services", + "trace-deep-dive", + } + + assertStringSlicesEqual(t, got, want) +} + +func TestNewMCPServerRegistersExpectedResources(t *testing.T) { + srv := newMCPServer() + + resources := resourceMap(srv) + got := make([]string, 0, len(resources)) + for uri := range resources { + got = append(got, uri) + } + sort.Strings(got) + + want := []string{ + "mqe://docs/ai_prompt", + "mqe://docs/examples", + "mqe://docs/syntax", + "mqe://metrics/available", + } + + assertStringSlicesEqual(t, got, want) +} + +func TestPromptMetadataIncludesExpectedArguments(t *testing.T) { + srv := newMCPServer() + prompts := promptMap(srv) + + prompt, ok := prompts["generate_duration"] + if !ok { + t.Fatal("generate_duration prompt not registered") + } + if prompt.Description == "" { + t.Fatal("generate_duration prompt description is empty") + } + if len(prompt.Arguments) != 1 { + t.Fatalf("generate_duration prompt arguments = %d, want 1", len(prompt.Arguments)) + } + if prompt.Arguments[0].Name != "time_range" || !prompt.Arguments[0].Required { + t.Fatalf("unexpected generate_duration argument: %+v", prompt.Arguments[0]) + } + + tracePrompt, ok := prompts["trace-deep-dive"] + if !ok { + t.Fatal("trace-deep-dive prompt not registered") + } + if len(tracePrompt.Arguments) != 2 { + t.Fatalf("trace-deep-dive prompt arguments = %d, want 2", len(tracePrompt.Arguments)) + } + if tracePrompt.Arguments[0].Name != "trace_id" || !tracePrompt.Arguments[0].Required { + t.Fatalf("unexpected first trace-deep-dive argument: %+v", tracePrompt.Arguments[0]) + } +} + +func TestResourceMetadataIncludesExpectedMIMETypes(t *testing.T) { + srv := newMCPServer() + resources := resourceMap(srv) + + tests := []struct { + uri string + name string + mimeType string + }{ + {uri: "mqe://docs/syntax", name: "MQE Detailed Syntax Rules", mimeType: "text/markdown"}, + {uri: "mqe://docs/examples", name: "MQE Examples", mimeType: "application/json"}, + {uri: "mqe://metrics/available", name: "Available Metrics", mimeType: "application/json"}, + {uri: "mqe://docs/ai_prompt", name: "MQE AI Understanding Guide", mimeType: "text/markdown"}, + } + + for _, tc := range tests { + t.Run(tc.uri, func(t *testing.T) { + resource, ok := resources[tc.uri] + if !ok { + t.Fatalf("resource %q not registered", tc.uri) + } + if resource.Name != tc.name { + t.Fatalf("resource name = %q, want %q", resource.Name, tc.name) + } + if resource.MIMEType != tc.mimeType { + t.Fatalf("resource MIME type = %q, want %q", resource.MIMEType, tc.mimeType) + } + }) + } +} + +func TestToolMetadataIncludesExpectedDescriptionsAndSchemas(t *testing.T) { + srv := newMCPServer() + tools := toolMap(srv) + + tests := []struct { + name string + expectDesc bool + expectProperties []string + }{ + {name: "query_traces", expectDesc: true, expectProperties: []string{"service_id", "trace_id", "view"}}, + {name: "execute_mqe_expression", expectDesc: true, expectProperties: []string{"expression", "service_name", "debug"}}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + tool, ok := tools[tc.name] + if !ok { + t.Fatalf("tool %q not registered", tc.name) + } + if tc.expectDesc && tool.Description == "" { + t.Fatalf("tool %q description is empty", tc.name) + } + properties := tool.InputSchema.Properties + for _, property := range tc.expectProperties { + if _, ok := properties[property]; !ok { + t.Fatalf("tool %q missing input schema property %q", tc.name, property) + } + } + }) + } +} + +func toolMap(srv *server.MCPServer) map[string]mcp.Tool { + serverTools := mustReadServerField(testedServerValue(srv), "tools") + result := make(map[string]mcp.Tool, serverTools.Len()) + + iter := serverTools.MapRange() + for iter.Next() { + name := iter.Key().String() + toolValue := copyReflectValue(iter.Value()) + result[name] = toolValue.FieldByName("Tool").Interface().(mcp.Tool) + } + + return result +} + +func promptMap(srv *server.MCPServer) map[string]mcp.Prompt { + serverPrompts := mustReadServerField(testedServerValue(srv), "prompts") + result := make(map[string]mcp.Prompt, serverPrompts.Len()) + + iter := serverPrompts.MapRange() + for iter.Next() { + result[iter.Key().String()] = copyReflectValue(iter.Value()).Interface().(mcp.Prompt) + } + + return result +} + +func resourceMap(srv *server.MCPServer) map[string]mcp.Resource { + serverResources := mustReadServerField(testedServerValue(srv), "resources") + result := make(map[string]mcp.Resource, serverResources.Len()) + + iter := serverResources.MapRange() + for iter.Next() { + resourceField := copyReflectValue(iter.Value()).FieldByName("resource") + result[iter.Key().String()] = readPrivateField(resourceField).Interface().(mcp.Resource) + } + + return result +} + +func sortedToolNames(srv *server.MCPServer) []string { + tools := toolMap(srv) + names := make([]string, 0, len(tools)) + for name := range tools { + names = append(names, name) + } + sort.Strings(names) + return names +} + +func sortedPromptNames(srv *server.MCPServer) []string { + prompts := promptMap(srv) + names := make([]string, 0, len(prompts)) + for name := range prompts { + names = append(names, name) + } + sort.Strings(names) + return names +} + +func assertStringSlicesEqual(t *testing.T, got, want []string) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Fatalf("values mismatch:\n got: %v\nwant: %v", got, want) + } +} + +func readPrivateField(v reflect.Value) reflect.Value { + return reflect.NewAt(v.Type(), unsafe.Pointer(v.UnsafeAddr())).Elem() +} + +func testedServerValue(srv *server.MCPServer) reflect.Value { + return reflect.ValueOf(srv).Elem() +} + +func mustReadServerField(srv reflect.Value, fieldName string) reflect.Value { + field := srv.FieldByName(fieldName) + if !field.IsValid() { + panic("mcp-go MCPServer no longer has field " + fieldName) + } + return readPrivateField(field) +} + +func copyReflectValue(v reflect.Value) reflect.Value { + cloned := reflect.New(v.Type()).Elem() + cloned.Set(v) + return cloned +} diff --git a/internal/swmcp/server_test.go b/internal/swmcp/server_test.go new file mode 100644 index 0000000..28056c8 --- /dev/null +++ b/internal/swmcp/server_test.go @@ -0,0 +1,171 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +package swmcp + +import ( + "context" + "net/http" + "testing" + + "github.com/apache/skywalking-cli/pkg/contextkey" + "github.com/spf13/viper" + + "github.com/apache/skywalking-mcp/internal/config" +) + +const ( + configuredHTTPOAPURL = "http://configured-oap:12800/graphql" + configuredHTTPSOAPURL = "https://configured-oap.example.com/graphql" +) + +func TestConfiguredSkyWalkingURLUsesDefaultWhenUnset(t *testing.T) { + t.Cleanup(viper.Reset) + + got := configuredSkyWalkingURL() + if got != config.DefaultSWURL { + t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, config.DefaultSWURL) + } +} + +func TestConfiguredSkyWalkingURLFinalizesConfiguredValue(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "https://configured-oap.example.com:12800/") + + got := configuredSkyWalkingURL() + want := "https://configured-oap.example.com:12800/graphql" + if got != want { + t.Fatalf("configuredSkyWalkingURL() = %q, want %q", got, want) + } +} + +func TestResolveEnvVar(t *testing.T) { + t.Setenv("SW_TEST_SECRET", "resolved-secret") + + tests := []struct { + name string + value string + want string + }{ + {name: "raw", value: "raw-value", want: "raw-value"}, + {name: "env", value: "${SW_TEST_SECRET}", want: "resolved-secret"}, + {name: "trimmed env", value: " ${SW_TEST_SECRET} ", want: "resolved-secret"}, + {name: "missing env", value: "${SW_TEST_MISSING}", want: ""}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := resolveEnvVar(tc.value); got != tc.want { + t.Fatalf("resolveEnvVar(%q) = %q, want %q", tc.value, got, tc.want) + } + }) + } +} + +func TestWithConfiguredAuth(t *testing.T) { + t.Cleanup(viper.Reset) + t.Setenv("SW_TEST_USER", "env-user") + t.Setenv("SW_TEST_PASS", "env-pass") + viper.Set("username", "${SW_TEST_USER}") + viper.Set("password", "${SW_TEST_PASS}") + + ctx := withConfiguredAuth(context.Background()) + + gotUser, _ := ctx.Value(contextkey.Username{}).(string) + if gotUser != "env-user" { + t.Fatalf("username = %q, want %q", gotUser, "env-user") + } + + gotPass, _ := ctx.Value(contextkey.Password{}).(string) + if gotPass != "env-pass" { + t.Fatalf("password = %q, want %q", gotPass, "env-pass") + } +} + +func TestWithConfiguredAuthSkipsEmptyUsername(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("password", "password-only") + + ctx := withConfiguredAuth(context.Background()) + + if got, ok := ctx.Value(contextkey.Username{}).(string); ok || got != "" { + t.Fatalf("username unexpectedly set to %q", got) + } + if got, ok := ctx.Value(contextkey.Password{}).(string); ok || got != "" { + t.Fatalf("password unexpectedly set to %q", got) + } +} + +func TestEnhanceStdioContextFuncUsesConfiguredURLAndAuth(t *testing.T) { + t.Cleanup(viper.Reset) + t.Setenv("SW_STDIO_PASS", "stdio-pass") + viper.Set("url", "https://configured-oap.example.com") + viper.Set("username", "stdio-user") + viper.Set("password", "${SW_STDIO_PASS}") + + ctx := EnhanceStdioContextFunc()(context.Background()) + + gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string) + if gotURL != configuredHTTPSOAPURL { + t.Fatalf("base URL = %q", gotURL) + } + + gotUser, _ := ctx.Value(contextkey.Username{}).(string) + if gotUser != "stdio-user" { + t.Fatalf("username = %q", gotUser) + } + + gotPass, _ := ctx.Value(contextkey.Password{}).(string) + if gotPass != "stdio-pass" { + t.Fatalf("password = %q", gotPass) + } +} + +func TestEnhanceHTTPContextFuncDoesNotUseSWURLHeader(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "http://configured-oap:12800") + + req, err := http.NewRequest(http.MethodPost, "http://client/request", http.NoBody) + if err != nil { + t.Fatalf("create request: %v", err) + } + req.Header.Set("SW-URL", "http://attacker.invalid:8080") + + ctx := EnhanceHTTPContextFunc()(context.Background(), req) + + gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string) + if gotURL != configuredHTTPOAPURL { + t.Fatalf("base URL = %q", gotURL) + } +} + +func TestEnhanceSSEContextFuncDoesNotUseSWURLHeader(t *testing.T) { + t.Cleanup(viper.Reset) + viper.Set("url", "https://configured-oap.example.com") + + req, err := http.NewRequest(http.MethodGet, "http://client/events", http.NoBody) + if err != nil { + t.Fatalf("create request: %v", err) + } + req.Header.Set("SW-URL", "https://attacker.invalid") + + ctx := EnhanceSSEContextFunc()(context.Background(), req) + + gotURL, _ := ctx.Value(contextkey.BaseURL{}).(string) + if gotURL != configuredHTTPSOAPURL { + t.Fatalf("base URL = %q", gotURL) + } +} diff --git a/internal/swmcp/session.go b/internal/swmcp/session.go deleted file mode 100644 index 5a684c1..0000000 --- a/internal/swmcp/session.go +++ /dev/null @@ -1,138 +0,0 @@ -// Licensed to Apache Software Foundation (ASF) under one or more contributor -// license agreements. See the NOTICE file distributed with -// this work for additional information regarding copyright -// ownership. Apache Software Foundation (ASF) licenses this file to you under -// the Apache License, Version 2.0 (the "License"); you may -// not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package swmcp - -import ( - "context" - "fmt" - "sync" - - "github.com/mark3labs/mcp-go/mcp" - "github.com/mark3labs/mcp-go/server" - - "github.com/apache/skywalking-mcp/internal/tools" -) - -// sessionKey is the context key for looking up the session store. -type sessionKey struct{} - -// Session holds per-session SkyWalking connection configuration. -type Session struct { - mu sync.RWMutex - url string - username string - password string -} - -// SetConnection updates the session's connection parameters. -func (s *Session) SetConnection(url, username, password string) { - s.mu.Lock() - defer s.mu.Unlock() - s.url = url - s.username = username - s.password = password -} - -// URL returns the session's configured URL, or empty if not set. -func (s *Session) URL() string { - s.mu.RLock() - defer s.mu.RUnlock() - return s.url -} - -// Username returns the session's configured username. -func (s *Session) Username() string { - s.mu.RLock() - defer s.mu.RUnlock() - return s.username -} - -// Password returns the session's configured password. -func (s *Session) Password() string { - s.mu.RLock() - defer s.mu.RUnlock() - return s.password -} - -// SessionFromContext retrieves the session from the context, or nil if not present. -func SessionFromContext(ctx context.Context) *Session { - s, _ := ctx.Value(sessionKey{}).(*Session) - return s -} - -// WithSession attaches a session to the context. -func WithSession(ctx context.Context, s *Session) context.Context { - return context.WithValue(ctx, sessionKey{}, s) -} - -// SetSkyWalkingURLRequest represents the request for the set_skywalking_url tool. -type SetSkyWalkingURLRequest struct { - URL string `json:"url"` - Username string `json:"username,omitempty"` - Password string `json:"password,omitempty"` -} - -func setSkyWalkingURL(ctx context.Context, req *SetSkyWalkingURLRequest) (*mcp.CallToolResult, error) { - if req.URL == "" { - return mcp.NewToolResultError("url is required"), nil - } - - session := SessionFromContext(ctx) - if session == nil { - return mcp.NewToolResultError("session not available"), nil - } - - finalURL := tools.FinalizeURL(req.URL) - session.SetConnection(finalURL, resolveEnvVar(req.Username), resolveEnvVar(req.Password)) - - msg := fmt.Sprintf("SkyWalking URL set to %s", finalURL) - if req.Username != "" { - msg += " with basic auth credentials" - } - return mcp.NewToolResultText(msg), nil -} - -// AddSessionTools registers session management tools with the MCP server. -func AddSessionTools(s *server.MCPServer) { - tool := tools.NewTool( - "set_skywalking_url", - `Set the SkyWalking OAP server URL and optional basic auth credentials for this session. -This tool is only available in stdio transport mode. - -This tool configures the connection to SkyWalking OAP for all subsequent tool calls in the current session. -The URL and credentials persist for the lifetime of the session. - -Priority: session URL (set by this tool) > --sw-url flag > default (http://localhost:12800/graphql) -For SSE/HTTP transports, use the SW-URL HTTP header or --sw-url flag instead. - -Credentials support raw values or environment variable references using ${ENV_VAR} syntax. - -Examples: -- {"url": "http://demo.skywalking.apache.org:12800"}: Connect without auth -- {"url": "http://oap.internal:12800", "username": "admin", "password": "admin"}: Connect with basic auth -- {"url": "https://skywalking.example.com:443", "username": "${SW_USER}", "password": "${SW_PASS}"}: Auth via env vars`, - setSkyWalkingURL, - mcp.WithString("url", mcp.Required(), - mcp.Description("SkyWalking OAP server URL (required). Example: http://localhost:12800")), - mcp.WithString("username", - mcp.Description("Username for basic auth (optional). Supports ${ENV_VAR} syntax.")), - mcp.WithString("password", - mcp.Description("Password for basic auth (optional). Supports ${ENV_VAR} syntax.")), - ) - tool.Register(s) -} diff --git a/internal/swmcp/sse.go b/internal/swmcp/sse.go index 1e9a04e..14365a9 100644 --- a/internal/swmcp/sse.go +++ b/internal/swmcp/sse.go @@ -72,7 +72,7 @@ func runSSEServer(ctx context.Context, cfg *config.SSEServerConfig) error { } sseServer := server.NewSSEServer( - newMCPServer(false), + newMCPServer(), server.WithStaticBasePath(cfg.BasePath), server.WithSSEContextFunc(EnhanceSSEContextFunc()), ) diff --git a/internal/swmcp/stdio.go b/internal/swmcp/stdio.go index 9dd58df..02abb4a 100644 --- a/internal/swmcp/stdio.go +++ b/internal/swmcp/stdio.go @@ -60,7 +60,7 @@ func runStdioServer(ctx context.Context, cfg *config.StdioServerConfig) error { ctx, stop := signal.NotifyContext(ctx, os.Interrupt, syscall.SIGTERM) defer stop() - stdioServer := server.NewStdioServer(newMCPServer(true)) + stdioServer := server.NewStdioServer(newMCPServer()) logger, err := initLogger(cfg.LogFilePath) if err != nil { diff --git a/internal/swmcp/streamable.go b/internal/swmcp/streamable.go index 8f41194..0500352 100644 --- a/internal/swmcp/streamable.go +++ b/internal/swmcp/streamable.go @@ -57,7 +57,7 @@ func NewStreamable() *cobra.Command { // runStreamableServer starts the Streamable server with the provided configuration. func runStreamableServer(cfg *config.StreamableServerConfig) error { httpServer := server.NewStreamableHTTPServer( - newMCPServer(false), + newMCPServer(), server.WithStateLess(true), server.WithLogger(log.StandardLogger()), server.WithHTTPContextFunc(EnhanceHTTPContextFunc()), diff --git a/internal/tools/mqe.go b/internal/tools/mqe.go index 0f4189a..3a45396 100644 --- a/internal/tools/mqe.go +++ b/internal/tools/mqe.go @@ -25,8 +25,10 @@ import ( "fmt" "io" "net/http" - "strings" + "regexp" "time" + "unicode" + "unicode/utf8" "github.com/mark3labs/mcp-go/mcp" "github.com/mark3labs/mcp-go/server" @@ -41,6 +43,16 @@ func AddMQETools(mcp *server.MCPServer) { MQEMetricsTypeTool.Register(mcp) } +const ( + maxMQEExpressionLength = 2048 + maxMQEExpressionDepth = 12 + maxMQEEntityFieldLen = 256 + maxMQERegexLength = 256 + maxMetricNameLength = 128 +) + +var metricNamePattern = regexp.MustCompile(`^[A-Za-z0-9_.:-]+$`) + // GraphQLRequest represents a GraphQL request type GraphQLRequest struct { Query string `json:"query"` @@ -101,8 +113,8 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - bodyBytes, _ := io.ReadAll(resp.Body) - return nil, fmt.Errorf("HTTP request failed with status: %d, body: %s", resp.StatusCode, string(bodyBytes)) + _, _ = io.ReadAll(resp.Body) + return nil, fmt.Errorf("GraphQL request failed with HTTP status %d", resp.StatusCode) } var graphqlResp GraphQLResponse @@ -111,11 +123,7 @@ func executeGraphQLWithContext(ctx context.Context, query string, variables map[ } if len(graphqlResp.Errors) > 0 { - var errorMsgs []string - for _, err := range graphqlResp.Errors { - errorMsgs = append(errorMsgs, err.Message) - } - return nil, fmt.Errorf("GraphQL errors: %s", strings.Join(errorMsgs, ", ")) + return nil, fmt.Errorf("GraphQL query failed") } return &graphqlResp, nil @@ -307,6 +315,9 @@ func executeMQEExpression(ctx context.Context, req *MQEExpressionRequest) (*mcp. if req.Expression == "" { return mcp.NewToolResultError("expression is required"), nil } + if err := validateMQEExpressionRequest(req); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } entity := buildMQEEntity(ctx, req) timeCtx := GetTimeContext(ctx) @@ -386,6 +397,10 @@ func executeMQEExpression(ctx context.Context, req *MQEExpressionRequest) (*mcp. // listMQEMetrics lists available metrics func listMQEMetrics(ctx context.Context, req *MQEMetricsListRequest) (*mcp.CallToolResult, error) { + if err := validateMQEMetricsListRequest(req); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } + // GraphQL query for listing metrics query := ` query listMetrics($regex: String) { @@ -438,6 +453,9 @@ func getMQEMetricsType(ctx context.Context, req *MQEMetricsTypeRequest) (*mcp.Ca if req.MetricName == "" { return mcp.NewToolResultError("metric_name must be provided"), nil } + if err := validateMetricName(req.MetricName); err != nil { + return mcp.NewToolResultError(err.Error()), nil + } // GraphQL query for getting metric type query := ` @@ -462,6 +480,114 @@ func getMQEMetricsType(ctx context.Context, req *MQEMetricsTypeRequest) (*mcp.Ca return mcp.NewToolResultText(string(jsonBytes)), nil } +func validateMQEExpressionRequest(req *MQEExpressionRequest) error { + if err := validateMQEExpression(req.Expression); err != nil { + return err + } + + for fieldName, value := range map[string]string{ + "service_name": req.ServiceName, + "layer": req.Layer, + "service_instance_name": req.ServiceInstanceName, + "endpoint_name": req.EndpointName, + "process_name": req.ProcessName, + "dest_service_name": req.DestServiceName, + "dest_layer": req.DestLayer, + "dest_service_instance_name": req.DestServiceInstanceName, + "dest_endpoint_name": req.DestEndpointName, + "dest_process_name": req.DestProcessName, + } { + if err := validateMQETextField(fieldName, value, maxMQEEntityFieldLen); err != nil { + return err + } + } + + return nil +} + +func validateMQEMetricsListRequest(req *MQEMetricsListRequest) error { + if req == nil || req.Regex == "" { + return nil + } + if err := validateMQETextField("regex", req.Regex, maxMQERegexLength); err != nil { + return err + } + if _, err := regexp.Compile(req.Regex); err != nil { + return fmt.Errorf("regex is invalid") + } + return nil +} + +func validateMetricName(metricName string) error { + if err := validateMQETextField("metric_name", metricName, maxMetricNameLength); err != nil { + return err + } + if !metricNamePattern.MatchString(metricName) { + return fmt.Errorf("metric_name contains invalid characters") + } + return nil +} + +func validateMQEExpression(expression string) error { + if !utf8.ValidString(expression) { + return fmt.Errorf("expression must be valid UTF-8") + } + if len(expression) > maxMQEExpressionLength { + return fmt.Errorf("expression exceeds maximum length of %d characters", maxMQEExpressionLength) + } + if containsUnsafeControlChars(expression) { + return fmt.Errorf("expression contains invalid control characters") + } + if nestingDepth(expression) > maxMQEExpressionDepth { + return fmt.Errorf("expression exceeds maximum nesting depth of %d", maxMQEExpressionDepth) + } + return nil +} + +func validateMQETextField(fieldName, value string, maxLen int) error { + if value == "" { + return nil + } + if !utf8.ValidString(value) { + return fmt.Errorf("%s must be valid UTF-8", fieldName) + } + if len(value) > maxLen { + return fmt.Errorf("%s exceeds maximum length of %d characters", fieldName, maxLen) + } + if containsUnsafeControlChars(value) { + return fmt.Errorf("%s contains invalid control characters", fieldName) + } + return nil +} + +func containsUnsafeControlChars(value string) bool { + for _, r := range value { + if unicode.IsControl(r) && r != '\n' && r != '\r' && r != '\t' { + return true + } + } + return false +} + +func nestingDepth(value string) int { + depth := 0 + maxDepth := 0 + for _, r := range value { + switch r { + case '(', '{', '[': + depth++ + if depth > maxDepth { + maxDepth = depth + } + case ')', '}', ']': + if depth > 0 { + depth-- + } + } + } + return maxDepth +} + var MQEExpressionTool = NewTool( "execute_mqe_expression", `Execute MQE (Metrics Query Expression) to query and calculate metrics data. diff --git a/internal/tools/mqe_test.go b/internal/tools/mqe_test.go new file mode 100644 index 0000000..37d08af --- /dev/null +++ b/internal/tools/mqe_test.go @@ -0,0 +1,119 @@ +// Licensed to Apache Software Foundation (ASF) under one or more contributor +// license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright +// ownership. Apache Software Foundation (ASF) licenses this file to you under +// the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tools + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/apache/skywalking-cli/pkg/contextkey" + "github.com/mark3labs/mcp-go/mcp" +) + +func TestValidateMQEExpressionRequestRejectsDeeplyNestedExpression(t *testing.T) { + req := &MQEExpressionRequest{ + Expression: strings.Repeat("(", maxMQEExpressionDepth+1) + "service_cpm" + strings.Repeat(")", maxMQEExpressionDepth+1), + } + + err := validateMQEExpressionRequest(req) + if err == nil || !strings.Contains(err.Error(), "maximum nesting depth") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateMQEMetricsListRequestRejectsInvalidRegex(t *testing.T) { + err := validateMQEMetricsListRequest(&MQEMetricsListRequest{Regex: "("}) + if err == nil || err.Error() != "regex is invalid" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestValidateMetricNameRejectsInvalidCharacters(t *testing.T) { + err := validateMetricName("service cpm") + if err == nil || err.Error() != "metric_name contains invalid characters" { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestExecuteMQEExpressionRejectsOverlongEntityField(t *testing.T) { + req := &MQEExpressionRequest{ + Expression: "service_cpm", + ServiceName: strings.Repeat("a", maxMQEEntityFieldLen+1), + } + + result, err := executeMQEExpression(context.Background(), req) + if err != nil { + t.Fatalf("executeMQEExpression returned error: %v", err) + } + if !result.IsError { + t.Fatal("expected tool error result") + } + assertToolResultContains(t, result, "service_name exceeds maximum length") +} + +func TestExecuteGraphQLWithContextSanitizesHTTPErrorBody(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + http.Error(w, "sensitive backend details", http.StatusBadGateway) + })) + defer ts.Close() + + ctx := context.WithValue(context.Background(), contextkey.BaseURL{}, ts.URL) + _, err := executeGraphQLWithContext(ctx, "query { ping }", map[string]interface{}{}) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "HTTP status 502") { + t.Fatalf("unexpected error: %v", err) + } + if strings.Contains(err.Error(), "sensitive backend details") { + t.Fatalf("backend body leaked in error: %v", err) + } +} + +func TestExecuteGraphQLWithContextSanitizesGraphQLErrors(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`{"errors":[{"message":"database stack trace"}]}`)) + })) + defer ts.Close() + + ctx := context.WithValue(context.Background(), contextkey.BaseURL{}, ts.URL) + _, err := executeGraphQLWithContext(ctx, "query { ping }", map[string]interface{}{}) + if err == nil { + t.Fatal("expected error") + } + if err.Error() != "GraphQL query failed" { + t.Fatalf("unexpected error: %v", err) + } +} + +func assertToolResultContains(t *testing.T, result *mcp.CallToolResult, want string) { + t.Helper() + if len(result.Content) == 0 { + t.Fatal("tool result had no content") + } + text, ok := result.Content[0].(mcp.TextContent) + if !ok { + t.Fatalf("unexpected content type: %T", result.Content[0]) + } + if !strings.Contains(text.Text, want) { + t.Fatalf("tool result text %q does not contain %q", text.Text, want) + } +}