@@ -36,8 +36,6 @@ def __init__(self, username: str, base_url=None, session: aiohttp.ClientSession
3636 self ._default_url = "https://euiot.roborock.com"
3737 self .base_url = base_url
3838 self ._device_identifier = secrets .token_urlsafe (16 )
39- if session is None :
40- session = aiohttp .ClientSession ()
4139 self .session = session
4240
4341 async def _get_base_url (self ) -> str :
@@ -470,18 +468,20 @@ async def download_category_code(self, user_data: UserData):
470468
471469
472470class PreparedRequest :
473- def __init__ (self , base_url : str , session : aiohttp .ClientSession , base_headers : dict | None = None ) -> None :
471+ def __init__ (
472+ self , base_url : str , session : aiohttp .ClientSession | None = None , base_headers : dict | None = None
473+ ) -> None :
474474 self .base_url = base_url
475475 self .base_headers = base_headers or {}
476476 self .session = session
477477
478478 async def request (self , method : str , url : str , params = None , data = None , headers = None , json = None ) -> dict :
479479 _url = "/" .join (s .strip ("/" ) for s in [self .base_url , url ])
480480 _headers = {** self .base_headers , ** (headers or {})}
481+ close_session = self .session is None
482+ session = self .session if self .session is not None else aiohttp .ClientSession ()
481483 try :
482- async with self .session .request (
483- method , _url , params = params , data = data , headers = _headers , json = json
484- ) as resp :
484+ async with session .request (method , _url , params = params , data = data , headers = _headers , json = json ) as resp :
485485 return await resp .json ()
486486 except ContentTypeError as err :
487487 """If we get an error, lets log everything for debugging."""
@@ -494,3 +494,6 @@ async def request(self, method: str, url: str, params=None, data=None, headers=N
494494 _LOGGER .info ("Resp raw: %s" , resp_raw )
495495 # Still raise the err so that it's clear it failed.
496496 raise err
497+ finally :
498+ if close_session :
499+ await session .close ()
0 commit comments