Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,6 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
int out_offset = 0;
for (int i = 0; i < bsz; i += buf_size) {
int64_t read_size = min(static_cast<int64_t>(bsz - i), buf_size);
GM2LM_ASYNC(base_model_seq_lens_this_time + i,
lm_base_model_seq_lens_this_time,
read_size * sizeof(int));
GM2LM_ASYNC(base_model_seq_lens_encoder + i,
lm_base_model_seq_lens_encoder,
read_size * sizeof(int));
GM2LM_ASYNC(
seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int));
GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int));
GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int));
for (int j = 0; j < read_size; j++) {
int cur_base_model_seq_lens_this_time =
Expand All @@ -69,6 +60,32 @@ __global__ void ComputeOrderKernel(const int* seq_lens_this_time,
in_offset += write_size;
}
mfence_lm();
}
}
}

for (int i = 0; i < bsz; i += buf_size) {
int64_t read_size = min(static_cast<int64_t>(bsz - i), buf_size);
GM2LM_ASYNC(base_model_seq_lens_this_time + i,
lm_base_model_seq_lens_this_time,
read_size * sizeof(int));
GM2LM_ASYNC(base_model_seq_lens_encoder + i,
lm_base_model_seq_lens_encoder,
read_size * sizeof(int));
GM2LM_ASYNC(
seq_lens_this_time + i, lm_seq_lens_this_time, read_size * sizeof(int));
GM2LM_ASYNC(accept_nums + i, lm_accept_nums, read_size * sizeof(int));
GM2LM(seq_lens_encoder + i, lm_seq_lens_encoder, read_size * sizeof(int));
for (int j = 0; j < read_size; j++) {
int cur_base_model_seq_lens_this_time =
lm_base_model_seq_lens_this_time[j];
int cur_base_model_seq_lens_encoder = lm_base_model_seq_lens_encoder[j];
int cur_seq_lens_this_time = lm_seq_lens_this_time[j];
int accept_num = lm_accept_nums[j];
int cur_seq_lens_encoder = lm_seq_lens_encoder[j];
// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
continue;
// 2. Base model stop at last verify-step.
} else if (cur_base_model_seq_lens_this_time != 0 &&
cur_seq_lens_this_time == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,17 @@ static int cpu_wrapper(Context* ctx,
const int input_token_num) {
int in_offset = 0; // input_offset(long)
int out_offset = 0; // output_offset(short)

// for support mix, encoder need set first
for (int i = 0; i < bsz; ++i) {
int cur_seq_lens_encoder = seq_lens_encoder[i];
if (cur_seq_lens_encoder > 0) {
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
}
}

for (int i = 0; i < bsz; ++i) {
int cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i];
int cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i];
Expand All @@ -58,9 +69,7 @@ static int cpu_wrapper(Context* ctx,

// 1. eagle encoder. Base step=1
if (cur_seq_lens_encoder > 0) {
for (int j = 0; j < cur_seq_lens_encoder; j++) {
position_map[in_offset++] = out_offset++;
}
continue;
// 2. base model encoder. Base step=0
} else if (cur_base_model_seq_lens_encoder != 0) {
// nothing happens
Expand Down
206 changes: 137 additions & 69 deletions custom_ops/xpu_ops/test/test_eagle_get_hidden_states.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
Expand All @@ -12,79 +12,149 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np
import paddle

from fastdeploy.model_executor.ops.xpu import eagle_get_hidden_states


def test_eagle_get_hidden_states():
bs = np.random.randint(1, 8 + 1, dtype=np.int32)
input_token_num = np.random.randint(2 * 1024, 4 * 1024 + 1, dtype=np.int32)
dim_embed = np.random.randint(1, 4 * 1024 + 1, dtype=np.int32)
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)

seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)
# dont care
seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)

seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)
# dont care
seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)

# fp32 test
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
input_tensor = paddle.to_tensor(input, dtype=paddle.float32)
cpu_out = eagle_get_hidden_states(
input_tensor.cpu(),
seq_lens_this_time_tensor.cpu(),
seq_lens_encoder_tensor.cpu(),
seq_lens_decoder_tensor.cpu(),
stop_flags_tensor.cpu(),
accept_nums_tensor.cpu(),
base_model_seq_lens_this_time_tensor.cpu(),
base_model_seq_lens_encoder_tensor.cpu(),
actual_draft_token_num,
)
xpu_out = eagle_get_hidden_states(
input_tensor,
seq_lens_this_time_tensor,
seq_lens_encoder_tensor,
seq_lens_decoder_tensor,
stop_flags_tensor,
accept_nums_tensor,
base_model_seq_lens_this_time_tensor,
base_model_seq_lens_encoder_tensor,
def ComputeOrderKernel(
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
accept_nums,
position_map,
output_token_num,
bsz,
actual_draft_token_num,
input_token_num,
):
in_offset = 0
out_offset = 0
# set encoder position map first
for i in range(bsz):
cur_seq_lens_encoder = seq_lens_encoder[i]
cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]
if cur_seq_lens_encoder > 0:
for j in range(cur_seq_lens_encoder):
position_map[in_offset] = out_offset
in_offset += 1
out_offset += 1

