Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions docs/CN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,6 @@ PD 分离模式参数

判断服务是否繁忙的阈值,默认为 ``0.0``,一旦kv cache 使用率超过此值,则会直接变为保守调度。

.. option:: --router_max_new_token_len

调度器评估请求kv占用时,使用的请求输出长度,默认为 ``1024``,一般低于用户设置的max_new_tokens。该参数只在 --router_token_ratio 大于0时生效。
设置改参数,会使请求调度更为激进,系统同时处理的请求数会更多,同时也会不可避免的造成请求的暂停重计算。

.. option:: --router_max_wait_tokens

每 router_max_wait_tokens 解码步骤后触发一次调度新请求,默认为 ``6``
Expand Down
5 changes: 0 additions & 5 deletions docs/EN/source/tutorial/api_server_args.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,6 @@ Scheduling Parameters

Threshold for determining if the service is busy, default is ``0.0``. Once the kv cache usage exceeds this value, it will directly switch to conservative scheduling.

.. option:: --router_max_new_token_len

The request output length used by the scheduler when evaluating request kv usage, default is ``1024``, generally lower than the max_new_tokens set by the user. This parameter only takes effect when --router_token_ratio is greater than 0.
Setting this parameter will make request scheduling more aggressive, allowing the system to process more requests simultaneously, but will inevitably cause request pause and recalculation.

.. option:: --router_max_wait_tokens

