Skip to content

Conversation

@RuohengMa
Copy link
Contributor

Motivation

add speculate_get_logits

Modifications

add speculate_get_logits

Usage or Command

No

Accuracy Tests

No

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

@paddle-bot
Copy link

paddle-bot bot commented Dec 11, 2025

Thanks for your contribution!

@paddle-bot paddle-bot bot added the XPU label Dec 11, 2025
baidu::xpu::api::Context* ctx =
static_cast<const phi::XPUContext*>(dev_ctx)->x_context();
if (draft_logits.is_cpu()) {
ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 context 什么时候被销毁掉呢?是否会造成内存泄露?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CPU一般是用于单测验证,除了这里其他的算子可能也没对这个cpu的ctx做释放,后续可能需要统一排查一下

cmcamdy
cmcamdy previously approved these changes Dec 11, 2025
hong19860320
hong19860320 previously approved these changes Dec 11, 2025
Copy link
Collaborator

@hong19860320 hong19860320 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

cmcamdy
cmcamdy previously approved these changes Dec 11, 2025
Copy link
Collaborator

@cmcamdy cmcamdy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@codecov-commenter
Copy link

codecov-commenter commented Dec 11, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
⚠️ Please upload report for BASE (develop@9f4512c). Learn more about missing BASE report.

Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #5497   +/-   ##
==========================================
  Coverage           ?   60.27%           
==========================================
  Files              ?      329           
  Lines              ?    41114           
  Branches           ?     6261           
==========================================
  Hits               ?    24782           
  Misses             ?    14443           
  Partials           ?     1889           
Flag Coverage Δ
GPU 60.27% <ø> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

seq_lens_encoder,
real_bsz,
vocab_size);
WRAPPER_DUMP(ctx);

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对输入输出的 xpu 指针加下 chekc 检查?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

Comment on lines +21 to +47
if (clus_id < real_bsz && cid == 0) {
GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int));
GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int));
int next_token_num_previous = 0;
for (int bid = 0; bid < real_bsz; bid++) {
sm_batch_token_num[bid] =
sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid];
if (bid == 0) {
sm_cu_batch_token_offset[bid] = 0;
sm_cu_next_token_offset[bid] = 0;
} else {
sm_cu_batch_token_offset[bid] =
sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1];
sm_cu_next_token_offset[bid] =
sm_cu_next_token_offset[bid - 1] + next_token_num_previous;
}
next_token_num_previous =
sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid];
}
mfence_sm();
if (clus_id == 0) {
SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int));
SM2GM_ASYNC(sm_cu_batch_token_offset,
cu_batch_token_offset,
real_bsz * sizeof(int));
}
}

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个部分的代码逻辑,21行-40行是每个 cluster 都会执行,41-46行只有 cluster0 会执行,是不是等价于21行-46行的代码实际上只有 clus_id == 0 执行的有用?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0

好的,明白了

@RuohengMa RuohengMa dismissed stale reviews from hong19860320 and cmcamdy via 8ad13e7 December 11, 2025 06:47
@Jiang-Jia-Jun Jiang-Jia-Jun merged commit 12c76f8 into PaddlePaddle:develop Dec 12, 2025
13 of 17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants