From c7d378a405f59b95ffbfc2344c06efea824dae57 Mon Sep 17 00:00:00 2001 From: jhcipar Date: Thu, 12 Mar 2026 09:05:25 -0400 Subject: [PATCH] feat: add request id to logs --- src/handler.py | 54 +++++++++++++++++++++++++++++-- src/log_streamer.py | 3 +- src/logger.py | 65 +++++++++++++++++++++++++++++++------- tests/unit/test_handler.py | 51 +++++++++++++++++++++++++++++- tests/unit/test_logger.py | 49 ++++++++++++++++++++++++++++ 5 files changed, 207 insertions(+), 15 deletions(-) create mode 100644 tests/unit/test_logger.py diff --git a/src/handler.py b/src/handler.py index d51b41c..01f3637 100644 --- a/src/handler.py +++ b/src/handler.py @@ -1,10 +1,11 @@ import importlib.util import logging import os +import uuid from pathlib import Path from typing import Any, Dict, Optional -from logger import setup_logging +from logger import setup_logging, set_request_id, reset_request_id from unpack_volume import maybe_unpack from version import format_version_banner @@ -21,6 +22,44 @@ logger.info(format_version_banner()) +def _extract_request_id(event: Dict[str, Any]) -> str: + """Extract RunPod job id from event, with safe fallback.""" + event_id = event.get("id") + if isinstance(event_id, str) and event_id.strip(): + return event_id + + job_id = event.get("job_id") + if isinstance(job_id, str) and job_id.strip(): + return job_id + + job = event.get("job") + if isinstance(job, dict): + nested_job_id = job.get("id") + if isinstance(nested_job_id, str) and nested_job_id.strip(): + return nested_job_id + + return str(uuid.uuid4()) + + +def _extract_request_id(event: Dict[str, Any]) -> str: + """Extract RunPod job id from event, with safe fallback.""" + event_id = event.get("id") + if isinstance(event_id, str) and event_id.strip(): + return event_id + + job_id = event.get("job_id") + if isinstance(job_id, str) and job_id.strip(): + return job_id + + job = event.get("job") + if isinstance(job, dict): + nested_job_id = job.get("id") + if isinstance(nested_job_id, str) and nested_job_id.strip(): + return nested_job_id + + return str(uuid.uuid4()) + + def _load_generated_handler() -> Optional[Any]: """Load Flash-generated handler if available (deployed QB mode). @@ -119,7 +158,15 @@ def _load_generated_handler() -> Optional[Any]: _generated = _load_generated_handler() if _generated: - handler = _generated + generated_handler = _generated + + async def handler(event: Dict[str, Any]) -> Dict[str, Any]: + request_id_token = set_request_id(_extract_request_id(event)) + try: + return await generated_handler(event) + finally: + reset_request_id(request_id_token) + else: # Fallback: original FunctionRequest handler (backward compatible) from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse @@ -128,6 +175,7 @@ def _load_generated_handler() -> Optional[Any]: async def handler(event: Dict[str, Any]) -> Dict[str, Any]: """RunPod serverless function handler with dependency installation.""" output: FunctionResponse + request_id_token = set_request_id(_extract_request_id(event)) try: executor = RemoteExecutor() @@ -139,6 +187,8 @@ async def handler(event: Dict[str, Any]) -> Dict[str, Any]: success=False, error=f"Error in handler: {str(error)}", ) + finally: + reset_request_id(request_id_token) return output.model_dump() # type: ignore[no-any-return] diff --git a/src/log_streamer.py b/src/log_streamer.py index 1ec61b8..80c1fe3 100644 --- a/src/log_streamer.py +++ b/src/log_streamer.py @@ -11,7 +11,7 @@ from collections import deque from typing import Optional, Deque, Callable -from logger import get_log_format +from logger import ensure_request_id_filter, get_log_format class LogStreamer: @@ -58,6 +58,7 @@ def start_streaming( # Use same format as main logging formatter = logging.Formatter(get_log_format(level)) self._handler.setFormatter(formatter) + ensure_request_id_filter(self._handler) # Add to root logger root_logger = logging.getLogger() diff --git a/src/logger.py b/src/logger.py index 9f97000..e939f3b 100644 --- a/src/logger.py +++ b/src/logger.py @@ -7,9 +7,47 @@ import logging import os import sys +from contextvars import ContextVar, Token from typing import Union, Optional +_REQUEST_ID: ContextVar[str] = ContextVar("request_id", default="-") + + +class RequestIdFilter(logging.Filter): + """Inject request_id from context into each log record.""" + + def filter(self, record: logging.LogRecord) -> bool: + record.request_id = _REQUEST_ID.get() + return True + + +_REQUEST_ID_FILTER = RequestIdFilter() + + +def set_request_id(request_id: Optional[str]) -> Token[str]: + """Set request id in log context and return reset token.""" + if request_id: + normalized = request_id.strip() or "-" + else: + normalized = "-" + return _REQUEST_ID.set(normalized) + + +def reset_request_id(token: Token[str]) -> None: + """Reset request id context with token from set_request_id.""" + _REQUEST_ID.reset(token) + + +def get_request_id() -> str: + return _REQUEST_ID.get() + + +def ensure_request_id_filter(handler: logging.Handler) -> None: + if not any(isinstance(existing, RequestIdFilter) for existing in handler.filters): + handler.addFilter(_REQUEST_ID_FILTER) + + def get_log_level() -> int: """Get log level from environment variable, defaulting to INFO.""" log_level = os.environ.get("LOG_LEVEL", "INFO").upper() @@ -19,9 +57,12 @@ def get_log_level() -> int: def get_log_format(level: int) -> str: """Get appropriate log format based on level, matching runpod-flash style.""" if level == logging.DEBUG: - return "%(asctime)s | %(levelname)-5s | %(name)s | %(filename)s:%(lineno)d | %(message)s" + return ( + "%(asctime)s | %(levelname)-5s | %(request_id)s | " + "%(name)s | %(filename)s:%(lineno)d | %(message)s" + ) else: - return "%(asctime)s | %(levelname)-5s | %(message)s" + return "%(asctime)s | %(levelname)-5s | %(request_id)s | %(message)s" def setup_logging( @@ -38,25 +79,27 @@ def setup_logging( stream: Output stream for logs fmt: Custom format string (auto-selected based on level if None) """ - # Determine log level if level is None: - level = get_log_level() + resolved_level = get_log_level() elif isinstance(level, str): - level = getattr(logging, level.upper(), logging.INFO) + resolved_level = getattr(logging, level.upper(), logging.INFO) + else: + resolved_level = level - # Determine format based on requested level if fmt is None: - fmt = get_log_format(level) + fmt = get_log_format(resolved_level) - # Configure root logger root_logger = logging.getLogger() - root_logger.setLevel(level) + root_logger.setLevel(resolved_level) if not root_logger.hasHandlers(): handler = logging.StreamHandler(stream) handler.setFormatter(logging.Formatter(fmt)) + ensure_request_id_filter(handler) root_logger.addHandler(handler) - # When DEBUG is requested, silence the noisy module - if level == logging.DEBUG: + for handler in root_logger.handlers: + ensure_request_id_filter(handler) + + if resolved_level == logging.DEBUG: logging.getLogger("filelock").setLevel(logging.INFO) diff --git a/tests/unit/test_handler.py b/tests/unit/test_handler.py index ab63135..351d943 100644 --- a/tests/unit/test_handler.py +++ b/tests/unit/test_handler.py @@ -4,7 +4,7 @@ import base64 import cloudpickle from unittest.mock import patch, AsyncMock -from handler import handler, _load_generated_handler +from handler import handler, _load_generated_handler, _extract_request_id from runpod_flash.protos.remote_execution import FunctionResponse @@ -115,6 +115,38 @@ async def test_handler_response_serialization(self): assert "result" in result assert result["stdout"] == "Test output" + @pytest.mark.asyncio + async def test_handler_uses_runpod_job_id_for_log_context(self): + """Handler should set logger request_id using RunPod job id.""" + event = { + "id": "job-123", + "input": { + "function_name": "test_func", + "function_code": "def test_func(): return 'success'", + "args": [], + "kwargs": {}, + }, + } + + with ( + patch("handler.set_request_id", return_value="token") as mock_set, + patch("handler.reset_request_id") as mock_reset, + patch("handler.RemoteExecutor") as mock_executor_class, + ): + mock_executor = AsyncMock() + mock_executor_class.return_value = mock_executor + mock_executor.ExecuteFunction.return_value = FunctionResponse( + success=True, + result=base64.b64encode(cloudpickle.dumps("success")).decode("utf-8"), + stdout="Function executed successfully", + ) + + result = await handler(event) + + assert result["success"] is True + mock_set.assert_called_once_with("job-123") + mock_reset.assert_called_once_with("token") + @pytest.mark.asyncio async def test_handler_class_execution(self): """Test handler with class execution request.""" @@ -254,3 +286,20 @@ def test_returns_none_when_handler_not_callable(self, tmp_path): result = _load_generated_handler() assert result is None + + +class TestExtractRequestId: + """Tests for request-id extraction in QB handler.""" + + def test_extract_uses_event_id(self): + assert _extract_request_id({"id": "job-main"}) == "job-main" + + def test_extract_uses_job_id_fallback(self): + assert _extract_request_id({"job_id": "job-fallback"}) == "job-fallback" + + def test_extract_uses_nested_job_id_fallback(self): + assert _extract_request_id({"job": {"id": "job-nested"}}) == "job-nested" + + def test_extract_generates_uuid_when_missing(self): + with patch("handler.uuid.uuid4", return_value="generated-id"): + assert _extract_request_id({}) == "generated-id" diff --git a/tests/unit/test_logger.py b/tests/unit/test_logger.py new file mode 100644 index 0000000..05fa2fc --- /dev/null +++ b/tests/unit/test_logger.py @@ -0,0 +1,49 @@ +import logging + +from logger import ( + RequestIdFilter, + ensure_request_id_filter, + get_log_format, + reset_request_id, + set_request_id, +) + + +def test_log_format_includes_request_id_for_info(): + fmt = get_log_format(logging.INFO) + assert "%(request_id)s" in fmt + + +def test_log_format_includes_request_id_for_debug(): + fmt = get_log_format(logging.DEBUG) + assert "%(request_id)s" in fmt + + +def test_request_id_filter_injects_context_value(): + log_record = logging.LogRecord( + name="test", + level=logging.INFO, + pathname=__file__, + lineno=1, + msg="hello", + args=(), + exc_info=None, + ) + token = set_request_id("job-abc") + + try: + request_id_filter = RequestIdFilter() + assert request_id_filter.filter(log_record) is True + assert log_record.request_id == "job-abc" + finally: + reset_request_id(token) + + +def test_ensure_request_id_filter_attaches_only_once(): + handler = logging.StreamHandler() + + ensure_request_id_filter(handler) + ensure_request_id_filter(handler) + + request_id_filters = [f for f in handler.filters if isinstance(f, RequestIdFilter)] + assert len(request_id_filters) == 1