Skip to content

Commit 12c76f8

Browse files
RuohengMacmcamdyEmmonsCurse
authored
[XPU] add speculate_get_logits (#5497)
* [XPU] add speculate_step_system_cache * [XPU] add speculate_step_system_cache * [XPU] add speculate_get_logits * delete context * add ptr check --------- Co-authored-by: cmcamdy <[email protected]> Co-authored-by: YuBaoku <[email protected]>
1 parent 888c4b9 commit 12c76f8

File tree

6 files changed

+603
-0
lines changed

6 files changed

+603
-0
lines changed
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include <paddle/phi/backends/xpu/xpu_context.h>
16+
#include <stdio.h>
17+
#include "paddle/common/flags.h"
18+
#include "paddle/extension.h"
19+
#include "paddle/phi/backends/xpu/enforce_xpu.h"
20+
#include "xpu/internal/infra_op.h"
21+
#include "xpu/plugin.h"
22+
23+
#ifndef PD_BUILD_STATIC_OP
24+
#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name)
25+
#endif
26+
27+
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
28+
const paddle::Tensor& next_token_num,
29+
const paddle::Tensor& batch_token_num,
30+
const paddle::Tensor& cu_next_token_offset,
31+
const paddle::Tensor& cu_batch_token_offset,
32+
const paddle::Tensor& logits,
33+
const paddle::Tensor& first_token_logits,
34+
const paddle::Tensor& seq_lens_this_time,
35+
const paddle::Tensor& seq_lens_encoder) {
36+
phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId());
37+
auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place);
38+
baidu::xpu::api::Context* ctx =
39+
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
40+
if (draft_logits.is_cpu()) {
41+
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
42+
}
43+
const int vocab_size = logits.shape()[1];
44+
const int real_bsz = seq_lens_this_time.shape()[0];
45+
46+
baidu::xpu::api::plugin::speculate_get_logits(
47+
ctx,
48+
const_cast<float*>(draft_logits.data<float>()),
49+
const_cast<int*>(next_token_num.data<int>()),
50+
const_cast<int*>(batch_token_num.data<int>()),
51+
const_cast<int*>(cu_next_token_offset.data<int>()),
52+
const_cast<int*>(cu_batch_token_offset.data<int>()),
53+
logits.data<float>(),
54+
first_token_logits.data<float>(),
55+
seq_lens_this_time.data<int>(),
56+
seq_lens_encoder.data<int>(),
57+
real_bsz,
58+
vocab_size);
59+
if (draft_logits.is_cpu()) {
60+
delete ctx;
61+
}
62+
}
63+
64+
PD_BUILD_STATIC_OP(speculate_get_logits)
65+
.Inputs({"draft_logits",
66+
"next_token_num",
67+
"batch_token_num",
68+
"cu_next_token_offset",
69+
"cu_batch_token_offset",
70+
"logits",
71+
"first_token_logits",
72+
"seq_lens_this_time",
73+
"seq_lens_encoder"})
74+
.Outputs({"draft_logits_out",
75+
"batch_token_num_out",
76+
"cu_batch_token_offset_out"})
77+
.SetInplaceMap({{"draft_logits", "draft_logits_out"},
78+
{"batch_token_num", "batch_token_num_out"},
79+
{"cu_batch_token_offset", "cu_batch_token_offset_out"}})
80+
.SetKernelFn(PD_KERNEL(SpeculateGetLogits));

custom_ops/xpu_ops/src/ops/pybind/pybind.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,16 @@ void SpeculateStepPaddle(
470470
const int encoder_decoder_block_num,
471471
const int max_draft_tokens);
472472

473+
void SpeculateGetLogits(const paddle::Tensor& draft_logits,
474+
const paddle::Tensor& next_token_num,
475+
const paddle::Tensor& batch_token_num,
476+
const paddle::Tensor& cu_next_token_offset,
477+
const paddle::Tensor& cu_batch_token_offset,
478+
const paddle::Tensor& logits,
479+
const paddle::Tensor& first_token_logits,
480+
const paddle::Tensor& seq_lens_this_time,
481+
const paddle::Tensor& seq_lens_encoder);
482+
473483
void SaveOutMmsgStatic(const paddle::Tensor& x,
474484
const paddle::Tensor& not_need_stop,
475485
int64_t rank_id,
@@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) {
11741184
py::arg("max_draft_tokens"),
11751185
"Step paddle function");
11761186

