From 0d5ad381fd61b1a1cd058da2ddd888694cde9589 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Thu, 2 Apr 2026 15:29:59 +0800 Subject: [PATCH 1/2] Fix int32 overflow --- .../layers/moe/routing_indices_cache.py | 43 +++++++++---------- 1 file changed, 20 insertions(+), 23 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index da423bd3da2..06914ee2be9 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -54,6 +54,7 @@ def _save_routing_kernel( TOP_K, NUM_HIDDEN_LAYERS, MAX_MODEL_LEN, + MAX_NUM_SEQS, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, ): @@ -63,49 +64,44 @@ def _save_routing_kernel( token_mask = token_offsets < TOKEN_NUM k_offsets = tl.arange(0, BLOCK_SIZE_K) - k_mask = k_offsets < TOP_K topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :] - # [BLOCK_SIZE_M, BLOCK_SIZE_K] - load_mask = token_mask[:, None] & k_mask[None, :] - topk_vals = tl.load(topk_ids_ptrs, mask=load_mask) - - batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask) - pad_mask = token_mask & (batch_ids != -1) - # [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3] - # -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10] - # -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1] - start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask) + topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1) + + batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1) + + batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS) + pad_mask = token_mask & (batch_ids != -1) & batch_mask + + start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0) token_relative_index = token_offsets - start_offsets - # [BLOCK_SIZE_M] - len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask) + len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0) token_seq_pos = len_decoder + token_relative_index - STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K - STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K + STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64) + STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64) STRIDE_BUF_LAYER = TOP_K - # [BLOCK_SIZE_M, BLOCK_SIZE_K] output_ptrs = ( ROUTING_REPLAY_TABLE_PTR - + batch_ids[:, None] * STRIDE_BUF_SEQ - + token_seq_pos[:, None] * STRIDE_BUF_TOKEN - + LAYER_IDX * STRIDE_BUF_LAYER + + tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ + + tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN + + tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER + k_offsets[None, :] ) - pos_mask = token_seq_pos < MAX_MODEL_LEN + pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN) pos_mask = pos_mask & pad_mask - - # [BLOCK_SIZE_M, BLOCK_SIZE_K] pos_mask = pos_mask[:, None] & k_mask[None, :] final_mask = load_mask & pos_mask + batch_valid = (batch_ids[:, None] >= 0) & (batch_ids[:, None] < MAX_NUM_SEQS) + final_mask = final_mask & batch_valid + tl.store(output_ptrs, topk_vals, mask=final_mask) @@ -150,6 +146,7 @@ def save_routing_to_buffer( TOP_K=top_k, NUM_HIDDEN_LAYERS=num_hidden_layers, MAX_MODEL_LEN=max_model_len, + MAX_NUM_SEQS=max_num_seqs, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_K=BLOCK_SIZE_K, ) From a386642b84ccbdb88d0b7fec5482e2e26fc46c63 Mon Sep 17 00:00:00 2001 From: gongshaotian Date: Thu, 2 Apr 2026 19:14:05 +0800 Subject: [PATCH 2/2] refine code --- fastdeploy/model_executor/layers/moe/routing_indices_cache.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py index 06914ee2be9..efd29477f2b 100644 --- a/fastdeploy/model_executor/layers/moe/routing_indices_cache.py +++ b/fastdeploy/model_executor/layers/moe/routing_indices_cache.py @@ -99,9 +99,6 @@ def _save_routing_kernel( final_mask = load_mask & pos_mask - batch_valid = (batch_ids[:, None] >= 0) & (batch_ids[:, None] < MAX_NUM_SEQS) - final_mask = final_mask & batch_valid - tl.store(output_ptrs, topk_vals, mask=final_mask)