Skip to content

[bug] grad explode to nan while training Qwen 3 Reranker with tuner unsloth #6992

@MosRat

Description

@MosRat

Describe the bug
I encountered a gradient explosion issue resulting in NaN values when training Qwen/Qwen3-Reranker-8B (Listwise Generative Reranker) using ms-swift with the unsloth backend.

Symptoms:

  1. Normal Run: If I run the training normally, the training loop continues, but the logs report grad_norm = nan. The loss value still fluctuates/decreases, but the gradients are clearly broken.
  2. Debug Run: When enabling torch.autograd.set_detect_anomaly(True) and adding a debug hook to sdpa, the training crashes with RuntimeError: Function 'MmBackward0' returned nan values.
  3. Gradient Analysis: By hooking into _debug_sdpa, I observed the gradients flowing backward into the Attention layer growing exponentially from 1e-6 to 1e+36 within a single step, eventually hitting Infinity/NaN.

Reproduction Script:

#!/bin/bash
# Listwise training configuration
CACHE_DIR="./webq_cached_dataset"
export CUDA_VISIBLE_DEVICES=1
export PYTORCH_ALLOC_CONF=expandable_segments:True
export UNSLOTH_COMPILE_DISABLE=1

echo ">>> Start SFT..."

uv --preview-features extra-build-dependencies run swift sft \
    --model Qwen/Qwen3-Reranker-8B \
    --task_type generative_reranker \
    --loss_type listwise_generative_reranker \
    --train_type lora \
    --tuner_backend unsloth \
    --torch_dtype bfloat16 \
    --learning_rate 2e-4 \
    --gradient_accumulation_steps 16 \
    --lora_rank 64 \
    --lora_alpha 128 \
    --dataset 'MTEB/scidocs-reranking' \
    # ... (rest of parameters as per context)

Debug Log & Stack Trace:
Using torch.autograd.set_detect_anomaly(True) and with torch.nn.attention.sdpa_kernel([torch.nn.attention.SDPBackend.MATH]):, the gradients explode as follows:

# logs from `output.register_hook(check_grad_output)` for  output of F.sdpa
[SDPA Backward] grad Max: 8.046627044677734e-06, Min: -6.794929504394531e-06
...
[SDPA Backward] grad Max: 566935683072.0, Min: -807453851648.0
...
[SDPA Backward] grad Max: 2.648071397852762e+36, Min: -2.1599954931504883e+36
# error raised
RuntimeError: Function 'MmBackward0' returned nan values in its 0th output.
  File "/home/User/workspace/ir_train/.venv/lib/python3.12/site-packages/accelerate/accelerator.py", line 2740, in backward
    loss.backward(**kwargs)
  File "/home/User/workspace/ir_train/.venv/lib/python3.12/site-packages/unsloth_zoo/gradient_checkpointing.py", line 598, in backward
    torch.autograd.backward(outputs_with_grad, args_with_grad)

Your hardware and system info

  • GPU: NVIDIA A100 (40GB)
  • CUDA: 12.8
  • System: Linux
  • Python: 3.12
  • Dependencies (from uv config):
    • torch==2.9.0 (installed via index pytorch-cu128)
    • unsloth (latest git)
    • ms-swift<=3.10.3
    • accelerate==1.11
    • deepspeed>=0.18.2

Additional context

  • The issue seems specific to the combination of Unsloth backend and the Listwise Generative Reranker task.
  • The explosion happens during the backward pass after the attention mechanism (gradients coming from the upper layers into SDPA are already huge).
  • attn_impl was set to sdpa for debugging, but flash-attn exhibits the same behavior (NaN grad_norm).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions