diff --git a/examples/server/custom-method/main.go b/examples/server/custom-method/main.go new file mode 100644 index 00000000..d36bc6eb --- /dev/null +++ b/examples/server/custom-method/main.go @@ -0,0 +1,105 @@ +// Copyright 2025 The Go MCP SDK Authors. All rights reserved. +// Use of this source code is governed by the license +// that can be found in the LICENSE file. + +// The custom-method example demonstrates registering and calling a custom +// JSON-RPC method that is not part of the standard MCP spec. +// +// The server registers a "latin/translate" method that translates simple +// English phrases into Latin. A client connects over an in-memory transport, +// calls the custom method, and prints the result. +package main + +import ( + "context" + "fmt" + "log" + "strings" + + "github.com/modelcontextprotocol/go-sdk/mcp" +) + +type TranslateParams struct { + mcp.ParamsBase + Text string `json:"text"` +} + +type TranslateResult struct { + mcp.ResultBase + Latin string `json:"latin"` +} + +var translations = map[string]string{ + "hello": "salve", + "goodbye": "vale", + "thank you": "gratias tibi ago", + "how are you": "quid agis", + "good morning": "bonum mane", + "good night": "bonam noctem", + "friend": "amicus", + "water": "aqua", + "love": "amor", + "war": "bellum", + "peace": "pax", + "truth": "veritas", + "light": "lux", + "time": "tempus", + "life": "vita", + "death": "mors", + "star": "stella", + "earth": "terra", + "sea": "mare", + "the die is cast": "alea iacta est", + "i came i saw i conquered": "veni vidi vici", + "seize the day": "carpe diem", +} + +func main() { + ctx := context.Background() + + server := mcp.NewServer(&mcp.Implementation{Name: "latin-server", Version: "v1.0.0"}, nil) + + mcp.AddReceivingCustomMethod(server, "latin/translate", + func(ctx context.Context, ss *mcp.ServerSession, params *TranslateParams) (*TranslateResult, error) { + key := strings.ToLower(strings.TrimSpace(params.Text)) + latin, ok := translations[key] + if !ok { + latin = fmt.Sprintf("[unknown: %q — try: %s]", params.Text, knownPhrases()) + } + return &TranslateResult{Latin: latin}, nil + }) + + ct, st := mcp.NewInMemoryTransports() + + ss, err := server.Connect(ctx, st, nil) + if err != nil { + log.Fatal(err) + } + defer ss.Close() + + client := mcp.NewClient(&mcp.Implementation{Name: "latin-client", Version: "v1.0.0"}, nil) + translate := mcp.AddSendingCustomMethod[*TranslateParams, *TranslateResult](client, "latin/translate") + + cs, err := client.Connect(ctx, ct, nil) + if err != nil { + log.Fatal(err) + } + defer cs.Close() + + phrases := []string{"Hello", "Seize the day", "Peace", "Truth", "I came I saw I conquered"} + for _, phrase := range phrases { + result, err := translate(ctx, cs, &TranslateParams{Text: phrase}) + if err != nil { + log.Fatalf("translate %q: %v", phrase, err) + } + fmt.Printf("%-35s → %s\n", phrase, result.Latin) + } +} + +func knownPhrases() string { + phrases := make([]string, 0, len(translations)) + for k := range translations { + phrases = append(phrases, fmt.Sprintf("%q", k)) + } + return strings.Join(phrases, ", ") +} diff --git a/mcp/client.go b/mcp/client.go index 6e24c5a3..093249ef 100644 --- a/mcp/client.go +++ b/mcp/client.go @@ -10,6 +10,8 @@ import ( "fmt" "iter" "log/slog" + "maps" + "reflect" "slices" "strings" "sync" @@ -32,6 +34,7 @@ type Client struct { sessions []*ClientSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + customSendMethods map[string]methodInfo } // NewClient creates a new [Client]. @@ -64,6 +67,7 @@ func NewClient(impl *Implementation, options *ClientOptions) *Client { roots: newFeatureSet(func(r *Root) string { return r.URI }), sendingMethodHandler_: defaultSendingMethodHandler, receivingMethodHandler_: defaultReceivingMethodHandler[*ClientSession], + customSendMethods: make(map[string]methodInfo), } } @@ -945,7 +949,13 @@ var clientMethodInfos = map[string]methodInfo{ } func (cs *ClientSession) sendingMethodInfos() map[string]methodInfo { - return serverMethodInfos + if len(cs.client.customSendMethods) == 0 { + return serverMethodInfos + } + infos := make(map[string]methodInfo, len(serverMethodInfos)+len(cs.client.customSendMethods)) + maps.Copy(infos, serverMethodInfos) + maps.Copy(infos, cs.client.customSendMethods) + return infos } func (cs *ClientSession) receivingMethodInfos() map[string]methodInfo { @@ -1218,3 +1228,33 @@ func paginate[P listParams, R listResult[T], T any](ctx context.Context, params } } } + +// AddSendingCustomMethod registers a custom method that the client can send +// to the server and returns a typed caller function. +// +// The returned function calls the custom method through the client's sending +// middleware chain, with full type safety on both params and result. +// +// callSearch := mcp.AddSendingCustomMethod[*SearchParams, *SearchResult](c, "acme/search") +// result, err := callSearch(ctx, cs, &SearchParams{Query: "hello"}) +func AddSendingCustomMethod[P paramsPtr[PT], R Result, PT any]( + c *Client, + method string, +) func(ctx context.Context, cs *ClientSession, params P) (R, error) { + mi := methodInfo{ + newResult: func() Result { + return reflect.New(reflect.TypeFor[R]().Elem()).Interface().(R) + }, + } + + c.mu.Lock() + defer c.mu.Unlock() + c.customSendMethods[method] = mi + + return func(ctx context.Context, cs *ClientSession, params P) (R, error) { + return handleSend[R](ctx, method, &ClientRequest[P]{ + Session: cs, + Params: params, + }) + } +} diff --git a/mcp/mcp_test.go b/mcp/mcp_test.go index 14173231..810b1e58 100644 --- a/mcp/mcp_test.go +++ b/mcp/mcp_test.go @@ -2372,3 +2372,72 @@ func TestSetErrorPreservesContent(t *testing.T) { } var ctrCmpOpts = []cmp.Option{cmp.AllowUnexported(CallToolResult{})} + +func TestCustomMethods(t *testing.T) { + type searchParams struct { + ParamsBase + Query string `json:"query"` + Limit int `json:"limit,omitempty"` + } + + type searchResult struct { + ResultBase + Hits []string `json:"hits"` + Total int `json:"total"` + } + + callCustom := func(ctx context.Context, conn *jsonrpc2.Connection, method string, params, result any) error { + return conn.Call(ctx, method, params).Await(ctx, result) + } + + ctx := context.Background() + s := NewServer(testImpl, nil) + + AddReceivingCustomMethod(s, "acme/search", func(ctx context.Context, ss *ServerSession, params *searchParams) (*searchResult, error) { + hits := []string{"result for " + params.Query} + return &searchResult{ + Hits: hits, + Total: len(hits), + }, nil + }) + + ct, st := NewInMemoryTransports() + ss, err := s.Connect(ctx, st, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = ss.Close() }) + + c := NewClient(testImpl, nil) + callSearch := AddSendingCustomMethod[*searchParams, *searchResult](c, "acme/search") + + cs, err := c.Connect(ctx, ct, nil) + if err != nil { + t.Fatal(err) + } + t.Cleanup(func() { _ = cs.Close() }) + + // Test raw JSON-RPC call + var result1 searchResult + if err := callCustom(ctx, cs.getConn(), "acme/search", &searchParams{Query: "hello", Limit: 10}, &result1); err != nil { + t.Fatal(err) + } + if len(result1.Hits) != 1 || result1.Hits[0] != "result for hello" { + t.Errorf("raw call: unexpected hits: %v", result1.Hits) + } + if result1.Total != 1 { + t.Errorf("raw call: unexpected total: %d", result1.Total) + } + + // Test typed caller + result2, err := callSearch(ctx, cs, &searchParams{Query: "hello"}) + if err != nil { + t.Fatal(err) + } + if len(result2.Hits) != 1 || result2.Hits[0] != "result for hello" { + t.Errorf("typed call: unexpected hits: %v", result2.Hits) + } + if result2.Total != 1 { + t.Errorf("typed call: unexpected total: %d", result2.Total) + } +} diff --git a/mcp/server.go b/mcp/server.go index d25c7922..549645a6 100644 --- a/mcp/server.go +++ b/mcp/server.go @@ -52,6 +52,7 @@ type Server struct { sessions []*ServerSession sendingMethodHandler_ MethodHandler receivingMethodHandler_ MethodHandler + customMethods map[string]methodInfo resourceSubscriptions map[string]map[*ServerSession]bool // uri -> session -> bool pendingNotifications map[string]*time.Timer // notification name -> timer for pending notification send } @@ -195,6 +196,7 @@ func NewServer(impl *Implementation, options *ServerOptions) *Server { receivingMethodHandler_: defaultReceivingMethodHandler[*ServerSession], resourceSubscriptions: make(map[string]map[*ServerSession]bool), pendingNotifications: make(map[string]*time.Timer), + customMethods: make(map[string]methodInfo), } } @@ -1422,7 +1424,19 @@ func initializeMethodInfo() methodInfo { func (ss *ServerSession) sendingMethodInfos() map[string]methodInfo { return clientMethodInfos } -func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { return serverMethodInfos } +func (s *Server) receivingMethodInfos() map[string]methodInfo { + if len(s.customMethods) == 0 { + return serverMethodInfos + } + infos := make(map[string]methodInfo, len(serverMethodInfos)+len(s.customMethods)) + maps.Copy(infos, serverMethodInfos) + maps.Copy(infos, s.customMethods) + return infos +} + +func (ss *ServerSession) receivingMethodInfos() map[string]methodInfo { + return ss.server.receivingMethodInfos() +} func (ss *ServerSession) sendingMethodHandler() MethodHandler { s := ss.server @@ -1630,3 +1644,44 @@ func paginateList[P listParams, R listResult[T], T any](fs *featureSet[T], pageS *res.nextCursorPtr() = nextCursor return res, nil } + +// AddReceivingCustomMethod registers a handler for a custom (non-standard) +// JSON-RPC method on the server. +// +// When a client sends a request with the given method name, the params will be +// unmarshaled into P, the handler will be called, and the returned R will be +// marshaled as the JSON-RPC result. +// +// Custom methods go through the server's middleware chain just like standard +// MCP methods (tools/call, prompts/list, etc.). +// +// P and R must implement [Params] and [Result] respectively, which is most +// easily done by embedding [ParamsBase] and [ResultBase]: +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +// +// mcp.AddReceivingCustomMethod(server, "acme/search", +// func(ctx context.Context, ss *mcp.ServerSession, params *SearchParams) (*SearchResult, error) { +// return &SearchResult{Hits: []string{"result"}}, nil +// }) +func AddReceivingCustomMethod[P paramsPtr[T], R Result, T any]( + s *Server, + method string, + handler func(ctx context.Context, ss *ServerSession, params P) (R, error), +) { + typed := typedServerMethodHandler[P, R](func(ctx context.Context, req *ServerRequest[P]) (R, error) { + return handler(ctx, req.Session, req.Params) + }) + + s.mu.Lock() + defer s.mu.Unlock() + s.customMethods[method] = newServerMethodInfo(typed, missingParamsOK) +} diff --git a/mcp/shared.go b/mcp/shared.go index 078b401b..42472704 100644 --- a/mcp/shared.go +++ b/mcp/shared.go @@ -544,6 +544,20 @@ type Params interface { isParams() } +// ParamsBase can be embedded in custom parameter structs to satisfy the +// [Params] interface. It provides the required [Meta] field and the unexported +// isParams marker method. +// +// type SearchParams struct { +// mcp.ParamsBase +// Query string `json:"query"` +// } +type ParamsBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ParamsBase) isParams() {} + // RequestParams is a parameter (input) type for an MCP request. type RequestParams interface { Params @@ -568,6 +582,20 @@ type Result interface { SetMeta(map[string]any) } +// ResultBase can be embedded in custom result structs to satisfy the +// [Result] interface. It provides the required [Meta] field and the unexported +// isResult marker method. +// +// type SearchResult struct { +// mcp.ResultBase +// Hits []string `json:"hits"` +// } +type ResultBase struct { + Meta `json:"_meta,omitempty"` +} + +func (*ResultBase) isResult() {} + // emptyResult is returned by methods that have no result, like ping. // Those methods cannot return nil, because jsonrpc2 cannot handle nils. type emptyResult struct{} diff --git a/mcp/streamable.go b/mcp/streamable.go index b8e36553..a4a34a52 100644 --- a/mcp/streamable.go +++ b/mcp/streamable.go @@ -396,6 +396,7 @@ func connectStreamable(ctx context.Context, server *Server, transport *Streamabl if err != nil { return nil, err } + transport.connection.server = server transport.connection.toolLookup = server.getServerTool return s, nil } @@ -734,6 +735,7 @@ type streamableServerConn struct { logger *slog.Logger + server *Server toolLookup func(name string) (*serverTool, bool) incoming chan jsonrpc.Message // messages from the client to the server @@ -1241,7 +1243,7 @@ func (c *streamableServerConn) servePOST(w http.ResponseWriter, req *http.Reques // Preemptively check that this is a valid request, so that we can fail // the HTTP request. If we didn't do this, a request with a bad method or // missing ID could be silently swallowed. - if _, err := checkRequest(jreq, serverMethodInfos); err != nil { + if _, err := checkRequest(jreq, c.server.receivingMethodInfos()); err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return }