Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions custom_ops/xpu_ops/src/ops/moe_ep_dispatch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
const int64_t ep_size = 1;
const int64_t ep_rank = 0;

if (std::is_same<TY, int8_t>::value) {
if (std::is_same<TY, int8_t>::value && !std::is_same<TX, int8_t>::value) {
permute_input =
paddle::empty({token_nums_this_rank, n}, paddle::DataType::INT8, place);
if (token_nums_this_rank > 0) {
Expand Down Expand Up @@ -99,7 +99,11 @@ std::vector<paddle::Tensor> EPMoeExpertDispatchKernel(
block_num,
ep_size,
ep_rank,
token_nums_this_rank);
token_nums_this_rank,
std::is_same<TX, int8_t>::value
? input_scales.get_ptr()->data<float>()
: nullptr,
expand_input_scales.data<float>());
PD_CHECK(ret == 0, "moe_ep_ffn_pre_sorted failed");
}
}
Expand Down Expand Up @@ -138,10 +142,12 @@ std::vector<paddle::Tensor> EPMoeExpertDispatch(
} else if (input_dtype == paddle::DataType::BFLOAT16 &&
quant_method != "w4a8") {
APPLY_KERNEL(paddle::bfloat16, paddle::bfloat16);
} else if (input_dtype == paddle::DataType::INT8) {
APPLY_KERNEL(int8_t, int8_t);
} else {
PD_THROW("EPMoeExpertDispatch not support input_dtype=",
static_cast<int>(input_dtype),
"quant_method=",
", quant_method=",
quant_method);
return {};
}
Expand Down
244 changes: 154 additions & 90 deletions custom_ops/xpu_ops/src/ops/moe_expert_ffn.cc

Large diffs are not rendered by default.

5 changes: 5 additions & 0 deletions custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,8 @@ std::vector<paddle::Tensor> WeightOnlyLinear(
const int arch,
const int group_size);

std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x);

std::vector<paddle::Tensor> MoeEPCombine(const paddle::Tensor& ffn_out,
const paddle::Tensor& moe_index,
const paddle::Tensor& weights,
Expand Down Expand Up @@ -1252,6 +1254,9 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("arch"),
py::arg("group_size") = -1);

m.def(
"quant2d_per_token", &Quant2dPerToken, py::arg("x"), "quant x per token");

m.def("xpu_moe_layer",
&MoeLayer,
py::arg("x"),
Expand Down
88 changes: 88 additions & 0 deletions custom_ops/xpu_ops/src/ops/quant2d_per_token.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <core/check.h>
#include <core/context.h>
#include <core/param.h>
#include <infer_ops.h>
#include <xft_api.h>
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "utility/debug.h"
#include "utility/env.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

namespace xftblock = baidu::xpu::xftblock;
namespace api = baidu::xpu::api;

template <typename TX>
std::vector<paddle::Tensor> Quant2dPerTokenKernel(const paddle::Tensor& x) {
using XPU_TX = typename XPUTypeTrait<TX>::Type;
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
xftblock::XFTContext xctx(xpu_ctx->x_context(), nullptr);
auto rt_guard = xctx.get_rt_guard();

auto input_shape = x.shape();
auto x_scale =
paddle::empty({input_shape[0]}, paddle::DataType::FLOAT32, x.place());
auto quant_x = paddle::empty(
{input_shape[0], input_shape[1]}, paddle::DataType::INT8, x.place());
if (input_shape[0] > 0) {
int ret = infer_ops::quant2d_per_token<XPU_TX, float, int8_t>(
xpu_ctx->x_context(),
reinterpret_cast<const XPU_TX*>(x.data<TX>()),
nullptr,
reinterpret_cast<int8_t*>(quant_x.data<int8_t>()),
reinterpret_cast<float*>(x_scale.data<float>()),
input_shape[0],
input_shape[1]);
PD_CHECK(ret == api::SUCCESS);
}

return {quant_x, x_scale};
}

std::vector<paddle::Tensor> Quant2dPerToken(const paddle::Tensor& x) {
const auto x_type = x.dtype();
if (x_type == paddle::DataType::BFLOAT16) {
return Quant2dPerTokenKernel<paddle::bfloat16>(x);
} else if (x_type == paddle::DataType::FLOAT16) {
return Quant2dPerTokenKernel<paddle::float16>(x);
} else {
PD_THROW("Quant2dPerToken not support x_type=", static_cast<int>(x_type));
return {};
}
}

std::vector<std::vector<int64_t>> Quant2dPerTokenInferShape(
const std::vector<int64_t>& x_shape) {
return {x_shape};
}

std::vector<paddle::DataType> Quant2dPerTokenInferDtype(
const paddle::DataType& x_dtype) {
return {paddle::DataType::INT8};
}

PD_BUILD_STATIC_OP(quant2d_per_token)
.Inputs({"x"})
.Outputs({"quant_x", "x_scale"})
.SetKernelFn(PD_KERNEL(Quant2dPerToken))
.SetInferShapeFn(PD_INFER_SHAPE(Quant2dPerTokenInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(Quant2dPerTokenInferDtype));
28 changes: 22 additions & 6 deletions fastdeploy/model_executor/layers/backends/xpu/moe/ep.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,9 @@ def dispatch(
**kwargs,
):
self.num_combined_tokens = x.shape[0]
x_scale_tensor = kwargs.get("x_scale_tensor", None)
x_scale = kwargs.get("x_scale", None)
dispatch_args = {
"x": (x, x_scale_tensor) if x_scale_tensor is not None else x,
"x": (x, x_scale) if x_scale is not None else x,
"topk_idx": topk_idx,
"topk_weights": topk_weights,
}
Expand Down Expand Up @@ -428,11 +428,27 @@ def dispatch(
dispatch_hook,
valid_token_num,
) = self.ep_engine.low_latency_dispatch(x, topk_idx, expertwise_scale, use_fp8)
# no need to call dispatch_hook here, because it has already been done in xDeepEP
# if dispatch_hook is not None:
# dispatch_hook()
# valid_token_num is optional:
# - if valid_token_num is None, it means that we CANNOT accurately know
# the size of the tensor, but the advantage is that it can reduce
# the overhead of kernel launch.
# - if valid_token_num is NOT None, it means that we CAN accurately know
# the size of the tensor, but the disadvantage is that it will interrupt
# the process of kernel launch.
if valid_token_num is None and dispatch_hook is not None:
dispatch_hook()

if valid_token_num is None:
valid_token_num = -1

if isinstance(recv_hidden_states, tuple):
recv_x = recv_hidden_states[0]
recv_x_scale = recv_hidden_states[1]
else:
recv_x = recv_hidden_states
recv_x_scale = None

return recv_hidden_states, recv_expert_count, handle, valid_token_num
return recv_x, recv_x_scale, recv_expert_count, handle, valid_token_num

def combine(self, ffn_out, topk_idx, topk_weights, handle):
combined_hidden_states, combine_hook = self.ep_engine.low_latency_combine(
Expand Down
Loading
Loading