diff --git a/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc new file mode 100644 index 00000000000..b0608b6e023 --- /dev/null +++ b/custom_ops/xpu_ops/src/ops/mtp/speculate_get_logits.cc @@ -0,0 +1,80 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "paddle/common/flags.h" +#include "paddle/extension.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "xpu/internal/infra_op.h" +#include "xpu/plugin.h" + +#ifndef PD_BUILD_STATIC_OP +#define PD_BUILD_STATIC_OP(name) PD_BUILD_OP(static_op_##name) +#endif + +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder) { + phi::XPUPlace place(phi::backends::xpu::GetXPUCurrentDeviceId()); + auto dev_ctx = paddle::experimental::DeviceContextPool::Instance().Get(place); + baidu::xpu::api::Context* ctx = + static_cast(dev_ctx)->x_context(); + if (draft_logits.is_cpu()) { + ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); + } + const int vocab_size = logits.shape()[1]; + const int real_bsz = seq_lens_this_time.shape()[0]; + + baidu::xpu::api::plugin::speculate_get_logits( + ctx, + const_cast(draft_logits.data()), + const_cast(next_token_num.data()), + const_cast(batch_token_num.data()), + const_cast(cu_next_token_offset.data()), + const_cast(cu_batch_token_offset.data()), + logits.data(), + first_token_logits.data(), + seq_lens_this_time.data(), + seq_lens_encoder.data(), + real_bsz, + vocab_size); + if (draft_logits.is_cpu()) { + delete ctx; + } +} + +PD_BUILD_STATIC_OP(speculate_get_logits) + .Inputs({"draft_logits", + "next_token_num", + "batch_token_num", + "cu_next_token_offset", + "cu_batch_token_offset", + "logits", + "first_token_logits", + "seq_lens_this_time", + "seq_lens_encoder"}) + .Outputs({"draft_logits_out", + "batch_token_num_out", + "cu_batch_token_offset_out"}) + .SetInplaceMap({{"draft_logits", "draft_logits_out"}, + {"batch_token_num", "batch_token_num_out"}, + {"cu_batch_token_offset", "cu_batch_token_offset_out"}}) + .SetKernelFn(PD_KERNEL(SpeculateGetLogits)); diff --git a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc index 0400aa02d7d..dbe62ee1cba 100644 --- a/custom_ops/xpu_ops/src/ops/pybind/pybind.cc +++ b/custom_ops/xpu_ops/src/ops/pybind/pybind.cc @@ -470,6 +470,16 @@ void SpeculateStepPaddle( const int encoder_decoder_block_num, const int max_draft_tokens); +void SpeculateGetLogits(const paddle::Tensor& draft_logits, + const paddle::Tensor& next_token_num, + const paddle::Tensor& batch_token_num, + const paddle::Tensor& cu_next_token_offset, + const paddle::Tensor& cu_batch_token_offset, + const paddle::Tensor& logits, + const paddle::Tensor& first_token_logits, + const paddle::Tensor& seq_lens_this_time, + const paddle::Tensor& seq_lens_encoder); + void SaveOutMmsgStatic(const paddle::Tensor& x, const paddle::Tensor& not_need_stop, int64_t rank_id, @@ -1174,6 +1184,19 @@ PYBIND11_MODULE(fastdeploy_ops, m) { py::arg("max_draft_tokens"), "Step paddle function"); + m.def("speculate_get_logits", + &SpeculateGetLogits, + py::arg("draft_logits"), + py::arg("next_token_num"), + py::arg("batch_token_num"), + py::arg("cu_next_token_offset"), + py::arg("cu_batch_token_offset"), + py::arg("logits"), + py::arg("first_token_logits"), + py::arg("seq_lens_this_time"), + py::arg("seq_lens_encoder"), + "speculate get logits function"); + m.def("text_image_gather_scatter", &TextImageGatherScatter, py::arg("input"), diff --git a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h index bc27a54a94a..a00fa46ec13 100644 --- a/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h +++ b/custom_ops/xpu_ops/src/plugin/include/xpu/plugin.h @@ -600,6 +600,19 @@ DLL_EXPORT int rebuild_self_hidden_states(api::Context* ctx, T* output, int dim_embed, int elem_cnt); + +DLL_EXPORT int speculate_get_logits(Context* ctx, + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size); /*--------------------------------------- MTP end * --------------------------------------------*/ diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu new file mode 100644 index 00000000000..9174b444c69 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/speculate_get_logits.xpu @@ -0,0 +1,131 @@ +#include "xpu/kernel/cluster.h" +#include "xpu/kernel/cluster_partition.h" +#include "xpu/kernel/cluster_primitive.h" + +namespace xpu3 { +namespace plugin { + +__device__ void prefix_sum(__shared_ptr__ int* sm_seq_lens_encoder, + __shared_ptr__ int* sm_seq_lens_this_time, + __shared_ptr__ int* sm_batch_token_num, + __shared_ptr__ int* sm_cu_batch_token_offset, + __shared_ptr__ int* sm_cu_next_token_offset, + __global_ptr__ int* batch_token_num, + __global_ptr__ int* cu_batch_token_offset, + __global_ptr__ const int* seq_lens_this_time, + __global_ptr__ const int* seq_lens_encoder, + const int real_bsz) { + int cid = core_id(); + int clus_id = cluster_id(); + + if (clus_id < real_bsz && cid == 0) { + GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int)); + GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int)); + int next_token_num_previous = 0; + for (int bid = 0; bid < real_bsz; bid++) { + sm_batch_token_num[bid] = + sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid]; + if (bid == 0) { + sm_cu_batch_token_offset[bid] = 0; + sm_cu_next_token_offset[bid] = 0; + } else { + sm_cu_batch_token_offset[bid] = + sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1]; + sm_cu_next_token_offset[bid] = + sm_cu_next_token_offset[bid - 1] + next_token_num_previous; + } + next_token_num_previous = + sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid]; + } + mfence_sm(); + if (clus_id == 0) { + SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int)); + SM2GM_ASYNC(sm_cu_batch_token_offset, + cu_batch_token_offset, + real_bsz * sizeof(int)); + } + } + mfence_sm(); + sync_all(); +} +__global__ void speculate_get_logits(float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + int cid = core_id(); + int ncores = core_num(); + int clus_id = cluster_id(); + int nclusters = cluster_num(); + + int lm_size = 2 * 1024; + int lm_buf_len = lm_size / sizeof(float); + float first_token_logits_now_lm[lm_buf_len]; + float logits_now_lm[lm_buf_len]; + + const int sm_size = 256 * 1024; + __shared__ char sm[sm_size]; + int sm_max_buf_len = 256 * 1024 / sizeof(int); + sm_max_buf_len /= 5; + __shared_ptr__ int* sm_seq_lens_encoder = (__shared_ptr__ int*)sm; + __shared_ptr__ int* sm_seq_lens_this_time = + sm_seq_lens_encoder + sm_max_buf_len; + __shared_ptr__ int* sm_batch_token_num = + sm_seq_lens_this_time + sm_max_buf_len; + __shared_ptr__ int* sm_cu_batch_token_offset = + sm_batch_token_num + sm_max_buf_len; + __shared_ptr__ int* sm_cu_next_token_offset = + sm_cu_batch_token_offset + sm_max_buf_len; + + prefix_sum(sm_seq_lens_encoder, + sm_seq_lens_this_time, + sm_batch_token_num, + sm_cu_batch_token_offset, + sm_cu_next_token_offset, + batch_token_num, + cu_batch_token_offset, + seq_lens_this_time, + seq_lens_encoder, + real_bsz); + + for (int bid = clus_id; bid < real_bsz; bid += nclusters) { + auto* draft_logits_now = + draft_logits + sm_cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + sm_cu_next_token_offset[bid] * vocab_size; + auto* first_token_logits_now = first_token_logits + bid * vocab_size; + + for (int i = cid * lm_buf_len; i < vocab_size; i += ncores * lm_buf_len) { + int read_len = min(lm_buf_len, vocab_size - i); + if (sm_seq_lens_encoder[bid] > 0) { + GM2LM_ASYNC(first_token_logits_now + i, + first_token_logits_now_lm, + read_len * sizeof(float)); + GM2LM(logits_now + i, logits_now_lm, read_len * sizeof(float)); + LM2GM_ASYNC(first_token_logits_now_lm, + draft_logits_now + i, + read_len * sizeof(float)); + LM2GM(logits_now_lm, + draft_logits_now + vocab_size + i, + read_len * sizeof(float)); + } else { + for (int j = 0; j < sm_seq_lens_this_time[bid]; j++) { + GM2LM(logits_now + j * vocab_size + i, + logits_now_lm, + read_len * sizeof(float)); + LM2GM(logits_now_lm, + draft_logits_now + j * vocab_size + i, + read_len * sizeof(float)); + } + } + } + } +} + +} // namespace plugin +} // namespace xpu3 diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp new file mode 100644 index 00000000000..0f96ea2e621 --- /dev/null +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp @@ -0,0 +1,184 @@ +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include "xpu/plugin.h" +#include "xpu/refactor/impl_public/wrapper_check.h" + +namespace xpu3 { +namespace plugin { + +__attribute__((global)) void speculate_get_logits( + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size); +} // namespace plugin +} // namespace xpu3 + +namespace baidu { +namespace xpu { +namespace api { +namespace plugin { + +static int cpu_wrapper(float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + int batch_token_num_sum = 0; + int next_token_num_sum = 0; + for (int bid = 0; bid < real_bsz; bid++) { + // prefix sum + cu_batch_token_offset[bid] = batch_token_num_sum; + cu_next_token_offset[bid] = next_token_num_sum; + + batch_token_num[bid] = + seq_lens_encoder[bid] > 0 ? 2 : seq_lens_this_time[bid]; + next_token_num[bid] = + seq_lens_encoder[bid] > 0 ? 1 : seq_lens_this_time[bid]; + + batch_token_num_sum += batch_token_num[bid]; + next_token_num_sum += next_token_num[bid]; + + auto* draft_logits_now = + draft_logits + cu_batch_token_offset[bid] * vocab_size; + auto* logits_now = logits + cu_next_token_offset[bid] * vocab_size; + auto* first_token_logits_now = first_token_logits + bid * vocab_size; + for (int i = 0; i < vocab_size; i++) { + if (seq_lens_encoder[bid] > 0) { + draft_logits_now[i] = first_token_logits_now[i]; + draft_logits_now[vocab_size + i] = logits_now[i]; + } else { + for (int j = 0; j < seq_lens_this_time[bid]; j++) { + draft_logits_now[j * vocab_size + i] = logits_now[j * vocab_size + i]; + } + } + } + } + return api::SUCCESS; +} + +static int xpu3_wrapper(Context* ctx, + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + xpu3::plugin::speculate_get_logits<<ncluster(), 64, ctx->xpu_stream>>>( + draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + return api::SUCCESS; +} + +int speculate_get_logits(Context* ctx, + float* draft_logits, + int* next_token_num, + int* batch_token_num, + int* cu_next_token_offset, + int* cu_batch_token_offset, + const float* logits, + const float* first_token_logits, + const int* seq_lens_this_time, + const int* seq_lens_encoder, + const int real_bsz, + const int vocab_size) { + WRAPPER_CHECK_CTX(ctx); + WRAPPER_DUMP_FUNCTION_T1(ctx, "speculate_get_logits", float); + WRAPPER_DUMP_PARAM6(ctx, + draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits); + WRAPPER_DUMP_PARAM5(ctx, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + WRAPPER_DUMP(ctx); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, next_token_num); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, batch_token_num); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_next_token_offset); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_batch_token_offset); + WRAPPER_CHECK_PTR(ctx, float, real_bsz* vocab_size, first_token_logits); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time); + WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder); + WRAPPER_ASSERT_LE(ctx, real_bsz, 256 * 1024 / sizeof(int) / 5); + WRAPPER_ASSERT_GT(ctx, vocab_size, 0); + if (ctx->dev().type() == api::kCPU) { + return cpu_wrapper(draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + } + if (ctx->dev().type() == api::kXPU3) { + return xpu3_wrapper(ctx, + draft_logits, + next_token_num, + batch_token_num, + cu_next_token_offset, + cu_batch_token_offset, + logits, + first_token_logits, + seq_lens_this_time, + seq_lens_encoder, + real_bsz, + vocab_size); + } + WRAPPER_UNIMPLEMENTED(ctx); +} + +} // namespace plugin +} // namespace api +} // namespace xpu +} // namespace baidu diff --git a/custom_ops/xpu_ops/test/test_speculate_get_logits.py b/custom_ops/xpu_ops/test/test_speculate_get_logits.py new file mode 100644 index 00000000000..07a207aa701 --- /dev/null +++ b/custom_ops/xpu_ops/test/test_speculate_get_logits.py @@ -0,0 +1,172 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import paddle + +from fastdeploy.model_executor.ops.xpu import speculate_get_logits + +# 固定随机种子,保证测试可复现 +np.random.seed(2023) +paddle.seed(2023) + + +def generate_test_data(): + """ + 生成测试数据的辅助函数。 + 这部分逻辑从 pytest 的 fixture 转换而来,作为一个普通函数供测试方法调用。 + """ + real_bsz = 64 + vocab_size = 2 * 1024 + max_seq_len = 8 * 1024 + + # 生成原始测试数据(完全复用原有逻辑) + seq_lens_encoder = np.random.randint(0, 2, [real_bsz], dtype=np.int32) + seq_lens_this_time = np.random.randint(1, max_seq_len, [real_bsz], dtype=np.int32) + draft_logits_seqlen = 0 + logits_seqlen = 0 + for i in range(real_bsz): + if seq_lens_encoder[i] > 0: + draft_logits_seqlen += 2 + logits_seqlen += 1 + else: + draft_logits_seqlen += seq_lens_this_time[i] + logits_seqlen += seq_lens_this_time[i] + + draft_logits = np.zeros([draft_logits_seqlen, vocab_size], dtype=np.float32) + next_token_num = np.zeros([real_bsz], dtype=np.int32) + batch_token_num = np.zeros([real_bsz], dtype=np.int32) + cu_next_token_offset = np.zeros([real_bsz], dtype=np.int32) + cu_batch_token_offset = np.zeros([real_bsz], dtype=np.int32) + logits = np.random.rand(logits_seqlen, vocab_size).astype(np.float32) + first_token_logits = np.random.rand(real_bsz, vocab_size).astype(np.float32) + + paddle.set_device("cpu") + # 转换为 paddle tensor(保持原有逻辑) + data_cpu = { + "draft_logits": paddle.to_tensor(draft_logits), + "next_token_num": paddle.to_tensor(next_token_num), + "batch_token_num": paddle.to_tensor(batch_token_num), + "cu_next_token_offset": paddle.to_tensor(cu_next_token_offset), + "cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset), + "logits": paddle.to_tensor(logits), + "first_token_logits": paddle.to_tensor(first_token_logits), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + } + + paddle.set_device("xpu:0") + data_xpu = { + "draft_logits": paddle.to_tensor(draft_logits), + "next_token_num": paddle.to_tensor(next_token_num), + "batch_token_num": paddle.to_tensor(batch_token_num), + "cu_next_token_offset": paddle.to_tensor(cu_next_token_offset), + "cu_batch_token_offset": paddle.to_tensor(cu_batch_token_offset), + "logits": paddle.to_tensor(logits), + "first_token_logits": paddle.to_tensor(first_token_logits), + "seq_lens_this_time": paddle.to_tensor(seq_lens_this_time), + "seq_lens_encoder": paddle.to_tensor(seq_lens_encoder), + } + + # 恢复默认设备,避免影响其他测试 + paddle.set_device("cpu") + + return data_cpu, data_xpu + + +def speculate_get_logits_execution(test_data): + """测试函数的执行性和输出合理性""" + + # 执行目标函数(核心测试步骤) + speculate_get_logits(**test_data) + + return test_data + + +class TestSpeculateGetLogits(unittest.TestCase): + """ + 测试类,继承自 unittest.TestCase。 + 所有以 'test_' 开头的方法都会被视为测试用例。 + """ + + def assert_test_data_equal(self, test_data1, test_data2, rtol=1e-05, atol=1e-08, target_keys=None): + """ + 自定义的断言方法,用于比较两个 test_data 结构和数据。 + 在 unittest 中,自定义断言通常以 'assert' 开头。 + """ + # 1. 先校验两个 test_data 的字段名完全一致 + keys1 = set(test_data1.keys()) + keys2 = set(test_data2.keys()) + self.assertEqual( + keys1, + keys2, + msg=f"两个 test_data 字段不一致!\n仅在第一个中存在:{keys1 - keys2}\n仅在第二个中存在:{keys2 - keys1}", + ) + + # 2. 逐字段校验数据 + if target_keys is not None and isinstance(target_keys, list): + local_target_key = target_keys + else: + local_target_key = keys1 + for key in local_target_key: + data1 = test_data1[key] + data2 = test_data2[key] + + # 区分:paddle Tensor(需转 numpy)和 普通标量/数组(直接使用) + if isinstance(data1, paddle.Tensor): + np1 = data1.detach().cpu().numpy() + else: + np1 = np.asarray(data1) + + if isinstance(data2, paddle.Tensor): + np2 = data2.detach().cpu().numpy() + else: + np2 = np.asarray(data2) + + # 3. 校验数据 + if np1.dtype in (np.bool_, np.int8, np.int16, np.int32, np.int64, np.uint8): + # 布尔/整数型:必须完全相等 + np.testing.assert_array_equal(np1, np2, err_msg=f"字段 {key} 数据不一致!") + else: + # 浮点型:允许 rtol/atol 范围内的误差 + np.testing.assert_allclose(np1, np2, rtol=rtol, atol=atol, err_msg=f"字段 {key} 浮点数据不一致!") + + print("✅ 两个 test_data 结构和数据完全一致!") + + def test_speculate_get_logits(self): + """ + 核心测试用例方法。 + 该方法会调用 generate_test_data 获取数据, + 分别在 CPU 和 XPU 上执行测试函数, + 并使用自定义的断言方法比较结果。 + """ + print("\nRunning test: test_speculate_get_logits") + + # 1. 获取测试数据 + data_cpu, data_xpu = generate_test_data() + + # 2. 执行测试函数 + result_xpu = speculate_get_logits_execution(data_xpu) + result_cpu = speculate_get_logits_execution(data_cpu) + + # 3. 断言结果一致 + target_keys = ["draft_logits", "batch_token_num", "cu_batch_token_offset"] + self.assert_test_data_equal(result_cpu, result_xpu, target_keys=target_keys) + + +if __name__ == "__main__": + # 使用 unittest 的主程序来运行所有测试用例 + unittest.main()