From 6c28faacf5cd3e3794ba907a28fb9038cd122282 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 9 Mar 2026 12:56:11 -0700 Subject: [PATCH 1/3] Address PR #252 review findings: error handling, types, tests Error handling: - Add CorruptMirrorError/ValueError handling to all mirror endpoints - Block unknown models on platform/community keys (fail-closed) - Add OSError handling to create_mirror_endpoint - Make cleanup_expired_mirrors resilient to per-mirror failures - Narrow scheduler cleanup catch to expected exception types - Add field_validator to RefreshMirrorRequest.community_ids Type design: - Make MirrorInfo a frozen dataclass with tuple community_ids - Move is_safe_identifier to src/core/validation.py (shared utility) - Add non-negativity validation to MODEL_PRICING at import time - Expand SecureFormatter key patterns for Anthropic/OpenAI keys Code quality: - Replace deprecated asyncio.get_event_loop() with get_running_loop() - Fix ContextVar comment accuracy (request lifecycle, not per-task) - Use get_active_mirror() instead of _active_mirror_id.get() - Fix docstring inaccuracies (caching, asyncio, model names) Tests: - Add active_mirror_context tests (set/reset, exception safety) - Add MirrorInfo invariant tests (empty ids, invalid id, immutability) - Add serialization round-trip test - Add TTL clamping test - Add run_sync_now invalid sync_type test - Update cost protection test for fail-closed behavior Closes #256 --- src/api/routers/community.py | 13 ++- src/api/routers/mirrors.py | 77 +++++++++++++-- src/api/scheduler.py | 6 +- src/core/logging.py | 20 ++-- src/core/services/litellm_llm.py | 5 +- src/core/services/llm.py | 2 +- src/core/validation.py | 13 +++ src/knowledge/db.py | 10 +- src/knowledge/mirror.py | 61 +++++++----- src/metrics/cost.py | 11 +++ tests/test_api/test_cost_protection.py | 9 +- tests/test_knowledge/test_mirror.py | 124 ++++++++++++++++++++++++- 12 files changed, 297 insertions(+), 54 deletions(-) create mode 100644 src/core/validation.py diff --git a/src/api/routers/community.py b/src/api/routers/community.py index a13ca55..3e16e5a 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -620,12 +620,19 @@ def _check_model_cost(model: str, key_source: str) -> None: pricing = MODEL_PRICING.get(model) if pricing is None: - logger.warning( - "Model %s not in pricing table; allowing without cost check. " + logger.error( + "Model %s not in pricing table; blocking on platform/community key. " "Add this model to MODEL_PRICING in src/metrics/cost.py.", model, ) - return + raise HTTPException( + status_code=403, + detail=( + f"Model '{model}' is not in the approved pricing list and cannot be used " + "with platform or community keys. To use this model, provide your own " + "API key via the X-OpenRouter-Key header." + ), + ) input_rate = pricing.input_per_1m if input_rate >= COST_BLOCK_THRESHOLD: diff --git a/src/api/routers/mirrors.py b/src/api/routers/mirrors.py index fef751f..be11ad5 100644 --- a/src/api/routers/mirrors.py +++ b/src/api/routers/mirrors.py @@ -15,14 +15,15 @@ from pydantic import BaseModel, Field, field_validator from src.api.security import RequireAuth +from src.core.validation import is_safe_identifier from src.knowledge.db import active_mirror_context from src.knowledge.mirror import ( + CorruptMirrorError, MirrorInfo, create_mirror, delete_mirror, get_mirror, get_mirror_db_path, - is_safe_identifier, list_mirrors, refresh_mirror, ) @@ -93,6 +94,16 @@ class RefreshMirrorRequest(BaseModel): default=None, description="Specific communities to refresh, or null for all" ) + @field_validator("community_ids") + @classmethod + def validate_community_ids(cls, v: list[str] | None) -> list[str] | None: + if v is None: + return v + for cid in v: + if not is_safe_identifier(cid): + raise ValueError(f"Invalid community ID: {cid!r}") + return list(dict.fromkeys(v)) + SyncType = Literal["github", "papers", "docstrings", "mailman", "faq", "beps", "all"] @@ -135,6 +146,12 @@ async def create_mirror_endpoint( ) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except OSError as e: + logger.error("Failed to create mirror: %s", e, exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create mirror due to a server filesystem error.", + ) from e logger.info( "Mirror created: %s (communities=%s, owner=%s)", @@ -161,7 +178,18 @@ async def get_mirror_endpoint( _auth: RequireAuth, ) -> MirrorResponse: """Get metadata for a specific mirror.""" - info = get_mirror(mirror_id) + try: + info = get_mirror(mirror_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) if not info: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -205,6 +233,11 @@ async def refresh_mirror_endpoint( info = refresh_mirror(mirror_id, community_ids=body.community_ids) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) logger.info("Mirror refreshed via API: %s", mirror_id) return MirrorResponse.from_info(info) @@ -222,7 +255,18 @@ async def sync_mirror_endpoint( databases instead of production. Supports sync types: github, papers, docstrings, mailman, faq, beps, or all. """ - info = get_mirror(mirror_id) + try: + info = get_mirror(mirror_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) if not info: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, @@ -235,8 +279,9 @@ async def sync_mirror_endpoint( ) # Run sync in a thread with the mirror context explicitly copied. - # asyncio.to_thread copies ContextVars on Python 3.12+ but not 3.11, - # so we capture the context and run within it for compatibility. + # We use run_in_executor with an explicit context copy instead of + # asyncio.to_thread because to_thread only copies ContextVars + # automatically on Python 3.12+. from src.api.scheduler import run_sync_now ctx = contextvars.copy_context() @@ -246,7 +291,8 @@ def _run_sync_in_mirror() -> dict[str, int]: return run_sync_now(body.sync_type) try: - results = await asyncio.get_event_loop().run_in_executor(None, ctx.run, _run_sync_in_mirror) + loop = asyncio.get_running_loop() + results = await loop.run_in_executor(None, ctx.run, _run_sync_in_mirror) total = sum(results.values()) return MirrorSyncResponse( message=f"Sync completed: {total} items synced into mirror {mirror_id}", @@ -254,6 +300,12 @@ def _run_sync_in_mirror() -> dict[str, int]: ) except ValueError as e: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except OSError as e: + logger.error("Mirror sync I/O error for %s: %s", mirror_id, e, exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Sync failed due to a filesystem error: {e}", + ) from e except Exception as e: logger.error("Mirror sync failed for %s: %s", mirror_id, e, exc_info=True) raise HTTPException( @@ -274,7 +326,18 @@ async def download_mirror_db( """ from fastapi.responses import FileResponse - info = get_mirror(mirror_id) + try: + info = get_mirror(mirror_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) if not info: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, diff --git a/src/api/scheduler.py b/src/api/scheduler.py index 65b4c5e..555f278 100644 --- a/src/api/scheduler.py +++ b/src/api/scheduler.py @@ -292,14 +292,14 @@ def _run_beps_sync_for_community(community_id: str) -> bool: def _cleanup_mirrors() -> None: """Remove expired ephemeral database mirrors.""" global _mirror_cleanup_failures - try: - from src.knowledge.mirror import cleanup_expired_mirrors + from src.knowledge.mirror import CorruptMirrorError, cleanup_expired_mirrors + try: deleted = cleanup_expired_mirrors() if deleted: logger.info("Mirror cleanup: removed %d expired mirrors", deleted) _mirror_cleanup_failures = 0 - except Exception: + except (OSError, ValueError, CorruptMirrorError): _mirror_cleanup_failures += 1 logger.error( "Mirror cleanup failed (consecutive failures: %d)", diff --git a/src/core/logging.py b/src/core/logging.py index 46e24eb..25114f2 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -1,7 +1,8 @@ """Secure logging configuration with API key redaction. -Provides a custom log formatter that automatically redacts OpenRouter API keys -from log messages to prevent credential exposure in centralized logging systems. +Provides a custom log formatter that automatically redacts API keys +(OpenRouter, Anthropic, OpenAI) from log messages to prevent credential +exposure in centralized logging systems. Supports both text and JSON-structured logging formats. """ @@ -17,12 +18,19 @@ class SecureFormatter(logging.Formatter): """Custom log formatter that redacts API keys from log messages. - Automatically detects and redacts OpenRouter API keys in the format - sk-or-v1-[64 hex chars] to prevent accidental credential exposure. + Automatically detects and redacts API keys from OpenRouter, Anthropic, + and OpenAI to prevent accidental credential exposure. """ - # Pattern to match OpenRouter API keys: sk-or-v1-[64 hex chars] - API_KEY_PATTERN = re.compile(r"sk-or-v1-[0-9a-f]{64}", re.IGNORECASE) + # Patterns for API keys from various providers. + # IGNORECASE as defense-in-depth; real keys use lowercase hex. + API_KEY_PATTERN = re.compile( + r"sk-or-v1-[0-9a-f]{64}" # OpenRouter: sk-or-v1-[64 hex chars] + r"|sk-ant-[a-zA-Z0-9_-]{80,}" # Anthropic: sk-ant-... + r"|sk-proj-[a-zA-Z0-9_-]{40,}" # OpenAI project keys: sk-proj-... + r"|sk-[a-zA-Z0-9]{48,}", # Generic OpenAI keys: sk-... + re.IGNORECASE, + ) def format(self, record: logging.LogRecord) -> str: """Format log record and redact any API keys. diff --git a/src/core/services/litellm_llm.py b/src/core/services/litellm_llm.py index 291292a..1f43b47 100644 --- a/src/core/services/litellm_llm.py +++ b/src/core/services/litellm_llm.py @@ -69,8 +69,9 @@ def create_openrouter_llm( provider: Specific provider to use (e.g., "Cerebras", "DeepInfra/FP8"). Ignored for Anthropic models, which always use "Anthropic" provider. user_id: User identifier for cache optimization (sticky routing) - enable_caching: Enable prompt caching. If None (default), enabled for all models. - OpenRouter/LiteLLM gracefully handles models that don't support caching. + enable_caching: Enable prompt caching. If None (default), caching is requested + for all models. Models that do not support caching will ignore the + cache_control markers without error. Returns: LLM instance configured for OpenRouter diff --git a/src/core/services/llm.py b/src/core/services/llm.py index cf40697..bce2e70 100644 --- a/src/core/services/llm.py +++ b/src/core/services/llm.py @@ -143,7 +143,7 @@ def get_model( model_name: Model name. Supports: - OpenRouter format: 'creator/model' (e.g., 'openai/gpt-oss-120b', 'qwen/qwen3-235b') - Direct OpenAI: 'gpt-4o', 'gpt-4o-mini', etc. - - Direct Anthropic: 'claude-3-5-sonnet', etc. + - Direct Anthropic: 'claude-3.5-sonnet', etc. If not provided, uses settings.default_model. api_key: Optional API key override (for BYOK). temperature: Model temperature. If not provided, uses settings.llm_temperature. diff --git a/src/core/validation.py b/src/core/validation.py new file mode 100644 index 0000000..778ee86 --- /dev/null +++ b/src/core/validation.py @@ -0,0 +1,13 @@ +"""Shared input validation utilities. + +Provides common validation functions used across modules for +preventing path traversal and ensuring safe identifiers. +""" + + +def is_safe_identifier(value: str) -> bool: + """Check if a string is a safe identifier (alphanumeric, hyphens, underscores). + + Used for both mirror IDs and community IDs to prevent path traversal. + """ + return bool(value) and value.replace("-", "").replace("_", "").isalnum() diff --git a/src/knowledge/db.py b/src/knowledge/db.py index a505699..5c9166d 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -21,14 +21,16 @@ from pathlib import Path from src.cli.config import get_data_dir -from src.knowledge.mirror import _validate_mirror_id, is_safe_identifier +from src.core.validation import is_safe_identifier +from src.knowledge.mirror import _validate_mirror_id logger = logging.getLogger(__name__) # ContextVar for transparent mirror routing. When set, get_db_path() returns # the mirror's database path instead of the production path. -# ContextVar is safe for concurrent async tasks; each task gets its own copy, -# so mirror routing in one request does not affect other requests. +# Safe for concurrent requests because the middleware sets and resets the +# value around each request's lifecycle. Nested async calls within the +# same request inherit the value. _active_mirror_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( "_active_mirror_id", default=None ) @@ -448,7 +450,7 @@ def get_db_path(project: str = "hed") -> Path: "Use only alphanumeric characters, hyphens, and underscores." ) - mirror_id = _active_mirror_id.get() + mirror_id = get_active_mirror() if mirror_id: # mirror_id was already validated when set via set_active_mirror() return get_data_dir() / "mirrors" / mirror_id / f"{project}.db" diff --git a/src/knowledge/mirror.py b/src/knowledge/mirror.py index 71e48d9..87d2c10 100644 --- a/src/knowledge/mirror.py +++ b/src/knowledge/mirror.py @@ -4,6 +4,8 @@ developers to read, write, and re-sync without affecting production data. Each mirror gets its own directory under data/mirrors/{mirror_id}/ containing copies of the relevant community SQLite databases. + +Default TTL is 48 hours; maximum is 168 hours (7 days). """ import json @@ -15,9 +17,13 @@ from uuid import uuid4 from src.cli.config import get_data_dir +from src.core.validation import is_safe_identifier logger = logging.getLogger(__name__) +# Re-export for backward compatibility +__all__ = ["is_safe_identifier"] + MIRRORS_DIR_NAME = "mirrors" METADATA_FILE = "_metadata.json" @@ -28,20 +34,16 @@ MAX_MIRRORS_PER_USER = 2 -def is_safe_identifier(value: str) -> bool: - """Check if a string is a safe identifier (alphanumeric, hyphens, underscores). +@dataclass(frozen=True) +class MirrorInfo: + """Metadata for an ephemeral database mirror. - Used for both mirror IDs and community IDs to prevent path traversal. + Frozen dataclass: all fields are immutable after construction. + Default TTL is 48 hours; maximum is 168 hours (7 days). """ - return bool(value) and value.replace("-", "").replace("_", "").isalnum() - - -@dataclass -class MirrorInfo: - """Metadata for an ephemeral database mirror.""" mirror_id: str - community_ids: list[str] + community_ids: tuple[str, ...] created_at: datetime expires_at: datetime owner_id: str | None = None @@ -67,7 +69,7 @@ def to_dict(self) -> dict: """ return { "mirror_id": self.mirror_id, - "community_ids": self.community_ids, + "community_ids": list(self.community_ids), "created_at": self.created_at.isoformat(), "expires_at": self.expires_at.isoformat(), "owner_id": self.owner_id, @@ -75,15 +77,16 @@ def to_dict(self) -> dict: } @classmethod - def from_dict(cls, data: dict) -> "MirrorInfo": + def from_dict(cls, data: dict, size_bytes: int = 0) -> "MirrorInfo": """Deserialize from dictionary.""" return cls( mirror_id=data["mirror_id"], - community_ids=data["community_ids"], + community_ids=tuple(data["community_ids"]), created_at=datetime.fromisoformat(data["created_at"]), expires_at=datetime.fromisoformat(data["expires_at"]), owner_id=data.get("owner_id"), label=data.get("label"), + size_bytes=size_bytes, ) @@ -165,8 +168,7 @@ def _read_metadata(mirror_id: str) -> MirrorInfo | None: return None try: data = json.loads(path.read_text()) - info = MirrorInfo.from_dict(data) - info.size_bytes = _calculate_mirror_size(mirror_id) + info = MirrorInfo.from_dict(data, size_bytes=_calculate_mirror_size(mirror_id)) return info except (json.JSONDecodeError, KeyError, UnicodeDecodeError, ValueError) as e: logger.error( @@ -276,7 +278,7 @@ def create_mirror( now = datetime.now(UTC) info = MirrorInfo( mirror_id=mirror_id, - community_ids=copied_communities, + community_ids=tuple(copied_communities), created_at=now, expires_at=now + timedelta(hours=ttl_hours), owner_id=owner_id, @@ -401,17 +403,32 @@ def refresh_mirror( f"No production databases found for communities: {targets}. Nothing was refreshed." ) - # Update size in metadata - info.size_bytes = _calculate_mirror_size(mirror_id) - return info + # Return a new MirrorInfo with updated size (frozen dataclass) + return MirrorInfo( + mirror_id=info.mirror_id, + community_ids=info.community_ids, + created_at=info.created_at, + expires_at=info.expires_at, + owner_id=info.owner_id, + label=info.label, + size_bytes=_calculate_mirror_size(mirror_id), + ) def cleanup_expired_mirrors() -> int: - """Delete all expired mirrors. Returns the count of mirrors deleted.""" + """Delete all expired mirrors. Returns the count of mirrors deleted. + + Continues past individual deletion failures so one stuck mirror + does not block cleanup of the rest. + """ deleted = 0 for info in list_mirrors(): - if info.is_expired() and delete_mirror(info.mirror_id): - deleted += 1 + if info.is_expired(): + try: + if delete_mirror(info.mirror_id): + deleted += 1 + except OSError: + logger.error("Failed to delete expired mirror %s", info.mirror_id, exc_info=True) if deleted: logger.info("Cleaned up %d expired mirrors", deleted) return deleted diff --git a/src/metrics/cost.py b/src/metrics/cost.py index a31b290..de88fd5 100644 --- a/src/metrics/cost.py +++ b/src/metrics/cost.py @@ -3,6 +3,9 @@ Model pricing table with per-token costs (USD per million tokens). Pricing is from OpenRouter; models added incrementally so individual prices may have different verification dates. + +Also defines cost protection thresholds for blocking expensive models +on platform/community keys (not BYOK). """ import logging @@ -80,6 +83,14 @@ class ModelRate(NamedTuple): "meta-llama/llama-3.3-70b-instruct": ModelRate(0.10, 0.32), } +# Validate all pricing entries at import time to catch typos +for _model_name, _rate in MODEL_PRICING.items(): + if _rate.input_per_1m < 0 or _rate.output_per_1m < 0: + raise ValueError( + f"Negative rate for model {_model_name}: " + f"input={_rate.input_per_1m}, output={_rate.output_per_1m}" + ) + # Fallback rate for models not in the pricing table _FALLBACK_RATE = ModelRate(input_per_1m=1.00, output_per_1m=3.00) diff --git a/tests/test_api/test_cost_protection.py b/tests/test_api/test_cost_protection.py index 8bf518e..a2b0030 100644 --- a/tests/test_api/test_cost_protection.py +++ b/tests/test_api/test_cost_protection.py @@ -54,9 +54,12 @@ def test_expensive_model_allowed_with_byok(self) -> None: _check_model_cost(expensive_models[0], "byok") - def test_unknown_model_allowed_on_platform_key(self) -> None: - """Unknown models (not in pricing table) should be allowed.""" - _check_model_cost("unknown/made-up-model-xyz", "platform") + def test_unknown_model_blocked_on_platform_key(self) -> None: + """Unknown models (not in pricing table) should be blocked on platform keys.""" + with pytest.raises(HTTPException) as exc_info: + _check_model_cost("unknown/made-up-model-xyz", "platform") + assert exc_info.value.status_code == 403 + assert "not in the approved pricing list" in exc_info.value.detail def test_unknown_model_allowed_with_byok(self) -> None: """BYOK users with unknown models should also be allowed.""" diff --git a/tests/test_knowledge/test_mirror.py b/tests/test_knowledge/test_mirror.py index 332be8a..4e711e7 100644 --- a/tests/test_knowledge/test_mirror.py +++ b/tests/test_knowledge/test_mirror.py @@ -3,11 +3,14 @@ Tests cover: - Mirror CRUD lifecycle (create, get, list, delete) - ContextVar-based DB routing (get_db_path returns mirror path when set) +- active_mirror_context context manager (set/reset, exception safety) +- MirrorInfo invariants (frozen dataclass, validation, serialization) - Mirror refresh (re-copy from production) -- TTL expiration and cleanup +- TTL expiration, clamping, and cleanup - Resource limits (max mirrors, per-user limits) - Path traversal prevention - Corrupt metadata resilience +- run_sync_now input validation """ import json @@ -234,7 +237,7 @@ def test_mirror_expired(self): """Mirror with past expiration is expired.""" info = MirrorInfo( mirror_id="test", - community_ids=["testcommunity"], + community_ids=("testcommunity",), created_at=datetime.now(UTC), expires_at=datetime.now(UTC) - timedelta(hours=1), ) @@ -405,5 +408,120 @@ class TestCreateMirrorCleanup: def test_create_mirror_partial_communities(self): """Creating with mix of valid and invalid communities only copies valid ones.""" info = create_mirror(community_ids=["testcommunity", "nonexistent"]) - assert info.community_ids == ["testcommunity"] + assert info.community_ids == ("testcommunity",) assert "nonexistent" not in info.community_ids + + +class TestActiveMirrorContext: + """Tests for the active_mirror_context context manager.""" + + def test_context_manager_sets_and_resets(self): + """Context manager sets mirror ID and resets it after the block.""" + from src.knowledge.db import active_mirror_context + + assert get_active_mirror() is None + with active_mirror_context("abc123"): + assert get_active_mirror() == "abc123" + assert get_active_mirror() is None + + def test_context_manager_resets_on_exception(self): + """Context manager resets mirror ID even if an exception occurs.""" + from src.knowledge.db import active_mirror_context + + assert get_active_mirror() is None + with pytest.raises(RuntimeError), active_mirror_context("abc123"): + assert get_active_mirror() == "abc123" + raise RuntimeError("test error") + assert get_active_mirror() is None + + def test_context_manager_validates_mirror_id(self): + """Context manager rejects invalid mirror IDs.""" + from src.knowledge.db import active_mirror_context + + with pytest.raises(ValueError), active_mirror_context("../invalid"): + pass + + +class TestMirrorInfoInvariants: + """Tests for MirrorInfo construction validation.""" + + def test_empty_community_ids_rejected(self): + """Constructing MirrorInfo with empty community_ids raises ValueError.""" + with pytest.raises(ValueError, match="community_ids must not be empty"): + MirrorInfo( + mirror_id="valid", + community_ids=(), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + + def test_invalid_mirror_id_at_construction(self): + """Constructing MirrorInfo with path-traversal mirror_id raises ValueError.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + MirrorInfo( + mirror_id="../etc", + community_ids=("test",), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + + def test_frozen_dataclass_is_immutable(self): + """MirrorInfo fields cannot be modified after construction.""" + info = MirrorInfo( + mirror_id="test", + community_ids=("testcommunity",), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + with pytest.raises(AttributeError): + info.mirror_id = "changed" # type: ignore[misc] + with pytest.raises(AttributeError): + info.size_bytes = 999 # type: ignore[misc] + + def test_serialization_roundtrip(self): + """MirrorInfo to_dict/from_dict preserves all fields.""" + now = datetime.now(UTC) + original = MirrorInfo( + mirror_id="test123", + community_ids=("hed", "bids"), + created_at=now, + expires_at=now + timedelta(hours=24), + owner_id="user1", + label="my mirror", + ) + data = original.to_dict() + restored = MirrorInfo.from_dict(data) + + assert restored.mirror_id == original.mirror_id + assert restored.community_ids == original.community_ids + assert restored.created_at == original.created_at + assert restored.expires_at == original.expires_at + assert restored.owner_id == original.owner_id + assert restored.label == original.label + # size_bytes is excluded from serialization + assert "size_bytes" not in data + assert restored.size_bytes == 0 + + +class TestTTLClamping: + """Tests for TTL clamping in create_mirror.""" + + def test_ttl_clamped_to_max(self): + """create_mirror clamps TTL to MAX_TTL_HOURS.""" + from src.knowledge.mirror import MAX_TTL_HOURS + + info = create_mirror(community_ids=["testcommunity"], ttl_hours=999) + actual_ttl = (info.expires_at - info.created_at).total_seconds() / 3600 + assert actual_ttl <= MAX_TTL_HOURS + assert actual_ttl == pytest.approx(MAX_TTL_HOURS, abs=0.01) + + +class TestRunSyncNowValidation: + """Tests for run_sync_now input validation.""" + + def test_invalid_sync_type_raises_valueerror(self): + """run_sync_now raises ValueError for unknown sync types.""" + from src.api.scheduler import run_sync_now + + with pytest.raises(ValueError, match="Unknown sync_type"): + run_sync_now("invalid_type") From bf62b434e5634faa3f6ba6c52064b7f98af921ae Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 9 Mar 2026 13:20:28 -0700 Subject: [PATCH 2/3] Use generic redaction placeholder, remove misleading __all__ - Change redaction string from "sk-or-v1-***[redacted]" to "***[key-redacted]" since the pattern now covers multiple providers - Remove __all__ from mirror.py since no callers use wildcard imports from that module (is_safe_identifier now lives in core.validation) --- src/core/logging.py | 6 +++--- src/knowledge/mirror.py | 3 --- tests/test_core/test_logging.py | 16 ++++++++-------- 3 files changed, 11 insertions(+), 14 deletions(-) diff --git a/src/core/logging.py b/src/core/logging.py index 25114f2..f4d7401 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -63,7 +63,7 @@ def format(self, record: logging.LogRecord) -> str: if len(formatted) > 100_000: # 100KB limit formatted = formatted[:100_000] + "... [truncated for safety]" - formatted = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", formatted) + formatted = self.API_KEY_PATTERN.sub("***[key-redacted]", formatted) except re.error as e: # Regex pattern is broken - this is a code bug print(f"CRITICAL: Redaction regex failed: {e}", file=sys.stderr) @@ -145,7 +145,7 @@ def format(self, record: logging.LogRecord) -> str: json_str = json.dumps(log_entry, default=str) # Redact API keys from the JSON string - json_str = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", json_str) + json_str = self.API_KEY_PATTERN.sub("***[key-redacted]", json_str) return json_str @@ -162,7 +162,7 @@ def format(self, record: logging.LogRecord) -> str: "original_message": safe_msg, } fallback_json = json.dumps(error_entry) - return self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", fallback_json) + return self.API_KEY_PATTERN.sub("***[key-redacted]", fallback_json) except Exception as e: # Unexpected errors - surface to stderr and re-raise print(f"CRITICAL: Unexpected error in SecureJSONFormatter: {e}", file=sys.stderr) diff --git a/src/knowledge/mirror.py b/src/knowledge/mirror.py index 87d2c10..8f654f3 100644 --- a/src/knowledge/mirror.py +++ b/src/knowledge/mirror.py @@ -21,9 +21,6 @@ logger = logging.getLogger(__name__) -# Re-export for backward compatibility -__all__ = ["is_safe_identifier"] - MIRRORS_DIR_NAME = "mirrors" METADATA_FILE = "_metadata.json" diff --git a/tests/test_core/test_logging.py b/tests/test_core/test_logging.py index a6cf249..114ea73 100644 --- a/tests/test_core/test_logging.py +++ b/tests/test_core/test_logging.py @@ -28,7 +28,7 @@ def test_redacts_api_key_in_message(self) -> None: ) formatted = formatter.format(record) - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted assert "aaaa" not in formatted # Original key should not appear def test_redacts_multiple_api_keys(self) -> None: @@ -45,7 +45,7 @@ def test_redacts_multiple_api_keys(self) -> None: ) formatted = formatter.format(record) - assert formatted.count("sk-or-v1-***[redacted]") == 2 + assert formatted.count("***[key-redacted]") == 2 assert "aaaa" not in formatted assert "bbbb" not in formatted @@ -65,7 +65,7 @@ def test_preserves_non_key_content(self) -> None: formatted = formatter.format(record) assert "Starting request with key" in formatted assert "for user john" in formatted - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted def test_handles_message_without_keys(self) -> None: """Should not modify messages without API keys.""" @@ -98,7 +98,7 @@ def test_redacts_case_insensitive(self) -> None: ) formatted = formatter.format(record) - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted def test_uses_standard_format_fields(self) -> None: """Should support standard logging format fields.""" @@ -116,7 +116,7 @@ def test_uses_standard_format_fields(self) -> None: formatted = formatter.format(record) assert "WARNING" in formatted assert "my.logger" in formatted - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted def test_preserves_partial_matches(self) -> None: """Should not redact strings that partially match key pattern.""" @@ -205,7 +205,7 @@ def test_redacts_api_keys_in_exception_tracebacks(self) -> None: formatted = formatter.format(record) # API key in exception message should be redacted - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted assert "aaaa" not in formatted def test_concurrent_logging_thread_safety(self) -> None: @@ -253,7 +253,7 @@ def log_with_key(index: int) -> None: assert len(errors) == 0, f"Concurrent logging errors: {errors}" assert len(formatted_logs) == 10 for _, api_key, log in formatted_logs: - assert "sk-or-v1-***[redacted]" in log + assert "***[key-redacted]" in log # Original key should not appear assert api_key not in log @@ -402,7 +402,7 @@ def test_redacts_api_keys_in_json(self) -> None: formatted = formatter.format(record) log_data = json.loads(formatted) - assert "sk-or-v1-***[redacted]" in log_data["message"] + assert "***[key-redacted]" in log_data["message"] assert "aaaa" not in log_data["message"] def test_configures_json_logging(self) -> None: From 84f8ec8d649aa6f7caae8de48655852a23119e45 Mon Sep 17 00:00:00 2001 From: Seyed Yahya Shirazi Date: Mon, 9 Mar 2026 13:21:40 -0700 Subject: [PATCH 3/3] Add ValueError catch to delete endpoint, validate community IDs - Add missing ValueError handling in delete_mirror_endpoint for consistency with all other mirror endpoints - Add community ID validation in MirrorInfo.__post_init__ so corrupt metadata with path-traversal community IDs is caught at load time - Document CorruptMirrorError in refresh_mirror docstring --- src/api/routers/mirrors.py | 5 +++++ src/knowledge/mirror.py | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/src/api/routers/mirrors.py b/src/api/routers/mirrors.py index be11ad5..5c33adb 100644 --- a/src/api/routers/mirrors.py +++ b/src/api/routers/mirrors.py @@ -210,6 +210,11 @@ async def delete_mirror_endpoint( status_code=status.HTTP_404_NOT_FOUND, detail=f"Mirror '{mirror_id}' not found", ) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) except OSError as e: logger.error("Failed to delete mirror %s: %s", mirror_id, e, exc_info=True) raise HTTPException( diff --git a/src/knowledge/mirror.py b/src/knowledge/mirror.py index 8f654f3..478f536 100644 --- a/src/knowledge/mirror.py +++ b/src/knowledge/mirror.py @@ -53,6 +53,9 @@ def __post_init__(self) -> None: raise ValueError(f"Invalid mirror ID: {self.mirror_id!r}") if not self.community_ids: raise ValueError("community_ids must not be empty") + for cid in self.community_ids: + if not is_safe_identifier(cid): + raise ValueError(f"Invalid community ID in community_ids: {cid!r}") def is_expired(self) -> bool: """Check if the mirror has passed its expiration time.""" @@ -374,6 +377,7 @@ def refresh_mirror( Raises: ValueError: If mirror not found or expired. + CorruptMirrorError: If mirror metadata is corrupt. """ info = get_mirror(mirror_id) if not info: