-
-
Notifications
You must be signed in to change notification settings - Fork 0
feat(security): Add rate limiting middleware to prevent API abuse #117
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. Weβll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
89c285b
7d190c9
98179c3
76d79d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,199 @@ | ||
| """ | ||
| Rate limiting middleware for FastAPI to prevent brute force attacks and API abuse. | ||
|
|
||
| This module provides a simple in-memory rate limiter that can be used to protect | ||
| API endpoints from excessive requests. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import time | ||
| from collections import defaultdict | ||
| from dataclasses import dataclass, field | ||
| from threading import Lock | ||
| from typing import Callable, Dict, Tuple | ||
|
|
||
| from fastapi import HTTPException, Request, Response | ||
| from starlette.middleware.base import BaseHTTPMiddleware | ||
|
|
||
|
|
||
| @dataclass | ||
| class RateLimitConfig: | ||
| """Configuration for rate limiting.""" | ||
|
|
||
| requests_per_window: int = 10 # Maximum requests per time window | ||
| window_seconds: int = 60 # Time window in seconds | ||
| enabled: bool = True # Enable/disable rate limiting | ||
|
|
||
|
|
||
| @dataclass | ||
| class ClientRequestTracker: | ||
| """Track request counts and timestamps for a single client.""" | ||
|
|
||
| request_count: int = 0 | ||
| window_start: float = field(default_factory=time.time) | ||
|
|
||
| def is_rate_limited(self, config: RateLimitConfig) -> bool: | ||
| """Check if client has exceeded rate limit.""" | ||
| current_time = time.time() | ||
|
|
||
| # Reset window if expired | ||
| if current_time - self.window_start >= config.window_seconds: | ||
| self.window_start = current_time | ||
| self.request_count = 0 | ||
|
|
||
| # Check if limit exceeded | ||
| if self.request_count >= config.requests_per_window: | ||
| return True | ||
|
|
||
| return False | ||
|
|
||
| def increment(self): | ||
| """Increment request count.""" | ||
| self.request_count += 1 | ||
|
|
||
|
|
||
| class RateLimitMiddleware(BaseHTTPMiddleware): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The in-memory rate limiter is not effective in the project's default multi-process environment. The deployment configuration specifies multiple workers, but the rate-limiter's state is not shared between them, which undermines the feature's goal. A shared store like Redis, which is already in the stack, should be used. Prompt for AI agents |
||
| """ | ||
| Middleware to enforce rate limiting on API requests. | ||
|
|
||
| Tracks requests per IP address and enforces configurable rate limits. | ||
| Uses in-memory storage with periodic cleanup of stale entries. | ||
| """ | ||
|
|
||
| def __init__(self, app, config: RateLimitConfig): | ||
| super().__init__(app) | ||
| self.config = config | ||
| self._trackers: Dict[str, ClientRequestTracker] = defaultdict(ClientRequestTracker) | ||
| self._lock = Lock() | ||
| self._last_cleanup = time.time() | ||
| self._cleanup_interval = 300 # Cleanup every 5 minutes | ||
|
|
||
| def _get_client_identifier(self, request: Request) -> str: | ||
| """ | ||
| Extract client identifier from request. | ||
|
|
||
| Uses X-Forwarded-For header if present (for proxied requests), | ||
| otherwise falls back to client IP. | ||
| """ | ||
| forwarded = request.headers.get("X-Forwarded-For") | ||
| if forwarded: | ||
| # Take the first IP in the chain | ||
| return forwarded.split(",")[0].strip() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rate limiter is vulnerable to IP spoofing because it incorrectly parses the Prompt for AI agents |
||
|
|
||
| if request.client: | ||
| return request.client.host | ||
|
|
||
| return "unknown" | ||
|
|
||
| def _cleanup_stale_trackers(self): | ||
| """Remove trackers that haven't been used recently.""" | ||
| current_time = time.time() | ||
|
|
||
| if current_time - self._last_cleanup < self._cleanup_interval: | ||
| return | ||
|
|
||
| with self._lock: | ||
| stale_keys = [ | ||
| key for key, tracker in self._trackers.items() | ||
| if current_time - tracker.window_start > self.config.window_seconds * 2 | ||
| ] | ||
| for key in stale_keys: | ||
| del self._trackers[key] | ||
|
|
||
| self._last_cleanup = current_time | ||
|
|
||
| async def dispatch(self, request: Request, call_next: Callable) -> Response: | ||
| """ | ||
| Process request and enforce rate limiting. | ||
|
|
||
| Args: | ||
| request: Incoming HTTP request | ||
| call_next: Next middleware/handler in chain | ||
|
|
||
| Returns: | ||
| Response object | ||
|
|
||
| Raises: | ||
| HTTPException: If rate limit is exceeded | ||
| """ | ||
| if not self.config.enabled: | ||
| return await call_next(request) | ||
|
|
||
| # Periodic cleanup | ||
| self._cleanup_stale_trackers() | ||
|
|
||
| # Get client identifier | ||
| client_id = self._get_client_identifier(request) | ||
|
|
||
| # Check rate limit | ||
| with self._lock: | ||
| tracker = self._trackers[client_id] | ||
|
|
||
| if tracker.is_rate_limited(self.config): | ||
| # Calculate retry-after time | ||
| time_until_reset = self.config.window_seconds - ( | ||
| time.time() - tracker.window_start | ||
| ) | ||
|
|
||
| raise HTTPException( | ||
| status_code=429, | ||
| detail={ | ||
| "error": "Rate limit exceeded", | ||
| "retry_after_seconds": int(time_until_reset) + 1, | ||
| "limit": self.config.requests_per_window, | ||
| "window_seconds": self.config.window_seconds, | ||
| }, | ||
| headers={ | ||
| "Retry-After": str(int(time_until_reset) + 1), | ||
| "X-RateLimit-Limit": str(self.config.requests_per_window), | ||
| "X-RateLimit-Remaining": "0", | ||
| "X-RateLimit-Reset": str(int(tracker.window_start + self.config.window_seconds)), | ||
| } | ||
| ) | ||
|
|
||
| # Increment request count | ||
| tracker.increment() | ||
|
|
||
| # Calculate remaining requests | ||
| remaining = self.config.requests_per_window - tracker.request_count | ||
|
|
||
| # Process request | ||
| response = await call_next(request) | ||
|
|
||
| # Add rate limit headers to response | ||
| response.headers["X-RateLimit-Limit"] = str(self.config.requests_per_window) | ||
| response.headers["X-RateLimit-Remaining"] = str(max(0, remaining)) | ||
| response.headers["X-RateLimit-Reset"] = str( | ||
| int(tracker.window_start + self.config.window_seconds) | ||
| ) | ||
|
|
||
| return response | ||
|
|
||
|
|
||
| def create_rate_limiter( | ||
| requests_per_window: int = 100, | ||
| window_seconds: int = 60, | ||
| enabled: bool = True | ||
| ) -> RateLimitMiddleware: | ||
| """ | ||
| Factory function to create a rate limiter middleware. | ||
|
|
||
| Args: | ||
| requests_per_window: Maximum requests allowed per time window | ||
| window_seconds: Time window duration in seconds | ||
| enabled: Whether rate limiting is enabled | ||
|
|
||
| Returns: | ||
| Configured RateLimitMiddleware instance | ||
| """ | ||
| config = RateLimitConfig( | ||
| requests_per_window=requests_per_window, | ||
| window_seconds=window_seconds, | ||
| enabled=enabled | ||
| ) | ||
|
|
||
| def middleware_factory(app): | ||
| return RateLimitMiddleware(app, config) | ||
|
|
||
| return middleware_factory | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,16 +60,31 @@ def _parse_overlay(text: str) -> Dict[str, Any]: | |
| def _deep_merge( | ||
| base: MutableMapping[str, Any], overrides: Mapping[str, Any] | ||
| ) -> MutableMapping[str, Any]: | ||
| """ | ||
| Deep merge two dictionaries, returning a new dictionary without mutating the base. | ||
|
|
||
| Args: | ||
| base: Base configuration dictionary (not modified) | ||
| overrides: Override values to merge in | ||
|
|
||
| Returns: | ||
| New dictionary with merged values | ||
| """ | ||
| import copy | ||
|
|
||
| # Create a deep copy to avoid mutating the base dictionary | ||
| result = copy.deepcopy(base) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Returning a deep-copied result here stops _deep_merge from mutating the provided base mapping; callers like load_overlay() still expect in-place merging, so profile overrides are no longer applied and overlay configuration breaks. Prompt for AI agents |
||
|
|
||
| for key, value in overrides.items(): | ||
| if ( | ||
| key in base | ||
| and isinstance(base[key], MutableMapping) | ||
| key in result | ||
| and isinstance(result[key], MutableMapping) | ||
| and isinstance(value, Mapping) | ||
| ): | ||
| base[key] = _deep_merge(base[key], value) # type: ignore[assignment] | ||
| result[key] = _deep_merge(result[key], value) # type: ignore[assignment] | ||
| else: | ||
| base[key] = value # type: ignore[assignment] | ||
| return base | ||
| result[key] = copy.deepcopy(value) # type: ignore[assignment] | ||
| return result | ||
|
Comment on lines
60
to
+87
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Changing Useful? React with πΒ / π. |
||
|
|
||
|
|
||
| _DEFAULT_GUARDRAIL_MATURITY = "scaling" | ||
|
|
||
Check failure
Code scanning / CodeQL
Clear-text storage of sensitive information High
Copilot Autofix
AI 7 months ago
The best way to fix this problem is to ensure that the JWT secret, if persisted to disk, is stored encrypted rather than as cleartext. The recommended approach is to encrypt the secret before writing it out, using a key not persisted with the secret (ideally, sourced from environment variables or a secure store like a vault). If that's not possible, the demo-mode secret should at least be encrypted using a local key derived at runtime (e.g., from a password or entropy unique to the local host).
Detailed steps for this fix:
cryptographymodule's Fernet symmetric encryption for strong, simple encryption.cryptography.fernet.File/region to change:
_load_or_generate_jwt_secret).Requirements:
cryptography.fernet.