Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion custom_ops/gpu_ops/cpp_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,9 @@ std::vector<paddle::Tensor> TextImageGatherScatter(
const bool is_scatter);

std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor& topk_ids, int64_t num_experts);
const paddle::Tensor& topk_ids,
int64_t num_experts,
bool compute_padded_cumsum = false);
void GetPositionIdsAndMaskEncoderBatch(
const paddle::Tensor& seq_lens_encoder,
const paddle::Tensor& seq_lens_decoder,
Expand Down
70 changes: 52 additions & 18 deletions custom_ops/gpu_ops/moe/deepgemm_preprocess.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
#include "helper.h"
#include "paddle/extension.h"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❓ 疑问 PR 描述提到此变更会影响模型输出,但 Accuracy Tests 部分为空。

建议补充精度测试结果,特别是使用 FD_USE_PHI_MOE_PERMUTE=1 时的模型输出对比,确保与原始实现结果一致。


template <typename scalar_t>
template <typename scalar_t, bool kComputeCumsum>
__global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,
int32_t *__restrict__ res,
int32_t *__restrict__ res_padded,
int32_t *__restrict__ res_padded_cumsum,
size_t numel,
int num_experts) {
extern __shared__ int32_t tokens_per_ep[];
Expand All @@ -35,48 +36,81 @@ __global__ void cuda_kernel(const scalar_t *__restrict__ topk_ids,

__syncthreads();

for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
res[i] = tokens_per_ep[i];
res_padded[i] = (res[i] + 127) / 128 * 128;
if constexpr (kComputeCumsum) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议num_experts 较大时,串行计算 cumsum(仅 thread 0 执行)可能成为性能瓶颈。

虽然这不会导致正确性问题,但对于大量专家场景(如 >1000),串行循环可能导致显著的性能延迟。建议考虑并行 cumsum 实现或添加性能监控。

if (threadIdx.x == 0) {
int32_t running_sum = 0;
for (int i = 0; i < num_experts; i++) {
int32_t count = tokens_per_ep[i];
int32_t padded = (count + 127) / 128 * 128;
res[i] = count;
res_padded[i] = padded;
running_sum += padded;
res_padded_cumsum[i] = running_sum;
}
}
} else {
for (size_t i = threadIdx.x; i < num_experts; i += blockDim.x) {
res[i] = tokens_per_ep[i];
res_padded[i] = (tokens_per_ep[i] + 127) / 128 * 128;
}
}
}

std::vector<paddle::Tensor> count_tokens_per_expert_func(
const paddle::Tensor &topk_ids, int64_t num_experts) {
const paddle::Tensor &topk_ids,
int64_t num_experts,
bool compute_padded_cumsum) {
int topk_ids_numel = topk_ids.shape()[0] * topk_ids.shape()[1];

int64_t num_rows = compute_padded_cumsum ? 3 : 2;
auto token_nums_per_expert = paddle::empty(
{2, num_experts}, paddle::DataType::INT32, topk_ids.place());
{num_rows, num_experts}, paddle::DataType::INT32, topk_ids.place());

auto stream = topk_ids.stream();
using scalar_t = int64_t;

// CUDA_CHECK(cudaGetLastError());
cuda_kernel<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
topk_ids_numel,
num_experts);
if (compute_padded_cumsum) {
cuda_kernel<scalar_t, true>
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
token_nums_per_expert.data<int32_t>() + 2 * num_experts,
topk_ids_numel,
num_experts);
} else {
cuda_kernel<scalar_t, false>
<<<1, 1024, num_experts * sizeof(int32_t), stream>>>(
topk_ids.data<scalar_t>(),
token_nums_per_expert.data<int32_t>(),
token_nums_per_expert.data<int32_t>() + num_experts,
nullptr,
topk_ids_numel,
num_experts);
}

// CUDA_CHECK(cudaGetLastError());
return {token_nums_per_expert};
}

std::vector<paddle::DataType> count_tokens_per_expert_func_infer_dtype(
const paddle::DataType &topk_ids_dtype, int64_t num_experts) {
const paddle::DataType &topk_ids_dtype,
int64_t num_experts,
bool compute_padded_cumsum) {
return {paddle::DataType::INT32};
}

std::vector<std::vector<int64_t>> count_tokens_per_expert_func_infer_shape(
const std::vector<int64_t> &topk_ids_shape, int64_t num_experts) {
return {{2, num_experts}};
const std::vector<int64_t> &topk_ids_shape,
int64_t num_experts,
bool compute_padded_cumsum) {
int64_t num_rows = compute_padded_cumsum ? 3 : 2;
return {{num_rows, num_experts}};
}

