Skip to content

Commit 15a0c2b

Browse files
committed
[TRTLLM-5972][chore] Load balance decode token KV cache with helix parallelism
Signed-off-by: Balaram Buddharaju <169953907+brb-nv@users.noreply.github.com>
1 parent d6f961d commit 15a0c2b

File tree

7 files changed

+20
-10
lines changed

7 files changed

+20
-10
lines changed

tensorrt_llm/_torch/pyexecutor/executor_request_queue.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,7 @@ def _merge_helix_requests(self, new_requests: list[RequestQueueItem],
694694
position_ids=position_ids_this_rank,
695695
)
696696
req.total_input_len_cp = input_len
697+
req.seqlen_this_rank_cp = len(input_ids_this_rank)
697698
req_with_children.append(req)
698699
if req.child_requests:
699700
req_with_children.extend(req.child_requests)

tensorrt_llm/_torch/pyexecutor/llm_request.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -489,6 +489,8 @@ def __init__(
489489
self.py_max_new_tokens = self.max_new_tokens
490490
self.py_min_length = self.sampling_config.min_length
491491
self.py_helix_is_inactive_rank = False
492+
self.seqlen_this_rank_cp = 0
493+
self.total_input_len_cp = 0
492494
self.py_batch_idx = None
493495
self.py_draft_pages_allocated = 0
494496
self.py_rewind_len = 0

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1671,12 +1671,12 @@ def _prepare_tp_inputs(
16711671
# Warmup doesn't have `total_input_len_cp` set because merge_helix_requests is not called.
16721672
if not self.is_warmup and not request.is_cuda_graph_dummy:
16731673
position_id = request.total_input_len_cp + request.py_decoding_iter - 1
1674-
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
1675-
if self.mapping.cp_rank == self.mapping.cp_size - 1:
1676-
past_seen_token_num = request.orig_prompt_len + request.py_decoding_iter - 1
1674+
if request.py_helix_is_inactive_rank:
1675+
past_seen_token_num = request.seqlen_this_rank_cp
16771676
else:
1678-
# past_seen_token_num doesn't grow on inactive ranks.
1679-
past_seen_token_num = request.orig_prompt_len
1677+
# Discount the token added to active rank in resource manager as it hasn't
1678+
# been previously seen.
1679+
past_seen_token_num = request.seqlen_this_rank_cp - 1
16801680

16811681
position_ids.append(position_id)
16821682
num_cached_tokens_per_seq.append(past_seen_token_num)

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -468,13 +468,17 @@ def prepare_resources(self, scheduled_batch: ScheduledRequests):
468468
req, block_ids)
469469

470470
for req in generation_batch:
471-
# TODO: [TRTLLM-5972] Lift the limitation that last rank is always the active one for helix.
472471
if self.mapping.has_cp_helix():
473-
if self.mapping.cp_rank != self.mapping.cp_size - 1:
472+
# Distribute the decode blocks across CP ranks in a round-robin manner.
473+
decode_block_id = (req.py_decoding_iter -
474+
1) // self.tokens_per_block
475+
if decode_block_id % self.mapping.cp_size == self.mapping.cp_rank:
476+
req.py_helix_is_inactive_rank = False
477+
req.seqlen_this_rank_cp += 1
478+
else:
474479
req.py_helix_is_inactive_rank = True
475-
# Skip allocating KV cache at decode for inactive helix ranks.
476-
if req.py_helix_is_inactive_rank:
477-
continue
480+
# Skip allocating KV cache at decode for inactive helix ranks.
481+
continue
478482
self.impl.add_token(req.py_request_id)
479483
for _ in range(get_draft_token_length(req)):
480484
self.impl.add_token(req.py_request_id)

tests/integration/test_lists/qa/llm_function_core.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,7 @@ accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding
519519
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[xgrammar-mtp_nextn=2]
520520
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=0]
521521
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_guided_decoding[llguidance-mtp_nextn=2]
522+
accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix
522523
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[False]
523524
accuracy/test_disaggregated_serving.py::TestGemma3_1BInstruct::test_auto_dtype[True]
524525
accuracy/test_disaggregated_serving.py::TestGPTOSS::test_auto_dtype[True]

tests/integration/test_lists/test-db/l0_dgx_b200.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,3 +221,4 @@ l0_dgx_b200:
221221
- accuracy/test_llm_api_pytorch.py::TestLlama3_1_8BInstruct::test_fp8_4gpus[pp4-fp8kv=True-attn_backend=TRTLLM-torch_compile=False]
222222
- accuracy/test_llm_api_pytorch.py::TestDeepSeekV3Lite::test_nvfp4_4gpus[moe_backend=CUTEDSL-mtp_nextn=2-ep4-fp8kv=True-attention_dp=True-cuda_graph=True-overlap_scheduler=True-torch_compile=False]
223223
- accuracy/test_llm_api_pytorch.py::TestLlama3_3_70BInstruct::test_fp4_tp2pp2
224+
- accuracy/test_disaggregated_serving.py::TestDeepSeekV3Lite::test_auto_dtype_with_helix

tests/unittest/_torch/executor/test_pytorch_model_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def test_prepare_tp_inputs_with_helix_parallelism(self) -> None:
407407
req.sampling_config.beam_width = 1
408408
req.py_multimodal_data = {}
409409
req.total_input_len_cp = prompt_lens[idx] * 2
410+
req.seqlen_this_rank_cp = prompt_lens[idx]
410411
req.py_decoding_iter = 1
411412
gen_requests.append(req)
412413
scheduled_requests.generation_requests = gen_requests

0 commit comments

Comments
 (0)