diff --git a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu index b8d70544a47..e1c13528ddd 100644 --- a/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu +++ b/custom_ops/xpu_ops/src/plugin/src/kernel/kunlun3cpp/mtp_kernel/compute_order.xpu @@ -35,15 +35,6 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, int out_offset = 0; for (int i = 0; i < bsz; i += buf_size) { int64_t read_size = min(static_cast(bsz - i), buf_size); - GM2LM_ASYNC(base_model_seq_lens_this_time + i, - lm_base_model_seq_lens_this_time, - read_size * sizeof(int)); - GM2LM_ASYNC(base_model_seq_lens_encoder + i, - lm_base_model_seq_lens_encoder, - read_size * sizeof(int)); - GM2LM_ASYNC( - seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int)); - GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int)); GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int)); for (int j = 0; j < read_size; j++) { int cur_base_model_seq_lens_this_time = @@ -69,6 +60,32 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time, in_offset += write_size; } mfence_lm(); + } + } + } + + for (int i = 0; i < bsz; i += buf_size) { + int64_t read_size = min(static_cast(bsz - i), buf_size); + GM2LM_ASYNC(base_model_seq_lens_this_time + i, + lm_base_model_seq_lens_this_time, + read_size * sizeof(int)); + GM2LM_ASYNC(base_model_seq_lens_encoder + i, + lm_base_model_seq_lens_encoder, + read_size * sizeof(int)); + GM2LM_ASYNC( + seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int)); + GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int)); + GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int)); + for (int j = 0; j < read_size; j++) { + int cur_base_model_seq_lens_this_time = + lm_base_model_seq_lens_this_time[j]; + int cur_base_model_seq_lens_encoder = lm_base_model_seq_lens_encoder[j]; + int cur_seq_lens_this_time = lm_seq_lens_this_time[j]; + int accept_num = lm_accept_nums[j]; + int cur_seq_lens_encoder = lm_seq_lens_encoder[j]; + // 1. eagle encoder. Base step=1 + if (cur_seq_lens_encoder > 0) { + continue; // 2. Base model stop at last verify-step. } else if (cur_base_model_seq_lens_this_time != 0 && cur_seq_lens_this_time == 0) { diff --git a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp index 64a45ad9b17..4f830ae4a74 100644 --- a/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp +++ b/custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/compute_order.cpp @@ -49,6 +49,17 @@ static int cpu_wrapper(Context* ctx, const int input_token_num) { int in_offset = 0; // input_offset(long) int out_offset = 0; // output_offset(short) + + // for support mix, encoder need set first + for (int i = 0; i < bsz; ++i) { + int cur_seq_lens_encoder = seq_lens_encoder[i]; + if (cur_seq_lens_encoder > 0) { + for (int j = 0; j < cur_seq_lens_encoder; j++) { + position_map[in_offset++] = out_offset++; + } + } + } + for (int i = 0; i < bsz; ++i) { int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]; int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i]; @@ -58,9 +69,7 @@ static int cpu_wrapper(Context* ctx, // 1. eagle encoder. Base step=1 if (cur_seq_lens_encoder > 0) { - for (int j = 0; j < cur_seq_lens_encoder; j++) { - position_map[in_offset++] = out_offset++; - } + continue; // 2. base model encoder. Base step=0 } else if (cur_base_model_seq_lens_encoder != 0) { // nothing happens diff --git a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py index ac68a53e367..bd5636a919d 100644 --- a/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py +++ b/custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py @@ -1,6 +1,6 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); +# Licensed under the Apache License, Version 2.0 (the "License") # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -12,79 +12,149 @@ # See the License for the specific language governing permissions and # limitations under the License. +import unittest + import numpy as np import paddle from fastdeploy.model_executor.ops.xpu import eagle_get_hidden_states -def test_eagle_get_hidden_states(): - bs = np.random.randint(1, 8 + 1, dtype=np.int32) - input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32) - dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32) - actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32) - - seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32) - seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) - accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32) - base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) - base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32) - # dont care - seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) - stop_flags = np.random.randint(0, 2, bs, dtype=np.int32) - - seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32) - seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32) - accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32) - base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32) - base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32) - # dont care - seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32) - stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32) - - # fp32 test - input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) - input_tensor = paddle.to_tensor(input, dtype=paddle.float32) - cpu_out = eagle_get_hidden_states( - input_tensor.cpu(), - seq_lens_this_time_tensor.cpu(), - seq_lens_encoder_tensor.cpu(), - seq_lens_decoder_tensor.cpu(), - stop_flags_tensor.cpu(), - accept_nums_tensor.cpu(), - base_model_seq_lens_this_time_tensor.cpu(), - base_model_seq_lens_encoder_tensor.cpu(), - actual_draft_token_num, - ) - xpu_out = eagle_get_hidden_states( - input_tensor, - seq_lens_this_time_tensor, - seq_lens_encoder_tensor, - seq_lens_decoder_tensor, - stop_flags_tensor, - accept_nums_tensor, - base_model_seq_lens_this_time_tensor, - base_model_seq_lens_encoder_tensor, +def ComputeOrderKernel( + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, + actual_draft_token_num, + input_token_num, +): + in_offset = 0 + out_offset = 0 + # set encoder position map first + for i in range(bsz): + cur_seq_lens_encoder = seq_lens_encoder[i] + cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] + if cur_seq_lens_encoder > 0: + for j in range(cur_seq_lens_encoder): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 + + for i in range(bsz): + cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i] + # cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i] + cur_seq_lens_this_time = seq_lens_this_time[i] + accept_num = accept_nums[i] + cur_seq_lens_encoder = seq_lens_encoder[i] + # 1. eagle encoder. Base step=1 + if cur_seq_lens_encoder > 0: + continue + # 2. Base model stop at last verify-step. + elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0: + in_offset += cur_base_model_seq_lens_this_time + # 4. stopped + elif cur_base_model_seq_lens_this_time == 0 and cur_seq_lens_this_time == 0: # end + pass + else: + for i in range(accept_num): + position_map[in_offset] = out_offset + in_offset += 1 + out_offset += 1 + in_offset += cur_base_model_seq_lens_this_time - accept_num + output_token_num[0] = out_offset + + +def rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt): + for elem_idx in range(elem_cnt): + ori_token_idx = int(elem_idx / dim_embed) + token_idx = position_map[ori_token_idx] + if token_idx >= 0: + offset = elem_idx % dim_embed + out[token_idx][offset] = input[ori_token_idx][offset] + + +def eagle_get_hidden_states_ref( + input, + seq_lens_this_time, + seq_lens_encoder, + seq_lens_decoder, + stop_flags, + accept_nums, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + actual_draft_token_num, +): + input_token_num = input.shape[0] + dim_embed = input.shape[1] + bsz = seq_lens_this_time.shape[0] + position_map = paddle.full([input_token_num], 0xFFFFFFFF, seq_lens_this_time.dtype) + output_token_num = paddle.empty([1], seq_lens_this_time.dtype) + ComputeOrderKernel( + seq_lens_this_time, + seq_lens_encoder, + base_model_seq_lens_this_time, + base_model_seq_lens_encoder, + accept_nums, + position_map, + output_token_num, + bsz, actual_draft_token_num, + input_token_num, ) - assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) - - # bf16/fp16 test - for dtype in [paddle.bfloat16, paddle.float16]: - input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16) - input_tensor = paddle.to_tensor(input, dtype=dtype) - cpu_out = eagle_get_hidden_states( - input_tensor.cpu(), - seq_lens_this_time_tensor.cpu(), - seq_lens_encoder_tensor.cpu(), - seq_lens_decoder_tensor.cpu(), - stop_flags_tensor.cpu(), - accept_nums_tensor.cpu(), - base_model_seq_lens_this_time_tensor.cpu(), - base_model_seq_lens_encoder_tensor.cpu(), + + output_token_num_cpu = output_token_num[0] + out = paddle.empty([output_token_num_cpu, dim_embed], input.dtype) + elem_cnt = input_token_num * dim_embed + rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt) + return out + + +class TestEagleGetHiddenStates(unittest.TestCase): + def test_eagle_get_hidden_states(self): + np.random.seed(2023) + paddle.seed(2023) + bs = 2 + input_token_num = 10 + dim_embed = 512 + actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32) + + seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32) + seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32) + base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32) + + seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32) + stop_flags = np.random.randint(0, 2, bs, dtype=np.int32) + + seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32) + seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32) + accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32) + base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32) + base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32) + + seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32) + stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32) + + input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32) + input_tensor = paddle.to_tensor(input, dtype=paddle.float16) + out = eagle_get_hidden_states( + input_tensor, + seq_lens_this_time_tensor, + seq_lens_encoder_tensor, + seq_lens_decoder_tensor, + stop_flags_tensor, + accept_nums_tensor, + base_model_seq_lens_this_time_tensor, + base_model_seq_lens_encoder_tensor, actual_draft_token_num, ) - xpu_out = eagle_get_hidden_states( + + out_ref = eagle_get_hidden_states_ref( input_tensor, seq_lens_this_time_tensor, seq_lens_encoder_tensor, @@ -95,10 +165,8 @@ def test_eagle_get_hidden_states(): base_model_seq_lens_encoder_tensor, actual_draft_token_num, ) - assert np.allclose(cpu_out.numpy(), xpu_out.numpy()) - - print("All test passed") + np.testing.assert_allclose(out.numpy(), out_ref.numpy()) if __name__ == "__main__": - test_eagle_get_hidden_states() + unittest.main()