PD_BUILD_STATIC_OP(count_tokens_per_expert_func)
.Inputs({"topk_ids"})
.Outputs({"token_nums_per_expert"})
.Attrs({"num_experts:int64_t"})
.Attrs({"num_experts:int64_t", "compute_padded_cumsum:bool"})
.SetKernelFn(PD_KERNEL(count_tokens_per_expert_func))
.SetInferShapeFn(PD_INFER_SHAPE(count_tokens_per_expert_func_infer_shape))
.SetInferDtypeFn(PD_INFER_DTYPE(count_tokens_per_expert_func_infer_dtype));
200 changes: 152 additions & 48 deletions fastdeploy/model_executor/layers/moe/fused_moe_cutlass_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@
from .fused_moe_backend_base import UnquantizedFusedMoEMethod

if current_platform.is_cuda():
from fastdeploy.model_executor.ops.gpu import moe_expert_dispatch, moe_expert_reduce
from fastdeploy.model_executor.ops.gpu import (
count_tokens_per_expert_func,
moe_expert_dispatch,
moe_expert_reduce,
)

try:
from fastdeploy.model_executor.ops.gpu import (
Expand Down Expand Up @@ -126,14 +130,15 @@ def apply_ep_prefill(
# 1. Select topk experts and weights
topk_idx, topk_weights = self.ep_prefill_runner.moe_select(layer, gate_out)
# 2. EP Dispatch
dispatch_kwargs = {"expert_alignment": 128} if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE else {}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 环境变量 FD_USE_PHI_MOE_PERMUTE 是本次 PR 引入的新功能,但在文档中没有说明。

建议在文档中添加该环境变量的说明,包括:

  • 功能描述(使用 Paddle 官方 moe_permute 算子替代自定义算子)
  • 适用场景(quant_type=w16a16)
  • 使用示例(export FD_USE_PHI_MOE_PERMUTE=1

(
recv_x,
recv_topk_idx,
recv_topk_weights,
recv_num_tokens_per_expert_list,
handle,
event,
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights)
) = self.ep_prefill_runner.dispatch(x, topk_idx, topk_weights, **dispatch_kwargs)

if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)
Expand All @@ -146,54 +151,91 @@ def apply_ep_prefill(
# 3. Compute ffn
if token_all_num > 0:
logger.debug(f"token_all_num {token_all_num}")
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
recv_topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
# only w4a8 and w4afp8 need expert_idx_per_token
# Other need not this tensor, so we make it None.
expert_idx_per_token = None

if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
# --- moe_permute / moe_unpermute path ---
recv_topk_idx_i32 = recv_topk_idx.astype(paddle.int32)
(permute_input, permute_indices_per_token, dst_weights, _scale_out) = paddle.nn.functional.moe_permute(
hidden_states=recv_x,
scale=None,
expert_routemap_topk=recv_topk_idx_i32,
expert_prob_topk=recv_topk_weights,
num_experts=layer.num_local_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=token_all_num,
)

token_nums_per_expert_cumsum = count_tokens_per_expert_func(
recv_topk_idx, layer.num_local_experts, True
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 此处调用 compute_ffn 时传入了 8 个参数,但原始路径只传 7 个参数。虽然 compute_ffnmax_tokens_per_expert 参数有默认值,但为了代码一致性,建议两种路径保持相同的参数传递方式。

原始路径(第 220-228 行):

ffn_out = self.compute_ffn(
    layer,
    permute_input,
    recv_num_tokens_per_expert_list_cumsum,
    expert_idx_per_token,
    False,
    -1,
    dequant_scale,
)

新路径(第 172-181 行)传入了额外的 None 作为第 8 个参数。

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

默认参数为none,应该不影响

)[2].cast(paddle.int64)
ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None,
False,
-1,
None,
None,
)

tmp_ffn_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=recv_topk_idx_i32,
token_prob_unzipped=dst_weights,
total_zipped_tokens=recv_x.shape[0],
num_experts=layer.num_local_experts,
using_weighted_combine=True,
)
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")
# --- original ep_moe_expert_dispatch / combine path ---
(
permute_input,
permute_indices_per_token,
recv_num_tokens_per_expert_list_cumsum,
dst_weights,
dst_indices,
cumsum_idx_gpu,
expert_idx_per_token,
dequant_scale,
) = fastdeploy.model_executor.ops.gpu.ep_moe_expert_dispatch(
recv_x,
recv_topk_idx,
recv_topk_weights,
(layer.up_gate_proj_in_scale if hasattr(layer, "up_gate_proj_in_scale") else None),
recv_num_tokens_per_expert_list,
token_all_num,
self.moe_quant_type,
)
if not layer.with_bias and self.moe_quant_type != "w4a8" and self.moe_quant_type != "w4afp8":
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")

