Skip to content

Fix FP8 activation quantization for >2D activations in mixed_precision_ops#14643

Open
liminfei-amd wants to merge 1 commit into
Comfy-Org:masterfrom
liminfei-amd:fix-fp8-nd-activation-quant-14595
Open

Fix FP8 activation quantization for >2D activations in mixed_precision_ops#14643
liminfei-amd wants to merge 1 commit into
Comfy-Org:masterfrom
liminfei-amd:fix-fp8-nd-activation-quant-14595

Conversation

@liminfei-amd

Copy link
Copy Markdown

Fix FP8 activation quantization fall-through for >2D activations in mixed_precision_ops

Problem

With a fully-quantized FP8 checkpoint and --fast fp8_matrix_mult, only the
attention GEMMs (QKV / output proj) run in FP8 while the MLP GEMMs (up / down
proj) fall back to bf16. Since MLP is roughly half of the compute, the measured
FP8 speedup is far below expectation (~15%). Reported and root-caused by @TaihoC
in #14595 (Anima checkpoint).

Root cause

In mixed_precision_ops(...) Linear.forward (comfy/ops.py), input activations
are only quantized when they are 2D, or 3D (reshaped to 2D):

input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
if input_reshaped.ndim == 2:
    ...
    input = QuantizedTensor.from_float(input_reshaped, self.layout_type, scale=scale)

Models whose Linear inputs have rank >= 4 (e.g. Anima's MLP activations, which are
not reshaped to 3D the way the attention path is) never enter the ndim == 2
branch, so QuantizedTensor.from_float is never called. The activation reaches
scaled_mm as bf16 and a bf16 kernel is dispatched silently (reshaped_3d is
pre-initialised to False, so there is no error — just a silent fallback).

Fix

Generalize the existing 3D->2D reshape to any rank >= 3 (flatten all leading dims,
keep the contraction dim), and reshape the output back to the original leading
dims:

  • input.ndim == 3 -> input.ndim >= 3, input_shape[2] -> input_shape[-1]
  • output reshape (input_shape[0], input_shape[1], out) -> (*input_shape[:-1], out)

Backward compatible: 2D and 3D inputs are handled exactly as before; only rank >= 4
inputs change behavior (now quantized instead of silently skipped).

Notes on quantization semantics

Flattening leading dims is the same operation the code already performs for 3D
(it flattens B,T). For per-tensor FP8 and the block formats currently in use
(MXFP8 / NVFP4 block along the last / contraction dim, which the reshape
preserves), the meaning of quantization is unchanged. This matches the safer,
more general of the two directions discussed in the issue.

Verification

The shape logic was verified with a standalone script (no FP8 hardware required):
4D inputs now take the quantize branch and the output shape/values round-trip
correctly, while 2D/3D behavior is unchanged. I do not have an FP8 GPU to measure
end-to-end kernel dispatch; reviewers with an FP8 device can confirm the MLP GEMMs
now dispatch as FP8 using the trace method in the issue.

Closes #14595.

…n_ops

mixed_precision_ops.Linear.forward only quantized activations that were 2D, or
3D (reshaped to 2D). Inputs with rank >= 4 (e.g. Anima's MLP activations, which
are not reshaped to 3D the way the attention path is) fell through the
`input_reshaped.ndim == 2` guard and reached scaled_mm as bf16, silently
dispatching a bf16 kernel instead of FP8. Since MLP is roughly half the compute,
the FP8 speedup was far below expectation.

Generalize the existing 3D->2D reshape to any rank >= 3 (flatten the leading
dims, keep the contraction dim) and reshape the output back to the original
leading dims. 2D and 3D inputs are handled exactly as before; only rank >= 4
inputs change (now quantized instead of skipped). This matches the rank-agnostic
handling already used by the training path (flatten(0, -2) / unflatten).

Fixes Comfy-Org#14595.

Signed-off-by: liminfei-amd <91481003+liminfei-amd@users.noreply.github.com>
@coderabbitai

coderabbitai Bot commented Jun 26, 2026

Copy link
Copy Markdown

Review Change Stack

No actionable comments were generated in the recent review. 🎉

ℹ️ Recent review info
⚙️ Run configuration

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro Plus

Run ID: 98bde03f-9789-458a-8b82-dd379dc0e0e9

📥 Commits

Reviewing files that changed from the base of the PR and between 7cb784e and f8af8da.

📒 Files selected for processing (1)
  • comfy/ops.py

📝 Walkthrough

Walkthrough

MixedPrecisionOps.Linear.forward was updated so the quantized inference path now handles inputs with rank 3 or higher. The code tracks the reshape with reshaped_nd, collapses leading dimensions into a 2D view before quantization, and reshapes the result back to (*input_shape[:-1], self.weight.shape[0]) after computation.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly states the FP8 activation quantization fix for higher-rank activations in mixed_precision_ops.
Description check ✅ Passed The description matches the change and explains the same FP8 quantization issue, root cause, fix, and verification.
Linked Issues check ✅ Passed The change generalizes activation quantization for rank>=3 in comfy/ops.py, matching the issue's FP8 MLP fallback fix.
Out of Scope Changes check ✅ Passed The only code change is the targeted Linear.forward reshape/quantization adjustment, with no unrelated edits evident.

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Anima FP8 MLP activation quantization is not handled correctly

1 participant