diff --git a/reticulum_openapi/service.py b/reticulum_openapi/service.py index ea3790a..4c99ff4 100644 --- a/reticulum_openapi/service.py +++ b/reticulum_openapi/service.py @@ -208,8 +208,6 @@ def _link_established(self, link: RNS.Link) -> None: if hasattr(link, "set_link_closed_callback"): link.set_link_closed_callback(self._link_closed) - self._configure_link_request_handler(link) - if self._link_handler is not None: def _schedule_handler() -> None: asyncio.create_task(self._link_handler(link)) @@ -247,35 +245,6 @@ async def _link_keepalive(self, link: RNS.Link) -> None: except asyncio.CancelledError: return - def _configure_link_request_handler(self, link: RNS.Link) -> None: - """Assign the default request handler for established links.""" - - def _handler( - path: Any, - data: Any = None, - respond: Optional[Callable[[bytes], Any]] = None, - ) -> None: - def _dispatch() -> None: - self._loop.create_task( - self._handle_link_request(link, path, data, respond) - ) - - self._loop.call_soon_threadsafe(_dispatch) - - setter = getattr(link, "set_request_handler", None) - if callable(setter): - try: - setter(_handler) - return - except Exception: - logger.debug("set_request_handler unavailable on link", exc_info=True) - try: - setattr(link, "request_handler", _handler) - except Exception: - logger.debug( - "Unable to assign default link request handler", exc_info=True - ) - @staticmethod def _extract_command_from_path(path: Any) -> Optional[str]: """Return the command name encoded within a link request path.""" @@ -296,19 +265,53 @@ def _extract_command_from_path(path: Any) -> Optional[str]: command = cleaned.rsplit("/", 1)[-1] return command or None - async def _handle_link_request( + def _register_link_route(self, command: Optional[str]) -> None: + """Register the request handler for ``command`` on the link destination.""" + + if self.link_destination is None: + return + if not command: + return + + path = f"/commands/{command}" + try: + self.link_destination.deregister_request_handler(path) + except Exception: + logger.debug( + "Failed to deregister existing link handler for %s", path, exc_info=True + ) + try: + self.link_destination.register_request_handler( + path, + self._handle_registered_link_request, + allow=RNS.Destination.ALLOW_ALL, + ) + except Exception: + logger.exception("Unable to register link handler for %s", command) + + def _handle_registered_link_request( self, - link: RNS.Link, path: Any, - data: Any, - respond: Optional[Callable[[bytes], Any]], - ) -> None: - """Decode and dispatch commands received over an ``RNS.Link``.""" + request_data: Any, + request_id: Any, + *extra: Any, + ) -> Optional[bytes]: + """Handle link requests dispatched via ``RNS.Destination`` hooks.""" command_candidate = self._extract_command_from_path(path) if command_candidate is None: logger.warning("Received link request with invalid path: %r", path) - return + return None + + response = self._generate_link_response(command_candidate, request_data) + if response is not None: + logger.info("Sent response for %s over link", command_candidate) + return response + + def _generate_link_response( + self, command_candidate: str, raw_payload: Any + ) -> Optional[bytes]: + """Return response bytes for a link-delivered command.""" command_key = command_candidate normalised = self._normalise_command_title(command_candidate) @@ -318,30 +321,29 @@ async def _handle_link_request( route = self._routes.get(command_key) if route is None: logger.warning("No route found for link command: %s", command_candidate) - return + return None handler, payload_type, payload_schema = route payload_obj, valid = self._decode_command_payload( command_candidate, - data, + raw_payload, payload_type, payload_schema, ) if not valid: - return - - def _responder(response_bytes: bytes) -> Optional[Any]: - if respond is None: - return None - logger.info("Sent response for %s over link", command_candidate) - return respond(response_bytes) + return None - await self._dispatch_handler_response( - command_candidate, - handler, - payload_obj, - _responder, + future = asyncio.run_coroutine_threadsafe( + self._execute_command_handler(command_candidate, handler, payload_obj), + self._loop, ) + try: + return future.result() + except Exception: + logger.exception( + "Failed to handle link command %s", command_candidate, exc_info=True + ) + return None def add_route( self, @@ -365,6 +367,8 @@ def add_route( raise ValueError("Command names must be UTF-8 decodable") self._routes[normalised_command] = (handler, payload_type, payload_schema) RNS.log(f"Route registered: '{normalised_command}' -> {handler}") + if self.link_destination is not None: + self._register_link_route(normalised_command) def _decode_command_payload( self, @@ -481,14 +485,13 @@ def _serialise_handler_result( return None return compress_json(fallback_json) - async def _dispatch_handler_response( + async def _execute_command_handler( self, command: str, handler: Callable[..., Awaitable[Any]], payload_obj: Optional[Any], - responder: Callable[[bytes], Optional[Any]], - ) -> None: - """Execute a handler and forward the serialised response.""" + ) -> Optional[bytes]: + """Execute a handler and return serialised response bytes.""" try: if payload_obj is not None: @@ -497,12 +500,25 @@ async def _dispatch_handler_response( result = await handler() except Exception as exc: logger.exception("Exception in handler for %s: %s", command, exc) - return + return None if result is None: - return + return None + + return self._serialise_handler_result(command, result) + + async def _dispatch_handler_response( + self, + command: str, + handler: Callable[..., Awaitable[Any]], + payload_obj: Optional[Any], + responder: Callable[[bytes], Optional[Any]], + ) -> None: + """Execute a handler and forward the serialised response.""" - response_bytes = self._serialise_handler_result(command, result) + response_bytes = await self._execute_command_handler( + command, handler, payload_obj + ) if response_bytes is None: return diff --git a/tests/test_example_emergency_management.py b/tests/test_example_emergency_management.py index 0a682f7..386d2dc 100644 --- a/tests/test_example_emergency_management.py +++ b/tests/test_example_emergency_management.py @@ -408,6 +408,15 @@ async def fake_retrieve(client, server_id, callsign): fake_retrieve, raising=False, ) + async def immediate_wait(*args, **kwargs): + return None + + monkeypatch.setattr( + module, + "_wait_until_interrupted", + immediate_wait, + raising=False, + ) await module.main() @@ -489,6 +498,15 @@ async def fake_retrieve(client, server_id, callsign): fake_retrieve, raising=False, ) + async def immediate_wait(*args, **kwargs): + return None + + monkeypatch.setattr( + module, + "_wait_until_interrupted", + immediate_wait, + raising=False, + ) await module.main() @@ -578,6 +596,15 @@ async def fake_retrieve(client, server_id, callsign): fake_retrieve, raising=False, ) + async def immediate_wait(*args, **kwargs): + return None + + monkeypatch.setattr( + module, + "_wait_until_interrupted", + immediate_wait, + raising=False, + ) await module.main() diff --git a/tests/test_service.py b/tests/test_service.py index e77a0e8..c67de85 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -1,6 +1,9 @@ import asyncio from dataclasses import dataclass from types import SimpleNamespace +from typing import Any +from typing import Callable +from typing import Optional from unittest.mock import Mock import pytest @@ -277,6 +280,29 @@ async def test_link_request_dispatches_routes() -> None: """Link request handlers should execute command routes and return responses.""" loop = asyncio.get_running_loop() + handlers: dict[str, Callable[..., Optional[bytes]]] = {} + + class DummyDestination: + def __init__(self) -> None: + self.register_calls: list[str] = [] + + def deregister_request_handler(self, path: str) -> None: + handlers.pop(path, None) + + def register_request_handler( + self, + path: str, + response_generator: Callable[..., Optional[bytes]], + allow: Optional[int] = None, + allowed_list: Optional[list[bytes]] = None, + auto_compress: bool | int = True, + ) -> None: + self.register_calls.append(path) + handlers[path] = response_generator + + def set_link_established_callback(self, callback: Callable[[Any], None]) -> None: + self.link_callback = callback + service = LXMFService.__new__(LXMFService) service._loop = loop service.auth_token = None @@ -286,11 +312,12 @@ async def test_link_request_dispatches_routes() -> None: service._link_keepalive_interval = 0 service._active_links = {} service._link_keepalive_tasks = {} + service.link_destination = DummyDestination() async def handler() -> dict: return {"status": "ok"} - service._routes["PING"] = (handler, None, None) + service.add_route("PING", handler) class DummyLink: def __init__(self) -> None: @@ -302,17 +329,15 @@ def set_link_closed_callback(self, callback): link = DummyLink() service._link_established(link) - request_handler = getattr(link, "request_handler", None) - assert callable(request_handler) + assert link.link_id in service._active_links - response_future = loop.create_future() + request_handler = handlers.get("/commands/PING") + assert callable(request_handler) - def respond(payload: bytes) -> None: - if not response_future.done(): - response_future.set_result(payload) + def _invoke_handler() -> Optional[bytes]: + return request_handler("/commands/PING", b"", object()) - request_handler("/commands/PING", b"", respond) - payload = await asyncio.wait_for(response_future, 1) + payload = await loop.run_in_executor(None, _invoke_handler) assert msgpack_from_bytes(payload) == {"status": "ok"} diff --git a/tests/test_service_extra.py b/tests/test_service_extra.py index c9d8c73..5bed97a 100644 --- a/tests/test_service_extra.py +++ b/tests/test_service_extra.py @@ -4,6 +4,8 @@ from types import SimpleNamespace from unittest.mock import Mock +from typing import Callable + import pytest from reticulum_openapi import service as service_module @@ -162,11 +164,13 @@ class Destination: IN = "in" SINGLE = "single" OUT = "out" + ALLOW_ALL = "allow_all" def __init__(self, *args, **kwargs): self.hash = b"h" self.accepts_links_called = [] self.link_callback = None + self.request_handlers: dict[str, Callable[..., bytes]] = {} destinations.append(self) def announce(self): @@ -178,6 +182,19 @@ def accepts_links(self, flag): def set_link_established_callback(self, callback): self.link_callback = callback + def register_request_handler( + self, + path, + response_generator, + allow=None, + allowed_list=None, + auto_compress=True, + ): + self.request_handlers[path] = response_generator + + def deregister_request_handler(self, path): + self.request_handlers.pop(path, None) + class Link: KEEPALIVE = 0.1 @@ -214,8 +231,10 @@ class FakeLXMF: ) svc = service_module.LXMFService() assert isinstance(svc.router, FakeLXMRouter) + assert "/commands/GetSchema" in destinations[-1].request_handlers svc.add_route("PING", lambda: None) assert "PING" in svc._routes + assert "/commands/PING" in destinations[-1].request_handlers assert svc.link_destination is destinations[-1] assert destinations[-1].accepts_links_called == [True] @@ -319,3 +338,31 @@ async def handler(payload): svc._lxmf_delivery_callback(message) await asyncio.sleep(0) svc._send_lxmf.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_registered_link_request_dispatches(): + async def handler(payload): + return {"ok": True, "echo": payload} + + svc = service_module.LXMFService.__new__(service_module.LXMFService) + svc._routes = {"CMD": (handler, None, None)} + svc.max_payload_size = 1024 + svc._loop = asyncio.get_running_loop() + svc.auth_token = None + + payload = dataclass_to_msgpack({"value": 1}) + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, + lambda: svc._handle_registered_link_request( + "/commands/CMD", + payload, + request_id=object(), + ), + ) + + assert response is not None + decoded = service_module.msgpack_from_bytes(response) + assert decoded["ok"] is True + assert decoded["echo"]["value"] == 1