Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 17 additions & 23 deletions fastdeploy/model_executor/layers/moe/routing_indices_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def _save_routing_kernel(
TOP_K,
NUM_HIDDEN_LAYERS,
MAX_MODEL_LEN,
MAX_NUM_SEQS,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
Expand All @@ -63,45 +64,37 @@ def _save_routing_kernel(
token_mask = token_offsets < TOKEN_NUM

k_offsets = tl.arange(0, BLOCK_SIZE_K)

k_mask = k_offsets < TOP_K

topk_ids_ptrs = TOPK_IDS_PTR + token_offsets[:, None] * TOP_K + k_offsets[None, :]
# [BLOCK_SIZE_M, BLOCK_SIZE_K]

load_mask = token_mask[:, None] & k_mask[None, :]
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask)

batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask)
pad_mask = token_mask & (batch_ids != -1)
# [0, 3, 4, 10, 12][0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 3, 3]
# -> [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] - [0, 0, 0, 0, 4, 4, 4, 4, 4, 4, 10, 10]
# -> [0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 0, 1]
start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask)
topk_vals = tl.load(topk_ids_ptrs, mask=load_mask, other=-1)

batch_ids = tl.load(BATCH_ID_PER_TOKEN_PTR + token_offsets, mask=token_mask, other=-1)

batch_mask = (batch_ids >= 0) & (batch_ids < MAX_NUM_SEQS)
pad_mask = token_mask & (batch_ids != -1) & batch_mask

start_offsets = tl.load(CU_SEQLENS_Q_PTR + batch_ids, mask=pad_mask, other=0)
token_relative_index = token_offsets - start_offsets

# [BLOCK_SIZE_M]
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask)
len_decoder = tl.load(SEQ_LENS_DECODER_PTR + batch_ids, mask=pad_mask, other=0)
token_seq_pos = len_decoder + token_relative_index

STRIDE_BUF_SEQ = MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K
STRIDE_BUF_TOKEN = NUM_HIDDEN_LAYERS * TOP_K
STRIDE_BUF_SEQ = tl.cast(MAX_MODEL_LEN * NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
STRIDE_BUF_TOKEN = tl.cast(NUM_HIDDEN_LAYERS * TOP_K, tl.int64)
STRIDE_BUF_LAYER = TOP_K

# [BLOCK_SIZE_M, BLOCK_SIZE_K]
output_ptrs = (
ROUTING_REPLAY_TABLE_PTR
+ batch_ids[:, None] * STRIDE_BUF_SEQ
+ token_seq_pos[:, None] * STRIDE_BUF_TOKEN
+ LAYER_IDX * STRIDE_BUF_LAYER
+ tl.cast(batch_ids[:, None], tl.int64) * STRIDE_BUF_SEQ
+ tl.cast(token_seq_pos[:, None], tl.int64) * STRIDE_BUF_TOKEN
+ tl.cast(LAYER_IDX, tl.int64) * STRIDE_BUF_LAYER
+ k_offsets[None, :]
)

pos_mask = token_seq_pos < MAX_MODEL_LEN
pos_mask = (token_seq_pos >= 0) & (token_seq_pos < MAX_MODEL_LEN)
pos_mask = pos_mask & pad_mask

# [BLOCK_SIZE_M, BLOCK_SIZE_K]
pos_mask = pos_mask[:, None] & k_mask[None, :]

final_mask = load_mask & pos_mask
Expand Down Expand Up @@ -150,6 +143,7 @@ def save_routing_to_buffer(
TOP_K=top_k,
NUM_HIDDEN_LAYERS=num_hidden_layers,
MAX_MODEL_LEN=max_model_len,
MAX_NUM_SEQS=max_num_seqs,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_K=BLOCK_SIZE_K,
)
Expand Down
Loading