diff --git a/reticulum_openapi/service.py b/reticulum_openapi/service.py index 3ce7cd0..90c17cc 100644 --- a/reticulum_openapi/service.py +++ b/reticulum_openapi/service.py @@ -13,6 +13,7 @@ from typing import Optional from typing import Tuple from typing import Type +from threading import Lock import LXMF import RNS @@ -165,6 +166,7 @@ def __init__( self._link_keepalive_interval = link_keepalive_interval self._active_links: Dict[bytes, RNS.Link] = {} self._link_keepalive_tasks: Dict[bytes, asyncio.Task] = {} + self._links_lock = Lock() self.link_destination: Optional[RNS.Destination] = None if self._links_enabled: self._initialise_link_destination() @@ -203,7 +205,9 @@ def set_link_handler( def _link_established(self, link: RNS.Link) -> None: """Handle a newly established link from a remote node.""" - self._active_links[link.link_id] = link + lock = self._get_links_lock() + with lock: + self._active_links[link.link_id] = link if hasattr(link, "set_link_closed_callback"): link.set_link_closed_callback(self._link_closed) @@ -217,7 +221,9 @@ def _schedule_handler() -> None: if self._link_keepalive_interval and self._link_keepalive_interval > 0: def _start_keepalive() -> None: task = asyncio.create_task(self._link_keepalive(link)) - self._link_keepalive_tasks[link.link_id] = task + lock_inner = self._get_links_lock() + with lock_inner: + self._link_keepalive_tasks[link.link_id] = task self._loop.call_soon_threadsafe(_start_keepalive) @@ -225,8 +231,10 @@ def _link_closed(self, link: RNS.Link) -> None: """Cleanup when a link is closed by either party.""" def _cleanup() -> None: - self._active_links.pop(link.link_id, None) - task = self._link_keepalive_tasks.pop(link.link_id, None) + lock = self._get_links_lock() + with lock: + self._active_links.pop(link.link_id, None) + task = self._link_keepalive_tasks.pop(link.link_id, None) if task is not None: task.cancel() @@ -236,8 +244,15 @@ async def _link_keepalive(self, link: RNS.Link) -> None: """Periodically send keep-alive packets for an active link.""" try: - while link.link_id in self._active_links: + lock = self._get_links_lock() + while True: + with lock: + if link.link_id not in self._active_links: + break await asyncio.sleep(self._link_keepalive_interval) + with lock: + if link.link_id not in self._active_links: + break try: link.send_keepalive() except Exception: @@ -245,6 +260,15 @@ async def _link_keepalive(self, link: RNS.Link) -> None: except asyncio.CancelledError: return + def _get_links_lock(self) -> Lock: + """Return the lock protecting link bookkeeping structures.""" + + lock = getattr(self, "_links_lock", None) + if lock is None: + lock = Lock() + self._links_lock = lock + return lock + @staticmethod def _extract_command_from_path(path: Any) -> Optional[str]: """Return the command name encoded within a link request path.""" @@ -851,17 +875,27 @@ async def __aexit__(self, exc_type, exc, tb): async def _shutdown_links(self) -> None: """Close any active links and cancel keep-alive tasks.""" - active_links = getattr(self, "_active_links", {}) - for link in list(active_links.values()): + lock = self._get_links_lock() + with lock: + active_dict = getattr(self, "_active_links", None) + keepalive_dict = getattr(self, "_link_keepalive_tasks", None) + if active_dict is None: + active_links = [] + else: + active_links = list(active_dict.values()) + active_dict.clear() + if keepalive_dict is None: + keepalive_tasks = [] + else: + keepalive_tasks = list(keepalive_dict.values()) + keepalive_dict.clear() + + for link in active_links: try: link.close() except Exception: pass - active_links.clear() - keepalive_tasks = getattr(self, "_link_keepalive_tasks", {}) - tasks = list(keepalive_tasks.values()) - keepalive_tasks.clear() - for task in tasks: + for task in keepalive_tasks: task.cancel() - if tasks: - await asyncio.gather(*tasks, return_exceptions=True) + if keepalive_tasks: + await asyncio.gather(*keepalive_tasks, return_exceptions=True) diff --git a/tests/test_service.py b/tests/test_service.py index 548c748..2ed66ae 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -348,6 +348,98 @@ def _invoke_handler() -> Optional[bytes]: assert msgpack_from_bytes(payload) == {"status": "ok"} +@pytest.mark.asyncio +async def test_parallel_link_requests_use_isolated_state() -> None: + """Concurrent link requests should receive independent responses.""" + + loop = asyncio.get_running_loop() + service = LXMFService.__new__(LXMFService) + service._loop = loop + service.auth_token = None + service.max_payload_size = 32000 + service._active_links = {} + service._link_keepalive_tasks = {} + service._link_keepalive_interval = 0.01 + service._link_handler = None + + handled_values: list[int] = [] + + async def handler(payload: dict[str, Any]) -> dict[str, Any]: + await asyncio.sleep(0.01) + value = payload["value"] + handled_values.append(value) + return {"value": value, "handled": True} + + service._routes = {"Echo": (handler, None, None)} + + class FakeLink: + def __init__(self, link_id: bytes) -> None: + self.link_id = link_id + self.closed_callback: Optional[Callable[[Any], None]] = None + self.keepalives = 0 + + def set_link_closed_callback(self, callback: Callable[[Any], None]) -> None: + self.closed_callback = callback + + def send_keepalive(self) -> None: + self.keepalives += 1 + + def close(self) -> None: + if self.closed_callback is not None: + self.closed_callback(self) + + links = [FakeLink(bytes([idx])) for idx in range(1, 4)] + + for link in links: + service._link_established(link) + + await asyncio.sleep(0.05) + + link_ids = {link.link_id for link in links} + assert set(service._active_links) == link_ids + assert set(service._link_keepalive_tasks) == link_ids + + keepalive_tasks = list(service._link_keepalive_tasks.values()) + assert len(keepalive_tasks) == len(links) + assert len({id(task) for task in keepalive_tasks}) == len(links) + + payloads = [{"value": idx} for idx in range(1, 4)] + + async def issue_request(link: FakeLink, payload: dict[str, Any]) -> dict[str, Any]: + response_bytes = await asyncio.to_thread( + service._handle_registered_link_request, + "/commands/Echo", + dataclass_to_msgpack(payload), + None, + link.link_id, + SimpleNamespace(hash=b"remote" + link.link_id), + 123.0, + ) + assert response_bytes is not None + return msgpack_from_bytes(response_bytes) + + responses = await asyncio.gather( + *(issue_request(link, payload) for link, payload in zip(links, payloads)) + ) + + assert handled_values and sorted(handled_values) == [1, 2, 3] + + for expected, response in zip(payloads, responses): + assert response["value"] == expected["value"] + assert response["handled"] is True + + for link in links: + link.close() + + await asyncio.sleep(0.05) + + assert service._active_links == {} + assert service._link_keepalive_tasks == {} + + if keepalive_tasks: + await asyncio.gather(*keepalive_tasks, return_exceptions=True) + + @pytest.mark.asyncio async def test_lxmf_callback_dispatches_response(): """Handler return values are sent back via _send_lxmf."""