-
Notifications
You must be signed in to change notification settings - Fork 665
[XPU] add speculate_get_logits #5497
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for your contribution! |
| baidu::xpu::api::Context* ctx = | ||
| static_cast<const phi::XPUContext*>(dev_ctx)->x_context(); | ||
| if (draft_logits.is_cpu()) { | ||
| ctx = new baidu::xpu::api::Context(baidu::xpu::api::kCPU); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个 context 什么时候被销毁掉呢?是否会造成内存泄露?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CPU一般是用于单测验证,除了这里其他的算子可能也没对这个cpu的ctx做释放,后续可能需要统一排查一下
hong19860320
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
cmcamdy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## develop #5497 +/- ##
==========================================
Coverage ? 60.27%
==========================================
Files ? 329
Lines ? 41114
Branches ? 6261
==========================================
Hits ? 24782
Misses ? 14443
Partials ? 1889
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
| seq_lens_encoder, | ||
| real_bsz, | ||
| vocab_size); | ||
| WRAPPER_DUMP(ctx); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对输入输出的 xpu 指针加下 chekc 检查?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
| if (clus_id < real_bsz && cid == 0) { | ||
| GM2SM_ASYNC(seq_lens_encoder, sm_seq_lens_encoder, real_bsz * sizeof(int)); | ||
| GM2SM(seq_lens_this_time, sm_seq_lens_this_time, real_bsz * sizeof(int)); | ||
| int next_token_num_previous = 0; | ||
| for (int bid = 0; bid < real_bsz; bid++) { | ||
| sm_batch_token_num[bid] = | ||
| sm_seq_lens_encoder[bid] > 0 ? 2 : sm_seq_lens_this_time[bid]; | ||
| if (bid == 0) { | ||
| sm_cu_batch_token_offset[bid] = 0; | ||
| sm_cu_next_token_offset[bid] = 0; | ||
| } else { | ||
| sm_cu_batch_token_offset[bid] = | ||
| sm_cu_batch_token_offset[bid - 1] + sm_batch_token_num[bid - 1]; | ||
| sm_cu_next_token_offset[bid] = | ||
| sm_cu_next_token_offset[bid - 1] + next_token_num_previous; | ||
| } | ||
| next_token_num_previous = | ||
| sm_seq_lens_encoder[bid] > 0 ? 1 : sm_seq_lens_this_time[bid]; | ||
| } | ||
| mfence_sm(); | ||
| if (clus_id == 0) { | ||
| SM2GM_ASYNC(sm_batch_token_num, batch_token_num, real_bsz * sizeof(int)); | ||
| SM2GM_ASYNC(sm_cu_batch_token_offset, | ||
| cu_batch_token_offset, | ||
| real_bsz * sizeof(int)); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个部分的代码逻辑,21行-40行是每个 cluster 都会执行,41-46行只有 cluster0 会执行,是不是等价于21行-46行的代码实际上只有 clus_id == 0 执行的有用?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的prefix sum所有参与计算的cluster都需要使用,所以21-40所有< real_bsz的cluster都要计算一份;但是写回gm的话,只要一个cluster写就行,所以41-46只要cluster0
好的,明白了
Motivation
add speculate_get_logits
Modifications
add speculate_get_logits
Usage or Command
No
Accuracy Tests
No
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.