-
Notifications
You must be signed in to change notification settings - Fork 665
[XPU] add speculate_get_logits #5497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
8f06f7e
ef359ae
e0bec56
902beb2
9f40fc6
d5d8c7c
c8fccdc
635126f
8ad13e7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| // Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. | ||
| // | ||
| // Licensed under the Apache License, Version 2.0 (the "License"); | ||
| // you may not use this file except in compliance with the License. | ||
| // You may obtain a copy of the License at | ||
| // | ||
| // http://www.apache.org/licenses/LICENSE-2.0 | ||
| // | ||
| // Unless required by applicable law or agreed to in writing, software | ||
| // distributed under the License is distributed on an "AS IS" BASIS, | ||
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| // See the License for the specific language governing permissions and | ||
| // limitations under the License. | ||
|
|
||
| #include <paddle/phi/backends/xpu/xpu_context.h> | ||
| #include <stdio.h> | ||
| #include "paddle/common/flags.h" | ||
| #include "paddle/extension.h" | ||
| #include "paddle/phi/backends/xpu/enforce_xpu.h" | ||
| #include "xpu/internal/infra_op.h" | ||
| #include "xpu/plugin.h" | ||
|
|
||
| #ifndef PD_BUILD_STATIC_OP | ||
| #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) | ||
| #endif | ||
|
|
||
| void SpeculateGetLogits(const paddle::Tensor& draft_logits, | ||
| const paddle::Tensor& next_token_num, | ||
| const paddle::Tensor& batch_token_num, | ||
| const paddle::Tensor& cu_next_token_offset, | ||
| const paddle::Tensor& cu_batch_token_offset, | ||
| const paddle::Tensor& logits, | ||
| const paddle::Tensor& first_token_logits, | ||
| const paddle::Tensor& seq_lens_this_time, | ||
| const paddle::Tensor& seq_lens_encoder) { | ||
| phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); | ||
| auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); | ||
| baidu::xpu::api::Context* ctx = | ||
| static_cast<const phi::XPUContext*>(dev_ctx)->x_context(); | ||
| if (draft_logits.is_cpu()) { | ||
| ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); | ||
| } | ||
| const int vocab_size = logits.shape()[1]; | ||
| const int real_bsz = seq_lens_this_time.shape()[0]; | ||
|
|
||
| baidu::xpu::api::plugin::speculate_get_logits( | ||
| ctx, | ||
| const_cast<float*>(draft_logits.data<float>()), | ||
| const_cast<int*>(next_token_num.data<int>()), | ||
| const_cast<int*>(batch_token_num.data<int>()), | ||
| const_cast<int*>(cu_next_token_offset.data<int>()), | ||
| const_cast<int*>(cu_batch_token_offset.data<int>()), | ||
| logits.data<float>(), | ||
| first_token_logits.data<float>(), | ||
| seq_lens_this_time.data<int>(), | ||
| seq_lens_encoder.data<int>(), | ||
| real_bsz, | ||
| vocab_size); | ||
| if (draft_logits.is_cpu()) { | ||
| delete ctx; | ||
| } | ||
| } | ||
|
|
||
| PD_BUILD_STATIC_OP(speculate_get_logits) | ||
| .Inputs({"draft_logits", | ||
| "next_token_num", | ||
| "batch_token_num", | ||
| "cu_next_token_offset", | ||
| "cu_batch_token_offset", | ||
| "logits", | ||
| "first_token_logits", | ||
| "seq_lens_this_time", | ||
| "seq_lens_encoder"}) | ||
| .Outputs({"draft_logits_out", | ||
| "batch_token_num_out", | ||
| "cu_batch_token_offset_out"}) | ||
| .SetInplaceMap({{"draft_logits", "draft_logits_out"}, | ||
| {"batch_token_num", "batch_token_num_out"}, | ||
| {"cu_batch_token_offset", "cu_batch_token_offset_out"}}) | ||
| .SetKernelFn(PD_KERNEL(SpeculateGetLogits)); | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| #include "xpu/kernel/cluster.h" | ||
| #include "xpu/kernel/cluster_partition.h" | ||
| #include "xpu/kernel/cluster_primitive.h" | ||
|
|
||
| namespace xpu3 { | ||
| namespace plugin { | ||
|
|
||
| __device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder, | ||
| __shared_ptr__ int* sm_seq_lens_this_time, | ||
| __shared_ptr__ int* sm_batch_token_num, | ||
| __shared_ptr__ int* sm_cu_batch_token_offset, | ||
| __shared_ptr__ int* sm_cu_next_token_offset, | ||
| __global_ptr__ int* batch_token_num, | ||
| __global_ptr__ int* cu_batch_token_offset, | ||
| __global_ptr__ const int* seq_lens_this_time, | ||
| __global_ptr__ const int* seq_lens_encoder, | ||
| const int real_bsz) { | ||
| int cid = core_id(); | ||
| int clus_id = cluster_id(); | ||
|
|
||
| if (clus_id < real_bsz && cid == 0) { | ||
| GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int)); | ||
| GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int)); | ||
| int next_token_num_previous = 0; | ||
| for (int bid = 0; bid < real_bsz; bid++) { | ||
| sm_batch_token_num[bid] = | ||
| sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid]; | ||
| if (bid == 0) { | ||
| sm_cu_batch_token_offset[bid] = 0; | ||
| sm_cu_next_token_offset[bid] = 0; | ||
| } else { | ||
| sm_cu_batch_token_offset[bid] = | ||
| sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1]; | ||
| sm_cu_next_token_offset[bid] = | ||
| sm_cu_next_token_offset[bid - 1] + next_token_num_previous; | ||
| } | ||
| next_token_num_previous = | ||
| sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid]; | ||
| } | ||
| mfence_sm(); | ||
| if (clus_id == 0) { | ||
| SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int)); | ||
| SM2GM_ASYNC(sm_cu_batch_token_offset, | ||
| cu_batch_token_offset, | ||
| real_bsz * sizeof(int)); | ||
| } | ||
| } | ||
|
Comment on lines
+21
to
+47
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个部分的代码逻辑,21行-40行是每个 cluster 都会执行,41-46行只有 cluster0 会执行,是不是等价于21行-46行的代码实际上只有 clus_id == 0 执行的有用?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
好的,明白了 |
||
| mfence_sm(); | ||
| sync_all(); | ||
| } | ||
| __global__ void speculate_get_logits(float* draft_logits, | ||
| int* next_token_num, | ||
| int* batch_token_num, | ||
| int* cu_next_token_offset, | ||
| int* cu_batch_token_offset, | ||
| const float* logits, | ||
| const float* first_token_logits, | ||
| const int* seq_lens_this_time, | ||
| const int* seq_lens_encoder, | ||
| const int real_bsz, | ||
| const int vocab_size) { | ||
| int cid = core_id(); | ||
| int ncores = core_num(); | ||
| int clus_id = cluster_id(); | ||
| int nclusters = cluster_num(); | ||
|
|
||
| int lm_size = 2 * 1024; | ||
| int lm_buf_len = lm_size / sizeof(float); | ||
| float first_token_logits_now_lm[lm_buf_len]; | ||
| float logits_now_lm[lm_buf_len]; | ||
|
|
||
| const int sm_size = 256 * 1024; | ||
| __shared__ char sm[sm_size]; | ||
| int sm_max_buf_len = 256 * 1024 / sizeof(int); | ||
| sm_max_buf_len /= 5; | ||
| __shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm; | ||
| __shared_ptr__ int* sm_seq_lens_this_time = | ||
| sm_seq_lens_encoder + sm_max_buf_len; | ||
| __shared_ptr__ int* sm_batch_token_num = | ||
| sm_seq_lens_this_time + sm_max_buf_len; | ||
| __shared_ptr__ int* sm_cu_batch_token_offset = | ||
| sm_batch_token_num + sm_max_buf_len; | ||
| __shared_ptr__ int* sm_cu_next_token_offset = | ||
| sm_cu_batch_token_offset + sm_max_buf_len; | ||
|
|
||
| prefix_sum(sm_seq_lens_encoder, | ||
| sm_seq_lens_this_time, | ||
| sm_batch_token_num, | ||
| sm_cu_batch_token_offset, | ||
| sm_cu_next_token_offset, | ||
| batch_token_num, | ||
| cu_batch_token_offset, | ||
| seq_lens_this_time, | ||
| seq_lens_encoder, | ||
| real_bsz); | ||
|
|
||
| for (int bid = clus_id; bid < real_bsz; bid += nclusters) { | ||
| auto* draft_logits_now = | ||
| draft_logits + sm_cu_batch_token_offset[bid] * vocab_size; | ||
| auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size; | ||
| auto* first_token_logits_now = first_token_logits + bid * vocab_size; | ||
|
|
||
| for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) { | ||
| int read_len = min(lm_buf_len, vocab_size - i); | ||
| if (sm_seq_lens_encoder[bid] > 0) { | ||
| GM2LM_ASYNC(first_token_logits_now + i, | ||
| first_token_logits_now_lm, | ||
| read_len * sizeof(float)); | ||
| GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float)); | ||
| LM2GM_ASYNC(first_token_logits_now_lm, | ||
| draft_logits_now + i, | ||
| read_len * sizeof(float)); | ||
| LM2GM(logits_now_lm, | ||
| draft_logits_now + vocab_size + i, | ||
| read_len * sizeof(float)); | ||
| } else { | ||
| for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) { | ||
| GM2LM(logits_now + j * vocab_size + i, | ||
| logits_now_lm, | ||
| read_len * sizeof(float)); | ||
| LM2GM(logits_now_lm, | ||
| draft_logits_now + j * vocab_size + i, | ||
| read_len * sizeof(float)); | ||
| } | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
| } // namespace plugin | ||
| } // namespace xpu3 | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 context 什么时候被销毁掉呢?是否会造成内存泄露?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU一般是用于单测验证,除了这里其他的算子可能也没对这个cpu的ctx做释放,后续可能需要统一排查一下