88from typing import Any , Generic , Protocol , TypeVar
99
1010import anyio
11+ from anyio .abc import TaskGroup
1112from anyio .streams .memory import MemoryObjectSendStream
1213from opentelemetry .trace import SpanKind
1314from pydantic import BaseModel , TypeAdapter
@@ -201,6 +202,15 @@ def __init__(
201202 self ._progress_callbacks = {}
202203 self ._response_routers = []
203204 self ._exit_stack = AsyncExitStack ()
205+ self ._task_group : TaskGroup = anyio .create_task_group ()
206+ self ._started = False
207+
208+ def _require_started (self ) -> None :
209+ if not self ._started :
210+ raise RuntimeError (
211+ "Session is not running. Use it as an async context manager "
212+ "(e.g. `async with ClientSession(...) as session:`)."
213+ )
204214
205215 def add_response_router (self , router : ResponseRouter ) -> None :
206216 """Register a response router to handle responses for non-standard requests.
@@ -218,8 +228,11 @@ def add_response_router(self, router: ResponseRouter) -> None:
218228 self ._response_routers .append (router )
219229
220230 async def __aenter__ (self ) -> Self :
231+ if self ._started :
232+ raise RuntimeError ("Session is already running" )
221233 self ._task_group = anyio .create_task_group ()
222234 await self ._task_group .__aenter__ ()
235+ self ._started = True
223236 self ._task_group .start_soon (self ._receive_loop )
224237 return self
225238
@@ -234,7 +247,10 @@ async def __aexit__(
234247 # would be very surprising behavior), so make sure to cancel the tasks
235248 # in the task group.
236249 self ._task_group .cancel_scope .cancel ()
237- return await self ._task_group .__aexit__ (exc_type , exc_val , exc_tb )
250+ try :
251+ return await self ._task_group .__aexit__ (exc_type , exc_val , exc_tb )
252+ finally :
253+ self ._started = False
238254
239255 async def send_request (
240256 self ,
@@ -251,6 +267,7 @@ async def send_request(
251267
252268 Do not use this method to emit notifications! Use send_notification() instead.
253269 """
270+ self ._require_started ()
254271 request_id = self ._request_id
255272 self ._request_id = request_id + 1
256273
@@ -313,6 +330,7 @@ async def send_notification(
313330 related_request_id : RequestId | None = None ,
314331 ) -> None :
315332 """Emits a notification, which is a one-way message that does not expect a response."""
333+ self ._require_started ()
316334 # Some transport implementations may need to set the related_request_id
317335 # to attribute to the notifications to the request that triggered them.
318336 jsonrpc_notification = JSONRPCNotification (
0 commit comments