diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index 3e700f3f..6cbd71c9 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -139,10 +139,12 @@ async def get_products(params, context=None): ) from adcp.server.sponsored_intelligence import SponsoredIntelligenceHandler from adcp.server.tenant_router import ( + CallableSubdomainTenantRouter, InMemorySubdomainTenantRouter, SubdomainTenantMiddleware, SubdomainTenantRouter, Tenant, + TenantResolver, current_tenant, ) from adcp.server.test_controller import ( @@ -204,10 +206,12 @@ async def get_products(params, context=None): "IdempotencyStore", "MemoryBackend", # Subdomain tenant routing + "CallableSubdomainTenantRouter", "InMemorySubdomainTenantRouter", "SubdomainTenantMiddleware", "SubdomainTenantRouter", "Tenant", + "TenantResolver", "current_tenant", # Multi-agent discovery manifest (/.well-known/adcp-agents.json) "DISCOVERY_PATH", diff --git a/src/adcp/server/tenant_router.py b/src/adcp/server/tenant_router.py index beb8cce6..a01d8c78 100644 --- a/src/adcp/server/tenant_router.py +++ b/src/adcp/server/tenant_router.py @@ -15,8 +15,13 @@ * :class:`SubdomainTenantRouter` — runtime-checkable Protocol with one async ``resolve(host: str) -> Tenant | None`` method. * :class:`InMemorySubdomainTenantRouter` — reference impl for - dev/test backed by a static ``host → Tenant`` dict. Production - adopters back the Protocol with their tenant table. + dev/test backed by a static ``host → Tenant`` dict. +* :class:`CallableSubdomainTenantRouter` — adopter-callable router + for DB-backed lookups. Adopter writes a single sync-or-async + callable mapping a normalized host to a :class:`Tenant`; the + framework owns host normalization. Optional bounded TTL cache + for hot-path lookups. **Recommended for production multi-tenant + deployments** — replaces ~25 LOC of adopter glue with ~5. * :class:`SubdomainTenantMiddleware` — Starlette ASGI middleware that calls the router, stashes the result in a :class:`contextvars.ContextVar`, and ``404`` s on unknown hosts. @@ -84,6 +89,10 @@ def build_context(meta): from __future__ import annotations import contextvars +import inspect +import time +from collections import OrderedDict +from collections.abc import Awaitable, Callable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable @@ -160,6 +169,187 @@ async def resolve(self, host: str) -> Tenant | None: return self._tenants.get(_normalize_host(host)) +# Type alias for adopter-supplied lookup callables. Either sync (returns +# Tenant | None) or async (returns Awaitable[Tenant | None]) is accepted — +# CallableSubdomainTenantRouter awaits at call time. Receives the +# already-normalized (lower-cased + port-stripped) host so adopters don't +# reimplement the parser. +TenantResolver = Callable[[str], "Tenant | None | Awaitable[Tenant | None]"] + + +class CallableSubdomainTenantRouter: + """Adopter-callable :class:`SubdomainTenantRouter` for DB-backed lookups. + + The adopter passes a single callable mapping a normalized host to a + :class:`Tenant` (or ``None`` for 404). The framework owns host + normalization (lower-case + port-strip), so adopters write only the + lookup itself — typically a single SQL query against their tenant + table. + + The callable may be sync or async; the router awaits at call time. + + Example:: + + from sqlalchemy import select + from adcp.server import CallableSubdomainTenantRouter, Tenant + + async def lookup(host: str) -> Tenant | None: + subdomain = host.split(".", 1)[0] # 'acme.example.com' -> 'acme' + async with my_db.session() as s: + row = await s.scalar( + select(TenantRow).filter_by(subdomain=subdomain, is_active=True) + ) + return Tenant(id=row.tenant_id, display_name=row.name) if row else None + + router = CallableSubdomainTenantRouter(lookup) + + Optional bounded TTL cache absorbs hot-path lookups without adopters + reimplementing — useful when the resolver hits a remote DB on every + request. Defaults to **no caching** (``cache_size=0``); adopters opt + in with explicit bounds: + + :: + + router = CallableSubdomainTenantRouter( + lookup, + cache_size=1024, # bounded LRU; never grows beyond this + cache_ttl_seconds=60.0, # expire entries after 60s + ) + + Cache bounds are mandatory when caching is enabled — there is no + "cache forever, unbounded size" mode by design. Tenants come and go + (suspension, deactivation); long-lived caches without TTL hand + adopters a stale-cache footgun. The ``cache_ttl_seconds`` ceiling is + the explicit knob. + + **Negative-cache + tenant onboarding race.** When caching is enabled, + ``None`` results are cached too (to absorb probing for unknown hosts). + This creates a race on tenant creation: if a probe for + ``acme.example.com`` hits at T=0 (host doesn't exist yet) and the + tenant is provisioned at T=1, the cached ``None`` causes 404s for up + to ``cache_ttl_seconds`` afterward. Call ``invalidate(host)`` from + your tenant *creation* path — not only deactivation — to clear the + negative entry immediately:: + + # on tenant create / re-activate + router.invalidate("acme.example.com") + + Memory profile + -------------- + Without caching: zero state held by the router. Each ``resolve()`` + call awaits the adopter callable directly. + + With caching: bounded by ``cache_size`` entries. Maximum memory is + ``cache_size × (sizeof(host_str) + sizeof(your_Tenant) + 16)`` + where ``sizeof(your_Tenant)`` depends on what you store in + :attr:`Tenant.ext` — the router can't predict it. The cache never + grows beyond ``cache_size`` entries regardless of payload size. + """ + + def __init__( + self, + resolver: TenantResolver, + *, + cache_size: int = 0, + cache_ttl_seconds: float = 0.0, + ) -> None: + """Construct the router. + + :param resolver: Callable taking a normalized host string and + returning ``Tenant | None`` (sync or async). Receives + already-normalized hosts — lower-cased with any + ``:port`` suffix stripped. + :param cache_size: Maximum number of cached lookups. ``0`` + disables caching entirely (the adopter callable is awaited + on every request). Must be ``>= 0``. + :param cache_ttl_seconds: Per-entry TTL in seconds. Must be + ``> 0`` when ``cache_size > 0``. There is no "cache forever" + mode — see the class docstring for rationale. + :raises ValueError: If ``cache_size > 0`` and + ``cache_ttl_seconds <= 0`` (cache requires explicit TTL). + """ + if cache_size < 0: + raise ValueError(f"cache_size must be >= 0, got {cache_size}") + if cache_size > 0 and cache_ttl_seconds <= 0: + raise ValueError( + "cache_ttl_seconds must be > 0 when cache_size > 0; " + "explicit TTL prevents stale-tenant footguns. Pass a " + "value like 60.0 (one-minute cache) to opt in." + ) + self._resolver = resolver + self._cache_size = cache_size + self._cache_ttl = cache_ttl_seconds + # OrderedDict gives us LRU-by-move-to-end for free; bounded by + # popitem(last=False) when over cache_size. Each entry is + # (Tenant | None, expires_at_monotonic). Negative results are + # cached too so DOS-style probing doesn't bypass the cache. + self._cache: OrderedDict[str, tuple[Tenant | None, float]] = OrderedDict() + + async def resolve(self, host: str) -> Tenant | None: + normalized = _normalize_host(host) + + if self._cache_size > 0: + cached = self._cache_get(normalized) + if cached is not _CACHE_MISS: + return cached # type: ignore[return-value] + + result = self._resolver(normalized) + if inspect.isawaitable(result): + result = await result + + if self._cache_size > 0: + self._cache_put(normalized, result) + + return result + + # ----- cache internals (request-path; keep tight) --------------------- + + def _cache_get(self, host: str) -> Tenant | None | object: + entry = self._cache.get(host) + if entry is None: + return _CACHE_MISS + tenant, expires_at = entry + if time.monotonic() > expires_at: + # Expired — drop and miss. Don't await a fresh resolve here; + # the caller does that. Avoids holding the entry through the + # adopter callable's network round-trip. + self._cache.pop(host, None) + return _CACHE_MISS + # LRU touch + self._cache.move_to_end(host) + return tenant + + def _cache_put(self, host: str, tenant: Tenant | None) -> None: + expires_at = time.monotonic() + self._cache_ttl + self._cache[host] = (tenant, expires_at) + self._cache.move_to_end(host) + # Bound size — evict oldest until under limit. + while len(self._cache) > self._cache_size: + self._cache.popitem(last=False) + + def invalidate(self, host: str | None = None) -> None: + """Drop a cached entry (or all entries when ``host`` is ``None``). + + Adopters call this from their tenant-creation, -deactivation, and + -modification flows to evict stale entries before the TTL fires. + Creation matters because negative results (``None``) are cached — + see the class docstring for details. Safe to call even when caching + is disabled (no-op). + + :param host: Specific host to evict (raw or normalized — the + method normalizes internally). ``None`` clears the entire + cache. + """ + if host is None: + self._cache.clear() + return + self._cache.pop(_normalize_host(host), None) + + +# Sentinel for cache miss vs. cached-None (negative result) +_CACHE_MISS: object = object() + + # Module-level contextvar — request-scoped via the ASGI middleware's # per-call `set()`. ASGI guarantees per-task context isolation, so # concurrent requests on the same process see only their own tenant. @@ -303,9 +493,11 @@ async def _send_404(send: Send, *, reason: str) -> None: __all__ = [ + "CallableSubdomainTenantRouter", "InMemorySubdomainTenantRouter", "SubdomainTenantMiddleware", "SubdomainTenantRouter", "Tenant", + "TenantResolver", "current_tenant", ] diff --git a/tests/test_subdomain_tenant_router.py b/tests/test_subdomain_tenant_router.py index c0d1d250..f2439f28 100644 --- a/tests/test_subdomain_tenant_router.py +++ b/tests/test_subdomain_tenant_router.py @@ -28,6 +28,7 @@ from starlette.testclient import TestClient # noqa: E402 from adcp.server import ( # noqa: E402 + CallableSubdomainTenantRouter, InMemorySubdomainTenantRouter, SubdomainTenantMiddleware, SubdomainTenantRouter, @@ -114,6 +115,246 @@ def test_in_memory_router_satisfies_protocol() -> None: assert isinstance(router, SubdomainTenantRouter) +# ----- CallableSubdomainTenantRouter --------------------------------------- + + +def test_callable_router_passes_normalized_host_to_resolver() -> None: + """Adopter callable receives the lower-cased + port-stripped host.""" + received: list[str] = [] + + async def lookup(host: str) -> Tenant | None: + received.append(host) + return Tenant(id="acme", display_name="Acme") if host == "acme.example.com" else None + + router = CallableSubdomainTenantRouter(lookup) + result = asyncio.run(router.resolve("ACME.Example.COM:8080")) + + assert received == ["acme.example.com"] + assert result is not None + assert result.id == "acme" + + +def test_callable_router_supports_sync_callables() -> None: + """Adopter may pass a plain sync function — no `async def` required.""" + + def lookup(host: str) -> Tenant | None: + return Tenant(id="acme") if host == "acme.example.com" else None + + router = CallableSubdomainTenantRouter(lookup) + result = asyncio.run(router.resolve("acme.example.com")) + assert result is not None + assert result.id == "acme" + + +def test_callable_router_returns_none_for_unknown_host() -> None: + async def lookup(host: str) -> Tenant | None: + return None + + router = CallableSubdomainTenantRouter(lookup) + assert asyncio.run(router.resolve("unknown.example.com")) is None + + +def test_callable_router_satisfies_protocol() -> None: + async def lookup(host: str) -> Tenant | None: + return None + + router = CallableSubdomainTenantRouter(lookup) + assert isinstance(router, SubdomainTenantRouter) + + +def test_callable_router_default_no_caching() -> None: + """Default ``cache_size=0`` — every resolve calls the resolver.""" + call_count = 0 + + async def lookup(host: str) -> Tenant | None: + nonlocal call_count + call_count += 1 + return Tenant(id="acme") + + router = CallableSubdomainTenantRouter(lookup) + asyncio.run(router.resolve("acme.example.com")) + asyncio.run(router.resolve("acme.example.com")) + asyncio.run(router.resolve("acme.example.com")) + assert call_count == 3 + + +def test_callable_router_caching_dedupes_within_ttl() -> None: + """Within ``cache_ttl_seconds`` the resolver is only called once per host.""" + call_count = 0 + + async def lookup(host: str) -> Tenant | None: + nonlocal call_count + call_count += 1 + return Tenant(id="acme") + + router = CallableSubdomainTenantRouter(lookup, cache_size=8, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("acme.example.com")) + asyncio.run(router.resolve("acme.example.com")) + asyncio.run(router.resolve("acme.example.com")) + assert call_count == 1 + + +def test_callable_router_caching_negative_results_too() -> None: + """Cached ``None`` is honored — DOS-style probing for unknown hosts + doesn't bypass the cache.""" + call_count = 0 + + async def lookup(host: str) -> Tenant | None: + nonlocal call_count + call_count += 1 + return None + + router = CallableSubdomainTenantRouter(lookup, cache_size=8, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("attacker.example.com")) + asyncio.run(router.resolve("attacker.example.com")) + assert call_count == 1 + + +def test_callable_router_caching_evicts_after_ttl(monkeypatch) -> None: + """Entries older than ``cache_ttl_seconds`` re-query the resolver.""" + call_count = 0 + + async def lookup(host: str) -> Tenant | None: + nonlocal call_count + call_count += 1 + return Tenant(id="acme") + + fake_clock = [1000.0] + + def fake_monotonic() -> float: + return fake_clock[0] + + monkeypatch.setattr("adcp.server.tenant_router.time.monotonic", fake_monotonic) + + router = CallableSubdomainTenantRouter(lookup, cache_size=8, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("acme.example.com")) + fake_clock[0] += 30 # within TTL + asyncio.run(router.resolve("acme.example.com")) + assert call_count == 1 + + fake_clock[0] += 31 # past TTL (61s total) + asyncio.run(router.resolve("acme.example.com")) + assert call_count == 2 + + +def test_callable_router_cache_bounded_by_size() -> None: + """``cache_size`` is a hard ceiling — oldest entries evicted on overflow.""" + + def lookup(host: str) -> Tenant | None: + return Tenant(id=host.split(".")[0]) + + router = CallableSubdomainTenantRouter(lookup, cache_size=2, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("a.example.com")) + asyncio.run(router.resolve("b.example.com")) + asyncio.run(router.resolve("c.example.com")) # evicts 'a' + # Cache still bounded — never grows beyond cache_size + assert len(router._cache) == 2 # noqa: SLF001 — testing bound directly + assert "a.example.com" not in router._cache + assert "b.example.com" in router._cache + assert "c.example.com" in router._cache + + +def test_callable_router_invalidate_specific_host() -> None: + """``invalidate(host)`` drops a cached entry; next call re-queries.""" + call_count = 0 + + async def lookup(host: str) -> Tenant | None: + nonlocal call_count + call_count += 1 + return Tenant(id="acme") + + router = CallableSubdomainTenantRouter(lookup, cache_size=8, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("acme.example.com")) + asyncio.run(router.resolve("acme.example.com")) + assert call_count == 1 + + router.invalidate("ACME.Example.COM:8080") # any-case + port form works + asyncio.run(router.resolve("acme.example.com")) + assert call_count == 2 + + +def test_callable_router_invalidate_all() -> None: + """``invalidate()`` with no arg clears every entry.""" + + def lookup(host: str) -> Tenant | None: + return Tenant(id=host.split(".")[0]) + + router = CallableSubdomainTenantRouter(lookup, cache_size=8, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("a.example.com")) + asyncio.run(router.resolve("b.example.com")) + assert len(router._cache) == 2 # noqa: SLF001 + + router.invalidate() + assert len(router._cache) == 0 # noqa: SLF001 + + +def test_callable_router_invalidate_no_op_without_caching() -> None: + """Invalidating a router with caching disabled is a safe no-op.""" + + async def lookup(host: str) -> Tenant | None: + return None + + router = CallableSubdomainTenantRouter(lookup) # cache_size=0 + router.invalidate("anything.example.com") + router.invalidate() + # No exception — cache stays empty + assert len(router._cache) == 0 # noqa: SLF001 + + +def test_callable_router_case_and_port_variants_share_cache_entry() -> None: + """Case variants and port suffix all normalize to the same cache key. + + ``Acme.localhost:3001`` and ``acme.localhost`` must hit the resolver + exactly once — a second probe after the cache is warm must not call + the resolver again, regardless of how the host was presented. + """ + call_count = 0 + + async def lookup(host: str) -> Tenant | None: + nonlocal call_count + call_count += 1 + return Tenant(id="acme") + + router = CallableSubdomainTenantRouter(lookup, cache_size=8, cache_ttl_seconds=60.0) + asyncio.run(router.resolve("Acme.localhost:3001")) + asyncio.run(router.resolve("acme.localhost")) + assert call_count == 1 + + +def test_callable_router_rejects_cache_without_ttl() -> None: + """Cache requires explicit TTL — no 'cache forever' mode.""" + with pytest.raises(ValueError, match="TTL"): + CallableSubdomainTenantRouter( + lambda host: None, + cache_size=8, + # cache_ttl_seconds defaults to 0 — invalid when caching enabled + ) + + +def test_callable_router_rejects_negative_cache_size() -> None: + with pytest.raises(ValueError, match="cache_size"): + CallableSubdomainTenantRouter(lambda host: None, cache_size=-1) + + +def test_callable_router_through_middleware() -> None: + """End-to-end: callable router behind the standard middleware.""" + + async def lookup(host: str) -> Tenant | None: + if host == "acme.example.com": + return Tenant(id="acme", display_name="Acme") + return None + + router = CallableSubdomainTenantRouter(lookup) + client = TestClient(_build_app(router)) + + resp = client.get("/whoami", headers={"Host": "acme.example.com"}) + assert resp.status_code == 200 + assert resp.json() == {"tenant_id": "acme", "display_name": "Acme"} + + resp = client.get("/whoami", headers={"Host": "unknown.example.com"}) + assert resp.status_code == 404 + + # ----- middleware: known host happy path ------------------------------