Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <paddle/phi/backends/xpu/xpu_context.h>
#include <stdio.h>
#include "paddle/common/flags.h"
#include "paddle/extension.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "xpu/internal/infra_op.h"
#include "xpu/plugin.h"

#ifndef PD_BUILD_STATIC_OP
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
#endif

void SpeculateGetLogits(const paddle::Tensor& draft_logits,
const paddle::Tensor& next_token_num,
const paddle::Tensor& batch_token_num,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& logits,
const paddle::Tensor& first_token_logits,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder) {
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
baidu::xpu::api::Context* ctx =
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (draft_logits.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 context 什么时候被销毁掉呢?是否会造成内存泄露?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU一般是用于单测验证,除了这里其他的算子可能也没对这个cpu的ctx做释放,后续可能需要统一排查一下

}
const int vocab_size = logits.shape()[1];
const int real_bsz = seq_lens_this_time.shape()[0];

baidu::xpu::api::plugin::speculate_get_logits(
ctx,
const_cast<float*>(draft_logits.data<float>()),
const_cast<int*>(next_token_num.data<int>()),
const_cast<int*>(batch_token_num.data<int>()),
const_cast<int*>(cu_next_token_offset.data<int>()),
const_cast<int*>(cu_batch_token_offset.data<int>()),
logits.data<float>(),
first_token_logits.data<float>(),
seq_lens_this_time.data<int>(),
seq_lens_encoder.data<int>(),
real_bsz,
vocab_size);
if (draft_logits.is_cpu()) {
delete ctx;
}
}

PD_BUILD_STATIC_OP(speculate_get_logits)
.Inputs({"draft_logits",
"next_token_num",
"batch_token_num",
"cu_next_token_offset",
"cu_batch_token_offset",
"logits",
"first_token_logits",
"seq_lens_this_time",
"seq_lens_encoder"})
.Outputs({"draft_logits_out",
"batch_token_num_out",
"cu_batch_token_offset_out"})
.SetInplaceMap({{"draft_logits", "draft_logits_out"},
{"batch_token_num", "batch_token_num_out"},
{"cu_batch_token_offset", "cu_batch_token_offset_out"}})
.SetKernelFn(PD_KERNEL(SpeculateGetLogits));
23 changes: 23 additions & 0 deletions custom_ops/xpu_ops/src/ops/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,16 @@ void SpeculateStepPaddle(
const int encoder_decoder_block_num,
const int max_draft_tokens);

void SpeculateGetLogits(const paddle::Tensor& draft_logits,
const paddle::Tensor& next_token_num,
const paddle::Tensor& batch_token_num,
const paddle::Tensor& cu_next_token_offset,
const paddle::Tensor& cu_batch_token_offset,
const paddle::Tensor& logits,
const paddle::Tensor& first_token_logits,
const paddle::Tensor& seq_lens_this_time,
const paddle::Tensor& seq_lens_encoder);

void SaveOutMmsgStatic(const paddle::Tensor& x,
const paddle::Tensor& not_need_stop,
int64_t rank_id,
Expand Down Expand Up @@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
py::arg("max_draft_tokens"),
"Step paddle function");

m.def("speculate_get_logits",
&SpeculateGetLogits,
py::arg("draft_logits"),
py::arg("next_token_num"),
py::arg("batch_token_num"),
py::arg("cu_next_token_offset"),
py::arg("cu_batch_token_offset"),
py::arg("logits"),
py::arg("first_token_logits"),
py::arg("seq_lens_this_time"),
py::arg("seq_lens_encoder"),
"speculate get logits function");

m.def("text_image_gather_scatter",
&TextImageGatherScatter,
py::arg("input"),
Expand Down
13 changes: 13 additions & 0 deletions custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,19 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
T* output,
int dim_embed,
int elem_cnt);

DLL_EXPORT int speculate_get_logits(Context* ctx,
float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size);
/*--------------------------------------- MTP end
* --------------------------------------------*/

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"

