mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Add nested tensor support to autograd (#79446)
The issue that is tracking this work is: #79447 This is one in a series of PRs to add autograd support for nested tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/79446 Approved by: https://github.com/soulitzer
This commit is contained in:
committed by
PyTorch MergeBot
parent
4b342b30ad
commit
f9656817df
@ -54,9 +54,6 @@ NestedTensorImpl::NestedTensorImpl(
|
||||
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
|
||||
int64_t size_dim = nested_size_tensor_.dim();
|
||||
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
|
||||
remove_autograd_key();
|
||||
key_set_ =
|
||||
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
|
||||
refresh_dim();
|
||||
set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
|
||||
}
|
||||
|
||||
@ -16,6 +16,5 @@ TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl&
|
||||
|
||||
TORCH_API Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding, OptionalIntArrayRef output_size);
|
||||
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
||||
@ -482,6 +482,52 @@ class TestNestedTensorDeviceType(TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
nt1.clone(memory_format=torch.channels_last)
|
||||
|
||||
class TestNestedTensorAutograd(TestCase):
|
||||
def nt_equal(self, nt1, nt2):
|
||||
self.assertEqual(nt1.dtype, nt2.dtype)
|
||||
self.assertEqual(nt1.device, nt2.device)
|
||||
ub1 = nt1.unbind()
|
||||
ub2 = nt2.unbind()
|
||||
self.assertEqual(len(ub1), len(ub2))
|
||||
n = len(ub1)
|
||||
for i in range(n):
|
||||
self.assertEqual(ub1[i], ub2[i])
|
||||
|
||||
def _create_nested_tensor_from_list(self, requires_grad=False):
|
||||
return torch.nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
|
||||
torch.randn(7, 8, requires_grad=requires_grad)])
|
||||
|
||||
def _create_nested_tensor_from_mask(self, requires_grad=False):
|
||||
data = torch.randn(2, 3, 4, requires_grad=requires_grad)
|
||||
mask = torch.ones_like(data[:, :, 0]).bool()
|
||||
return torch._nested_tensor_from_mask(data, mask)
|
||||
|
||||
def test_set_requires_grad_from_list(self):
|
||||
nt = self._create_nested_tensor_from_list()
|
||||
nt.requires_grad_()
|
||||
assert nt.requires_grad
|
||||
|
||||
def test_set_requires_grad_from_mask(self):
|
||||
nt = self._create_nested_tensor_from_mask()
|
||||
nt.requires_grad_()
|
||||
assert nt.requires_grad
|
||||
|
||||
def test_backward_for_add_op(self):
|
||||
nt_1 = self._create_nested_tensor_from_mask()
|
||||
nt_2 = self._create_nested_tensor_from_mask()
|
||||
|
||||
nt_1.requires_grad_()
|
||||
c = nt_1 + nt_2
|
||||
|
||||
assert nt_1.requires_grad
|
||||
assert c.requires_grad
|
||||
grad_output = self._create_nested_tensor_from_mask()
|
||||
c.backward(grad_output)
|
||||
|
||||
# Grad check doesn't work with nested yet.
|
||||
# d/dnt_1 (nt + nt_1) = 1*grad_output
|
||||
self.nt_equal(nt_1.grad, grad_output)
|
||||
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@ -509,6 +509,9 @@ DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
|
||||
"dequantize_self",
|
||||
# lift() should never actually be called with a requires_grad=True tensor,
|
||||
"lift",
|
||||
# Nested Tensors related functions
|
||||
# _nested_tensor_size() should never actually be called with requires_grad=True tensor
|
||||
"_nested_tensor_size",
|
||||
}
|
||||
|
||||
DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
|
||||
@ -516,8 +519,6 @@ DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
|
||||
"_slow_conv2d_forward",
|
||||
"slow_conv3d_forward",
|
||||
"channel_shuffle",
|
||||
# lift() should never actually be called with a requires_grad=True tensor,
|
||||
"lift",
|
||||
# If an input is returned as-is in output, we cannot guarantee its storage_impl
|
||||
# use count to be 1 either.
|
||||
*DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,
|
||||
|
||||
@ -724,7 +724,7 @@ void validate_outputs(
|
||||
|
||||
if (!metadata.is_same_shape(grad)) {
|
||||
if (metadata.is_expandable_to_shape(grad)) {
|
||||
grad = at::sum_to(std::move(grad), metadata.shape());
|
||||
grad = metadata.reduce_grad(grad);
|
||||
} else {
|
||||
const auto message = metadata.incompatible_shape_error_message(i, grad);
|
||||
AT_ERROR(format_error(message.str()));
|
||||
|
||||
@ -190,7 +190,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
bool is_tensor_subclass) noexcept {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t input_nr = input_metadata_.size();
|
||||
input_metadata_.emplace_back(options, shape, is_tensor_subclass);
|
||||
auto meta_shape = MetadataShape{c10::in_place_type<at::DimVector>, shape};
|
||||
input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
|
||||
return input_nr;
|
||||
}
|
||||
|
||||
|
||||
@ -148,7 +148,8 @@ struct TORCH_API AccumulateGrad : public Node {
|
||||
new_grad.sizes(),
|
||||
new_grad.options()));
|
||||
} else {
|
||||
if (new_grad.is_sparse() || new_grad.is_sparse_csr()) {
|
||||
if (new_grad.is_sparse() || new_grad.is_sparse_csr() ||
|
||||
new_grad.is_nested()) {
|
||||
update_grad(new_grad.clone());
|
||||
} else {
|
||||
if (new_grad.is_mkldnn()) {
|
||||
|
||||
@ -1,11 +1,15 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ExpandUtils.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/Device.h>
|
||||
#include <c10/core/DeviceType.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/variant.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
#include <ATen/Functions.h>
|
||||
@ -14,10 +18,13 @@
|
||||
#endif
|
||||
|
||||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
using MetadataShape = c10::variant<at::DimVector, at::Tensor>;
|
||||
|
||||
/**
|
||||
* Records TensorOptions, shape of the tensor, whether or not the Python
|
||||
* dispatch key is set (tensor subclass), and, where applicable, the stream the
|
||||
@ -31,10 +38,10 @@ struct InputMetadata {
|
||||
|
||||
InputMetadata(
|
||||
const at::TensorOptions options,
|
||||
at::IntArrayRef shape,
|
||||
MetadataShape input_shape,
|
||||
bool is_tensor_subclass)
|
||||
: options_{options},
|
||||
shape_{shape},
|
||||
shape_{input_shape},
|
||||
is_tensor_subclass_{is_tensor_subclass} {
|
||||
auto device_ = options.device();
|
||||
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
|
||||
@ -43,17 +50,13 @@ struct InputMetadata {
|
||||
InputMetadata(const at::Tensor& t)
|
||||
: InputMetadata(
|
||||
t.options(),
|
||||
t.sizes(),
|
||||
compute_variant_shape(t),
|
||||
t.unsafeGetTensorImpl()->is_python_dispatch()) {}
|
||||
|
||||
const at::TensorOptions options() const {
|
||||
return options_;
|
||||
}
|
||||
|
||||
at::IntArrayRef shape() const {
|
||||
return shape_;
|
||||
}
|
||||
|
||||
caffe2::TypeMeta dtype() const {
|
||||
return options_.dtype();
|
||||
}
|
||||
@ -75,37 +78,83 @@ struct InputMetadata {
|
||||
}
|
||||
|
||||
at::Tensor zeros_like() const {
|
||||
return at::zeros(shape_, options_);
|
||||
TORCH_CHECK(
|
||||
!is_nested_tensor(),
|
||||
"Zeros is not currently supported for nested tensors.")
|
||||
return at::zeros(shape_as_dim_vector(), options_);
|
||||
}
|
||||
|
||||
bool is_same_shape(const at::Tensor& grad) const {
|
||||
TORCH_CHECK(!grad.is_nested(), "Nested grads are not currently supported.")
|
||||
return grad.sizes().equals(shape());
|
||||
TORCH_CHECK(
|
||||
grad.is_nested() == is_nested_tensor(),
|
||||
"Both grad and InputMetadata need to be either nested or non nested tensors.")
|
||||
if (grad.is_nested()) {
|
||||
return at::native::get_nested_size_tensor(grad).is_same_size(
|
||||
shape_as_tensor());
|
||||
}
|
||||
return grad.sizes().equals(shape_as_dim_vector());
|
||||
}
|
||||
bool is_expandable_to_shape(const at::Tensor& grad) const {
|
||||
// TODO: Currently NestedTensors are not expandable.
|
||||
return grad.is_nested() ? false
|
||||
: at::is_expandable_to(shape(), grad.sizes());
|
||||
// Currently NestedTensors are not expandable. If this support is added then
|
||||
// updates to reduce_grad will be needed
|
||||
TORCH_CHECK(
|
||||
grad.is_nested() == is_nested_tensor(),
|
||||
"Both grad and InputMetadata need to be either nested or non nested tensors.")
|
||||
return grad.is_nested()
|
||||
? false
|
||||
: at::is_expandable_to(shape_as_dim_vector(), grad.sizes());
|
||||
}
|
||||
|
||||
at::Tensor reduce_grad(at::Tensor& grad) const {
|
||||
// Currently reduce_grad is only called if is_expandable_to_shape returns
|
||||
// true For nested tensors this always returns False, so this check
|
||||
// shouldn't fail
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
|
||||
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
||||
}
|
||||
|
||||
std::stringstream incompatible_shape_error_message(
|
||||
const size_t index,
|
||||
const at::Tensor& grad) const {
|
||||
std::stringstream ss;
|
||||
TORCH_CHECK(!grad.is_nested(), "Nested grads are not currently supported.")
|
||||
ss << "invalid gradient at index " << index << " - got ";
|
||||
ss << grad.sizes();
|
||||
if (grad.is_nested()) {
|
||||
ss << at::native::get_nested_size_tensor(grad);
|
||||
} else {
|
||||
ss << grad.sizes();
|
||||
}
|
||||
ss << " but expected shape compatible with ";
|
||||
ss << shape();
|
||||
if (is_nested_tensor()) {
|
||||
ss << shape_as_tensor();
|
||||
} else {
|
||||
ss << shape_as_dim_vector();
|
||||
}
|
||||
return ss;
|
||||
}
|
||||
|
||||
private:
|
||||
bool is_nested_tensor() const {
|
||||
return (c10::holds_alternative<at::Tensor>(shape_));
|
||||
}
|
||||
MetadataShape compute_variant_shape(const at::Tensor& input) {
|
||||
if (input.is_nested()) {
|
||||
auto nested_size = at::native::get_nested_size_tensor(input);
|
||||
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
|
||||
}
|
||||
return MetadataShape{c10::in_place_type<at::DimVector>, input.sizes()};
|
||||
}
|
||||
|
||||
at::DimVector shape_as_dim_vector() const {
|
||||
return c10::get<at::DimVector>(shape_);
|
||||
}
|
||||
at::Tensor shape_as_tensor() const {
|
||||
return c10::get<at::Tensor>(shape_);
|
||||
}
|
||||
|
||||
const at::TensorOptions options_;
|
||||
at::DimVector shape_;
|
||||
MetadataShape shape_;
|
||||
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
|
||||
bool is_tensor_subclass_ = false;
|
||||
};
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
||||
@ -17,7 +17,14 @@ inline bool obeys_layout_contract(
|
||||
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
|
||||
TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());
|
||||
if (variable.is_non_overlapping_and_dense()) {
|
||||
|
||||
if (variable.is_nested()) {
|
||||
// TODO: Nested Tensor does not have an implementation of detach. The
|
||||
// current implementation of nested tensor likely does obey the gradient
|
||||
// contract and should return true, but this would likely change in the
|
||||
// future
|
||||
return false;
|
||||
} else if (variable.is_non_overlapping_and_dense()) {
|
||||
// Only look at stride for dimensions that are not of size 1.
|
||||
const auto& grad_sizes = grad.sizes();
|
||||
const auto& grad_strides = grad.strides();
|
||||
|
||||
Reference in New Issue
Block a user