From 8f06f7e432cbf24cffe19b1154bceb9cf9c8e981 Mon Sep 17 00:00:00 2001 From: maruoheng Date: Fri, 5 Dec 2025 08:56:17 +0000 Subject: [PATCH 1/5] [XPU] add speculate_step_system_cache --- .../src/ops/mtp/speculate_step_helper.cc | 117 +++++++ .../src/ops/mtp/speculate_step_helper.h | 49 +++ .../src/ops/mtp/speculate_step_paddle.cc | 105 ++---- .../ops/mtp/speculate_step_system_cache.cc | 145 ++++++++ .../xpu_ops/src/plugin/include/xpu/plugin.h | 3 +- .../mtp_kernel/speculate_recover_block.xpu | 11 +- .../mtp_wrapper/speculate_recover_block.cpp | 21 +- .../test/test_speculate_step_system_cache.py | 316 ++++++++++++++++++ 8 files changed, 685 insertions(+), 82 deletions(-) create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc create mode 100644 custom_ops/xpu_ops/test/test_speculate_step_system_cache.py diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc new file mode 100644 index 00000000000..383abd9536b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc @@ -0,0 +1,117 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "speculate_step_helper.h" + +void SpeculateStepPaddleBase( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::optional &ori_seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens) { + namespace api = baidu::xpu::api; + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + auto xpu_ctx = static_cast(dev_ctx); + api::Context *ctx = xpu_ctx->x_context(); + if (seq_lens_this_time.is_cpu()) { + ctx = new api::Context(api::kCPU); + } + const int bsz = seq_lens_this_time.shape()[0]; + PADDLE_ENFORCE_LE( + bsz, + 640, + phi::errors::InvalidArgument( + "Only support bsz <= 640, but received bsz is %d", bsz)); + const int block_num_per_seq = block_tables.shape()[1]; + const int length = input_ids.shape()[1]; + const int pre_id_length = pre_ids.shape()[1]; + const int max_decoder_block_num = pre_id_length / block_size; + int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( + ctx, + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(encoder_block_lens.data()), + const_cast(is_block_step.data()), + const_cast(step_block_list.data()), + const_cast(step_lens.data()), + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(need_block_list.data()), + const_cast(need_block_len.data()), + const_cast(used_list_len.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(first_token_ids.data()), + const_cast(accept_num.data()), + bsz, + block_size, + block_num_per_seq, + max_decoder_block_num, + max_draft_tokens); + PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); + auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); + int recover_lens_cpu_data = recover_lens_cpu.data()[0]; + if (recover_lens_cpu_data > 0) { + r = baidu::xpu::api::plugin::speculate_recover_block( + ctx, + const_cast(recover_block_list.data()), + const_cast(recover_lens.data()), + const_cast(stop_flags.data()), + const_cast(seq_lens_this_time.data()), + ori_seq_lens_encoder.data(), + ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data() : nullptr, + const_cast(seq_lens_encoder.data()), + const_cast(seq_lens_decoder.data()), + const_cast(block_tables.data()), + const_cast(free_list.data()), + const_cast(free_list_len.data()), + const_cast(input_ids.data()), + pre_ids.data(), + step_idx.data(), + encoder_block_lens.data(), + used_list_len.data(), + next_tokens.data(), + first_token_ids.data(), + bsz, + block_num_per_seq, + length, + pre_id_length); + PD_CHECK(r == 0, "speculate_recover_block failed."); + } +} \ No newline at end of file diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h new file mode 100644 index 00000000000..4d9d5e97a7b --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h @@ -0,0 +1,49 @@ +// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/extension.h" +#include "paddle/phi/core/enforce.h" +#include "xpu/plugin.h" + +void SpeculateStepPaddleBase( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::optional &ori_seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens); \ No newline at end of file diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc index d8b113fb81a..542f0f1a4fa 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -12,10 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include -#include "paddle/extension.h" -#include "paddle/phi/core/enforce.h" -#include "xpu/plugin.h" +#include "speculate_step_helper.h" #ifndef PD_BUILD_STATIC_OP #define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) @@ -48,77 +45,35 @@ void SpeculateStepPaddle( const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens) { - namespace api = baidu::xpu::api; - phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); - auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); - auto xpu_ctx = static_cast(dev_ctx); - api::Context *ctx = xpu_ctx->x_context(); - if (seq_lens_this_time.is_cpu()) { - ctx = new api::Context(api::kCPU); - } - const int bsz = seq_lens_this_time.shape()[0]; - PADDLE_ENFORCE_LE( - bsz, - 640, - phi::errors::InvalidArgument( - "Only support bsz <= 640, but received bsz is %d", bsz)); - const int block_num_per_seq = block_tables.shape()[1]; - const int length = input_ids.shape()[1]; - const int pre_id_length = pre_ids.shape()[1]; - const int max_decoder_block_num = pre_id_length / block_size; - int r = baidu::xpu::api::plugin::speculate_free_and_dispatch_block( - ctx, - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - const_cast(seq_lens_decoder.data()), - const_cast(block_tables.data()), - const_cast(encoder_block_lens.data()), - const_cast(is_block_step.data()), - const_cast(step_block_list.data()), - const_cast(step_lens.data()), - const_cast(recover_block_list.data()), - const_cast(recover_lens.data()), - const_cast(need_block_list.data()), - const_cast(need_block_len.data()), - const_cast(used_list_len.data()), - const_cast(free_list.data()), - const_cast(free_list_len.data()), - const_cast(first_token_ids.data()), - const_cast(accept_num.data()), - bsz, - block_size, - block_num_per_seq, - max_decoder_block_num, - max_draft_tokens); - PD_CHECK(r == 0, "speculate_free_and_dispatch_block failed."); - auto recover_lens_cpu = recover_lens.copy_to(paddle::CPUPlace(), false); - int recover_lens_cpu_data = recover_lens_cpu.data()[0]; - if (recover_lens_cpu_data > 0) { - r = baidu::xpu::api::plugin::speculate_recover_block( - ctx, - const_cast(recover_block_list.data()), - const_cast(recover_lens.data()), - const_cast(stop_flags.data()), - const_cast(seq_lens_this_time.data()), - ori_seq_lens_encoder.data(), - const_cast(seq_lens_encoder.data()), - seq_lens_decoder.data(), - const_cast(block_tables.data()), - const_cast(free_list.data()), - const_cast(free_list_len.data()), - const_cast(input_ids.data()), - pre_ids.data(), - step_idx.data(), - encoder_block_lens.data(), - used_list_len.data(), - next_tokens.data(), - first_token_ids.data(), - bsz, - block_num_per_seq, - length, - pre_id_length); - PD_CHECK(r == 0, "speculate_recover_block failed."); - } + SpeculateStepPaddleBase( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::optional(), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens + ); } PD_BUILD_STATIC_OP(speculate_step_paddle) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc new file mode 100644 index 00000000000..89643a457e5 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc @@ -0,0 +1,145 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "speculate_step_helper.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateStepSystemCachePaddle( + const paddle::Tensor &stop_flags, + const paddle::Tensor &seq_lens_this_time, + const paddle::Tensor &ori_seq_lens_encoder, + const paddle::Tensor &ori_seq_lens_decoder, + const paddle::Tensor &seq_lens_encoder, + const paddle::Tensor &seq_lens_decoder, + const paddle::Tensor &block_tables, // [bsz, block_num_per_seq] + const paddle::Tensor &encoder_block_lens, + const paddle::Tensor &is_block_step, + const paddle::Tensor &step_block_list, + const paddle::Tensor &step_lens, + const paddle::Tensor &recover_block_list, + const paddle::Tensor &recover_lens, + const paddle::Tensor &need_block_list, + const paddle::Tensor &need_block_len, + const paddle::Tensor &used_list_len, + const paddle::Tensor &free_list, + const paddle::Tensor &free_list_len, + const paddle::Tensor &input_ids, + const paddle::Tensor &pre_ids, + const paddle::Tensor &step_idx, + const paddle::Tensor &next_tokens, + const paddle::Tensor &first_token_ids, + const paddle::Tensor &accept_num, + const int block_size, + const int encoder_decoder_block_num, + const int max_draft_tokens) { + SpeculateStepPaddleBase( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::make_optional(ori_seq_lens_decoder), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens + ); +} + +PD_BUILD_STATIC_OP(speculate_step_system_cache) + .Inputs({"stop_flags", + "seq_lens_this_time", + "ori_seq_lens_encoder", + "ori_seq_lens_decoder", + "seq_lens_encoder", + "seq_lens_decoder", + "block_tables", + "encoder_block_lens", + "is_block_step", + "step_block_list", + "step_lens", + "recover_block_list", + "recover_lens", + "need_block_list", + "need_block_len", + "used_list_len", + "free_list", + "free_list_len", + "input_ids", + "pre_ids", + "step_idx", + "next_tokens", + "first_token_ids", + "accept_num"}) + .Attrs({"block_size: int", + "encoder_decoder_block_num: int", + "max_draft_tokens: int"}) + .Outputs({"stop_flags_out", + "seq_lens_this_time_out", + "seq_lens_encoder_out", + "seq_lens_decoder_out", + "block_tables_out", + "encoder_block_lens_out", + "is_block_step_out", + "step_block_list_out", + "step_lens_out", + "recover_block_list_out", + "recover_lens_out", + "need_block_list_out", + "need_block_len_out", + "used_list_len_out", + "free_list_out", + "free_list_len_out", + "input_ids_out", + "first_token_ids_out"}) + .SetInplaceMap({{"stop_flags", "stop_flags_out"}, + {"seq_lens_this_time", "seq_lens_this_time_out"}, + {"seq_lens_encoder", "seq_lens_encoder_out"}, + {"seq_lens_decoder", "seq_lens_decoder_out"}, + {"block_tables", "block_tables_out"}, + {"encoder_block_lens", "encoder_block_lens_out"}, + {"is_block_step", "is_block_step_out"}, + {"step_block_list", "step_block_list_out"}, + {"step_lens", "step_lens_out"}, + {"recover_block_list", "recover_block_list_out"}, + {"recover_lens", "recover_lens_out"}, + {"need_block_list", "need_block_list_out"}, + {"need_block_len", "need_block_len_out"}, + {"used_list_len", "used_list_len_out"}, + {"free_list", "free_list_out"}, + {"free_list_len", "free_list_len_out"}, + {"input_ids", "input_ids_out"}, + {"first_token_ids", "first_token_ids_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateStepSystemCachePaddle)); + diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index 09a426a3126..bc27a54a94a 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -207,8 +207,9 @@ DLL_EXPORT int speculate_recover_block(Context* ctx, bool* stop_flags, int* seq_lens_this_time, const int* ori_seq_lens_encoder, + const int* ori_seq_lens_decoder, int* seq_lens_encoder, - const int* seq_lens_decoder, + int* seq_lens_decoder, int* block_tables, int* free_list, int* free_list_len, diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu index 46d24821dda..6eb7279d97d 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu @@ -33,8 +33,9 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] bool* stop_flags, int* seq_lens_this_time, const int* ori_seq_lens_encoder, + const int* ori_seq_lens_decoder, int* seq_lens_encoder, - const int* seq_lens_decoder, + int* seq_lens_decoder, int* block_tables, int* free_list, int* free_list_len, @@ -82,6 +83,7 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] for (int bid = cid; bid < recover_len_lm; bid += ncores) { int recover_id; int ori_seq_len_encoder; + int ori_seq_len_decoder; int step_idx_now; int encoder_block_len; int decoder_used_len; @@ -89,12 +91,19 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] GM2LM(recover_block_list + bid, &recover_id, sizeof(int)); GM2LM_ASYNC( ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int)); + if (ori_seq_lens_decoder != nullptr) { + GM2LM_ASYNC( + ori_seq_lens_decoder + recover_id, &ori_seq_len_decoder, sizeof(int)); + } GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int)); GM2LM_ASYNC( encoder_block_lens + recover_id, &encoder_block_len, sizeof(int)); GM2LM_ASYNC(used_list_len + recover_id, &decoder_used_len, sizeof(int)); GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t)); mfence(); + if (ori_seq_lens_decoder != nullptr) { + LM2GM_ASYNC(&ori_seq_len_decoder, seq_lens_decoder + recover_id, sizeof(int)); + } int seq_len = ori_seq_len_encoder + step_idx_now; mfence(); diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp index 2996325c833..5f3c8bdf6c2 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -26,8 +26,9 @@ __attribute__((global)) void speculate_recover_block( bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -57,8 +58,9 @@ static int cpu_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -76,6 +78,9 @@ static int cpu_wrapper(Context *ctx, for (int bid = 0; bid < recover_len[0]; bid++) { const int recover_id = recover_block_list[bid]; const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; + if (ori_seq_lens_decoder != nullptr) { + seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id]; + } const int step_idx_now = step_idx[recover_id]; const int seq_len = ori_seq_len_encoder + step_idx_now; const int encoder_block_len = encoder_block_lens[recover_id]; @@ -112,8 +117,9 @@ static int xpu3_wrapper(Context *ctx, bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -136,6 +142,7 @@ static int xpu3_wrapper(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, + ori_seq_lens_decoder, seq_lens_encoder, seq_lens_decoder, block_tables, @@ -161,8 +168,9 @@ int speculate_recover_block(Context *ctx, bool *stop_flags, int *seq_lens_this_time, const int *ori_seq_lens_encoder, + const int *ori_seq_lens_decoder, int *seq_lens_encoder, - const int *seq_lens_decoder, + int *seq_lens_decoder, int *block_tables, int *free_list, int *free_list_len, @@ -185,7 +193,8 @@ int speculate_recover_block(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, - seq_lens_encoder); + ori_seq_lens_decoder); + WRAPPER_DUMP_PARAM1(ctx, seq_lens_encoder); WRAPPER_DUMP_PARAM6(ctx, seq_lens_decoder, block_tables, @@ -208,6 +217,7 @@ int speculate_recover_block(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, + ori_seq_lens_decoder, seq_lens_encoder, seq_lens_decoder, block_tables, @@ -232,6 +242,7 @@ int speculate_recover_block(Context *ctx, stop_flags, seq_lens_this_time, ori_seq_lens_encoder, + ori_seq_lens_decoder, seq_lens_encoder, seq_lens_decoder, block_tables, diff --git a/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py new file mode 100644 index 00000000000..d691533d03f --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py @@ -0,0 +1,316 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_step_system_cache + +# 固定随机种子,保证测试可复现 +np.random.seed(2023) +paddle.seed(2023) + +def generate_test_data(): + """ + 生成测试数据的辅助函数。 + 这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。 + """ + # max_bs = 128 + max_bs = 8 + bs = max_bs + max_seq_len = 8192 + block_size = 64 + block_bs = 8 + block_ratio = 0.75 + max_draft_tokens = 1 + encoder_decoder_block_num = 1 + + # 生成原始测试数据(完全复用原有逻辑) + stop_flags = np.random.randint(0, 2, [max_bs]).astype("bool") + seq_lens_this_time = np.zeros([bs], "int32") + seq_lens_encoder = np.zeros([max_bs], "int32") + seq_lens_decoder = np.zeros([max_bs], "int32") + accept_num = np.random.randint(1, 3, [max_bs]).astype("int32") + for i in range(bs): + seq_lens_decoder[i] = 2 + i * 2 + seq_lens_this_time[i] = 1 + + ori_seq_lens_encoder = np.zeros([max_bs], "int32") + ori_seq_lens_encoder[:] = seq_lens_decoder[:] // 2 + ori_seq_lens_decoder = np.random.randint(1, 10, (max_bs), "int32") + step_idx = (seq_lens_decoder - ori_seq_lens_encoder).astype("int64") + + max_block_num = block_bs * max_seq_len // block_size + free_list_len = int(max_block_num * (1 - block_ratio)) + free_list_len = np.full([1], free_list_len, "int32") + free_list = np.arange( + max_block_num - 1, max_block_num - free_list_len.item() - 1, -1, dtype="int32" # 加 .item() 转为标量 + ) + encoder_block_lens = np.zeros([max_bs], "int32") + used_list_len = np.zeros([max_bs], "int32") + block_tables = np.full([max_bs, 128], -1, "int32") + encoder_block_id = 0 + + for i in range(bs): + enc_block_num = (ori_seq_lens_encoder[i] + block_size - 1) // block_size + encoder_block_lens[i] = enc_block_num + dec_block_num = (seq_lens_decoder[i] + block_size - 1) // block_size - enc_block_num + used_list_len[i] = dec_block_num + block_tables[i, :enc_block_num] = np.arange(encoder_block_id, encoder_block_id + enc_block_num, 1, "int32") + encoder_block_id += enc_block_num + if dec_block_num > 0: + block_tables[i, enc_block_num : enc_block_num + dec_block_num] = free_list[ + free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1 + ] + free_list[free_list_len[0] - 1 - dec_block_num : free_list_len[0] - 1] = -1 + free_list_len[0] -= dec_block_num + + assert free_list_len[0] >= 0, "free_list_len should not be negative" + + is_block_step = np.zeros([max_bs], "bool") + is_block_step[:bs] = np.random.randint(0, 2, [bs]).astype("bool") + step_block_list = np.full([max_bs], -1, "int32") + step_lens = np.full([1], 0, "int32") + + for i in range(bs): + if is_block_step[i]: + step_block_list[step_lens[0]] = i + step_lens[0] += 1 + + recover_lens = np.full([1], 0, "int32") + recover_block_list = np.full([max_bs], -1, "int32") + need_block_len = np.full([1], 0, "int32") + need_block_list = np.full([max_bs], -1, "int32") + + input_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + pre_ids = np.random.randint(0, 1000, [max_bs, max_seq_len], "int64") + next_tokens = np.random.randint(0, 1000, [max_bs], "int64") + first_token_ids = np.random.randint(0, 1000, [max_bs], "int64") + + paddle.set_device("cpu") + # 转换为 paddle tensor(保持原有逻辑) + data_cpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "ori_seq_lens_decoder": paddle.to_tensor(ori_seq_lens_decoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + paddle.set_device("xpu:0") + data_xpu = { + "stop_flags": paddle.to_tensor(stop_flags), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + "seq_lens_decoder": paddle.to_tensor(seq_lens_decoder), + "ori_seq_lens_encoder": paddle.to_tensor(ori_seq_lens_encoder), + "ori_seq_lens_decoder": paddle.to_tensor(ori_seq_lens_decoder), + "block_tables": paddle.to_tensor(block_tables), + "encoder_block_lens": paddle.to_tensor(encoder_block_lens), + "is_block_step": paddle.to_tensor(is_block_step), + "step_block_list": paddle.to_tensor(step_block_list), + "step_lens": paddle.to_tensor(step_lens), + "recover_block_list": paddle.to_tensor(recover_block_list), + "recover_lens": paddle.to_tensor(recover_lens), + "need_block_list": paddle.to_tensor(need_block_list), + "need_block_len": paddle.to_tensor(need_block_len), + "used_list_len": paddle.to_tensor(used_list_len), + "free_list_len": paddle.to_tensor(free_list_len), + "free_list": paddle.to_tensor(free_list), + "input_ids": paddle.to_tensor(input_ids), + "pre_ids": paddle.to_tensor(pre_ids), + "step_idx": paddle.to_tensor(step_idx), + "next_tokens": paddle.to_tensor(next_tokens), + "first_token_ids": paddle.to_tensor(first_token_ids), + "accept_num": paddle.to_tensor(accept_num), + "block_size": block_size, + "encoder_decoder_block_num": encoder_decoder_block_num, + "max_draft_tokens": max_draft_tokens, + } + + # 恢复默认设备,避免影响其他测试 + paddle.set_device("cpu") + + return data_cpu, data_xpu + + +def speculate_step_paddle_execution(test_data): + """测试 speculate_step_system_cache 函数的执行性和输出合理性""" + # 提取输入数据 + stop_flags = test_data["stop_flags"] # 克隆避免影响夹具数据 + seq_lens_this_time = test_data["seq_lens_this_time"] + ori_seq_lens_encoder = test_data["ori_seq_lens_encoder"] + ori_seq_lens_decoder = test_data["ori_seq_lens_decoder"] + seq_lens_encoder = test_data["seq_lens_encoder"] + seq_lens_decoder = test_data["seq_lens_decoder"] + block_tables = test_data["block_tables"] + encoder_block_lens = test_data["encoder_block_lens"] + is_block_step = test_data["is_block_step"] + step_block_list = test_data["step_block_list"] + step_lens = test_data["step_lens"] + recover_block_list = test_data["recover_block_list"] + recover_lens = test_data["recover_lens"] + need_block_list = test_data["need_block_list"] + need_block_len = test_data["need_block_len"] + used_list_len = test_data["used_list_len"] + free_list = test_data["free_list"] + free_list_len = test_data["free_list_len"] + input_ids = test_data["input_ids"] + pre_ids = test_data["pre_ids"] + step_idx = test_data["step_idx"] + next_tokens = test_data["next_tokens"] + first_token_ids = test_data["first_token_ids"] + accept_num = test_data["accept_num"] + block_size = test_data["block_size"] + encoder_decoder_block_num = test_data["encoder_decoder_block_num"] + max_draft_tokens = test_data["max_draft_tokens"] + + # 可选:打印执行前关键信息(如需调试可开启) + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": + print("-" * 50 + "before step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + + # 执行目标函数(核心测试步骤) + speculate_step_system_cache( + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + ori_seq_lens_decoder, + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens, + ) + + # 可选:打印执行后关键信息(如需调试可开启) + if os.environ.get("STEP_TEST_DEBUG", "0") == "1": + print("-" * 50 + "after step op" + "-" * 50) + # ... (省略打印内容以保持简洁) + + return test_data + + +class TestSpeculateStepSystemCache(unittest.TestCase): + """ + 测试类,继承自 unittest.TestCase。 + 所有以 'test_' 开头的方法都会被视为测试用例。 + """ + + def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08): + """ + 自定义的断言方法,用于比较两个 test_data 结构和数据。 + 在 unittest 中,自定义断言通常以 'assert' 开头。 + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + self.assertEqual( + keys1, + keys2, + msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}", + ) + + # 2. 逐字段校验数据 + for key in keys1: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据 + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!") + else: + # 浮点型:允许 rtol/atol 范围内的误差 + np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!") + + print("✅ 两个 test_data 结构和数据完全一致!") + + def test_speculate_step_system_cache_execution(self): + """ + 核心测试用例方法。 + 该方法会调用 generate_test_data 获取数据, + 分别在 CPU 和 XPU 上执行测试函数, + 并使用自定义的断言方法比较结果。 + """ + print("\nRunning test: test_speculate_step_system_cache_execution") + + # 1. 获取测试数据 + data_cpu, data_xpu = generate_test_data() + + # 2. 执行测试函数 + result_xpu = speculate_step_paddle_execution(data_xpu) + result_cpu = speculate_step_paddle_execution(data_cpu) + + # 3. 断言结果一致 + self.assert_test_data_equal(result_xpu, result_cpu) + + +if __name__ == "__main__": + # 使用 unittest 的主程序来运行所有测试用例 + unittest.main() From e0bec564929003897a9549c6db27ff6c12da3918 Mon Sep 17 00:00:00 2001 From: maruoheng Date: Mon, 8 Dec 2025 07:14:56 +0000 Subject: [PATCH 2/5] [XPU] add speculate_step_system_cache --- .../src/ops/mtp/speculate_step_helper.cc | 5 +- .../src/ops/mtp/speculate_step_helper.h | 2 +- .../src/ops/mtp/speculate_step_paddle.cc | 56 +++++++++---------- .../ops/mtp/speculate_step_system_cache.cc | 56 +++++++++---------- .../mtp_kernel/speculate_recover_block.xpu | 7 ++- .../mtp_wrapper/speculate_recover_block.cpp | 2 +- .../test/test_speculate_step_system_cache.py | 1 + 7 files changed, 64 insertions(+), 65 deletions(-) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc index 383abd9536b..2344531a333 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.cc @@ -95,7 +95,8 @@ void SpeculateStepPaddleBase( const_cast(stop_flags.data()), const_cast(seq_lens_this_time.data()), ori_seq_lens_encoder.data(), - ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data() : nullptr, + ori_seq_lens_decoder ? ori_seq_lens_decoder.get_ptr()->data() + : nullptr, const_cast(seq_lens_encoder.data()), const_cast(seq_lens_decoder.data()), const_cast(block_tables.data()), @@ -114,4 +115,4 @@ void SpeculateStepPaddleBase( pre_id_length); PD_CHECK(r == 0, "speculate_recover_block failed."); } -} \ No newline at end of file +} diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h index 4d9d5e97a7b..ea2eb2c9bb6 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_helper.h @@ -46,4 +46,4 @@ void SpeculateStepPaddleBase( const paddle::Tensor &accept_num, const int block_size, const int encoder_decoder_block_num, - const int max_draft_tokens); \ No newline at end of file + const int max_draft_tokens); diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc index 542f0f1a4fa..1088b604c91 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_paddle.cc @@ -45,35 +45,33 @@ void SpeculateStepPaddle( const int block_size, const int encoder_decoder_block_num, const int max_draft_tokens) { - SpeculateStepPaddleBase( - stop_flags, - seq_lens_this_time, - ori_seq_lens_encoder, - paddle::optional(), - seq_lens_encoder, - seq_lens_decoder, - block_tables, - encoder_block_lens, - is_block_step, - step_block_list, - step_lens, - recover_block_list, - recover_lens, - need_block_list, - need_block_len, - used_list_len, - free_list, - free_list_len, - input_ids, - pre_ids, - step_idx, - next_tokens, - first_token_ids, - accept_num, - block_size, - encoder_decoder_block_num, - max_draft_tokens - ); + SpeculateStepPaddleBase(stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::optional(), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens); } PD_BUILD_STATIC_OP(speculate_step_paddle) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc index 89643a457e5..0040600ca37 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_step_system_cache.cc @@ -47,34 +47,33 @@ void SpeculateStepSystemCachePaddle( const int encoder_decoder_block_num, const int max_draft_tokens) { SpeculateStepPaddleBase( - stop_flags, - seq_lens_this_time, - ori_seq_lens_encoder, - paddle::make_optional(ori_seq_lens_decoder), - seq_lens_encoder, - seq_lens_decoder, - block_tables, - encoder_block_lens, - is_block_step, - step_block_list, - step_lens, - recover_block_list, - recover_lens, - need_block_list, - need_block_len, - used_list_len, - free_list, - free_list_len, - input_ids, - pre_ids, - step_idx, - next_tokens, - first_token_ids, - accept_num, - block_size, - encoder_decoder_block_num, - max_draft_tokens - ); + stop_flags, + seq_lens_this_time, + ori_seq_lens_encoder, + paddle::make_optional(ori_seq_lens_decoder), + seq_lens_encoder, + seq_lens_decoder, + block_tables, + encoder_block_lens, + is_block_step, + step_block_list, + step_lens, + recover_block_list, + recover_lens, + need_block_list, + need_block_len, + used_list_len, + free_list, + free_list_len, + input_ids, + pre_ids, + step_idx, + next_tokens, + first_token_ids, + accept_num, + block_size, + encoder_decoder_block_num, + max_draft_tokens); } PD_BUILD_STATIC_OP(speculate_step_system_cache) @@ -142,4 +141,3 @@ PD_BUILD_STATIC_OP(speculate_step_system_cache) {"input_ids", "input_ids_out"}, {"first_token_ids", "first_token_ids_out"}}) .SetKernelFn(PD_KERNEL(SpeculateStepSystemCachePaddle)); - diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu index 6eb7279d97d..85439819265 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_recover_block.xpu @@ -92,8 +92,8 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] GM2LM_ASYNC( ori_seq_lens_encoder + recover_id, &ori_seq_len_encoder, sizeof(int)); if (ori_seq_lens_decoder != nullptr) { - GM2LM_ASYNC( - ori_seq_lens_decoder + recover_id, &ori_seq_len_decoder, sizeof(int)); + GM2LM_ASYNC( + ori_seq_lens_decoder + recover_id, &ori_seq_len_decoder, sizeof(int)); } GM2LM_ASYNC(step_idx + recover_id, &step_idx_now, sizeof(int)); GM2LM_ASYNC( @@ -102,7 +102,8 @@ __global__ void speculate_recover_block(int* recover_block_list, // [bsz] GM2LM_ASYNC(next_tokens + recover_id, &next_token, sizeof(int64_t)); mfence(); if (ori_seq_lens_decoder != nullptr) { - LM2GM_ASYNC(&ori_seq_len_decoder, seq_lens_decoder + recover_id, sizeof(int)); + LM2GM_ASYNC( + &ori_seq_len_decoder, seq_lens_decoder + recover_id, sizeof(int)); } int seq_len = ori_seq_len_encoder + step_idx_now; diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp index 5f3c8bdf6c2..0e270c0f01a 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_recover_block.cpp @@ -79,7 +79,7 @@ static int cpu_wrapper(Context *ctx, const int recover_id = recover_block_list[bid]; const int ori_seq_len_encoder = ori_seq_lens_encoder[recover_id]; if (ori_seq_lens_decoder != nullptr) { - seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id]; + seq_lens_decoder[recover_id] = ori_seq_lens_decoder[recover_id]; } const int step_idx_now = step_idx[recover_id]; const int seq_len = ori_seq_len_encoder + step_idx_now; diff --git a/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py index d691533d03f..6b52efe13f7 100644 --- a/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py +++ b/custom_ops/xpu_ops/test/test_speculate_step_system_cache.py @@ -24,6 +24,7 @@ np.random.seed(2023) paddle.seed(2023) + def generate_test_data(): """ 生成测试数据的辅助函数。 From d5d8c7cb7a4efe172e66a907f1a5cfd20c1ab5fc Mon Sep 17 00:00:00 2001 From: maruoheng Date: Thu, 11 Dec 2025 01:45:16 +0000 Subject: [PATCH 3/5] [XPU] add speculate_get_logits --- .../src/ops/mtp/speculate_get_logits.cc | 77 ++++++++ custom_ops/xpu_ops/src/ops/pybind/pybind.cc | 23 +++ .../xpu_ops/src/plugin/include/xpu/plugin.h | 13 ++ .../mtp_kernel/speculate_get_logits.xpu | 131 +++++++++++++ .../mtp_wrapper/speculate_get_logits.cpp | 176 ++++++++++++++++++ .../xpu_ops/test/test_speculate_get_logits.py | 172 +++++++++++++++++ 6 files changed, 592 insertions(+) create mode 100644 custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc create mode 100644 custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu create mode 100644 custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp create mode 100644 custom_ops/xpu_ops/test/test_speculate_get_logits.py diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc new file mode 100644 index 00000000000..106ac758d9d --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc @@ -0,0 +1,77 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/common/flags.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "xpu/internal/infra_op.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + baidu::xpu::api::Context* ctx = + static_cast(dev_ctx)->x_context(); + if (draft_logits.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + baidu::xpu::api::plugin::speculate_get_logits( + ctx, + const_cast(draft_logits.data()), + const_cast(next_token_num.data()), + const_cast(batch_token_num.data()), + const_cast(cu_next_token_offset.data()), + const_cast(cu_batch_token_offset.data()), + logits.data(), + first_token_logits.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + real_bsz, + vocab_size); +} + +PD_BUILD_STATIC_OP(speculate_get_logits) + .Inputs({"draft_logits", + "next_token_num", + "batch_token_num", + "cu_next_token_offset", + "cu_batch_token_offset", + "logits", + "first_token_logits", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"draft_logits_out", + "batch_token_num_out", + "cu_batch_token_offset_out"}) + .SetInplaceMap({{"draft_logits", "draft_logits_out"}, + {"batch_token_num", "batch_token_num_out"}, + {"cu_batch_token_offset", "cu_batch_token_offset_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetLogits)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 0400aa02d7d..dbe62ee1cba 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -470,6 +470,16 @@ void SpeculateStepPaddle( const int encoder_decoder_block_num, const int max_draft_tokens); +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); + void SaveOutMmsgStatic(const paddle::Tensor& x, const paddle::Tensor& not_need_stop, int64_t rank_id, @@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_draft_tokens"), "Step paddle function"); + m.def("speculate_get_logits", + &SpeculateGetLogits, + py::arg("draft_logits"), + py::arg("next_token_num"), + py::arg("batch_token_num"), + py::arg("cu_next_token_offset"), + py::arg("cu_batch_token_offset"), + py::arg("logits"), + py::arg("first_token_logits"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + "speculate get logits function"); + m.def("text_image_gather_scatter", &TextImageGatherScatter, py::arg("input"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index bc27a54a94a..a00fa46ec13 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -600,6 +600,19 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx, T* output, int dim_embed, int elem_cnt); + +DLL_EXPORT int speculate_get_logits(Context* ctx, + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size); /*--------------------------------------- MTP end * --------------------------------------------*/ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu new file mode 100644 index 00000000000..9174b444c69 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu @@ -0,0 +1,131 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +__device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder, + __shared_ptr__ int* sm_seq_lens_this_time, + __shared_ptr__ int* sm_batch_token_num, + __shared_ptr__ int* sm_cu_batch_token_offset, + __shared_ptr__ int* sm_cu_next_token_offset, + __global_ptr__ int* batch_token_num, + __global_ptr__ int* cu_batch_token_offset, + __global_ptr__ const int* seq_lens_this_time, + __global_ptr__ const int* seq_lens_encoder, + const int real_bsz) { + int cid = core_id(); + int clus_id = cluster_id(); + + if (clus_id < real_bsz && cid == 0) { + GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int)); + GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int)); + int next_token_num_previous = 0; + for (int bid = 0; bid < real_bsz; bid++) { + sm_batch_token_num[bid] = + sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid]; + if (bid == 0) { + sm_cu_batch_token_offset[bid] = 0; + sm_cu_next_token_offset[bid] = 0; + } else { + sm_cu_batch_token_offset[bid] = + sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1]; + sm_cu_next_token_offset[bid] = + sm_cu_next_token_offset[bid - 1] + next_token_num_previous; + } + next_token_num_previous = + sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid]; + } + mfence_sm(); + if (clus_id == 0) { + SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int)); + SM2GM_ASYNC(sm_cu_batch_token_offset, + cu_batch_token_offset, + real_bsz * sizeof(int)); + } + } + mfence_sm(); + sync_all(); +} +__global__ void speculate_get_logits(float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + int cid = core_id(); + int ncores = core_num(); + int clus_id = cluster_id(); + int nclusters = cluster_num(); + + int lm_size = 2 * 1024; + int lm_buf_len = lm_size / sizeof(float); + float first_token_logits_now_lm[lm_buf_len]; + float logits_now_lm[lm_buf_len]; + + const int sm_size = 256 * 1024; + __shared__ char sm[sm_size]; + int sm_max_buf_len = 256 * 1024 / sizeof(int); + sm_max_buf_len /= 5; + __shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm; + __shared_ptr__ int* sm_seq_lens_this_time = + sm_seq_lens_encoder + sm_max_buf_len; + __shared_ptr__ int* sm_batch_token_num = + sm_seq_lens_this_time + sm_max_buf_len; + __shared_ptr__ int* sm_cu_batch_token_offset = + sm_batch_token_num + sm_max_buf_len; + __shared_ptr__ int* sm_cu_next_token_offset = + sm_cu_batch_token_offset + sm_max_buf_len; + + prefix_sum(sm_seq_lens_encoder, + sm_seq_lens_this_time, + sm_batch_token_num, + sm_cu_batch_token_offset, + sm_cu_next_token_offset, + batch_token_num, + cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + real_bsz); + + for (int bid = clus_id; bid < real_bsz; bid += nclusters) { + auto* draft_logits_now = + draft_logits + sm_cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size; + auto* first_token_logits_now = first_token_logits + bid * vocab_size; + + for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) { + int read_len = min(lm_buf_len, vocab_size - i); + if (sm_seq_lens_encoder[bid] > 0) { + GM2LM_ASYNC(first_token_logits_now + i, + first_token_logits_now_lm, + read_len * sizeof(float)); + GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float)); + LM2GM_ASYNC(first_token_logits_now_lm, + draft_logits_now + i, + read_len * sizeof(float)); + LM2GM(logits_now_lm, + draft_logits_now + vocab_size + i, + read_len * sizeof(float)); + } else { + for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) { + GM2LM(logits_now + j * vocab_size + i, + logits_now_lm, + read_len * sizeof(float)); + LM2GM(logits_now_lm, + draft_logits_now + j * vocab_size + i, + read_len * sizeof(float)); + } + } + } + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp new file mode 100644 index 00000000000..87eea101cbd --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp @@ -0,0 +1,176 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_get_logits( + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + int batch_token_num_sum = 0; + int next_token_num_sum = 0; + for (int bid = 0; bid < real_bsz; bid++) { + // prefix sum + cu_batch_token_offset[bid] = batch_token_num_sum; + cu_next_token_offset[bid] = next_token_num_sum; + + batch_token_num[bid] = + seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid]; + next_token_num[bid] = + seq_lens_encoder[bid] > 0 ? 1 : seq_lens_this_time[bid]; + + batch_token_num_sum += batch_token_num[bid]; + next_token_num_sum += next_token_num[bid]; + + auto* draft_logits_now = + draft_logits + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_next_token_offset[bid] * vocab_size; + auto* first_token_logits_now = first_token_logits + bid * vocab_size; + for (int i = 0; i < vocab_size; i++) { + if (seq_lens_encoder[bid] > 0) { + draft_logits_now[i] = first_token_logits_now[i]; + draft_logits_now[vocab_size + i] = logits_now[i]; + } else { + for (int j = 0; j < seq_lens_this_time[bid]; j++) { + draft_logits_now[j * vocab_size + i] = logits_now[j * vocab_size + i]; + } + } + } + } + return api::SUCCESS; +} + +static int xpu3_wrapper(Context* ctx, + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + xpu3::plugin::speculate_get_logits<<ncluster(), 64, ctx->xpu_stream>>>( + draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + return api::SUCCESS; +} + +int speculate_get_logits(Context* ctx, + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_logits", float); + WRAPPER_DUMP_PARAM6(ctx, + draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits); + WRAPPER_DUMP_PARAM5(ctx, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + WRAPPER_DUMP(ctx); + WRAPPER_ASSERT_LE(ctx, real_bsz, 256 * 1024 / sizeof(int) / 5); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/test/test_speculate_get_logits.py b/custom_ops/xpu_ops/test/test_speculate_get_logits.py new file mode 100644 index 00000000000..07a207aa701 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_get_logits.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_get_logits + +# 固定随机种子,保证测试可复现 +np.random.seed(2023) +paddle.seed(2023) + + +def generate_test_data(): + """ + 生成测试数据的辅助函数。 + 这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。 + """ + real_bsz = 64 + vocab_size = 2 * 1024 + max_seq_len = 8 * 1024 + + # 生成原始测试数据(完全复用原有逻辑) + seq_lens_encoder = np.random.randint(0, 2, [real_bsz], dtype=np.int32) + seq_lens_this_time = np.random.randint(1, max_seq_len, [real_bsz], dtype=np.int32) + draft_logits_seqlen = 0 + logits_seqlen = 0 + for i in range(real_bsz): + if seq_lens_encoder[i] > 0: + draft_logits_seqlen += 2 + logits_seqlen += 1 + else: + draft_logits_seqlen += seq_lens_this_time[i] + logits_seqlen += seq_lens_this_time[i] + + draft_logits = np.zeros([draft_logits_seqlen, vocab_size], dtype=np.float32) + next_token_num = np.zeros([real_bsz], dtype=np.int32) + batch_token_num = np.zeros([real_bsz], dtype=np.int32) + cu_next_token_offset = np.zeros([real_bsz], dtype=np.int32) + cu_batch_token_offset = np.zeros([real_bsz], dtype=np.int32) + logits = np.random.rand(logits_seqlen, vocab_size).astype(np.float32) + first_token_logits = np.random.rand(real_bsz, vocab_size).astype(np.float32) + + paddle.set_device("cpu") + # 转换为 paddle tensor(保持原有逻辑) + data_cpu = { + "draft_logits": paddle.to_tensor(draft_logits), + "next_token_num": paddle.to_tensor(next_token_num), + "batch_token_num": paddle.to_tensor(batch_token_num), + "cu_next_token_offset": paddle.to_tensor(cu_next_token_offset), + "cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset), + "logits": paddle.to_tensor(logits), + "first_token_logits": paddle.to_tensor(first_token_logits), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + } + + paddle.set_device("xpu:0") + data_xpu = { + "draft_logits": paddle.to_tensor(draft_logits), + "next_token_num": paddle.to_tensor(next_token_num), + "batch_token_num": paddle.to_tensor(batch_token_num), + "cu_next_token_offset": paddle.to_tensor(cu_next_token_offset), + "cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset), + "logits": paddle.to_tensor(logits), + "first_token_logits": paddle.to_tensor(first_token_logits), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + } + + # 恢复默认设备,避免影响其他测试 + paddle.set_device("cpu") + + return data_cpu, data_xpu + + +def speculate_get_logits_execution(test_data): + """测试函数的执行性和输出合理性""" + + # 执行目标函数(核心测试步骤) + speculate_get_logits(**test_data) + + return test_data + + +class TestSpeculateGetLogits(unittest.TestCase): + """ + 测试类,继承自 unittest.TestCase。 + 所有以 'test_' 开头的方法都会被视为测试用例。 + """ + + def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08, target_keys=None): + """ + 自定义的断言方法,用于比较两个 test_data 结构和数据。 + 在 unittest 中,自定义断言通常以 'assert' 开头。 + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + self.assertEqual( + keys1, + keys2, + msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}", + ) + + # 2. 逐字段校验数据 + if target_keys is not None and isinstance(target_keys, list): + local_target_key = target_keys + else: + local_target_key = keys1 + for key in local_target_key: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据 + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!") + else: + # 浮点型:允许 rtol/atol 范围内的误差 + np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!") + + print("✅ 两个 test_data 结构和数据完全一致!") + + def test_speculate_get_logits(self): + """ + 核心测试用例方法。 + 该方法会调用 generate_test_data 获取数据, + 分别在 CPU 和 XPU 上执行测试函数, + 并使用自定义的断言方法比较结果。 + """ + print("\nRunning test: test_speculate_get_logits") + + # 1. 获取测试数据 + data_cpu, data_xpu = generate_test_data() + + # 2. 执行测试函数 + result_xpu = speculate_get_logits_execution(data_xpu) + result_cpu = speculate_get_logits_execution(data_cpu) + + # 3. 断言结果一致 + target_keys = ["draft_logits", "batch_token_num", "cu_batch_token_offset"] + self.assert_test_data_equal(result_cpu, result_xpu, target_keys=target_keys) + + +if __name__ == "__main__": + # 使用 unittest 的主程序来运行所有测试用例 + unittest.main() From c8fccdcf12481652d0c9c5eb34926b8f0a5946ee Mon Sep 17 00:00:00 2001 From: maruoheng Date: Thu, 11 Dec 2025 02:11:38 +0000 Subject: [PATCH 4/5] delete context --- custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc | 3 +++ 1 file changed, 3 insertions(+) diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc index 106ac758d9d..b0608b6e023 100644 --- a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc @@ -56,6 +56,9 @@ void SpeculateGetLogits(const paddle::Tensor& draft_logits, seq_lens_encoder.data(), real_bsz, vocab_size); + if (draft_logits.is_cpu()) { + delete ctx; + } } PD_BUILD_STATIC_OP(speculate_get_logits) From 8ad13e704d9c9048a21030c646971fbcb8c76cbc Mon Sep 17 00:00:00 2001 From: maruoheng Date: Thu, 11 Dec 2025 06:44:17 +0000 Subject: [PATCH 5/5] add ptr check --- .../src/wrapper/mtp_wrapper/speculate_get_logits.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp index 87eea101cbd..0f96ea2e621 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp @@ -139,7 +139,15 @@ int speculate_get_logits(Context* ctx, real_bsz, vocab_size); WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, next_token_num); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, batch_token_num); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_next_token_offset); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_batch_token_offset); + WRAPPER_CHECK_PTR(ctx, float, real_bsz* vocab_size, first_token_logits); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder); WRAPPER_ASSERT_LE(ctx, real_bsz, 256 * 1024 / sizeof(int) / 5); + WRAPPER_ASSERT_GT(ctx, vocab_size, 0); if (ctx->dev().type() == api::kCPU) { return cpu_wrapper(draft_logits, next_token_num,