Skip to content

Fast gradient clipping ignores masking #792

@isuruwg

Description

@isuruwg

🐛 Bug

We noticed that when we enable ghost_fsdp​ when wrapping the model with the privacy engine, the criterion​ is also wrapped and the code follows the following path:

if "ghost" in grad_sample_mode → call _prepare_criterion
prepare_criterion returns a [DPLossFastGradientClipping](https://github.com/meta-pytorch/opacus/blob/main/opacus/privacy_engine.py#L216 object)
Inside the __call__​ function of the DPLossFastGradientClipping class, there’s this line that does the mean reduction: https://github.com/meta-pytorch/opacus/blob/main/opacus/utils/fast_gradient_clipping_utils.py#L121

However, this mean reduction does not take into account the ignore_index​ setting that would have been passed to the original PyTorch Criterion (eg: CrossEntropyLoss). In the original PyTorch implementation, the masks are ignored when calculating the loss as per the documentation.

But since the Opacus implementation of the mean reduction uses the following method to calculate the mean:

if self.loss_reduction == "mean":
                loss_per_sample = loss_per_sample.mean(dim=1)  # B

This would calculate the mean for the whole tensor without ignoring masks that we put. This is apparent when trying to do a task like SQUAD where the part of the output that we care about it very small compared to the whole output (eg: for SQUAD task, the answers are very short (few tokens), but if the masks are ignored when calculating the loss, the loss is much lower than it should be (since the denominator is equal to sequence length (eg: 1024) instead of the correct value, 2 or so).

This leads to the model not training as the loss is always very low. Is support for ignore_index​ similar to PyTorch implementation something that is planned to be added in the future to the fast gradient clipping function?

As a workaround, if we manually calculate this loss outside of opacus and then replace the loss_per_sample​ tensor with the correct values, could that cause issues in other places?

Thank you for looking into this.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions