Skip to content

Commit f0bb718

Browse files
committed
Flatten the methods in Client
1 parent 34e66a3 commit f0bb718

File tree

8 files changed

+227
-305
lines changed

8 files changed

+227
-305
lines changed

CLAUDE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ This document contains critical information about working with this codebase. Fo
1717
- Functions must be focused and small
1818
- Follow existing patterns exactly
1919
- Line length: 120 chars maximum
20+
- FORBIDDEN: imports inside functions
2021

2122
3. Testing Requirements
2223
- Framework: `uv run --frozen pytest`
@@ -25,6 +26,8 @@ This document contains critical information about working with this codebase. Fo
2526
- Coverage: test edge cases and errors
2627
- New features require tests
2728
- Bug fixes require regression tests
29+
- IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns.
30+
- IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible.
2831

2932
- For commits fixing bugs or adding features based on user reports add:
3033

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,6 @@ xfail_strict = true
171171
addopts = """
172172
--color=yes
173173
--capture=fd
174-
--numprocesses auto
175174
"""
176175
filterwarnings = [
177176
"error",

src/mcp/client/_memory.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,7 @@ class InMemoryTransport:
3636
result = await client.call_tool("my_tool", {...})
3737
"""
3838

39-
def __init__(
40-
self,
41-
server: Server[Any] | FastMCP,
42-
*,
43-
raise_exceptions: bool = False,
44-
) -> None:
39+
def __init__(self, server: Server[Any] | FastMCP, *, raise_exceptions: bool = False) -> None:
4540
"""Initialize the in-memory transport.
4641
4742
Args:

src/mcp/client/client.py

Lines changed: 18 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
self,
5959
server: Server[Any] | FastMCP,
6060
*,
61+
# TODO(Marcelo): When do `raise_exceptions=True` actually raises?
6162
raise_exceptions: bool = False,
6263
read_timeout_seconds: float | None = None,
6364
sampling_callback: SamplingFnT | None = None,
@@ -125,12 +126,7 @@ async def __aenter__(self) -> Client:
125126
self._exit_stack = exit_stack.pop_all()
126127
return self
127128

128-
async def __aexit__(
129-
self,
130-
exc_type: type[BaseException] | None,
131-
exc_val: BaseException | None,
132-
exc_tb: Any,
133-
) -> None:
129+
async def __aexit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: Any) -> None:
134130
"""Exit the async context manager."""
135131
if self._exit_stack:
136132
await self._exit_stack.__aexit__(exc_type, exc_val, exc_tb)
@@ -177,28 +173,22 @@ async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResul
177173
"""Set the logging level on the server."""
178174
return await self.session.set_logging_level(level)
179175

180-
async def list_resources(
181-
self,
182-
params: types.PaginatedRequestParams | None = None,
183-
) -> types.ListResourcesResult:
176+
async def list_resources(self, *, cursor: str | None = None) -> types.ListResourcesResult:
184177
"""List available resources from the server."""
185-
return await self.session.list_resources(params=params)
178+
return await self.session.list_resources(params=types.PaginatedRequestParams(cursor=cursor))
186179

187-
async def list_resource_templates(
188-
self,
189-
params: types.PaginatedRequestParams | None = None,
190-
) -> types.ListResourceTemplatesResult:
180+
async def list_resource_templates(self, *, cursor: str | None = None) -> types.ListResourceTemplatesResult:
191181
"""List available resource templates from the server."""
192-
return await self.session.list_resource_templates(params=params)
182+
return await self.session.list_resource_templates(params=types.PaginatedRequestParams(cursor=cursor))
193183

194184
async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult:
195185
"""Read a resource from the server.
196186
197187
Args:
198-
uri: The URI of the resource to read
188+
uri: The URI of the resource to read.
199189
200190
Returns:
201-
The resource content
191+
The resource content.
202192
"""
203193
return await self.session.read_resource(uri)
204194

@@ -239,26 +229,19 @@ async def call_tool(
239229
meta=meta,
240230
)
241231

242-
async def list_prompts(
243-
self,
244-
params: types.PaginatedRequestParams | None = None,
245-
) -> types.ListPromptsResult:
232+
async def list_prompts(self, *, cursor: str | None = None) -> types.ListPromptsResult:
246233
"""List available prompts from the server."""
247-
return await self.session.list_prompts(params=params)
234+
return await self.session.list_prompts(params=types.PaginatedRequestParams(cursor=cursor))
248235

249-
async def get_prompt(
250-
self,
251-
name: str,
252-
arguments: dict[str, str] | None = None,
253-
) -> types.GetPromptResult:
236+
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
254237
"""Get a prompt from the server.
255238
256239
Args:
257240
name: The name of the prompt
258241
arguments: Arguments to pass to the prompt
259242
260243
Returns:
261-
The prompt content
244+
The prompt content.
262245
"""
263246
return await self.session.get_prompt(name=name, arguments=arguments)
264247

@@ -276,21 +259,15 @@ async def complete(
276259
context_arguments: Additional context arguments
277260
278261
Returns:
279-
Completion suggestions
262+
Completion suggestions.
280263
"""
281-
return await self.session.complete(
282-
ref=ref,
283-
argument=argument,
284-
context_arguments=context_arguments,
285-
)
264+
return await self.session.complete(ref=ref, argument=argument, context_arguments=context_arguments)
286265

287-
async def list_tools(
288-
self,
289-
params: types.PaginatedRequestParams | None = None,
290-
) -> types.ListToolsResult:
266+
async def list_tools(self, *, cursor: str | None = None) -> types.ListToolsResult:
291267
"""List available tools from the server."""
292-
return await self.session.list_tools(params=params)
268+
return await self.session.list_tools(params=types.PaginatedRequestParams(cursor=cursor))
293269

294270
async def send_roots_list_changed(self) -> None:
295271
"""Send a notification that the roots list has changed."""
296-
await self.session.send_roots_list_changed()
272+
# TODO(Marcelo): Currently, there is no way for the server to handle this. We should add support.
273+
await self.session.send_roots_list_changed() # pragma: no cover

src/mcp/client/session.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -349,10 +349,7 @@ async def list_prompts(self, *, params: types.PaginatedRequestParams | None = No
349349
Args:
350350
params: Full pagination parameters including cursor and any future fields
351351
"""
352-
return await self.send_request(
353-
types.ListPromptsRequest(params=params),
354-
types.ListPromptsResult,
355-
)
352+
return await self.send_request(types.ListPromptsRequest(params=params), types.ListPromptsResult)
356353

357354
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
358355
"""Send a prompts/get request."""

src/mcp/shared/session.py

Lines changed: 12 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -234,12 +234,12 @@ async def send_request(
234234
metadata: MessageMetadata = None,
235235
progress_callback: ProgressFnT | None = None,
236236
) -> ReceiveResultT:
237-
"""Sends a request and wait for a response. Raises an McpError if the
238-
response contains an error. If a request read timeout is provided, it
239-
will take precedence over the session read timeout.
237+
"""Sends a request and wait for a response.
240238
241-
Do not use this method to emit notifications! Use send_notification()
242-
instead.
239+
Raises an McpError if the response contains an error. If a request read timeout is provided, it will take
240+
precedence over the session read timeout.
241+
242+
Do not use this method to emit notifications! Use send_notification() instead.
243243
"""
244244
request_id = self._request_id
245245
self._request_id = request_id + 1
@@ -261,15 +261,10 @@ async def send_request(
261261

262262
try:
263263
jsonrpc_request = JSONRPCRequest(jsonrpc="2.0", id=request_id, **request_data)
264-
265264
await self._write_stream.send(SessionMessage(message=jsonrpc_request, metadata=metadata))
266265

267266
# request read timeout takes precedence over session read timeout
268-
timeout = None
269-
if request_read_timeout_seconds is not None: # pragma: no cover
270-
timeout = request_read_timeout_seconds
271-
elif self._session_read_timeout_seconds is not None: # pragma: no cover
272-
timeout = self._session_read_timeout_seconds
267+
timeout = request_read_timeout_seconds or self._session_read_timeout_seconds
273268

274269
try:
275270
with anyio.fail_after(timeout):
@@ -279,9 +274,8 @@ async def send_request(
279274
ErrorData(
280275
code=httpx.codes.REQUEST_TIMEOUT,
281276
message=(
282-
f"Timed out while waiting for response to "
283-
f"{request.__class__.__name__}. Waited "
284-
f"{timeout} seconds."
277+
f"Timed out while waiting for response to {request.__class__.__name__}. "
278+
f"Waited {timeout} seconds."
285279
),
286280
)
287281
)
@@ -302,9 +296,7 @@ async def send_notification(
302296
notification: SendNotificationT,
303297
related_request_id: RequestId | None = None,
304298
) -> None:
305-
"""Emits a notification, which is a one-way message that does not expect
306-
a response.
307-
"""
299+
"""Emits a notification, which is a one-way message that does not expect a response."""
308300
# Some transport implementations may need to set the related_request_id
309301
# to attribute to the notifications to the request that triggered them.
310302
jsonrpc_notification = JSONRPCNotification(
@@ -373,11 +365,7 @@ async def _receive_loop(self) -> None:
373365
error_response = JSONRPCError(
374366
jsonrpc="2.0",
375367
id=message.message.id,
376-
error=ErrorData(
377-
code=INVALID_PARAMS,
378-
message="Invalid request parameters",
379-
data="",
380-
),
368+
error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""),
381369
)
382370
session_message = SessionMessage(message=error_response)
383371
await self._write_stream.send(session_message)
@@ -518,13 +506,9 @@ async def send_progress_notification(
518506
total: float | None = None,
519507
message: str | None = None,
520508
) -> None:
521-
"""Sends a progress notification for a request that is currently being
522-
processed.
523-
"""
509+
"""Sends a progress notification for a request that is currently being processed."""
524510

525511
async def _handle_incoming(
526-
self,
527-
req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception,
512+
self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception
528513
) -> None:
529514
"""A generic handler for incoming messages. Overwritten by subclasses."""
530-
pass # pragma: no cover

0 commit comments

Comments
 (0)