diff --git a/reticulum_openapi/client.py b/reticulum_openapi/client.py index 9e3d1f5..b94a59d 100644 --- a/reticulum_openapi/client.py +++ b/reticulum_openapi/client.py @@ -27,6 +27,51 @@ logger = logging.getLogger(__name__) +def _prepare_config_directory(config_path: Optional[str]) -> Optional[str]: + """Normalise a Reticulum configuration path to an existing directory. + + Args: + config_path (Optional[str]): User supplied configuration path. Can + reference either the configuration directory or the ``config`` + file inside that directory. + + Returns: + Optional[str]: Directory path suitable for ``RNS.Reticulum``. When + ``config_path`` is falsy, ``None`` is returned to preserve the + default Reticulum discovery behaviour. + """ + + if not config_path: + return None + + candidate = Path(config_path).expanduser() + + if candidate.exists(): + if candidate.is_file(): + directory = candidate.parent + elif candidate.is_dir(): + directory = candidate + else: + directory = candidate.parent + else: + if candidate.suffix: + directory = candidate.parent + elif candidate.name == "config": + directory = candidate.parent + else: + directory = candidate + + if directory is None or str(directory) == "": + return None + + try: + directory.mkdir(parents=True, exist_ok=True) + except OSError: + pass + + return str(directory) + + class _AnnounceHandler: """Adapter that forwards Reticulum announces into an asyncio queue.""" @@ -63,17 +108,26 @@ def __init__( timeout: float = 10.0, shared_instance_rpc_key: Optional[str] = None, ): - self.reticulum = RNS.Reticulum(config_path) + config_directory = _prepare_config_directory(config_path) + self.reticulum = RNS.Reticulum(config_directory) self._shared_instance_rpc_key: Optional[bytes] = None if shared_instance_rpc_key is not None: key_bytes = self._decode_shared_instance_rpc_key(shared_instance_rpc_key) self.reticulum.rpc_key = key_bytes self._shared_instance_rpc_key = key_bytes - storage_path = storage_path or (RNS.Reticulum.storagepath + "/lxmf_client") - self.router = LXMF.LXMRouter(storagepath=storage_path) + if storage_path: + resolved_storage = Path(storage_path).expanduser() + else: + resolved_storage = Path(RNS.Reticulum.storagepath) / "lxmf_client" + try: + resolved_storage.mkdir(parents=True, exist_ok=True) + except OSError: + pass + self.router = LXMF.LXMRouter(storagepath=str(resolved_storage)) self.router.register_delivery_callback(self._callback) if identity is None: - identity = load_or_create_identity(config_path) + identity_base = config_directory or config_path + identity = load_or_create_identity(identity_base) self.identity = identity self.source_identity = self.router.register_delivery_identity( identity, display_name=display_name, stamp_cost=0 diff --git a/tests/test_client_extra.py b/tests/test_client_extra.py index cd1cc9e..496f370 100644 --- a/tests/test_client_extra.py +++ b/tests/test_client_extra.py @@ -70,6 +70,72 @@ def fake_register(handler): assert register_calls["handler"].aspect_filter == "lxmf" +@pytest.mark.asyncio +async def test_client_normalises_config_file_path(monkeypatch, tmp_path): + config_dir = tmp_path / "reticulum" + config_dir.mkdir() + config_file = config_dir / "config" + config_file.write_text("interfaces {}\n") + + captured = {} + + class DummyReticulum: + storagepath = str(tmp_path / "existing_storage") + + def __init__(self, config_path=None): + captured["config_path"] = config_path + + class DummyIdentity: + def __init__(self): + self.hash = b"h" + self.announce = Mock() + + class DummyRouter: + def __init__(self, storagepath=None): + captured["storage_path"] = storagepath + + def register_delivery_callback(self, cb): + self.cb = cb + + def register_delivery_identity(self, ident, display_name=None, stamp_cost=0): + return ident + + class DummyDestination: + OUT = object() + SINGLE = object() + + def __init__(self, *a, **k): + pass + + def fake_register(handler): + captured["handler"] = handler + + monkeypatch.setattr(client_module.RNS, "Reticulum", DummyReticulum) + monkeypatch.setattr(client_module.RNS, "Identity", DummyIdentity) + monkeypatch.setattr(client_module.RNS, "Destination", DummyDestination) + monkeypatch.setattr( + client_module.RNS.Transport, "register_announce_handler", fake_register + ) + monkeypatch.setattr(client_module.LXMF, "LXMRouter", DummyRouter) + + def fake_load(path, *args, **kwargs): + captured["identity_path"] = path + return DummyIdentity() + + monkeypatch.setattr(client_module, "load_or_create_identity", fake_load) + + storage_dir = tmp_path / "custom_storage" + client_module.LXMFClient( + config_path=str(config_file), + storage_path=str(storage_dir), + ) + + assert captured["config_path"] == str(config_dir) + assert captured["identity_path"] == str(config_dir) + assert captured["storage_path"] == str(storage_dir) + assert storage_dir.is_dir() + + @pytest.mark.asyncio async def test_send_command_bytes_payload(monkeypatch): loop = asyncio.get_running_loop()