Vectorize repetition/presence penalty in BaseGenerate.sample_token#14645
Open
nat-chan wants to merge 1 commit into
Open
Vectorize repetition/presence penalty in BaseGenerate.sample_token#14645nat-chan wants to merge 1 commit into
nat-chan wants to merge 1 commit into
Conversation
The per-token sampling penalties were applied with a nested Python loop over set(token_history) for each batch row. That loop grows with the generated sequence length and indexes the logits tensor with scalars, forcing a GPU->CPU sync on every decode step. Replace it with a single gather/scatter over the unique history tokens. The per-element arithmetic is unchanged, so the sampled logits are bit-for-bit identical, while the work runs entirely on-device and no longer scales with history length.
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughBaseGenerate.sample_token was changed to apply repetition and presence penalties with vectorized tensor operations on the logits device. The code now gates the penalty block on non-empty token_history, gathers unique token IDs once, updates the selected logits columns, and writes the modified values back in place. The remaining sampling steps are unchanged. 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
BaseGenerate.sample_token(incomfy/text_encoders/llama.py) applies the repetition and presence penalties with a nested Python loop overset(token_history)for each batch row. This loop runs on every decode step and has two problems:logits[i, token_id] < 0), which forces a GPU→CPU sync on every element.This PR replaces the loop with a single gather (
index_select) / scatter (index_copy_) over the unique history tokens. The work now runs entirely on-device and no longer scales with history length.Numerical equivalence
The per-element arithmetic is preserved exactly (including
* (1.0 / repetition_penalty)to match the original float rounding), so sampled logits are bit-for-bit identical. Verified withtorch.equalacross 72 cases (float32 / bfloat16 / float16 × severalrepetition_penalty/presence_penaltyvalues × batch sizes 1 and 3) — all identical.Benchmark
Per-step penalty cost (bf16, vocab ≈ 152k, single GPU):
The original cost rises with history length; the vectorized version stays flat. The saving applies per decode step, so it compounds over long generations.