Skip to content

Commit 489d34b

Browse files
committed
[XPU] refactor moe ffn
- remove BKCL_DISPATCH_ALL_GATHER - support sparse mode - support moe quant_method
1 parent 438c9f7 commit 489d34b

File tree

12 files changed

+402
-131
lines changed

12 files changed

+402
-131
lines changed

custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
5757
const int64_t ep_size = 1;
5858
const int64_t ep_rank = 0;
5959

60-
if (std::is_same<TY, int8_t>::value) {
60+
if (std::is_same<TY, int8_t>::value && !std::is_same<TX, int8_t>::value) {
6161
permute_input =
6262
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
6363
if (token_nums_this_rank > 0) {
@@ -99,7 +99,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
9999
block_num,
100100
ep_size,
101101
ep_rank,
102-
token_nums_this_rank);
102+
token_nums_this_rank,
103+
std::is_same<TX, int8_t>::value
104+
? input_scales.get_ptr()->data<float>()
105+
: nullptr,
106+
expand_input_scales.data<float>());
103107
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
104108
}
105109
}
@@ -138,10 +142,12 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
138142
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
139143
quant_method != "w4a8") {
140144
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
145+
} else if (input_dtype == paddle::DataType::INT8) {
146+
APPLY_KERNEL(int8_t, int8_t);
141147
} else {
142148
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
143149
static_cast<int>(input_dtype),
144-
"quant_method=",
150+
", quant_method=",
145151
quant_method);
146152
return {};
147153
}

custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc

Lines changed: 154 additions & 90 deletions
Large diffs are not rendered by default.

custom_ops/xpu_ops/src/ops/pybind/pybind.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ std::vector<paddle::Tensor> WeightOnlyLinear(
143143
const int arch,
144144
const int group_size);
145145

146+
std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x);
147+
146148
std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
147149
const paddle::Tensor& moe_index,
148150
const paddle::Tensor& weights,
@@ -1252,6 +1254,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
12521254
py::arg("arch"),
12531255
py::arg("group_size") = -1);
12541256

1257+
m.def(
1258+
"quant2d_per_token", &Quant2dPerToken, py::arg("x"), "quant x per token");
1259+
12551260
m.def("xpu_moe_layer",
12561261
&MoeLayer,
12571262
py::arg("x"),
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <core/check.h>
16+
#include <core/context.h>
17+
#include <core/param.h>
18+
#include <infer_ops.h>
19+
#include <xft_api.h>
20+
#include "paddle/extension.h"
21+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
22+
#include "utility/debug.h"
23+
#include "utility/env.h"
24+
25+
#ifndef PD_BUILD_STATIC_OP
26+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
27+
#endif
28+
29+
namespace xftblock = baidu::xpu::xftblock;
30+
namespace api = baidu::xpu::api;
31+
32+
template <typename TX>
33+
std::vector<paddle::Tensor> Quant2dPerTokenKernel(const paddle::Tensor& x) {
34+
using XPU_TX = typename XPUTypeTrait<TX>::Type;
35+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
36+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
37+
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
38+
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
39+
auto rt_guard = xctx.get_rt_guard();
40+
41+
auto input_shape = x.shape();
42+
auto x_scale =
43+
paddle::empty({input_shape[0]}, paddle::DataType::FLOAT32, x.place());
44+
auto quant_x = paddle::empty(
45+
{input_shape[0], input_shape[1]}, paddle::DataType::INT8, x.place());
46+
if (input_shape[0] > 0) {
47+
int ret = infer_ops::quant2d_per_token<XPU_TX, float, int8_t>(
48+
xpu_ctx->x_context(),
49+
reinterpret_cast<const XPU_TX*>(x.data<TX>()),
50+
nullptr,
51+
reinterpret_cast<int8_t*>(quant_x.data<int8_t>()),
52+
reinterpret_cast<float*>(x_scale.data<float>()),
53+
input_shape[0],
54+
input_shape[1]);
55+
PD_CHECK(ret == api::SUCCESS);
56+
}
57+
58+
return {quant_x, x_scale};
59+
}
60+
61+
std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x) {
62+
const auto x_type = x.dtype();
63+
if (x_type == paddle::DataType::BFLOAT16) {
64+
return Quant2dPerTokenKernel<paddle::bfloat16>(x);
65+
} else if (x_type == paddle::DataType::FLOAT16) {
66+
return Quant2dPerTokenKernel<paddle::float16>(x);
67+
} else {
68+
PD_THROW("Quant2dPerToken not support x_type=", static_cast<int>(x_type));
69+
return {};
70+
}
71+
}
72+
73+
std::vector<std::vector<int64_t>> Quant2dPerTokenInferShape(
74+
const std::vector<int64_t>& x_shape) {
75+
return {x_shape};
76+
}
77+
78+
std::vector<paddle::DataType> Quant2dPerTokenInferDtype(
79+
const paddle::DataType& x_dtype) {
80+
return {paddle::DataType::INT8};
81+
}
82+
83+
PD_BUILD_STATIC_OP(quant2d_per_token)
84+
.Inputs({"x"})
85+
.Outputs({"quant_x", "x_scale"})
86+
.SetKernelFn(PD_KERNEL(Quant2dPerToken))
87+
.SetInferShapeFn(PD_INFER_SHAPE(Quant2dPerTokenInferShape))
88+
.SetInferDtypeFn(PD_INFER_DTYPE(Quant2dPerTokenInferDtype));

fastdeploy/model_executor/layers/backends/xpu/moe/ep.py

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -352,9 +352,9 @@ def dispatch(
352352
**kwargs,
353353
):
354354
self.num_combined_tokens = x.shape[0]
355-
x_scale_tensor = kwargs.get("x_scale_tensor", None)
355+
x_scale = kwargs.get("x_scale", None)
356356
dispatch_args = {
357-
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
357+
"x": (x, x_scale) if x_scale is not None else x,
358358
"topk_idx": topk_idx,
359359
"topk_weights": topk_weights,
360360
}
@@ -428,11 +428,27 @@ def dispatch(
428428
dispatch_hook,
429429
valid_token_num,
430430
) = self.ep_engine.low_latency_dispatch(x, topk_idx, expertwise_scale, use_fp8)
431-
# no need to call dispatch_hook here, because it has already been done in xDeepEP
432-
# if dispatch_hook is not None:
433-
# dispatch_hook()
431+
# valid_token_num is optional:
432+
# - if valid_token_num is None, it means that we CANNOT accurately know
433+
# the size of the tensor, but the advantage is that it can reduce
434+
# the overhead of kernel launch.
435+
# - if valid_token_num is NOT None, it means that we CAN accurately know
436+
# the size of the tensor, but the disadvantage is that it will interrupt
437+
# the process of kernel launch.
438+
if valid_token_num is None and dispatch_hook is not None:
439+
dispatch_hook()
440+
441+
if valid_token_num is None:
442+
valid_token_num = -1
443+
444+
if isinstance(recv_hidden_states, tuple):
445+
recv_x = recv_hidden_states[0]
446+
recv_x_scale = recv_hidden_states[1]
447+
else:
448+
recv_x = recv_hidden_states
449+
recv_x_scale = None
434450

435-
return recv_hidden_states, recv_expert_count, handle, valid_token_num
451+
return recv_x, recv_x_scale, recv_expert_count, handle, valid_token_num
436452

437453
def combine(self, ffn_out, topk_idx, topk_weights, handle):
438454
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(

0 commit comments

Comments
 (0)