[Expanded Weights] fix layer norm (#80895)

Opacus found that Layer Norm can fail from a wrong ordering in the ExpandedWeights code. What was happening is that all our tests had the input require grad so a layer norm check was always short circuiting in the tests, avoiding the wrong ordering. This adds a test where the input does not require gradients and fixes the issue in Layer Norm

Closes #80952
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80895
Approved by: https://github.com/zou3519
This commit is contained in:
samdow
2022-07-18 13:19:16 -04:00
committed by PyTorch MergeBot
parent 7408004454
commit 2bcbea1ff6
2 changed files with 12 additions and 2 deletions

View File

@ -192,6 +192,16 @@ class TestExpandedWeightFunctional(TestCase):
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in supported_inputs(op, sample_inputs):
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
sample_input.input.requires_grad_(False)
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
def test_unsupported_expand_weights(self, device, dtype, op):
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)