namespace xpu3 {
namespace plugin {

__device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder,
__shared_ptr__ int* sm_seq_lens_this_time,
__shared_ptr__ int* sm_batch_token_num,
__shared_ptr__ int* sm_cu_batch_token_offset,
__shared_ptr__ int* sm_cu_next_token_offset,
__global_ptr__ int* batch_token_num,
__global_ptr__ int* cu_batch_token_offset,
__global_ptr__ const int* seq_lens_this_time,
__global_ptr__ const int* seq_lens_encoder,
const int real_bsz) {
int cid = core_id();
int clus_id = cluster_id();

if (clus_id < real_bsz && cid == 0) {
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int));
GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int));
int next_token_num_previous = 0;
for (int bid = 0; bid < real_bsz; bid++) {
sm_batch_token_num[bid] =
sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid];
if (bid == 0) {
sm_cu_batch_token_offset[bid] = 0;
sm_cu_next_token_offset[bid] = 0;
} else {
sm_cu_batch_token_offset[bid] =
sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1];
sm_cu_next_token_offset[bid] =
sm_cu_next_token_offset[bid - 1] + next_token_num_previous;
}
next_token_num_previous =
sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid];
}
mfence_sm();
if (clus_id == 0) {
SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int));
SM2GM_ASYNC(sm_cu_batch_token_offset,
cu_batch_token_offset,
real_bsz * sizeof(int));
}
}
Comment on lines +21 to +47

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个部分的代码逻辑,21行-40行是每个 cluster 都会执行,41-46行只有 cluster0 会执行,是不是等价于21行-46行的代码实际上只有 clus_id == 0 执行的有用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0

好的,明白了

mfence_sm();
sync_all();
}
__global__ void speculate_get_logits(float* draft_logits,
int* next_token_num,
int* batch_token_num,
int* cu_next_token_offset,
int* cu_batch_token_offset,
const float* logits,
const float* first_token_logits,
const int* seq_lens_this_time,
const int* seq_lens_encoder,
const int real_bsz,
const int vocab_size) {
int cid = core_id();
int ncores = core_num();
int clus_id = cluster_id();
int nclusters = cluster_num();

int lm_size = 2 * 1024;
int lm_buf_len = lm_size / sizeof(float);
float first_token_logits_now_lm[lm_buf_len];
float logits_now_lm[lm_buf_len];

const int sm_size = 256 * 1024;
__shared__ char sm[sm_size];
int sm_max_buf_len = 256 * 1024 / sizeof(int);
sm_max_buf_len /= 5;
__shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm;
__shared_ptr__ int* sm_seq_lens_this_time =
sm_seq_lens_encoder + sm_max_buf_len;
__shared_ptr__ int* sm_batch_token_num =
sm_seq_lens_this_time + sm_max_buf_len;
__shared_ptr__ int* sm_cu_batch_token_offset =
sm_batch_token_num + sm_max_buf_len;
__shared_ptr__ int* sm_cu_next_token_offset =
sm_cu_batch_token_offset + sm_max_buf_len;

prefix_sum(sm_seq_lens_encoder,
sm_seq_lens_this_time,
sm_batch_token_num,
sm_cu_batch_token_offset,
sm_cu_next_token_offset,
batch_token_num,
cu_batch_token_offset,
seq_lens_this_time,
seq_lens_encoder,
real_bsz);

for (int bid = clus_id; bid < real_bsz; bid += nclusters) {
auto* draft_logits_now =
draft_logits + sm_cu_batch_token_offset[bid] * vocab_size;
auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size;
auto* first_token_logits_now = first_token_logits + bid * vocab_size;

for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) {
int read_len = min(lm_buf_len, vocab_size - i);
if (sm_seq_lens_encoder[bid] > 0) {
GM2LM_ASYNC(first_token_logits_now + i,
first_token_logits_now_lm,
read_len * sizeof(float));
GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float));
LM2GM_ASYNC(first_token_logits_now_lm,
draft_logits_now + i,
read_len * sizeof(float));
LM2GM(logits_now_lm,
draft_logits_now + vocab_size + i,
read_len * sizeof(float));
} else {
for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) {
GM2LM(logits_now + j * vocab_size + i,
logits_now_lm,
read_len * sizeof(float));
LM2GM(logits_now_lm,
draft_logits_now + j * vocab_size + i,
read_len * sizeof(float));
}
}
}
}
}

} // namespace plugin
} // namespace xpu3
Loading
Loading