Trigger scheduling of new requests every router_max_wait_tokens decoding steps, default is ``6``
Expand Down
20 changes: 16 additions & 4 deletions lightllm/server/api_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,12 @@ def make_argument_parser() -> argparse.ArgumentParser:
help="the dp balancer type, default is bs_balancer",
)
parser.add_argument(
"--max_req_total_len", type=int, default=16384, help="the max value for req_input_len + req_output_len"
"--max_req_total_len",
type=int,
default=16384,
help="Maximum allowed length for a request (input tokens + output tokens). "
"In PD (Prefill-Decode) mode, this value must be synchronized across the "
"PD master, prefill, and decode nodes.",
)
parser.add_argument(
"--nccl_host",
Expand Down Expand Up @@ -237,11 +242,18 @@ def make_argument_parser() -> argparse.ArgumentParser:
This setting allows you to turn off these warning checks.""",
)

parser.add_argument("--router_token_ratio", type=float, default=0.0, help="token ratio to control router dispatch")
parser.add_argument(
"--router_max_new_token_len", type=int, default=1024, help="the request max new token len for router"
"--router_token_ratio",
type=float,
default=None,
help="""Token used ratio to control router dispatch, range in [0.0, 1.0].
When the token VRAM usage ratio is higher than this value,
the dispatch strategy tends to be conservative.
When the token VRAM usage ratio is lower than this value,
the dispatching of requests tends to be aggressive.
The default value is None, meaning it will be automatically
determined by the internal system based on other startup parameters.""",
)

parser.add_argument(
"--router_max_wait_tokens",
type=int,
Expand Down
26 changes: 18 additions & 8 deletions lightllm/server/api_start.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,17 @@ def normal_or_p_d_start(args):
if args.enable_multimodal:
args.multi_modal_cache_shm_id = uuid.uuid1().int % 123456789

# 调度参数的自动设置, 人工设置则听人工的
if args.router_token_ratio is None:
if args.run_mode in ["normal"]:
args.router_token_ratio = 0.85
else:
# pd 分离模式下,不开启高级调度
args.router_token_ratio = 0.0
# 部分模式还不能支持与高级动态调度算法协同,to do.
if args.diverse_mode:
assert args.router_token_ratio == 0.0

if not args.disable_shm_warning:
check_recommended_shm_size(args)

Expand Down Expand Up @@ -146,10 +157,6 @@ def normal_or_p_d_start(args):
assert args.disable_dynamic_prompt_cache is True, "need add --disable_dynamic_prompt_cache"
assert args.disable_chunked_prefill is True, "need add --disable_chunked_prefill"

# 部分模式还不能支持与高级动态调度算法协同,to do.
if args.diverse_mode:
assert args.router_token_ratio == 0.0

if args.enable_dp_prefill_balance:
assert args.enable_tpsp_mix_mode and args.dp > 1, "need set --enable_tpsp_mix_mode firstly and --dp > 1"

Expand Down Expand Up @@ -235,7 +242,7 @@ def normal_or_p_d_start(args):

node_world_size = args.tp // args.nnodes
can_use_ports = alloc_can_use_network_port(
num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_nccl_ports=already_uesd_ports
num=10 + node_world_size + args.visual_dp * (args.visual_tp + 1), used_ports=already_uesd_ports
)
logger.info(f"alloced ports: {can_use_ports}")
(
Expand Down Expand Up @@ -409,7 +416,12 @@ def pd_master_start(args):
logger.info(f"use tgi api: {args.use_tgi_api}")
logger.info(f"all start args:{args}")

can_use_ports = alloc_can_use_network_port(num=1, used_nccl_ports=[args.nccl_port, args.port])
can_use_ports = alloc_can_use_network_port(
num=1,
used_ports=[
args.port,
],
)
metric_port = can_use_ports[0]

args.metric_port = metric_port
Expand All @@ -435,7 +447,6 @@ def pd_master_start(args):
"-",
"--error-logfile",
"-",
"--preload",
"lightllm.server.api_http:app",
"--keep-alive",
f"{get_lightllm_gunicorn_keep_alive()}",
Expand Down Expand Up @@ -473,7 +484,6 @@ def config_server_start(args):
"-",
"--error-logfile",
"-",
"--preload",
"lightllm.server.config_server.api_http:app",
"--keep-alive",
f"{get_lightllm_gunicorn_keep_alive()}",
Expand Down
8 changes: 3 additions & 5 deletions lightllm/server/core/objs/req.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def can_release(self):
def get_used_tokens(self):
return max(0, self.shm_cur_kv_len)

def get_tuple_tokens(self, is_busy, router_max_new_token_len):
def get_tuple_tokens(self, is_busy, ema_req_out_len):
raise NotImplementedError("Subclasses should implement this method")

def get_decode_need_tokens(self):
Expand Down Expand Up @@ -311,7 +311,7 @@ def print_time_log(self, log_info: str):
class ChunkedPrefillReq(Req):
_pack_ = 4

def get_tuple_tokens(self, is_busy, router_max_new_token_len):
def get_tuple_tokens(self, is_busy, ema_req_out_len):
args = get_env_start_args()
# chuncked prefill 推理的过程中,存在很多模式的延迟 step 推理的控制, 用于
# 保证更好的包间数据或者是提升 dp 模式下prefill 的效率,但是在估计 token 显存
Expand All @@ -327,9 +327,7 @@ def get_tuple_tokens(self, is_busy, router_max_new_token_len):
elif is_busy:
cur_max_new_token_len = self.sample_params.max_new_tokens
else:
cur_max_new_token_len = min(
self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), router_max_new_token_len)
)
cur_max_new_token_len = min(self.sample_params.max_new_tokens, max(int(1.1 * has_out_len), ema_req_out_len))

a_len = max(self.input_len + has_out_len + 1, self.shm_cur_kv_len + 1)
b_len = (
Expand Down
1 change: 0 additions & 1 deletion lightllm/server/core/objs/start_args_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ class StartArgs:
disable_log_stats: bool = field(default=False)
log_stats_interval: int = field(default=10)
router_token_ratio: float = field(default=0.0)
router_max_new_token_len: int = field(default=1024)
router_max_wait_tokens: int = field(default=1)
disable_aggressive_schedule: bool = field(default=False)
disable_dynamic_prompt_cache: bool = field(default=False)
Expand Down
19 changes: 15 additions & 4 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,10 +480,21 @@ async def _check_and_repair_length(self, prompt_ids: List[int], sampling_params:
if prompt_tokens + sampling_params.max_new_tokens > self.max_req_total_len:
# use long_truncation_mode to truncate long input len req.
if self.args.long_truncation_mode is None:
raise ValueError(
f"the input prompt token len {prompt_tokens} + max_new_tokens \
{sampling_params.max_new_tokens} > {self.max_req_total_len}"
)
# 修改默认逻辑,如果 prompt_tokens + max_new_tokens 长度超过总的允许长度,则将
# 修改 max_new_tokens 的值,使其满足合法约束。
new_max_new_tokens = self.max_req_total_len - prompt_tokens
if new_max_new_tokens > 0:
logger.debug(
f"the input prompt token len {prompt_tokens} + max_new_tokens"
f"{sampling_params.max_new_tokens} > {self.max_req_total_len},"
f"so change max_new_tokens to {new_max_new_tokens}"
)
sampling_params.max_new_tokens = new_max_new_tokens
else:
raise ValueError(
f"the input prompt token len {prompt_tokens} + max_new_tokens \
{sampling_params.max_new_tokens} > {self.max_req_total_len}"
)
elif self.args.long_truncation_mode == "head":
prompt_ids = prompt_ids[-(self.max_req_total_len - sampling_params.max_new_tokens) :]
elif self.args.long_truncation_mode == "center":
Expand Down
104 changes: 82 additions & 22 deletions lightllm/server/httpserver_for_pd_master/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from lightllm.utils.statics_utils import MovingAverage
from lightllm.server.httpserver.manager import AsyncQueue
from lightllm.utils.error_utils import ServerBusyError
from lightllm.utils.envs_utils import get_pd_split_max_new_tokens
from .pd_selector import create_selector

logger = init_logger(__name__)
Expand All @@ -31,6 +32,8 @@ def __init__(
args: StartArgs,
):
self.args = args
self.max_req_total_len = args.max_req_total_len
assert self.max_req_total_len is not None
self.metric_client = MetricClient(args.metric_port)
self.id_gen = ReqIDGenerator()

Expand Down Expand Up @@ -92,44 +95,84 @@ async def generate(
multimodal_params: MultimodalParams,
request: Request,
):
assert isinstance(prompt, str), "prompt must be str"
start_time = time.time()
group_request_id = self.id_gen.generate_id()
# 计算输入的 input_token_num, 进行校验,如果输入+输出参数设置太长,则将
# sampling_params 的参数进行修正。
input_token_num = self.tokens(prompt, multimodal_params, sampling_params)
fake_prompt_ids = [0 for _ in range(input_token_num)]
from lightllm.server.httpserver.manager import HttpServerManager

await HttpServerManager._check_and_repair_length(
self, prompt_ids=fake_prompt_ids, sampling_params=sampling_params
)

# 先将请求根据max_new_tokens 参数进行分块操作,主要是 pd 分离场景中,
# 只能使用保守调度,但是如果用户都设置一个很大的 max_new_tokens 值,会
# 导致极大显存预留,照成系统的吞吐能力下降,所以我们将请求分割成几段进行
# 推理,只要保证分块合理,实际分段推理是极少发生的情况,系统吞吐就不会受
# 到影响。
origin_sampling_params = SamplingParams.from_buffer_copy(sampling_params)
origin_group_request_id = self.id_gen.generate_id()
max_new_tokens_list = self._split_max_new_tokens(max_new_tokens=origin_sampling_params.max_new_tokens)

try:
sampling_params.group_request_id = group_request_id
# 记录请求到达的相关信息
await self._log_req_header(request, group_request_id)
await self._log_req_header(request, origin_group_request_id)
# 监控
self.metric_client.counter_inc("lightllm_request_count")
self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens)
self.metric_client.histogram_observe(
"lightllm_request_max_new_tokens", origin_sampling_params.max_new_tokens
)

p_node, d_node = await self.select_p_d_node(prompt, origin_sampling_params, multimodal_params)

p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params)
history_gen_token_strs = []

if not p_node or not d_node:
logger.error(f"{group_request_id}: No p_node or d_node found")
raise Exception(f"{group_request_id}: No p_node or d_node found")

results_generator = self._wait_to_token_package(
p_node,
d_node,
start_time,
prompt,
sampling_params,
multimodal_params,
request,
)
async for sub_req_id, request_output, metadata, finish_status in results_generator:
yield sub_req_id, request_output, metadata, finish_status
logger.error(f"{origin_group_request_id}: No p_node or d_node found")
raise Exception(f"{origin_group_request_id}: No p_node or d_node found")

for iter_index, block_max_new_tokens in enumerate(max_new_tokens_list):
sampling_params = SamplingParams.from_buffer_copy(origin_sampling_params)
block_group_request_id = self.id_gen.generate_id()
sampling_params.group_request_id = block_group_request_id
logger.info(f"pd log gen sub req id {block_group_request_id} for main req id {origin_group_request_id}")
sampling_params.max_new_tokens = block_max_new_tokens

results_generator = self._wait_to_token_package(
p_node,
d_node,
start_time,
prompt + "".join(history_gen_token_strs),
sampling_params,
multimodal_params,
request,
)
is_last_block = iter_index == len(max_new_tokens_list) - 1
prompt_tokens = sys.maxsize # 因为分段的原因
async for sub_req_id, request_output, metadata, finish_status in results_generator:
# pd 分离模式下,返回的 metadata 可能序号信息可能存在不准确性。
assert sub_req_id == block_group_request_id
if finish_status.get_finish_reason() == "length" and (not is_last_block):
finish_status = FinishStatus() # 转换为NoFinished
history_gen_token_strs.append(request_output)
prompt_tokens = min(prompt_tokens, metadata["prompt_tokens"])
metadata["prompt_tokens"] = prompt_tokens
yield origin_group_request_id, request_output, metadata, finish_status

await self.remove_req(group_request_id=block_group_request_id)

except BaseException as e:
logger.error(f"has exception {str(e)}")
try:
await self.abort(group_request_id, p_node=p_node, d_node=d_node)
await self.abort(block_group_request_id, p_node=p_node, d_node=d_node)
except:
await self.abort(group_request_id)
await self.abort(block_group_request_id)
raise e

finally:
await self.remove_req(group_request_id)
await self.remove_req(block_group_request_id)
return
Comment on lines 119 to 176
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The exception and finalization logic in the generate method has a couple of critical issues that could lead to runtime errors:

  1. Potential NameError: block_group_request_id is defined inside the for loop. If an exception occurs before the loop starts (e.g., in select_p_d_node), the except and finally blocks will raise a NameError when trying to access block_group_request_id.
  2. Double remove_req call: On successful completion of all chunks, remove_req is called for the last chunk at the end of the for loop (line 154), and then called again in the finally block. This will likely cause an error (e.g., KeyError) as remove_req is probably not idempotent.

To fix this, I suggest initializing block_group_request_id = None before the try block and ensuring cleanup logic handles all edge cases correctly without causing errors. For instance, the finally block could be removed and cleanup handled within the except block for failures, while the success path is already handled inside the loop.


async def _log_req_header(self, request: Request, group_request_id: int):
Expand Down Expand Up @@ -465,6 +508,14 @@ async def handle_loop(self):
logger.exception(str(e))
return

def _split_max_new_tokens(self, max_new_tokens: int) -> List[int]:
block_max_new_tokens = get_pd_split_max_new_tokens()
ans_list = [block_max_new_tokens for _ in range(max_new_tokens // block_max_new_tokens)]
left_token = max_new_tokens - (max_new_tokens // block_max_new_tokens) * block_max_new_tokens
if left_token > 0:
ans_list.append(left_token)
Comment on lines +513 to +516
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This implementation can be made more concise and Pythonic by using the modulo operator (%) for the remainder and list multiplication for creating the list of blocks.

Suggested change
ans_list = [block_max_new_tokens for _ in range(max_new_tokens // block_max_new_tokens)]
left_token = max_new_tokens - (max_new_tokens // block_max_new_tokens) * block_max_new_tokens
if left_token > 0:
ans_list.append(left_token)
ans_list = [block_max_new_tokens] * (max_new_tokens // block_max_new_tokens)
remainder = max_new_tokens % block_max_new_tokens
if remainder > 0:
ans_list.append(remainder)

return ans_list


class ReqStatus:
def __init__(self, req_id, p_node, d_node) -> None:
Expand Down Expand Up @@ -510,6 +561,15 @@ def __init__(self, args: StartArgs):

def register_pd(self, pd_info_json, websocket):
pd_client = PD_Client_Obj(**pd_info_json)
client_max_req_total_len = pd_client.start_args["max_req_total_len"]
if client_max_req_total_len != self.args.max_req_total_len:
logger.error(
f"client dont has same max_req_total_len params, but pd master is {self.args.max_req_total_len}"
f"client is {client_max_req_total_len}"
f"client info {pd_info_json}"
)
assert False

pd_client.websocket = websocket
self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client

Expand Down
6 changes: 5 additions & 1 deletion lightllm/server/router/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Dict, List, Optional, Tuple, Union
from lightllm.server.core.objs import ShmReqManager, Req
from lightllm.utils.log_utils import init_logger
from .stats import RouterStatics

logger = init_logger(__name__)

Expand Down Expand Up @@ -49,11 +50,14 @@ def get_all_dp_req_num(self) -> List[int]:
all_dp_req_num[req.sample_params.suggested_dp_index] += 1
return all_dp_req_num

def filter_out_finished_req(self, shm_req_manager: ShmReqManager):
def filter_out_finished_req(self, shm_req_manager: ShmReqManager, router_statics: RouterStatics):
unfinished_req_ids = []
for req in self.reqs:
if req.shm_infer_released:
logger.info(f"router release req id {req.request_id}")
if not req.is_aborted:
router_statics.update(req.candetoken_out_len)

shm_req_manager.put_back_req_obj(req)
req = None
else:
Expand Down
Loading
Loading