Add nested tensor support to autograd (#79446)

The issue that is tracking this work is: #79447

This is one in a series of PRs to add autograd support for nested tensors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79446
Approved by: https://github.com/soulitzer
This commit is contained in:
drisspg
2022-06-16 21:09:17 +00:00
committed by PyTorch MergeBot
parent 4b342b30ad
commit f9656817df
9 changed files with 129 additions and 28 deletions

View File

@ -54,9 +54,6 @@ NestedTensorImpl::NestedTensorImpl(
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
int64_t size_dim = nested_size_tensor_.dim();
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
remove_autograd_key();
key_set_ =
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
refresh_dim();
set_sizes_strides_policy(c10::TensorImpl::SizesStridesPolicy::CustomSizes);
}

View File

@ -16,6 +16,5 @@ TORCH_API std::vector<int64_t> NestedTensor_get_max_size(const NestedTensorImpl&
TORCH_API Tensor NestedTensor_to_padded_tensor_generic(const Tensor& t, double padding, OptionalIntArrayRef output_size);
} // namespace native
} // namespace at

View File

@ -482,6 +482,52 @@ class TestNestedTensorDeviceType(TestCase):
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)
class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
self.assertEqual(nt1.device, nt2.device)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
self.assertEqual(len(ub1), len(ub2))
n = len(ub1)
for i in range(n):
self.assertEqual(ub1[i], ub2[i])
def _create_nested_tensor_from_list(self, requires_grad=False):
return torch.nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
torch.randn(7, 8, requires_grad=requires_grad)])
def _create_nested_tensor_from_mask(self, requires_grad=False):
data = torch.randn(2, 3, 4, requires_grad=requires_grad)
mask = torch.ones_like(data[:, :, 0]).bool()
return torch._nested_tensor_from_mask(data, mask)
def test_set_requires_grad_from_list(self):
nt = self._create_nested_tensor_from_list()
nt.requires_grad_()
assert nt.requires_grad
def test_set_requires_grad_from_mask(self):
nt = self._create_nested_tensor_from_mask()
nt.requires_grad_()
assert nt.requires_grad
def test_backward_for_add_op(self):
nt_1 = self._create_nested_tensor_from_mask()
nt_2 = self._create_nested_tensor_from_mask()
nt_1.requires_grad_()
c = nt_1 + nt_2
assert nt_1.requires_grad
assert c.requires_grad
grad_output = self._create_nested_tensor_from_mask()
c.backward(grad_output)
# Grad check doesn't work with nested yet.
# d/dnt_1 (nt + nt_1) = 1*grad_output
self.nt_equal(nt_1.grad, grad_output)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
if __name__ == '__main__':

View File

