Minimal NestedTensor (#72881)

Summary:
This PR adds a minimal version of a NestedTensor. It introduces the general harness future development can be built around.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/72881

Reviewed By: albanD

Differential Revision: D34259177

Pulled By: cpuhrsch

fbshipit-source-id: 0245c36f603424e20f3b09651043c207f526d760
(cherry picked from commit 10764e8d427f29b364567e4cbc86ed73c3933158)
This commit is contained in:
Christian Puhrsch
2022-03-02 07:29:19 -08:00
committed by PyTorch MergeBot
parent bbf4bc9f8e
commit 484c0de670
14 changed files with 593 additions and 1 deletions

View File

@ -234,6 +234,11 @@ filegroup(
srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]),
)
filegroup(
name = "aten_native_nested_cpp",
srcs = glob(["aten/src/ATen/native/nested/*.cpp"]),
)
filegroup(
name = "aten_native_quantized_cpp",
srcs = glob(
@ -406,6 +411,7 @@ cc_library(
":aten_native_mkldnn_cpp",
":aten_native_quantized_cpp",
":aten_native_sparse_cpp",
":aten_native_nested_cpp",
":aten_native_xnnpack",
":aten_src_ATen_config",
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),

View File

@ -104,6 +104,7 @@ file(GLOB native_sparse_cpp "native/sparse/*.cpp")
file(GLOB native_quantized_cpp
"native/quantized/*.cpp"
"native/quantized/cpu/*.cpp")
file(GLOB native_nested_cpp "native/nested/*.cpp")
file(GLOB native_h "native/*.h")
file(GLOB native_ao_sparse_h
@ -155,7 +156,7 @@ if(BUILD_LITE_INTERPRETER)
else()
set(
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
${native_ao_sparse_cpp} ${native_sparse_cpp}
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
${native_utils_cpp} ${native_xnnpack} ${generated_sources} ${core_generated_sources}
${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}

View File

@ -0,0 +1,35 @@
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
namespace at {
namespace native {
NestedTensorImpl::NestedTensorImpl(
at::Tensor buffer,
at::Tensor nested_size_tensor)
: TensorImpl(
// TODO: This doesn't properly report is_cpu/is_cuda for NestedTensor.
// The intended resolution is that once #72827 lands we will be able to
// allocate separate dispatch keys for CPUNestedTensor (and any other
// hypothetical device backends for NestedTensor); then we will be
// able to derive this directly. If you need this to work before then,
// make sure you add CPU to this dispatch key set
c10::DispatchKeySet({DispatchKey::NestedTensor}),
buffer.dtype(),
buffer.device()),
buffer_(std::move(buffer)),
nested_size_tensor_(std::move(nested_size_tensor)) {
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
int64_t size_dim = nested_size_tensor_.dim();
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
remove_autograd_key();
key_set_ =
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
}
} // namespace native
} // namespace at

View File

@ -0,0 +1,65 @@
#pragma once
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>
#include <c10/util/irange.h>
#include <ATen/MemoryOverlap.h>
#include <c10/core/MemoryFormat.h>
#include <c10/util/Metaprogramming.h>
namespace at {
namespace native {
struct NestedTensorImpl : public c10::TensorImpl {
explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
int64_t dim() const override {
TORCH_CHECK(
false, "dim is disabled. These methods are not virtual in fbcode.");
}
#endif
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
int64_t numel() const override {
TORCH_CHECK(
false, "numel is disabled. These methods are not virtual in fbcode.");
}
#endif
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
bool is_contiguous(at::MemoryFormat memory_format) const override {
TORCH_CHECK(
false,
"is_contiguous is disabled. These methods are not virtual in fbcode.");
}
#endif
const Tensor& get_nested_size_tensor() {
return nested_size_tensor_;
}
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
IntArrayRef sizes() const override {
TORCH_CHECK(
false,
"Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
return IntArrayRef();
}
#endif
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
IntArrayRef strides() const override {
TORCH_CHECK(
false,
"Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
return IntArrayRef();
}
#endif
const at::Tensor& get_buffer() const {
return buffer_;
}
private:
at::Tensor buffer_;
const at::Tensor nested_size_tensor_;
};
} // namespace native
} // namespace at

View File

@ -5473,6 +5473,7 @@
variants: function, method
dispatch:
CompositeExplicitAutograd: unbind
NestedTensor: NestedTensor_unbind
- func: unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]
variants: function, method
@ -11248,3 +11249,6 @@
- func: unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]
variants: function
python_module: nn
- func: _nested_tensor(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
variants: function

View File

@ -0,0 +1,119 @@
#include <ATen/ATen.h>
#include <ATen/NamedTensorUtils.h>
#include <ATen/WrapDimUtils.h>
#include <ATen/core/op_registration/op_registration.h>
#include <ATen/NestedTensorImpl.h>
#include <c10/core/DispatchKey.h>
namespace at {
namespace native {
at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_size_tensor) {
TORCH_CHECK(buffer.is_contiguous(), "Given buffer must be contiguous.");
return at::detail::make_tensor<NestedTensorImpl>(
std::move(buffer), std::move(nested_size_tensor));
}
bool is_nested_tensor_impl(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->key_set().has(
c10::DispatchKey::NestedTensor);
}
inline at::native::NestedTensorImpl* get_nested_tensor_impl(
const at::Tensor& tensor) {
TORCH_CHECK(
is_nested_tensor_impl(tensor),
"get_nested_tensor_impl requires a NestedTensor.");
return static_cast<at::native::NestedTensorImpl*>(
tensor.unsafeGetTensorImpl());
}
inline const at::Tensor& get_buffer(const at::Tensor& tensor) {
return get_nested_tensor_impl(tensor)->get_buffer();
}
inline const at::Tensor& get_nested_size_tensor(const at::Tensor& tensor) {
return get_nested_tensor_impl(tensor)->get_nested_size_tensor();
}
// CPU only!
// TODO: The algorithm here can be optimized, right now it involves a lot of
// small tensor manipulations
std::vector<at::Tensor> NestedTensor_unbind(
const at::Tensor& self,
int64_t dim) {
TORCH_CHECK(
dim == 0,
"NestedTensor can only be unbound along dimension 0 ",
"got dimension ",
dim,
" instead.");
auto esizes = get_nested_size_tensor(self);
std::vector<at::Tensor> result_tensors;
if (esizes.dim() == 0) {
return result_tensors;
}
auto esizes_chunks = esizes.unbind(0);
std::vector<int64_t> splits;
for (const auto i : c10::irange(esizes_chunks.size())) {
splits.push_back(esizes_chunks[i].prod().item<int64_t>());
}
auto buffer_chunks = at::split_with_sizes(get_buffer(self), splits);
for (const auto i : c10::irange(buffer_chunks.size())) {
const auto& esize_chunk = esizes_chunks[i];
result_tensors.push_back(buffer_chunks[i].view(IntArrayRef(
esize_chunk.data_ptr<int64_t>(),
esize_chunk.data_ptr<int64_t>() + esize_chunk.numel())));
}
return result_tensors;
}
/*
* This result of this function cannot be used by itself. The result needs to
* be wrapped in torch.nested.NestedTensor.
*/
Tensor _nested_tensor(
TensorList list,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
TensorOptions options_ =
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
pin_memory);
if (list.size() == 0) {
return wrap_buffer(ones({0}), ones({}));
}
std::vector<Tensor> sizes;
std::vector<Tensor> flat_tensors;
for (const auto i : c10::irange(list.size())) {
if (i > 0) {
int64_t dim_i = list[i].dim();
int64_t dim_prev = list[i - 1].dim();
TORCH_CHECK(
dim_i == dim_prev,
"All Tensors given to nested_tensor must have the same dimension. ",
"Found dimension ",
dim_i,
" for Tensor at index ",
i,
" and dimension ",
dim_prev,
" for Tensor at index ",
i - 1,
".");
}
// TODO: Remove call to contiguous once we support strides.
flat_tensors.push_back(list[i].reshape(-1).contiguous());
sizes.push_back(tensor(c10::IntArrayRef(list[i].sizes())));
}
TensorOptions options = flat_tensors[0].options().merge_in(options_);
return wrap_buffer(
at::native::cat(flat_tensors).to(options), at::native::stack(sizes));
}
} // namespace native
} // namespace at

View File

@ -84,6 +84,7 @@ Features described in this documentation are classified by release status:
quantization
rpc
torch.random <random>
nested
sparse
storage
torch.testing <testing>

62
docs/source/nested.rst Normal file
View File

@ -0,0 +1,62 @@
torch.nested
============
.. automodule:: torch.nested
Introduction
++++++++++++
.. warning::
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
.. warning::
torch.NestedTensor currently does not support autograd. It needs to be used in the context
of torch.inference_mode().
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
The only constraint on the input Tensors is that their dimension must match.
This enables more efficient metadata representations and operator coverage.
Construction is straightforward and involves passing a list of Tensors to the constructor.
>>> a, b = torch.arange(3), torch.arange(5) + 3
>>> a
tensor([0, 1, 2])
>>> b
tensor([3, 4, 5, 6, 7])
>>> nt = torch.nested_tensor([a, b])
>>> nt
nested_tensor([
tensor([0, 1, 2]),
tensor([3, 4, 5, 6, 7])
])
Data type and device can be chosen via the usual keyword arguments
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
Operator coverage
+++++++++++++++++
We are currently on our path to wholesale extend operator coverage guided by specific ML use cases.
Operator coverage thus is currently very limited and only unbind is supported.
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
>>> nt
nested_tensor([
tensor([0., 1., 2.], device='cuda:0'),
tensor([3., 4., 5., 6., 7.], device='cuda:0')
])
>>> nt.unbind()
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]

184
test/test_nestedtensor.py Normal file
View File

@ -0,0 +1,184 @@
# Owner(s): ["module: nestedtensor"]
import torch
import unittest
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
from torch import nested_tensor
# Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future.
def _iter_constructors():
# yield as_nested_tensor
yield nested_tensor
class TestNestedTensor(TestCase):
@torch.inference_mode()
def _test_unbind_case(self, a, b):
nt = nested_tensor([a, b])
a1, b1 = nt.unbind()
self.assertTrue(a is not a1)
self.assertTrue(b is not b1)
nt = nested_tensor([a, b], dtype=a.dtype)
a1, b1 = nt.unbind(0)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
a = torch.randn((2, 3)).add_(1)
nt = nested_tensor([a])
self.assertEqual(a, nt.unbind(0)[0])
@torch.inference_mode()
def test_unbind_0(self):
self._test_unbind_case(
torch.tensor([1, 2]), torch.tensor([7, 8]),
)
@torch.inference_mode()
def test_unbind_1(self):
self._test_unbind_case(
torch.tensor([1]), torch.tensor([7]),
)
# @torch.inference_mode()
# def test_unbind_2(self):
# self._test_unbind_case(
# torch.tensor(1), torch.tensor(7),
# )
@torch.inference_mode()
def test_unbind_3(self):
self._test_unbind_case(
torch.tensor([1.0]), torch.tensor([]),
)
@torch.inference_mode()
def test_unbind_4(self):
self._test_unbind_case(
torch.tensor([]), torch.tensor([]),
)
@torch.inference_mode()
def test_unbind_dim(self):
def _test_fn(unbind_fn):
a = torch.rand(3, 2)
b = torch.rand(2, 3)
nt = nested_tensor([a, b])
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
# Both of these tests are necessary, because we're using
# torch_function.
_test_fn(lambda x, dim: x.unbind(dim))
# TODO: Re-enable this once using torch_dispatch
# _test_fn(lambda x, dim: torch.unbind(x, dim))
@torch.inference_mode()
def test_nested_tensor(self):
self.assertRaises(TypeError, lambda: nested_tensor([3.0]))
self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0])))
self.assertRaises(TypeError, lambda: nested_tensor(4.0))
@torch.inference_mode()
def test_nested_tensor_matching_dim(self):
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
lambda: nested_tensor([torch.tensor(1.0), torch.tensor([])]),
)
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
lambda: nested_tensor(
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
),
)
@torch.inference_mode()
def test_default_nested_tensor(self):
self.assertRaises(TypeError, lambda: nested_tensor())
default_nested_tensor = nested_tensor([])
default_tensor = torch.tensor([])
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
# self.assertEqual(default_nested_tensor.nested_size(), ())
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
self.assertEqual(default_nested_tensor.device, default_tensor.device)
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
self.assertEqual(
default_nested_tensor.requires_grad, default_tensor.requires_grad
)
self.assertIsNone(default_tensor.grad)
# TODO: Re-enable once we have a performance driven
# use case and implementation.
# self.assertEqual(default_nested_tensor.is_pinned(),
# default_tensor.is_pinned())
@torch.inference_mode()
def test_dim(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor(3.0)])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor([1, 2, 3, 4])])
self.assertEqual(a1.dim(), 2)
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
@torch.inference_mode()
def test_numel(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError, "numel is disabled", lambda: a1.numel(),
)
@unittest.skipIf(IS_FBCODE, "size is not virtual in fbcode.")
@torch.inference_mode()
def test_size(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"NestedTensorImpl doesn't support sizes",
lambda: a1.size(),
)
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
@torch.inference_mode()
def test_stride(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"NestedTensorImpl doesn't support strides",
lambda: a1.stride(),
)
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
@torch.inference_mode()
def test_is_contiguous(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError, "is_contiguous is disabled", lambda: a1.is_contiguous()
)
@torch.inference_mode()
def test_repr_string(self):
a = nested_tensor([])
expected = "nested_tensor([" "\n\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = nested_tensor([torch.tensor(1.0)])
expected = "nested_tensor([" "\n tensor(1.)" "\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
expected = (
"nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])"
)
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)

