-
Notifications
You must be signed in to change notification settings - Fork 309
auto set schedule way. #1235
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+187
−107
Merged
auto set schedule way. #1235
Changes from all commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
04ad8ca
fix
hiworldwzj 87e54c8
fix
hiworldwzj 3830b45
fix
hiworldwzj 43e88be
fix
hiworldwzj d60e7f7
split pd split max new tokens.
hiworldwzj 2c191ff
fix
hiworldwzj 456e189
fix
hiworldwzj 8954561
fix
hiworldwzj 2bbc013
fix set unique_server_name
hiworldwzj 63fcf1f
fix
hiworldwzj a52831e
fix
hiworldwzj 962930c
fix
hiworldwzj 5976fd1
fix
hiworldwzj File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |||||||||||||||||
| from lightllm.utils.statics_utils import MovingAverage | ||||||||||||||||||
| from lightllm.server.httpserver.manager import AsyncQueue | ||||||||||||||||||
| from lightllm.utils.error_utils import ServerBusyError | ||||||||||||||||||
| from lightllm.utils.envs_utils import get_pd_split_max_new_tokens | ||||||||||||||||||
| from .pd_selector import create_selector | ||||||||||||||||||
|
|
||||||||||||||||||
| logger = init_logger(__name__) | ||||||||||||||||||
|
|
@@ -31,6 +32,8 @@ def __init__( | |||||||||||||||||
| args: StartArgs, | ||||||||||||||||||
| ): | ||||||||||||||||||
| self.args = args | ||||||||||||||||||
| self.max_req_total_len = args.max_req_total_len | ||||||||||||||||||
| assert self.max_req_total_len is not None | ||||||||||||||||||
| self.metric_client = MetricClient(args.metric_port) | ||||||||||||||||||
| self.id_gen = ReqIDGenerator() | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
@@ -92,44 +95,84 @@ async def generate( | |||||||||||||||||
| multimodal_params: MultimodalParams, | ||||||||||||||||||
| request: Request, | ||||||||||||||||||
| ): | ||||||||||||||||||
| assert isinstance(prompt, str), "prompt must be str" | ||||||||||||||||||
| start_time = time.time() | ||||||||||||||||||
| group_request_id = self.id_gen.generate_id() | ||||||||||||||||||
| # 计算输入的 input_token_num, 进行校验,如果输入+输出参数设置太长,则将 | ||||||||||||||||||
| # sampling_params 的参数进行修正。 | ||||||||||||||||||
| input_token_num = self.tokens(prompt, multimodal_params, sampling_params) | ||||||||||||||||||
| fake_prompt_ids = [0 for _ in range(input_token_num)] | ||||||||||||||||||
| from lightllm.server.httpserver.manager import HttpServerManager | ||||||||||||||||||
|
|
||||||||||||||||||
| await HttpServerManager._check_and_repair_length( | ||||||||||||||||||
| self, prompt_ids=fake_prompt_ids, sampling_params=sampling_params | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| # 先将请求根据max_new_tokens 参数进行分块操作,主要是 pd 分离场景中, | ||||||||||||||||||
| # 只能使用保守调度,但是如果用户都设置一个很大的 max_new_tokens 值,会 | ||||||||||||||||||
| # 导致极大显存预留,照成系统的吞吐能力下降,所以我们将请求分割成几段进行 | ||||||||||||||||||
| # 推理,只要保证分块合理,实际分段推理是极少发生的情况,系统吞吐就不会受 | ||||||||||||||||||
| # 到影响。 | ||||||||||||||||||
| origin_sampling_params = SamplingParams.from_buffer_copy(sampling_params) | ||||||||||||||||||
| origin_group_request_id = self.id_gen.generate_id() | ||||||||||||||||||
| max_new_tokens_list = self._split_max_new_tokens(max_new_tokens=origin_sampling_params.max_new_tokens) | ||||||||||||||||||
|
|
||||||||||||||||||
| try: | ||||||||||||||||||
| sampling_params.group_request_id = group_request_id | ||||||||||||||||||
| # 记录请求到达的相关信息 | ||||||||||||||||||
| await self._log_req_header(request, group_request_id) | ||||||||||||||||||
| await self._log_req_header(request, origin_group_request_id) | ||||||||||||||||||
| # 监控 | ||||||||||||||||||
| self.metric_client.counter_inc("lightllm_request_count") | ||||||||||||||||||
| self.metric_client.histogram_observe("lightllm_request_max_new_tokens", sampling_params.max_new_tokens) | ||||||||||||||||||
| self.metric_client.histogram_observe( | ||||||||||||||||||
| "lightllm_request_max_new_tokens", origin_sampling_params.max_new_tokens | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| p_node, d_node = await self.select_p_d_node(prompt, origin_sampling_params, multimodal_params) | ||||||||||||||||||
|
|
||||||||||||||||||
| p_node, d_node = await self.select_p_d_node(prompt, sampling_params, multimodal_params) | ||||||||||||||||||
| history_gen_token_strs = [] | ||||||||||||||||||
|
|
||||||||||||||||||
| if not p_node or not d_node: | ||||||||||||||||||
| logger.error(f"{group_request_id}: No p_node or d_node found") | ||||||||||||||||||
| raise Exception(f"{group_request_id}: No p_node or d_node found") | ||||||||||||||||||
|
|
||||||||||||||||||
| results_generator = self._wait_to_token_package( | ||||||||||||||||||
| p_node, | ||||||||||||||||||
| d_node, | ||||||||||||||||||
| start_time, | ||||||||||||||||||
| prompt, | ||||||||||||||||||
| sampling_params, | ||||||||||||||||||
| multimodal_params, | ||||||||||||||||||
| request, | ||||||||||||||||||
| ) | ||||||||||||||||||
| async for sub_req_id, request_output, metadata, finish_status in results_generator: | ||||||||||||||||||
| yield sub_req_id, request_output, metadata, finish_status | ||||||||||||||||||
| logger.error(f"{origin_group_request_id}: No p_node or d_node found") | ||||||||||||||||||
| raise Exception(f"{origin_group_request_id}: No p_node or d_node found") | ||||||||||||||||||
|
|
||||||||||||||||||
| for iter_index, block_max_new_tokens in enumerate(max_new_tokens_list): | ||||||||||||||||||
| sampling_params = SamplingParams.from_buffer_copy(origin_sampling_params) | ||||||||||||||||||
| block_group_request_id = self.id_gen.generate_id() | ||||||||||||||||||
| sampling_params.group_request_id = block_group_request_id | ||||||||||||||||||
| logger.info(f"pd log gen sub req id {block_group_request_id} for main req id {origin_group_request_id}") | ||||||||||||||||||
| sampling_params.max_new_tokens = block_max_new_tokens | ||||||||||||||||||
|
|
||||||||||||||||||
| results_generator = self._wait_to_token_package( | ||||||||||||||||||
| p_node, | ||||||||||||||||||
| d_node, | ||||||||||||||||||
| start_time, | ||||||||||||||||||
| prompt + "".join(history_gen_token_strs), | ||||||||||||||||||
| sampling_params, | ||||||||||||||||||
| multimodal_params, | ||||||||||||||||||
| request, | ||||||||||||||||||
| ) | ||||||||||||||||||
| is_last_block = iter_index == len(max_new_tokens_list) - 1 | ||||||||||||||||||
| prompt_tokens = sys.maxsize # 因为分段的原因 | ||||||||||||||||||
| async for sub_req_id, request_output, metadata, finish_status in results_generator: | ||||||||||||||||||
| # pd 分离模式下,返回的 metadata 可能序号信息可能存在不准确性。 | ||||||||||||||||||
| assert sub_req_id == block_group_request_id | ||||||||||||||||||
| if finish_status.get_finish_reason() == "length" and (not is_last_block): | ||||||||||||||||||
| finish_status = FinishStatus() # 转换为NoFinished | ||||||||||||||||||
| history_gen_token_strs.append(request_output) | ||||||||||||||||||
| prompt_tokens = min(prompt_tokens, metadata["prompt_tokens"]) | ||||||||||||||||||
| metadata["prompt_tokens"] = prompt_tokens | ||||||||||||||||||
| yield origin_group_request_id, request_output, metadata, finish_status | ||||||||||||||||||
|
|
||||||||||||||||||
| await self.remove_req(group_request_id=block_group_request_id) | ||||||||||||||||||
|
|
||||||||||||||||||
| except BaseException as e: | ||||||||||||||||||
| logger.error(f"has exception {str(e)}") | ||||||||||||||||||
| try: | ||||||||||||||||||
| await self.abort(group_request_id, p_node=p_node, d_node=d_node) | ||||||||||||||||||
| await self.abort(block_group_request_id, p_node=p_node, d_node=d_node) | ||||||||||||||||||
| except: | ||||||||||||||||||
| await self.abort(group_request_id) | ||||||||||||||||||
| await self.abort(block_group_request_id) | ||||||||||||||||||
| raise e | ||||||||||||||||||
|
|
||||||||||||||||||
| finally: | ||||||||||||||||||
| await self.remove_req(group_request_id) | ||||||||||||||||||
| await self.remove_req(block_group_request_id) | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| async def _log_req_header(self, request: Request, group_request_id: int): | ||||||||||||||||||
|
|
@@ -465,6 +508,14 @@ async def handle_loop(self): | |||||||||||||||||
| logger.exception(str(e)) | ||||||||||||||||||
| return | ||||||||||||||||||
|
|
||||||||||||||||||
| def _split_max_new_tokens(self, max_new_tokens: int) -> List[int]: | ||||||||||||||||||
| block_max_new_tokens = get_pd_split_max_new_tokens() | ||||||||||||||||||
| ans_list = [block_max_new_tokens for _ in range(max_new_tokens // block_max_new_tokens)] | ||||||||||||||||||
| left_token = max_new_tokens - (max_new_tokens // block_max_new_tokens) * block_max_new_tokens | ||||||||||||||||||
| if left_token > 0: | ||||||||||||||||||
| ans_list.append(left_token) | ||||||||||||||||||
|
Comment on lines
+513
to
+516
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This implementation can be made more concise and Pythonic by using the modulo operator (
Suggested change
|
||||||||||||||||||
| return ans_list | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| class ReqStatus: | ||||||||||||||||||
| def __init__(self, req_id, p_node, d_node) -> None: | ||||||||||||||||||
|
|
@@ -510,6 +561,15 @@ def __init__(self, args: StartArgs): | |||||||||||||||||
|
|
||||||||||||||||||
| def register_pd(self, pd_info_json, websocket): | ||||||||||||||||||
| pd_client = PD_Client_Obj(**pd_info_json) | ||||||||||||||||||
| client_max_req_total_len = pd_client.start_args["max_req_total_len"] | ||||||||||||||||||
| if client_max_req_total_len != self.args.max_req_total_len: | ||||||||||||||||||
| logger.error( | ||||||||||||||||||
| f"client dont has same max_req_total_len params, but pd master is {self.args.max_req_total_len}" | ||||||||||||||||||
| f"client is {client_max_req_total_len}" | ||||||||||||||||||
| f"client info {pd_info_json}" | ||||||||||||||||||
| ) | ||||||||||||||||||
| assert False | ||||||||||||||||||
|
|
||||||||||||||||||
| pd_client.websocket = websocket | ||||||||||||||||||
| self.url_to_pd_nodes[pd_client.client_ip_port] = pd_client | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The exception and finalization logic in the
generatemethod has a couple of critical issues that could lead to runtime errors:NameError:block_group_request_idis defined inside theforloop. If an exception occurs before the loop starts (e.g., inselect_p_d_node), theexceptandfinallyblocks will raise aNameErrorwhen trying to accessblock_group_request_id.remove_reqcall: On successful completion of all chunks,remove_reqis called for the last chunk at the end of theforloop (line 154), and then called again in thefinallyblock. This will likely cause an error (e.g.,KeyError) asremove_reqis probably not idempotent.To fix this, I suggest initializing
block_group_request_id = Nonebefore thetryblock and ensuring cleanup logic handles all edge cases correctly without causing errors. For instance, thefinallyblock could be removed and cleanup handled within theexceptblock for failures, while the success path is already handled inside the loop.