if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None
if hasattr(layer, "up_gate_proj_in_scale"):
dequant_scale = None

ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)
ffn_out = self.compute_ffn(
layer,
permute_input,
recv_num_tokens_per_expert_list_cumsum,
expert_idx_per_token,
False,
-1,
dequant_scale,
)

# prmt back per rank
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)
tmp_ffn_out = fastdeploy.model_executor.ops.gpu.ep_moe_expert_combine(
ffn_out,
dst_weights,
permute_indices_per_token,
dst_indices,
None, # down_proj_bias,
False, # norm_topk_prob
1.0,
)
else:
tmp_ffn_out = recv_x

Expand Down Expand Up @@ -276,6 +318,69 @@ def apply_tp(
"""
gate_out = gate(x)
gate_out = gate_out.cast("float32")
if fastdeploy.envs.FD_USE_PHI_MOE_PERMUTE and self.moe_quant_type == "w16a16":
if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
layer.n_group,
layer.topk_group,
layer.top_k,
layer.routed_scaling_factor,
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)
else:
topk_idx, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
gate_out,
layer.gate_correction_bias,
layer.top_k,
True, # apply_norm_weight
False,
)
topk_idx_i32 = topk_idx.astype(paddle.int32)
override_buffer_size = x.shape[0] * layer.top_k + layer.num_experts * (128 - 1)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟡 建议 override_buffer_size 的计算逻辑 x.shape[0] * layer.top_k + layer.num_experts * (128 - 1) 需要补充说明。

请确认:

  1. layer.num_experts * (128 - 1) 这个额外空间的目的是什么?
  2. 是否会导致内存浪费?

建议添加注释说明该计算公式的推导过程,或者考虑使用更保守的计算方式。

(permute_input, permute_indices_per_token, dst_weights, _scale_out) = ( # zipped_expertwise_rowmap
paddle.nn.functional.moe_permute(
hidden_states=x,
scale=None,
expert_routemap_topk=topk_idx_i32,
expert_prob_topk=topk_weights,
num_experts=layer.num_experts,
tokens_per_expert=[],
padding_alignment=128,
override_buffer_size=override_buffer_size,
)
)

# Row 2 of count_tokens_per_expert_func is the prefix sum token_nums_per_expert.
token_nums_per_expert_cumsum = count_tokens_per_expert_func(topk_idx, layer.num_experts, True)[2].cast(
paddle.int64
)
if topk_ids_hookfunc is not None:
topk_ids_hookfunc(topk_ids=topk_idx)

ffn_out = self.compute_ffn(
layer,
permute_input,
token_nums_per_expert_cumsum,
None, # expert_idx_per_token not needed for w16a16 without bias
False,
-1,
None, # dequant_scale
None, # max_tokens_per_expert
)

fused_moe_out, _out_probs = paddle.nn.functional.moe_unpermute(
hidden_states_unzipped=ffn_out,
zipped_expertwise_rowmap=permute_indices_per_token,
expert_routemap_topk=topk_idx_i32,
token_prob_unzipped=dst_weights,
total_zipped_tokens=x.shape[0],
num_experts=layer.num_experts,
using_weighted_combine=True,
)
return fused_moe_out

if layer.topk_method == "noaux_tc":
gate_out, topk_weights, topk_idx = get_moe_scores(
gate_out,
Expand All @@ -286,6 +391,7 @@ def apply_tp(
layer.gate_correction_bias,
getattr(layer, "renormalize", True),
)

(
permute_input,
token_nums_per_expert,
Expand Down Expand Up @@ -340,7 +446,6 @@ def apply_tp(
expert_idx_per_token = None
else:
expert_idx_per_token = expert_idx_per_token.cast("int64")

ffn_out = self.compute_ffn(
layer,
permute_input,
Expand All @@ -362,7 +467,6 @@ def apply_tp(
norm_topk_prob=False if layer.topk_method == "noaux_tc" else True,
routed_scaling_factor=1.0,
)

return fused_moe_out


Expand Down
Loading
Loading