Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 52 additions & 2 deletions src/handler.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import importlib.util
import logging
import os
import uuid
from pathlib import Path
from typing import Any, Dict, Optional

from logger import setup_logging
from logger import setup_logging, set_request_id, reset_request_id
from unpack_volume import maybe_unpack
from version import format_version_banner

Expand All @@ -21,6 +22,44 @@
logger.info(format_version_banner())


def _extract_request_id(event: Dict[str, Any]) -> str:
"""Extract RunPod job id from event, with safe fallback."""
event_id = event.get("id")
if isinstance(event_id, str) and event_id.strip():
return event_id

job_id = event.get("job_id")
if isinstance(job_id, str) and job_id.strip():
return job_id

job = event.get("job")
if isinstance(job, dict):
nested_job_id = job.get("id")
if isinstance(nested_job_id, str) and nested_job_id.strip():
return nested_job_id

return str(uuid.uuid4())


def _extract_request_id(event: Dict[str, Any]) -> str:
"""Extract RunPod job id from event, with safe fallback."""
event_id = event.get("id")
if isinstance(event_id, str) and event_id.strip():
return event_id

job_id = event.get("job_id")
if isinstance(job_id, str) and job_id.strip():
return job_id

job = event.get("job")
if isinstance(job, dict):
nested_job_id = job.get("id")
if isinstance(nested_job_id, str) and nested_job_id.strip():
return nested_job_id

return str(uuid.uuid4())


def _load_generated_handler() -> Optional[Any]:
"""Load Flash-generated handler if available (deployed QB mode).

Expand Down Expand Up @@ -119,7 +158,15 @@ def _load_generated_handler() -> Optional[Any]:
_generated = _load_generated_handler()

if _generated:
handler = _generated
generated_handler = _generated

async def handler(event: Dict[str, Any]) -> Dict[str, Any]:
request_id_token = set_request_id(_extract_request_id(event))
try:
return await generated_handler(event)
finally:
reset_request_id(request_id_token)

else:
# Fallback: original FunctionRequest handler (backward compatible)
from runpod_flash.protos.remote_execution import FunctionRequest, FunctionResponse
Expand All @@ -128,6 +175,7 @@ def _load_generated_handler() -> Optional[Any]:
async def handler(event: Dict[str, Any]) -> Dict[str, Any]:
"""RunPod serverless function handler with dependency installation."""
output: FunctionResponse
request_id_token = set_request_id(_extract_request_id(event))

try:
executor = RemoteExecutor()
Expand All @@ -139,6 +187,8 @@ async def handler(event: Dict[str, Any]) -> Dict[str, Any]:
success=False,
error=f"Error in handler: {str(error)}",
)
finally:
reset_request_id(request_id_token)

return output.model_dump() # type: ignore[no-any-return]

Expand Down
3 changes: 2 additions & 1 deletion src/log_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from collections import deque
from typing import Optional, Deque, Callable

from logger import get_log_format
from logger import ensure_request_id_filter, get_log_format


class LogStreamer:
Expand Down Expand Up @@ -58,6 +58,7 @@ def start_streaming(
# Use same format as main logging
formatter = logging.Formatter(get_log_format(level))
self._handler.setFormatter(formatter)
ensure_request_id_filter(self._handler)

# Add to root logger
root_logger = logging.getLogger()
Expand Down
65 changes: 54 additions & 11 deletions src/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,47 @@
import logging
import os
import sys
from contextvars import ContextVar, Token
from typing import Union, Optional


_REQUEST_ID: ContextVar[str] = ContextVar("request_id", default="-")


class RequestIdFilter(logging.Filter):
"""Inject request_id from context into each log record."""

def filter(self, record: logging.LogRecord) -> bool:
record.request_id = _REQUEST_ID.get()
return True


_REQUEST_ID_FILTER = RequestIdFilter()


def set_request_id(request_id: Optional[str]) -> Token[str]:
"""Set request id in log context and return reset token."""
if request_id:
normalized = request_id.strip() or "-"
else:
normalized = "-"
return _REQUEST_ID.set(normalized)


def reset_request_id(token: Token[str]) -> None:
"""Reset request id context with token from set_request_id."""
_REQUEST_ID.reset(token)


def get_request_id() -> str:
return _REQUEST_ID.get()


def ensure_request_id_filter(handler: logging.Handler) -> None:
if not any(isinstance(existing, RequestIdFilter) for existing in handler.filters):
handler.addFilter(_REQUEST_ID_FILTER)


def get_log_level() -> int:
"""Get log level from environment variable, defaulting to INFO."""
log_level = os.environ.get("LOG_LEVEL", "INFO").upper()
Expand All @@ -19,9 +57,12 @@ def get_log_level() -> int:
def get_log_format(level: int) -> str:
"""Get appropriate log format based on level, matching runpod-flash style."""
if level == logging.DEBUG:
return "%(asctime)s | %(levelname)-5s | %(name)s | %(filename)s:%(lineno)d | %(message)s"
return (
"%(asctime)s | %(levelname)-5s | %(request_id)s | "
"%(name)s | %(filename)s:%(lineno)d | %(message)s"
)
else:
return "%(asctime)s | %(levelname)-5s | %(message)s"
return "%(asctime)s | %(levelname)-5s | %(request_id)s | %(message)s"


def setup_logging(
Expand All @@ -38,25 +79,27 @@ def setup_logging(
stream: Output stream for logs
fmt: Custom format string (auto-selected based on level if None)
"""
# Determine log level
if level is None:
level = get_log_level()
resolved_level = get_log_level()
elif isinstance(level, str):
level = getattr(logging, level.upper(), logging.INFO)
resolved_level = getattr(logging, level.upper(), logging.INFO)
else:
resolved_level = level

