diff --git a/apps/api/app.py b/apps/api/app.py index 00061393c..cce4c0137 100644 --- a/apps/api/app.py +++ b/apps/api/app.py @@ -47,6 +47,7 @@ NormalizedVEX, ) from .pipeline import PipelineOrchestrator +from .rate_limiter import create_rate_limiter from .routes.enhanced import router as enhanced_router from .upload_manager import ChunkUploadManager @@ -54,18 +55,59 @@ JWT_ALGORITHM = "HS256" JWT_EXP_MINUTES = int(os.getenv("FIXOPS_JWT_EXP_MINUTES", "120")) - -_jwt_secret_env = os.getenv("FIXOPS_JWT_SECRET") -if _jwt_secret_env: - JWT_SECRET = _jwt_secret_env -else: +_JWT_SECRET_FILE = Path(os.getenv("FIXOPS_DATA_DIR", ".fixops_data")) / ".jwt_secret" + + +def _load_or_generate_jwt_secret() -> str: + """ + Load JWT secret from environment or file, or generate and persist a new one. + + Priority: + 1. FIXOPS_JWT_SECRET environment variable + 2. Persisted secret file + 3. Generate new secret and persist to file (demo mode only) + + Returns: + str: The JWT secret key + + Raises: + ValueError: If no secret is available in non-demo mode + """ + # Priority 1: Environment variable + env_secret = os.getenv("FIXOPS_JWT_SECRET") + if env_secret: + logger.info("Using JWT secret from FIXOPS_JWT_SECRET environment variable") + return env_secret + + # Priority 2: Persisted file + try: + _JWT_SECRET_FILE.parent.mkdir(parents=True, exist_ok=True) + if _JWT_SECRET_FILE.exists(): + secret = _JWT_SECRET_FILE.read_text().strip() + if secret: + logger.info(f"Loaded persisted JWT secret from {_JWT_SECRET_FILE}") + return secret + except Exception as e: + logger.warning(f"Failed to read JWT secret file: {e}") + + # Priority 3: Generate and persist (demo mode only) mode = os.getenv("FIXOPS_MODE", "").lower() if mode == "demo": - JWT_SECRET = secrets.token_hex(32) - logger.warning( - "JWT_SECRET not set - using auto-generated secret. " - "Tokens will be invalid after restart. Set FIXOPS_JWT_SECRET for persistence." - ) + secret = secrets.token_hex(32) + try: + _JWT_SECRET_FILE.write_text(secret) + _JWT_SECRET_FILE.chmod(0o600) # Secure permissions + logger.warning( + f"Generated and persisted new JWT secret to {_JWT_SECRET_FILE}. " + "For production, set FIXOPS_JWT_SECRET environment variable." + ) + return secret + except Exception as e: + logger.error(f"Failed to persist JWT secret: {e}") + logger.warning( + "Using non-persisted secret. Tokens will be invalid after restart." + ) + return secret else: raise ValueError( "FIXOPS_JWT_SECRET environment variable must be set in non-demo mode. " @@ -73,6 +115,9 @@ ) +JWT_SECRET = _load_or_generate_jwt_secret() + + def generate_access_token(data: Dict[str, Any]) -> str: """Generate a signed JWT access token with an expiry.""" @@ -114,6 +159,19 @@ def create_app() -> FastAPI: allow_methods=["*"], allow_headers=["*"], ) + + # Add rate limiting middleware + rate_limit_enabled = os.getenv("FIXOPS_RATE_LIMIT_ENABLED", "true").lower() == "true" + rate_limit_requests = int(os.getenv("FIXOPS_RATE_LIMIT_REQUESTS", "100")) + rate_limit_window = int(os.getenv("FIXOPS_RATE_LIMIT_WINDOW_SECONDS", "60")) + + app.add_middleware( + create_rate_limiter( + requests_per_window=rate_limit_requests, + window_seconds=rate_limit_window, + enabled=rate_limit_enabled + ) + ) normalizer = InputNormalizer() orchestrator = PipelineOrchestrator() @@ -688,6 +746,14 @@ async def upload_chunk( raise HTTPException( status_code=404, detail=f"Stage '{stage}' not recognised" ) + + # Validate offset parameter + if offset is not None and offset < 0: + raise HTTPException( + status_code=400, + detail=f"Invalid offset: {offset}. Offset must be non-negative." + ) + data = await chunk.read() try: session = upload_manager.append_chunk(session_id, data, offset=offset) diff --git a/apps/api/rate_limiter.py b/apps/api/rate_limiter.py new file mode 100644 index 000000000..cca0b5260 --- /dev/null +++ b/apps/api/rate_limiter.py @@ -0,0 +1,199 @@ +""" +Rate limiting middleware for FastAPI to prevent brute force attacks and API abuse. + +This module provides a simple in-memory rate limiter that can be used to protect +API endpoints from excessive requests. +""" + +from __future__ import annotations + +import time +from collections import defaultdict +from dataclasses import dataclass, field +from threading import Lock +from typing import Callable, Dict, Tuple + +from fastapi import HTTPException, Request, Response +from starlette.middleware.base import BaseHTTPMiddleware + + +@dataclass +class RateLimitConfig: + """Configuration for rate limiting.""" + + requests_per_window: int = 10 # Maximum requests per time window + window_seconds: int = 60 # Time window in seconds + enabled: bool = True # Enable/disable rate limiting + + +@dataclass +class ClientRequestTracker: + """Track request counts and timestamps for a single client.""" + + request_count: int = 0 + window_start: float = field(default_factory=time.time) + + def is_rate_limited(self, config: RateLimitConfig) -> bool: + """Check if client has exceeded rate limit.""" + current_time = time.time() + + # Reset window if expired + if current_time - self.window_start >= config.window_seconds: + self.window_start = current_time + self.request_count = 0 + + # Check if limit exceeded + if self.request_count >= config.requests_per_window: + return True + + return False + + def increment(self): + """Increment request count.""" + self.request_count += 1 + + +class RateLimitMiddleware(BaseHTTPMiddleware): + """ + Middleware to enforce rate limiting on API requests. + + Tracks requests per IP address and enforces configurable rate limits. + Uses in-memory storage with periodic cleanup of stale entries. + """ + + def __init__(self, app, config: RateLimitConfig): + super().__init__(app) + self.config = config + self._trackers: Dict[str, ClientRequestTracker] = defaultdict(ClientRequestTracker) + self._lock = Lock() + self._last_cleanup = time.time() + self._cleanup_interval = 300 # Cleanup every 5 minutes + + def _get_client_identifier(self, request: Request) -> str: + """ + Extract client identifier from request. + + Uses X-Forwarded-For header if present (for proxied requests), + otherwise falls back to client IP. + """ + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + # Take the first IP in the chain + return forwarded.split(",")[0].strip() + + if request.client: + return request.client.host + + return "unknown" + + def _cleanup_stale_trackers(self): + """Remove trackers that haven't been used recently.""" + current_time = time.time() + + if current_time - self._last_cleanup < self._cleanup_interval: + return + + with self._lock: + stale_keys = [ + key for key, tracker in self._trackers.items() + if current_time - tracker.window_start > self.config.window_seconds * 2 + ] + for key in stale_keys: + del self._trackers[key] + + self._last_cleanup = current_time + + async def dispatch(self, request: Request, call_next: Callable) -> Response: + """ + Process request and enforce rate limiting. + + Args: + request: Incoming HTTP request + call_next: Next middleware/handler in chain + + Returns: + Response object + + Raises: + HTTPException: If rate limit is exceeded + """ + if not self.config.enabled: + return await call_next(request) + + # Periodic cleanup + self._cleanup_stale_trackers() + + # Get client identifier + client_id = self._get_client_identifier(request) + + # Check rate limit + with self._lock: + tracker = self._trackers[client_id] + + if tracker.is_rate_limited(self.config): + # Calculate retry-after time + time_until_reset = self.config.window_seconds - ( + time.time() - tracker.window_start + ) + + raise HTTPException( + status_code=429, + detail={ + "error": "Rate limit exceeded", + "retry_after_seconds": int(time_until_reset) + 1, + "limit": self.config.requests_per_window, + "window_seconds": self.config.window_seconds, + }, + headers={ + "Retry-After": str(int(time_until_reset) + 1), + "X-RateLimit-Limit": str(self.config.requests_per_window), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(int(tracker.window_start + self.config.window_seconds)), + } + ) + + # Increment request count + tracker.increment() + + # Calculate remaining requests + remaining = self.config.requests_per_window - tracker.request_count + + # Process request + response = await call_next(request) + + # Add rate limit headers to response + response.headers["X-RateLimit-Limit"] = str(self.config.requests_per_window) + response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) + response.headers["X-RateLimit-Reset"] = str( + int(tracker.window_start + self.config.window_seconds) + ) + + return response + + +def create_rate_limiter( + requests_per_window: int = 100, + window_seconds: int = 60, + enabled: bool = True +) -> RateLimitMiddleware: + """ + Factory function to create a rate limiter middleware. + + Args: + requests_per_window: Maximum requests allowed per time window + window_seconds: Time window duration in seconds + enabled: Whether rate limiting is enabled + + Returns: + Configured RateLimitMiddleware instance + """ + config = RateLimitConfig( + requests_per_window=requests_per_window, + window_seconds=window_seconds, + enabled=enabled + ) + + def middleware_factory(app): + return RateLimitMiddleware(app, config) + + return middleware_factory diff --git a/core/configuration.py b/core/configuration.py index 191f41d1d..a1a56d324 100644 --- a/core/configuration.py +++ b/core/configuration.py @@ -60,16 +60,31 @@ def _parse_overlay(text: str) -> Dict[str, Any]: def _deep_merge( base: MutableMapping[str, Any], overrides: Mapping[str, Any] ) -> MutableMapping[str, Any]: + """ + Deep merge two dictionaries, returning a new dictionary without mutating the base. + + Args: + base: Base configuration dictionary (not modified) + overrides: Override values to merge in + + Returns: + New dictionary with merged values + """ + import copy + + # Create a deep copy to avoid mutating the base dictionary + result = copy.deepcopy(base) + for key, value in overrides.items(): if ( - key in base - and isinstance(base[key], MutableMapping) + key in result + and isinstance(result[key], MutableMapping) and isinstance(value, Mapping) ): - base[key] = _deep_merge(base[key], value) # type: ignore[assignment] + result[key] = _deep_merge(result[key], value) # type: ignore[assignment] else: - base[key] = value # type: ignore[assignment] - return base + result[key] = copy.deepcopy(value) # type: ignore[assignment] + return result _DEFAULT_GUARDRAIL_MATURITY = "scaling"