Fixes a layer_norm_nested backwards edge case. (#96788)

# Summary
Add Test and the fix for when input NT doesn't require grad to layernorm.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96788
Approved by: https://github.com/cpuhrsch
This commit is contained in:
Driss Guessous
2023-03-15 17:16:13 +00:00
committed by PyTorch MergeBot
parent 80e8e41ca7
commit 5612aa6acd
2 changed files with 18 additions and 0 deletions

View File

@ -221,6 +221,14 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous);
} else {
dInput = at::native::zeros_like(
input_buffer,
c10::nullopt /* dtype */,
c10::nullopt /* layout */,
c10::nullopt /* device */,
c10::nullopt /* pin_memory */,
at::MemoryFormat::Contiguous);
}
if (grad_input_mask[1]) {
dgamma = M > 0 ? at::native::empty_like(

View File

@ -2401,6 +2401,16 @@ class TestNestedTensorAutograd(TestCase):
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# Previously would error when input NT doesn't require grad
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
def test_layer_norm_backward_edge_case(self, device):
size = 4
a = torch.randn(1, 2, size, requires_grad=False, dtype=torch.float64, device=device)
nt = torch.nested.nested_tensor([a])
nt_layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64)
out = nt_layer_norm(nt)
out.backward(out.clone())
# TODO: OOM https://github.com/pytorch/pytorch/issues/95562
@skipIfSlowGradcheckEnv
@parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])