-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Open
Description
Describe the bug
I encountered a gradient explosion issue resulting in NaN values when training Qwen/Qwen3-Reranker-8B (Listwise Generative Reranker) using ms-swift with the unsloth backend.
Symptoms:
- Normal Run: If I run the training normally, the training loop continues, but the logs report
grad_norm = nan. The loss value still fluctuates/decreases, but the gradients are clearly broken. - Debug Run: When enabling
torch.autograd.set_detect_anomaly(True)and adding a debug hook tosdpa, the training crashes withRuntimeError: Function 'MmBackward0' returned nan values. - Gradient Analysis: By hooking into
_debug_sdpa, I observed the gradients flowing backward into the Attention layer growing exponentially from1e-6to1e+36within a single step, eventually hitting Infinity/NaN.
Reproduction Script:
#!/bin/bash
# Listwise training configuration
CACHE_DIR="./webq_cached_dataset"
export CUDA_VISIBLE_DEVICES=1
export PYTORCH_ALLOC_CONF=expandable_segments:True
export UNSLOTH_COMPILE_DISABLE=1
echo ">>> Start SFT..."
uv --preview-features extra-build-dependencies run swift sft \
--model Qwen/Qwen3-Reranker-8B \
--task_type generative_reranker \
--loss_type listwise_generative_reranker \
--train_type lora \
--tuner_backend unsloth \
--torch_dtype bfloat16 \
--learning_rate 2e-4 \
--gradient_accumulation_steps 16 \
--lora_rank 64 \
--lora_alpha 128 \
--dataset 'MTEB/scidocs-reranking' \
# ... (rest of parameters as per context)Debug Log & Stack Trace:
Using torch.autograd.set_detect_anomaly(True) and with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):, the gradients explode as follows:
# logs from `output.register_hook(check_grad_output)` for output of F.sdpa
[SDPA Backward] grad Max: 8.046627044677734e-06, Min: -6.794929504394531e-06
...
[SDPA Backward] grad Max: 566935683072.0, Min: -807453851648.0
...
[SDPA Backward] grad Max: 2.648071397852762e+36, Min: -2.1599954931504883e+36
# error raised
RuntimeError: Function 'MmBackward0' returned nan values in its 0th output.
File "/home/User/workspace/ir_train/.venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 2740, in backward
loss.backward(**kwargs)
File "/home/User/workspace/ir_train/.venv/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py", line 598, in backward
torch.autograd.backward(outputs_with_grad, args_with_grad)
Your hardware and system info
- GPU: NVIDIA A100 (40GB)
- CUDA: 12.8
- System: Linux
- Python: 3.12
- Dependencies (from
uvconfig):torch==2.9.0(installed via indexpytorch-cu128)unsloth(latest git)ms-swift<=3.10.3accelerate==1.11deepspeed>=0.18.2
Additional context
- The issue seems specific to the combination of Unsloth backend and the Listwise Generative Reranker task.
- The explosion happens during the backward pass after the attention mechanism (gradients coming from the upper layers into SDPA are already huge).
attn_implwas set tosdpafor debugging, butflash-attnexhibits the same behavior (NaN grad_norm).
Metadata
Metadata
Assignees
Labels
No labels