From 6cb9c07e34f96ae5cbb6e0def77c416345b5d4c3 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 04:18:54 +0000 Subject: [PATCH 01/11] [Feat] Add cache_aware_load_balancing router --- src/vllm_router/app.py | 3 + src/vllm_router/parsers/parser.py | 26 ++++ src/vllm_router/routers/routing_logic.py | 155 +++++++++++++++++++++++ 3 files changed, 184 insertions(+) diff --git a/src/vllm_router/app.py b/src/vllm_router/app.py index f28413d04..0372659c2 100644 --- a/src/vllm_router/app.py +++ b/src/vllm_router/app.py @@ -192,10 +192,13 @@ def initialize_all(app: FastAPI, args): initialize_routing_logic( args.routing_logic, session_key=args.session_key, + tolerate_waiting_requests=args.tolerate_waiting_requests, lmcache_controller_port=args.lmcache_controller_port, prefill_model_labels=args.prefill_model_labels, decode_model_labels=args.decode_model_labels, kv_aware_threshold=args.kv_aware_threshold, + enable_request_logging=args.enable_request_logging, + request_log_dir=args.request_log_dir, ) # Initialize feature gates diff --git a/src/vllm_router/parsers/parser.py b/src/vllm_router/parsers/parser.py index 21ad2c9e1..26337f1af 100644 --- a/src/vllm_router/parsers/parser.py +++ b/src/vllm_router/parsers/parser.py @@ -99,6 +99,10 @@ def validate_args(args): raise ValueError( "Session key must be provided when using session routing logic." ) + if args.routing_logic == "cache_aware_load_balancing" and args.session_key is None: + raise ValueError( + "Session key must be provided when using cache_aware_load_balancing routing logic." + ) if args.log_stats and args.log_stats_interval <= 0: raise ValueError("Log stats interval must be greater than 0.") if args.engine_stats_interval <= 0: @@ -180,12 +184,34 @@ def parse_args(): choices=[ "roundrobin", "session", + "cache_aware_load_balancing", "kvaware", "prefixaware", "disaggregated_prefill", ], help="The routing logic to use", ) + + parser.add_argument( + "--tolerate-waiting-requests", + type=int, + default=10, + help="The number of waiting requests to tolerate in cache-aware load balancing router.", + ) + + parser.add_argument( + "--enable-request-logging", + action="store_true", + help="Enable request logging, record the routing decision and performance data of each request", + ) + + parser.add_argument( + "--request-log-dir", + type=str, + default=None, + help="The directory to store the request log file, if provided, the log will be written to this file", + ) + parser.add_argument( "--lmcache-controller-port", type=int, diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index ef544c9c6..1220d88e8 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -213,6 +213,161 @@ def route_request( return url +class CacheAwareLoadBalancingRouter(RoutingInterface): + """ + Routing algorithm that combines load balancing with KV Cache hit rate awareness + + This algorithm considers three key factors: + 1. Engine load (number of queuing requests, number of running requests) + 2. Estimated KV cache hit rate (for specific sessions) + """ + + def __init__(self, session_key: str = None, tolerate_waiting_requests: int = 20): + if hasattr(self, "_initialized"): + return + + if session_key is None: + raise ValueError("CacheAwareLoadBalancingRouter must be initialized with a session_key") + + self.session_key = session_key + self.tolerate_waiting_requests = tolerate_waiting_requests + + # Initialize hash ring + self.hash_ring = HashRing() + + self.req_id = 0 # Request ID, used for round-robin selection + + self._initialized = True + + def _calculate_engine_load_score(self, engine_url: str, + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats]) -> float: + """ + Calculate engine load score + + Lower score indicates lighter engine load + + Load factors: load score (running requests * 0.02 + queuing requests * 0.1) + """ + if engine_url not in engine_stats: + return 0.0 # No statistics available, assume load is 0 + + # Get engine statistics + stats = engine_stats[engine_url] + + # Basic load factors: running requests and queuing requests + running_load = stats.num_running_requests * 0.02 # Running requests weight + queuing_load = stats.num_queuing_requests * 0.1 # Queuing requests weight (slightly higher) + + # Calculate total load score + total_load_score = running_load + queuing_load + + return total_load_score + + def _update_hash_ring(self, endpoints: List["EndpointInfo"]): + """ + Update the hash ring with the current list of endpoints. + """ + # Extract endpoint URLs + endpoint_urls = [endpoint.url for endpoint in endpoints] + + # Get the current nodes in the hash ring + current_nodes = set(self.hash_ring.get_nodes()) + + # Convert the new endpoint URLs to a set for easy comparison + new_nodes = set(endpoint_urls) + + # Remove nodes that are no longer in the list + for node in current_nodes - new_nodes: + self.hash_ring.remove_node(node) + + # Add new nodes that are not already in the hash ring + for node in new_nodes - current_nodes: + self.hash_ring.add_node(node) + + def _select_best_engine(self, session_id: str, endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats]) -> Tuple[str, str]: + """ + Select the best engine + 1. First determine which engine the request corresponds to based on hash_ring + 2. Check the queue situation of that engine (num_queuing_requests) + 3. If there are queuing requests (>tolerate_waiting_requests), try to find an engine without queue + 4. If all engines have queues, assign engine based on probability + 5. If the initial engine has no queuing requests, use session-based routing (i.e., hash_ring result) + """ + # Update hash ring to reflect currently available endpoints + self._update_hash_ring(endpoints) + + # Use hash_ring to get the initial engine_url + initial_engine_url = self.hash_ring.get_node(session_id) + + routing_method = "cache_aware" + + # Check the queuing situation of the initial engine + if engine_stats[initial_engine_url].num_queuing_requests < self.tolerate_waiting_requests: + # If queuing requests are less than tolerate_waiting_requests, use it directly + logger.debug(f"Session {session_id} initial engine waitting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}") + return initial_engine_url, routing_method + + # Try to find engines without queue + engines_without_queue = [] + for info in endpoints: + url = info.url + if engine_stats[url].num_queuing_requests == 0: + engines_without_queue.append(url) + + # If there are engines without queue, randomly select one + if engines_without_queue: + routing_method = "redirect_to_no_queue_engine" + selected_engine = random.choice(engines_without_queue) + logger.debug(f"Session {session_id} redirect to no queue engine: {selected_engine}") + return selected_engine, routing_method + + # All engines have queues, select one based on probability + routing_method = "probability_based" + total_queue_length = sum(engine_stats[url].num_queuing_requests for url in [info.url for info in endpoints]) + probabilities = [1 / (engine_stats[url].num_queuing_requests / total_queue_length) for url in [info.url for info in endpoints]] + probabilities = [p / sum(probabilities) for p in probabilities] + + selected_engine = random.choices([info.url for info in endpoints], weights=probabilities)[0] + + logger.debug(f"Session {session_id} probability based routing to: {selected_engine}") + return selected_engine, routing_method + + def route_request( + self, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + request: Request, + ) -> str: + """ + Intelligent request routing, combining load awareness and cache hit rate prediction + + For requests with session ID, intelligent selection is made based on KV cache hit rate prediction and load conditions + For requests without session ID, engine selection is purely based on load balancing + """ + # Extract session ID + session_id = request.headers.get(self.session_key, None) + logger.debug(f"Got session id: {session_id}") + + routing_method = "load_balancing" + + if session_id is None: + # No session ID, use pure load balancing strategy + engine_url = min( + [e.url for e in endpoints], + key=lambda url: self._calculate_engine_load_score(url, engine_stats, request_stats) + ) + routing_method = "load_based" + else: + # Has session ID, use comprehensive strategy + engine_url, routing_method = self._select_best_engine(session_id, endpoints, engine_stats, request_stats) + + return engine_url, routing_method + + class KvawareRouter(RoutingInterface): """ Route the request to the appropriate engine URL by where the KV cache From e43c1714ea34dd05063af7aed4f747e1679697f0 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 04:45:31 +0000 Subject: [PATCH 02/11] fix: remove trailing whitespace --- src/vllm_router/routers/routing_logic.py | 114 +++++++++++++++-------- 1 file changed, 73 insertions(+), 41 deletions(-) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 1220d88e8..60e343920 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -216,54 +216,61 @@ def route_request( class CacheAwareLoadBalancingRouter(RoutingInterface): """ Routing algorithm that combines load balancing with KV Cache hit rate awareness - + This algorithm considers three key factors: 1. Engine load (number of queuing requests, number of running requests) 2. Estimated KV cache hit rate (for specific sessions) """ - + def __init__(self, session_key: str = None, tolerate_waiting_requests: int = 20): if hasattr(self, "_initialized"): return - + if session_key is None: - raise ValueError("CacheAwareLoadBalancingRouter must be initialized with a session_key") - + raise ValueError( + "CacheAwareLoadBalancingRouter must be initialized with a session_key" + ) + self.session_key = session_key self.tolerate_waiting_requests = tolerate_waiting_requests - + # Initialize hash ring self.hash_ring = HashRing() self.req_id = 0 # Request ID, used for round-robin selection - + self._initialized = True - - def _calculate_engine_load_score(self, engine_url: str, - engine_stats: Dict[str, EngineStats], - request_stats: Dict[str, RequestStats]) -> float: + + def _calculate_engine_load_score( + self, + engine_url: str, + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + ) -> float: """ Calculate engine load score - + Lower score indicates lighter engine load Load factors: load score (running requests * 0.02 + queuing requests * 0.1) """ if engine_url not in engine_stats: return 0.0 # No statistics available, assume load is 0 - + # Get engine statistics stats = engine_stats[engine_url] - + # Basic load factors: running requests and queuing requests running_load = stats.num_running_requests * 0.02 # Running requests weight - queuing_load = stats.num_queuing_requests * 0.1 # Queuing requests weight (slightly higher) - + queuing_load = ( + stats.num_queuing_requests * 0.1 + ) # Queuing requests weight (slightly higher) + # Calculate total load score total_load_score = running_load + queuing_load - + return total_load_score - + def _update_hash_ring(self, endpoints: List["EndpointInfo"]): """ Update the hash ring with the current list of endpoints. @@ -285,9 +292,13 @@ def _update_hash_ring(self, endpoints: List["EndpointInfo"]): for node in new_nodes - current_nodes: self.hash_ring.add_node(node) - def _select_best_engine(self, session_id: str, endpoints: List[EndpointInfo], - engine_stats: Dict[str, EngineStats], - request_stats: Dict[str, RequestStats]) -> Tuple[str, str]: + def _select_best_engine( + self, + session_id: str, + endpoints: List[EndpointInfo], + engine_stats: Dict[str, EngineStats], + request_stats: Dict[str, RequestStats], + ) -> Tuple[str, str]: """ Select the best engine 1. First determine which engine the request corresponds to based on hash_ring @@ -298,41 +309,58 @@ def _select_best_engine(self, session_id: str, endpoints: List[EndpointInfo], """ # Update hash ring to reflect currently available endpoints self._update_hash_ring(endpoints) - + # Use hash_ring to get the initial engine_url initial_engine_url = self.hash_ring.get_node(session_id) routing_method = "cache_aware" - + # Check the queuing situation of the initial engine - if engine_stats[initial_engine_url].num_queuing_requests < self.tolerate_waiting_requests: + if ( + engine_stats[initial_engine_url].num_queuing_requests + < self.tolerate_waiting_requests + ): # If queuing requests are less than tolerate_waiting_requests, use it directly - logger.debug(f"Session {session_id} initial engine waitting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}") + logger.debug( + f"Session {session_id} initial engine waitting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}" + ) return initial_engine_url, routing_method - + # Try to find engines without queue engines_without_queue = [] for info in endpoints: url = info.url if engine_stats[url].num_queuing_requests == 0: engines_without_queue.append(url) - + # If there are engines without queue, randomly select one if engines_without_queue: routing_method = "redirect_to_no_queue_engine" selected_engine = random.choice(engines_without_queue) - logger.debug(f"Session {session_id} redirect to no queue engine: {selected_engine}") + logger.debug( + f"Session {session_id} redirect to no queue engine: {selected_engine}" + ) return selected_engine, routing_method - + # All engines have queues, select one based on probability routing_method = "probability_based" - total_queue_length = sum(engine_stats[url].num_queuing_requests for url in [info.url for info in endpoints]) - probabilities = [1 / (engine_stats[url].num_queuing_requests / total_queue_length) for url in [info.url for info in endpoints]] + total_queue_length = sum( + engine_stats[url].num_queuing_requests + for url in [info.url for info in endpoints] + ) + probabilities = [ + 1 / (engine_stats[url].num_queuing_requests / total_queue_length) + for url in [info.url for info in endpoints] + ] probabilities = [p / sum(probabilities) for p in probabilities] - - selected_engine = random.choices([info.url for info in endpoints], weights=probabilities)[0] - - logger.debug(f"Session {session_id} probability based routing to: {selected_engine}") + + selected_engine = random.choices( + [info.url for info in endpoints], weights=probabilities + )[0] + + logger.debug( + f"Session {session_id} probability based routing to: {selected_engine}" + ) return selected_engine, routing_method def route_request( @@ -344,27 +372,31 @@ def route_request( ) -> str: """ Intelligent request routing, combining load awareness and cache hit rate prediction - + For requests with session ID, intelligent selection is made based on KV cache hit rate prediction and load conditions For requests without session ID, engine selection is purely based on load balancing """ # Extract session ID session_id = request.headers.get(self.session_key, None) logger.debug(f"Got session id: {session_id}") - + routing_method = "load_balancing" - + if session_id is None: # No session ID, use pure load balancing strategy engine_url = min( [e.url for e in endpoints], - key=lambda url: self._calculate_engine_load_score(url, engine_stats, request_stats) + key=lambda url: self._calculate_engine_load_score( + url, engine_stats, request_stats + ), ) routing_method = "load_based" else: # Has session ID, use comprehensive strategy - engine_url, routing_method = self._select_best_engine(session_id, endpoints, engine_stats, request_stats) - + engine_url, routing_method = self._select_best_engine( + session_id, endpoints, engine_stats, request_stats + ) + return engine_url, routing_method From e62094319c443505716297ae6bd697855939623a Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 04:47:42 +0000 Subject: [PATCH 03/11] minor fix spelling errors --- src/vllm_router/routers/routing_logic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 60e343920..31e82c8bc 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -322,7 +322,7 @@ def _select_best_engine( ): # If queuing requests are less than tolerate_waiting_requests, use it directly logger.debug( - f"Session {session_id} initial engine waitting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}" + f"Session {session_id} initial engine waiting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}" ) return initial_engine_url, routing_method From 19596508877705d3bd792e3d8cb022cad17fe251 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 05:53:23 +0000 Subject: [PATCH 04/11] [Feat] Add request logging --- src/vllm_router/request_logger.py | 135 ++++++++++++++++++ src/vllm_router/routers/routing_logic.py | 23 +++ .../services/request_service/request.py | 21 ++- 3 files changed, 175 insertions(+), 4 deletions(-) create mode 100644 src/vllm_router/request_logger.py diff --git a/src/vllm_router/request_logger.py b/src/vllm_router/request_logger.py new file mode 100644 index 000000000..d30b3ce5c --- /dev/null +++ b/src/vllm_router/request_logger.py @@ -0,0 +1,135 @@ +# Copyright 2024-2025 The vLLM Production Stack Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from fastapi import Request +import os +import json +from datetime import datetime + +from vllm_router.log import init_logger + +logger = init_logger(__name__) + +class RequestLogger: + """Request logger for recording routing decisions and performance data""" + + def __init__(self): + """Initialize request logger""" + self.log_enabled = False + self.log_file = None + self.requests_bodies_dir = None + # Dictionary for temporarily storing request arrival times + self.arrival_times = {} + + def enable_logging(self, enabled: bool = True, log_dir: str = None): + """ + Enable or disable logging + + Args: + enabled: Whether to enable logging + log_dir: Log directory path, if provided write to file, otherwise save to project path + """ + self.log_enabled = enabled + self.log_file = None # Reset log_file + + if enabled: + # If log_dir is not specified, use project path + if not log_dir: + log_dir = os.getcwd() + logger.info(f"No log directory specified, will use current working directory: {log_dir}") + + try: + # Ensure log directory exists + os.makedirs(log_dir, exist_ok=True) + + # Create request body storage directory + self.requests_bodies_dir = os.path.join(log_dir, "request_bodies") + os.makedirs(self.requests_bodies_dir, exist_ok=True) + + # Create log file name with timestamp + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + log_file = os.path.join(log_dir, f"router_requests_logs_{timestamp}.csv") + + # Create log file and write header + with open(log_file, 'w') as f: + f.write("timestamp,request_id,conversation_id,arrival_time,routing_method,target_engine,process_time\n") + + self.log_file = log_file + logger.info(f"Request routing logs will be written to file: {log_file}") + except Exception as e: + logger.error(f"Failed to create log file: {str(e)}") + self.log_file = None + + logger.info("Request routing logging has been enabled") + + def log_request_routed(self, arrival_time: float, request_id: str, routing_method: str, target_engine: str, + session_id: str = None, process_time: float = None): + """Record request routing decision and timestamp, and write to file (if enabled)""" + if not self.log_enabled or not self.log_file: + return + + # Ensure session_id has a value + session_id = session_id or "unknown" + + # Write to file + try: + timestamp = time.strftime("%Y-%m-%d %H:%M:%S") + log_line = f"{timestamp},{request_id},{session_id},{arrival_time},{routing_method},{target_engine},{process_time}\n" + + with open(self.log_file, 'a') as f: + f.write(log_line) + + except Exception as e: + logger.error(f"Failed to write to log file: {str(e)}") + + def log_request_body(self, request_body, request_id=None): + """ + Log request body to a separate file + + Args: + request_body: Request body content + request_id: Request ID, if None then try to extract from request body + """ + if not self.log_enabled or not self.requests_bodies_dir: + return + + if not request_id: + # Try to extract request_id from request body + try: + body_json = json.loads(request_body) + request_id = body_json.get("id", str(int(time.time()))) + except: + request_id = str(int(time.time())) + + # Create file name + file_path = os.path.join(self.requests_bodies_dir, f"{request_id}.json") + + # Write request body to file + try: + with open(file_path, "wb") as f: + if isinstance(request_body, bytes): + f.write(request_body) + else: + f.write(request_body.encode('utf-8')) + except Exception as e: + logger.error(f"Failed to write request body file: {str(e)}") + + def clear_logs(self): + """Clear temporarily stored arrival times""" + self.arrival_times.clear() + + +# Create global request logger instance +request_logger = RequestLogger() \ No newline at end of file diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 31e82c8bc..444572a96 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -646,12 +646,25 @@ def route_request( def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs ) -> RoutingInterface: + + from vllm_router.request_logger import request_logger + if kwargs.get("enable_request_logging"): + logger.info("Enabling request logging") + request_logger.enable_logging(True, kwargs.get("request_log_dir")) + if routing_logic == RoutingLogic.ROUND_ROBIN: logger.info("Initializing round-robin routing logic") return RoundRobinRouter() elif routing_logic == RoutingLogic.SESSION_BASED: logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}") return SessionRouter(kwargs.get("session_key")) + elif routing_logic == RoutingLogic.CACHE_AWARE_LOAD_BALANCING: + logger.info(f"Initializing cache-aware load balancing routing logic with kwargs: {kwargs}") + router = CacheAwareLoadBalancingRouter( + kwargs.get("session_key"), + kwargs.get("tolerate_waiting_requests") + ) + return router elif routing_logic == RoutingLogic.KVAWARE: logger.info("Initializing kvaware routing logic") router = KvawareRouter( @@ -685,6 +698,16 @@ def reconfigure_routing_logic( ): if cls in SingletonABCMeta._instances: del SingletonABCMeta._instances[cls] + + # Re-configure request logging + from vllm_router.request_logger import request_logger + if kwargs.get("enable_request_logging"): + logger.info("Re-enabling request logging with new configuration") + request_logger.enable_logging(True, kwargs.get("request_log_dir")) + else: + # If request logging is not enabled, disable it + request_logger.enable_logging(False) + return initialize_routing_logic(routing_logic, *args, **kwargs) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 46969b2b2..7c81d06d8 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -21,6 +21,7 @@ import httpx from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse +from vllm_router.request_logger import request_logger from vllm_router.log import init_logger from vllm_router.routers.routing_logic import ( @@ -69,6 +70,7 @@ async def process_request( request_id, endpoint, background_tasks: BackgroundTasks, + routing_method=None, debug_request=None, ): """ @@ -80,6 +82,7 @@ async def process_request( backend_url: The URL of the backend to send the request to. request_id: A unique identifier for the request. endpoint: The endpoint to send the request to on the backend. + routing_method: The routing method used to select this backend. debug_request: The original request object from the client, used for optional debug logging. @@ -93,7 +96,7 @@ async def process_request( total_len = 0 start_time = time.time() request.app.state.request_stats_monitor.on_new_request( - backend_url, request_id, start_time + backend_url, request_id, start_time, routing_method ) # Check if this is a streaming request is_streaming = False @@ -261,11 +264,11 @@ async def route_general_request( elif isinstance(request.app.state.router, KvawareRouter) or isinstance( request.app.state.router, PrefixAwareRouter ): - server_url = await request.app.state.router.route_request( + server_url, routing_method = await request.app.state.router.route_request( endpoints, engine_stats, request_stats, request, request_json ) else: - server_url = request.app.state.router.route_request( + server_url, routing_method = request.app.state.router.route_request( endpoints, engine_stats, request_stats, request ) @@ -290,8 +293,17 @@ async def route_general_request( logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") logger.info( - f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, process time = {curr_time - in_router_time:.4f}" + f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, " + f"process time = {curr_time - in_router_time:.4f}, " + f"routing method = {routing_method}" ) + + # record the request + request_logger.log_request_routed(in_router_time, request_id, routing_method, + server_url, session_id, process_time=curr_time - in_router_time) + # record the request body + request_logger.log_request_body(request_body, request_id) + stream_generator = process_request( request, request_body, @@ -299,6 +311,7 @@ async def route_general_request( request_id, endpoint, background_tasks, + routing_method, ) headers, status_code = await anext(stream_generator) headers_dict = {key: value for key, value in headers.items()} From b0023a552af68ae0ca9a30e894be6d50808072ef Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 05:59:49 +0000 Subject: [PATCH 05/11] minor fix import --- src/vllm_router/routers/routing_logic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 444572a96..3ed0c22bb 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -20,7 +20,7 @@ import random import socket import threading -from typing import Dict, List +from typing import Dict, List, Tuple import requests from fastapi import Request From fbbd7821299de774012783d7d37484c4743e343c Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 06:16:54 +0000 Subject: [PATCH 06/11] minor fix --- src/vllm_router/routers/routing_logic.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 3ed0c22bb..f973053ad 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -53,6 +53,7 @@ class RoutingLogic(str, enum.Enum): ROUND_ROBIN = "roundrobin" SESSION_BASED = "session" + CACHE_AWARE_LOAD_BALANCING = "cache_aware_load_balancing" KVAWARE = "kvaware" PREFIXAWARE = "prefixaware" DISAGGREGATED_PREFILL = "disaggregated_prefill" From 244382096e1cdb75ac9da01ef359141bbe3eb38b Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Tue, 8 Jul 2025 06:41:37 +0000 Subject: [PATCH 07/11] style: fix end-of-file and trailing whitespace issues --- src/vllm_router/request_logger.py | 81 +++++++++++-------- src/vllm_router/routers/routing_logic.py | 15 ++-- .../services/request_service/request.py | 14 +++- 3 files changed, 67 insertions(+), 43 deletions(-) diff --git a/src/vllm_router/request_logger.py b/src/vllm_router/request_logger.py index d30b3ce5c..664eea420 100644 --- a/src/vllm_router/request_logger.py +++ b/src/vllm_router/request_logger.py @@ -12,19 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. -import time -from fastapi import Request -import os import json +import os +import time from datetime import datetime +from fastapi import Request + from vllm_router.log import init_logger logger = init_logger(__name__) + class RequestLogger: """Request logger for recording routing decisions and performance data""" - + def __init__(self): """Initialize request logger""" self.log_enabled = False @@ -32,79 +34,92 @@ def __init__(self): self.requests_bodies_dir = None # Dictionary for temporarily storing request arrival times self.arrival_times = {} - + def enable_logging(self, enabled: bool = True, log_dir: str = None): """ Enable or disable logging - + Args: enabled: Whether to enable logging log_dir: Log directory path, if provided write to file, otherwise save to project path """ self.log_enabled = enabled self.log_file = None # Reset log_file - + if enabled: # If log_dir is not specified, use project path if not log_dir: log_dir = os.getcwd() - logger.info(f"No log directory specified, will use current working directory: {log_dir}") - + logger.info( + f"No log directory specified, will use current working directory: {log_dir}" + ) + try: # Ensure log directory exists os.makedirs(log_dir, exist_ok=True) - + # Create request body storage directory self.requests_bodies_dir = os.path.join(log_dir, "request_bodies") os.makedirs(self.requests_bodies_dir, exist_ok=True) - + # Create log file name with timestamp timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") - log_file = os.path.join(log_dir, f"router_requests_logs_{timestamp}.csv") - + log_file = os.path.join( + log_dir, f"router_requests_logs_{timestamp}.csv" + ) + # Create log file and write header - with open(log_file, 'w') as f: - f.write("timestamp,request_id,conversation_id,arrival_time,routing_method,target_engine,process_time\n") - + with open(log_file, "w") as f: + f.write( + "timestamp,request_id,conversation_id,arrival_time,routing_method,target_engine,process_time\n" + ) + self.log_file = log_file logger.info(f"Request routing logs will be written to file: {log_file}") except Exception as e: logger.error(f"Failed to create log file: {str(e)}") self.log_file = None - + logger.info("Request routing logging has been enabled") - - def log_request_routed(self, arrival_time: float, request_id: str, routing_method: str, target_engine: str, - session_id: str = None, process_time: float = None): + + def log_request_routed( + self, + arrival_time: float, + request_id: str, + routing_method: str, + target_engine: str, + session_id: str = None, + process_time: float = None, + ): """Record request routing decision and timestamp, and write to file (if enabled)""" if not self.log_enabled or not self.log_file: return - + # Ensure session_id has a value session_id = session_id or "unknown" - + # Write to file try: timestamp = time.strftime("%Y-%m-%d %H:%M:%S") log_line = f"{timestamp},{request_id},{session_id},{arrival_time},{routing_method},{target_engine},{process_time}\n" - - with open(self.log_file, 'a') as f: + + with open(self.log_file, "a") as f: f.write(log_line) - + except Exception as e: logger.error(f"Failed to write to log file: {str(e)}") - + def log_request_body(self, request_body, request_id=None): """ Log request body to a separate file - + Args: request_body: Request body content request_id: Request ID, if None then try to extract from request body """ if not self.log_enabled or not self.requests_bodies_dir: return - + if not request_id: # Try to extract request_id from request body try: @@ -112,24 +127,24 @@ def log_request_body(self, request_body, request_id=None): request_id = body_json.get("id", str(int(time.time()))) except: request_id = str(int(time.time())) - + # Create file name file_path = os.path.join(self.requests_bodies_dir, f"{request_id}.json") - + # Write request body to file try: with open(file_path, "wb") as f: if isinstance(request_body, bytes): f.write(request_body) else: - f.write(request_body.encode('utf-8')) + f.write(request_body.encode("utf-8")) except Exception as e: logger.error(f"Failed to write request body file: {str(e)}") - + def clear_logs(self): """Clear temporarily stored arrival times""" self.arrival_times.clear() # Create global request logger instance -request_logger = RequestLogger() \ No newline at end of file +request_logger = RequestLogger() diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index f973053ad..96c8f5e5a 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -647,8 +647,9 @@ def route_request( def initialize_routing_logic( routing_logic: RoutingLogic, *args, **kwargs ) -> RoutingInterface: - + from vllm_router.request_logger import request_logger + if kwargs.get("enable_request_logging"): logger.info("Enabling request logging") request_logger.enable_logging(True, kwargs.get("request_log_dir")) @@ -660,10 +661,11 @@ def initialize_routing_logic( logger.info(f"Initializing session-based routing logic with kwargs: {kwargs}") return SessionRouter(kwargs.get("session_key")) elif routing_logic == RoutingLogic.CACHE_AWARE_LOAD_BALANCING: - logger.info(f"Initializing cache-aware load balancing routing logic with kwargs: {kwargs}") + logger.info( + f"Initializing cache-aware load balancing routing logic with kwargs: {kwargs}" + ) router = CacheAwareLoadBalancingRouter( - kwargs.get("session_key"), - kwargs.get("tolerate_waiting_requests") + kwargs.get("session_key"), kwargs.get("tolerate_waiting_requests") ) return router elif routing_logic == RoutingLogic.KVAWARE: @@ -699,16 +701,17 @@ def reconfigure_routing_logic( ): if cls in SingletonABCMeta._instances: del SingletonABCMeta._instances[cls] - + # Re-configure request logging from vllm_router.request_logger import request_logger + if kwargs.get("enable_request_logging"): logger.info("Re-enabling request logging with new configuration") request_logger.enable_logging(True, kwargs.get("request_log_dir")) else: # If request logging is not enabled, disable it request_logger.enable_logging(False) - + return initialize_routing_logic(routing_logic, *args, **kwargs) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 7c81d06d8..46909f60e 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -21,9 +21,9 @@ import httpx from fastapi import BackgroundTasks, Request from fastapi.responses import JSONResponse, StreamingResponse -from vllm_router.request_logger import request_logger from vllm_router.log import init_logger +from vllm_router.request_logger import request_logger from vllm_router.routers.routing_logic import ( DisaggregatedPrefillRouter, KvawareRouter, @@ -299,11 +299,17 @@ async def route_general_request( ) # record the request - request_logger.log_request_routed(in_router_time, request_id, routing_method, - server_url, session_id, process_time=curr_time - in_router_time) + request_logger.log_request_routed( + in_router_time, + request_id, + routing_method, + server_url, + session_id, + process_time=curr_time - in_router_time, + ) # record the request body request_logger.log_request_body(request_body, request_id) - + stream_generator = process_request( request, request_body, From 25ed806a5d37743a1ee0faf9e51f40859877c755 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Wed, 23 Jul 2025 08:00:44 +0000 Subject: [PATCH 08/11] [Bugfix] Fix a bug in rolling restart --- src/vllm_router/routers/routing_logic.py | 44 ++++++++++-------------- 1 file changed, 19 insertions(+), 25 deletions(-) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 96c8f5e5a..349a2a606 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -89,6 +89,7 @@ def _update_hash_ring(self, endpoints: List["EndpointInfo"]): """ Update the hash ring with the current list of endpoints. """ + # logger.debug(f"Updating hash ring with endpoints: {endpoints}") # Extract endpoint URLs endpoint_urls = [endpoint.url for endpoint in endpoints] @@ -159,7 +160,8 @@ def route_request( len_engines = len(endpoints) chosen = sorted(endpoints, key=lambda e: e.url)[self.req_id % len_engines] self.req_id += 1 - return chosen.url + routing_method = "round_robin" + return chosen.url, routing_method class SessionRouter(RoutingInterface): @@ -211,7 +213,14 @@ def route_request( # Use the hash ring to get the endpoint for the session ID url = self.hash_ring.get_node(session_id) - return url + # If the initial engine is not found in engine_stats + if url not in engine_stats: + logger.warning( + f"Engine {url} not found in engine_stats" + ) + + routing_method = "session_based" + return url, routing_method class CacheAwareLoadBalancingRouter(RoutingInterface): @@ -272,27 +281,6 @@ def _calculate_engine_load_score( return total_load_score - def _update_hash_ring(self, endpoints: List["EndpointInfo"]): - """ - Update the hash ring with the current list of endpoints. - """ - # Extract endpoint URLs - endpoint_urls = [endpoint.url for endpoint in endpoints] - - # Get the current nodes in the hash ring - current_nodes = set(self.hash_ring.get_nodes()) - - # Convert the new endpoint URLs to a set for easy comparison - new_nodes = set(endpoint_urls) - - # Remove nodes that are no longer in the list - for node in current_nodes - new_nodes: - self.hash_ring.remove_node(node) - - # Add new nodes that are not already in the hash ring - for node in new_nodes - current_nodes: - self.hash_ring.add_node(node) - def _select_best_engine( self, session_id: str, @@ -314,7 +302,12 @@ def _select_best_engine( # Use hash_ring to get the initial engine_url initial_engine_url = self.hash_ring.get_node(session_id) - routing_method = "cache_aware" + # If the initial engine is not found in engine_stats + if initial_engine_url not in engine_stats: + logger.warning( + f"Engine {initial_engine_url} not found in engine_stats" + ) + return initial_engine_url, "cache_aware" # Check the queuing situation of the initial engine if ( @@ -325,7 +318,7 @@ def _select_best_engine( logger.debug( f"Session {session_id} initial engine waiting requests < {self.tolerate_waiting_requests}, route to: {initial_engine_url}" ) - return initial_engine_url, routing_method + return initial_engine_url, "cache_aware" # Try to find engines without queue engines_without_queue = [] @@ -720,6 +713,7 @@ def get_routing_logic() -> RoutingInterface: for cls in ( SessionRouter, RoundRobinRouter, + CacheAwareLoadBalancingRouter, KvawareRouter, PrefixAwareRouter, DisaggregatedPrefillRouter, From 1aed6017a2874146911e0ae495217c3ca6ec5c32 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Wed, 23 Jul 2025 08:01:48 +0000 Subject: [PATCH 09/11] [Feat] Add routing method qps metrics --- src/vllm_router/routers/metrics_router.py | 22 +++++ .../services/metrics_service/__init__.py | 7 ++ .../services/request_service/request.py | 12 +-- src/vllm_router/stats/request_stats.py | 86 ++++++++++++++++++- 4 files changed, 118 insertions(+), 9 deletions(-) diff --git a/src/vllm_router/routers/metrics_router.py b/src/vllm_router/routers/metrics_router.py index 276023949..491c4977a 100644 --- a/src/vllm_router/routers/metrics_router.py +++ b/src/vllm_router/routers/metrics_router.py @@ -32,6 +32,7 @@ num_prefill_requests, num_requests_running, num_requests_swapped, + routing_method_qps, ) from vllm_router.stats.engine_stats import get_engine_stats_scraper from vllm_router.stats.request_stats import get_request_stats_monitor @@ -99,6 +100,27 @@ async def metrics(): avg_itl.labels(server=server).set(stat.avg_itl) num_requests_swapped.labels(server=server).set(stat.num_swapped_requests) + # ----------------------------------------------------------------------------- + + # Routing method QPS + routing_methods_qps = get_request_stats_monitor().get_routing_methods_qps() + # Save all known routing methods, for resetting + known_methods = set() + # Find registered routing method labels + for labels, _ in routing_method_qps._metrics.items(): + if labels and len(labels) > 0: + method = labels[0] # Assuming labels is (method,) format + known_methods.add(method) + # Reset QPS of all known but currently inactive routing methods to 0 + for method in known_methods: + if method not in routing_methods_qps: + routing_method_qps.labels(method=method).set(0) + # Set QPS of currently active routing methods + for method, qps_value in routing_methods_qps.items(): + routing_method_qps.labels(method=method).set(qps_value) + + # ----------------------------------------------------------------------------- + # Engine statistics (GPU prefix cache metrics) engine_stats = get_engine_stats_scraper().get_engine_stats() for server, engine_stat in engine_stats.items(): diff --git a/src/vllm_router/services/metrics_service/__init__.py b/src/vllm_router/services/metrics_service/__init__.py index 80dbaed9a..8bfa47472 100644 --- a/src/vllm_router/services/metrics_service/__init__.py +++ b/src/vllm_router/services/metrics_service/__init__.py @@ -45,3 +45,10 @@ num_requests_swapped = Gauge( "vllm:num_requests_swapped", "Number of swapped requests", ["server"] ) + +# Routing method metrics +routing_method_qps = Gauge( + "vllm:routing_method_qps", + "Queries Per Second for different routing methods in the current time window", + ["method"] +) diff --git a/src/vllm_router/services/request_service/request.py b/src/vllm_router/services/request_service/request.py index 46909f60e..110e8eb9c 100644 --- a/src/vllm_router/services/request_service/request.py +++ b/src/vllm_router/services/request_service/request.py @@ -285,12 +285,12 @@ async def route_general_request( session_id_display = session_id if session_id is not None else "None" # Debug logging to help troubleshoot session ID extraction - logger.debug( - f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" - ) - logger.debug(f"Debug session extraction - Session key config: {session_key}") - logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") - logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") + # logger.debug( + # f"Debug session extraction - Router type: {type(request.app.state.router).__name__}" + # ) + # logger.debug(f"Debug session extraction - Session key config: {session_key}") + # logger.debug(f"Debug session extraction - Request headers: {dict(request.headers)}") + # logger.debug(f"Debug session extraction - Extracted session ID: {session_id}") logger.info( f"Routing request {request_id} with session id {session_id_display} to {server_url} at {curr_time}, " diff --git a/src/vllm_router/stats/request_stats.py b/src/vllm_router/stats/request_stats.py index f0409b912..b209b2f6f 100644 --- a/src/vllm_router/stats/request_stats.py +++ b/src/vllm_router/stats/request_stats.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. import time -from collections import deque +from collections import deque, Counter from dataclasses import dataclass -from typing import Deque, Dict, Tuple +from typing import Deque, Dict, Tuple, Optional from vllm_router.log import init_logger @@ -53,6 +53,16 @@ class RequestStats: avg_itl: float # Number of swapped requests (moved from GPU to CPU) num_swapped_requests: int + # Dictionary counting different routing methods used + routing_methods: Dict[str, int] = None + +@dataclass +class RoutingMethodEntry: + """ + A class to record the routing method used for a request. + """ + method: str + timestamp: float class MovingAverageMonitor: @@ -138,11 +148,18 @@ def __init__(self, sliding_window_size: float = None): # Counter for swapped requests self.swapped_requests: Dict[str, int] = {} + + # Counting different routing methods used by each engine + self.routing_methods: Dict[str, Counter] = {} + + # Counting different routing methods used in the current time window + self.routing_methods_window: Deque[RoutingMethodEntry] = deque() self.first_query_time: float = None self._initialized = True - def on_new_request(self, engine_url: str, request_id: str, timestamp: float): + def on_new_request(self, engine_url: str, request_id: str, + timestamp: float, routing_method: Optional[str] = None): """ Tell the monitor that a new request has been created. @@ -150,6 +167,7 @@ def on_new_request(self, engine_url: str, request_id: str, timestamp: float): engine_url: The URL of the serving engine request_id: The global request ID timestamp: the timestamp when the request was created + routing_method: The routing method used for this request """ self.request_start_time[(engine_url, request_id)] = timestamp @@ -167,9 +185,33 @@ def on_new_request(self, engine_url: str, request_id: str, timestamp: float): self.latency_monitors[engine_url] = MovingAverageMonitor( self.sliding_window_size ) + + # Counting different routing methods used by each engine + if routing_method: + if engine_url not in self.routing_methods: + self.routing_methods[engine_url] = Counter() + self.routing_methods[engine_url][routing_method] += 1 + + # Counting different routing methods used in the current time window + self.routing_methods_window.append( + RoutingMethodEntry(method=routing_method, timestamp=timestamp) + ) + + # Clean up expired routing methods records + self._clean_routing_methods_window(timestamp) if self.first_query_time is None: self.first_query_time = timestamp + + def _clean_routing_methods_window(self, current_time: float): + """ + Clean up expired routing methods records + """ + while ( + self.routing_methods_window and + self.routing_methods_window[0].timestamp < current_time - self.sliding_window_size + ): + self.routing_methods_window.popleft() def on_request_response(self, engine_url: str, request_id: str, timestamp: float): """ @@ -288,6 +330,9 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: swapped = self.swapped_requests[engine_url] else: swapped = 0 + + # Get routing methods statistics + routing_methods_dict = dict(self.routing_methods.get(engine_url, Counter())) ret[engine_url] = RequestStats( qps=qps, @@ -302,8 +347,43 @@ def get_request_stats(self, current_time: float) -> Dict[str, RequestStats]: avg_latency=avg_lat, avg_itl=avg_itl_val, num_swapped_requests=swapped, + routing_methods=routing_methods_dict, ) return ret + + def get_all_routing_methods_in_window(self) -> Dict[str, int]: + """ + Get the usage statistics of all routing methods in the current time window + + Returns: + Dict[str, int]: The routing method and its usage count in the current time window + """ + # Clean up expired routing methods records + self._clean_routing_methods_window(time.time()) + + # Counting different routing methods used in the current time window + methods_count = Counter() + for entry in self.routing_methods_window: + methods_count[entry.method] += 1 + + return dict(methods_count) + + def get_routing_methods_qps(self) -> Dict[str, float]: + """ + Get the QPS (Queries Per Second) of routing methods in the current time window + + Returns: + Dict[str, float]: The routing method and its QPS in the current time window + """ + # Counting different routing methods used in the current time window + methods_count = self.get_all_routing_methods_in_window() + + # Calculating QPS = usage count / time window size + methods_qps = {} + for method, count in methods_count.items(): + methods_qps[method] = count / self.sliding_window_size + + return methods_qps def initialize_request_stats_monitor(sliding_window_size: float): From 2e7cc1a2c9f34bf8ca2305b23f5d28840b780663 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Thu, 24 Jul 2025 11:15:00 +0000 Subject: [PATCH 10/11] [Feat] Add method to check Pod termination status and update Pod readiness logic --- src/vllm_router/service_discovery.py | 28 +++++++++++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/src/vllm_router/service_discovery.py b/src/vllm_router/service_discovery.py index 17cfc8844..cfccbf7ee 100644 --- a/src/vllm_router/service_discovery.py +++ b/src/vllm_router/service_discovery.py @@ -380,6 +380,20 @@ def _check_pod_ready(container_statuses): ready_count = sum(1 for status in container_statuses if status.ready) return ready_count == len(container_statuses) + @staticmethod + def _is_pod_terminating(pod): + """ + Check if the pod is in terminating state by checking + deletion timestamp and pod phase. + """ + # Check if pod has deletion timestamp + if pod.metadata.deletion_timestamp is not None: + return True + # Check if pod phase is failed or succeeded + if pod.status.phase in ["Failed", "Succeeded"]: + return True + return False + def _get_engine_sleep_status(self, pod_ip) -> Optional[bool]: """ Get the engine sleeping status by querying the engine's @@ -554,13 +568,25 @@ def _watch_engines(self): event_type = event["type"] pod_name = pod.metadata.name pod_ip = pod.status.pod_ip - is_pod_ready = self._check_pod_ready(pod.status.container_statuses) + + # Check if pod is terminating + is_pod_terminating = self._is_pod_terminating(pod) + is_container_ready = self._check_pod_ready(pod.status.container_statuses) + + # Pod is ready if container is ready and pod is not terminating + is_pod_ready = is_container_ready and not is_pod_terminating + if is_pod_ready: model_names = self._get_model_names(pod_ip) model_label = self._get_model_label(pod) else: model_names = [] model_label = None + + # Record pod status for debugging + if is_container_ready and is_pod_terminating: + logger.info(f"Pod {pod_name} has ready containers but is terminating - marking as unavailable") + self._on_engine_update( pod_name, pod_ip, From 33f22495722fbf70272f817f8026061659d55061 Mon Sep 17 00:00:00 2001 From: KevinCheung2259 <2651309292@qq.com> Date: Fri, 1 Aug 2025 07:50:56 +0000 Subject: [PATCH 11/11] [Feat] Improve routing logic with boundary checks and enhanced probability calculation --- src/vllm_router/routers/routing_logic.py | 44 +++++++++++++++++------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/src/vllm_router/routers/routing_logic.py b/src/vllm_router/routers/routing_logic.py index 349a2a606..c6f62c547 100644 --- a/src/vllm_router/routers/routing_logic.py +++ b/src/vllm_router/routers/routing_logic.py @@ -324,36 +324,54 @@ def _select_best_engine( engines_without_queue = [] for info in endpoints: url = info.url - if engine_stats[url].num_queuing_requests == 0: + # Add boundary check for engine_stats + if url in engine_stats and engine_stats[url].num_queuing_requests == 0: engines_without_queue.append(url) # If there are engines without queue, randomly select one if engines_without_queue: - routing_method = "redirect_to_no_queue_engine" selected_engine = random.choice(engines_without_queue) logger.debug( f"Session {session_id} redirect to no queue engine: {selected_engine}" ) - return selected_engine, routing_method + return selected_engine, "redirect_to_no_queue_engine" - # All engines have queues, select one based on probability + # All engines have queues, select one based on improved probability calculation routing_method = "probability_based" + + # Filter endpoints that have engine stats + valid_endpoints = [info for info in endpoints if info.url in engine_stats] + if not valid_endpoints: + # Fallback to initial engine if no valid stats available + logger.warning("No valid engine stats available, falling back to initial engine") + return initial_engine_url, "cache_aware_fallback" + + # Calculate total queue length from valid endpoints only total_queue_length = sum( - engine_stats[url].num_queuing_requests - for url in [info.url for info in endpoints] + engine_stats[info.url].num_queuing_requests + for info in valid_endpoints ) - probabilities = [ - 1 / (engine_stats[url].num_queuing_requests / total_queue_length) - for url in [info.url for info in endpoints] - ] - probabilities = [p / sum(probabilities) for p in probabilities] + + # Fixed probability calculation: inverse of normalized queue length + queue_lengths = [engine_stats[info.url].num_queuing_requests for info in valid_endpoints] + max_queue = max(queue_lengths) + + # Calculate inverse weights (higher weight for lower queue length) + # Add small epsilon to avoid division by zero + epsilon = 0.1 + inverse_weights = [(max_queue - queue_len + epsilon) for queue_len in queue_lengths] + total_weight = sum(inverse_weights) + + # Normalize to probabilities + probabilities = [weight / total_weight for weight in inverse_weights] selected_engine = random.choices( - [info.url for info in endpoints], weights=probabilities + [info.url for info in valid_endpoints], weights=probabilities )[0] logger.debug( - f"Session {session_id} probability based routing to: {selected_engine}" + f"Session {session_id} probability based routing to: {selected_engine}, " + f"queue_lengths: {queue_lengths}, probabilities: {[f'{p:.3f}' for p in probabilities]}" ) return selected_engine, routing_method