# Determine format based on requested level
if fmt is None:
fmt = get_log_format(level)
fmt = get_log_format(resolved_level)

# Configure root logger
root_logger = logging.getLogger()
root_logger.setLevel(level)
root_logger.setLevel(resolved_level)

if not root_logger.hasHandlers():
handler = logging.StreamHandler(stream)
handler.setFormatter(logging.Formatter(fmt))
ensure_request_id_filter(handler)
root_logger.addHandler(handler)

# When DEBUG is requested, silence the noisy module
if level == logging.DEBUG:
for handler in root_logger.handlers:
ensure_request_id_filter(handler)

if resolved_level == logging.DEBUG:
logging.getLogger("filelock").setLevel(logging.INFO)
51 changes: 50 additions & 1 deletion tests/unit/test_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import base64
import cloudpickle
from unittest.mock import patch, AsyncMock
from handler import handler, _load_generated_handler
from handler import handler, _load_generated_handler, _extract_request_id
from runpod_flash.protos.remote_execution import FunctionResponse


Expand Down Expand Up @@ -115,6 +115,38 @@ async def test_handler_response_serialization(self):
assert "result" in result
assert result["stdout"] == "Test output"

@pytest.mark.asyncio
async def test_handler_uses_runpod_job_id_for_log_context(self):
"""Handler should set logger request_id using RunPod job id."""
event = {
"id": "job-123",
"input": {
"function_name": "test_func",
"function_code": "def test_func(): return 'success'",
"args": [],
"kwargs": {},
},
}

with (
patch("handler.set_request_id", return_value="token") as mock_set,
patch("handler.reset_request_id") as mock_reset,
patch("handler.RemoteExecutor") as mock_executor_class,
):
mock_executor = AsyncMock()
mock_executor_class.return_value = mock_executor
mock_executor.ExecuteFunction.return_value = FunctionResponse(
success=True,
result=base64.b64encode(cloudpickle.dumps("success")).decode("utf-8"),
stdout="Function executed successfully",
)

result = await handler(event)

assert result["success"] is True
mock_set.assert_called_once_with("job-123")
mock_reset.assert_called_once_with("token")

@pytest.mark.asyncio
async def test_handler_class_execution(self):
"""Test handler with class execution request."""
Expand Down Expand Up @@ -254,3 +286,20 @@ def test_returns_none_when_handler_not_callable(self, tmp_path):
result = _load_generated_handler()

assert result is None


class TestExtractRequestId:
"""Tests for request-id extraction in QB handler."""

def test_extract_uses_event_id(self):
assert _extract_request_id({"id": "job-main"}) == "job-main"

def test_extract_uses_job_id_fallback(self):
assert _extract_request_id({"job_id": "job-fallback"}) == "job-fallback"

def test_extract_uses_nested_job_id_fallback(self):
assert _extract_request_id({"job": {"id": "job-nested"}}) == "job-nested"

def test_extract_generates_uuid_when_missing(self):
with patch("handler.uuid.uuid4", return_value="generated-id"):
assert _extract_request_id({}) == "generated-id"
49 changes: 49 additions & 0 deletions tests/unit/test_logger.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import logging

from logger import (
RequestIdFilter,
ensure_request_id_filter,
get_log_format,
reset_request_id,
set_request_id,
)


def test_log_format_includes_request_id_for_info():
fmt = get_log_format(logging.INFO)
assert "%(request_id)s" in fmt


def test_log_format_includes_request_id_for_debug():
fmt = get_log_format(logging.DEBUG)
assert "%(request_id)s" in fmt


def test_request_id_filter_injects_context_value():
log_record = logging.LogRecord(
name="test",
level=logging.INFO,
pathname=__file__,
lineno=1,
msg="hello",
args=(),
exc_info=None,
)
token = set_request_id("job-abc")

try:
request_id_filter = RequestIdFilter()
assert request_id_filter.filter(log_record) is True
assert log_record.request_id == "job-abc"
finally:
reset_request_id(token)


def test_ensure_request_id_filter_attaches_only_once():
handler = logging.StreamHandler()

ensure_request_id_filter(handler)
ensure_request_id_filter(handler)

request_id_filters = [f for f in handler.filters if isinstance(f, RequestIdFilter)]
assert len(request_id_filters) == 1