Skip to content

Vectorize repetition/presence penalty in BaseGenerate.sample_token#14645

Open
nat-chan wants to merge 1 commit into
Comfy-Org:masterfrom
nat-chan:perf/vectorize-sample-token-penalties
Open

Vectorize repetition/presence penalty in BaseGenerate.sample_token#14645
nat-chan wants to merge 1 commit into
Comfy-Org:masterfrom
nat-chan:perf/vectorize-sample-token-penalties

Conversation

@nat-chan

Copy link
Copy Markdown

Summary

BaseGenerate.sample_token (in comfy/text_encoders/llama.py) applies the repetition and presence penalties with a nested Python loop over set(token_history) for each batch row. This loop runs on every decode step and has two problems:

  1. It scales as O(len(history)) per generated token, so the cost grows with the sequence length (effectively O(n²) over a full generation).
  2. It indexes the logits tensor with Python scalars (logits[i, token_id] < 0), which forces a GPU→CPU sync on every element.

This PR replaces the loop with a single gather (index_select) / scatter (index_copy_) over the unique history tokens. The work now runs entirely on-device and no longer scales with history length.

Numerical equivalence

The per-element arithmetic is preserved exactly (including * (1.0 / repetition_penalty) to match the original float rounding), so sampled logits are bit-for-bit identical. Verified with torch.equal across 72 cases (float32 / bfloat16 / float16 × several repetition_penalty / presence_penalty values × batch sizes 1 and 3) — all identical.

Benchmark

Per-step penalty cost (bf16, vocab ≈ 152k, single GPU):

history length original loop vectorized
64 3.06 ms 1.43 ms
256 8.31 ms 0.06 ms

The original cost rises with history length; the vectorized version stays flat. The saving applies per decode step, so it compounds over long generations.

The per-token sampling penalties were applied with a nested Python loop over
set(token_history) for each batch row. That loop grows with the generated
sequence length and indexes the logits tensor with scalars, forcing a
GPU->CPU sync on every decode step.

Replace it with a single gather/scatter over the unique history tokens. The
per-element arithmetic is unchanged, so the sampled logits are bit-for-bit
identical, while the work runs entirely on-device and no longer scales with
history length.
@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: a7bf4b6e-7003-49fc-8e9a-cf2258bd0f19

📥 Commits

Reviewing files that changed from the base of the PR and between 7cb784e and 2bb8d10.

📒 Files selected for processing (1)
  • comfy/text_encoders/llama.py

📝 Walkthrough

Walkthrough

BaseGenerate.sample_token was changed to apply repetition and presence penalties with vectorized tensor operations on the logits device. The code now gates the penalty block on non-empty token_history, gathers unique token IDs once, updates the selected logits columns, and writes the modified values back in place. The remaining sampling steps are unchanged.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and concisely summarizes the main change: vectorizing repetition/presence penalty handling in sample_token.
Description check ✅ Passed The description directly explains the same vectorization change, its motivation, and the reported results.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant