|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +import time |
| 4 | +from collections.abc import Awaitable, Callable |
| 5 | +from typing import TYPE_CHECKING, ClassVar |
| 6 | + |
| 7 | +from prometheus_client import Gauge, Histogram |
| 8 | +from starlette.middleware.base import BaseHTTPMiddleware |
| 9 | +from starlette.routing import Match |
| 10 | + |
| 11 | +from archipy.configs.base_config import BaseConfig |
| 12 | +from archipy.helpers.utils.base_utils import BaseUtils |
| 13 | + |
| 14 | +if TYPE_CHECKING: |
| 15 | + from fastapi import Request, Response |
| 16 | + from starlette.types import ASGIApp |
| 17 | + |
| 18 | + |
| 19 | +class FastAPIMetricInterceptor(BaseHTTPMiddleware): |
| 20 | + """A FastAPI interceptor for collecting and reporting metrics using Prometheus. |
| 21 | +
|
| 22 | + This interceptor measures the response time of HTTP requests and records it in a Prometheus histogram. |
| 23 | + It also tracks the number of active requests using a Prometheus gauge. |
| 24 | + The interceptor captures errors and logs them for monitoring purposes. |
| 25 | + """ |
| 26 | + |
| 27 | + ZERO_TO_ONE_SECONDS_BUCKETS: ClassVar[list[float]] = [i / 1000 for i in range(0, 1000, 5)] |
| 28 | + ONE_TO_FIVE_SECONDS_BUCKETS: ClassVar[list[float]] = [i / 100 for i in range(100, 500, 20)] |
| 29 | + FIVE_TO_THIRTY_SECONDS_BUCKETS: ClassVar[list[float]] = [i / 100 for i in range(500, 3000, 50)] |
| 30 | + TOTAL_BUCKETS: ClassVar[list[float]] = ( |
| 31 | + ZERO_TO_ONE_SECONDS_BUCKETS + ONE_TO_FIVE_SECONDS_BUCKETS + FIVE_TO_THIRTY_SECONDS_BUCKETS + [float("inf")] |
| 32 | + ) |
| 33 | + |
| 34 | + RESPONSE_TIME_SECONDS: ClassVar[Histogram] = Histogram( |
| 35 | + "fastapi_response_time_seconds", |
| 36 | + "Time spent processing HTTP request", |
| 37 | + labelnames=("method", "status_code", "path_template"), |
| 38 | + buckets=TOTAL_BUCKETS, |
| 39 | + ) |
| 40 | + |
| 41 | + ACTIVE_REQUESTS: ClassVar[Gauge] = Gauge( |
| 42 | + "fastapi_active_requests", |
| 43 | + "Number of active HTTP requests", |
| 44 | + labelnames=("method", "path_template"), |
| 45 | + ) |
| 46 | + |
| 47 | + _path_template_cache: ClassVar[dict[str, str]] = {} |
| 48 | + |
| 49 | + def __init__(self, app: ASGIApp) -> None: |
| 50 | + """Initialize the FastAPI metric interceptor. |
| 51 | +
|
| 52 | + Args: |
| 53 | + app (ASGIApp): The ASGI application to wrap. |
| 54 | + """ |
| 55 | + super().__init__(app) |
| 56 | + |
| 57 | + async def dispatch(self, request: Request, call_next: Callable[[Request], Awaitable[Response]]) -> Response: |
| 58 | + """Intercept HTTP requests to measure response time and track active requests. |
| 59 | +
|
| 60 | + Args: |
| 61 | + request (Request): The incoming HTTP request. |
| 62 | + call_next (Callable[[Request], Awaitable[Response]]): The next interceptor or endpoint to call. |
| 63 | +
|
| 64 | + Returns: |
| 65 | + Response: The HTTP response from the endpoint. |
| 66 | +
|
| 67 | + Raises: |
| 68 | + Exception: If an exception occurs during request processing, it is captured and re-raised. |
| 69 | + """ |
| 70 | + if not BaseConfig.global_config().PROMETHEUS.IS_ENABLED: |
| 71 | + return await call_next(request) |
| 72 | + |
| 73 | + path_template = self._get_path_template(request) |
| 74 | + method = request.method |
| 75 | + |
| 76 | + self.ACTIVE_REQUESTS.labels(method=method, path_template=path_template).inc() |
| 77 | + |
| 78 | + start_time = time.time() |
| 79 | + status_code = 500 |
| 80 | + |
| 81 | + try: |
| 82 | + response = await call_next(request) |
| 83 | + status_code = response.status_code |
| 84 | + except Exception as exception: |
| 85 | + BaseUtils.capture_exception(exception) |
| 86 | + raise |
| 87 | + else: |
| 88 | + return response |
| 89 | + finally: |
| 90 | + duration = time.time() - start_time |
| 91 | + self.RESPONSE_TIME_SECONDS.labels( |
| 92 | + method=method, |
| 93 | + status_code=status_code, |
| 94 | + path_template=path_template, |
| 95 | + ).observe(duration) |
| 96 | + self.ACTIVE_REQUESTS.labels(method=method, path_template=path_template).dec() |
| 97 | + |
| 98 | + def _get_path_template(self, request: Request) -> str: |
| 99 | + """Extract path template from request by matching against app routes with in-memory caching. |
| 100 | +
|
| 101 | + Args: |
| 102 | + request (Request): The FastAPI request object. |
| 103 | +
|
| 104 | + Returns: |
| 105 | + str: Path template (e.g., /users/{id}) or raw path if no route found. |
| 106 | + """ |
| 107 | + path = request.url.path |
| 108 | + method = request.method |
| 109 | + cache_key = f"{method}:{path}" |
| 110 | + |
| 111 | + if cache_key in self._path_template_cache: |
| 112 | + return self._path_template_cache[cache_key] |
| 113 | + |
| 114 | + for route in request.app.routes: |
| 115 | + match, _ = route.matches(request.scope) |
| 116 | + if match == Match.FULL: |
| 117 | + path_template = route.path |
| 118 | + self._path_template_cache[cache_key] = path_template |
| 119 | + return path_template |
| 120 | + |
| 121 | + self._path_template_cache[cache_key] = path |
| 122 | + return path |
0 commit comments