Skip to content
This repository was archived by the owner on May 3, 2026. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 48 additions & 14 deletions reticulum_openapi/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from typing import Optional
from typing import Tuple
from typing import Type
from threading import Lock

import LXMF
import RNS
Expand Down Expand Up @@ -165,6 +166,7 @@ def __init__(
self._link_keepalive_interval = link_keepalive_interval
self._active_links: Dict[bytes, RNS.Link] = {}
self._link_keepalive_tasks: Dict[bytes, asyncio.Task] = {}
self._links_lock = Lock()
self.link_destination: Optional[RNS.Destination] = None
if self._links_enabled:
self._initialise_link_destination()
Expand Down Expand Up @@ -203,7 +205,9 @@ def set_link_handler(
def _link_established(self, link: RNS.Link) -> None:
"""Handle a newly established link from a remote node."""

self._active_links[link.link_id] = link
lock = self._get_links_lock()
with lock:
self._active_links[link.link_id] = link

if hasattr(link, "set_link_closed_callback"):
link.set_link_closed_callback(self._link_closed)
Expand All @@ -217,16 +221,20 @@ def _schedule_handler() -> None:
if self._link_keepalive_interval and self._link_keepalive_interval > 0:
def _start_keepalive() -> None:
task = asyncio.create_task(self._link_keepalive(link))
self._link_keepalive_tasks[link.link_id] = task
lock_inner = self._get_links_lock()
with lock_inner:
self._link_keepalive_tasks[link.link_id] = task

self._loop.call_soon_threadsafe(_start_keepalive)

def _link_closed(self, link: RNS.Link) -> None:
"""Cleanup when a link is closed by either party."""

def _cleanup() -> None:
self._active_links.pop(link.link_id, None)
task = self._link_keepalive_tasks.pop(link.link_id, None)
lock = self._get_links_lock()
with lock:
self._active_links.pop(link.link_id, None)
task = self._link_keepalive_tasks.pop(link.link_id, None)
if task is not None:
task.cancel()

Expand All @@ -236,15 +244,31 @@ async def _link_keepalive(self, link: RNS.Link) -> None:
"""Periodically send keep-alive packets for an active link."""

try:
while link.link_id in self._active_links:
lock = self._get_links_lock()
while True:
with lock:
if link.link_id not in self._active_links:
break
await asyncio.sleep(self._link_keepalive_interval)
with lock:
if link.link_id not in self._active_links:
break
try:
link.send_keepalive()
except Exception:
logger.debug("Failed to send keepalive on link %r", link.link_id)
except asyncio.CancelledError:
return

def _get_links_lock(self) -> Lock:
"""Return the lock protecting link bookkeeping structures."""

lock = getattr(self, "_links_lock", None)
if lock is None:
lock = Lock()
self._links_lock = lock
return lock

@staticmethod
def _extract_command_from_path(path: Any) -> Optional[str]:
"""Return the command name encoded within a link request path."""
Expand Down Expand Up @@ -851,17 +875,27 @@ async def __aexit__(self, exc_type, exc, tb):
async def _shutdown_links(self) -> None:
"""Close any active links and cancel keep-alive tasks."""

active_links = getattr(self, "_active_links", {})
for link in list(active_links.values()):
lock = self._get_links_lock()
with lock:
active_dict = getattr(self, "_active_links", None)
keepalive_dict = getattr(self, "_link_keepalive_tasks", None)
if active_dict is None:
active_links = []
else:
active_links = list(active_dict.values())
active_dict.clear()
if keepalive_dict is None:
keepalive_tasks = []
else:
keepalive_tasks = list(keepalive_dict.values())
keepalive_dict.clear()

for link in active_links:
try:
link.close()
except Exception:
pass
active_links.clear()
keepalive_tasks = getattr(self, "_link_keepalive_tasks", {})
tasks = list(keepalive_tasks.values())
keepalive_tasks.clear()
for task in tasks:
for task in keepalive_tasks:
task.cancel()
if tasks:
await asyncio.gather(*tasks, return_exceptions=True)
if keepalive_tasks:
await asyncio.gather(*keepalive_tasks, return_exceptions=True)
92 changes: 92 additions & 0 deletions tests/test_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,98 @@ def _invoke_handler() -> Optional[bytes]:
assert msgpack_from_bytes(payload) == {"status": "ok"}


@pytest.mark.asyncio
async def test_parallel_link_requests_use_isolated_state() -> None:
"""Concurrent link requests should receive independent responses."""

loop = asyncio.get_running_loop()
service = LXMFService.__new__(LXMFService)
service._loop = loop
service.auth_token = None
service.max_payload_size = 32000
service._active_links = {}
service._link_keepalive_tasks = {}
service._link_keepalive_interval = 0.01
service._link_handler = None

handled_values: list[int] = []

async def handler(payload: dict[str, Any]) -> dict[str, Any]:
await asyncio.sleep(0.01)
value = payload["value"]
handled_values.append(value)
return {"value": value, "handled": True}

service._routes = {"Echo": (handler, None, None)}

class FakeLink:
def __init__(self, link_id: bytes) -> None:
self.link_id = link_id
self.closed_callback: Optional[Callable[[Any], None]] = None
self.keepalives = 0

def set_link_closed_callback(self, callback: Callable[[Any], None]) -> None:
self.closed_callback = callback

def send_keepalive(self) -> None:
self.keepalives += 1

def close(self) -> None:
if self.closed_callback is not None:
self.closed_callback(self)

links = [FakeLink(bytes([idx])) for idx in range(1, 4)]

for link in links:
service._link_established(link)

await asyncio.sleep(0.05)

link_ids = {link.link_id for link in links}
assert set(service._active_links) == link_ids
assert set(service._link_keepalive_tasks) == link_ids

keepalive_tasks = list(service._link_keepalive_tasks.values())
assert len(keepalive_tasks) == len(links)
assert len({id(task) for task in keepalive_tasks}) == len(links)

payloads = [{"value": idx} for idx in range(1, 4)]

async def issue_request(link: FakeLink, payload: dict[str, Any]) -> dict[str, Any]:
response_bytes = await asyncio.to_thread(
service._handle_registered_link_request,
"/commands/Echo",
dataclass_to_msgpack(payload),
None,
link.link_id,
SimpleNamespace(hash=b"remote" + link.link_id),
123.0,
)
assert response_bytes is not None
return msgpack_from_bytes(response_bytes)

responses = await asyncio.gather(
*(issue_request(link, payload) for link, payload in zip(links, payloads))
)

assert handled_values and sorted(handled_values) == [1, 2, 3]

for expected, response in zip(payloads, responses):
assert response["value"] == expected["value"]
assert response["handled"] is True

for link in links:
link.close()

await asyncio.sleep(0.05)

assert service._active_links == {}
assert service._link_keepalive_tasks == {}

if keepalive_tasks:
await asyncio.gather(*keepalive_tasks, return_exceptions=True)


@pytest.mark.asyncio
async def test_lxmf_callback_dispatches_response():
"""Handler return values are sent back via _send_lxmf."""
Expand Down