Backward support for unbind() with NJT (#128032)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032
Approved by: https://github.com/soulitzer
This commit is contained in:
Joel Schlosser
2024-06-20 19:18:13 -04:00
committed by PyTorch MergeBot
parent 27ae1f981d
commit e1c1052829
5 changed files with 52 additions and 1 deletions

View File

@ -5610,6 +5610,25 @@ class TestNestedTensorSubclass(TestCase):
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
@dtypes(torch.float32, torch.double, torch.half)
def test_unbind_backward(self, device, dtype):
nt = torch.nested.nested_tensor(
[
torch.randn(2, 4, device=device),
torch.randn(5, 4, device=device),
torch.randn(3, 4, device=device),
],
layout=torch.jagged,
requires_grad=True,
)
a, b, c = nt.unbind()
b.sum().backward()
expected_grad = torch.zeros_like(nt)
expected_grad.unbind()[1].add_(1.0)
torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())

View File

@ -2852,7 +2852,7 @@
self: unbind_backward(grads, dim)
result: auto_linear
AutogradNestedTensor:
self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())
self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())"
result: auto_linear
- name: stack(Tensor[] tensors, int dim=0) -> Tensor

View File

@ -1014,6 +1014,23 @@ Tensor unbind_backward_nested(
return at::_nested_tensor_from_tensor_list(grads_tensors);
}
Tensor unbind_backward_nested_jagged(
const variable_list& grads,
const Tensor& self,
int64_t dim) {
TORCH_INTERNAL_ASSERT(
dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
auto grad_nt = at::zeros_like(self);
auto unbound_grads = grad_nt.unbind();
for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
if (grads[i].defined()) {
unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
}
}
return grad_nt;
}
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
auto result = self;

View File

@ -244,6 +244,10 @@ at::Tensor unbind_backward_nested(
const Tensor& nt_sizes,
int64_t dim,
const at::TensorOptions& options);
at::Tensor unbind_backward_nested_jagged(
const variable_list& grads,
const Tensor& self,
int64_t dim);
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
at::Tensor unsqueeze_to(
const at::Tensor& self,

View File

@ -472,6 +472,17 @@ register_jagged_func(
)(jagged_unary_pointwise)
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
def zero__default(func, *args, **kwargs):
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
func(inp._values)
return inp
@register_jagged_func(
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
)