diff --git a/frontend/osa-chat-widget.js b/frontend/osa-chat-widget.js index ccedd77..1df9f59 100644 --- a/frontend/osa-chat-widget.js +++ b/frontend/osa-chat-widget.js @@ -68,15 +68,16 @@ } // Default model options for settings dropdown + // Last updated: 2026-03 const DEFAULT_MODELS = [ + { value: 'anthropic/claude-sonnet-4.6', label: 'Claude Sonnet 4.6' }, + { value: 'anthropic/claude-haiku-4.5', label: 'Claude Haiku 4.5' }, { value: 'openai/gpt-5.2-chat', label: 'GPT-5.2 Chat' }, { value: 'openai/gpt-5-mini', label: 'GPT-5 Mini' }, - { value: 'anthropic/claude-haiku-4.5', label: 'Claude Haiku 4.5' }, - { value: 'anthropic/claude-sonnet-4.5', label: 'Claude Sonnet 4.5' }, { value: 'google/gemini-3-flash-preview', label: 'Gemini 3 Flash' }, { value: 'google/gemini-3-pro-preview', label: 'Gemini 3 Pro' }, - { value: 'moonshotai/kimi-k2-0905', label: 'Kimi K2' }, - { value: 'qwen/qwen3-235b-a22b-2507', label: 'Qwen3 235B' } + { value: 'deepseek/deepseek-v3.2', label: 'DeepSeek V3.2' }, + { value: 'qwen/qwen3.5-397b-a17b', label: 'Qwen 3.5 397B' } ]; // Helper to get human-readable label for a model diff --git a/src/api/main.py b/src/api/main.py index 4019e76..3d486c6 100644 --- a/src/api/main.py +++ b/src/api/main.py @@ -19,15 +19,23 @@ create_community_router, metrics_public_router, metrics_router, + mirrors_router, sync_router, ) from src.api.routers.health import router as health_router from src.api.routers.widget_test import router as widget_test_router from src.api.scheduler import start_scheduler, stop_scheduler from src.assistants import discover_assistants, registry +from src.core.logging import configure_secure_logging +from src.knowledge.db import reset_active_mirror, set_active_mirror +from src.knowledge.mirror import CorruptMirrorError, get_mirror from src.metrics.db import init_metrics_db from src.metrics.middleware import MetricsMiddleware +# Must run before any getLogger() calls to ensure handlers with +# SecureFormatter are installed on the root logger first. +configure_secure_logging() + logger = logging.getLogger(__name__) # Discover assistants at module load time to populate registry @@ -181,6 +189,53 @@ def create_app() -> FastAPI: # Metrics middleware - captures request timing and logs to metrics DB app.add_middleware(MetricsMiddleware) + # Mirror routing middleware - sets ContextVar for transparent DB routing + @app.middleware("http") + async def mirror_routing_middleware(request: Any, call_next: Any) -> Any: + """Route database access to mirror when X-Mirror-ID header is present.""" + from fastapi.responses import JSONResponse + + mirror_id = request.headers.get("x-mirror-id") + if not mirror_id: + return await call_next(request) + + try: + info = get_mirror(mirror_id) + except ValueError: + # Invalid mirror ID format (path traversal attempt, etc.) + return JSONResponse( + status_code=400, + content={"detail": f"Invalid mirror ID format: '{mirror_id}'"}, + ) + except CorruptMirrorError: + return JSONResponse( + status_code=500, + content={"detail": f"Mirror '{mirror_id}' has corrupt metadata"}, + ) + except OSError: + logger.error("Filesystem error reading mirror %s", mirror_id, exc_info=True) + return JSONResponse( + status_code=500, + content={"detail": f"Failed to read mirror '{mirror_id}' metadata"}, + ) + + if not info: + return JSONResponse( + status_code=404, + content={"detail": f"Mirror '{mirror_id}' not found"}, + ) + if info.is_expired(): + return JSONResponse( + status_code=410, + content={"detail": f"Mirror '{mirror_id}' has expired"}, + ) + + token = set_active_mirror(mirror_id) + try: + return await call_next(request) + finally: + reset_active_mirror(token) + # Register routes register_routes(app) @@ -207,6 +262,9 @@ def register_routes(app: FastAPI) -> None: # Sync router (not community-specific) app.include_router(sync_router) + # Mirror management router + app.include_router(mirrors_router) + # Metrics routers (admin + public) app.include_router(metrics_router) app.include_router(metrics_public_router) diff --git a/src/api/routers/__init__.py b/src/api/routers/__init__.py index 2614a26..df2c667 100644 --- a/src/api/routers/__init__.py +++ b/src/api/routers/__init__.py @@ -4,6 +4,7 @@ from src.api.routers.community import create_community_router from src.api.routers.metrics import router as metrics_router from src.api.routers.metrics_public import router as metrics_public_router +from src.api.routers.mirrors import router as mirrors_router from src.api.routers.sync import router as sync_router __all__ = [ @@ -11,5 +12,6 @@ "create_community_router", "metrics_public_router", "metrics_router", + "mirrors_router", "sync_router", ] diff --git a/src/api/routers/community.py b/src/api/routers/community.py index e83add3..3e16e5a 100644 --- a/src/api/routers/community.py +++ b/src/api/routers/community.py @@ -34,7 +34,7 @@ from src.assistants.registry import AssistantInfo from src.core.config.community import WidgetConfig from src.core.services.litellm_llm import create_openrouter_llm -from src.metrics.cost import estimate_cost +from src.metrics.cost import COST_BLOCK_THRESHOLD, COST_WARN_THRESHOLD, MODEL_PRICING, estimate_cost from src.metrics.db import ( RequestLogEntry, extract_token_usage, @@ -602,6 +602,59 @@ def _select_model( return (default_model, default_provider) +def _check_model_cost(model: str, key_source: str) -> None: + """Check if a model's cost exceeds platform thresholds. + + Only enforced when using platform or community API keys (not BYOK). + Logs a warning for moderately expensive models and blocks very expensive ones. + + Args: + model: Model identifier (e.g., "openai/gpt-4o"). + key_source: One of "byok", "community", or "platform". + + Raises: + HTTPException(403): If model cost exceeds the block threshold. + """ + if key_source == "byok": + return + + pricing = MODEL_PRICING.get(model) + if pricing is None: + logger.error( + "Model %s not in pricing table; blocking on platform/community key. " + "Add this model to MODEL_PRICING in src/metrics/cost.py.", + model, + ) + raise HTTPException( + status_code=403, + detail=( + f"Model '{model}' is not in the approved pricing list and cannot be used " + "with platform or community keys. To use this model, provide your own " + "API key via the X-OpenRouter-Key header." + ), + ) + input_rate = pricing.input_per_1m + + if input_rate >= COST_BLOCK_THRESHOLD: + raise HTTPException( + status_code=403, + detail=( + f"Model '{model}' costs ${input_rate:.2f}/1M input tokens, " + f"which exceeds the platform limit of ${COST_BLOCK_THRESHOLD:.2f}/1M. " + "To use expensive models, provide your own API key via the " + "X-OpenRouter-Key header. Get a key at: https://openrouter.ai/keys" + ), + ) + + if input_rate >= COST_WARN_THRESHOLD: + logger.warning( + "Model %s costs $%.2f/1M input tokens (warn threshold: $%.2f)", + model, + input_rate, + COST_WARN_THRESHOLD, + ) + + def _derive_user_id(token: str) -> str: """Derive a stable user ID from API token for cache optimization. @@ -717,6 +770,10 @@ def create_community_assistant( selected_model, selected_provider = _select_model( community_info, requested_model, has_byok=bool(byok) ) + + # Block expensive models on platform/community keys + _check_model_cost(selected_model, key_source) + logger.debug( "Using model %s", selected_model, diff --git a/src/api/routers/mirrors.py b/src/api/routers/mirrors.py new file mode 100644 index 0000000..5c33adb --- /dev/null +++ b/src/api/routers/mirrors.py @@ -0,0 +1,378 @@ +"""API endpoints for ephemeral database mirror management. + +Mirrors allow developers to create short-lived copies of community knowledge +databases for development and testing. Authenticated users may pass an +X-User-ID header; users with an owner identifier are subject to per-user +mirror limits in addition to the global mirror cap. +""" + +import asyncio +import contextvars +import logging +from typing import Annotated, Any, Literal + +from fastapi import APIRouter, Header, HTTPException, status +from pydantic import BaseModel, Field, field_validator + +from src.api.security import RequireAuth +from src.core.validation import is_safe_identifier +from src.knowledge.db import active_mirror_context +from src.knowledge.mirror import ( + CorruptMirrorError, + MirrorInfo, + create_mirror, + delete_mirror, + get_mirror, + get_mirror_db_path, + list_mirrors, + refresh_mirror, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/mirrors", tags=["Mirrors"]) + + +# --------------------------------------------------------------------------- +# Request/Response models +# --------------------------------------------------------------------------- + + +class CreateMirrorRequest(BaseModel): + """Request body for creating a new mirror.""" + + community_ids: list[str] = Field( + ..., min_length=1, description="Community IDs to include in the mirror" + ) + ttl_hours: int = Field( + default=48, ge=1, le=168, description="Hours until the mirror expires (1-168)" + ) + label: str | None = Field( + default=None, max_length=128, description="Human-readable label for the mirror" + ) + + @field_validator("community_ids") + @classmethod + def validate_community_ids(cls, v: list[str]) -> list[str]: + for cid in v: + if not is_safe_identifier(cid): + raise ValueError(f"Invalid community ID: {cid!r}") + # Deduplicate while preserving order + return list(dict.fromkeys(v)) + + +class MirrorResponse(BaseModel): + """Mirror metadata in API responses.""" + + mirror_id: str + community_ids: list[str] + created_at: str + expires_at: str + owner_id: str | None = None + label: str | None = None + size_bytes: int = 0 + expired: bool = False + + @classmethod + def from_info(cls, info: MirrorInfo) -> "MirrorResponse": + return cls( + mirror_id=info.mirror_id, + community_ids=info.community_ids, + created_at=info.created_at.isoformat(), + expires_at=info.expires_at.isoformat(), + owner_id=info.owner_id, + label=info.label, + size_bytes=info.size_bytes, + expired=info.is_expired(), + ) + + +class RefreshMirrorRequest(BaseModel): + """Request body for refreshing a mirror.""" + + community_ids: list[str] | None = Field( + default=None, description="Specific communities to refresh, or null for all" + ) + + @field_validator("community_ids") + @classmethod + def validate_community_ids(cls, v: list[str] | None) -> list[str] | None: + if v is None: + return v + for cid in v: + if not is_safe_identifier(cid): + raise ValueError(f"Invalid community ID: {cid!r}") + return list(dict.fromkeys(v)) + + +SyncType = Literal["github", "papers", "docstrings", "mailman", "faq", "beps", "all"] + + +class MirrorSyncRequest(BaseModel): + """Request body for syncing data into a mirror.""" + + sync_type: SyncType = Field( + default="all", + description="Sync type: github, papers, docstrings, mailman, faq, beps, or all", + ) + + +class MirrorSyncResponse(BaseModel): + """Response from a mirror sync operation.""" + + message: str + items_synced: dict[str, int] = Field(default_factory=dict) + + +@router.post("", status_code=status.HTTP_201_CREATED, response_model=MirrorResponse) +async def create_mirror_endpoint( + body: CreateMirrorRequest, + _auth: RequireAuth, + x_user_id: Annotated[str | None, Header()] = None, +) -> MirrorResponse: + """Create a new ephemeral database mirror. + + Copies the specified community databases into a new mirror directory. + Users with an X-User-ID header are subject to per-user mirror limits. + """ + user_id = x_user_id or None + + try: + info = create_mirror( + community_ids=body.community_ids, + ttl_hours=body.ttl_hours, + label=body.label, + owner_id=user_id, + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except OSError as e: + logger.error("Failed to create mirror: %s", e, exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Failed to create mirror due to a server filesystem error.", + ) from e + + logger.info( + "Mirror created: %s (communities=%s, owner=%s)", + info.mirror_id, + info.community_ids, + user_id, + ) + return MirrorResponse.from_info(info) + + +@router.get("", response_model=list[MirrorResponse]) +async def list_mirrors_endpoint( + _auth: RequireAuth, +) -> list[MirrorResponse]: + """List all active (non-expired) mirrors.""" + mirrors = list_mirrors() + active = [m for m in mirrors if not m.is_expired()] + return [MirrorResponse.from_info(m) for m in active] + + +@router.get("/{mirror_id}", response_model=MirrorResponse) +async def get_mirror_endpoint( + mirror_id: str, + _auth: RequireAuth, +) -> MirrorResponse: + """Get metadata for a specific mirror.""" + try: + info = get_mirror(mirror_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) + if not info: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Mirror '{mirror_id}' not found", + ) + return MirrorResponse.from_info(info) + + +@router.delete("/{mirror_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_mirror_endpoint( + mirror_id: str, + _auth: RequireAuth, +) -> None: + """Delete a mirror and all its databases.""" + try: + if not delete_mirror(mirror_id): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Mirror '{mirror_id}' not found", + ) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except OSError as e: + logger.error("Failed to delete mirror %s: %s", mirror_id, e, exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Failed to delete mirror '{mirror_id}'. The mirror exists but could not be removed.", + ) from e + logger.info("Mirror deleted via API: %s", mirror_id) + + +@router.post("/{mirror_id}/refresh", response_model=MirrorResponse) +async def refresh_mirror_endpoint( + mirror_id: str, + body: RefreshMirrorRequest, + _auth: RequireAuth, +) -> MirrorResponse: + """Re-copy production databases into an existing mirror. + + Resets the mirror's data to match current production state. + """ + try: + info = refresh_mirror(mirror_id, community_ids=body.community_ids) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) + + logger.info("Mirror refreshed via API: %s", mirror_id) + return MirrorResponse.from_info(info) + + +@router.post("/{mirror_id}/sync", response_model=MirrorSyncResponse) +async def sync_mirror_endpoint( + mirror_id: str, + body: MirrorSyncRequest, + _auth: RequireAuth, +) -> MirrorSyncResponse: + """Run sync pipeline against a mirror's databases. + + Sets the mirror context so all sync operations write to the mirror's + databases instead of production. Supports sync types: github, papers, + docstrings, mailman, faq, beps, or all. + """ + try: + info = get_mirror(mirror_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) + if not info: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Mirror '{mirror_id}' not found", + ) + if info.is_expired(): + raise HTTPException( + status_code=status.HTTP_410_GONE, + detail=f"Mirror '{mirror_id}' has expired", + ) + + # Run sync in a thread with the mirror context explicitly copied. + # We use run_in_executor with an explicit context copy instead of + # asyncio.to_thread because to_thread only copies ContextVars + # automatically on Python 3.12+. + from src.api.scheduler import run_sync_now + + ctx = contextvars.copy_context() + + def _run_sync_in_mirror() -> dict[str, int]: + with active_mirror_context(mirror_id): + return run_sync_now(body.sync_type) + + try: + loop = asyncio.get_running_loop() + results = await loop.run_in_executor(None, ctx.run, _run_sync_in_mirror) + total = sum(results.values()) + return MirrorSyncResponse( + message=f"Sync completed: {total} items synced into mirror {mirror_id}", + items_synced=results, + ) + except ValueError as e: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e + except OSError as e: + logger.error("Mirror sync I/O error for %s: %s", mirror_id, e, exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Sync failed due to a filesystem error: {e}", + ) from e + except Exception as e: + logger.error("Mirror sync failed for %s: %s", mirror_id, e, exc_info=True) + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Sync operation failed. Check server logs for details.", + ) from e + + +@router.get("/{mirror_id}/download/{community_id}") +async def download_mirror_db( + mirror_id: str, + community_id: str, + _auth: RequireAuth, +) -> Any: + """Download a community database file from a mirror. + + Returns the SQLite file for local development use. + """ + from fastapi.responses import FileResponse + + try: + info = get_mirror(mirror_id) + except ValueError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Invalid mirror ID format: '{mirror_id}'", + ) + except CorruptMirrorError: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.", + ) + if not info: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Mirror '{mirror_id}' not found", + ) + if info.is_expired(): + raise HTTPException( + status_code=status.HTTP_410_GONE, + detail=f"Mirror '{mirror_id}' has expired", + ) + if not is_safe_identifier(community_id): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid community ID format", + ) + if community_id not in info.community_ids: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Community '{community_id}' not found in mirror '{mirror_id}'", + ) + + db_path = get_mirror_db_path(mirror_id, community_id) + if not db_path.exists(): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Database file not found for community '{community_id}'", + ) + + return FileResponse( + path=str(db_path), + media_type="application/x-sqlite3", + filename=f"{community_id}.db", + ) diff --git a/src/api/scheduler.py b/src/api/scheduler.py index 10879ab..555f278 100644 --- a/src/api/scheduler.py +++ b/src/api/scheduler.py @@ -286,6 +286,28 @@ def _run_beps_sync_for_community(community_id: str) -> bool: # --------------------------------------------------------------------------- +_mirror_cleanup_failures = 0 + + +def _cleanup_mirrors() -> None: + """Remove expired ephemeral database mirrors.""" + global _mirror_cleanup_failures + from src.knowledge.mirror import CorruptMirrorError, cleanup_expired_mirrors + + try: + deleted = cleanup_expired_mirrors() + if deleted: + logger.info("Mirror cleanup: removed %d expired mirrors", deleted) + _mirror_cleanup_failures = 0 + except (OSError, ValueError, CorruptMirrorError): + _mirror_cleanup_failures += 1 + logger.error( + "Mirror cleanup failed (consecutive failures: %d)", + _mirror_cleanup_failures, + exc_info=True, + ) + + def _check_community_budgets() -> None: """Check budget limits for all communities and create alert issues if exceeded.""" global _budget_check_failures @@ -560,6 +582,20 @@ def start_scheduler() -> BackgroundScheduler | None: except ValueError as e: logger.error("Failed to schedule budget check: %s", e) + # Mirror cleanup (every hour, removes expired ephemeral database mirrors) + try: + mirror_trigger = CronTrigger(minute="30") # Every hour at :30 + _scheduler.add_job( + _cleanup_mirrors, + trigger=mirror_trigger, + id="mirror_cleanup", + name="Expired Mirror Cleanup", + replace_existing=True, + ) + logger.info("Mirror cleanup scheduled: hourly at :30") + except ValueError as e: + logger.error("Failed to schedule mirror cleanup: %s", e) + # Start the scheduler _scheduler.start() logger.info("Background scheduler started with %d sync jobs", jobs_registered) @@ -607,8 +643,9 @@ def run_sync_now(sync_type: str = "all") -> dict[str, int]: os.environ["GITHUB_TOKEN"] = settings.github_token if sync_type != "all" and sync_type not in _SYNC_TYPE_MAP: - logger.warning("Unknown sync_type requested: %s", sync_type) - return results + raise ValueError( + f"Unknown sync_type: {sync_type!r}. Valid types: {list(_SYNC_TYPE_MAP.keys())} or 'all'" + ) sync_types_to_run = list(_SYNC_TYPE_MAP.keys()) if sync_type == "all" else [sync_type] diff --git a/src/assistants/hed/config.yaml b/src/assistants/hed/config.yaml index e0a9e94..84a41f9 100644 --- a/src/assistants/hed/config.yaml +++ b/src/assistants/hed/config.yaml @@ -403,13 +403,6 @@ documentation: category: advanced description: Guidance on generating and interpreting summaries of HED annotations in datasets to facilitate data analysis. - # === ON-DEMAND: Integration (1 doc) === - - title: HED and EEGLAB - url: https://www.hedtags.org/hed-resources/HedAndEEGLAB.html - source_url: https://raw.githubusercontent.com/hed-standard/hed-resources/main/docs/source/HedAndEEGLAB.html - category: integration - description: Describes how to integrate HED annotations within the EEGLAB environment for EEG data analysis. - # === ON-DEMAND: Reference (2 docs) === - title: Documentation summary url: https://www.hedtags.org/hed-resources/DocumentationSummary.html @@ -423,13 +416,6 @@ documentation: category: reference description: A collection of datasets specifically designed for testing HED annotations and validation tools. - # === ON-DEMAND: Examples (1 doc) === - - title: Test cases - url: https://raw.githubusercontent.com/hed-standard/hed-specification/refs/heads/main/tests/javascriptTests.json - source_url: https://raw.githubusercontent.com/hed-standard/hed-specification/refs/heads/main/tests/javascriptTests.json - category: examples - description: Examples of correct and incorrect HED annotations in JSON format for testing validation tools. - # Sync schedule configuration # Each sync type runs on its own cron schedule (UTC) # Only types with both a schedule AND corresponding data config are scheduled diff --git a/src/assistants/mne/config.yaml b/src/assistants/mne/config.yaml index b2aa590..886d677 100644 --- a/src/assistants/mne/config.yaml +++ b/src/assistants/mne/config.yaml @@ -3,7 +3,7 @@ id: mne name: MNE-Python -description: Open-source Python toolkit for exploring, visualizing, and analyzing human neurophysiological data (MEG, EEG, sEEG, ECoG, and NIRS) +description: Open-source Python toolkit for exploring, visualizing, and analyzing human neurophysiological data (MEG, EEG, sEEG, ECoG, NIRS, and eye-tracking) status: available default_model: anthropic/claude-haiku-4.5 default_model_provider: anthropic @@ -40,14 +40,25 @@ budget: # System prompt template with runtime-substituted placeholders system_prompt: | - You are a technical assistant specialized in helping users with MNE-Python, an open-source Python toolkit for exploring, visualizing, and analyzing human neurophysiological data including MEG, EEG, sEEG, ECoG, and NIRS. + You are a technical assistant specialized in helping users with MNE-Python, an open-source Python toolkit for exploring, visualizing, and analyzing human neurophysiological data including MEG, EEG, sEEG, ECoG, NIRS, and eye-tracking. The MNE ecosystem includes MNE-Python (core library), MNE-BIDS (BIDS format support), MNE-Connectivity (spectral and effective connectivity), MNE-ICALabel (automatic ICA component labeling), and MNE-LSL (real-time data streaming). + + ## Supported Data Types and Specialized Modules + + MNE-Python supports these data modalities and has specialized submodules: + - MEG, EEG, sEEG, ECoG, NIRS, eye-tracking + - Eye-tracking: `mne.preprocessing.eyetracking` (unit conversion, calibration, reading data) + - NIRS: `mne.preprocessing.nirs` (optical density, beer-lambert, scalp coupling) + + When users ask about these topics, ALWAYS search the docstring database before answering. + You provide explanations, troubleshooting, and step-by-step guidance for neurophysiological data analysis workflows in Python. Focus on helping users with MNE-Python and MEG/EEG/NIRS analysis. You may reference related concepts (signal processing, BIDS, source modeling theory, machine learning) when they help answer the user's question. Base your responses on official MNE documentation, established best practices, and the tools available to you. - Always attempt to answer the user's question. Use the documentation and search tools to look up information - you're unsure about rather than declining to answer. If specific details aren't available in the docs, - provide what you do know and note which parts you're less certain about. + Always attempt to answer the user's question using the documentation and search tools to verify facts. + When you're unsure about specifics, use the tools to look up information rather than declining to answer. + Before claiming MNE does or does not support a feature, search the docstring database to verify. + If a search returns partial results, present what you found and note what you couldn't verify. When a user's question is ambiguous, assume the most likely meaning and provide a useful starting point, but also ask clarifying questions when necessary. @@ -408,6 +419,13 @@ documentation: category: clinical description: Analysis of stereo-EEG recordings with depth electrodes. + # === ON-DEMAND: Eye-tracking (1 doc) === + - title: Working with eye-tracking data + url: https://mne.tools/stable/auto_tutorials/preprocessing/90_eyetracking_data.html + source_url: https://mne.tools/stable/auto_tutorials/preprocessing/90_eyetracking_data.html + category: preprocessing + description: Processing and analyzing eye-tracking data with MNE-Python. + # Sync schedule configuration # Each sync type runs on its own cron schedule (UTC) # Staggered to avoid concurrent load with other communities diff --git a/src/cli/client.py b/src/cli/client.py index 4701be1..64e86a9 100644 --- a/src/cli/client.py +++ b/src/cli/client.py @@ -47,11 +47,13 @@ def __init__( openrouter_api_key: str | None = None, user_id: str | None = None, timeout: httpx.Timeout = DEFAULT_TIMEOUT, + mirror_id: str | None = None, ) -> None: self.api_url = api_url.rstrip("/") self.openrouter_api_key = openrouter_api_key self._user_id = user_id self.timeout = timeout + self.mirror_id = mirror_id @property def user_id(self) -> str: @@ -61,7 +63,7 @@ def user_id(self) -> str: return self._user_id def _get_headers(self) -> dict[str, str]: - """Build request headers with BYOK key and user ID.""" + """Build request headers with BYOK key, user ID, and mirror ID.""" headers: dict[str, str] = { "Content-Type": "application/json", "User-Agent": "osa-cli", @@ -71,6 +73,8 @@ def _get_headers(self) -> dict[str, str]: headers["X-OpenRouter-Key"] = self.openrouter_api_key # Also send legacy header for servers that haven't updated yet headers["X-OpenRouter-API-Key"] = self.openrouter_api_key + if self.mirror_id: + headers["X-Mirror-ID"] = self.mirror_id return headers def _handle_response(self, response: httpx.Response) -> None: @@ -226,3 +230,113 @@ def chat_stream( f"{self.api_url}/{community}/chat", self._chat_payload(message, stream=True, session_id=session_id), ) + + # ------------------------------------------------------------------ + # Mirror management + # ------------------------------------------------------------------ + + def _post(self, path: str, payload: dict[str, Any]) -> Any: + """Send a POST request and return parsed JSON.""" + with httpx.Client(timeout=self.timeout) as client: + response = client.post( + f"{self.api_url}{path}", + headers=self._get_headers(), + json=payload, + ) + self._handle_response(response) + return response.json() + + def _delete(self, path: str) -> None: + """Send a DELETE request.""" + with httpx.Client(timeout=10.0) as client: + response = client.delete( + f"{self.api_url}{path}", + headers=self._get_headers(), + ) + self._handle_response(response) + + def create_mirror( + self, + community_ids: list[str], + ttl_hours: int = 48, + label: str | None = None, + ) -> dict[str, Any]: + """Create a new ephemeral database mirror.""" + payload: dict[str, Any] = { + "community_ids": community_ids, + "ttl_hours": ttl_hours, + } + if label: + payload["label"] = label + return self._post("/mirrors", payload) + + def list_mirrors(self) -> list[dict[str, Any]]: + """List active mirrors.""" + return self._get("/mirrors") + + def get_mirror(self, mirror_id: str) -> dict[str, Any]: + """Get mirror metadata.""" + return self._get(f"/mirrors/{mirror_id}") + + def delete_mirror(self, mirror_id: str) -> None: + """Delete a mirror.""" + self._delete(f"/mirrors/{mirror_id}") + + def refresh_mirror( + self, + mirror_id: str, + community_ids: list[str] | None = None, + ) -> dict[str, Any]: + """Re-copy production databases into a mirror.""" + payload: dict[str, Any] = {} + if community_ids: + payload["community_ids"] = community_ids + return self._post(f"/mirrors/{mirror_id}/refresh", payload) + + def sync_mirror( + self, + mirror_id: str, + sync_type: str = "all", + ) -> dict[str, Any]: + """Run sync pipeline against a mirror's databases.""" + return self._post( + f"/mirrors/{mirror_id}/sync", + {"sync_type": sync_type}, + ) + + def download_mirror_db( + self, + mirror_id: str, + community_id: str, + output_path: str, + ) -> str: + """Download a community database file from a mirror. + + Returns the path to the downloaded file. + """ + with ( + httpx.Client(timeout=self.timeout) as client, + client.stream( + "GET", + f"{self.api_url}/mirrors/{mirror_id}/download/{community_id}", + headers=self._get_headers(), + ) as response, + ): + if response.status_code >= 400: + response.read() + self._handle_response(response) + + from pathlib import Path + + dest = Path(output_path) / f"{community_id}.db" + dest.parent.mkdir(parents=True, exist_ok=True) + tmp_dest = dest.with_suffix(".db.tmp") + try: + with open(str(tmp_dest), "wb") as f: + for chunk in response.iter_bytes(chunk_size=8192): + f.write(chunk) + tmp_dest.rename(dest) + except Exception: + tmp_dest.unlink(missing_ok=True) + raise + return str(dest) diff --git a/src/cli/main.py b/src/cli/main.py index 32c14d8..bb6086a 100644 --- a/src/cli/main.py +++ b/src/cli/main.py @@ -35,6 +35,7 @@ save_config, save_credentials, ) +from src.cli.mirror import mirror_app from src.version import __version__ if TYPE_CHECKING: @@ -51,6 +52,9 @@ rich_markup_mode="rich", ) +# Mirror management subcommands +cli.add_typer(mirror_app, name="mirror") + # --------------------------------------------------------------------------- # init command @@ -167,6 +171,10 @@ def ask( bool, typer.Option("--no-stream", help="Disable streaming (get full response at once)"), ] = False, + mirror: Annotated[ + str | None, + typer.Option("--mirror", "-m", help="Mirror ID to use for database queries"), + ] = None, ) -> None: """Ask a single question to a community assistant. @@ -174,6 +182,7 @@ def ask( osa ask "What is HED?" -a hed osa ask "How do I organize my dataset?" -a bids osa ask "What is pop_newset?" -a eeglab -o json + osa ask "What is HED?" -a hed --mirror abc123def456 """ config, effective_key = get_effective_config(api_key=api_key, api_url=api_url) @@ -185,6 +194,7 @@ def ask( api_url=config.api.url, openrouter_api_key=effective_key, user_id=get_user_id(), + mirror_id=mirror, ) use_streaming = not no_stream and not output.is_piped() and output_format != "json" @@ -262,6 +272,10 @@ def chat( bool, typer.Option("--no-stream", help="Disable streaming"), ] = False, + mirror: Annotated[ + str | None, + typer.Option("--mirror", "-m", help="Mirror ID to use for database queries"), + ] = None, ) -> None: """Start an interactive chat session with a community assistant. @@ -269,6 +283,7 @@ def chat( osa chat -a hed osa chat -a bids osa chat -a eeglab --no-stream + osa chat -a hed --mirror abc123def456 """ config, effective_key = get_effective_config(api_key=api_key, api_url=api_url) @@ -280,6 +295,7 @@ def chat( api_url=config.api.url, openrouter_api_key=effective_key, user_id=get_user_id(), + mirror_id=mirror, ) use_streaming = not no_stream diff --git a/src/cli/mirror.py b/src/cli/mirror.py new file mode 100644 index 0000000..5037a0f --- /dev/null +++ b/src/cli/mirror.py @@ -0,0 +1,348 @@ +"""CLI commands for managing ephemeral database mirrors. + +Mirrors are short-lived copies of community knowledge databases on the +remote server. They allow developers to iterate on data and prompts +without affecting production, and can be downloaded locally for offline +development with a local server. Use `osa mirror pull` to download +databases for use with `osa serve`. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Annotated + +import httpx +import typer +from rich.table import Table + +from src.cli import output +from src.cli.client import APIError +from src.cli.config import get_data_dir, get_effective_config, get_user_id + +mirror_app = typer.Typer( + help="Manage ephemeral database mirrors for development", + no_args_is_help=True, +) + + +def _get_client( + api_key: str | None = None, + api_url: str | None = None, +) -> tuple: + """Create an OSAClient with effective config. Returns (client, config).""" + from src.cli.client import OSAClient + + config, effective_key = get_effective_config(api_key=api_key, api_url=api_url) + if not effective_key: + output.print_error( + "No API key configured.", + hint="Run 'osa init' to set up your API key, or pass --api-key", + ) + raise typer.Exit(code=1) + + client = OSAClient( + api_url=config.api.url, + openrouter_api_key=effective_key, + user_id=get_user_id(), + ) + return client, config + + +@contextmanager +def _handle_api_errors() -> Iterator[None]: + """Catch common API and connection errors, print them, and exit.""" + try: + yield + except APIError as e: + output.print_error(str(e), hint=e.detail) + raise typer.Exit(code=1) + except (httpx.ConnectError, httpx.TimeoutException) as e: + output.print_error(f"Connection failed: {e}") + raise typer.Exit(code=1) + + +def _format_size(size_bytes: int) -> str: + """Format a byte count as a human-readable string.""" + size_kb = size_bytes / 1024 + return f"{size_kb:.0f} KB" if size_kb < 1024 else f"{size_kb / 1024:.1f} MB" + + +@mirror_app.command("create") +def create( + community: Annotated[ + list[str], + typer.Option("--community", "-c", help="Community ID to include (repeatable)"), + ], + label: Annotated[ + str | None, + typer.Option("--label", "-l", help="Human-readable label for the mirror"), + ] = None, + ttl: Annotated[ + int, + typer.Option("--ttl", help="Hours until mirror expires (1-168)"), + ] = 48, + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """Create a new ephemeral database mirror. + + Examples: + osa mirror create -c hed -c bids + osa mirror create -c hed --label "testing-new-prompt" --ttl 24 + """ + client, _ = _get_client(api_key, api_url) + + with _handle_api_errors(): + with output.streaming_status("Creating mirror..."): + result = client.create_mirror( + community_ids=community, + ttl_hours=ttl, + label=label, + ) + output.print_success(f"Mirror created: {result['mirror_id']}") + output.print_info(f" Communities: {', '.join(result['community_ids'])}") + output.print_info(f" Expires: {result['expires_at']}") + if result.get("label"): + output.print_info(f" Label: {result['label']}") + output.console.print() + output.console.print( + f'[dim]Use with: osa ask "question" -a hed --mirror {result["mirror_id"]}[/dim]' + ) + + +@mirror_app.command("list") +def list_cmd( + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """List active mirrors.""" + client, _ = _get_client(api_key, api_url) + + with _handle_api_errors(): + mirrors = client.list_mirrors() + + if not mirrors: + output.print_info("No active mirrors.") + return + + table = Table(title="Active Mirrors") + table.add_column("ID", style="cyan") + table.add_column("Communities", style="green") + table.add_column("Label") + table.add_column("Expires", style="yellow") + table.add_column("Size", style="dim") + + for m in mirrors: + table.add_row( + m["mirror_id"], + ", ".join(m["community_ids"]), + m.get("label") or "", + m["expires_at"][:19], + _format_size(m.get("size_bytes", 0)), + ) + + output.console.print(table) + + +@mirror_app.command("info") +def info( + mirror_id: Annotated[str, typer.Argument(help="Mirror ID")], + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """Show detailed information about a mirror.""" + client, _ = _get_client(api_key, api_url) + + with _handle_api_errors(): + m = client.get_mirror(mirror_id) + + output.console.print(f"[bold]Mirror:[/bold] {m['mirror_id']}") + output.console.print(f" Communities: {', '.join(m['community_ids'])}") + output.console.print(f" Created: {m['created_at']}") + output.console.print(f" Expires: {m['expires_at']}") + if m.get("label"): + output.console.print(f" Label: {m['label']}") + if m.get("owner_id"): + output.console.print(f" Owner: {m['owner_id']}") + output.console.print(f" Size: {_format_size(m.get('size_bytes', 0))}") + expired = m.get("expired", False) + mirror_status = "[red]expired[/red]" if expired else "[green]active[/green]" + output.console.print(f" Status: {mirror_status}") + + +@mirror_app.command("delete") +def delete( + mirror_id: Annotated[str, typer.Argument(help="Mirror ID")], + confirm: Annotated[ + bool, + typer.Option("--yes", "-y", help="Skip confirmation"), + ] = False, + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """Delete a mirror and its databases.""" + if not confirm: + confirm = typer.confirm(f"Delete mirror {mirror_id}?") + if not confirm: + output.print_info("Cancelled.") + return + + client, _ = _get_client(api_key, api_url) + + with _handle_api_errors(): + client.delete_mirror(mirror_id) + output.print_success(f"Mirror {mirror_id} deleted.") + + +@mirror_app.command("refresh") +def refresh( + mirror_id: Annotated[str, typer.Argument(help="Mirror ID")], + community: Annotated[ + list[str] | None, + typer.Option("--community", "-c", help="Specific community to refresh"), + ] = None, + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """Re-copy production databases into an existing mirror. + + Resets mirror data to match current production state. + """ + client, _ = _get_client(api_key, api_url) + + with _handle_api_errors(): + with output.streaming_status("Refreshing mirror..."): + result = client.refresh_mirror(mirror_id, community_ids=community) + output.print_success(f"Mirror {mirror_id} refreshed.") + output.print_info(f" Communities: {', '.join(result['community_ids'])}") + + +@mirror_app.command("sync") +def sync( + mirror_id: Annotated[str, typer.Argument(help="Mirror ID")], + sync_type: Annotated[ + str, + typer.Option( + "--type", + "-t", + help="Sync type: github, papers, docstrings, mailman, faq, beps, or all", + ), + ] = "all", + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """Run sync pipeline against a mirror's databases. + + Populates or refreshes the mirror's data from public sources + (GitHub, papers, etc.) using the server's sync pipeline. + + Examples: + osa mirror sync abc123def456 + osa mirror sync abc123def456 --type github + """ + client, _ = _get_client(api_key, api_url) + + with _handle_api_errors(): + with output.streaming_status(f"Syncing {sync_type} into mirror..."): + result = client.sync_mirror(mirror_id, sync_type=sync_type) + output.print_success(result.get("message", "Sync completed")) + items = result.get("items_synced", {}) + if items: + for st, count in items.items(): + output.print_info(f" {st}: {count} communities synced") + + +@mirror_app.command("pull") +def pull( + mirror_id: Annotated[str, typer.Argument(help="Mirror ID")], + community: Annotated[ + str | None, + typer.Option("--community", "-c", help="Specific community to download"), + ] = None, + output_dir: Annotated[ + str | None, + typer.Option("--output", "-o", help="Output directory (default: local data/knowledge)"), + ] = None, + api_key: Annotated[ + str | None, + typer.Option("--api-key", "-k", help="OpenRouter API key"), + ] = None, + api_url: Annotated[ + str | None, + typer.Option("--api-url", help="Override API URL"), + ] = None, +) -> None: + """Download mirror databases locally for offline development. + + Downloads SQLite files so you can run `osa serve` locally with the + mirror's data. Useful for testing code changes or using a local LLM. + + Examples: + osa mirror pull abc123def456 + osa mirror pull abc123def456 -c hed -o ./data/knowledge + """ + client, _ = _get_client(api_key, api_url) + dest = output_dir or str(get_data_dir() / "knowledge") + + # Get mirror info to know which communities to download + with _handle_api_errors(): + mirror_info = client.get_mirror(mirror_id) + + communities = [community] if community else mirror_info["community_ids"] + + failures = 0 + for cid in communities: + try: + with output.streaming_status(f"Downloading {cid}.db..."): + path = client.download_mirror_db(mirror_id, cid, dest) + output.print_success(f"Downloaded: {path}") + except APIError as e: + output.print_error(f"Failed to download {cid}: {e}", hint=e.detail) + failures += 1 + except (httpx.ConnectError, httpx.TimeoutException) as e: + output.print_error(f"Connection failed downloading {cid}: {e}") + failures += 1 + + output.console.print() + if failures: + output.print_error(f"{failures} download(s) failed. Local data may be incomplete.") + raise typer.Exit(code=1) + output.console.print("[dim]Start local server with: osa serve[/dim]") diff --git a/src/core/logging.py b/src/core/logging.py index 69e2983..f4d7401 100644 --- a/src/core/logging.py +++ b/src/core/logging.py @@ -1,7 +1,8 @@ """Secure logging configuration with API key redaction. -Provides a custom log formatter that automatically redacts OpenRouter API keys -from log messages to prevent credential exposure in centralized logging systems. +Provides a custom log formatter that automatically redacts API keys +(OpenRouter, Anthropic, OpenAI) from log messages to prevent credential +exposure in centralized logging systems. Supports both text and JSON-structured logging formats. """ @@ -17,12 +18,19 @@ class SecureFormatter(logging.Formatter): """Custom log formatter that redacts API keys from log messages. - Automatically detects and redacts OpenRouter API keys in the format - sk-or-v1-[64 hex chars] to prevent accidental credential exposure. + Automatically detects and redacts API keys from OpenRouter, Anthropic, + and OpenAI to prevent accidental credential exposure. """ - # Pattern to match OpenRouter API keys: sk-or-v1-[64 hex chars] - API_KEY_PATTERN = re.compile(r"sk-or-v1-[0-9a-f]{64}", re.IGNORECASE) + # Patterns for API keys from various providers. + # IGNORECASE as defense-in-depth; real keys use lowercase hex. + API_KEY_PATTERN = re.compile( + r"sk-or-v1-[0-9a-f]{64}" # OpenRouter: sk-or-v1-[64 hex chars] + r"|sk-ant-[a-zA-Z0-9_-]{80,}" # Anthropic: sk-ant-... + r"|sk-proj-[a-zA-Z0-9_-]{40,}" # OpenAI project keys: sk-proj-... + r"|sk-[a-zA-Z0-9]{48,}", # Generic OpenAI keys: sk-... + re.IGNORECASE, + ) def format(self, record: logging.LogRecord) -> str: """Format log record and redact any API keys. @@ -49,13 +57,13 @@ def format(self, record: logging.LogRecord) -> str: print(f"CRITICAL: Unexpected error in SecureFormatter: {e}", file=sys.stderr) raise - # Redact API keys with size limit to prevent ReDoS + # Redact API keys with size limit to bound processing time on large messages try: # Limit message size to prevent potential regex issues with extremely large inputs if len(formatted) > 100_000: # 100KB limit formatted = formatted[:100_000] + "... [truncated for safety]" - formatted = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", formatted) + formatted = self.API_KEY_PATTERN.sub("***[key-redacted]", formatted) except re.error as e: # Regex pattern is broken - this is a code bug print(f"CRITICAL: Redaction regex failed: {e}", file=sys.stderr) @@ -137,19 +145,28 @@ def format(self, record: logging.LogRecord) -> str: json_str = json.dumps(log_entry, default=str) # Redact API keys from the JSON string - json_str = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", json_str) + json_str = self.API_KEY_PATTERN.sub("***[key-redacted]", json_str) return json_str - except Exception as e: - # Fallback to safe error message + except (ValueError, TypeError, KeyError) as e: + # Expected serialization errors - include context for debugging + safe_msg = str(getattr(record, "msg", ""))[:200] + safe_name = getattr(record, "name", "") error_entry = { "timestamp": datetime.now(UTC).isoformat(), "level": "ERROR", "logger": "logging", - "message": f"[LOGGING ERROR: {type(e).__name__}]", + "message": f"[LOGGING ERROR: {type(e).__name__}: {e}]", + "original_logger": safe_name, + "original_message": safe_msg, } - return json.dumps(error_entry) + fallback_json = json.dumps(error_entry) + return self.API_KEY_PATTERN.sub("***[key-redacted]", fallback_json) + except Exception as e: + # Unexpected errors - surface to stderr and re-raise + print(f"CRITICAL: Unexpected error in SecureJSONFormatter: {e}", file=sys.stderr) + raise def configure_secure_logging( diff --git a/src/core/services/litellm_llm.py b/src/core/services/litellm_llm.py index 6d1b5f8..1f43b47 100644 --- a/src/core/services/litellm_llm.py +++ b/src/core/services/litellm_llm.py @@ -69,8 +69,9 @@ def create_openrouter_llm( provider: Specific provider to use (e.g., "Cerebras", "DeepInfra/FP8"). Ignored for Anthropic models, which always use "Anthropic" provider. user_id: User identifier for cache optimization (sticky routing) - enable_caching: Enable prompt caching. If None (default), enabled for all models. - OpenRouter/LiteLLM gracefully handles models that don't support caching. + enable_caching: Enable prompt caching. If None (default), caching is requested + for all models. Models that do not support caching will ignore the + cache_control markers without error. Returns: LLM instance configured for OpenRouter @@ -663,6 +664,8 @@ async def astream( # Caching is enabled by default for all models; OpenRouter/LiteLLM handle # unsupported models gracefully by ignoring cache_control parameters. CACHEABLE_MODELS = { + "claude-opus-4.6": "anthropic/claude-opus-4.6", + "claude-sonnet-4.6": "anthropic/claude-sonnet-4.6", "claude-opus-4.5": "anthropic/claude-opus-4.5", "claude-sonnet-4.5": "anthropic/claude-sonnet-4.5", "claude-haiku-4.5": "anthropic/claude-haiku-4.5", diff --git a/src/core/services/llm.py b/src/core/services/llm.py index 8809416..bce2e70 100644 --- a/src/core/services/llm.py +++ b/src/core/services/llm.py @@ -18,20 +18,29 @@ class LLMService: # Model mappings for direct OpenAI API OPENAI_MODELS = { + "gpt-5.2": "gpt-5.2", + "gpt-5": "gpt-5", + "gpt-5-mini": "gpt-5-mini", + "gpt-4.1": "gpt-4.1", + "gpt-4.1-mini": "gpt-4.1-mini", "gpt-4o": "gpt-4o", "gpt-4o-mini": "gpt-4o-mini", - "gpt-4-turbo": "gpt-4-turbo", - "gpt-4": "gpt-4", - "gpt-3.5-turbo": "gpt-3.5-turbo", + "o4-mini": "o4-mini", + "o3": "o3", + "o3-mini": "o3-mini", } # Model mappings for direct Anthropic API ANTHROPIC_MODELS = { - "claude-3-5-sonnet": "claude-3-5-sonnet-20241022", - "claude-3-5-haiku": "claude-3-5-haiku-20241022", - "claude-3-opus": "claude-3-opus-20240229", - "claude-3-sonnet": "claude-3-sonnet-20240229", - "claude-3-haiku": "claude-3-haiku-20240307", + "claude-sonnet-4.6": "claude-sonnet-4-6-20250610", + "claude-opus-4.6": "claude-opus-4-6-20250610", + "claude-sonnet-4.5": "claude-sonnet-4-5-20250514", + "claude-opus-4.5": "claude-opus-4-5-20250514", + "claude-haiku-4.5": "claude-haiku-4-5-20251001", + "claude-sonnet-4": "claude-sonnet-4-20250514", + "claude-3.7-sonnet": "claude-3-7-sonnet-20250219", + "claude-3.5-sonnet": "claude-3-5-sonnet-20241022", + "claude-3.5-haiku": "claude-3-5-haiku-20241022", } @property @@ -134,7 +143,7 @@ def get_model( model_name: Model name. Supports: - OpenRouter format: 'creator/model' (e.g., 'openai/gpt-oss-120b', 'qwen/qwen3-235b') - Direct OpenAI: 'gpt-4o', 'gpt-4o-mini', etc. - - Direct Anthropic: 'claude-3-5-sonnet', etc. + - Direct Anthropic: 'claude-3.5-sonnet', etc. If not provided, uses settings.default_model. api_key: Optional API key override (for BYOK). temperature: Model temperature. If not provided, uses settings.llm_temperature. diff --git a/src/core/validation.py b/src/core/validation.py new file mode 100644 index 0000000..778ee86 --- /dev/null +++ b/src/core/validation.py @@ -0,0 +1,13 @@ +"""Shared input validation utilities. + +Provides common validation functions used across modules for +preventing path traversal and ensuring safe identifiers. +""" + + +def is_safe_identifier(value: str) -> bool: + """Check if a string is a safe identifier (alphanumeric, hyphens, underscores). + + Used for both mirror IDs and community IDs to prevent path traversal. + """ + return bool(value) and value.replace("-", "").replace("_", "").isalnum() diff --git a/src/knowledge/db.py b/src/knowledge/db.py index 8ba2d77..5c9166d 100644 --- a/src/knowledge/db.py +++ b/src/knowledge/db.py @@ -11,6 +11,7 @@ not authoritative sources for answering questions. """ +import contextvars import json import logging import sqlite3 @@ -20,9 +21,61 @@ from pathlib import Path from src.cli.config import get_data_dir +from src.core.validation import is_safe_identifier +from src.knowledge.mirror import _validate_mirror_id logger = logging.getLogger(__name__) +# ContextVar for transparent mirror routing. When set, get_db_path() returns +# the mirror's database path instead of the production path. +# Safe for concurrent requests because the middleware sets and resets the +# value around each request's lifecycle. Nested async calls within the +# same request inherit the value. +_active_mirror_id: contextvars.ContextVar[str | None] = contextvars.ContextVar( + "_active_mirror_id", default=None +) + + +def set_active_mirror(mirror_id: str | None) -> contextvars.Token[str | None]: + """Set the active mirror for the current async context. + + Validates the mirror ID immediately to catch invalid values early. + Returns a token that can be used to reset the context variable. + """ + if mirror_id is not None: + _validate_mirror_id(mirror_id) + return _active_mirror_id.set(mirror_id) + + +def reset_active_mirror(token: contextvars.Token[str | None]) -> None: + """Reset the active mirror to its previous value.""" + _active_mirror_id.reset(token) + + +def get_active_mirror() -> str | None: + """Get the currently active mirror ID, or None if not in mirror mode.""" + return _active_mirror_id.get() + + +@contextmanager +def active_mirror_context(mirror_id: str) -> Iterator[None]: + """Context manager for mirror routing. + + Sets the active mirror for the duration of the block and resets it + afterward, even if an exception occurs. + + Usage: + with active_mirror_context("abc123"): + # All DB operations within this block go to the mirror + ... + """ + token = set_active_mirror(mirror_id) + try: + yield + finally: + reset_active_mirror(token) + + SCHEMA_SQL = """ -- GitHub issues and PRs CREATE TABLE IF NOT EXISTS github_items ( @@ -377,6 +430,8 @@ def get_db_path(project: str = "hed") -> Path: """Get path to knowledge database for a project. Each assistant/project has its own isolated knowledge database. + When a mirror is active (via ContextVar), returns the mirror's + database path instead. Args: project: Assistant/project name (e.g., 'hed', 'bids', 'eeglab'). @@ -389,12 +444,17 @@ def get_db_path(project: str = "hed") -> Path: ValueError: If project name contains invalid characters. """ # Validate project name to prevent path traversal - if not project or not project.replace("-", "").replace("_", "").isalnum(): + if not is_safe_identifier(project): raise ValueError( f"Invalid project name: {project}. " "Use only alphanumeric characters, hyphens, and underscores." ) + mirror_id = get_active_mirror() + if mirror_id: + # mirror_id was already validated when set via set_active_mirror() + return get_data_dir() / "mirrors" / mirror_id / f"{project}.db" + return get_data_dir() / "knowledge" / f"{project}.db" diff --git a/src/knowledge/mirror.py b/src/knowledge/mirror.py new file mode 100644 index 0000000..478f536 --- /dev/null +++ b/src/knowledge/mirror.py @@ -0,0 +1,435 @@ +"""Ephemeral database mirror lifecycle management. + +Mirrors are short-lived copies of community knowledge databases that allow +developers to read, write, and re-sync without affecting production data. +Each mirror gets its own directory under data/mirrors/{mirror_id}/ containing +copies of the relevant community SQLite databases. + +Default TTL is 48 hours; maximum is 168 hours (7 days). +""" + +import json +import logging +import shutil +from dataclasses import dataclass, field +from datetime import UTC, datetime, timedelta +from pathlib import Path +from uuid import uuid4 + +from src.cli.config import get_data_dir +from src.core.validation import is_safe_identifier + +logger = logging.getLogger(__name__) + +MIRRORS_DIR_NAME = "mirrors" +METADATA_FILE = "_metadata.json" + +# Resource limits +DEFAULT_TTL_HOURS = 48 +MAX_TTL_HOURS = 168 # 7 days +MAX_MIRRORS_TOTAL = 50 +MAX_MIRRORS_PER_USER = 2 + + +@dataclass(frozen=True) +class MirrorInfo: + """Metadata for an ephemeral database mirror. + + Frozen dataclass: all fields are immutable after construction. + Default TTL is 48 hours; maximum is 168 hours (7 days). + """ + + mirror_id: str + community_ids: tuple[str, ...] + created_at: datetime + expires_at: datetime + owner_id: str | None = None + label: str | None = None + size_bytes: int = field(default=0, repr=False) + + def __post_init__(self) -> None: + """Validate invariants at construction time.""" + if not is_safe_identifier(self.mirror_id): + raise ValueError(f"Invalid mirror ID: {self.mirror_id!r}") + if not self.community_ids: + raise ValueError("community_ids must not be empty") + for cid in self.community_ids: + if not is_safe_identifier(cid): + raise ValueError(f"Invalid community ID in community_ids: {cid!r}") + + def is_expired(self) -> bool: + """Check if the mirror has passed its expiration time.""" + return datetime.now(UTC) >= self.expires_at + + def to_dict(self) -> dict: + """Serialize to dictionary for JSON storage. + + Note: size_bytes is excluded because it is calculated dynamically + at read time from the actual database files on disk. + """ + return { + "mirror_id": self.mirror_id, + "community_ids": list(self.community_ids), + "created_at": self.created_at.isoformat(), + "expires_at": self.expires_at.isoformat(), + "owner_id": self.owner_id, + "label": self.label, + } + + @classmethod + def from_dict(cls, data: dict, size_bytes: int = 0) -> "MirrorInfo": + """Deserialize from dictionary.""" + return cls( + mirror_id=data["mirror_id"], + community_ids=tuple(data["community_ids"]), + created_at=datetime.fromisoformat(data["created_at"]), + expires_at=datetime.fromisoformat(data["expires_at"]), + owner_id=data.get("owner_id"), + label=data.get("label"), + size_bytes=size_bytes, + ) + + +def _get_mirrors_dir() -> Path: + """Get the base directory for all mirrors.""" + return get_data_dir() / MIRRORS_DIR_NAME + + +def _validate_mirror_id(mirror_id: str) -> None: + """Validate mirror ID format to prevent path traversal. + + Raises: + ValueError: If mirror_id contains invalid characters. + """ + if not mirror_id or len(mirror_id) > 64: + raise ValueError(f"Invalid mirror ID length: {len(mirror_id) if mirror_id else 0}") + if not is_safe_identifier(mirror_id): + raise ValueError( + f"Invalid mirror ID: {mirror_id}. " + "Use only alphanumeric characters, hyphens, and underscores." + ) + + +def _validate_community_id(community_id: str) -> None: + """Validate community ID format to prevent path traversal. + + Raises: + ValueError: If community_id contains invalid characters. + """ + if not is_safe_identifier(community_id): + raise ValueError( + f"Invalid community ID: {community_id!r}. " + "Use only alphanumeric characters, hyphens, and underscores." + ) + + +def _get_mirror_dir(mirror_id: str) -> Path: + """Get the directory for a specific mirror. + + Validates mirror_id format to prevent path traversal. + """ + _validate_mirror_id(mirror_id) + return _get_mirrors_dir() / mirror_id + + +def _get_metadata_path(mirror_id: str) -> Path: + """Get the path to a mirror's metadata file.""" + return _get_mirror_dir(mirror_id) / METADATA_FILE + + +def _write_metadata(info: MirrorInfo) -> None: + """Write mirror metadata to disk.""" + path = _get_metadata_path(info.mirror_id) + path.write_text(json.dumps(info.to_dict(), indent=2)) + + +class CorruptMirrorError(Exception): + """Raised when a mirror's metadata file exists but is corrupt or unreadable.""" + + def __init__(self, mirror_id: str, cause: Exception): + self.mirror_id = mirror_id + self.cause = cause + super().__init__( + f"Mirror '{mirror_id}' has corrupt metadata: {cause}. " + f"Delete and recreate, or inspect the metadata file." + ) + + +def _read_metadata(mirror_id: str) -> MirrorInfo | None: + """Read mirror metadata from disk. + + Returns None if the mirror does not exist. + + Raises: + CorruptMirrorError: If metadata file exists but is corrupt. + """ + path = _get_metadata_path(mirror_id) + if not path.exists(): + return None + try: + data = json.loads(path.read_text()) + info = MirrorInfo.from_dict(data, size_bytes=_calculate_mirror_size(mirror_id)) + return info + except (json.JSONDecodeError, KeyError, UnicodeDecodeError, ValueError) as e: + logger.error( + "Corrupt metadata for mirror '%s': %s", + mirror_id, + e, + ) + raise CorruptMirrorError(mirror_id, e) from e + + +def _calculate_mirror_size(mirror_id: str) -> int: + """Calculate total size of database files in a mirror.""" + mirror_dir = _get_mirror_dir(mirror_id) + if not mirror_dir.exists(): + return 0 + return sum(f.stat().st_size for f in mirror_dir.glob("*.db") if f.is_file()) + + +def _get_production_db_path(community_id: str) -> Path: + """Get the path to a production community database. + + Validates community_id to prevent path traversal. + """ + _validate_community_id(community_id) + return get_data_dir() / "knowledge" / f"{community_id}.db" + + +def get_mirror_db_path(mirror_id: str, community_id: str) -> Path: + """Get the path to a community database file within a mirror. + + Args: + mirror_id: The mirror's identifier. + community_id: The community whose database to locate. + + Returns: + Path to the SQLite database file. + + Raises: + ValueError: If mirror_id or community_id is invalid. + """ + _validate_community_id(community_id) + return _get_mirror_dir(mirror_id) / f"{community_id}.db" + + +def create_mirror( + community_ids: list[str], + ttl_hours: int = DEFAULT_TTL_HOURS, + label: str | None = None, + owner_id: str | None = None, +) -> MirrorInfo: + """Create a new mirror by copying production database files. + + Args: + community_ids: List of community IDs to include in the mirror. + ttl_hours: Hours until the mirror expires (default 48, max 168). + label: Optional human-readable label for the mirror. + owner_id: Optional owner identifier (user_id) for rate limiting. + + Returns: + MirrorInfo with the new mirror's metadata. + + Raises: + ValueError: If no valid community databases found or limits exceeded. + """ + ttl_hours = min(ttl_hours, MAX_TTL_HOURS) + + # Check total mirror count + existing = list_mirrors() + active_mirrors = [m for m in existing if not m.is_expired()] + if len(active_mirrors) >= MAX_MIRRORS_TOTAL: + raise ValueError( + f"Maximum number of mirrors ({MAX_MIRRORS_TOTAL}) reached. " + "Delete existing mirrors or wait for them to expire." + ) + + # Check per-user limit + if owner_id: + user_mirrors = [m for m in active_mirrors if m.owner_id == owner_id] + if len(user_mirrors) >= MAX_MIRRORS_PER_USER: + raise ValueError( + f"Maximum mirrors per user ({MAX_MIRRORS_PER_USER}) reached. " + "Delete an existing mirror first." + ) + + mirror_id = uuid4().hex[:12] + mirror_dir = _get_mirror_dir(mirror_id) + mirror_dir.mkdir(parents=True, exist_ok=True) + + try: + copied_communities = [] + for community_id in community_ids: + source_db = _get_production_db_path(community_id) + if not source_db.exists(): + logger.warning("No database found for community '%s', skipping", community_id) + continue + dest_db = mirror_dir / f"{community_id}.db" + shutil.copy2(str(source_db), str(dest_db)) + copied_communities.append(community_id) + logger.info("Copied %s to mirror %s", community_id, mirror_id) + + if not copied_communities: + raise ValueError( + f"No databases found for communities: {community_ids}. " + "Ensure the communities exist and have been synced." + ) + + now = datetime.now(UTC) + info = MirrorInfo( + mirror_id=mirror_id, + community_ids=tuple(copied_communities), + created_at=now, + expires_at=now + timedelta(hours=ttl_hours), + owner_id=owner_id, + label=label, + size_bytes=_calculate_mirror_size(mirror_id), + ) + _write_metadata(info) + except Exception: + # Clean up on any failure to avoid orphaned directories + try: + shutil.rmtree(str(mirror_dir)) + except OSError as cleanup_err: + logger.warning( + "Failed to clean up partial mirror directory %s: %s", + mirror_dir, + cleanup_err, + ) + raise + + logger.info( + "Created mirror %s with communities %s (expires %s)", + mirror_id, + copied_communities, + info.expires_at, + ) + return info + + +def get_mirror(mirror_id: str) -> MirrorInfo | None: + """Get mirror metadata. + + Returns None if the mirror does not exist. + Does NOT check expiration; callers should check is_expired(). + + Raises: + CorruptMirrorError: If metadata file exists but is corrupt. + """ + return _read_metadata(mirror_id) + + +def list_mirrors() -> list[MirrorInfo]: + """List all mirrors (including expired ones still on disk). + + Skips mirrors with corrupt metadata (logged as errors). + """ + mirrors_dir = _get_mirrors_dir() + if not mirrors_dir.exists(): + return [] + + result = [] + for entry in mirrors_dir.iterdir(): + if entry.is_dir() and (entry / METADATA_FILE).exists(): + try: + info = _read_metadata(entry.name) + except CorruptMirrorError: + # Already logged in _read_metadata; skip corrupt mirrors + continue + if info: + result.append(info) + + result.sort(key=lambda m: m.created_at, reverse=True) + return result + + +def delete_mirror(mirror_id: str) -> bool: + """Delete a mirror and all its databases. + + Returns True if the mirror was deleted, False if not found. + + Raises: + OSError: If the directory exists but deletion fails (e.g. permissions). + """ + mirror_dir = _get_mirror_dir(mirror_id) + if not mirror_dir.exists(): + return False + + shutil.rmtree(str(mirror_dir)) + logger.info("Deleted mirror %s", mirror_id) + return True + + +def refresh_mirror( + mirror_id: str, + community_ids: list[str] | None = None, +) -> MirrorInfo: + """Re-copy production databases into an existing mirror. + + This resets the mirror's data to match current production. + + Args: + mirror_id: ID of the mirror to refresh. + community_ids: Specific communities to refresh, or None for all. + + Returns: + Updated MirrorInfo. + + Raises: + ValueError: If mirror not found or expired. + CorruptMirrorError: If mirror metadata is corrupt. + """ + info = get_mirror(mirror_id) + if not info: + raise ValueError(f"Mirror '{mirror_id}' not found") + if info.is_expired(): + raise ValueError(f"Mirror '{mirror_id}' has expired") + + targets = community_ids or info.community_ids + mirror_dir = _get_mirror_dir(mirror_id) + + refreshed = [] + for community_id in targets: + source_db = _get_production_db_path(community_id) + if not source_db.exists(): + logger.warning("No production database for '%s', skipping refresh", community_id) + continue + dest_db = mirror_dir / f"{community_id}.db" + shutil.copy2(str(source_db), str(dest_db)) + refreshed.append(community_id) + logger.info("Refreshed %s in mirror %s", community_id, mirror_id) + + if not refreshed: + raise ValueError( + f"No production databases found for communities: {targets}. Nothing was refreshed." + ) + + # Return a new MirrorInfo with updated size (frozen dataclass) + return MirrorInfo( + mirror_id=info.mirror_id, + community_ids=info.community_ids, + created_at=info.created_at, + expires_at=info.expires_at, + owner_id=info.owner_id, + label=info.label, + size_bytes=_calculate_mirror_size(mirror_id), + ) + + +def cleanup_expired_mirrors() -> int: + """Delete all expired mirrors. Returns the count of mirrors deleted. + + Continues past individual deletion failures so one stuck mirror + does not block cleanup of the rest. + """ + deleted = 0 + for info in list_mirrors(): + if info.is_expired(): + try: + if delete_mirror(info.mirror_id): + deleted += 1 + except OSError: + logger.error("Failed to delete expired mirror %s", info.mirror_id, exc_info=True) + if deleted: + logger.info("Cleaned up %d expired mirrors", deleted) + return deleted diff --git a/src/metrics/cost.py b/src/metrics/cost.py index 1fd5036..de88fd5 100644 --- a/src/metrics/cost.py +++ b/src/metrics/cost.py @@ -1,44 +1,103 @@ """Cost estimation for LLM requests. Model pricing table with per-token costs (USD per million tokens). -Pricing is from OpenRouter as of 2025-07; update regularly. +Pricing is from OpenRouter; models added incrementally so individual +prices may have different verification dates. + +Also defines cost protection thresholds for blocking expensive models +on platform/community keys (not BYOK). """ import logging +from typing import NamedTuple logger = logging.getLogger(__name__) -# Pricing: USD per 1M tokens (input, output) -# Source: https://openrouter.ai/models -# Last updated: 2025-07 -MODEL_PRICING: dict[str, tuple[float, float]] = { - # Qwen models - "qwen/qwen3-235b-a22b-2507": (0.14, 0.34), - "qwen/qwen3-30b-a3b-2507": (0.07, 0.15), - # OpenAI models - "openai/gpt-4o": (2.50, 10.00), - "openai/gpt-4o-mini": (0.15, 0.60), - "openai/gpt-oss-120b": (0.00, 0.00), # Free tier - "openai/o1": (15.00, 60.00), - "openai/o3-mini": (1.10, 4.40), + +class ModelRate(NamedTuple): + """Per-token pricing for a model (USD per 1M tokens).""" + + input_per_1m: float + output_per_1m: float + + +# Source: https://openrouter.ai/api/v1/models +# Last verified: 2026-03 +MODEL_PRICING: dict[str, ModelRate] = { # Anthropic models - "anthropic/claude-opus-4": (15.00, 75.00), - "anthropic/claude-sonnet-4": (3.00, 15.00), - "anthropic/claude-haiku-4.5": (0.80, 4.00), - "anthropic/claude-3.5-sonnet": (3.00, 15.00), + "anthropic/claude-opus-4.6": ModelRate(5.00, 25.00), + "anthropic/claude-opus-4.5": ModelRate(5.00, 25.00), + "anthropic/claude-opus-4.1": ModelRate(15.00, 75.00), + "anthropic/claude-opus-4": ModelRate(15.00, 75.00), + "anthropic/claude-sonnet-4.6": ModelRate(3.00, 15.00), + "anthropic/claude-sonnet-4.5": ModelRate(3.00, 15.00), + "anthropic/claude-sonnet-4": ModelRate(3.00, 15.00), + "anthropic/claude-haiku-4.5": ModelRate(1.00, 5.00), + "anthropic/claude-3.7-sonnet": ModelRate(3.00, 15.00), + "anthropic/claude-3.5-sonnet": ModelRate(6.00, 30.00), + "anthropic/claude-3.5-haiku": ModelRate(0.80, 4.00), + # OpenAI models + "openai/gpt-5.2": ModelRate(1.75, 14.00), + "openai/gpt-5.2-chat": ModelRate(1.75, 14.00), + "openai/gpt-5.1": ModelRate(1.25, 10.00), + "openai/gpt-5": ModelRate(1.25, 10.00), + "openai/gpt-5-chat": ModelRate(1.25, 10.00), + "openai/gpt-5-mini": ModelRate(0.25, 2.00), + "openai/gpt-5-nano": ModelRate(0.05, 0.40), + "openai/gpt-5-pro": ModelRate(15.00, 120.00), + "openai/gpt-4.1": ModelRate(2.00, 8.00), + "openai/gpt-4.1-mini": ModelRate(0.40, 1.60), + "openai/gpt-4.1-nano": ModelRate(0.10, 0.40), + "openai/gpt-4o": ModelRate(2.50, 10.00), + "openai/gpt-4o-mini": ModelRate(0.15, 0.60), + "openai/o4-mini": ModelRate(1.10, 4.40), + "openai/o3": ModelRate(2.00, 8.00), + "openai/o3-mini": ModelRate(1.10, 4.40), + "openai/o3-pro": ModelRate(20.00, 80.00), + "openai/o1": ModelRate(15.00, 60.00), + "openai/gpt-oss-120b": ModelRate(0.04, 0.19), # Google models - "google/gemini-2.5-pro-preview": (1.25, 10.00), - "google/gemini-2.5-flash-preview": (0.15, 0.60), + "google/gemini-3.1-pro-preview": ModelRate(2.00, 12.00), + "google/gemini-3-pro-preview": ModelRate(2.00, 12.00), + "google/gemini-3-flash-preview": ModelRate(0.50, 3.00), + "google/gemini-2.5-pro": ModelRate(1.25, 10.00), + "google/gemini-2.5-pro-preview": ModelRate(1.25, 10.00), + "google/gemini-2.5-flash": ModelRate(0.30, 2.50), + "google/gemini-2.5-flash-lite": ModelRate(0.10, 0.40), # DeepSeek models - "deepseek/deepseek-chat-v3": (0.14, 0.28), - "deepseek/deepseek-r1": (0.55, 2.19), + "deepseek/deepseek-v3.2": ModelRate(0.25, 0.40), + "deepseek/deepseek-chat-v3.1": ModelRate(0.15, 0.75), + "deepseek/deepseek-chat": ModelRate(0.32, 0.89), + "deepseek/deepseek-r1": ModelRate(0.70, 2.50), + "deepseek/deepseek-r1-0528": ModelRate(0.45, 2.15), + # Qwen models + "qwen/qwen3.5-397b-a17b": ModelRate(0.39, 2.34), + "qwen/qwen3-235b-a22b-2507": ModelRate(0.07, 0.10), + "qwen/qwen3-235b-a22b": ModelRate(0.45, 1.82), + "qwen/qwen3-30b-a3b-2507": ModelRate(0.09, 0.30), + "qwen/qwen3-coder": ModelRate(0.22, 1.00), + "qwen/qwen3-max": ModelRate(1.20, 6.00), # Meta models - "meta-llama/llama-4-maverick": (0.16, 0.40), + "meta-llama/llama-4-maverick": ModelRate(0.15, 0.60), + "meta-llama/llama-4-scout": ModelRate(0.08, 0.30), + "meta-llama/llama-3.3-70b-instruct": ModelRate(0.10, 0.32), } +# Validate all pricing entries at import time to catch typos +for _model_name, _rate in MODEL_PRICING.items(): + if _rate.input_per_1m < 0 or _rate.output_per_1m < 0: + raise ValueError( + f"Negative rate for model {_model_name}: " + f"input={_rate.input_per_1m}, output={_rate.output_per_1m}" + ) + # Fallback rate for models not in the pricing table -_FALLBACK_INPUT_RATE = 1.00 # USD per 1M tokens -_FALLBACK_OUTPUT_RATE = 3.00 # USD per 1M tokens +_FALLBACK_RATE = ModelRate(input_per_1m=1.00, output_per_1m=3.00) + +# Cost protection thresholds (USD per 1M input tokens) +# Applied only when using platform/community keys (not BYOK) +COST_WARN_THRESHOLD = 5.0 # Log warning for models above this +COST_BLOCK_THRESHOLD = 15.0 # Block requests for models above this def estimate_cost( @@ -57,11 +116,11 @@ def estimate_cost( Estimated cost in USD, rounded to 6 decimal places. """ if model and model in MODEL_PRICING: - input_rate, output_rate = MODEL_PRICING[model] + rate = MODEL_PRICING[model] else: if model: logger.warning("No pricing data for model %s, using fallback rates", model) - input_rate, output_rate = _FALLBACK_INPUT_RATE, _FALLBACK_OUTPUT_RATE + rate = _FALLBACK_RATE - cost = (input_tokens * input_rate + output_tokens * output_rate) / 1_000_000 + cost = (input_tokens * rate.input_per_1m + output_tokens * rate.output_per_1m) / 1_000_000 return round(cost, 6) diff --git a/src/version.py b/src/version.py index 610ea3e..4b68242 100644 --- a/src/version.py +++ b/src/version.py @@ -1,7 +1,7 @@ """Version information for OSA.""" -__version__ = "0.7.1" -__version_info__ = (0, 7, 1) +__version__ = "0.8.0" +__version_info__ = (0, 8, 0) def get_version() -> str: diff --git a/tests/test_api/test_authorization.py b/tests/test_api/test_authorization.py index 297f0f9..d892cff 100644 --- a/tests/test_api/test_authorization.py +++ b/tests/test_api/test_authorization.py @@ -1,137 +1,153 @@ -"""Tests for API key authorization and model selection logic.""" +"""Tests for API key authorization and model selection logic. -from unittest.mock import MagicMock, patch +Uses real community configurations loaded via discover_assistants(). +No mocks -- environment variables are set via monkeypatch (real Settings reads them). +""" import pytest from fastapi import HTTPException +from src.api.config import get_settings from src.api.routers.community import _is_authorized_origin, _select_api_key, _select_model +from src.assistants import discover_assistants, registry +from src.assistants.registry import AssistantInfo -@pytest.fixture -def mock_registry(): - """Mock registry with test community configurations.""" - with patch("src.api.routers.community.registry") as mock_reg: - # HED community with CORS origins - hed_info = MagicMock() - hed_info.id = "hed" - hed_info.community_config = MagicMock() - hed_info.community_config.cors_origins = [ - "https://hedtags.org", - "https://www.hedtags.org", - "https://*.pages.dev", - ] - hed_info.community_config.openrouter_api_key_env_var = "OPENROUTER_API_KEY_HED" - hed_info.community_config.default_model = None - hed_info.community_config.default_model_provider = None - - # BIDS community with custom model - bids_info = MagicMock() - bids_info.id = "bids" - bids_info.community_config = MagicMock() - bids_info.community_config.cors_origins = ["https://bids.neuroimaging.io"] - bids_info.community_config.openrouter_api_key_env_var = None - bids_info.community_config.default_model = "anthropic/claude-3.5-sonnet" - bids_info.community_config.default_model_provider = None - - # Community without CORS origins - no_cors_info = MagicMock() - no_cors_info.id = "no-cors" - no_cors_info.community_config = MagicMock() - no_cors_info.community_config.cors_origins = [] - - mock_reg.get.side_effect = lambda id: { - "hed": hed_info, - "bids": bids_info, - "no-cors": no_cors_info, - }.get(id) - - yield mock_reg +@pytest.fixture(autouse=True, scope="module") +def _load_communities(): + """Load real community configurations for all tests in this module.""" + discover_assistants() + + +@pytest.fixture(autouse=True) +def _clear_settings_cache(): + """Clear the lru_cache on get_settings so each test gets fresh Settings.""" + get_settings.cache_clear() + yield + get_settings.cache_clear() + + +def _get_config(community_id): + """Get a community's config from the registry.""" + info = registry.get(community_id) + assert info is not None, f"Community '{community_id}' not found in registry" + return info + + +def _get_exact_origin(community_id): + """Get the first non-wildcard CORS origin for a community.""" + info = _get_config(community_id) + for origin in info.community_config.cors_origins: + if "*" not in origin: + return origin + pytest.fail(f"No exact CORS origin found for '{community_id}'") + + +def _get_wildcard_origin(community_id): + """Get a wildcard CORS pattern for a community, returns (pattern, constructed_origin).""" + info = _get_config(community_id) + for origin in info.community_config.cors_origins: + if "*" in origin: + # Replace * with a test subdomain + return origin, origin.replace("*", "test-subdomain") + pytest.skip(f"No wildcard CORS origin for '{community_id}'") class TestIsAuthorizedOrigin: """Tests for _is_authorized_origin helper function.""" - def test_platform_default_origin_always_allowed(self, mock_registry): # noqa: ARG002 + def test_platform_default_origin_always_allowed(self): """Platform default origins should be allowed for all communities.""" - # Primary domain - assert _is_authorized_origin("https://demo.osc.earth", "hed") is True - assert _is_authorized_origin("https://demo.osc.earth", "bids") is True - assert _is_authorized_origin("https://demo.osc.earth", "no-cors") is True + for community_id in ["hed", "bids", "eeglab"]: + assert _is_authorized_origin("https://demo.osc.earth", community_id) is True # Legacy pages.dev assert _is_authorized_origin("https://osa-demo.pages.dev", "hed") is True - def test_platform_wildcard_origin_always_allowed(self, mock_registry): # noqa: ARG002 + def test_platform_wildcard_origin_always_allowed(self): """Platform wildcard origins should be allowed for all communities.""" - # Primary domain single-level subdomains assert _is_authorized_origin("https://develop-demo.osc.earth", "hed") is True assert _is_authorized_origin("https://preview-123-demo.osc.earth", "bids") is True # Legacy pages.dev subdomains - assert _is_authorized_origin("https://feature-branch.osa-demo.pages.dev", "no-cors") is True - - def test_exact_origin_match(self, mock_registry): # noqa: ARG002 - """Should return True for exact origin match.""" - assert _is_authorized_origin("https://hedtags.org", "hed") is True - assert _is_authorized_origin("https://www.hedtags.org", "hed") is True - - def test_wildcard_origin_match(self, mock_registry): # noqa: ARG002 - """Should return True for wildcard subdomain match.""" - assert _is_authorized_origin("https://my-app.pages.dev", "hed") is True - assert _is_authorized_origin("https://preview-123.pages.dev", "hed") is True - - def test_wildcard_does_not_match_multiple_levels(self, mock_registry): # noqa: ARG002 + assert _is_authorized_origin("https://feature-branch.osa-demo.pages.dev", "eeglab") is True + + def test_exact_origin_match(self): + """Should return True for exact origin match using real community CORS origins.""" + hed_info = _get_config("hed") + for origin in hed_info.community_config.cors_origins: + if "*" not in origin: + assert _is_authorized_origin(origin, "hed") is True, ( + f"Expected {origin} to be authorized for HED" + ) + + def test_wildcard_origin_match(self): + """Should return True for wildcard subdomain match using real config.""" + _pattern, test_origin = _get_wildcard_origin("mne") + assert _is_authorized_origin(test_origin, "mne") is True + + def test_wildcard_does_not_match_multiple_levels(self): """Wildcard should match single subdomain, not multiple levels.""" - # *.pages.dev should NOT match foo.bar.pages.dev - assert _is_authorized_origin("https://foo.bar.pages.dev", "hed") is False + pattern, _test_origin = _get_wildcard_origin("mne") + # Insert extra subdomain level into the constructed origin + multi_level = pattern.replace("*", "foo.bar") + assert _is_authorized_origin(multi_level, "mne") is False - def test_no_origin_returns_false(self, mock_registry): # noqa: ARG002 + def test_no_origin_returns_false(self): """Should return False when origin is None (CLI, mobile apps).""" assert _is_authorized_origin(None, "hed") is False - def test_unauthorized_origin_returns_false(self, mock_registry): # noqa: ARG002 + def test_unauthorized_origin_returns_false(self): """Should return False for origin not in CORS list.""" assert _is_authorized_origin("https://evil.com", "hed") is False assert _is_authorized_origin("https://example.org", "hed") is False - def test_case_sensitive_origin_matching(self, mock_registry): # noqa: ARG002 + def test_case_sensitive_origin_matching(self): """Origin matching should be case-sensitive.""" - # HTTPS vs https - assert _is_authorized_origin("https://hedtags.org", "hed") is True - assert _is_authorized_origin("HTTPS://hedtags.org", "hed") is False - - def test_community_without_cors_origins(self, mock_registry): # noqa: ARG002 - """Should return False for community with empty cors_origins.""" - assert _is_authorized_origin("https://example.com", "no-cors") is False - - def test_unknown_community_returns_false(self, mock_registry): # noqa: ARG002 + origin = _get_exact_origin("hed") + assert _is_authorized_origin(origin, "hed") is True + assert _is_authorized_origin(origin.replace("https://", "HTTPS://"), "hed") is False + + def test_community_cors_origins_from_eeglab(self): + """Verify EEGLAB CORS origins work correctly.""" + eeglab_info = _get_config("eeglab") + for origin in eeglab_info.community_config.cors_origins: + if "*" not in origin: + assert _is_authorized_origin(origin, "eeglab") is True + assert _is_authorized_origin("https://example.com", "eeglab") is False + + def test_unknown_community_returns_false(self): """Should return False for unknown community ID.""" - assert _is_authorized_origin("https://hedtags.org", "unknown") is False + origin = _get_exact_origin("hed") + assert _is_authorized_origin(origin, "nonexistent-community-xyz") is False - def test_domain_case_sensitivity(self, mock_registry): # noqa: ARG002 + def test_domain_case_sensitivity(self): """Domain matching is currently case-sensitive. Note: Per RFC 3986, scheme and host should be case-insensitive, but current implementation uses exact string matching. This test documents current behavior. """ - # Exact case match works - assert _is_authorized_origin("https://hedtags.org", "hed") is True - # Different case in domain currently fails (even though RFC says it should work) - assert _is_authorized_origin("https://HedTags.ORG", "hed") is False - assert _is_authorized_origin("https://HEDTAGS.ORG", "hed") is False + origin = _get_exact_origin("hed") + assert _is_authorized_origin(origin, "hed") is True + assert _is_authorized_origin(origin.upper(), "hed") is False + + def test_cross_community_origins_not_shared(self): + """Community origins should not work for other communities.""" + hed_origin = _get_exact_origin("hed") + bids_origin = _get_exact_origin("bids") + assert _is_authorized_origin(hed_origin, "bids") is False + assert _is_authorized_origin(bids_origin, "hed") is False class TestSelectApiKey: """Tests for _select_api_key authorization logic.""" - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_byok_always_allowed(self, mock_settings, mock_registry): # noqa: ARG002 + def test_byok_always_allowed(self, monkeypatch): """BYOK should always be allowed regardless of origin.""" - mock_settings.return_value.openrouter_api_key = "platform-key" + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") + origin = _get_exact_origin("hed") # BYOK with authorized origin - key, source = _select_api_key("hed", "user-key", "https://hedtags.org") + key, source = _select_api_key("hed", "user-key", origin) assert key == "user-key" assert source == "byok" @@ -145,31 +161,42 @@ def test_byok_always_allowed(self, mock_settings, mock_registry): # noqa: ARG00 assert key == "user-key" assert source == "byok" - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_authorized_origin_uses_platform_key(self, mock_settings, mock_registry): # noqa: ARG002 - """Authorized origin should use community or platform key.""" - mock_settings.return_value.openrouter_api_key = "platform-key" - - key, source = _select_api_key("hed", None, "https://hedtags.org") + def test_authorized_origin_uses_platform_key(self, monkeypatch): + """Authorized origin without community key should use platform key.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") + origin = _get_exact_origin("mne") + key, source = _select_api_key("mne", None, origin) assert key == "platform-key" assert source == "platform" - @patch.dict("os.environ", {"OPENROUTER_API_KEY_HED": "community-key"}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_authorized_origin_uses_community_key(self, mock_settings, mock_registry): # noqa: ARG002 + def test_authorized_origin_uses_community_key(self, monkeypatch): """Authorized origin should prefer community key over platform key.""" - mock_settings.return_value.openrouter_api_key = "platform-key" - - key, source = _select_api_key("hed", None, "https://hedtags.org") + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") + hed_info = _get_config("hed") + env_var = hed_info.community_config.openrouter_api_key_env_var + assert env_var, "HED should have openrouter_api_key_env_var configured" + monkeypatch.setenv(env_var, "community-key") + + origin = _get_exact_origin("hed") + key, source = _select_api_key("hed", None, origin) assert key == "community-key" assert source == "community" - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_unauthorized_origin_requires_byok(self, mock_settings, mock_registry): # noqa: ARG002 + def test_authorized_origin_falls_back_to_platform_when_community_key_missing(self, monkeypatch): + """When community env var is configured but not set, fall back to platform key.""" + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") + hed_info = _get_config("hed") + env_var = hed_info.community_config.openrouter_api_key_env_var + monkeypatch.delenv(env_var, raising=False) + + origin = _get_exact_origin("hed") + key, source = _select_api_key("hed", None, origin) + assert key == "platform-key" + assert source == "platform" + + def test_unauthorized_origin_requires_byok(self, monkeypatch): """Unauthorized origin without BYOK should raise 403.""" - mock_settings.return_value.openrouter_api_key = "platform-key" + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") with pytest.raises(HTTPException) as exc_info: _select_api_key("hed", None, "https://evil.com") @@ -178,11 +205,9 @@ def test_unauthorized_origin_requires_byok(self, mock_settings, mock_registry): assert "API key required" in exc_info.value.detail assert "openrouter.ai/keys" in exc_info.value.detail - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_cli_without_byok_requires_key(self, mock_settings, mock_registry): # noqa: ARG002 + def test_cli_without_byok_requires_key(self, monkeypatch): """CLI (no origin) without BYOK should raise 403.""" - mock_settings.return_value.openrouter_api_key = "platform-key" + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") with pytest.raises(HTTPException) as exc_info: _select_api_key("hed", None, None) @@ -190,14 +215,17 @@ def test_cli_without_byok_requires_key(self, mock_settings, mock_registry): # n assert exc_info.value.status_code == 403 assert "API key required" in exc_info.value.detail - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_no_platform_key_configured_raises_500(self, mock_settings, mock_registry): # noqa: ARG002 + def test_no_platform_key_configured_raises_500(self, monkeypatch): """No platform key configured should raise 500 for authorized origins.""" - mock_settings.return_value.openrouter_api_key = None + monkeypatch.delenv("OPENROUTER_API_KEY", raising=False) + hed_info = _get_config("hed") + env_var = hed_info.community_config.openrouter_api_key_env_var + if env_var: + monkeypatch.delenv(env_var, raising=False) + origin = _get_exact_origin("hed") with pytest.raises(HTTPException) as exc_info: - _select_api_key("hed", None, "https://hedtags.org") + _select_api_key("hed", None, origin) assert exc_info.value.status_code == 500 assert "No API key configured" in exc_info.value.detail @@ -206,49 +234,55 @@ def test_no_platform_key_configured_raises_500(self, mock_settings, mock_registr class TestSelectModel: """Tests for _select_model logic.""" - @patch("src.api.routers.community.get_settings") - def test_uses_platform_default_when_no_community_model(self, mock_settings, mock_registry): - """Should use platform default when community has no default_model.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" + def test_uses_community_default_model(self, monkeypatch): + """Should use community default_model when configured.""" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") - community_info = mock_registry.get("hed") - model, provider = _select_model(community_info, None, has_byok=False) + community_info = _get_config("hed") + expected_model = community_info.community_config.default_model + expected_provider = community_info.community_config.default_model_provider + assert expected_model, "HED should have a default_model configured" - assert model == "openai/gpt-oss-120b" - assert provider == "Cerebras" + model, provider = _select_model(community_info, None, has_byok=False) - @patch("src.api.routers.community.get_settings") - def test_uses_community_default_model(self, mock_settings, mock_registry): - """Should use community default_model when configured.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" + assert model == expected_model + assert provider == expected_provider - community_info = mock_registry.get("bids") + def test_uses_platform_default_when_no_community_model(self, monkeypatch): + """Should use platform default when community has no default_model.""" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") + + # Real dataclass instance with no community config (not a mock) + community_info = AssistantInfo( + id="test-no-model", + name="Test Community", + description="Community without a default model", + community_config=None, + ) model, provider = _select_model(community_info, None, has_byok=False) - assert model == "anthropic/claude-3.5-sonnet" - assert provider is None # BIDS doesn't specify provider + assert model == "openai/gpt-oss-120b" + assert provider == "Cerebras" - @patch("src.api.routers.community.get_settings") - def test_custom_model_with_byok_allowed(self, mock_settings, mock_registry): + def test_custom_model_with_byok_allowed(self, monkeypatch): """Custom model should be allowed when user has BYOK.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") - community_info = mock_registry.get("hed") + community_info = _get_config("hed") model, provider = _select_model(community_info, "anthropic/claude-opus-4", has_byok=True) assert model == "anthropic/claude-opus-4" assert provider is None # Custom models use default routing - @patch("src.api.routers.community.get_settings") - def test_custom_model_without_byok_rejected(self, mock_settings, mock_registry): + def test_custom_model_without_byok_rejected(self, monkeypatch): """Custom model without BYOK should raise 403.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") - community_info = mock_registry.get("hed") + community_info = _get_config("hed") with pytest.raises(HTTPException) as exc_info: _select_model(community_info, "anthropic/claude-opus-4", has_byok=False) @@ -258,95 +292,91 @@ def test_custom_model_without_byok_rejected(self, mock_settings, mock_registry): assert "anthropic/claude-opus-4" in exc_info.value.detail assert "requires your own API key" in exc_info.value.detail - @patch("src.api.routers.community.get_settings") - def test_requesting_default_model_explicitly_allowed(self, mock_settings, mock_registry): - """Explicitly requesting the default model should not require BYOK.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" + def test_requesting_default_model_explicitly_allowed(self, monkeypatch): + """Explicitly requesting the community default model should not require BYOK.""" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") - community_info = mock_registry.get("hed") - # User explicitly requests platform default - should not be treated as custom - model, provider = _select_model(community_info, "openai/gpt-oss-120b", has_byok=False) + community_info = _get_config("hed") + default_model = community_info.community_config.default_model + default_provider = community_info.community_config.default_model_provider - assert model == "openai/gpt-oss-120b" - assert provider == "Cerebras" + model, provider = _select_model(community_info, default_model, has_byok=False) + + assert model == default_model + assert provider == default_provider - @patch("src.api.routers.community.get_settings") - def test_requesting_community_default_model_allowed(self, mock_settings, mock_registry): - """Explicitly requesting the community default should not require BYOK.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" + def test_requesting_platform_default_model_allowed(self, monkeypatch): + """Explicitly requesting the platform default should not require BYOK.""" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") - community_info = mock_registry.get("bids") - # User explicitly requests BIDS's default - should not be treated as custom - model, provider = _select_model( - community_info, "anthropic/claude-3.5-sonnet", has_byok=False + # Community with no default_model so platform default is the effective default + community_info = AssistantInfo( + id="test-no-model", + name="Test Community", + description="Community without a default model", + community_config=None, ) + model, provider = _select_model(community_info, "openai/gpt-oss-120b", has_byok=False) - assert model == "anthropic/claude-3.5-sonnet" - assert provider is None + assert model == "openai/gpt-oss-120b" + assert provider == "Cerebras" class TestIntegration: """Integration tests for combined authorization + model selection.""" - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_widget_user_default_model(self, mock_settings, mock_registry): + def test_widget_user_default_model(self, monkeypatch): """Widget user on authorized site with default model.""" - mock_settings.return_value.openrouter_api_key = "platform-key" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - mock_settings.return_value.default_model_provider = "Cerebras" - - # Select API key - api_key, key_source = _select_api_key("hed", None, "https://hedtags.org") + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + monkeypatch.setenv("DEFAULT_MODEL_PROVIDER", "Cerebras") + hed_info = _get_config("hed") + env_var = hed_info.community_config.openrouter_api_key_env_var + if env_var: + monkeypatch.delenv(env_var, raising=False) + + origin = _get_exact_origin("hed") + api_key, key_source = _select_api_key("hed", None, origin) assert key_source in ["platform", "community"] - # Select model - community_info = mock_registry.get("hed") - model, provider = _select_model(community_info, None, has_byok=False) - assert model == "openai/gpt-oss-120b" + model, provider = _select_model(hed_info, None, has_byok=False) + assert model == hed_info.community_config.default_model - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_widget_user_custom_model_rejected(self, mock_settings, mock_registry): + def test_widget_user_custom_model_rejected(self, monkeypatch): """Widget user trying to use custom model should be rejected.""" - mock_settings.return_value.openrouter_api_key = "platform-key" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" - - # API key is allowed (authorized origin) - api_key, key_source = _select_api_key("hed", None, "https://hedtags.org") + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") + hed_info = _get_config("hed") + env_var = hed_info.community_config.openrouter_api_key_env_var + if env_var: + monkeypatch.delenv(env_var, raising=False) + + origin = _get_exact_origin("hed") + api_key, key_source = _select_api_key("hed", None, origin) assert key_source in ["platform", "community"] - # But custom model is rejected (no BYOK) - community_info = mock_registry.get("hed") with pytest.raises(HTTPException) as exc_info: - _select_model(community_info, "anthropic/claude-opus-4", has_byok=False) + _select_model(hed_info, "anthropic/claude-opus-4", has_byok=False) assert exc_info.value.status_code == 403 - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_cli_user_with_byok_and_custom_model(self, mock_settings, mock_registry): + def test_cli_user_with_byok_and_custom_model(self, monkeypatch): """CLI user with BYOK can use custom model.""" - mock_settings.return_value.default_model = "openai/gpt-oss-120b" + monkeypatch.setenv("DEFAULT_MODEL", "openai/gpt-oss-120b") - # CLI provides BYOK api_key, key_source = _select_api_key("hed", "user-key", None) assert api_key == "user-key" assert key_source == "byok" - # Can use custom model - community_info = mock_registry.get("hed") + community_info = _get_config("hed") model, provider = _select_model(community_info, "anthropic/claude-opus-4", has_byok=True) assert model == "anthropic/claude-opus-4" - @patch.dict("os.environ", {}, clear=True) - @patch("src.api.routers.community.get_settings") - def test_cli_user_without_byok_rejected(self, mock_settings, mock_registry): # noqa: ARG002 + def test_cli_user_without_byok_rejected(self, monkeypatch): """CLI user without BYOK should be rejected.""" - mock_settings.return_value.openrouter_api_key = "platform-key" + monkeypatch.setenv("OPENROUTER_API_KEY", "platform-key") - # CLI without BYOK is rejected with pytest.raises(HTTPException) as exc_info: _select_api_key("hed", None, None) assert exc_info.value.status_code == 403 diff --git a/tests/test_api/test_cost_protection.py b/tests/test_api/test_cost_protection.py new file mode 100644 index 0000000..a2b0030 --- /dev/null +++ b/tests/test_api/test_cost_protection.py @@ -0,0 +1,92 @@ +"""Tests for model cost protection. + +Verifies that expensive models are blocked when using platform/community keys, +but allowed when users provide their own API key (BYOK). +""" + +import pytest +from fastapi import HTTPException + +from src.api.routers.community import _check_model_cost +from src.metrics.cost import COST_BLOCK_THRESHOLD, COST_WARN_THRESHOLD, MODEL_PRICING + + +def _models_by_cost(min_rate: float = 0.0, max_rate: float = float("inf")) -> list[str]: + """Return model names with input rates in [min_rate, max_rate).""" + return [m for m, (inp, _) in MODEL_PRICING.items() if min_rate <= inp < max_rate] + + +class TestCheckModelCost: + """Tests for _check_model_cost() pre-invocation cost guard.""" + + def test_cheap_model_on_platform_key_allowed(self) -> None: + """Cheap models should be allowed on platform keys without error.""" + cheap_models = _models_by_cost(max_rate=COST_WARN_THRESHOLD) + assert cheap_models, "Test requires at least one cheap model in MODEL_PRICING" + + _check_model_cost(cheap_models[0], "platform") + _check_model_cost(cheap_models[0], "community") + + def test_expensive_model_blocked_on_platform_key(self) -> None: + """Models above block threshold should be rejected with 403 on platform keys.""" + expensive_models = _models_by_cost(min_rate=COST_BLOCK_THRESHOLD) + assert expensive_models, "Test requires at least one expensive model in MODEL_PRICING" + + with pytest.raises(HTTPException) as exc_info: + _check_model_cost(expensive_models[0], "platform") + assert exc_info.value.status_code == 403 + assert "exceeds the platform limit" in exc_info.value.detail + assert "openrouter.ai/keys" in exc_info.value.detail + + def test_expensive_model_blocked_on_community_key(self) -> None: + """Models above block threshold should also be rejected on community keys.""" + expensive_models = _models_by_cost(min_rate=COST_BLOCK_THRESHOLD) + assert expensive_models, "Test requires at least one expensive model in MODEL_PRICING" + + with pytest.raises(HTTPException) as exc_info: + _check_model_cost(expensive_models[0], "community") + assert exc_info.value.status_code == 403 + + def test_expensive_model_allowed_with_byok(self) -> None: + """BYOK users should be able to use any model, even expensive ones.""" + expensive_models = _models_by_cost(min_rate=COST_BLOCK_THRESHOLD) + assert expensive_models, "Test requires at least one expensive model in MODEL_PRICING" + + _check_model_cost(expensive_models[0], "byok") + + def test_unknown_model_blocked_on_platform_key(self) -> None: + """Unknown models (not in pricing table) should be blocked on platform keys.""" + with pytest.raises(HTTPException) as exc_info: + _check_model_cost("unknown/made-up-model-xyz", "platform") + assert exc_info.value.status_code == 403 + assert "not in the approved pricing list" in exc_info.value.detail + + def test_unknown_model_allowed_with_byok(self) -> None: + """BYOK users with unknown models should also be allowed.""" + _check_model_cost("unknown/made-up-model-xyz", "byok") + + def test_warn_threshold_model_not_blocked(self) -> None: + """Models between warn and block thresholds should be allowed (just warned).""" + warn_only_models = _models_by_cost( + min_rate=COST_WARN_THRESHOLD, max_rate=COST_BLOCK_THRESHOLD + ) + if not warn_only_models: + pytest.skip("No models between warn and block thresholds in current pricing") + + _check_model_cost(warn_only_models[0], "platform") + + def test_model_at_exact_block_threshold_is_blocked(self) -> None: + """A model priced exactly at the block threshold should be blocked.""" + exact_models = [m for m, (inp, _) in MODEL_PRICING.items() if inp == COST_BLOCK_THRESHOLD] + if not exact_models: + pytest.skip("No model priced exactly at block threshold") + + with pytest.raises(HTTPException) as exc_info: + _check_model_cost(exact_models[0], "platform") + assert exc_info.value.status_code == 403 + + def test_thresholds_are_sane(self) -> None: + """Sanity check: warn threshold should be lower than block threshold.""" + assert COST_WARN_THRESHOLD < COST_BLOCK_THRESHOLD + assert COST_WARN_THRESHOLD > 0 + assert COST_BLOCK_THRESHOLD > 0 diff --git a/tests/test_assistants/test_community_yaml_generic.py b/tests/test_assistants/test_community_yaml_generic.py index 0af13b1..402a3c8 100644 --- a/tests/test_assistants/test_community_yaml_generic.py +++ b/tests/test_assistants/test_community_yaml_generic.py @@ -96,7 +96,6 @@ def test_documentation_urls_valid_format(self, community_id): ) @pytest.mark.slow - @pytest.mark.skip(reason="Disabled: upstream HED URL broken (404). See #139") def test_documentation_urls_accessible(self, community_id): """All documentation source URLs should return HTTP 200. diff --git a/tests/test_core/test_llm_service.py b/tests/test_core/test_llm_service.py index 3798e7b..914c639 100644 --- a/tests/test_core/test_llm_service.py +++ b/tests/test_core/test_llm_service.py @@ -23,15 +23,15 @@ class TestLLMServiceModelMappings: def test_openai_models_mapping(self) -> None: """LLMService should have OpenAI model mappings.""" + assert "gpt-5.2" in LLMService.OPENAI_MODELS + assert "gpt-5-mini" in LLMService.OPENAI_MODELS assert "gpt-4o" in LLMService.OPENAI_MODELS - assert "gpt-4o-mini" in LLMService.OPENAI_MODELS - assert "gpt-4-turbo" in LLMService.OPENAI_MODELS def test_anthropic_models_mapping(self) -> None: """LLMService should have Anthropic model mappings.""" - assert "claude-3-5-sonnet" in LLMService.ANTHROPIC_MODELS - assert "claude-3-5-haiku" in LLMService.ANTHROPIC_MODELS - assert "claude-3-opus" in LLMService.ANTHROPIC_MODELS + assert "claude-sonnet-4.6" in LLMService.ANTHROPIC_MODELS + assert "claude-haiku-4.5" in LLMService.ANTHROPIC_MODELS + assert "claude-3.5-sonnet" in LLMService.ANTHROPIC_MODELS def test_default_model_from_settings(self) -> None: """LLMService should get default model from settings.""" @@ -82,7 +82,7 @@ def test_get_model_anthropic_without_key_raises(self) -> None: settings = Settings(anthropic_api_key=None) service = LLMService(settings=settings) with pytest.raises(ValueError, match="Anthropic API key required"): - service.get_model("claude-3-5-sonnet") + service.get_model("claude-3.5-sonnet") def test_get_model_with_api_key_override(self) -> None: """get_model should use provided API key over settings.""" @@ -104,7 +104,7 @@ def test_get_model_anthropic(self) -> None: """get_model should return Anthropic model when configured.""" settings = Settings(anthropic_api_key="test-anthropic-key") service = LLMService(settings=settings) - model = service.get_model("claude-3-5-haiku") + model = service.get_model("claude-3.5-haiku") assert model is not None def test_get_model_default_openrouter(self) -> None: diff --git a/tests/test_core/test_logging.py b/tests/test_core/test_logging.py index a6cf249..114ea73 100644 --- a/tests/test_core/test_logging.py +++ b/tests/test_core/test_logging.py @@ -28,7 +28,7 @@ def test_redacts_api_key_in_message(self) -> None: ) formatted = formatter.format(record) - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted assert "aaaa" not in formatted # Original key should not appear def test_redacts_multiple_api_keys(self) -> None: @@ -45,7 +45,7 @@ def test_redacts_multiple_api_keys(self) -> None: ) formatted = formatter.format(record) - assert formatted.count("sk-or-v1-***[redacted]") == 2 + assert formatted.count("***[key-redacted]") == 2 assert "aaaa" not in formatted assert "bbbb" not in formatted @@ -65,7 +65,7 @@ def test_preserves_non_key_content(self) -> None: formatted = formatter.format(record) assert "Starting request with key" in formatted assert "for user john" in formatted - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted def test_handles_message_without_keys(self) -> None: """Should not modify messages without API keys.""" @@ -98,7 +98,7 @@ def test_redacts_case_insensitive(self) -> None: ) formatted = formatter.format(record) - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted def test_uses_standard_format_fields(self) -> None: """Should support standard logging format fields.""" @@ -116,7 +116,7 @@ def test_uses_standard_format_fields(self) -> None: formatted = formatter.format(record) assert "WARNING" in formatted assert "my.logger" in formatted - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted def test_preserves_partial_matches(self) -> None: """Should not redact strings that partially match key pattern.""" @@ -205,7 +205,7 @@ def test_redacts_api_keys_in_exception_tracebacks(self) -> None: formatted = formatter.format(record) # API key in exception message should be redacted - assert "sk-or-v1-***[redacted]" in formatted + assert "***[key-redacted]" in formatted assert "aaaa" not in formatted def test_concurrent_logging_thread_safety(self) -> None: @@ -253,7 +253,7 @@ def log_with_key(index: int) -> None: assert len(errors) == 0, f"Concurrent logging errors: {errors}" assert len(formatted_logs) == 10 for _, api_key, log in formatted_logs: - assert "sk-or-v1-***[redacted]" in log + assert "***[key-redacted]" in log # Original key should not appear assert api_key not in log @@ -402,7 +402,7 @@ def test_redacts_api_keys_in_json(self) -> None: formatted = formatter.format(record) log_data = json.loads(formatted) - assert "sk-or-v1-***[redacted]" in log_data["message"] + assert "***[key-redacted]" in log_data["message"] assert "aaaa" not in log_data["message"] def test_configures_json_logging(self) -> None: diff --git a/tests/test_knowledge/test_mirror.py b/tests/test_knowledge/test_mirror.py new file mode 100644 index 0000000..4e711e7 --- /dev/null +++ b/tests/test_knowledge/test_mirror.py @@ -0,0 +1,527 @@ +"""Tests for ephemeral database mirror system. + +Tests cover: +- Mirror CRUD lifecycle (create, get, list, delete) +- ContextVar-based DB routing (get_db_path returns mirror path when set) +- active_mirror_context context manager (set/reset, exception safety) +- MirrorInfo invariants (frozen dataclass, validation, serialization) +- Mirror refresh (re-copy from production) +- TTL expiration, clamping, and cleanup +- Resource limits (max mirrors, per-user limits) +- Path traversal prevention +- Corrupt metadata resilience +- run_sync_now input validation +""" + +import json +from datetime import UTC, datetime, timedelta +from pathlib import Path +from unittest.mock import patch + +import pytest + +from src.knowledge.db import ( + get_active_mirror, + get_db_path, + init_db, + reset_active_mirror, + set_active_mirror, +) +from src.knowledge.mirror import ( + CorruptMirrorError, + MirrorInfo, + _get_metadata_path, + _validate_mirror_id, + cleanup_expired_mirrors, + create_mirror, + delete_mirror, + get_mirror, + list_mirrors, + refresh_mirror, +) + + +@pytest.fixture +def data_dir(tmp_path: Path): + """Set up a temporary data directory with a production database.""" + knowledge_dir = tmp_path / "knowledge" + knowledge_dir.mkdir() + + # Create a small production database for "testcommunity" + with ( + patch("src.cli.config.get_data_dir", return_value=tmp_path), + patch("src.knowledge.db.get_data_dir", return_value=tmp_path), + ): + init_db("testcommunity") + assert (knowledge_dir / "testcommunity.db").exists() + + yield tmp_path + + +@pytest.fixture(autouse=True) +def patch_data_dir(data_dir: Path): + """Patch get_data_dir for all tests to use the temp directory.""" + with ( + patch("src.cli.config.get_data_dir", return_value=data_dir), + patch("src.knowledge.db.get_data_dir", return_value=data_dir), + patch("src.knowledge.mirror.get_data_dir", return_value=data_dir), + ): + # Ensure no mirror context leaks between tests + token = set_active_mirror(None) + try: + yield data_dir + finally: + reset_active_mirror(token) + + +class TestContextVar: + """Tests for the ContextVar-based mirror routing.""" + + def test_get_db_path_default(self): + """Without mirror context, returns production path.""" + path = get_db_path("testcommunity") + assert "knowledge" in str(path) + assert "mirrors" not in str(path) + assert path.name == "testcommunity.db" + + def test_get_db_path_with_mirror(self): + """With mirror context set, returns mirror path.""" + token = set_active_mirror("abc123") + try: + path = get_db_path("testcommunity") + assert "mirrors" in str(path) + assert "abc123" in str(path) + assert path.name == "testcommunity.db" + finally: + reset_active_mirror(token) + + def test_get_db_path_resets_after_token(self): + """After resetting the token, returns production path again.""" + token = set_active_mirror("abc123") + path_mirror = get_db_path("testcommunity") + assert "mirrors" in str(path_mirror) + + reset_active_mirror(token) + path_prod = get_db_path("testcommunity") + assert "knowledge" in str(path_prod) + assert "mirrors" not in str(path_prod) + + def test_get_active_mirror_default_none(self): + """Default mirror context is None.""" + assert get_active_mirror() is None + + def test_set_and_get_active_mirror(self): + """set/get active mirror round-trip.""" + token = set_active_mirror("test123") + try: + assert get_active_mirror() == "test123" + finally: + reset_active_mirror(token) + assert get_active_mirror() is None + + def test_invalid_mirror_id_rejected(self): + """Mirror IDs with path traversal chars are rejected at set time.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + set_active_mirror("../etc/passwd") + + +class TestMirrorLifecycle: + """Tests for creating, listing, and deleting mirrors.""" + + def test_create_mirror(self): + """Create a mirror and verify it copies the database.""" + info = create_mirror(community_ids=["testcommunity"], ttl_hours=24) + assert info.mirror_id + assert "testcommunity" in info.community_ids + assert info.size_bytes > 0 + assert not info.is_expired() + + assert info.created_at + assert info.expires_at + + def test_create_mirror_with_label(self): + """Create a mirror with a label.""" + info = create_mirror( + community_ids=["testcommunity"], + label="test-prompt-v2", + ) + assert info.label == "test-prompt-v2" + + def test_create_mirror_with_owner(self): + """Create a mirror with an owner ID.""" + info = create_mirror( + community_ids=["testcommunity"], + owner_id="user123", + ) + assert info.owner_id == "user123" + + def test_create_mirror_nonexistent_community(self): + """Creating a mirror for a nonexistent community raises ValueError.""" + with pytest.raises(ValueError, match="No databases found"): + create_mirror(community_ids=["nonexistent"]) + + def test_get_mirror(self): + """Get mirror by ID returns correct metadata.""" + info = create_mirror(community_ids=["testcommunity"]) + retrieved = get_mirror(info.mirror_id) + assert retrieved is not None + assert retrieved.mirror_id == info.mirror_id + assert retrieved.community_ids == info.community_ids + + def test_get_mirror_nonexistent(self): + """Getting a nonexistent mirror returns None.""" + assert get_mirror("nonexistent") is None + + def test_list_mirrors(self): + """List mirrors returns all created mirrors.""" + info1 = create_mirror(community_ids=["testcommunity"], label="first") + info2 = create_mirror(community_ids=["testcommunity"], label="second") + + mirrors = list_mirrors() + ids = [m.mirror_id for m in mirrors] + assert info1.mirror_id in ids + assert info2.mirror_id in ids + + def test_delete_mirror(self): + """Delete a mirror removes it from disk.""" + info = create_mirror(community_ids=["testcommunity"]) + assert get_mirror(info.mirror_id) is not None + + result = delete_mirror(info.mirror_id) + assert result is True + assert get_mirror(info.mirror_id) is None + + def test_delete_nonexistent_mirror(self): + """Deleting a nonexistent mirror returns False.""" + assert delete_mirror("nonexistent") is False + + +class TestMirrorRefresh: + """Tests for refreshing mirror data from production.""" + + def test_refresh_mirror(self): + """Refresh re-copies production databases.""" + info = create_mirror(community_ids=["testcommunity"]) + refreshed = refresh_mirror(info.mirror_id) + assert refreshed.mirror_id == info.mirror_id + assert refreshed.size_bytes > 0 + + def test_refresh_expired_mirror(self): + """Refreshing an expired mirror raises ValueError.""" + info = create_mirror(community_ids=["testcommunity"], ttl_hours=1) + + # Manually expire the mirror + meta_path = _get_metadata_path(info.mirror_id) + meta = json.loads(meta_path.read_text()) + meta["expires_at"] = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + meta_path.write_text(json.dumps(meta)) + + with pytest.raises(ValueError, match="has expired"): + refresh_mirror(info.mirror_id) + + def test_refresh_nonexistent_mirror(self): + """Refreshing a nonexistent mirror raises ValueError.""" + with pytest.raises(ValueError, match="not found"): + refresh_mirror("nonexistent") + + +class TestTTLAndCleanup: + """Tests for mirror expiration and cleanup.""" + + def test_mirror_not_expired(self): + """Newly created mirror is not expired.""" + info = create_mirror(community_ids=["testcommunity"], ttl_hours=24) + assert not info.is_expired() + + def test_mirror_expired(self): + """Mirror with past expiration is expired.""" + info = MirrorInfo( + mirror_id="test", + community_ids=("testcommunity",), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) - timedelta(hours=1), + ) + assert info.is_expired() + + def test_cleanup_removes_expired(self): + """cleanup_expired_mirrors removes expired mirrors.""" + info = create_mirror(community_ids=["testcommunity"], ttl_hours=1) + + # Manually expire the mirror + meta_path = _get_metadata_path(info.mirror_id) + meta = json.loads(meta_path.read_text()) + meta["expires_at"] = (datetime.now(UTC) - timedelta(hours=1)).isoformat() + meta_path.write_text(json.dumps(meta)) + + deleted = cleanup_expired_mirrors() + assert deleted == 1 + assert get_mirror(info.mirror_id) is None + + def test_cleanup_preserves_active(self): + """cleanup_expired_mirrors preserves non-expired mirrors.""" + info = create_mirror(community_ids=["testcommunity"], ttl_hours=48) + deleted = cleanup_expired_mirrors() + assert deleted == 0 + assert get_mirror(info.mirror_id) is not None + + +class TestResourceLimits: + """Tests for mirror resource limits.""" + + def test_per_user_limit(self): + """BYOK users are limited to MAX_MIRRORS_PER_USER mirrors.""" + from src.knowledge.mirror import MAX_MIRRORS_PER_USER + + # Create max mirrors for user + for i in range(MAX_MIRRORS_PER_USER): + create_mirror( + community_ids=["testcommunity"], + owner_id="user1", + label=f"mirror-{i}", + ) + + # Next one should fail + with pytest.raises(ValueError, match="Maximum mirrors per user"): + create_mirror( + community_ids=["testcommunity"], + owner_id="user1", + ) + + def test_different_users_independent(self): + """Different users have independent mirror limits.""" + from src.knowledge.mirror import MAX_MIRRORS_PER_USER + + for _i in range(MAX_MIRRORS_PER_USER): + create_mirror(community_ids=["testcommunity"], owner_id="user1") + + # Different user should still be able to create mirrors + info = create_mirror(community_ids=["testcommunity"], owner_id="user2") + assert info.owner_id == "user2" + + def test_no_owner_no_per_user_limit(self): + """Mirrors without owner_id (admin) are not subject to per-user limits.""" + from src.knowledge.mirror import MAX_MIRRORS_PER_USER + + for _i in range(MAX_MIRRORS_PER_USER + 1): + info = create_mirror(community_ids=["testcommunity"]) + assert info.owner_id is None + + +class TestPathTraversal: + """Tests for path traversal prevention in mirror IDs.""" + + def test_empty_mirror_id_rejected(self): + """Empty string mirror ID is rejected.""" + with pytest.raises(ValueError, match="Invalid mirror ID length"): + _validate_mirror_id("") + + def test_dots_only_rejected(self): + """Mirror ID of just dots is rejected.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + _validate_mirror_id("..") + + def test_single_dot_rejected(self): + """Mirror ID of a single dot is rejected.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + _validate_mirror_id(".") + + def test_backslash_traversal_rejected(self): + """Mirror ID with backslash path traversal is rejected.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + _validate_mirror_id("..\\etc\\passwd") + + def test_slash_traversal_rejected(self): + """Mirror ID with forward slash is rejected.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + _validate_mirror_id("../etc/passwd") + + def test_too_long_mirror_id_rejected(self): + """Mirror ID exceeding 64 chars is rejected.""" + with pytest.raises(ValueError, match="Invalid mirror ID length"): + _validate_mirror_id("a" * 65) + + def test_valid_mirror_id_accepted(self): + """Valid alphanumeric mirror ID passes validation.""" + _validate_mirror_id("abc123def456") + _validate_mirror_id("mirror-1_test") + + def test_delete_mirror_validates_id(self): + """delete_mirror rejects path traversal mirror IDs.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + delete_mirror("../../etc") + + def test_get_mirror_validates_id(self): + """get_mirror rejects path traversal mirror IDs.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + get_mirror("../../etc") + + +class TestCorruptMetadata: + """Tests for resilience against corrupt metadata files.""" + + def test_corrupt_json_raises_error(self, data_dir: Path): + """Corrupt JSON metadata raises CorruptMirrorError.""" + mirror_dir = data_dir / "mirrors" / "corrupt123" + mirror_dir.mkdir(parents=True) + (mirror_dir / "_metadata.json").write_text("not valid json{{{") + + with pytest.raises(CorruptMirrorError, match="corrupt metadata"): + get_mirror("corrupt123") + + def test_missing_keys_raises_error(self, data_dir: Path): + """Metadata with missing required keys raises CorruptMirrorError.""" + mirror_dir = data_dir / "mirrors" / "missingkeys" + mirror_dir.mkdir(parents=True) + (mirror_dir / "_metadata.json").write_text('{"mirror_id": "missingkeys"}') + + with pytest.raises(CorruptMirrorError, match="corrupt metadata"): + get_mirror("missingkeys") + + def test_corrupt_metadata_does_not_break_list(self, data_dir: Path): + """One corrupt metadata file does not break list_mirrors.""" + # Create a valid mirror + info = create_mirror(community_ids=["testcommunity"], label="valid") + + # Create a corrupt mirror directory + corrupt_dir = data_dir / "mirrors" / "corrupt456" + corrupt_dir.mkdir(parents=True) + (corrupt_dir / "_metadata.json").write_text("garbage data") + + mirrors = list_mirrors() + ids = [m.mirror_id for m in mirrors] + assert info.mirror_id in ids + + def test_corrupt_metadata_does_not_break_cleanup(self, data_dir: Path): + """Corrupt metadata does not crash cleanup_expired_mirrors.""" + corrupt_dir = data_dir / "mirrors" / "corrupt789" + corrupt_dir.mkdir(parents=True) + (corrupt_dir / "_metadata.json").write_text("not json") + + # Should not raise + deleted = cleanup_expired_mirrors() + assert deleted == 0 + + +class TestCreateMirrorCleanup: + """Tests for create_mirror error handling and cleanup.""" + + def test_create_mirror_partial_communities(self): + """Creating with mix of valid and invalid communities only copies valid ones.""" + info = create_mirror(community_ids=["testcommunity", "nonexistent"]) + assert info.community_ids == ("testcommunity",) + assert "nonexistent" not in info.community_ids + + +class TestActiveMirrorContext: + """Tests for the active_mirror_context context manager.""" + + def test_context_manager_sets_and_resets(self): + """Context manager sets mirror ID and resets it after the block.""" + from src.knowledge.db import active_mirror_context + + assert get_active_mirror() is None + with active_mirror_context("abc123"): + assert get_active_mirror() == "abc123" + assert get_active_mirror() is None + + def test_context_manager_resets_on_exception(self): + """Context manager resets mirror ID even if an exception occurs.""" + from src.knowledge.db import active_mirror_context + + assert get_active_mirror() is None + with pytest.raises(RuntimeError), active_mirror_context("abc123"): + assert get_active_mirror() == "abc123" + raise RuntimeError("test error") + assert get_active_mirror() is None + + def test_context_manager_validates_mirror_id(self): + """Context manager rejects invalid mirror IDs.""" + from src.knowledge.db import active_mirror_context + + with pytest.raises(ValueError), active_mirror_context("../invalid"): + pass + + +class TestMirrorInfoInvariants: + """Tests for MirrorInfo construction validation.""" + + def test_empty_community_ids_rejected(self): + """Constructing MirrorInfo with empty community_ids raises ValueError.""" + with pytest.raises(ValueError, match="community_ids must not be empty"): + MirrorInfo( + mirror_id="valid", + community_ids=(), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + + def test_invalid_mirror_id_at_construction(self): + """Constructing MirrorInfo with path-traversal mirror_id raises ValueError.""" + with pytest.raises(ValueError, match="Invalid mirror ID"): + MirrorInfo( + mirror_id="../etc", + community_ids=("test",), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + + def test_frozen_dataclass_is_immutable(self): + """MirrorInfo fields cannot be modified after construction.""" + info = MirrorInfo( + mirror_id="test", + community_ids=("testcommunity",), + created_at=datetime.now(UTC), + expires_at=datetime.now(UTC) + timedelta(hours=1), + ) + with pytest.raises(AttributeError): + info.mirror_id = "changed" # type: ignore[misc] + with pytest.raises(AttributeError): + info.size_bytes = 999 # type: ignore[misc] + + def test_serialization_roundtrip(self): + """MirrorInfo to_dict/from_dict preserves all fields.""" + now = datetime.now(UTC) + original = MirrorInfo( + mirror_id="test123", + community_ids=("hed", "bids"), + created_at=now, + expires_at=now + timedelta(hours=24), + owner_id="user1", + label="my mirror", + ) + data = original.to_dict() + restored = MirrorInfo.from_dict(data) + + assert restored.mirror_id == original.mirror_id + assert restored.community_ids == original.community_ids + assert restored.created_at == original.created_at + assert restored.expires_at == original.expires_at + assert restored.owner_id == original.owner_id + assert restored.label == original.label + # size_bytes is excluded from serialization + assert "size_bytes" not in data + assert restored.size_bytes == 0 + + +class TestTTLClamping: + """Tests for TTL clamping in create_mirror.""" + + def test_ttl_clamped_to_max(self): + """create_mirror clamps TTL to MAX_TTL_HOURS.""" + from src.knowledge.mirror import MAX_TTL_HOURS + + info = create_mirror(community_ids=["testcommunity"], ttl_hours=999) + actual_ttl = (info.expires_at - info.created_at).total_seconds() / 3600 + assert actual_ttl <= MAX_TTL_HOURS + assert actual_ttl == pytest.approx(MAX_TTL_HOURS, abs=0.01) + + +class TestRunSyncNowValidation: + """Tests for run_sync_now input validation.""" + + def test_invalid_sync_type_raises_valueerror(self): + """run_sync_now raises ValueError for unknown sync types.""" + from src.api.scheduler import run_sync_now + + with pytest.raises(ValueError, match="Unknown sync_type"): + run_sync_now("invalid_type") diff --git a/tests/test_metrics/test_cost.py b/tests/test_metrics/test_cost.py index 6286fd0..2efe20c 100644 --- a/tests/test_metrics/test_cost.py +++ b/tests/test_metrics/test_cost.py @@ -50,8 +50,8 @@ def test_qwen_model_cost(self): input_tokens=1_000_000, output_tokens=1_000_000, ) - # input: 0.14, output: 0.34, total: 0.48 - assert cost == 0.48 + # input: 0.07, output: 0.10, total: 0.17 + assert cost == 0.17 def test_expensive_model(self): """Verify cost for an expensive model (Claude Opus 4)."""