Skip to content

fix nan_to_num crash on DTensor gradients in domain-parallel training#1764

Open
negin513 wants to merge 1 commit into
NVIDIA:mainfrom
negin513:nan_to_num-fix-main
Open

fix nan_to_num crash on DTensor gradients in domain-parallel training#1764
negin513 wants to merge 1 commit into
NVIDIA:mainfrom
negin513:nan_to_num-fix-main

Conversation

@negin513

@negin513 negin513 commented Jun 30, 2026

Copy link
Copy Markdown
Member

PhysicsNeMo Pull Request

Description

After domain-parallel backward, param.grad is a DTensor with Partial(sum) placement -- meaning each rank holds an unreduced partial sum. PyTorch's DTensor dispatch has no registered sharding strategy for nan_to_num on Partial tensors, so it fails.

Fix

This PR adds a fix that extracts the local shard with .to_local() and apply nan_to_num_ in-place on it directly.

Checklist

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

@copy-pr-bot

copy-pr-bot Bot commented Jun 30, 2026

Copy link
Copy Markdown

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@greptile-apps

greptile-apps Bot commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR fixes a crash in domain-parallel training where param.grad is a DTensor with Partial(sum) placement — PyTorch's DTensor dispatch has no registered sharding strategy for nan_to_num on Partial tensors, causing an error during the post-backward gradient cleaning step.

  • The fix extracts the local shard with grad.to_local() (using a duck-type check) and applies nan_to_num_ in-place on it, bypassing the DTensor dispatch for this operation.
  • For regular (non-DTensor) tensors the behaviour is unchanged: local is just grad itself, and the in-place call is equivalent to the original torch.nan_to_num(..., out=param.grad).

Important Files Changed

Filename Overview
examples/weather/stormcast/utils/trainer.py Replaces torch.nan_to_num with an in-place nan_to_num_ on the local shard to fix DTensor Partial placement crash; one subtle correctness concern around whether to_local() is guaranteed to share storage with the DTensor's internal local tensor across all PyTorch versions.

Reviews (1): Last reviewed commit: "fix for nan_to_num" | Re-trigger Greptile

Comment on lines +796 to +797
local = grad.to_local() if hasattr(grad, "to_local") else grad
local.nan_to_num_(nan=0, posinf=1e5, neginf=-1e5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Gradient clipping precedes NaN cleaning — ordering may produce incorrect clip norms

clip_grad_norm_ is called on line 782, before NaN values are cleaned. If any gradient contains NaN, the computed global norm will be NaN, making the clipping multiplier NaN as well and leaving all clipped gradients as NaN — then nan_to_num_ at line 797 zeros them. The effective result is that clipped gradients with NaN inputs may end up zeroed instead of clipped. This is a pre-existing ordering issue not introduced by this PR, but now that the NaN-cleaning path is being touched, it may be worth moving the NaN cleaning to before the gradient clipping call.

Comment on lines +796 to +797
local = grad.to_local() if hasattr(grad, "to_local") else grad
local.nan_to_num_(nan=0, posinf=1e5, neginf=-1e5)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 In-place edit on to_local() relies on shared storage with the DTensor

The fix calls grad.to_local() and mutates the result with nan_to_num_. This works only if to_local() returns the DTensor's internal _local_tensor by reference (sharing storage), so the in-place change is visible through param.grad when optimizer.step() reads it. If a PyTorch version returns a copy instead of a view, the NaN values in param.grad would not be cleaned and the optimizer would see corrupted gradients. Consider asserting or documenting the assumption, or alternatively using DTensor.from_local() to reconstruct the DTensor from the cleaned local shard if the in-place path ever proves fragile.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant