Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions src/api/routers/community.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,12 +620,19 @@ def _check_model_cost(model: str, key_source: str) -> None:

pricing = MODEL_PRICING.get(model)
if pricing is None:
logger.warning(
"Model %s not in pricing table; allowing without cost check. "
logger.error(
"Model %s not in pricing table; blocking on platform/community key. "
"Add this model to MODEL_PRICING in src/metrics/cost.py.",
model,
)
return
raise HTTPException(
status_code=403,
detail=(
f"Model '{model}' is not in the approved pricing list and cannot be used "
"with platform or community keys. To use this model, provide your own "
"API key via the X-OpenRouter-Key header."
),
)
input_rate = pricing.input_per_1m

if input_rate >= COST_BLOCK_THRESHOLD:
Expand Down
82 changes: 75 additions & 7 deletions src/api/routers/mirrors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@
from pydantic import BaseModel, Field, field_validator

from src.api.security import RequireAuth
from src.core.validation import is_safe_identifier
from src.knowledge.db import active_mirror_context
from src.knowledge.mirror import (
CorruptMirrorError,
MirrorInfo,
create_mirror,
delete_mirror,
get_mirror,
get_mirror_db_path,
is_safe_identifier,
list_mirrors,
refresh_mirror,
)
Expand Down Expand Up @@ -93,6 +94,16 @@ class RefreshMirrorRequest(BaseModel):
default=None, description="Specific communities to refresh, or null for all"
)

@field_validator("community_ids")
@classmethod
def validate_community_ids(cls, v: list[str] | None) -> list[str] | None:
if v is None:
return v
for cid in v:
if not is_safe_identifier(cid):
raise ValueError(f"Invalid community ID: {cid!r}")
return list(dict.fromkeys(v))


SyncType = Literal["github", "papers", "docstrings", "mailman", "faq", "beps", "all"]

Expand Down Expand Up @@ -135,6 +146,12 @@ async def create_mirror_endpoint(
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except OSError as e:
logger.error("Failed to create mirror: %s", e, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to create mirror due to a server filesystem error.",
) from e

logger.info(
"Mirror created: %s (communities=%s, owner=%s)",
Expand All @@ -161,7 +178,18 @@ async def get_mirror_endpoint(
_auth: RequireAuth,
) -> MirrorResponse:
"""Get metadata for a specific mirror."""
info = get_mirror(mirror_id)
try:
info = get_mirror(mirror_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid mirror ID format: '{mirror_id}'",
)
except CorruptMirrorError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
)
if not info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -182,6 +210,11 @@ async def delete_mirror_endpoint(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Mirror '{mirror_id}' not found",
)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid mirror ID format: '{mirror_id}'",
)
except OSError as e:
logger.error("Failed to delete mirror %s: %s", mirror_id, e, exc_info=True)
raise HTTPException(
Expand All @@ -205,6 +238,11 @@ async def refresh_mirror_endpoint(
info = refresh_mirror(mirror_id, community_ids=body.community_ids)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
except CorruptMirrorError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
)

logger.info("Mirror refreshed via API: %s", mirror_id)
return MirrorResponse.from_info(info)
Expand All @@ -222,7 +260,18 @@ async def sync_mirror_endpoint(
databases instead of production. Supports sync types: github, papers,
docstrings, mailman, faq, beps, or all.
"""
info = get_mirror(mirror_id)
try:
info = get_mirror(mirror_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid mirror ID format: '{mirror_id}'",
)
except CorruptMirrorError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
)
if not info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand All @@ -235,8 +284,9 @@ async def sync_mirror_endpoint(
)

# Run sync in a thread with the mirror context explicitly copied.
# asyncio.to_thread copies ContextVars on Python 3.12+ but not 3.11,
# so we capture the context and run within it for compatibility.
# We use run_in_executor with an explicit context copy instead of
# asyncio.to_thread because to_thread only copies ContextVars
# automatically on Python 3.12+.
from src.api.scheduler import run_sync_now

ctx = contextvars.copy_context()
Expand All @@ -246,14 +296,21 @@ def _run_sync_in_mirror() -> dict[str, int]:
return run_sync_now(body.sync_type)

try:
results = await asyncio.get_event_loop().run_in_executor(None, ctx.run, _run_sync_in_mirror)
loop = asyncio.get_running_loop()
results = await loop.run_in_executor(None, ctx.run, _run_sync_in_mirror)
total = sum(results.values())
return MirrorSyncResponse(
message=f"Sync completed: {total} items synced into mirror {mirror_id}",
items_synced=results,
)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e)) from e
except OSError as e:
logger.error("Mirror sync I/O error for %s: %s", mirror_id, e, exc_info=True)
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Sync failed due to a filesystem error: {e}",
) from e
except Exception as e:
logger.error("Mirror sync failed for %s: %s", mirror_id, e, exc_info=True)
raise HTTPException(
Expand All @@ -274,7 +331,18 @@ async def download_mirror_db(
"""
from fastapi.responses import FileResponse

info = get_mirror(mirror_id)
try:
info = get_mirror(mirror_id)
except ValueError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Invalid mirror ID format: '{mirror_id}'",
)
except CorruptMirrorError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Mirror '{mirror_id}' has corrupt metadata. Delete and recreate it.",
)
if not info:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
Expand Down
6 changes: 3 additions & 3 deletions src/api/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,14 +292,14 @@ def _run_beps_sync_for_community(community_id: str) -> bool:
def _cleanup_mirrors() -> None:
"""Remove expired ephemeral database mirrors."""
global _mirror_cleanup_failures
try:
from src.knowledge.mirror import cleanup_expired_mirrors
from src.knowledge.mirror import CorruptMirrorError, cleanup_expired_mirrors

