|
38 | 38 | from ._context import UiPathServerType |
39 | 39 | from ._exception import McpErrorCode, UiPathMcpRuntimeError |
40 | 40 | from ._session import BaseSessionServer, StdioSessionServer, StreamableHttpSessionServer |
| 41 | +from ._token_refresh import TokenRefresher |
41 | 42 |
|
42 | 43 | logger = logging.getLogger(__name__) |
43 | 44 | tracer = trace.get_tracer(__name__) |
@@ -85,6 +86,7 @@ def __init__( |
85 | 86 | self._http_stderr_drain_task: asyncio.Task[None] | None = None |
86 | 87 | self._http_server_stderr_lines: list[str] = [] |
87 | 88 | self._uipath = UiPath() |
| 89 | + self._token_refresher: TokenRefresher | None = None |
88 | 90 | self._cleanup_done = False |
89 | 91 |
|
90 | 92 | # Context fields from UiPathConfig |
@@ -207,6 +209,8 @@ async def _run_server(self) -> UiPathRuntimeResult: |
207 | 209 | root_span.set_attribute("args", json.dumps(self._server.args)) |
208 | 210 | root_span.set_attribute("span_type", "MCP Server") |
209 | 211 | bearer_token = self._uipath._config.secret |
| 212 | + self._token_refresher = TokenRefresher(self._uipath) |
| 213 | + |
210 | 214 | self._signalr_client = SignalRClient( |
211 | 215 | signalr_url, |
212 | 216 | headers={ |
@@ -236,6 +240,7 @@ async def _run_server(self) -> UiPathRuntimeResult: |
236 | 240 | run_task = asyncio.create_task(self._signalr_client.run()) |
237 | 241 | cancel_task = asyncio.create_task(self._cancel_event.wait()) |
238 | 242 | self._keep_alive_task = asyncio.create_task(self._keep_alive()) |
| 243 | + self._token_refresher.start() |
239 | 244 |
|
240 | 245 | try: |
241 | 246 | # Wait for either the run to complete or cancellation |
@@ -297,6 +302,9 @@ async def _cleanup(self) -> None: |
297 | 302 |
|
298 | 303 | await self._on_runtime_abort() |
299 | 304 |
|
| 305 | + if self._token_refresher: |
| 306 | + await self._token_refresher.stop() |
| 307 | + |
300 | 308 | if self._keep_alive_task: |
301 | 309 | self._keep_alive_task.cancel() |
302 | 310 | try: |
@@ -374,11 +382,11 @@ async def _handle_signalr_message(self, args: list[str]) -> None: |
374 | 382 | session_server: BaseSessionServer |
375 | 383 | if self._server.is_streamable_http: |
376 | 384 | session_server = StreamableHttpSessionServer( |
377 | | - self._server, self.slug, session_id |
| 385 | + self._server, self.slug, session_id, self._uipath |
378 | 386 | ) |
379 | 387 | else: |
380 | 388 | session_server = StdioSessionServer( |
381 | | - self._server, self.slug, session_id |
| 389 | + self._server, self.slug, session_id, self._uipath |
382 | 390 | ) |
383 | 391 | try: |
384 | 392 | await session_server.start() |
|
0 commit comments