From 0e0f8db3239f28ae7092db9ef94fc54a30d5fa87 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Thu, 11 Dec 2025 10:29:56 +0000 Subject: [PATCH 1/3] fix git hidden states --- .../kunlun3cpp/mtp_kernel/compute_order.xpu | 35 ++- .../src/wrapper/mtp_wrapper/compute_order.cpp | 15 +- .../test/test_eagle_get_hidden_states.py | 200 ++++++++++++------ .../operators/test_eagle_get_hidden_states.py | 15 +- 4 files changed, 180 insertions(+), 85 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu index b8d70544a47..1259cc6b5f5 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu @@ -35,15 +35,6 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, int out_offset = 0; for (int i = 0; i < bsz; i += buf_size) { int64_t read_size = min(static_cast(bsz - i), buf_size); - GM2LM_ASYNC(base_model_seq_lens_this_time + i, - lm_base_model_seq_lens_this_time, - read_size * sizeof(int)); - GM2LM_ASYNC(base_model_seq_lens_encoder + i, - lm_base_model_seq_lens_encoder, - read_size * sizeof(int)); - GM2LM_ASYNC( - seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int)); - GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int)); GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int)); for (int j = 0; j < read_size; j++) { int cur_base_model_seq_lens_this_time = @@ -69,6 +60,32 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, in_offset += write_size; } mfence_lm(); + } + } + } + + for (int i = 0; i < bsz; i += buf_size) { + int64_t read_size = min(static_cast(bsz - i), buf_size); + GM2LM_ASYNC(base_model_seq_lens_this_time + i, + lm_base_model_seq_lens_this_time, + read_size * sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_encoder + i, + lm_base_model_seq_lens_encoder, + read_size * sizeof(int)); + GM2LM_ASYNC( + seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int)); + GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int)); + GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int)); + for (int j = 0; j < read_size; j++) { + int cur_base_model_seq_lens_this_time = + lm_base_model_seq_lens_this_time[j]; + int cur_base_model_seq_lens_encoder = lm_base_model_seq_lens_encoder[j]; + int cur_seq_lens_this_time = lm_seq_lens_this_time[j]; + int accept_num = lm_accept_nums[j]; + int cur_seq_lens_encoder = lm_seq_lens_encoder[j]; + // 1. eagle encoder. Base step=1 + if (cur_seq_lens_encoder > 0) { + continue; // 2. Base model stop at last verify-step. } else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) { diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp index 64a45ad9b17..4f830ae4a74 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp @@ -49,6 +49,17 @@ static int cpu_wrapper(Context* ctx, const int input_token_num) { int in_offset = 0; // input_offset(long) int out_offset = 0; // output_offset(short) + + // for support mix, encoder need set first + for (int i = 0; i < bsz; ++i) { + int cur_seq_lens_encoder = seq_lens_encoder[i]; + if (cur_seq_lens_encoder > 0) { + for (int j = 0; j < cur_seq_lens_encoder; j++) { + position_map[in_offset++] = out_offset++; + } + } + } + for (int i = 0; i < bsz; ++i) { int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]; int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i]; @@ -58,9 +69,7 @@ static int cpu_wrapper(Context* ctx, // 1. eagle encoder. Base step=1 if (cur_seq_lens_encoder > 0) { - for (int j = 0; j < cur_seq_lens_encoder; j++) { - position_map[in_offset++] = out_offset++; - } + continue; // 2. base model encoder. Base step=0 } else if (cur_base_model_seq_lens_encoder != 0) { // nothing happens diff --git a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py index ac68a53e367..cb4969f87b1 100644 --- a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py +++ b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py @@ -1,6 +1,6 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License") # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -12,79 +12,143 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.xpu import eagle_get_hidden_states -def test_eagle_get_hidden_states(): - bs = np.random.randint(1, 8 + 1, dtype=np.int32) - input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32) - dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32) - actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32) - - seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32) - seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) - accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32) - base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) - base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32) - # dont care - seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) - stop_flags = np.random.randint(0, 2, bs, dtype=np.int32) - - seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32) - seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32) - accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32) - base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32) - base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32) - # dont care - seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32) - stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32) - - # fp32 test - input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) - input_tensor = paddle.to_tensor(input, dtype=paddle.float32) - cpu_out = eagle_get_hidden_states( - input_tensor.cpu(), - seq_lens_this_time_tensor.cpu(), - seq_lens_encoder_tensor.cpu(), - seq_lens_decoder_tensor.cpu(), - stop_flags_tensor.cpu(), - accept_nums_tensor.cpu(), - base_model_seq_lens_this_time_tensor.cpu(), - base_model_seq_lens_encoder_tensor.cpu(), - actual_draft_token_num, - ) - xpu_out = eagle_get_hidden_states( - input_tensor, - seq_lens_this_time_tensor, - seq_lens_encoder_tensor, - seq_lens_decoder_tensor, - stop_flags_tensor, - accept_nums_tensor, - base_model_seq_lens_this_time_tensor, - base_model_seq_lens_encoder_tensor, +def ComputeOrderKernel( + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, + actual_draft_token_num, + input_token_num, +): + in_offset = 0 + out_offset = 0 + for i in range(bsz): + cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] + # cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i] + cur_seq_lens_this_time = seq_lens_this_time[i] + accept_num = accept_nums[i] + cur_seq_lens_encoder = seq_lens_encoder[i] + # 1. eagle encoder. Base step=1 + if cur_seq_lens_encoder > 0: + for j in range(cur_seq_lens_encoder): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 + # 2. Base model stop at last verify-step. + elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0: + in_offset += cur_base_model_seq_lens_this_time + # 4. stopped + elif cur_base_model_seq_lens_this_time == 0 and cur_seq_lens_this_time == 0: # end + pass + else: + for i in range(accept_num): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 + in_offset += cur_base_model_seq_lens_this_time - accept_num + output_token_num[0] = out_offset + + +def rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt): + for elem_idx in range(elem_cnt): + ori_token_idx = int(elem_idx / dim_embed) + token_idx = position_map[ori_token_idx] + if token_idx >= 0: + offset = elem_idx % dim_embed + out[token_idx][offset] = input[ori_token_idx][offset] + + +def eagle_get_hidden_states_ref( + input, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + stop_flags, + accept_nums, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + actual_draft_token_num, +): + input_token_num = input.shape[0] + dim_embed = input.shape[1] + bsz = seq_lens_this_time.shape[0] + position_map = paddle.full([input_token_num], 0xFFFFFFFF, seq_lens_this_time.dtype) + output_token_num = paddle.empty([1], seq_lens_this_time.dtype) + ComputeOrderKernel( + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, actual_draft_token_num, + input_token_num, ) - assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) - - # bf16/fp16 test - for dtype in [paddle.bfloat16, paddle.float16]: - input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16) - input_tensor = paddle.to_tensor(input, dtype=dtype) - cpu_out = eagle_get_hidden_states( - input_tensor.cpu(), - seq_lens_this_time_tensor.cpu(), - seq_lens_encoder_tensor.cpu(), - seq_lens_decoder_tensor.cpu(), - stop_flags_tensor.cpu(), - accept_nums_tensor.cpu(), - base_model_seq_lens_this_time_tensor.cpu(), - base_model_seq_lens_encoder_tensor.cpu(), + + output_token_num_cpu = output_token_num[0] + out = paddle.empty([output_token_num_cpu, dim_embed], input.dtype) + elem_cnt = input_token_num * dim_embed + rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt) + return out + + +class TestEagleGetHiddenStates(unittest.TestCase): + def test_eagle_get_hidden_states(self): + np.random.seed(2023) + paddle.seed(2023) + bs = 2 + input_token_num = 10 + dim_embed = 512 + actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32) + + seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32) + seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32) + base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32) + + seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + stop_flags = np.random.randint(0, 2, bs, dtype=np.int32) + + seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32) + accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32) + base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32) + base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32) + + seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32) + stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32) + + input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) + input_tensor = paddle.to_tensor(input, dtype=paddle.float16) + print("input_tensor:", input_tensor) + out = eagle_get_hidden_states( + input_tensor, + seq_lens_this_time_tensor, + seq_lens_encoder_tensor, + seq_lens_decoder_tensor, + stop_flags_tensor, + accept_nums_tensor, + base_model_seq_lens_this_time_tensor, + base_model_seq_lens_encoder_tensor, actual_draft_token_num, ) - xpu_out = eagle_get_hidden_states( + + out_ref = eagle_get_hidden_states_ref( input_tensor, seq_lens_this_time_tensor, seq_lens_encoder_tensor, @@ -95,10 +159,8 @@ def test_eagle_get_hidden_states(): base_model_seq_lens_encoder_tensor, actual_draft_token_num, ) - assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) - - print("All test passed") + np.testing.assert_allclose(out.numpy(), out_ref.numpy()) if __name__ == "__main__": - test_eagle_get_hidden_states() + unittest.main() diff --git a/tests/operators/test_eagle_get_hidden_states.py b/tests/operators/test_eagle_get_hidden_states.py index 0fb893c1384..e9c8ad35627 100644 --- a/tests/operators/test_eagle_get_hidden_states.py +++ b/tests/operators/test_eagle_get_hidden_states.py @@ -34,6 +34,16 @@ def ComputeOrderKernel( ): in_offset = 0 out_offset = 0 + # set encoder first + for i in range(bsz): + cur_seq_lens_encoder = seq_lens_encoder[i] + cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] + if cur_seq_lens_encoder > 0: + for j in range(cur_seq_lens_encoder): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 + for i in range(bsz): cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] # cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i] @@ -42,10 +52,7 @@ def ComputeOrderKernel( cur_seq_lens_encoder = seq_lens_encoder[i] # 1. eagle encoder. Base step=1 if cur_seq_lens_encoder > 0: - for j in range(cur_seq_lens_encoder): - position_map[in_offset] = out_offset - in_offset += 1 - out_offset += 1 + continue # 2. Base model stop at last verify-step. elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0: in_offset += cur_base_model_seq_lens_this_time From d39f67c3b125acc7f00b76c417f275ebb68e31e9 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Thu, 11 Dec 2025 10:52:55 +0000 Subject: [PATCH 2/3] fix code style --- .../plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu | 2 +- tests/operators/test_eagle_get_hidden_states.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu index 1259cc6b5f5..e1c13528ddd 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu @@ -60,7 +60,7 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, in_offset += write_size; } mfence_lm(); - } + } } } diff --git a/tests/operators/test_eagle_get_hidden_states.py b/tests/operators/test_eagle_get_hidden_states.py index e9c8ad35627..131f68ff813 100644 --- a/tests/operators/test_eagle_get_hidden_states.py +++ b/tests/operators/test_eagle_get_hidden_states.py @@ -52,7 +52,7 @@ def ComputeOrderKernel( cur_seq_lens_encoder = seq_lens_encoder[i] # 1. eagle encoder. Base step=1 if cur_seq_lens_encoder > 0: - continue + continue # 2. Base model stop at last verify-step. elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0: in_offset += cur_base_model_seq_lens_this_time From e12d2ca03f47a465ddb3679ff24855ab6a9caba9 Mon Sep 17 00:00:00 2001 From: cmcamdy <1027740945@qq.com> Date: Thu, 11 Dec 2025 10:58:53 +0000 Subject: [PATCH 3/3] fix code style --- .../xpu_ops/test/test_eagle_get_hidden_states.py | 16 +++++++++++----- tests/operators/test_eagle_get_hidden_states.py | 15 ++++----------- 2 files changed, 15 insertions(+), 16 deletions(-) diff --git a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py index cb4969f87b1..bd5636a919d 100644 --- a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py +++ b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py @@ -34,6 +34,16 @@ def ComputeOrderKernel( ): in_offset = 0 out_offset = 0 + # set encoder position map first + for i in range(bsz): + cur_seq_lens_encoder = seq_lens_encoder[i] + cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] + if cur_seq_lens_encoder > 0: + for j in range(cur_seq_lens_encoder): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 + for i in range(bsz): cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] # cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i] @@ -42,10 +52,7 @@ def ComputeOrderKernel( cur_seq_lens_encoder = seq_lens_encoder[i] # 1. eagle encoder. Base step=1 if cur_seq_lens_encoder > 0: - for j in range(cur_seq_lens_encoder): - position_map[in_offset] = out_offset - in_offset += 1 - out_offset += 1 + continue # 2. Base model stop at last verify-step. elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0: in_offset += cur_base_model_seq_lens_this_time @@ -135,7 +142,6 @@ def test_eagle_get_hidden_states(self): input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) input_tensor = paddle.to_tensor(input, dtype=paddle.float16) - print("input_tensor:", input_tensor) out = eagle_get_hidden_states( input_tensor, seq_lens_this_time_tensor, diff --git a/tests/operators/test_eagle_get_hidden_states.py b/tests/operators/test_eagle_get_hidden_states.py index 131f68ff813..0fb893c1384 100644 --- a/tests/operators/test_eagle_get_hidden_states.py +++ b/tests/operators/test_eagle_get_hidden_states.py @@ -34,16 +34,6 @@ def ComputeOrderKernel( ): in_offset = 0 out_offset = 0 - # set encoder first - for i in range(bsz): - cur_seq_lens_encoder = seq_lens_encoder[i] - cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] - if cur_seq_lens_encoder > 0: - for j in range(cur_seq_lens_encoder): - position_map[in_offset] = out_offset - in_offset += 1 - out_offset += 1 - for i in range(bsz): cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] # cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i] @@ -52,7 +42,10 @@ def ComputeOrderKernel( cur_seq_lens_encoder = seq_lens_encoder[i] # 1. eagle encoder. Base step=1 if cur_seq_lens_encoder > 0: - continue + for j in range(cur_seq_lens_encoder): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 # 2. Base model stop at last verify-step. elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0: in_offset += cur_base_model_seq_lens_this_time