View File

@ -974,6 +974,7 @@ aten_cpu_source_non_codegen_list = [
"aten/src/ATen/MemoryOverlap.cpp",
"aten/src/ATen/MapAllocator.cpp",
"aten/src/ATen/NamedTensorUtils.cpp",
"aten/src/ATen/NestedTensorImpl.cpp",
"aten/src/ATen/ParallelCommon.cpp",
"aten/src/ATen/ParallelNative.cpp",
"aten/src/ATen/ParallelNativeTBB.cpp",
@ -1316,6 +1317,7 @@ aten_native_source_non_codegen_list = [
"aten/src/ATen/native/WeightNorm.cpp",
"aten/src/ATen/native/group_norm.cpp",
"aten/src/ATen/native/layer_norm.cpp",
"aten/src/ATen/native/nested/NestedTensorMath.cpp",
"aten/src/ATen/native/sparse/ParamUtils.cpp",
"aten/src/ATen/native/sparse/SoftMax.cpp",
"aten/src/ATen/native/sparse/SparseBlas.cpp",

View File

@ -139,6 +139,7 @@ dispatch_keys = [
DispatchKey.QuantizedCUDA,
DispatchKey.CompositeImplicitAutograd,
DispatchKey.CompositeExplicitAutograd,
DispatchKey.NestedTensor,
# Meta is a magic key: it is automatically generated for structured
# kernels
DispatchKey.Meta,

View File

@ -819,6 +819,9 @@ from torch import __config__ as __config__
from torch import __future__ as __future__
from torch import profiler as profiler
from torch.nested._nestedtensor import NestedTensor
from torch.nested._nestedtensor import nested_tensor
_C._init_names(list(torch._storage_classes))
# attach docstrings to torch and tensor functions

0
torch/nested/__init__.py Normal file
View File

View File

@ -0,0 +1,109 @@
import torch
from functools import wraps
@wraps(torch._nested_tensor)
def nested_tensor(*args, **kwargs):
return NestedTensor(torch._nested_tensor(*args, **kwargs))
# TODO: This entire class is not really necessary now that NestedTensor lives
# in tree; before it lived out of tree and there was no way to conveniently
# override the string printing behavior. Now that we are in tree, we can
# directly override _tensor_str to capture this behavior, and the wrapper subclass
# is not necessary. See also https://github.com/pytorch/pytorch/issues/73506
class NestedTensor:
# data is a torch.Tensor backed by a NestedTensorImpl
def __init__(self, impl):
self._impl = impl
@property
def dtype(self):
"""
The data type of ```self``` NestedTensor.
"""
return self._impl.dtype
@property
def layout(self):
"""
The layout of ```self``` NestedTensor.
"""
return self._impl.layout
@property
def device(self):
"""
The device of ```self``` NestedTensor.
"""
return self._impl.device
@property
def requires_grad(self):
"""
Is ```True``` if gradients need to be computed for this Tensor.
"""
return self._impl.requires_grad
def stride(self):
"""
NestedTensor currently does not have a stride. This will throw.
"""
return self._impl.stride()
def size(self):
"""
NestedTensor currently does not have a size. This will throw.
"""
return self._impl.size()
def dim(self):
"""
The dimension of ```self``` NestedTensor.
"""
tensors = self.unbind()
if len(tensors) == 0:
return 1
return int(tensors[0].dim() + 1)
def numel(self):
"""
The number of elements of ```self``` NestedTensor.
"""
return self._impl.numel()
def is_contiguous(self):
"""
Returns true if ```self``` NestedTensor is contiguous.
"""
return self._impl.is_contiguous()
def __str__(self):
def _str(x, indent=0, tab=" "):
s = indent * tab + "[\n"
strs = list(map(str, x.unbind()))
strs = list(
map(
lambda xi: "\n".join(
map(lambda xij: (indent + 1) * tab + xij, xi.split("\n"))
),
strs,
)
)
s += ",\n".join(strs)
s += "\n" + indent * tab + "]"
return s
return "nested_tensor(" + _str(self) + ")"
def __repr__(self):
return self.__str__()
def unbind(self, dim=None):
if dim is None:
unbound = torch.ops.aten.unbind.int(self._impl, 0)
if len(unbound) == 0:
return ()
return unbound
return torch.ops.aten.unbind.int(self._impl, dim)