mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
[Expanded Weights] fix layer norm (#80895)
Opacus found that Layer Norm can fail from a wrong ordering in the ExpandedWeights code. What was happening is that all our tests had the input require grad so a layer norm check was always short circuiting in the tests, avoiding the wrong ordering. This adds a test where the input does not require gradients and fixes the issue in Layer Norm Closes #80952 Pull Request resolved: https://github.com/pytorch/pytorch/pull/80895 Approved by: https://github.com/zou3519
This commit is contained in:
@ -192,6 +192,16 @@ class TestExpandedWeightFunctional(TestCase):
|
||||
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
for sample_input in supported_inputs(op, sample_inputs):
|
||||
if op.name == "nn.functional.embedding": # embedding flips its argument order for autograd tests
|
||||
sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs)
|
||||
sample_input.input.requires_grad_(False)
|
||||
|
||||
self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean)
|
||||
|
||||
@ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,))
|
||||
def test_unsupported_expand_weights(self, device, dtype, op):
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
|
Reference in New Issue
Block a user