try:
deleted = cleanup_expired_mirrors()
if deleted:
logger.info("Mirror cleanup: removed %d expired mirrors", deleted)
_mirror_cleanup_failures = 0
except Exception:
except (OSError, ValueError, CorruptMirrorError):
_mirror_cleanup_failures += 1
logger.error(
"Mirror cleanup failed (consecutive failures: %d)",
Expand Down
26 changes: 17 additions & 9 deletions src/core/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
"""Secure logging configuration with API key redaction.

Provides a custom log formatter that automatically redacts OpenRouter API keys
from log messages to prevent credential exposure in centralized logging systems.
Provides a custom log formatter that automatically redacts API keys
(OpenRouter, Anthropic, OpenAI) from log messages to prevent credential
exposure in centralized logging systems.

Supports both text and JSON-structured logging formats.
"""
Expand All @@ -17,12 +18,19 @@
class SecureFormatter(logging.Formatter):
"""Custom log formatter that redacts API keys from log messages.

Automatically detects and redacts OpenRouter API keys in the format
sk-or-v1-[64 hex chars] to prevent accidental credential exposure.
Automatically detects and redacts API keys from OpenRouter, Anthropic,
and OpenAI to prevent accidental credential exposure.
"""

# Pattern to match OpenRouter API keys: sk-or-v1-[64 hex chars]
API_KEY_PATTERN = re.compile(r"sk-or-v1-[0-9a-f]{64}", re.IGNORECASE)
# Patterns for API keys from various providers.
# IGNORECASE as defense-in-depth; real keys use lowercase hex.
API_KEY_PATTERN = re.compile(
r"sk-or-v1-[0-9a-f]{64}" # OpenRouter: sk-or-v1-[64 hex chars]
r"|sk-ant-[a-zA-Z0-9_-]{80,}" # Anthropic: sk-ant-...
r"|sk-proj-[a-zA-Z0-9_-]{40,}" # OpenAI project keys: sk-proj-...
r"|sk-[a-zA-Z0-9]{48,}", # Generic OpenAI keys: sk-...
re.IGNORECASE,
)

def format(self, record: logging.LogRecord) -> str:
"""Format log record and redact any API keys.
Expand Down Expand Up @@ -55,7 +63,7 @@ def format(self, record: logging.LogRecord) -> str:
if len(formatted) > 100_000: # 100KB limit
formatted = formatted[:100_000] + "... [truncated for safety]"

formatted = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", formatted)
formatted = self.API_KEY_PATTERN.sub("***[key-redacted]", formatted)
except re.error as e:
# Regex pattern is broken - this is a code bug
print(f"CRITICAL: Redaction regex failed: {e}", file=sys.stderr)
Expand Down Expand Up @@ -137,7 +145,7 @@ def format(self, record: logging.LogRecord) -> str:
json_str = json.dumps(log_entry, default=str)

# Redact API keys from the JSON string
json_str = self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", json_str)
json_str = self.API_KEY_PATTERN.sub("***[key-redacted]", json_str)

return json_str

Expand All @@ -154,7 +162,7 @@ def format(self, record: logging.LogRecord) -> str:
"original_message": safe_msg,
}
fallback_json = json.dumps(error_entry)
return self.API_KEY_PATTERN.sub("sk-or-v1-***[redacted]", fallback_json)
return self.API_KEY_PATTERN.sub("***[key-redacted]", fallback_json)
except Exception as e:
# Unexpected errors - surface to stderr and re-raise
print(f"CRITICAL: Unexpected error in SecureJSONFormatter: {e}", file=sys.stderr)
Expand Down
5 changes: 3 additions & 2 deletions src/core/services/litellm_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def create_openrouter_llm(
provider: Specific provider to use (e.g., "Cerebras", "DeepInfra/FP8").
Ignored for Anthropic models, which always use "Anthropic" provider.
user_id: User identifier for cache optimization (sticky routing)
enable_caching: Enable prompt caching. If None (default), enabled for all models.
OpenRouter/LiteLLM gracefully handles models that don't support caching.
enable_caching: Enable prompt caching. If None (default), caching is requested
for all models. Models that do not support caching will ignore the
cache_control markers without error.

Returns:
LLM instance configured for OpenRouter
Expand Down
2 changes: 1 addition & 1 deletion src/core/services/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_model(
model_name: Model name. Supports:
- OpenRouter format: 'creator/model' (e.g., 'openai/gpt-oss-120b', 'qwen/qwen3-235b')
- Direct OpenAI: 'gpt-4o', 'gpt-4o-mini', etc.
- Direct Anthropic: 'claude-3-5-sonnet', etc.
- Direct Anthropic: 'claude-3.5-sonnet', etc.
If not provided, uses settings.default_model.
api_key: Optional API key override (for BYOK).
temperature: Model temperature. If not provided, uses settings.llm_temperature.
Expand Down
13 changes: 13 additions & 0 deletions src/core/validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""Shared input validation utilities.

Provides common validation functions used across modules for
preventing path traversal and ensuring safe identifiers.
"""


def is_safe_identifier(value: str) -> bool:
"""Check if a string is a safe identifier (alphanumeric, hyphens, underscores).

Used for both mirror IDs and community IDs to prevent path traversal.
"""
return bool(value) and value.replace("-", "").replace("_", "").isalnum()
10 changes: 6 additions & 4 deletions src/knowledge/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
from pathlib import Path

from src.cli.config import get_data_dir
from src.knowledge.mirror import _validate_mirror_id, is_safe_identifier
from src.core.validation import is_safe_identifier
from src.knowledge.mirror import _validate_mirror_id

logger = logging.getLogger(__name__)

# ContextVar for transparent mirror routing. When set, get_db_path() returns
# the mirror's database path instead of the production path.
# ContextVar is safe for concurrent async tasks; each task gets its own copy,
# so mirror routing in one request does not affect other requests.
# Safe for concurrent requests because the middleware sets and resets the
# value around each request's lifecycle. Nested async calls within the
# same request inherit the value.
_active_mirror_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"_active_mirror_id", default=None
)
Expand Down Expand Up @@ -448,7 +450,7 @@ def get_db_path(project: str = "hed") -> Path:
"Use only alphanumeric characters, hyphens, and underscores."
)

mirror_id = _active_mirror_id.get()
mirror_id = get_active_mirror()
if mirror_id:
# mirror_id was already validated when set via set_active_mirror()
return get_data_dir() / "mirrors" / mirror_id / f"{project}.db"
Expand Down
Loading