From 370304adc84d6166d02dae35634f7dfcbcccaa6d Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Tue, 12 May 2026 16:12:11 +0000 Subject: [PATCH 1/5] feat: add support for registering and calling custom JSON-RPC methods with type safety --- examples/server/custom-method/main.go | 105 ++++++++++++ mcp/client.go | 11 +- mcp/custom.go | 109 ++++++++++++ mcp/custom_test.go | 235 ++++++++++++++++++++++++++ mcp/server.go | 16 +- mcp/streamable.go | 4 +- 6 files changed, 477 insertions(+), 3 deletions(-) create mode 100644 examples/server/custom-method/main.go create mode 100644 mcp/custom.go create mode 100644 mcp/custom_test.go diff --git a/examples/server/custom-method/main.go b/examples/server/custom-method/main.go new file mode 100644 index 00000000..d36bc6eb --- /dev/null +++ b/examples/server/custom-method/main.go @@ -0,0 +1,105 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +// The custom-method example demonstrates registering and calling a custom +// JSON-RPC method that is not part of the standard MCP spec. +// +// The server registers a "latin/translate" method that translates simple +// English phrases into Latin. A client connects over an in-memory transport, +// calls the custom method, and prints the result. +package main + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type TranslateParams struct { + mcp.ParamsBase + Text string `json:"text"` +} + +type TranslateResult struct { + mcp.ResultBase + Latin string `json:"latin"` +} + +var translations = map[string]string{ + "hello": "salve", + "goodbye": "vale", + "thank you": "gratias tibi ago", + "how are you": "quid agis", + "good morning": "bonum mane", + "good night": "bonam noctem", + "friend": "amicus", + "water": "aqua", + "love": "amor", + "war": "bellum", + "peace": "pax", + "truth": "veritas", + "light": "lux", + "time": "tempus", + "life": "vita", + "death": "mors", + "star": "stella", + "earth": "terra", + "sea": "mare", + "the die is cast": "alea iacta est", + "i came i saw i conquered": "veni vidi vici", + "seize the day": "carpe diem", +} + +func main() { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil) + + mcp.AddReceivingCustomMethod(server, "latin/translate", + func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { + key := strings.ToLower(strings.TrimSpace(params.Text)) + latin, ok := translations[key] + if !ok { + latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases()) + } + return &TranslateResult{Latin: latin}, nil + }) + + ct, st := mcp.NewInMemoryTransports() + + ss, err := server.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) + translate := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate") + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"} + for _, phrase := range phrases { + result, err := translate(ctx, cs, &TranslateParams{Text: phrase}) + if err != nil { + log.Fatalf("translate %q: %v", phrase, err) + } + fmt.Printf("%-35s → %s\n", phrase, result.Latin) + } +} + +func knownPhrases() string { + phrases := make([]string, 0, len(translations)) + for k := range translations { + phrases = append(phrases, fmt.Sprintf("%q", k)) + } + return strings.Join(phrases, ", ") +} diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..47686481 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -10,6 +10,7 @@ import ( "fmt" "iter" "log/slog" + "maps" "slices" "strings" "sync" @@ -32,6 +33,7 @@ type Client struct { sessions []*ClientSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + customSendMethods map[string]methodInfo } // NewClient creates a new [Client]. @@ -64,6 +66,7 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + customSendMethods: make(map[string]methodInfo), } } @@ -945,7 +948,13 @@ var clientMethodInfos = map[string]methodInfo{ } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { - return serverMethodInfos + if len(cs.client.customSendMethods) == 0 { + return serverMethodInfos + } + infos := make(map[string]methodInfo, len(serverMethodInfos)+len(cs.client.customSendMethods)) + maps.Copy(infos, serverMethodInfos) + maps.Copy(infos, cs.client.customSendMethods) + return infos } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { diff --git a/mcp/custom.go b/mcp/custom.go new file mode 100644 index 00000000..70b451ba --- /dev/null +++ b/mcp/custom.go @@ -0,0 +1,109 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "reflect" +) + +// ParamsBase can be embedded in custom parameter structs to satisfy the +// [Params] interface. It provides the required [Meta] field and the unexported +// isParams marker method. +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +type ParamsBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ParamsBase) isParams() {} + +// ResultBase can be embedded in custom result structs to satisfy the +// [Result] interface. It provides the required [Meta] field and the unexported +// isResult marker method. +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +type ResultBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ResultBase) isResult() {} + +// AddReceivingCustomMethod registers a handler for a custom (non-standard) +// JSON-RPC method on the server. +// +// When a client sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// Custom methods go through the server's middleware chain just like standard +// MCP methods (tools/call, prompts/list, etc.). +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]: +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +// +// mcp.AddReceivingCustomMethod(server, "acme/search", +// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { +// return &SearchResult{Hits: []string{"result"}}, nil +// }) +func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, + handler func(ctx context.Context, ss *ServerSession, params P) (R, error), +) { + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) +} + +// AddSendingCustomMethod registers a custom method that the client can send +// to the server and returns a typed caller function. +// +// The returned function calls the custom method through the client's sending +// middleware chain, with full type safety on both params and result. +// +// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") +// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) +func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( + c *Client, + method string, +) func(ctx context.Context, cs *ClientSession, params P) (R, error) { + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.customSendMethods[method] = mi + + return func(ctx context.Context, cs *ClientSession, params P) (R, error) { + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) + } +} diff --git a/mcp/custom_test.go b/mcp/custom_test.go new file mode 100644 index 00000000..5a6d76d8 --- /dev/null +++ b/mcp/custom_test.go @@ -0,0 +1,235 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +package mcp + +import ( + "context" + "net/http" + "testing" + + "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" +) + +type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` +} + +type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` +} + +// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, +// bypassing the SDK's typed method dispatch. +func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) +} + +func TestAddReceivingCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { + t.Fatal(err) + } + + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} + +func TestCustomMethodGoesThoughMiddleware(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + var middlewareCalled bool + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == "acme/ping" { + middlewareCalled = true + } + return next(ctx, method, req) + } + }) + + type pingParams struct { + ParamsBase + } + type pingResult struct { + ResultBase + Pong bool `json:"pong"` + } + AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { + return &pingResult{Pong: true}, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result pingResult + if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { + t.Fatal(err) + } + + if !result.Pong { + t.Error("expected Pong to be true") + } + if !middlewareCalled { + t.Error("middleware was not called for custom method") + } +} + +func TestCustomMethodNotOnOtherServers(t *testing.T) { + ctx := context.Background() + + type emptyParams struct{ ParamsBase } + type emptyResult struct{ ResultBase } + + // Server 1 has the custom method. + s1 := NewServer(testImpl, nil) + AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { + return &emptyResult{}, nil + }) + + // Server 2 does NOT have the custom method. + s2 := NewServer(testImpl, nil) + + // Test s1: custom method should work. + ct1, st1 := NewInMemoryTransports() + ss1, err := s1.Connect(ctx, st1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss1.Close() }) + + c1 := NewClient(testImpl, nil) + cs1, err := c1.Connect(ctx, ct1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs1.Close() }) + + if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { + t.Fatalf("custom method on s1 should work: %v", err) + } + + // Test s2: custom method should fail. + ct2, st2 := NewInMemoryTransports() + ss2, err := s2.Connect(ctx, st2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss2.Close() }) + + c2 := NewClient(testImpl, nil) + cs2, err := c2.Connect(ctx, ct2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs2.Close() }) + + err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) + if err == nil { + t.Fatal("expected error calling custom method on s2") + } +} + +func TestCallCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + return &searchResult{ + Hits: []string{"result for " + params.Query}, + Total: 1, + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} + +func TestCustomMethodStreamableHTTP(t *testing.T) { + s := NewServer(testImpl, nil) + + type echoParams struct { + ParamsBase + Msg string `json:"msg"` + } + type echoResult struct { + ResultBase + Reply string `json:"reply"` + } + AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { + return &echoResult{Reply: "echo: " + params.Msg}, nil + }) + + handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) + _ = handler +} diff --git a/mcp/server.go b/mcp/server.go index d25c7922..98c0c985 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -52,6 +52,7 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + customMethods map[string]methodInfo resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } @@ -195,6 +196,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), + customMethods: make(map[string]methodInfo), } } @@ -1422,7 +1424,19 @@ func initializeMethodInfo() methodInfo { func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } +func (s *Server) receivingMethodInfos() map[string]methodInfo { + if len(s.customMethods) == 0 { + return serverMethodInfos + } + infos := make(map[string]methodInfo, len(serverMethodInfos)+len(s.customMethods)) + maps.Copy(infos, serverMethodInfos) + maps.Copy(infos, s.customMethods) + return infos +} + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { + return ss.server.receivingMethodInfos() +} func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server diff --git a/mcp/streamable.go b/mcp/streamable.go index b8e36553..a4a34a52 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -396,6 +396,7 @@ func connectStreamable(ctx context.Context, server *Server, transport *Streamabl if err != nil { return nil, err } + transport.connection.server = server transport.connection.toolLookup = server.getServerTool return s, nil } @@ -734,6 +735,7 @@ type streamableServerConn struct { logger *slog.Logger + server *Server toolLookup func(name string) (*serverTool, bool) incoming chan jsonrpc.Message // messages from the client to the server @@ -1241,7 +1243,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Preemptively check that this is a valid request, so that we can fail // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. - if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + if _, err := checkRequest(jreq, c.server.receivingMethodInfos()); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return } From c4b8630abd1ce2b02cbc17ee679d1401d8356e49 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 12:34:12 +0000 Subject: [PATCH 2/5] feat: add support for registering and calling type-safe custom MCP methods --- mcp/client.go | 31 ++++++ mcp/client_test.go | 39 ++++++++ mcp/custom.go | 104 -------------------- mcp/custom_test.go | 230 --------------------------------------------- mcp/server.go | 41 ++++++++ mcp/server_test.go | 184 ++++++++++++++++++++++++++++++++++++ mcp/shared.go | 28 ++++++ 7 files changed, 323 insertions(+), 334 deletions(-) diff --git a/mcp/client.go b/mcp/client.go index 47686481..093249ef 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -11,6 +11,7 @@ import ( "iter" "log/slog" "maps" + "reflect" "slices" "strings" "sync" @@ -1227,3 +1228,33 @@ func paginate[P listParams, R listResult[T], T any](ctx context.Context, params } } } + +// AddSendingCustomMethod registers a custom method that the client can send +// to the server and returns a typed caller function. +// +// The returned function calls the custom method through the client's sending +// middleware chain, with full type safety on both params and result. +// +// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") +// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) +func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( + c *Client, + method string, +) func(ctx context.Context, cs *ClientSession, params P) (R, error) { + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.customSendMethods[method] = mi + + return func(ctx context.Context, cs *ClientSession, params P) (R, error) { + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) + } +} diff --git a/mcp/client_test.go b/mcp/client_test.go index 609fd501..6bb5ee8e 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -617,3 +617,42 @@ func TestClientCapabilitiesOverWire(t *testing.T) { }) } } + +func TestCallCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + return &searchResult{ + Hits: []string{"result for " + params.Query}, + Total: 1, + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} diff --git a/mcp/custom.go b/mcp/custom.go index 70b451ba..e5aef305 100644 --- a/mcp/custom.go +++ b/mcp/custom.go @@ -3,107 +3,3 @@ // that can be found in the LICENSE file. package mcp - -import ( - "context" - "reflect" -) - -// ParamsBase can be embedded in custom parameter structs to satisfy the -// [Params] interface. It provides the required [Meta] field and the unexported -// isParams marker method. -// -// type SearchParams struct { -// mcp.ParamsBase -// Query string `json:"query"` -// } -type ParamsBase struct { - Meta `json:"_meta,omitempty"` -} - -func (*ParamsBase) isParams() {} - -// ResultBase can be embedded in custom result structs to satisfy the -// [Result] interface. It provides the required [Meta] field and the unexported -// isResult marker method. -// -// type SearchResult struct { -// mcp.ResultBase -// Hits []string `json:"hits"` -// } -type ResultBase struct { - Meta `json:"_meta,omitempty"` -} - -func (*ResultBase) isResult() {} - -// AddReceivingCustomMethod registers a handler for a custom (non-standard) -// JSON-RPC method on the server. -// -// When a client sends a request with the given method name, the params will be -// unmarshaled into P, the handler will be called, and the returned R will be -// marshaled as the JSON-RPC result. -// -// Custom methods go through the server's middleware chain just like standard -// MCP methods (tools/call, prompts/list, etc.). -// -// P and R must implement [Params] and [Result] respectively, which is most -// easily done by embedding [ParamsBase] and [ResultBase]: -// -// type SearchParams struct { -// mcp.ParamsBase -// Query string `json:"query"` -// } -// -// type SearchResult struct { -// mcp.ResultBase -// Hits []string `json:"hits"` -// } -// -// mcp.AddReceivingCustomMethod(server, "acme/search", -// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { -// return &SearchResult{Hits: []string{"result"}}, nil -// }) -func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( - s *Server, - method string, - handler func(ctx context.Context, ss *ServerSession, params P) (R, error), -) { - typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { - return handler(ctx, req.Session, req.Params) - }) - - s.mu.Lock() - defer s.mu.Unlock() - s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) -} - -// AddSendingCustomMethod registers a custom method that the client can send -// to the server and returns a typed caller function. -// -// The returned function calls the custom method through the client's sending -// middleware chain, with full type safety on both params and result. -// -// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") -// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) -func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( - c *Client, - method string, -) func(ctx context.Context, cs *ClientSession, params P) (R, error) { - mi := methodInfo{ - newResult: func() Result { - return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) - }, - } - - c.mu.Lock() - defer c.mu.Unlock() - c.customSendMethods[method] = mi - - return func(ctx context.Context, cs *ClientSession, params P) (R, error) { - return handleSend[R](ctx, method, &ClientRequest[P]{ - Session: cs, - Params: params, - }) - } -} diff --git a/mcp/custom_test.go b/mcp/custom_test.go index 5a6d76d8..e5aef305 100644 --- a/mcp/custom_test.go +++ b/mcp/custom_test.go @@ -3,233 +3,3 @@ // that can be found in the LICENSE file. package mcp - -import ( - "context" - "net/http" - "testing" - - "github.com/modelcontextprotocol/go-sdk/internal/jsonrpc2" -) - -type searchParams struct { - ParamsBase - Query string `json:"query"` - Limit int `json:"limit,omitempty"` -} - -type searchResult struct { - ResultBase - Hits []string `json:"hits"` - Total int `json:"total"` -} - -// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, -// bypassing the SDK's typed method dispatch. -func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { - return conn.Call(ctx, method, params).Await(ctx, result) -} - -func TestAddReceivingCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - hits := []string{"result for " + params.Query} - return &searchResult{ - Hits: hits, - Total: len(hits), - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result searchResult - if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { - t.Fatal(err) - } - - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} - -func TestCustomMethodGoesThoughMiddleware(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - var middlewareCalled bool - s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == "acme/ping" { - middlewareCalled = true - } - return next(ctx, method, req) - } - }) - - type pingParams struct { - ParamsBase - } - type pingResult struct { - ResultBase - Pong bool `json:"pong"` - } - AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { - return &pingResult{Pong: true}, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result pingResult - if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { - t.Fatal(err) - } - - if !result.Pong { - t.Error("expected Pong to be true") - } - if !middlewareCalled { - t.Error("middleware was not called for custom method") - } -} - -func TestCustomMethodNotOnOtherServers(t *testing.T) { - ctx := context.Background() - - type emptyParams struct{ ParamsBase } - type emptyResult struct{ ResultBase } - - // Server 1 has the custom method. - s1 := NewServer(testImpl, nil) - AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { - return &emptyResult{}, nil - }) - - // Server 2 does NOT have the custom method. - s2 := NewServer(testImpl, nil) - - // Test s1: custom method should work. - ct1, st1 := NewInMemoryTransports() - ss1, err := s1.Connect(ctx, st1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss1.Close() }) - - c1 := NewClient(testImpl, nil) - cs1, err := c1.Connect(ctx, ct1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs1.Close() }) - - if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { - t.Fatalf("custom method on s1 should work: %v", err) - } - - // Test s2: custom method should fail. - ct2, st2 := NewInMemoryTransports() - ss2, err := s2.Connect(ctx, st2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss2.Close() }) - - c2 := NewClient(testImpl, nil) - cs2, err := c2.Connect(ctx, ct2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs2.Close() }) - - err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) - if err == nil { - t.Fatal("expected error calling custom method on s2") - } -} - -func TestCallCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - return &searchResult{ - Hits: []string{"result for " + params.Query}, - Total: 1, - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") - - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) - if err != nil { - t.Fatal(err) - } - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} - -func TestCustomMethodStreamableHTTP(t *testing.T) { - s := NewServer(testImpl, nil) - - type echoParams struct { - ParamsBase - Msg string `json:"msg"` - } - type echoResult struct { - ResultBase - Reply string `json:"reply"` - } - AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { - return &echoResult{Reply: "echo: " + params.Msg}, nil - }) - - handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) - _ = handler -} diff --git a/mcp/server.go b/mcp/server.go index 98c0c985..549645a6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -1644,3 +1644,44 @@ func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageS *res.nextCursorPtr() = nextCursor return res, nil } + +// AddReceivingCustomMethod registers a handler for a custom (non-standard) +// JSON-RPC method on the server. +// +// When a client sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// Custom methods go through the server's middleware chain just like standard +// MCP methods (tools/call, prompts/list, etc.). +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]: +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +// +// mcp.AddReceivingCustomMethod(server, "acme/search", +// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { +// return &SearchResult{Hits: []string{"result"}}, nil +// }) +func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, + handler func(ctx context.Context, ss *ServerSession, params P) (R, error), +) { + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) +} diff --git a/mcp/server_test.go b/mcp/server_test.go index 2937ea2b..c3e7a4b6 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,6 +11,7 @@ import ( "fmt" "log" "log/slog" + "net/http" "slices" "strings" "testing" @@ -1007,3 +1008,186 @@ func TestServerCapabilitiesOverWire(t *testing.T) { }) } } + +type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` +} + +type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` +} + +// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, +// bypassing the SDK's typed method dispatch. +func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) +} + +func TestAddReceivingCustomMethod(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { + t.Fatal(err) + } + + if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { + t.Errorf("unexpected hits: %v", result.Hits) + } + if result.Total != 1 { + t.Errorf("unexpected total: %d", result.Total) + } +} + +func TestCustomMethodGoesThoughMiddleware(t *testing.T) { + ctx := context.Background() + s := NewServer(testImpl, nil) + + var middlewareCalled bool + s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { + return func(ctx context.Context, method string, req Request) (Result, error) { + if method == "acme/ping" { + middlewareCalled = true + } + return next(ctx, method, req) + } + }) + + type pingParams struct { + ParamsBase + } + type pingResult struct { + ResultBase + Pong bool `json:"pong"` + } + AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { + return &pingResult{Pong: true}, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + var result pingResult + if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { + t.Fatal(err) + } + + if !result.Pong { + t.Error("expected Pong to be true") + } + if !middlewareCalled { + t.Error("middleware was not called for custom method") + } +} + +func TestCustomMethodNotOnOtherServers(t *testing.T) { + ctx := context.Background() + + type emptyParams struct{ ParamsBase } + type emptyResult struct{ ResultBase } + + // Server 1 has the custom method. + s1 := NewServer(testImpl, nil) + AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { + return &emptyResult{}, nil + }) + + // Server 2 does NOT have the custom method. + s2 := NewServer(testImpl, nil) + + // Test s1: custom method should work. + ct1, st1 := NewInMemoryTransports() + ss1, err := s1.Connect(ctx, st1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss1.Close() }) + + c1 := NewClient(testImpl, nil) + cs1, err := c1.Connect(ctx, ct1, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs1.Close() }) + + if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { + t.Fatalf("custom method on s1 should work: %v", err) + } + + // Test s2: custom method should fail. + ct2, st2 := NewInMemoryTransports() + ss2, err := s2.Connect(ctx, st2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss2.Close() }) + + c2 := NewClient(testImpl, nil) + cs2, err := c2.Connect(ctx, ct2, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs2.Close() }) + + err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) + if err == nil { + t.Fatal("expected error calling custom method on s2") + } +} + +func TestCustomMethodStreamableHTTP(t *testing.T) { + s := NewServer(testImpl, nil) + + type echoParams struct { + ParamsBase + Msg string `json:"msg"` + } + type echoResult struct { + ResultBase + Reply string `json:"reply"` + } + AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { + return &echoResult{Reply: "echo: " + params.Msg}, nil + }) + + handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) + _ = handler +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..42472704 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -544,6 +544,20 @@ type Params interface { isParams() } +// ParamsBase can be embedded in custom parameter structs to satisfy the +// [Params] interface. It provides the required [Meta] field and the unexported +// isParams marker method. +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +type ParamsBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ParamsBase) isParams() {} + // RequestParams is a parameter (input) type for an MCP request. type RequestParams interface { Params @@ -568,6 +582,20 @@ type Result interface { SetMeta(map[string]any) } +// ResultBase can be embedded in custom result structs to satisfy the +// [Result] interface. It provides the required [Meta] field and the unexported +// isResult marker method. +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +type ResultBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ResultBase) isResult() {} + // emptyResult is returned by methods that have no result, like ping. // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} From 7edd01cff5949d27a18bbdb220dd2c8f7a5eef5e Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 12:38:55 +0000 Subject: [PATCH 3/5] refactor: remove unused custom MCP implementation files --- mcp/custom.go | 5 ----- mcp/custom_test.go | 5 ----- 2 files changed, 10 deletions(-) delete mode 100644 mcp/custom.go delete mode 100644 mcp/custom_test.go diff --git a/mcp/custom.go b/mcp/custom.go deleted file mode 100644 index e5aef305..00000000 --- a/mcp/custom.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by the license -// that can be found in the LICENSE file. - -package mcp diff --git a/mcp/custom_test.go b/mcp/custom_test.go deleted file mode 100644 index e5aef305..00000000 --- a/mcp/custom_test.go +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2025 The Go MCP SDK Authors. All rights reserved. -// Use of this source code is governed by the license -// that can be found in the LICENSE file. - -package mcp From 51f3a2b26438645601424b75491fe03f5a7b2303 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 13:19:59 +0000 Subject: [PATCH 4/5] refactor: consolidate custom method tests into mcp_test.go --- mcp/client_test.go | 37 --------- mcp/mcp_test.go | 69 +++++++++++++++++ mcp/server_test.go | 186 +-------------------------------------------- 3 files changed, 70 insertions(+), 222 deletions(-) diff --git a/mcp/client_test.go b/mcp/client_test.go index 6bb5ee8e..f36d7d00 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -618,41 +618,4 @@ func TestClientCapabilitiesOverWire(t *testing.T) { } } -func TestCallCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - return &searchResult{ - Hits: []string{"result for " + params.Query}, - Total: 1, - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - result, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) - if err != nil { - t.Fatal(err) - } - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..810b1e58 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2372,3 +2372,72 @@ func TestSetErrorPreservesContent(t *testing.T) { } var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} + +func TestCustomMethods(t *testing.T) { + type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` + } + + type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` + } + + callCustom := func(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) + } + + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + // Test raw JSON-RPC call + var result1 searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result1); err != nil { + t.Fatal(err) + } + if len(result1.Hits) != 1 || result1.Hits[0] != "result for hello" { + t.Errorf("raw call: unexpected hits: %v", result1.Hits) + } + if result1.Total != 1 { + t.Errorf("raw call: unexpected total: %d", result1.Total) + } + + // Test typed caller + result2, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result2.Hits) != 1 || result2.Hits[0] != "result for hello" { + t.Errorf("typed call: unexpected hits: %v", result2.Hits) + } + if result2.Total != 1 { + t.Errorf("typed call: unexpected total: %d", result2.Total) + } +} diff --git a/mcp/server_test.go b/mcp/server_test.go index c3e7a4b6..609f8e76 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -11,7 +11,6 @@ import ( "fmt" "log" "log/slog" - "net/http" "slices" "strings" "testing" @@ -1007,187 +1006,4 @@ func TestServerCapabilitiesOverWire(t *testing.T) { } }) } -} - -type searchParams struct { - ParamsBase - Query string `json:"query"` - Limit int `json:"limit,omitempty"` -} - -type searchResult struct { - ResultBase - Hits []string `json:"hits"` - Total int `json:"total"` -} - -// callCustom calls a custom JSON-RPC method via the raw jsonrpc2 connection, -// bypassing the SDK's typed method dispatch. -func callCustom(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { - return conn.Call(ctx, method, params).Await(ctx, result) -} - -func TestAddReceivingCustomMethod(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { - hits := []string{"result for " + params.Query} - return &searchResult{ - Hits: hits, - Total: len(hits), - }, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result searchResult - if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result); err != nil { - t.Fatal(err) - } - - if len(result.Hits) != 1 || result.Hits[0] != "result for hello" { - t.Errorf("unexpected hits: %v", result.Hits) - } - if result.Total != 1 { - t.Errorf("unexpected total: %d", result.Total) - } -} - -func TestCustomMethodGoesThoughMiddleware(t *testing.T) { - ctx := context.Background() - s := NewServer(testImpl, nil) - - var middlewareCalled bool - s.AddReceivingMiddleware(func(next MethodHandler) MethodHandler { - return func(ctx context.Context, method string, req Request) (Result, error) { - if method == "acme/ping" { - middlewareCalled = true - } - return next(ctx, method, req) - } - }) - - type pingParams struct { - ParamsBase - } - type pingResult struct { - ResultBase - Pong bool `json:"pong"` - } - AddReceivingCustomMethod(s, "acme/ping", func(ctx context.Context, ss *ServerSession, params *pingParams) (*pingResult, error) { - return &pingResult{Pong: true}, nil - }) - - ct, st := NewInMemoryTransports() - ss, err := s.Connect(ctx, st, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss.Close() }) - - c := NewClient(testImpl, nil) - cs, err := c.Connect(ctx, ct, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs.Close() }) - - var result pingResult - if err := callCustom(ctx, cs.getConn(), "acme/ping", &pingParams{}, &result); err != nil { - t.Fatal(err) - } - - if !result.Pong { - t.Error("expected Pong to be true") - } - if !middlewareCalled { - t.Error("middleware was not called for custom method") - } -} - -func TestCustomMethodNotOnOtherServers(t *testing.T) { - ctx := context.Background() - - type emptyParams struct{ ParamsBase } - type emptyResult struct{ ResultBase } - - // Server 1 has the custom method. - s1 := NewServer(testImpl, nil) - AddReceivingCustomMethod(s1, "acme/custom", func(ctx context.Context, ss *ServerSession, params *emptyParams) (*emptyResult, error) { - return &emptyResult{}, nil - }) - - // Server 2 does NOT have the custom method. - s2 := NewServer(testImpl, nil) - - // Test s1: custom method should work. - ct1, st1 := NewInMemoryTransports() - ss1, err := s1.Connect(ctx, st1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss1.Close() }) - - c1 := NewClient(testImpl, nil) - cs1, err := c1.Connect(ctx, ct1, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs1.Close() }) - - if err := callCustom(ctx, cs1.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}); err != nil { - t.Fatalf("custom method on s1 should work: %v", err) - } - - // Test s2: custom method should fail. - ct2, st2 := NewInMemoryTransports() - ss2, err := s2.Connect(ctx, st2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = ss2.Close() }) - - c2 := NewClient(testImpl, nil) - cs2, err := c2.Connect(ctx, ct2, nil) - if err != nil { - t.Fatal(err) - } - t.Cleanup(func() { _ = cs2.Close() }) - - err = callCustom(ctx, cs2.getConn(), "acme/custom", &emptyParams{}, &emptyResult{}) - if err == nil { - t.Fatal("expected error calling custom method on s2") - } -} - -func TestCustomMethodStreamableHTTP(t *testing.T) { - s := NewServer(testImpl, nil) - - type echoParams struct { - ParamsBase - Msg string `json:"msg"` - } - type echoResult struct { - ResultBase - Reply string `json:"reply"` - } - AddReceivingCustomMethod(s, "acme/echo", func(ctx context.Context, ss *ServerSession, params *echoParams) (*echoResult, error) { - return &echoResult{Reply: "echo: " + params.Msg}, nil - }) - - handler := NewStreamableHTTPHandler(func(r *http.Request) *Server { return s }, nil) - _ = handler -} +} \ No newline at end of file From e5748ee6ea53b19647d05429ace4b5cd7b6336b8 Mon Sep 17 00:00:00 2001 From: guglielmoc Date: Wed, 13 May 2026 13:21:31 +0000 Subject: [PATCH 5/5] refactor: remove trailing whitespace and empty lines from client and server test files --- mcp/client_test.go | 2 -- mcp/server_test.go | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/mcp/client_test.go b/mcp/client_test.go index f36d7d00..609fd501 100644 --- a/mcp/client_test.go +++ b/mcp/client_test.go @@ -617,5 +617,3 @@ func TestClientCapabilitiesOverWire(t *testing.T) { }) } } - - diff --git a/mcp/server_test.go b/mcp/server_test.go index 609f8e76..2937ea2b 100644 --- a/mcp/server_test.go +++ b/mcp/server_test.go @@ -1006,4 +1006,4 @@ func TestServerCapabilitiesOverWire(t *testing.T) { } }) } -} \ No newline at end of file +}