Skip to content

Commit 8ad13e7

Browse files
committed
add ptr check
1 parent 635126f commit 8ad13e7

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

custom_ops/xpu_ops/src/plugin/src/wrapper/mtp_wrapper/speculate_get_logits.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,15 @@ int speculate_get_logits(Context* ctx,
139139
real_bsz,
140140
vocab_size);
141141
WRAPPER_DUMP(ctx);
142+
WRAPPER_CHECK_PTR(ctx, int, real_bsz, next_token_num);
143+
WRAPPER_CHECK_PTR(ctx, int, real_bsz, batch_token_num);
144+
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_next_token_offset);
145+
WRAPPER_CHECK_PTR(ctx, int, real_bsz, cu_batch_token_offset);
146+
WRAPPER_CHECK_PTR(ctx, float, real_bsz* vocab_size, first_token_logits);
147+
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_this_time);
148+
WRAPPER_CHECK_PTR(ctx, int, real_bsz, seq_lens_encoder);
142149
WRAPPER_ASSERT_LE(ctx, real_bsz, 256 * 1024 / sizeof(int) / 5);
150+
WRAPPER_ASSERT_GT(ctx, vocab_size, 0);
143151
if (ctx->dev().type() == api::kCPU) {
144152
return cpu_wrapper(draft_logits,
145153
next_token_num,

0 commit comments

Comments
 (0)