Fix FP8 activation quantization for >2D activations in mixed_precision_ops#14643
Open
liminfei-amd wants to merge 1 commit into
Open
Fix FP8 activation quantization for >2D activations in mixed_precision_ops#14643liminfei-amd wants to merge 1 commit into
liminfei-amd wants to merge 1 commit into
Conversation
…n_ops mixed_precision_ops.Linear.forward only quantized activations that were 2D, or 3D (reshaped to 2D). Inputs with rank >= 4 (e.g. Anima's MLP activations, which are not reshaped to 3D the way the attention path is) fell through the `input_reshaped.ndim == 2` guard and reached scaled_mm as bf16, silently dispatching a bf16 kernel instead of FP8. Since MLP is roughly half the compute, the FP8 speedup was far below expectation. Generalize the existing 3D->2D reshape to any rank >= 3 (flatten the leading dims, keep the contraction dim) and reshape the output back to the original leading dims. 2D and 3D inputs are handled exactly as before; only rank >= 4 inputs change (now quantized instead of skipped). This matches the rank-agnostic handling already used by the training path (flatten(0, -2) / unflatten). Fixes Comfy-Org#14595. Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
1 task
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review info⚙️ Run configurationConfiguration used: Path: .coderabbit.yaml Review profile: CHILL Plan: Pro Plus Run ID: 📒 Files selected for processing (1)
📝 WalkthroughWalkthrough
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fix FP8 activation quantization fall-through for >2D activations in
mixed_precision_opsProblem
With a fully-quantized FP8 checkpoint and
--fast fp8_matrix_mult, only theattention GEMMs (QKV / output proj) run in FP8 while the MLP GEMMs (up / down
proj) fall back to bf16. Since MLP is roughly half of the compute, the measured
FP8 speedup is far below expectation (~15%). Reported and root-caused by @TaihoC
in #14595 (Anima checkpoint).
Root cause
In
mixed_precision_ops(...)Linear.forward(comfy/ops.py), input activationsare only quantized when they are 2D, or 3D (reshaped to 2D):
Models whose Linear inputs have rank >= 4 (e.g. Anima's MLP activations, which are
not reshaped to 3D the way the attention path is) never enter the
ndim == 2branch, so
QuantizedTensor.from_floatis never called. The activation reachesscaled_mmas bf16 and a bf16 kernel is dispatched silently (reshaped_3dispre-initialised to
False, so there is no error — just a silent fallback).Fix
Generalize the existing 3D->2D reshape to any rank >= 3 (flatten all leading dims,
keep the contraction dim), and reshape the output back to the original leading
dims:
input.ndim == 3->input.ndim >= 3,input_shape[2]->input_shape[-1](input_shape[0], input_shape[1], out)->(*input_shape[:-1], out)Backward compatible: 2D and 3D inputs are handled exactly as before; only rank >= 4
inputs change behavior (now quantized instead of silently skipped).
Notes on quantization semantics
Flattening leading dims is the same operation the code already performs for 3D
(it flattens
B,T). For per-tensor FP8 and the block formats currently in use(MXFP8 / NVFP4 block along the last / contraction dim, which the reshape
preserves), the meaning of quantization is unchanged. This matches the safer,
more general of the two directions discussed in the issue.
Verification
The shape logic was verified with a standalone script (no FP8 hardware required):
4D inputs now take the quantize branch and the output shape/values round-trip
correctly, while 2D/3D behavior is unchanged. I do not have an FP8 GPU to measure
end-to-end kernel dispatch; reviewers with an FP8 device can confirm the MLP GEMMs
now dispatch as FP8 using the trace method in the issue.
Closes #14595.