mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Backward support for unbind() with NJT (#128032)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128032 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
27ae1f981d
commit
e1c1052829
@ -5610,6 +5610,25 @@ class TestNestedTensorSubclass(TestCase):
|
||||
for dynamic in [False, True, None]:
|
||||
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
|
||||
|
||||
@dtypes(torch.float32, torch.double, torch.half)
|
||||
def test_unbind_backward(self, device, dtype):
|
||||
nt = torch.nested.nested_tensor(
|
||||
[
|
||||
torch.randn(2, 4, device=device),
|
||||
torch.randn(5, 4, device=device),
|
||||
torch.randn(3, 4, device=device),
|
||||
],
|
||||
layout=torch.jagged,
|
||||
requires_grad=True,
|
||||
)
|
||||
|
||||
a, b, c = nt.unbind()
|
||||
b.sum().backward()
|
||||
|
||||
expected_grad = torch.zeros_like(nt)
|
||||
expected_grad.unbind()[1].add_(1.0)
|
||||
torch._dynamo.disable(self.assertEqual)(nt.grad, expected_grad)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
|
@ -2852,7 +2852,7 @@
|
||||
self: unbind_backward(grads, dim)
|
||||
result: auto_linear
|
||||
AutogradNestedTensor:
|
||||
self: unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())
|
||||
self: "self.layout() == c10::kJagged ? unbind_backward_nested_jagged(grads, self, dim) : unbind_backward_nested(grads, at::native::get_nested_tensor_impl(self)->get_nested_sizes(), dim, self.options())"
|
||||
result: auto_linear
|
||||
|
||||
- name: stack(Tensor[] tensors, int dim=0) -> Tensor
|
||||
|
@ -1014,6 +1014,23 @@ Tensor unbind_backward_nested(
|
||||
return at::_nested_tensor_from_tensor_list(grads_tensors);
|
||||
}
|
||||
|
||||
Tensor unbind_backward_nested_jagged(
|
||||
const variable_list& grads,
|
||||
const Tensor& self,
|
||||
int64_t dim) {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
dim == 0, "unbind_backward_nested_jagged() only supports dim=0")
|
||||
auto grad_nt = at::zeros_like(self);
|
||||
auto unbound_grads = grad_nt.unbind();
|
||||
for (int64_t i : c10::irange(static_cast<int64_t>(grads.size()))) {
|
||||
if (grads[i].defined()) {
|
||||
unbound_grads[i].copy_(static_cast<Tensor>(grads[i]));
|
||||
}
|
||||
}
|
||||
|
||||
return grad_nt;
|
||||
}
|
||||
|
||||
Tensor unsqueeze_to(const Tensor& self, c10::SymIntArrayRef sym_sizes) {
|
||||
auto result = self;
|
||||
|
||||
|
@ -244,6 +244,10 @@ at::Tensor unbind_backward_nested(
|
||||
const Tensor& nt_sizes,
|
||||
int64_t dim,
|
||||
const at::TensorOptions& options);
|
||||
at::Tensor unbind_backward_nested_jagged(
|
||||
const variable_list& grads,
|
||||
const Tensor& self,
|
||||
int64_t dim);
|
||||
at::Tensor unsqueeze_to(const at::Tensor& self, c10::SymIntArrayRef sym_sizes);
|
||||
at::Tensor unsqueeze_to(
|
||||
const at::Tensor& self,
|
||||
|
@ -472,6 +472,17 @@ register_jagged_func(
|
||||
)(jagged_unary_pointwise)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.zero_.default, "self: jt_all")
|
||||
def zero__default(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
|
||||
inp = new_kwargs.pop("input")
|
||||
func(inp._values)
|
||||
return inp
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten._softmax.default, "self: jt, dim: any, half_to_float: any"
|
||||
)
|
||||
|
Reference in New Issue
Block a user