diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index f28413d04..0372659c2 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -192,10 +192,13 @@ def initialize_all(app: FastAPI, args): initialize_routing_logic( args.routing_logic, session_key=args.session_key, + tolerate_waiting_requests=args.tolerate_waiting_requests, lmcache_controller_port=args.lmcache_controller_port, prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, kv_aware_threshold=args.kv_aware_threshold, + enable_request_logging=args.enable_request_logging, + request_log_dir=args.request_log_dir, ) # Initialize feature gates diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 21ad2c9e1..26337f1af 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -99,6 +99,10 @@ def validate_args(args): raise ValueError( "Session key must be provided when using session routing logic." ) + if args.routing_logic == "cache_aware_load_balancing" and args.session_key is None: + raise ValueError( + "Session key must be provided when using cache_aware_load_balancing routing logic." + ) if args.log_stats and args.log_stats_interval <= 0: raise ValueError("Log stats interval must be greater than 0.") if args.engine_stats_interval <= 0: @@ -180,12 +184,34 @@ def parse_args(): choices=[ "roundrobin", "session", + "cache_aware_load_balancing", "kvaware", "prefixaware", "disaggregated_prefill", ], help="The routing logic to use", ) + + parser.add_argument( + "--tolerate-waiting-requests", + type=int, + default=10, + help="The number of waiting requests to tolerate in cache-aware load balancing router.", + ) + + parser.add_argument( + "--enable-request-logging", + action="store_true", + help="Enable request logging, record the routing decision and performance data of each request", + ) + + parser.add_argument( + "--request-log-dir", + type=str, + default=None, + help="The directory to store the request log file, if provided, the log will be written to this file", + ) + parser.add_argument( "--lmcache-controller-port", type=int, diff --git a/src/vllm_router/request_logger.py b/src/vllm_router/request_logger.py new file mode 100644 index 000000000..664eea420 --- /dev/null +++ b/src/vllm_router/request_logger.py @@ -0,0 +1,150 @@ +# Copyright 2024-2025 The vLLM Production Stack Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import time +from datetime import datetime + +from fastapi import Request + +from vllm_router.log import init_logger + +logger = init_logger(__name__) + + +class RequestLogger: + """Request logger for recording routing decisions and performance data""" + + def __init__(self): + """Initialize request logger""" + self.log_enabled = False + self.log_file = None + self.requests_bodies_dir = None + # Dictionary for temporarily storing request arrival times + self.arrival_times = {} + + def enable_logging(self, enabled: bool = True, log_dir: str = None): + """ + Enable or disable logging + + Args: + enabled: Whether to enable logging + log_dir: Log directory path, if provided write to file, otherwise save to project path + """ + self.log_enabled = enabled + self.log_file = None # Reset log_file + + if enabled: + # If log_dir is not specified, use project path + if not log_dir: + log_dir = os.getcwd() + logger.info( + f"No log directory specified, will use current working directory: {log_dir}" + ) + + try: + # Ensure log directory exists + os.makedirs(log_dir, exist_ok=True) + + # Create request body storage directory + self.requests_bodies_dir = os.path.join(log_dir, "request_bodies") + os.makedirs(self.requests_bodies_dir, exist_ok=True) + + # Create log file name with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join( + log_dir, f"router_requests_logs_{timestamp}.csv" + ) + + # Create log file and write header + with open(log_file, "w") as f: + f.write( + "timestamp,request_id,conversation_id,arrival_time,routing_method,target_engine,process_time\n" + ) + + self.log_file = log_file + logger.info(f"Request routing logs will be written to file: {log_file}") + except Exception as e: + logger.error(f"Failed to create log file: {str(e)}") + self.log_file = None + + logger.info("Request routing logging has been enabled") + + def log_request_routed( + self, + arrival_time: float, + request_id: str, + routing_method: str, + target_engine: str, + session_id: str = None, + process_time: float = None, + ): + """Record request routing decision and timestamp, and write to file (if enabled)""" + if not self.log_enabled or not self.log_file: + return + + # Ensure session_id has a value + session_id = session_id or "unknown" + + # Write to file + try: + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + log_line = f"{timestamp},{request_id},{session_id},{arrival_time},{routing_method},{target_engine},{process_time}\n" + + with open(self.log_file, "a") as f: + f.write(log_line) + + except Exception as e: + logger.error(f"Failed to write to log file: {str(e)}") + + def log_request_body(self, request_body, request_id=None): + """ + Log request body to a separate file + + Args: + request_body: Request body content + request_id: Request ID, if None then try to extract from request body + """ + if not self.log_enabled or not self.requests_bodies_dir: + return + + if not request_id: + # Try to extract request_id from request body + try: + body_json = json.loads(request_body) + request_id = body_json.get("id", str(int(time.time()))) + except: + request_id = str(int(time.time())) + + # Create file name + file_path = os.path.join(self.requests_bodies_dir, f"{request_id}.json") + + # Write request body to file + try: + with open(file_path, "wb") as f: + if isinstance(request_body, bytes): + f.write(request_body) + else: + f.write(request_body.encode("utf-8")) + except Exception as e: + logger.error(f"Failed to write request body file: {str(e)}") + + def clear_logs(self): + """Clear temporarily stored arrival times""" + self.arrival_times.clear() + + +# Create global request logger instance +request_logger = RequestLogger() diff --git a/src/vllm_router/routers/metrics_router.py b/src/vllm_router/routers/metrics_router.py index 276023949..491c4977a 100644 --- a/src/vllm_router/routers/metrics_router.py +++ b/src/vllm_router/routers/metrics_router.py @@ -32,6 +32,7 @@ num_prefill_requests, num_requests_running, num_requests_swapped, + routing_method_qps, ) from vllm_router.stats.engine_stats import get_engine_stats_scraper from vllm_router.stats.request_stats import get_request_stats_monitor @@ -99,6 +100,27 @@ async def metrics(): avg_itl.labels(server=server).set(stat.avg_itl) num_requests_swapped.labels(server=server).set(stat.num_swapped_requests) + # ----------------------------------------------------------------------------- + + # Routing method QPS + routing_methods_qps = get_request_stats_monitor().get_routing_methods_qps() + # Save all known routing methods, for resetting + known_methods = set() + # Find registered routing method labels + for labels, _ in routing_method_qps._metrics.items(): + if labels and len(labels) > 0: + method = labels[0] # Assuming labels is (method,) format + known_methods.add(method) + # Reset QPS of all known but currently inactive routing methods to 0 + for method in known_methods: + if method not in routing_methods_qps: + routing_method_qps.labels(method=method).set(0) + # Set QPS of currently active routing methods + for method, qps_value in routing_methods_qps.items(): + routing_method_qps.labels(method=method).set(qps_value) + + # ----------------------------------------------------------------------------- + # Engine statistics (GPU prefix cache metrics) engine_stats = get_engine_stats_scraper().get_engine_stats() for server, engine_stat in engine_stats.items(): diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index ef544c9c6..c6f62c547 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -20,7 +20,7 @@ import random import socket import threading -from typing import Dict, List +from typing import Dict, List, Tuple import requests from fastapi import Request @@ -53,6 +53,7 @@ class RoutingLogic(str, enum.Enum): ROUND_ROBIN = "roundrobin" SESSION_BASED = "session" + CACHE_AWARE_LOAD_BALANCING = "cache_aware_load_balancing" KVAWARE = "kvaware" PREFIXAWARE = "prefixaware" DISAGGREGATED_PREFILL = "disaggregated_prefill" @@ -88,6 +89,7 @@ def _update_hash_ring(self, endpoints: List["EndpointInfo"]): """ Update the hash ring with the current list of endpoints. """ + # logger.debug(f"Updating hash ring with endpoints: {endpoints}") # Extract endpoint URLs endpoint_urls = [endpoint.url for endpoint in endpoints] @@ -158,7 +160,8 @@ def route_request( len_engines = len(endpoints) chosen = sorted(endpoints, key=lambda e: e.url)[self.req_id % len_engines] self.req_id += 1 - return chosen.url + routing_method = "round_robin" + return chosen.url, routing_method class SessionRouter(RoutingInterface): @@ -210,7 +213,203 @@ def route_request( # Use the hash ring to get the endpoint for the session ID url = self.hash_ring.get_node(session_id) - return url + # If the initial engine is not found in engine_stats + if url not in engine_stats: + logger.warning( + f"Engine {url} not found in engine_stats" + ) + + routing_method = "session_based" + return url, routing_method + + +class CacheAwareLoadBalancingRouter(RoutingInterface): + """ + Routing algorithm that combines load balancing with KV Cache hit rate awareness + + This algorithm considers three key factors: + 1. Engine load (number of queuing requests, number of running requests) + 2. Estimated KV cache hit rate (for specific sessions) + """ + + def __init__(self, session_key: str = None, tolerate_waiting_requests: int = 20): + if hasattr(self, "_initialized"): + return + + if session_key is None: + raise ValueError( + "CacheAwareLoadBalancingRouter must be initialized with a session_key" + ) + + self.session_key = session_key + self.tolerate_waiting_requests = tolerate_waiting_requests + + # Initialize hash ring + self.hash_ring = HashRing() + + self.req_id = 0 # Request ID, used for round-robin selection + + self._initialized = True + + def _calculate_engine_load_score( + self, + engine_url: str, + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + ) -> float: + """ + Calculate engine load score + + Lower score indicates lighter engine load + + Load factors: load score (running requests * 0.02 + queuing requests * 0.1) + """ + if engine_url not in engine_stats: + return 0.0 # No statistics available, assume load is 0 + + # Get engine statistics + stats = engine_stats[engine_url] + + # Basic load factors: running requests and queuing requests + running_load = stats.num_running_requests * 0.02 # Running requests weight + queuing_load = ( + stats.num_queuing_requests * 0.1 + ) # Queuing requests weight (slightly higher) + + # Calculate total load score + total_load_score = running_load + queuing_load + + return total_load_score + + def _select_best_engine( + self, + session_id: str, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + ) -> Tuple[str, str]: + """ + Select the best engine + 1. First determine which engine the request corresponds to based on hash_ring + 2. Check the queue situation of that engine (num_queuing_requests) + 3. If there are queuing requests (>tolerate_waiting_requests), try to find an engine without queue + 4. If all engines have queues, assign engine based on probability + 5. If the initial engine has no queuing requests, use session-based routing (i.e., hash_ring result) + """ + # Update hash ring to reflect currently available endpoints + self._update_hash_ring(endpoints) + + # Use hash_ring to get the initial engine_url + initial_engine_url = self.hash_ring.get_node(session_id) + + # If the initial engine is not found in engine_stats + if initial_engine_url not in engine_stats: + logger.warning( + f"Engine {initial_engine_url} not found in engine_stats" + ) + return initial_engine_url, "cache_aware" + + # Check the queuing situation of the initial engine + if ( + engine_stats[initial_engine_url].num_queuing_requests + < self.tolerate_waiting_requests + ): + # If queuing requests are less than tolerate_waiting_requests, use it directly + logger.debug( + f"Session {session_id} initial engine waiting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}" + ) + return initial_engine_url, "cache_aware" + + # Try to find engines without queue + engines_without_queue = [] + for info in endpoints: + url = info.url + # Add boundary check for engine_stats + if url in engine_stats and engine_stats[url].num_queuing_requests == 0: + engines_without_queue.append(url) + + # If there are engines without queue, randomly select one + if engines_without_queue: + selected_engine = random.choice(engines_without_queue) + logger.debug( + f"Session {session_id} redirect to no queue engine: {selected_engine}" + ) + return selected_engine, "redirect_to_no_queue_engine" + + # All engines have queues, select one based on improved probability calculation + routing_method = "probability_based" + + # Filter endpoints that have engine stats + valid_endpoints = [info for info in endpoints if info.url in engine_stats] + if not valid_endpoints: + # Fallback to initial engine if no valid stats available + logger.warning("No valid engine stats available, falling back to initial engine") + return initial_engine_url, "cache_aware_fallback" + + # Calculate total queue length from valid endpoints only + total_queue_length = sum( + engine_stats[info.url].num_queuing_requests + for info in valid_endpoints + ) + + # Fixed probability calculation: inverse of normalized queue length + queue_lengths = [engine_stats[info.url].num_queuing_requests for info in valid_endpoints] + max_queue = max(queue_lengths) + + # Calculate inverse weights (higher weight for lower queue length) + # Add small epsilon to avoid division by zero + epsilon = 0.1 + inverse_weights = [(max_queue - queue_len + epsilon) for queue_len in queue_lengths] + total_weight = sum(inverse_weights) + + # Normalize to probabilities + probabilities = [weight / total_weight for weight in inverse_weights] + + selected_engine = random.choices( + [info.url for info in valid_endpoints], weights=probabilities + )[0] + + logger.debug( + f"Session {session_id} probability based routing to: {selected_engine}, " + f"queue_lengths: {queue_lengths}, probabilities: {[f'{p:.3f}' for p in probabilities]}" + ) + return selected_engine, routing_method + + def route_request( + self, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + request: Request, + ) -> str: + """ + Intelligent request routing, combining load awareness and cache hit rate prediction + + For requests with session ID, intelligent selection is made based on KV cache hit rate prediction and load conditions + For requests without session ID, engine selection is purely based on load balancing + """ + # Extract session ID + session_id = request.headers.get(self.session_key, None) + logger.debug(f"Got session id: {session_id}") + + routing_method = "load_balancing" + + if session_id is None: + # No session ID, use pure load balancing strategy + engine_url = min( + [e.url for e in endpoints], + key=lambda url: self._calculate_engine_load_score( + url, engine_stats, request_stats + ), + ) + routing_method = "load_based" + else: + # Has session ID, use comprehensive strategy + engine_url, routing_method = self._select_best_engine( + session_id, endpoints, engine_stats, request_stats + ) + + return engine_url, routing_method class KvawareRouter(RoutingInterface): @@ -459,12 +658,27 @@ def route_request( def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs ) -> RoutingInterface: + + from vllm_router.request_logger import request_logger + + if kwargs.get("enable_request_logging"): + logger.info("Enabling request logging") + request_logger.enable_logging(True, kwargs.get("request_log_dir")) + if routing_logic == RoutingLogic.ROUND_ROBIN: logger.info("Initializing round-robin routing logic") return RoundRobinRouter() elif routing_logic == RoutingLogic.SESSION_BASED: logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}") return SessionRouter(kwargs.get("session_key")) + elif routing_logic == RoutingLogic.CACHE_AWARE_LOAD_BALANCING: + logger.info( + f"Initializing cache-aware load balancing routing logic with kwargs: {kwargs}" + ) + router = CacheAwareLoadBalancingRouter( + kwargs.get("session_key"), kwargs.get("tolerate_waiting_requests") + ) + return router elif routing_logic == RoutingLogic.KVAWARE: logger.info("Initializing kvaware routing logic") router = KvawareRouter( @@ -498,6 +712,17 @@ def reconfigure_routing_logic( ): if cls in SingletonABCMeta._instances: del SingletonABCMeta._instances[cls] + + # Re-configure request logging + from vllm_router.request_logger import request_logger + + if kwargs.get("enable_request_logging"): + logger.info("Re-enabling request logging with new configuration") + request_logger.enable_logging(True, kwargs.get("request_log_dir")) + else: + # If request logging is not enabled, disable it + request_logger.enable_logging(False) + return initialize_routing_logic(routing_logic, *args, **kwargs) @@ -506,6 +731,7 @@ def get_routing_logic() -> RoutingInterface: for cls in ( SessionRouter, RoundRobinRouter, + CacheAwareLoadBalancingRouter, KvawareRouter, PrefixAwareRouter, DisaggregatedPrefillRouter, diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 17cfc8844..cfccbf7ee 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -380,6 +380,20 @@ def _check_pod_ready(container_statuses): ready_count = sum(1 for status in container_statuses if status.ready) return ready_count == len(container_statuses) + @staticmethod + def _is_pod_terminating(pod): + """ + Check if the pod is in terminating state by checking + deletion timestamp and pod phase. + """ + # Check if pod has deletion timestamp + if pod.metadata.deletion_timestamp is not None: + return True + # Check if pod phase is failed or succeeded + if pod.status.phase in ["Failed", "Succeeded"]: + return True + return False + def _get_engine_sleep_status(self, pod_ip) -> Optional[bool]: """ Get the engine sleeping status by querying the engine's @@ -554,13 +568,25 @@ def _watch_engines(self): event_type = event["type"] pod_name = pod.metadata.name pod_ip = pod.status.pod_ip - is_pod_ready = self._check_pod_ready(pod.status.container_statuses) + + # Check if pod is terminating + is_pod_terminating = self._is_pod_terminating(pod) + is_container_ready = self._check_pod_ready(pod.status.container_statuses) + + # Pod is ready if container is ready and pod is not terminating + is_pod_ready = is_container_ready and not is_pod_terminating + if is_pod_ready: model_names = self._get_model_names(pod_ip) model_label = self._get_model_label(pod) else: model_names = [] model_label = None + + # Record pod status for debugging + if is_container_ready and is_pod_terminating: + logger.info(f"Pod {pod_name} has ready containers but is terminating - marking as unavailable") + self._on_engine_update( pod_name, pod_ip, diff --git a/src/vllm_router/services/metrics_service/__init__.py b/src/vllm_router/services/metrics_service/__init__.py index 80dbaed9a..8bfa47472 100644 --- a/src/vllm_router/services/metrics_service/__init__.py +++ b/src/vllm_router/services/metrics_service/__init__.py @@ -45,3 +45,10 @@ num_requests_swapped = Gauge( "vllm:num_requests_swapped", "Number of swapped requests", ["server"] ) + +# Routing method metrics +routing_method_qps = Gauge( + "vllm:routing_method_qps", + "Queries Per Second for different routing methods in the current time window", + ["method"] +) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 46969b2b2..110e8eb9c 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -23,6 +23,7 @@ from fastapi.responses import JSONResponse, StreamingResponse from vllm_router.log import init_logger +from vllm_router.request_logger import request_logger from vllm_router.routers.routing_logic import ( DisaggregatedPrefillRouter, KvawareRouter, @@ -69,6 +70,7 @@ async def process_request( request_id, endpoint, background_tasks: BackgroundTasks, + routing_method=None, debug_request=None, ): """ @@ -80,6 +82,7 @@ async def process_request( backend_url: The URL of the backend to send the request to. request_id: A unique identifier for the request. endpoint: The endpoint to send the request to on the backend. + routing_method: The routing method used to select this backend. debug_request: The original request object from the client, used for optional debug logging. @@ -93,7 +96,7 @@ async def process_request( total_len = 0 start_time = time.time() request.app.state.request_stats_monitor.on_new_request( - backend_url, request_id, start_time + backend_url, request_id, start_time, routing_method ) # Check if this is a streaming request is_streaming = False @@ -261,11 +264,11 @@ async def route_general_request( elif isinstance(request.app.state.router, KvawareRouter) or isinstance( request.app.state.router, PrefixAwareRouter ): - server_url = await request.app.state.router.route_request( + server_url, routing_method = await request.app.state.router.route_request( endpoints, engine_stats, request_stats, request, request_json ) else: - server_url = request.app.state.router.route_request( + server_url, routing_method = request.app.state.router.route_request( endpoints, engine_stats, request_stats, request ) @@ -282,16 +285,31 @@ async def route_general_request( session_id_display = session_id if session_id is not None else "None" # Debug logging to help troubleshoot session ID extraction - logger.debug( - f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" - ) - logger.debug(f"Debug session extraction - Session key config: {session_key}") - logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") - logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") + # logger.debug( + # f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" + # ) + # logger.debug(f"Debug session extraction - Session key config: {session_key}") + # logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") + # logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") logger.info( - f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" + f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, " + f"process time = {curr_time - in_router_time:.4f}, " + f"routing method = {routing_method}" + ) + + # record the request + request_logger.log_request_routed( + in_router_time, + request_id, + routing_method, + server_url, + session_id, + process_time=curr_time - in_router_time, ) + # record the request body + request_logger.log_request_body(request_body, request_id) + stream_generator = process_request( request, request_body, @@ -299,6 +317,7 @@ async def route_general_request( request_id, endpoint, background_tasks, + routing_method, ) headers, status_code = await anext(stream_generator) headers_dict = {key: value for key, value in headers.items()} diff --git a/src/vllm_router/stats/request_stats.py b/src/vllm_router/stats/request_stats.py index f0409b912..b209b2f6f 100644 --- a/src/vllm_router/stats/request_stats.py +++ b/src/vllm_router/stats/request_stats.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from collections import deque +from collections import deque, Counter from dataclasses import dataclass -from typing import Deque, Dict, Tuple +from typing import Deque, Dict, Tuple, Optional from vllm_router.log import init_logger @@ -53,6 +53,16 @@ class RequestStats: avg_itl: float # Number of swapped requests (moved from GPU to CPU) num_swapped_requests: int + # Dictionary counting different routing methods used + routing_methods: Dict[str, int] = None + +@dataclass +class RoutingMethodEntry: + """ + A class to record the routing method used for a request. + """ + method: str + timestamp: float class MovingAverageMonitor: @@ -138,11 +148,18 @@ def __init__(self, sliding_window_size: float = None): # Counter for swapped requests self.swapped_requests: Dict[str, int] = {} + + # Counting different routing methods used by each engine + self.routing_methods: Dict[str, Counter] = {} + + # Counting different routing methods used in the current time window + self.routing_methods_window: Deque[RoutingMethodEntry] = deque() self.first_query_time: float = None self._initialized = True - def on_new_request(self, engine_url: str, request_id: str, timestamp: float): + def on_new_request(self, engine_url: str, request_id: str, + timestamp: float, routing_method: Optional[str] = None): """ Tell the monitor that a new request has been created. @@ -150,6 +167,7 @@ def on_new_request(self, engine_url: str, request_id: str, timestamp: float): engine_url: The URL of the serving engine request_id: The global request ID timestamp: the timestamp when the request was created + routing_method: The routing method used for this request """ self.request_start_time[(engine_url, request_id)] = timestamp @@ -167,9 +185,33 @@ def on_new_request(self, engine_url: str, request_id: str, timestamp: float): self.latency_monitors[engine_url] = MovingAverageMonitor( self.sliding_window_size ) + + # Counting different routing methods used by each engine + if routing_method: + if engine_url not in self.routing_methods: + self.routing_methods[engine_url] = Counter() + self.routing_methods[engine_url][routing_method] += 1 + + # Counting different routing methods used in the current time window + self.routing_methods_window.append( + RoutingMethodEntry(method=routing_method, timestamp=timestamp) + ) + + # Clean up expired routing methods records + self._clean_routing_methods_window(timestamp) if self.first_query_time is None: self.first_query_time = timestamp + + def _clean_routing_methods_window(self, current_time: float): + """ + Clean up expired routing methods records + """ + while ( + self.routing_methods_window and + self.routing_methods_window[0].timestamp < current_time - self.sliding_window_size + ): + self.routing_methods_window.popleft() def on_request_response(self, engine_url: str, request_id: str, timestamp: float): """ @@ -288,6 +330,9 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: swapped = self.swapped_requests[engine_url] else: swapped = 0 + + # Get routing methods statistics + routing_methods_dict = dict(self.routing_methods.get(engine_url, Counter())) ret[engine_url] = RequestStats( qps=qps, @@ -302,8 +347,43 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: avg_latency=avg_lat, avg_itl=avg_itl_val, num_swapped_requests=swapped, + routing_methods=routing_methods_dict, ) return ret + + def get_all_routing_methods_in_window(self) -> Dict[str, int]: + """ + Get the usage statistics of all routing methods in the current time window + + Returns: + Dict[str, int]: The routing method and its usage count in the current time window + """ + # Clean up expired routing methods records + self._clean_routing_methods_window(time.time()) + + # Counting different routing methods used in the current time window + methods_count = Counter() + for entry in self.routing_methods_window: + methods_count[entry.method] += 1 + + return dict(methods_count) + + def get_routing_methods_qps(self) -> Dict[str, float]: + """ + Get the QPS (Queries Per Second) of routing methods in the current time window + + Returns: + Dict[str, float]: The routing method and its QPS in the current time window + """ + # Counting different routing methods used in the current time window + methods_count = self.get_all_routing_methods_in_window() + + # Calculating QPS = usage count / time window size + methods_qps = {} + for method, count in methods_count.items(): + methods_qps[method] = count / self.sliding_window_size + + return methods_qps def initialize_request_stats_monitor(sliding_window_size: float):