fix nan_to_num crash on DTensor gradients in domain-parallel training#1764
fix nan_to_num crash on DTensor gradients in domain-parallel training#1764negin513 wants to merge 1 commit into
Conversation
Greptile SummaryThis PR fixes a crash in domain-parallel training where
Important Files Changed
Reviews (1): Last reviewed commit: "fix for nan_to_num" | Re-trigger Greptile |
| local = grad.to_local() if hasattr(grad, "to_local") else grad | ||
| local.nan_to_num_(nan=0, posinf=1e5, neginf=-1e5) |
There was a problem hiding this comment.
Gradient clipping precedes NaN cleaning — ordering may produce incorrect clip norms
clip_grad_norm_ is called on line 782, before NaN values are cleaned. If any gradient contains NaN, the computed global norm will be NaN, making the clipping multiplier NaN as well and leaving all clipped gradients as NaN — then nan_to_num_ at line 797 zeros them. The effective result is that clipped gradients with NaN inputs may end up zeroed instead of clipped. This is a pre-existing ordering issue not introduced by this PR, but now that the NaN-cleaning path is being touched, it may be worth moving the NaN cleaning to before the gradient clipping call.
| local = grad.to_local() if hasattr(grad, "to_local") else grad | ||
| local.nan_to_num_(nan=0, posinf=1e5, neginf=-1e5) |
There was a problem hiding this comment.
In-place edit on
to_local() relies on shared storage with the DTensor
The fix calls grad.to_local() and mutates the result with nan_to_num_. This works only if to_local() returns the DTensor's internal _local_tensor by reference (sharing storage), so the in-place change is visible through param.grad when optimizer.step() reads it. If a PyTorch version returns a copy instead of a view, the NaN values in param.grad would not be cleaned and the optimizer would see corrupted gradients. Consider asserting or documenting the assumption, or alternatively using DTensor.from_local() to reconstruct the DTensor from the cleaned local shard if the in-place path ever proves fragile.
PhysicsNeMo Pull Request
Description
After domain-parallel backward,
param.gradis a DTensor withPartial(sum)placement -- meaning each rank holds an unreduced partial sum. PyTorch's DTensor dispatch has no registered sharding strategy fornan_to_numonPartialtensors, so it fails.Fix
This PR adds a fix that extracts the local shard with
.to_local()and applynan_to_num_in-place on it directly.Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.