|
1 | 1 | import logging |
2 | 2 | from typing import Any, Protocol, overload |
| 3 | +from types import TracebackType |
3 | 4 |
|
4 | 5 | import anyio.lowlevel |
5 | 6 | from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream |
@@ -140,11 +141,33 @@ def __init__( |
140 | 141 | self._tool_output_schemas: dict[str, dict[str, Any] | None] = {} |
141 | 142 | self._server_capabilities: types.ServerCapabilities | None = None |
142 | 143 | self._experimental_features: ExperimentalClientFeatures | None = None |
| 144 | + self._entered = False |
143 | 145 |
|
144 | 146 | # Experimental: Task handlers (use defaults if not provided) |
145 | 147 | self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers() |
146 | 148 |
|
| 149 | + async def __aenter__(self) -> "ClientSession": |
| 150 | + self._entered = True |
| 151 | + await super().__aenter__() |
| 152 | + return self |
| 153 | + |
| 154 | + async def __aexit__( |
| 155 | + self, |
| 156 | + exc_type: type[BaseException] | None, |
| 157 | + exc_value: BaseException | None, |
| 158 | + traceback: TracebackType | None, |
| 159 | + ) -> None: |
| 160 | + self._entered = False |
| 161 | + await super().__aexit__(exc_type, exc_value, traceback) |
| 162 | + |
| 163 | + def _check_is_active(self) -> None: |
| 164 | + if not self._entered: |
| 165 | + raise RuntimeError( |
| 166 | + "ClientSession must be used within an 'async with' block." |
| 167 | + ) |
| 168 | + |
147 | 169 | async def initialize(self) -> types.InitializeResult: |
| 170 | + self._check_is_active() |
148 | 171 | sampling = ( |
149 | 172 | (self._sampling_capabilities or types.SamplingCapability()) |
150 | 173 | if self._sampling_callback is not _default_sampling_callback |
|
0 commit comments