From 04ad8ca9a1648584e4cb868f64a669343ead6089 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 16 Mar 2026 05:52:31 +0000 Subject: [PATCH 01/13] fix --- lightllm/server/router/batch.py | 6 +++- lightllm/server/router/manager.py | 5 ++- lightllm/server/router/stats.py | 56 +++++++++---------------------- 3 files changed, 25 insertions(+), 42 deletions(-) diff --git a/lightllm/server/router/batch.py b/lightllm/server/router/batch.py index e9524466f0..24d0b9b824 100644 --- a/lightllm/server/router/batch.py +++ b/lightllm/server/router/batch.py @@ -3,6 +3,7 @@ from typing import Dict, List, Optional, Tuple, Union from lightllm.server.core.objs import ShmReqManager, Req from lightllm.utils.log_utils import init_logger +from .stats import RouterStatics logger = init_logger(__name__) @@ -49,11 +50,14 @@ def get_all_dp_req_num(self) -> List[int]: all_dp_req_num[req.sample_params.suggested_dp_index] += 1 return all_dp_req_num - def filter_out_finished_req(self, shm_req_manager: ShmReqManager): + def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics): unfinished_req_ids = [] for req in self.reqs: if req.shm_infer_released: logger.info(f"router release req id {req.request_id}") + if not req.is_aborted: + router_statics.update(req.candetoken_out_len) + shm_req_manager.put_back_req_obj(req) req = None else: diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 6d40c28ddb..0d2705fab2 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -33,6 +33,7 @@ from lightllm.utils.graceful_utils import graceful_registry from lightllm.utils.process_check import start_parent_check_thread from lightllm.utils.envs_utils import get_unique_server_name +from .stats import RouterStatics logger = init_logger(__name__) @@ -105,6 +106,7 @@ def __init__(self, args: StartArgs): if not self.args.enable_cpu_cache else CpuKvCacheClient(only_create_meta_data=True, init_shm_data=False) ) + self.router_statics = RouterStatics(self.args) return async def wait_to_model_ready(self): @@ -256,6 +258,7 @@ async def loop_for_fwd( f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" ) + logger.debug(self.router_statics.log_str()) self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) # pd decode mode need to update token_load more frequently self.req_queue.update_token_load(self.running_batch, force_update=self.is_pd_decode_mode) @@ -347,7 +350,7 @@ def _add_new_batch_to_running_batch(self, new_batch: Batch): def _filter_reqs_from_running_batch(self): if self.running_batch is not None: - self.running_batch.filter_out_finished_req(self.shm_req_manager) + self.running_batch.filter_out_finished_req(self.shm_req_manager, self.router_statics) if self.running_batch.is_clear(): self.running_batch = None return diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index d50c4e7ca5..e4069bdc3b 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -1,47 +1,23 @@ -import time from lightllm.utils.log_utils import init_logger -from .batch import Batch +from lightllm.server.core.objs import StartArgs logger = init_logger(__name__) -class Stats: - def __init__(self, log_status, log_stats_interval) -> None: - self.log_stats = log_status - self.log_stats_interval = log_stats_interval - self.last_log_time = time.time() - self.all_tokens = 0 - self.output_tokens = 0 - self.prompt_tokens = 0 - return +class RouterStatics: + def __init__(self, args: StartArgs): + self.busy_token_used_ratio = args.router_token_ratio + self.ema_req_out_put_len = 2048 + self.ema_params = 0.04 - def count_prompt_tokens(self, run_batch: Batch): - if self.log_stats and run_batch is not None: - tokens = run_batch.input_tokens() - self.prompt_tokens += tokens - self.all_tokens += tokens - return + def update(self, req_out_len: int): + # 过滤掉输出特别短的情况,防止计算得过于短,导致调度频繁引发暂停,导致系统吞吐下降。 + req_out_len = max(req_out_len, 64) + self.ema_req_out_put_len = int(self.ema_req_out_put_len * (1 - self.ema_params) + req_out_len * self.ema_params) + self.ema_req_out_put_len = max(64, self.ema_req_out_put_len) - def count_output_tokens(self, run_batch: Batch): - if self.log_stats and run_batch is not None: - tokens = len(run_batch.reqs) - self.output_tokens += tokens - self.all_tokens += tokens - return - - def print_stats(self): - if not self.log_stats: - return - - now = time.time() - if now - self.last_log_time > self.log_stats_interval: - logger.debug( - f"Avg tokens(prompt+generate) throughput: {self.all_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg prompt tokens throughput: {self.prompt_tokens/(now-self.last_log_time):8.3f} tokens/s\n" - f"Avg generate tokens throughput: {self.output_tokens/(now-self.last_log_time):8.3f} tokens/s" - ) - self.all_tokens = 0 - self.output_tokens = 0 - self.prompt_tokens = 0 - self.last_log_time = now - return + def log_str(self) -> str: + return ( + f"RouterStatics busy_token_used_ratio: {self.busy_token_used_ratio} " + f"ema_req_out_put_len: {self.ema_req_out_put_len}" + ) From 87e54c8b45c986041bc56f2abd80cfbd273861eb Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 16 Mar 2026 06:29:07 +0000 Subject: [PATCH 02/13] fix --- docs/CN/source/tutorial/api_server_args.rst | 5 ----- docs/EN/source/tutorial/api_server_args.rst | 5 ----- lightllm/server/api_cli.py | 3 --- lightllm/server/core/objs/req.py | 8 +++----- lightllm/server/core/objs/start_args_type.py | 1 - lightllm/server/router/req_queue/base_queue.py | 1 - .../server/router/req_queue/chunked_prefill/beam_impl.py | 4 ++-- lightllm/server/router/req_queue/chunked_prefill/impl.py | 6 ++++-- .../router/req_queue/chunked_prefill/impl_for_nixl_pd.py | 4 +++- .../req_queue/chunked_prefill/impl_for_pd_decode.py | 2 +- lightllm/server/router/stats.py | 8 ++++---- 11 files changed, 17 insertions(+), 30 deletions(-) diff --git a/docs/CN/source/tutorial/api_server_args.rst b/docs/CN/source/tutorial/api_server_args.rst index 2354818efa..29ab549045 100644 --- a/docs/CN/source/tutorial/api_server_args.rst +++ b/docs/CN/source/tutorial/api_server_args.rst @@ -191,11 +191,6 @@ PD 分离模式参数 判断服务是否繁忙的阈值,默认为 ``0.0``,一旦kv cache 使用率超过此值,则会直接变为保守调度。 -.. option:: --router_max_new_token_len - - 调度器评估请求kv占用时,使用的请求输出长度,默认为 ``1024``,一般低于用户设置的max_new_tokens。该参数只在 --router_token_ratio 大于0时生效。 - 设置改参数,会使请求调度更为激进,系统同时处理的请求数会更多,同时也会不可避免的造成请求的暂停重计算。 - .. option:: --router_max_wait_tokens 每 router_max_wait_tokens 解码步骤后触发一次调度新请求,默认为 ``6`` diff --git a/docs/EN/source/tutorial/api_server_args.rst b/docs/EN/source/tutorial/api_server_args.rst index ac4c1b87ec..dd6a9bca04 100644 --- a/docs/EN/source/tutorial/api_server_args.rst +++ b/docs/EN/source/tutorial/api_server_args.rst @@ -190,11 +190,6 @@ Scheduling Parameters Threshold for determining if the service is busy, default is ``0.0``. Once the kv cache usage exceeds this value, it will directly switch to conservative scheduling. -.. option:: --router_max_new_token_len - - The request output length used by the scheduler when evaluating request kv usage, default is ``1024``, generally lower than the max_new_tokens set by the user. This parameter only takes effect when --router_token_ratio is greater than 0. - Setting this parameter will make request scheduling more aggressive, allowing the system to process more requests simultaneously, but will inevitably cause request pause and recalculation. - .. option:: --router_max_wait_tokens Trigger scheduling of new requests every router_max_wait_tokens decoding steps, default is ``6`` diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 762d84575b..dfe2933db4 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -238,9 +238,6 @@ def make_argument_parser() -> argparse.ArgumentParser: ) parser.add_argument("--router_token_ratio", type=float, default=0.0, help="token ratio to control router dispatch") - parser.add_argument( - "--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router" - ) parser.add_argument( "--router_max_wait_tokens", diff --git a/lightllm/server/core/objs/req.py b/lightllm/server/core/objs/req.py index f489aac9c2..8905248bf8 100644 --- a/lightllm/server/core/objs/req.py +++ b/lightllm/server/core/objs/req.py @@ -260,7 +260,7 @@ def can_release(self): def get_used_tokens(self): return max(0, self.shm_cur_kv_len) - def get_tuple_tokens(self, is_busy, router_max_new_token_len): + def get_tuple_tokens(self, is_busy, ema_req_out_len): raise NotImplementedError("Subclasses should implement this method") def get_decode_need_tokens(self): @@ -311,7 +311,7 @@ def print_time_log(self, log_info: str): class ChunkedPrefillReq(Req): _pack_ = 4 - def get_tuple_tokens(self, is_busy, router_max_new_token_len): + def get_tuple_tokens(self, is_busy, ema_req_out_len): args = get_env_start_args() # chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于 # 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存 @@ -327,9 +327,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len): elif is_busy: cur_max_new_token_len = self.sample_params.max_new_tokens else: - cur_max_new_token_len = min( - self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), router_max_new_token_len) - ) + cur_max_new_token_len = min(self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), ema_req_out_len)) a_len = max(self.input_len + has_out_len + 1, self.shm_cur_kv_len + 1) b_len = ( diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 4b54cdccef..d3dc849664 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -68,7 +68,6 @@ class StartArgs: disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) router_token_ratio: float = field(default=0.0) - router_max_new_token_len: int = field(default=1024) router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) diff --git a/lightllm/server/router/req_queue/base_queue.py b/lightllm/server/router/req_queue/base_queue.py index 36aefae6e7..73113a59b8 100644 --- a/lightllm/server/router/req_queue/base_queue.py +++ b/lightllm/server/router/req_queue/base_queue.py @@ -25,7 +25,6 @@ def __init__(self, args: StartArgs, router, dp_index, dp_size_in_node) -> None: self.running_max_req_size = args.running_max_req_size # Maximum number of concurrent requests self.waiting_req_list: List[Req] = [] # List of queued requests self.router_token_ratio = args.router_token_ratio # ratio to determine whether the router is busy - self.router_max_new_token_len = args.router_max_new_token_len def free_aborted_req_cpu_cache_pages(self, req: Req): if self.args.enable_cpu_cache: diff --git a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py index e1bdf5ab8d..23f94de704 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/beam_impl.py @@ -16,7 +16,7 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None: def _init_cache_list(self, current_batch: Batch, is_busy): if current_batch is not None: self.cache_len_list = [ - (req, req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) + (req, req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len)) for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index ] @@ -28,7 +28,7 @@ def _init_cache_list(self, current_batch: Batch, is_busy): def _can_add_new_group_reqs(self, cur_handle_group_reqs: List[Req], is_busy, new_batch_first_router_need_tokens): for req in cur_handle_group_reqs: self.cache_len_list.append( - (req, req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) + (req, req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len)) ) # hard to analysis self.cache_len_list.sort(key=lambda x: -x[1][1]) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index 03a99642bd..884b5930b0 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -16,7 +16,7 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None: def _init_cache_list(self, current_batch: Batch, is_busy): if current_batch is not None: self.cache_len_list = [ - req.get_tuple_tokens(is_busy, self.router_max_new_token_len) + req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len) for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index ] @@ -26,7 +26,9 @@ def _init_cache_list(self, current_batch: Batch, is_busy): # @calculate_time(show=True, min_cost_ms=0.1) def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens): - self.cache_len_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) # hard to analysis + self.cache_len_list.append( + req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len) + ) # hard to analysis self.cache_len_list.sort(key=lambda x: -x[1]) left_out_len_array = np.array([e[1] for e in self.cache_len_list]) diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py index ba981a95bb..3b831c92a6 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_nixl_pd.py @@ -38,7 +38,9 @@ def _caclu_batch_estimated_peak_token_num(self, batch: Batch): for req in batch.reqs: if req.sample_params.suggested_dp_index == self.dp_index: if req.is_infer_decode(): - decoding_req_list.append(req.get_tuple_tokens(is_busy, self.router_max_new_token_len)) + decoding_req_list.append( + req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len) + ) else: estimated_peak_token_num += req.input_len + req.sample_params.max_new_tokens diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py index d546d34004..4c2ebf7c00 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl_for_pd_decode.py @@ -18,7 +18,7 @@ def __init__(self, args, router, dp_index, dp_size_in_node) -> None: def _init_cache_list(self, current_batch: Batch, is_busy): if current_batch is not None: self.cache_len_list = [ - req.get_tuple_tokens(is_busy, self.router_max_new_token_len) + req.get_tuple_tokens(is_busy, self.router.router_statics.ema_req_out_len) for req in current_batch.reqs if req.sample_params.suggested_dp_index == self.dp_index ] diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index e4069bdc3b..9f4ef8bee7 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -7,17 +7,17 @@ class RouterStatics: def __init__(self, args: StartArgs): self.busy_token_used_ratio = args.router_token_ratio - self.ema_req_out_put_len = 2048 + self.ema_req_out_len = 2048 self.ema_params = 0.04 def update(self, req_out_len: int): # 过滤掉输出特别短的情况,防止计算得过于短,导致调度频繁引发暂停,导致系统吞吐下降。 req_out_len = max(req_out_len, 64) - self.ema_req_out_put_len = int(self.ema_req_out_put_len * (1 - self.ema_params) + req_out_len * self.ema_params) - self.ema_req_out_put_len = max(64, self.ema_req_out_put_len) + self.ema_req_out_len = int(self.ema_req_out_len * (1 - self.ema_params) + req_out_len * self.ema_params) + self.ema_req_out_len = max(64, self.ema_req_out_len) def log_str(self) -> str: return ( f"RouterStatics busy_token_used_ratio: {self.busy_token_used_ratio} " - f"ema_req_out_put_len: {self.ema_req_out_put_len}" + f"ema_req_out_put_len: {self.ema_req_out_len}" ) From 3830b45627c3ea1525baa7f077ea9e592a023b7a Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 16 Mar 2026 06:43:50 +0000 Subject: [PATCH 03/13] fix --- lightllm/server/api_cli.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfe2933db4..67c6bd982b 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -237,8 +237,18 @@ def make_argument_parser() -> argparse.ArgumentParser: This setting allows you to turn off these warning checks.""", ) - parser.add_argument("--router_token_ratio", type=float, default=0.0, help="token ratio to control router dispatch") - + parser.add_argument( + "--router_token_ratio", + type=float, + default=None, + help="""Token used ratio to control router dispatch, range in [0.0, 1.0]. + When the token VRAM usage ratio is higher than this value, + the dispatch strategy tends to be conservative. + When the token VRAM usage ratio is lower than this value, + the dispatching of requests tends to be aggressive. + The default value is None, meaning it will be automatically + determined by the internal system based on other startup parameters.""", + ) parser.add_argument( "--router_max_wait_tokens", type=int, From 43e88be030147b421a01dec08ffade5168d0dc4b Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Mon, 16 Mar 2026 07:04:08 +0000 Subject: [PATCH 04/13] fix --- lightllm/server/api_start.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 0db786d0bf..f8bec5779c 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -105,6 +105,17 @@ def normal_or_p_d_start(args): if args.enable_multimodal: args.multi_modal_cache_shm_id = uuid.uuid1().int % 123456789 + # 调度参数的自动设置, 人工设置则听人工的 + if args.router_token_ratio is None: + if args.run_mode in ["normal"]: + args.router_token_ratio = 0.8 + else: + # pd 分离模式下,不开启高级调度 + args.router_token_ratio = 0.0 + # 部分模式还不能支持与高级动态调度算法协同,to do. + if args.diverse_mode: + assert args.router_token_ratio == 0.0 + if not args.disable_shm_warning: check_recommended_shm_size(args) @@ -146,10 +157,6 @@ def normal_or_p_d_start(args): assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache" assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill" - # 部分模式还不能支持与高级动态调度算法协同,to do. - if args.diverse_mode: - assert args.router_token_ratio == 0.0 - if args.enable_dp_prefill_balance: assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1" From d60e7f741694a022ff2b86d3e8fbf53f9907a15c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 06:20:27 +0000 Subject: [PATCH 05/13] split pd split max new tokens. --- .../httpserver_for_pd_master/manager.py | 85 ++++++++++++++----- lightllm/utils/envs_utils.py | 5 ++ 2 files changed, 68 insertions(+), 22 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index fda1387f3b..585f4ddcc1 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -20,6 +20,7 @@ from lightllm.utils.statics_utils import MovingAverage from lightllm.server.httpserver.manager import AsyncQueue from lightllm.utils.error_utils import ServerBusyError +from lightllm.utils.envs_utils import get_pd_split_max_new_tokens from .pd_selector import create_selector logger = init_logger(__name__) @@ -92,44 +93,76 @@ async def generate( multimodal_params: MultimodalParams, request: Request, ): + assert isinstance(prompt, str), "prompt must be str" start_time = time.time() - group_request_id = self.id_gen.generate_id() + # 先将请求根据max_new_tokens 参数进行分块操作,主要是 pd 分离场景中, + # 只能使用保守调度,但是如果用户都设置一个很大的 max_new_tokens 值,会 + # 导致极大显存预留,照成系统的吞吐能力下降,所以我们将请求分割成几段进行 + # 推理,只要保证分块合理,实际分段推理是极少发生的情况,系统吞吐就不会受 + # 到影响。 + origin_sampling_params = SamplingParams.from_buffer_copy(sampling_params) + origin_group_request_id = self.id_gen.generate_id() + max_new_tokens_list = self._split_max_new_tokens(max_new_tokens=origin_sampling_params.max_new_tokens) + try: - sampling_params.group_request_id = group_request_id # 记录请求到达的相关信息 - await self._log_req_header(request, group_request_id) + await self._log_req_header(request, origin_group_request_id) # 监控 self.metric_client.counter_inc("lightllm_request_count") - self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) + self.metric_client.histogram_observe( + "lightllm_request_max_new_tokens", origin_sampling_params.max_new_tokens + ) + + p_node, d_node = await self.select_p_d_node(prompt, origin_sampling_params, multimodal_params) - p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) + history_gen_token_strs = [] if not p_node or not d_node: - logger.error(f"{group_request_id}: No p_node or d_node found") - raise Exception(f"{group_request_id}: No p_node or d_node found") - - results_generator = self._wait_to_token_package( - p_node, - d_node, - start_time, - prompt, - sampling_params, - multimodal_params, - request, - ) - async for sub_req_id, request_output, metadata, finish_status in results_generator: - yield sub_req_id, request_output, metadata, finish_status + logger.error(f"{origin_group_request_id}: No p_node or d_node found") + raise Exception(f"{origin_group_request_id}: No p_node or d_node found") + + for iter_index, block_max_new_tokens in enumerate(max_new_tokens_list): + sampling_params = SamplingParams.from_buffer_copy(origin_sampling_params) + block_group_request_id = self.id_gen.generate_id() + sampling_params.group_request_id = block_group_request_id + logger.info( + f"pd log gen sub req id {block_group_request_id}" f" for main req id {origin_group_request_id}" + ) + sampling_params.max_new_tokens = block_max_new_tokens + + results_generator = self._wait_to_token_package( + p_node, + d_node, + start_time, + prompt + "".join(history_gen_token_strs), + sampling_params, + multimodal_params, + request, + ) + is_last_block = iter_index == len(max_new_tokens_list) - 1 + prompt_tokens = sys.maxsize # 因为分段的原因 + async for sub_req_id, request_output, metadata, finish_status in results_generator: + # pd 分离模式下,返回的 metadata 可能序号信息可能存在不准确性。 + assert sub_req_id == block_group_request_id + if finish_status.get_finish_reason() == "length" and (not is_last_block): + finish_status = FinishStatus() # 转换为NoFinished + history_gen_token_strs.append(request_output) + prompt_tokens = min(prompt_tokens, metadata["prompt_tokens"]) + metadata["prompt_tokens"] = prompt_tokens + yield origin_group_request_id, request_output, metadata, finish_status + + await self.remove_req(group_request_id=block_group_request_id) except BaseException as e: logger.error(f"has exception {str(e)}") try: - await self.abort(group_request_id, p_node=p_node, d_node=d_node) + await self.abort(block_group_request_id, p_node=p_node, d_node=d_node) except: - await self.abort(group_request_id) + await self.abort(block_group_request_id) raise e finally: - await self.remove_req(group_request_id) + await self.remove_req(block_group_request_id) return async def _log_req_header(self, request: Request, group_request_id: int): @@ -465,6 +498,14 @@ async def handle_loop(self): logger.exception(str(e)) return + def _split_max_new_tokens(self, max_new_tokens: int) -> List[int]: + block_max_new_tokens = get_pd_split_max_new_tokens() + ans_list = [block_max_new_tokens for _ in range(max_new_tokens // block_max_new_tokens)] + left_token = max_new_tokens - (max_new_tokens // block_max_new_tokens) * block_max_new_tokens + if left_token > 0: + ans_list.append(left_token) + return ans_list + class ReqStatus: def __init__(self, req_id, p_node, d_node) -> None: diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 7a7a9be121..7d8b7e527a 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -232,3 +232,8 @@ def get_added_mtp_kv_layer_num() -> int: added_mtp_layer_num += get_env_start_args().mtp_step return added_mtp_layer_num + + +@lru_cache(maxsize=None) +def get_pd_split_max_new_tokens() -> int: + return int(os.getenv("LIGHTLLM_PD_SPLIT_MAX_NEW_TOKENS", 2048)) From 2c191ff9b1e002b59d435ef14d065ce3d3521f3c Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 08:11:15 +0000 Subject: [PATCH 06/13] fix --- lightllm/server/router/model_infer/infer_batch.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 1b4a1ca5cb..0a83b101be 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -204,6 +204,7 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): req.paused = True if is_master_in_dp: req.shm_req.is_paused = True + logger.debug(f"infer paused req id {req.req_id}") if len(free_token_index) != 0: free_token_index = custom_cat(free_token_index) @@ -225,6 +226,7 @@ def recover_paused_reqs(self, paused_reqs: List["InferReq"], is_master_in_dp: bo req.paused = False if is_master_in_dp: req.shm_req.is_paused = False + logger.debug(f"infer recover paused req id {req.req_id}") can_alloc_token_num -= prefill_need_token_num g_infer_state_lock.release() From 456e189c48a5d72cde180287be9d0af8fcae3870 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 08:17:53 +0000 Subject: [PATCH 07/13] fix --- lightllm/server/api_start.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index f8bec5779c..c0b9905b16 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -108,7 +108,7 @@ def normal_or_p_d_start(args): # 调度参数的自动设置, 人工设置则听人工的 if args.router_token_ratio is None: if args.run_mode in ["normal"]: - args.router_token_ratio = 0.8 + args.router_token_ratio = 0.85 else: # pd 分离模式下,不开启高级调度 args.router_token_ratio = 0.0 From 8954561fda41f896ba7e1465223791ce00d4d4d8 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 09:08:50 +0000 Subject: [PATCH 08/13] fix --- lightllm/server/api_start.py | 9 +++++++-- lightllm/utils/net_utils.py | 4 ++-- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c0b9905b16..ef907e075a 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -242,7 +242,7 @@ def normal_or_p_d_start(args): node_world_size = args.tp // args.nnodes can_use_ports = alloc_can_use_network_port( - num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports + num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports ) logger.info(f"alloced ports: {can_use_ports}") ( @@ -416,7 +416,12 @@ def pd_master_start(args): logger.info(f"use tgi api: {args.use_tgi_api}") logger.info(f"all start args:{args}") - can_use_ports = alloc_can_use_network_port(num=1, used_nccl_ports=[args.nccl_port, args.port]) + can_use_ports = alloc_can_use_network_port( + num=1, + used_ports=[ + args.port, + ], + ) metric_port = can_use_ports[0] args.metric_port = metric_port diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index 20b9888753..b87096d945 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -7,12 +7,12 @@ logger = init_logger(__name__) -def alloc_can_use_network_port(num=3, used_nccl_ports=None, from_port_num=10000): +def alloc_can_use_network_port(num=3, used_ports=None, from_port_num=10000): port_list = [] for port in range(from_port_num, 65536): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: result = s.connect_ex(("localhost", port)) - if result != 0 and port not in used_nccl_ports: + if result != 0 and port not in used_ports: port_list.append(port) if len(port_list) > num * 30: break From 2bbc0134854a6c4e4a593af09e4cf425ca1456af Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 09:22:07 +0000 Subject: [PATCH 09/13] fix set unique_server_name --- lightllm/utils/envs_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 7d8b7e527a..caf1d7bd9a 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -1,6 +1,7 @@ import os import json import torch +import uuid from easydict import EasyDict from functools import lru_cache from lightllm.utils.log_utils import init_logger @@ -13,7 +14,8 @@ def set_unique_server_name(args): if args.run_mode == "pd_master": os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" else: - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.nccl_port) + "_" + str(args.node_rank) + node_uuid = uuid.uuid1().hex[0:8] + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(node_uuid) + "_" + str(args.node_rank) return From 63fcf1f91965a984f93b919a07d22132a6d629f2 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 09:23:40 +0000 Subject: [PATCH 10/13] fix --- lightllm/utils/envs_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index caf1d7bd9a..41089e7612 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -11,10 +11,11 @@ def set_unique_server_name(args): + node_uuid = uuid.uuid1().hex[0:8] + if args.run_mode == "pd_master": - os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(args.port) + "_pd_master" + os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(node_uuid) + "_pd_master" else: - node_uuid = uuid.uuid1().hex[0:8] os.environ["LIGHTLLM_UNIQUE_SERVICE_NAME_ID"] = str(node_uuid) + "_" + str(args.node_rank) return From a52831e6de68be92a80694b4656b9e4a475bc8d5 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Tue, 17 Mar 2026 10:27:53 +0000 Subject: [PATCH 11/13] fix --- lightllm/server/api_start.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index ef907e075a..77355f0d06 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -447,7 +447,6 @@ def pd_master_start(args): "-", "--error-logfile", "-", - "--preload", "lightllm.server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", @@ -485,7 +484,6 @@ def config_server_start(args): "-", "--error-logfile", "-", - "--preload", "lightllm.server.config_server.api_http:app", "--keep-alive", f"{get_lightllm_gunicorn_keep_alive()}", From 962930c3fbe990ac3f692ff80d750ae278921d52 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 18 Mar 2026 05:47:18 +0000 Subject: [PATCH 12/13] fix --- lightllm/server/api_cli.py | 7 ++++++- lightllm/server/httpserver/manager.py | 19 +++++++++++++---- .../httpserver_for_pd_master/manager.py | 21 +++++++++++++++++++ 3 files changed, 42 insertions(+), 5 deletions(-) diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 67c6bd982b..c8a82d3239 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -200,7 +200,12 @@ def make_argument_parser() -> argparse.ArgumentParser: help="the dp balancer type, default is bs_balancer", ) parser.add_argument( - "--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len" + "--max_req_total_len", + type=int, + default=16384, + help="Maximum allowed length for a request (input tokens + output tokens). " + "In PD (Prefill-Decode) mode, this value must be synchronized across the " + "PD master, prefill, and decode nodes.", ) parser.add_argument( "--nccl_host", diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 6481098eb9..e28e4c93ad 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -480,10 +480,21 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params: if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len: # use long_truncation_mode to truncate long input len req. if self.args.long_truncation_mode is None: - raise ValueError( - f"the input prompt token len {prompt_tokens} + max_new_tokens \ - {sampling_params.max_new_tokens} > {self.max_req_total_len}" - ) + # 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将 + # 修改 max_new_tokens 的值,使其满足合法约束。 + new_max_new_tokens = self.max_req_total_len - prompt_tokens + if new_max_new_tokens > 0: + logger.debug( + f"the input prompt token len {prompt_tokens} + max_new_tokens" + f"{sampling_params.max_new_tokens} > {self.max_req_total_len}," + f"so change max_new_tokens to {new_max_new_tokens}" + ) + sampling_params.max_new_tokens = new_max_new_tokens + else: + raise ValueError( + f"the input prompt token len {prompt_tokens} + max_new_tokens \ + {sampling_params.max_new_tokens} > {self.max_req_total_len}" + ) elif self.args.long_truncation_mode == "head": prompt_ids = prompt_ids[-(self.max_req_total_len - sampling_params.max_new_tokens) :] elif self.args.long_truncation_mode == "center": diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 585f4ddcc1..8d9443e933 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -32,6 +32,8 @@ def __init__( args: StartArgs, ): self.args = args + self.max_req_total_len = args.max_req_total_len + assert self.max_req_total_len is not None self.metric_client = MetricClient(args.metric_port) self.id_gen = ReqIDGenerator() @@ -95,6 +97,16 @@ async def generate( ): assert isinstance(prompt, str), "prompt must be str" start_time = time.time() + # 计算输入的 input_token_num, 进行校验,如果输入+输出参数设置太长,则将 + # sampling_params 的参数进行修正。 + input_token_num = self.tokens(prompt, multimodal_params, sampling_params) + fake_prompt_ids = [0 for _ in range(input_token_num)] + from lightllm.server.httpserver.manager import HttpServerManager + + await HttpServerManager._check_and_repair_length( + self, prompt_ids=fake_prompt_ids, sampling_params=sampling_params + ) + # 先将请求根据max_new_tokens 参数进行分块操作,主要是 pd 分离场景中, # 只能使用保守调度,但是如果用户都设置一个很大的 max_new_tokens 值,会 # 导致极大显存预留,照成系统的吞吐能力下降,所以我们将请求分割成几段进行 @@ -551,6 +563,15 @@ def __init__(self, args: StartArgs): def register_pd(self, pd_info_json, websocket): pd_client = PD_Client_Obj(**pd_info_json) + client_max_req_total_len = pd_client.start_args["max_req_total_len"] + if client_max_req_total_len != self.args.max_req_total_len: + logger.error( + f"client dont has same max_req_total_len params, but pd master is {self.args.max_req_total_len}" + f"client is {client_max_req_total_len}" + f"client info {pd_info_json}" + ) + assert False + pd_client.websocket = websocket self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client From 5976fd10197698450d29c6c99b82456b71d5dd40 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Wed, 18 Mar 2026 06:25:54 +0000 Subject: [PATCH 13/13] fix --- lightllm/server/httpserver_for_pd_master/manager.py | 4 +--- lightllm/server/router/stats.py | 8 ++++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/lightllm/server/httpserver_for_pd_master/manager.py b/lightllm/server/httpserver_for_pd_master/manager.py index 8d9443e933..d6a1a58b05 100644 --- a/lightllm/server/httpserver_for_pd_master/manager.py +++ b/lightllm/server/httpserver_for_pd_master/manager.py @@ -137,9 +137,7 @@ async def generate( sampling_params = SamplingParams.from_buffer_copy(origin_sampling_params) block_group_request_id = self.id_gen.generate_id() sampling_params.group_request_id = block_group_request_id - logger.info( - f"pd log gen sub req id {block_group_request_id}" f" for main req id {origin_group_request_id}" - ) + logger.info(f"pd log gen sub req id {block_group_request_id} for main req id {origin_group_request_id}") sampling_params.max_new_tokens = block_max_new_tokens results_generator = self._wait_to_token_package( diff --git a/lightllm/server/router/stats.py b/lightllm/server/router/stats.py index 9f4ef8bee7..b715c5bcb3 100644 --- a/lightllm/server/router/stats.py +++ b/lightllm/server/router/stats.py @@ -8,13 +8,17 @@ class RouterStatics: def __init__(self, args: StartArgs): self.busy_token_used_ratio = args.router_token_ratio self.ema_req_out_len = 2048 - self.ema_params = 0.04 + self.cur_ema_params = 0.5 + self.min_ema_params = 0.04 def update(self, req_out_len: int): # 过滤掉输出特别短的情况,防止计算得过于短,导致调度频繁引发暂停,导致系统吞吐下降。 req_out_len = max(req_out_len, 64) - self.ema_req_out_len = int(self.ema_req_out_len * (1 - self.ema_params) + req_out_len * self.ema_params) + self.ema_req_out_len = int(self.ema_req_out_len * (1 - self.cur_ema_params) + req_out_len * self.cur_ema_params) self.ema_req_out_len = max(64, self.ema_req_out_len) + # 不断的调整ema 的计算参数,这样可以在早期,快速将 ema_req_out_len 调整到接近 + # 当前分布的水平,然后后期趋于稳定调整。 + self.cur_ema_params = max(self.min_ema_params, self.cur_ema_params * 0.8) def log_str(self) -> str: return (