Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions custom_ops/xpu_ops/src/ops/block_attn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ std::vector<paddle::Tensor> BlockAttnKernel(
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string& pos_emb_type,
bool rope_3d) {
const bool use_neox_rotary_style,
const bool rope_3d) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
auto xpu_ctx = static_cast<const phi::XPUContext*>(dev_ctx);
Expand Down Expand Up @@ -134,12 +134,25 @@ std::vector<paddle::Tensor> BlockAttnKernel(
int prefix_block_num_per_seq = len_info_cpu.data<int32_t>()[5];

int rope_max_seqlen = 0;
int rope_3d_num_seqs = 1;
int rope_head_dim = 0;
if (rope_3d) {
PD_CHECK(rotary_embs.dims().size() == 6,
"rotary_embs dim size should be 6 in multi-modal model");
rope_max_seqlen = rotary_embs.dims()[3];
rope_3d_num_seqs = rotary_embs.dims()[0];
rope_head_dim = rotary_embs.dims()[5];
} else {
PD_CHECK(rotary_embs.dims().size() == 5,
"rotary_embs dim size should be 5 in language model");
rope_max_seqlen = rotary_embs.dims()[2];
rope_head_dim = rotary_embs.dims()[4];
}
std::string pos_emb_type;
if (use_neox_rotary_style == true) {
pos_emb_type = "NEOX";
} else if (rope_head_dim == head_dim / 2) {
pos_emb_type = "HALF_HEAD_DIM";
} else {
pos_emb_type = "NORMAL";
}

auto block_attn_out =
Expand Down Expand Up @@ -992,8 +1005,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string& pos_emb_type = "NORMAL",
bool rope_3d = false) {
const bool use_neox_rotary_style,
const bool rope_3d = false) {
#define APPLY_KERNEL(TX, TC, TS) \
return BlockAttnKernel<TX, TC, TS>(qkv, \
key_cache, \
Expand Down Expand Up @@ -1021,7 +1034,7 @@ std::vector<paddle::Tensor> BlockAttn(
smooth, \
kv_signal_data_cpu, \
cachekv_signal_thread_cpu, \
pos_emb_type, \
use_neox_rotary_style, \
rope_3d);

const auto cache_dtype = key_cache.dtype();
Expand Down Expand Up @@ -1087,7 +1100,7 @@ PD_BUILD_STATIC_OP(block_attn)
paddle::Optional("smooth"),
paddle::Optional("kv_signal_data_cpu"),
paddle::Optional("cachekv_signal_thread_cpu")})
.Attrs({"pos_emb_type:std::string", "rope_3d:bool"})
.Attrs({"use_neox_rotary_style:bool", "rope_3d:bool"})
.Outputs({"block_attn_out"})
.SetKernelFn(PD_KERNEL(BlockAttn))
.SetInferShapeFn(PD_INFER_SHAPE(BlockAttnInferShape))
Expand Down
6 changes: 3 additions & 3 deletions custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ std::vector<paddle::Tensor> BlockAttn(
const paddle::optional<paddle::Tensor>& smooth,
const paddle::optional<paddle::Tensor>& kv_signal_data_cpu,
const paddle::optional<paddle::Tensor>& cachekv_signal_thread_cpu,
const std::string& pos_emb_type = "NORMAL",
bool rope_3d = false);
const bool use_neox_rotary_style,
const bool rope_3d = false);

std::vector<paddle::Tensor> MoeLayer(
const paddle::Tensor& x,
Expand Down Expand Up @@ -616,7 +616,7 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("smooth"),
py::arg("kv_signal_data_cpu"),
py::arg("cachekv_signal_thread_cpu"),
py::arg("pos_emb_type") = "NORMAL",
py::arg("use_neox_rotary_style"),
py::arg("rope_3d") = false,
"block attention in XPU");

Expand Down
2 changes: 0 additions & 2 deletions fastdeploy/model_executor/forward_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,6 @@ class XPUForwardMeta(ForwardMeta):
dec_batch: Optional[paddle.Tensor] = None
#
total_enc_len: Optional[paddle.Tensor] = None
# position embedding type in rope, supports 'NORMAL' or 'HALF_HEAD_DIM'
pos_emb_type: Optional[str] = "NORMAL"
# for pd_disaggregation
kv_signal_sender: Optional[paddle.Tensor] = None

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ def forward_mixed(
None, # smooth
metadata.kv_signal_data_list[layer.layer_id], # kv_signal_data
forward_meta.kv_signal_sender, # kv_signal_sender
forward_meta.pos_emb_type,
layer.use_neox_rotary_style,
self.rope_3d,
)

Expand Down
1 change: 0 additions & 1 deletion fastdeploy/spec_decode/mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ def _initialize_forward_meta_xpu(self):
self.forward_meta.kv_batch_ids = (self.model_inputs["kv_batch_ids"],)
self.forward_meta.kv_tile_ids_per_batch = (self.model_inputs["kv_tile_ids_per_batch"],)
self.forward_meta.kv_num_blocks_x_cpu = (self.model_inputs["kv_num_blocks_x_cpu"],)
self.forward_meta.pos_emb_type = "NORMAL"
self.forward_meta.attn_backend = self.attn_backends[0]

# Initialzie attention meta data
Expand Down
4 changes: 0 additions & 4 deletions fastdeploy/worker/xpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -822,10 +822,8 @@ def _init_share_inputs(self, max_num_seqs: int):
head_dim = self.model_config.head_dim
if "paddleocr" in self.model_config.model_type: # neox style = True
rope_head_dim = head_dim
self.share_inputs["pos_emb_type"] = "NEOX"
else: # neox style = False
rope_head_dim = head_dim // 2
self.share_inputs["pos_emb_type"] = "HALF_HEAD_DIM"

self.share_inputs["rope_emb"] = paddle.full(
shape=[
Expand Down Expand Up @@ -918,8 +916,6 @@ def _prepare_inputs(self, is_dummy_run=False) -> None:
# Update bad tokens len
max_bad_tokens_len = paddle.max(self.share_inputs["bad_tokens_len"])

if self.enable_mm:
self.forward_meta.pos_emb_type = self.share_inputs["pos_emb_type"]
self.forward_meta.attn_backend = self.attn_backends[0]
self.initialize_attention_backend()

Expand Down
Loading