@ -509,6 +509,9 @@ DONT_ENFORCE_TENSOR_IMPL_USE_COUNT = {
"dequantize_self",
# lift() should never actually be called with a requires_grad=True tensor,
"lift",
# Nested Tensors related functions
# _nested_tensor_size() should never actually be called with requires_grad=True tensor
"_nested_tensor_size",
}
DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
@ -516,8 +519,6 @@ DONT_ENFORCE_STORAGE_IMPL_USE_COUNT = {
"_slow_conv2d_forward",
"slow_conv3d_forward",
"channel_shuffle",
# lift() should never actually be called with a requires_grad=True tensor,
"lift",
# If an input is returned as-is in output, we cannot guarantee its storage_impl
# use count to be 1 either.
*DONT_ENFORCE_TENSOR_IMPL_USE_COUNT,

View File

@ -724,7 +724,7 @@ void validate_outputs(
if (!metadata.is_same_shape(grad)) {
if (metadata.is_expandable_to_shape(grad)) {
grad = at::sum_to(std::move(grad), metadata.shape());
grad = metadata.reduce_grad(grad);
} else {
const auto message = metadata.incompatible_shape_error_message(i, grad);
AT_ERROR(format_error(message.str()));

View File

@ -190,7 +190,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
bool is_tensor_subclass) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
input_metadata_.emplace_back(options, shape, is_tensor_subclass);
auto meta_shape = MetadataShape{c10::in_place_type<at::DimVector>, shape};
input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
return input_nr;
}

View File

@ -148,7 +148,8 @@ struct TORCH_API AccumulateGrad : public Node {
new_grad.sizes(),
new_grad.options()));
} else {
if (new_grad.is_sparse() || new_grad.is_sparse_csr()) {
if (new_grad.is_sparse() || new_grad.is_sparse_csr() ||
new_grad.is_nested()) {
update_grad(new_grad.clone());
} else {
if (new_grad.is_mkldnn()) {

View File

@ -1,11 +1,15 @@
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/NestedTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <c10/core/Device.h>
#include <c10/core/DeviceType.h>
#include <c10/core/Stream.h>
#include <c10/core/TensorImpl.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/util/Exception.h>
#include <c10/util/variant.h>
#ifndef AT_PER_OPERATOR_HEADERS
#include <ATen/Functions.h>
@ -14,10 +18,13 @@
#endif
#include <cstdint>
#include <utility>
namespace torch {
namespace autograd {
using MetadataShape = c10::variant<at::DimVector, at::Tensor>;
/**
* Records TensorOptions, shape of the tensor, whether or not the Python
* dispatch key is set (tensor subclass), and, where applicable, the stream the
@ -31,10 +38,10 @@ struct InputMetadata {
InputMetadata(
const at::TensorOptions options,
at::IntArrayRef shape,
MetadataShape input_shape,
bool is_tensor_subclass)
: options_{options},
shape_{shape},
shape_{input_shape},
is_tensor_subclass_{is_tensor_subclass} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
@ -43,17 +50,13 @@ struct InputMetadata {
InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
t.sizes(),
compute_variant_shape(t),
t.unsafeGetTensorImpl()->is_python_dispatch()) {}
const at::TensorOptions options() const {
return options_;
}
at::IntArrayRef shape() const {
return shape_;
}
caffe2::TypeMeta dtype() const {
return options_.dtype();
}
@ -75,37 +78,83 @@ struct InputMetadata {
}
at::Tensor zeros_like() const {
return at::zeros(shape_, options_);
TORCH_CHECK(
!is_nested_tensor(),
"Zeros is not currently supported for nested tensors.")
return at::zeros(shape_as_dim_vector(), options_);
}
bool is_same_shape(const at::Tensor& grad) const {
TORCH_CHECK(!grad.is_nested(), "Nested grads are not currently supported.")
return grad.sizes().equals(shape());
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
if (grad.is_nested()) {
return at::native::get_nested_size_tensor(grad).is_same_size(
shape_as_tensor());
}
return grad.sizes().equals(shape_as_dim_vector());
}
bool is_expandable_to_shape(const at::Tensor& grad) const {
// TODO: Currently NestedTensors are not expandable.
return grad.is_nested() ? false
: at::is_expandable_to(shape(), grad.sizes());
// Currently NestedTensors are not expandable. If this support is added then
// updates to reduce_grad will be needed
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
return grad.is_nested()
? false
: at::is_expandable_to(shape_as_dim_vector(), grad.sizes());
}
at::Tensor reduce_grad(at::Tensor& grad) const {
// Currently reduce_grad is only called if is_expandable_to_shape returns
// true For nested tensors this always returns False, so this check
// shouldn't fail
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
std::stringstream incompatible_shape_error_message(
const size_t index,
const at::Tensor& grad) const {
std::stringstream ss;
TORCH_CHECK(!grad.is_nested(), "Nested grads are not currently supported.")
ss << "invalid gradient at index " << index << " - got ";
ss << grad.sizes();
if (grad.is_nested()) {
ss << at::native::get_nested_size_tensor(grad);
} else {
ss << grad.sizes();
}
ss << " but expected shape compatible with ";
ss << shape();
if (is_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << shape_as_dim_vector();
}
return ss;
}
private:
bool is_nested_tensor() const {
return (c10::holds_alternative<at::Tensor>(shape_));
}
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested()) {
auto nested_size = at::native::get_nested_size_tensor(input);
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{c10::in_place_type<at::DimVector>, input.sizes()};
}
at::DimVector shape_as_dim_vector() const {
return c10::get<at::DimVector>(shape_);
}
at::Tensor shape_as_tensor() const {
return c10::get<at::Tensor>(shape_);
}
const at::TensorOptions options_;
at::DimVector shape_;
MetadataShape shape_;
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
bool is_tensor_subclass_ = false;
};
} // namespace autograd
} // namespace torch

View File

@ -17,7 +17,14 @@ inline bool obeys_layout_contract(
TORCH_INTERNAL_ASSERT(!variable.is_sparse());
TORCH_INTERNAL_ASSERT(!grad.is_sparse_csr());
TORCH_INTERNAL_ASSERT(!variable.is_sparse_csr());
if (variable.is_non_overlapping_and_dense()) {
if (variable.is_nested()) {
// TODO: Nested Tensor does not have an implementation of detach. The
// current implementation of nested tensor likely does obey the gradient
// contract and should return true, but this would likely change in the
// future
return false;
} else if (variable.is_non_overlapping_and_dense()) {
// Only look at stride for dimensions that are not of size 1.
const auto& grad_sizes = grad.sizes();
const auto& grad_strides = grad.strides();