for i in range(bsz):
cur_base_model_seq_lens_this_time = base_model_seq_lens_this_time[i]
# cur_base_model_seq_lens_encoder = base_model_seq_lens_encoder[i]
cur_seq_lens_this_time = seq_lens_this_time[i]
accept_num = accept_nums[i]
cur_seq_lens_encoder = seq_lens_encoder[i]
# 1. eagle encoder. Base step=1
if cur_seq_lens_encoder > 0:
continue
# 2. Base model stop at last verify-step.
elif cur_base_model_seq_lens_this_time != 0 and cur_seq_lens_this_time == 0:
in_offset += cur_base_model_seq_lens_this_time
# 4. stopped
elif cur_base_model_seq_lens_this_time == 0 and cur_seq_lens_this_time == 0: # end
pass
else:
for i in range(accept_num):
position_map[in_offset] = out_offset
in_offset += 1
out_offset += 1
in_offset += cur_base_model_seq_lens_this_time - accept_num
output_token_num[0] = out_offset


def rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt):
for elem_idx in range(elem_cnt):
ori_token_idx = int(elem_idx / dim_embed)
token_idx = position_map[ori_token_idx]
if token_idx >= 0:
offset = elem_idx % dim_embed
out[token_idx][offset] = input[ori_token_idx][offset]


def eagle_get_hidden_states_ref(
input,
seq_lens_this_time,
seq_lens_encoder,
seq_lens_decoder,
stop_flags,
accept_nums,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
actual_draft_token_num,
):
input_token_num = input.shape[0]
dim_embed = input.shape[1]
bsz = seq_lens_this_time.shape[0]
position_map = paddle.full([input_token_num], 0xFFFFFFFF, seq_lens_this_time.dtype)
output_token_num = paddle.empty([1], seq_lens_this_time.dtype)
ComputeOrderKernel(
seq_lens_this_time,
seq_lens_encoder,
base_model_seq_lens_this_time,
base_model_seq_lens_encoder,
accept_nums,
position_map,
output_token_num,
bsz,
actual_draft_token_num,
input_token_num,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())

# bf16/fp16 test
for dtype in [paddle.bfloat16, paddle.float16]:
input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int16)
input_tensor = paddle.to_tensor(input, dtype=dtype)
cpu_out = eagle_get_hidden_states(
input_tensor.cpu(),
seq_lens_this_time_tensor.cpu(),
seq_lens_encoder_tensor.cpu(),
seq_lens_decoder_tensor.cpu(),
stop_flags_tensor.cpu(),
accept_nums_tensor.cpu(),
base_model_seq_lens_this_time_tensor.cpu(),
base_model_seq_lens_encoder_tensor.cpu(),

output_token_num_cpu = output_token_num[0]
out = paddle.empty([output_token_num_cpu, dim_embed], input.dtype)
elem_cnt = input_token_num * dim_embed
rebuildHiddenStatesKernel(input, position_map, out, dim_embed, elem_cnt)
return out


class TestEagleGetHiddenStates(unittest.TestCase):
def test_eagle_get_hidden_states(self):
np.random.seed(2023)
paddle.seed(2023)
bs = 2
input_token_num = 10
dim_embed = 512
actual_draft_token_num = np.random.randint(2, 6, dtype=np.int32)

seq_lens_this_time = np.random.randint(0, 2, bs, dtype=np.int32)
seq_lens_encoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
accept_nums = np.random.randint(0, actual_draft_token_num + 1, bs, dtype=np.int32)
base_model_seq_lens_this_time = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
base_model_seq_lens_encoder = np.random.randint(0, 2, bs, dtype=np.int32)

seq_lens_decoder = np.random.randint(0, input_token_num // bs + 1, bs, dtype=np.int32)
stop_flags = np.random.randint(0, 2, bs, dtype=np.int32)

seq_lens_this_time_tensor = paddle.to_tensor(seq_lens_this_time, dtype=paddle.int32)
seq_lens_encoder_tensor = paddle.to_tensor(seq_lens_encoder, dtype=paddle.int32)
accept_nums_tensor = paddle.to_tensor(accept_nums, dtype=paddle.int32)
base_model_seq_lens_this_time_tensor = paddle.to_tensor(base_model_seq_lens_this_time, dtype=paddle.int32)
base_model_seq_lens_encoder_tensor = paddle.to_tensor(base_model_seq_lens_encoder, dtype=paddle.int32)

seq_lens_decoder_tensor = paddle.to_tensor(seq_lens_decoder, dtype=paddle.int32)
stop_flags_tensor = paddle.to_tensor(stop_flags, dtype=paddle.int32)

input = np.random.randint(0, 10, (input_token_num, dim_embed), dtype=np.int32)
input_tensor = paddle.to_tensor(input, dtype=paddle.float16)
out = eagle_get_hidden_states(
input_tensor,
seq_lens_this_time_tensor,
seq_lens_encoder_tensor,
seq_lens_decoder_tensor,
stop_flags_tensor,
accept_nums_tensor,
base_model_seq_lens_this_time_tensor,
base_model_seq_lens_encoder_tensor,
actual_draft_token_num,
)
xpu_out = eagle_get_hidden_states(

out_ref = eagle_get_hidden_states_ref(
input_tensor,
seq_lens_this_time_tensor,
seq_lens_encoder_tensor,
Expand All @@ -95,10 +165,8 @@ def test_eagle_get_hidden_states():
base_model_seq_lens_encoder_tensor,
actual_draft_token_num,
)
assert np.allclose(cpu_out.numpy(), xpu_out.numpy())

print("All test passed")
np.testing.assert_allclose(out.numpy(), out_ref.numpy())


if __name__ == "__main__":
test_eagle_get_hidden_states()
unittest.main()
Loading