mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fixes a layer_norm_nested backwards edge case. (#96788)
# Summary Add Test and the fix for when input NT doesn't require grad to layernorm. Pull Request resolved: https://github.com/pytorch/pytorch/pull/96788 Approved by: https://github.com/cpuhrsch
This commit is contained in:
committed by
PyTorch MergeBot
parent
80e8e41ca7
commit
5612aa6acd
@ -221,6 +221,14 @@ std::tuple<Tensor, Tensor, Tensor> layer_norm_backward_nested(
|
||||
c10::nullopt /* device */,
|
||||
c10::nullopt /* pin_memory */,
|
||||
at::MemoryFormat::Contiguous);
|
||||
} else {
|
||||
dInput = at::native::zeros_like(
|
||||
input_buffer,
|
||||
c10::nullopt /* dtype */,
|
||||
c10::nullopt /* layout */,
|
||||
c10::nullopt /* device */,
|
||||
c10::nullopt /* pin_memory */,
|
||||
at::MemoryFormat::Contiguous);
|
||||
}
|
||||
if (grad_input_mask[1]) {
|
||||
dgamma = M > 0 ? at::native::empty_like(
|
||||
|
@ -2401,6 +2401,16 @@ class TestNestedTensorAutograd(TestCase):
|
||||
data = (a, b, c)
|
||||
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
||||
|
||||
# Previously would error when input NT doesn't require grad
|
||||
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
|
||||
def test_layer_norm_backward_edge_case(self, device):
|
||||
size = 4
|
||||
a = torch.randn(1, 2, size, requires_grad=False, dtype=torch.float64, device=device)
|
||||
nt = torch.nested.nested_tensor([a])
|
||||
nt_layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64)
|
||||
out = nt_layer_norm(nt)
|
||||
out.backward(out.clone())
|
||||
|
||||
# TODO: OOM https://github.com/pytorch/pytorch/issues/95562
|
||||
@skipIfSlowGradcheckEnv
|
||||
@parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
|
||||
|
Reference in New Issue
Block a user