1187+
m.def("speculate_get_logits",
1188+
&SpeculateGetLogits,
1189+
py::arg("draft_logits"),
1190+
py::arg("next_token_num"),
1191+
py::arg("batch_token_num"),
1192+
py::arg("cu_next_token_offset"),
1193+
py::arg("cu_batch_token_offset"),
1194+
py::arg("logits"),
1195+
py::arg("first_token_logits"),
1196+
py::arg("seq_lens_this_time"),
1197+
py::arg("seq_lens_encoder"),
1198+
"speculate get logits function");
1199+
11771200
m.def("text_image_gather_scatter",
11781201
&TextImageGatherScatter,
11791202
py::arg("input"),

custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,19 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx,
600600
T* output,
601601
int dim_embed,
602602
int elem_cnt);
603+
604+
DLL_EXPORT int speculate_get_logits(Context* ctx,
605+
float* draft_logits,
606+
int* next_token_num,
607+
int* batch_token_num,
608+
int* cu_next_token_offset,
609+
int* cu_batch_token_offset,
610+
const float* logits,
611+
const float* first_token_logits,
612+
const int* seq_lens_this_time,
613+
const int* seq_lens_encoder,
614+
const int real_bsz,
615+
const int vocab_size);
603616
/*--------------------------------------- MTP end
604617
* --------------------------------------------*/
605618

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
#include "xpu/kernel/cluster.h"
2+
#include "xpu/kernel/cluster_partition.h"
3+
#include "xpu/kernel/cluster_primitive.h"
4+
5+
namespace xpu3 {
6+
namespace plugin {
7+
8+
__device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder,
9+
__shared_ptr__ int* sm_seq_lens_this_time,
10+
__shared_ptr__ int* sm_batch_token_num,
11+
__shared_ptr__ int* sm_cu_batch_token_offset,
12+
__shared_ptr__ int* sm_cu_next_token_offset,
13+
__global_ptr__ int* batch_token_num,
14+
__global_ptr__ int* cu_batch_token_offset,
15+
__global_ptr__ const int* seq_lens_this_time,
16+
__global_ptr__ const int* seq_lens_encoder,
17+
const int real_bsz) {
18+
int cid = core_id();
19+
int clus_id = cluster_id();
20+
21+
if (clus_id < real_bsz && cid == 0) {
22+
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int));
23+
GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int));
24+
int next_token_num_previous = 0;
25+
for (int bid = 0; bid < real_bsz; bid++) {
26+
sm_batch_token_num[bid] =
27+
sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid];
28+
if (bid == 0) {
29+
sm_cu_batch_token_offset[bid] = 0;
30+
sm_cu_next_token_offset[bid] = 0;
31+
} else {
32+
sm_cu_batch_token_offset[bid] =
33+
sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1];
34+
sm_cu_next_token_offset[bid] =
35+
sm_cu_next_token_offset[bid - 1] + next_token_num_previous;
36+
}
37+
next_token_num_previous =
38+
sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid];
39+
}
40+
mfence_sm();
41+
if (clus_id == 0) {
42+
SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int));
43+
SM2GM_ASYNC(sm_cu_batch_token_offset,
44+
cu_batch_token_offset,
45+
real_bsz * sizeof(int));
46+
}
47+
}
48+
mfence_sm();
49+
sync_all();
50+
}
51+
__global__ void speculate_get_logits(float* draft_logits,
52+
int* next_token_num,
53+
int* batch_token_num,
54+
int* cu_next_token_offset,
55+
int* cu_batch_token_offset,
56+
const float* logits,
57+
const float* first_token_logits,
58+
const int* seq_lens_this_time,
59+
const int* seq_lens_encoder,
60+
const int real_bsz,
61+
const int vocab_size) {
62+
int cid = core_id();
63+
int ncores = core_num();
64+
int clus_id = cluster_id();
65+
int nclusters = cluster_num();
66+
67+
int lm_size = 2 * 1024;
68+
int lm_buf_len = lm_size / sizeof(float);
69+
float first_token_logits_now_lm[lm_buf_len];
70+
float logits_now_lm[lm_buf_len];
71+
72+
const int sm_size = 256 * 1024;
73+
__shared__ char sm[sm_size];
74+
int sm_max_buf_len = 256 * 1024 / sizeof(int);
75+
sm_max_buf_len /= 5;
76+
__shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm;
77+
__shared_ptr__ int* sm_seq_lens_this_time =
78+
sm_seq_lens_encoder + sm_max_buf_len;
79+
__shared_ptr__ int* sm_batch_token_num =
80+
sm_seq_lens_this_time + sm_max_buf_len;
81+
__shared_ptr__ int* sm_cu_batch_token_offset =
82+
sm_batch_token_num + sm_max_buf_len;
83+
__shared_ptr__ int* sm_cu_next_token_offset =
84+
sm_cu_batch_token_offset + sm_max_buf_len;
85+
86+
prefix_sum(sm_seq_lens_encoder,
87+
sm_seq_lens_this_time,
88+
sm_batch_token_num,
89+
sm_cu_batch_token_offset,
90+
sm_cu_next_token_offset,
91+
batch_token_num,
92+
cu_batch_token_offset,
93+
seq_lens_this_time,
94+
seq_lens_encoder,
95+
real_bsz);
96+
97+
for (int bid = clus_id; bid < real_bsz; bid += nclusters) {
98+
auto* draft_logits_now =
99+
draft_logits + sm_cu_batch_token_offset[bid] * vocab_size;
100+
auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size;
101+
auto* first_token_logits_now = first_token_logits + bid * vocab_size;
102+
103+
for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) {
104+
int read_len = min(lm_buf_len, vocab_size - i);
105+
if (sm_seq_lens_encoder[bid] > 0) {
106+
GM2LM_ASYNC(first_token_logits_now + i,
107+
first_token_logits_now_lm,
108+
read_len * sizeof(float));
109+
GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float));
110+
LM2GM_ASYNC(first_token_logits_now_lm,
111+
draft_logits_now + i,
112+
read_len * sizeof(float));
113+
LM2GM(logits_now_lm,
114+
draft_logits_now + vocab_size + i,
115+
read_len * sizeof(float));
116+
} else {
117+
for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) {
118+
GM2LM(logits_now + j * vocab_size + i,
119+
logits_now_lm,
120+
read_len * sizeof(float));
121+
LM2GM(logits_now_lm,
122+
draft_logits_now + j * vocab_size + i,
123+
read_len * sizeof(float));
124+
}
125+
}
126+
}
127+
}
128+
}
129+
130+
} // namespace plugin
131+
} // namespace xpu3

0 commit comments

Comments
 (0)