mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Minimal NestedTensor (#72881)
Summary: This PR adds a minimal version of a NestedTensor. It introduces the general harness future development can be built around. Pull Request resolved: https://github.com/pytorch/pytorch/pull/72881 Reviewed By: albanD Differential Revision: D34259177 Pulled By: cpuhrsch fbshipit-source-id: 0245c36f603424e20f3b09651043c207f526d760 (cherry picked from commit 10764e8d427f29b364567e4cbc86ed73c3933158)
This commit is contained in:
committed by
PyTorch MergeBot
parent
bbf4bc9f8e
commit
484c0de670
@ -234,6 +234,11 @@ filegroup(
|
||||
srcs = glob(["aten/src/ATen/native/sparse/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_native_nested_cpp",
|
||||
srcs = glob(["aten/src/ATen/native/nested/*.cpp"]),
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "aten_native_quantized_cpp",
|
||||
srcs = glob(
|
||||
@ -406,6 +411,7 @@ cc_library(
|
||||
":aten_native_mkldnn_cpp",
|
||||
":aten_native_quantized_cpp",
|
||||
":aten_native_sparse_cpp",
|
||||
":aten_native_nested_cpp",
|
||||
":aten_native_xnnpack",
|
||||
":aten_src_ATen_config",
|
||||
] + generated_cpu_cpp + aten_ufunc_generated_cpu_sources("aten/src/ATen/{}"),
|
||||
|
@ -104,6 +104,7 @@ file(GLOB native_sparse_cpp "native/sparse/*.cpp")
|
||||
file(GLOB native_quantized_cpp
|
||||
"native/quantized/*.cpp"
|
||||
"native/quantized/cpu/*.cpp")
|
||||
file(GLOB native_nested_cpp "native/nested/*.cpp")
|
||||
|
||||
file(GLOB native_h "native/*.h")
|
||||
file(GLOB native_ao_sparse_h
|
||||
@ -155,7 +156,7 @@ if(BUILD_LITE_INTERPRETER)
|
||||
else()
|
||||
set(
|
||||
all_cpu_cpp ${base_cpp} ${ATen_CORE_SRCS} ${native_cpp}
|
||||
${native_ao_sparse_cpp} ${native_sparse_cpp}
|
||||
${native_ao_sparse_cpp} ${native_sparse_cpp} ${native_nested_cpp}
|
||||
${native_quantized_cpp} ${native_mkl_cpp} ${native_mkldnn_cpp}
|
||||
${native_utils_cpp} ${native_xnnpack} ${generated_sources} ${core_generated_sources}
|
||||
${ATen_CPU_SRCS} ${ATen_QUANTIZED_SRCS} ${ATen_NNAPI_SRCS} ${cpu_kernel_cpp}
|
||||
|
35
aten/src/ATen/NestedTensorImpl.cpp
Normal file
35
aten/src/ATen/NestedTensorImpl.cpp
Normal file
@ -0,0 +1,35 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
NestedTensorImpl::NestedTensorImpl(
|
||||
at::Tensor buffer,
|
||||
at::Tensor nested_size_tensor)
|
||||
: TensorImpl(
|
||||
// TODO: This doesn't properly report is_cpu/is_cuda for NestedTensor.
|
||||
// The intended resolution is that once #72827 lands we will be able to
|
||||
// allocate separate dispatch keys for CPUNestedTensor (and any other
|
||||
// hypothetical device backends for NestedTensor); then we will be
|
||||
// able to derive this directly. If you need this to work before then,
|
||||
// make sure you add CPU to this dispatch key set
|
||||
c10::DispatchKeySet({DispatchKey::NestedTensor}),
|
||||
buffer.dtype(),
|
||||
buffer.device()),
|
||||
buffer_(std::move(buffer)),
|
||||
nested_size_tensor_(std::move(nested_size_tensor)) {
|
||||
TORCH_INTERNAL_ASSERT(nested_size_tensor_.is_contiguous());
|
||||
int64_t size_dim = nested_size_tensor_.dim();
|
||||
TORCH_INTERNAL_ASSERT(size_dim == 0 || size_dim == 2);
|
||||
remove_autograd_key();
|
||||
key_set_ =
|
||||
key_set_ - c10::DispatchKeySet({c10::DispatchKey::ADInplaceOrView});
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
65
aten/src/ATen/NestedTensorImpl.h
Normal file
65
aten/src/ATen/NestedTensorImpl.h
Normal file
@ -0,0 +1,65 @@
|
||||
#pragma once
|
||||
#include <ATen/Tensor.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/util/irange.h>
|
||||
#include <ATen/MemoryOverlap.h>
|
||||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
struct NestedTensorImpl : public c10::TensorImpl {
|
||||
explicit NestedTensorImpl(at::Tensor buffer, at::Tensor nested_size_tensor);
|
||||
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
int64_t dim() const override {
|
||||
TORCH_CHECK(
|
||||
false, "dim is disabled. These methods are not virtual in fbcode.");
|
||||
}
|
||||
#endif
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
int64_t numel() const override {
|
||||
TORCH_CHECK(
|
||||
false, "numel is disabled. These methods are not virtual in fbcode.");
|
||||
}
|
||||
#endif
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
bool is_contiguous(at::MemoryFormat memory_format) const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"is_contiguous is disabled. These methods are not virtual in fbcode.");
|
||||
}
|
||||
#endif
|
||||
const Tensor& get_nested_size_tensor() {
|
||||
return nested_size_tensor_;
|
||||
}
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
IntArrayRef sizes() const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Internal error: NestedTensorImpl doesn't support sizes. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
return IntArrayRef();
|
||||
}
|
||||
#endif
|
||||
#ifndef C10_DISABLE_TENSORIMPL_EXTENSIBILITY
|
||||
IntArrayRef strides() const override {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"Internal error: NestedTensorImpl doesn't support strides. Please file an issue on https://github.com/pytorch/nestedtensor");
|
||||
return IntArrayRef();
|
||||
}
|
||||
#endif
|
||||
|
||||
const at::Tensor& get_buffer() const {
|
||||
return buffer_;
|
||||
}
|
||||
|
||||
private:
|
||||
at::Tensor buffer_;
|
||||
const at::Tensor nested_size_tensor_;
|
||||
};
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -5473,6 +5473,7 @@
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: unbind
|
||||
NestedTensor: NestedTensor_unbind
|
||||
|
||||
- func: unbind.Dimname(Tensor(a -> *) self, Dimname dim) -> Tensor(a)[]
|
||||
variants: function, method
|
||||
@ -11248,3 +11249,6 @@
|
||||
- func: unflatten_dense_tensors(Tensor flat, Tensor[] tensors) -> Tensor[]
|
||||
variants: function
|
||||
python_module: nn
|
||||
|
||||
- func: _nested_tensor(Tensor[] list, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
variants: function
|
||||
|
119
aten/src/ATen/native/nested/NestedTensorMath.cpp
Normal file
119
aten/src/ATen/native/nested/NestedTensorMath.cpp
Normal file
@ -0,0 +1,119 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/NamedTensorUtils.h>
|
||||
#include <ATen/WrapDimUtils.h>
|
||||
#include <ATen/core/op_registration/op_registration.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <c10/core/DispatchKey.h>
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
|
||||
|
||||
at::Tensor wrap_buffer(at::Tensor buffer, at::Tensor nested_size_tensor) {
|
||||
TORCH_CHECK(buffer.is_contiguous(), "Given buffer must be contiguous.");
|
||||
return at::detail::make_tensor<NestedTensorImpl>(
|
||||
std::move(buffer), std::move(nested_size_tensor));
|
||||
}
|
||||
|
||||
bool is_nested_tensor_impl(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->key_set().has(
|
||||
c10::DispatchKey::NestedTensor);
|
||||
}
|
||||
|
||||
inline at::native::NestedTensorImpl* get_nested_tensor_impl(
|
||||
const at::Tensor& tensor) {
|
||||
TORCH_CHECK(
|
||||
is_nested_tensor_impl(tensor),
|
||||
"get_nested_tensor_impl requires a NestedTensor.");
|
||||
return static_cast<at::native::NestedTensorImpl*>(
|
||||
tensor.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
inline const at::Tensor& get_buffer(const at::Tensor& tensor) {
|
||||
return get_nested_tensor_impl(tensor)->get_buffer();
|
||||
}
|
||||
|
||||
inline const at::Tensor& get_nested_size_tensor(const at::Tensor& tensor) {
|
||||
return get_nested_tensor_impl(tensor)->get_nested_size_tensor();
|
||||
}
|
||||
|
||||
// CPU only!
|
||||
// TODO: The algorithm here can be optimized, right now it involves a lot of
|
||||
// small tensor manipulations
|
||||
std::vector<at::Tensor> NestedTensor_unbind(
|
||||
const at::Tensor& self,
|
||||
int64_t dim) {
|
||||
TORCH_CHECK(
|
||||
dim == 0,
|
||||
"NestedTensor can only be unbound along dimension 0 ",
|
||||
"got dimension ",
|
||||
dim,
|
||||
" instead.");
|
||||
auto esizes = get_nested_size_tensor(self);
|
||||
std::vector<at::Tensor> result_tensors;
|
||||
if (esizes.dim() == 0) {
|
||||
return result_tensors;
|
||||
}
|
||||
auto esizes_chunks = esizes.unbind(0);
|
||||
std::vector<int64_t> splits;
|
||||
for (const auto i : c10::irange(esizes_chunks.size())) {
|
||||
splits.push_back(esizes_chunks[i].prod().item<int64_t>());
|
||||
}
|
||||
auto buffer_chunks = at::split_with_sizes(get_buffer(self), splits);
|
||||
for (const auto i : c10::irange(buffer_chunks.size())) {
|
||||
const auto& esize_chunk = esizes_chunks[i];
|
||||
result_tensors.push_back(buffer_chunks[i].view(IntArrayRef(
|
||||
esize_chunk.data_ptr<int64_t>(),
|
||||
esize_chunk.data_ptr<int64_t>() + esize_chunk.numel())));
|
||||
}
|
||||
return result_tensors;
|
||||
}
|
||||
|
||||
/*
|
||||
* This result of this function cannot be used by itself. The result needs to
|
||||
* be wrapped in torch.nested.NestedTensor.
|
||||
*/
|
||||
Tensor _nested_tensor(
|
||||
TensorList list,
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
TensorOptions options_ =
|
||||
TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(
|
||||
pin_memory);
|
||||
|
||||
if (list.size() == 0) {
|
||||
return wrap_buffer(ones({0}), ones({}));
|
||||
}
|
||||
std::vector<Tensor> sizes;
|
||||
std::vector<Tensor> flat_tensors;
|
||||
for (const auto i : c10::irange(list.size())) {
|
||||
if (i > 0) {
|
||||
int64_t dim_i = list[i].dim();
|
||||
int64_t dim_prev = list[i - 1].dim();
|
||||
TORCH_CHECK(
|
||||
dim_i == dim_prev,
|
||||
"All Tensors given to nested_tensor must have the same dimension. ",
|
||||
"Found dimension ",
|
||||
dim_i,
|
||||
" for Tensor at index ",
|
||||
i,
|
||||
" and dimension ",
|
||||
dim_prev,
|
||||
" for Tensor at index ",
|
||||
i - 1,
|
||||
".");
|
||||
}
|
||||
// TODO: Remove call to contiguous once we support strides.
|
||||
flat_tensors.push_back(list[i].reshape(-1).contiguous());
|
||||
sizes.push_back(tensor(c10::IntArrayRef(list[i].sizes())));
|
||||
}
|
||||
|
||||
TensorOptions options = flat_tensors[0].options().merge_in(options_);
|
||||
|
||||
return wrap_buffer(
|
||||
at::native::cat(flat_tensors).to(options), at::native::stack(sizes));
|
||||
}
|
||||
} // namespace native
|
||||
} // namespace at
|
@ -84,6 +84,7 @@ Features described in this documentation are classified by release status:
|
||||
quantization
|
||||
rpc
|
||||
torch.random <random>
|
||||
nested
|
||||
sparse
|
||||
storage
|
||||
torch.testing <testing>
|
||||
|
62
docs/source/nested.rst
Normal file
62
docs/source/nested.rst
Normal file
@ -0,0 +1,62 @@
|
||||
torch.nested
|
||||
============
|
||||
|
||||
.. automodule:: torch.nested
|
||||
|
||||
Introduction
|
||||
++++++++++++
|
||||
|
||||
.. warning::
|
||||
|
||||
The PyTorch API of nested tensors is in prototype stage and will change in the near future.
|
||||
|
||||
.. warning::
|
||||
|
||||
torch.NestedTensor currently does not support autograd. It needs to be used in the context
|
||||
of torch.inference_mode().
|
||||
|
||||
NestedTensor allows the user to pack a list of Tensors into a single, efficient datastructure.
|
||||
|
||||
The only constraint on the input Tensors is that their dimension must match.
|
||||
|
||||
This enables more efficient metadata representations and operator coverage.
|
||||
|
||||
Construction is straightforward and involves passing a list of Tensors to the constructor.
|
||||
|
||||
>>> a, b = torch.arange(3), torch.arange(5) + 3
|
||||
>>> a
|
||||
tensor([0, 1, 2])
|
||||
>>> b
|
||||
tensor([3, 4, 5, 6, 7])
|
||||
>>> nt = torch.nested_tensor([a, b])
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([0, 1, 2]),
|
||||
tensor([3, 4, 5, 6, 7])
|
||||
])
|
||||
|
||||
Data type and device can be chosen via the usual keyword arguments
|
||||
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([0., 1., 2.], device='cuda:0'),
|
||||
tensor([3., 4., 5., 6., 7.], device='cuda:0')
|
||||
])
|
||||
|
||||
|
||||
Operator coverage
|
||||
+++++++++++++++++
|
||||
|
||||
We are currently on our path to wholesale extend operator coverage guided by specific ML use cases.
|
||||
|
||||
Operator coverage thus is currently very limited and only unbind is supported.
|
||||
|
||||
>>> nt = torch.nested_tensor([a, b], dtype=torch.float32, device="cuda")
|
||||
>>> nt
|
||||
nested_tensor([
|
||||
tensor([0., 1., 2.], device='cuda:0'),
|
||||
tensor([3., 4., 5., 6., 7.], device='cuda:0')
|
||||
])
|
||||
>>> nt.unbind()
|
||||
[tensor([0., 1., 2.], device='cuda:0'), tensor([3., 4., 5., 6., 7.], device='cuda:0')]
|
184
test/test_nestedtensor.py
Normal file
184
test/test_nestedtensor.py
Normal file
@ -0,0 +1,184 @@
|
||||
# Owner(s): ["module: nestedtensor"]
|
||||
|
||||
import torch
|
||||
import unittest
|
||||
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
|
||||
from torch import nested_tensor
|
||||
|
||||
# Tests are ported from pytorch/nestedtensor.
|
||||
# This makes porting as_nested_tensor easier in the future.
|
||||
def _iter_constructors():
|
||||
# yield as_nested_tensor
|
||||
yield nested_tensor
|
||||
|
||||
|
||||
class TestNestedTensor(TestCase):
|
||||
@torch.inference_mode()
|
||||
def _test_unbind_case(self, a, b):
|
||||
nt = nested_tensor([a, b])
|
||||
a1, b1 = nt.unbind()
|
||||
self.assertTrue(a is not a1)
|
||||
self.assertTrue(b is not b1)
|
||||
|
||||
nt = nested_tensor([a, b], dtype=a.dtype)
|
||||
a1, b1 = nt.unbind(0)
|
||||
self.assertEqual(a, a1)
|
||||
self.assertEqual(b, b1)
|
||||
|
||||
a = torch.randn((2, 3)).add_(1)
|
||||
nt = nested_tensor([a])
|
||||
self.assertEqual(a, nt.unbind(0)[0])
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_unbind_0(self):
|
||||
self._test_unbind_case(
|
||||
torch.tensor([1, 2]), torch.tensor([7, 8]),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_unbind_1(self):
|
||||
self._test_unbind_case(
|
||||
torch.tensor([1]), torch.tensor([7]),
|
||||
)
|
||||
|
||||
# @torch.inference_mode()
|
||||
# def test_unbind_2(self):
|
||||
# self._test_unbind_case(
|
||||
# torch.tensor(1), torch.tensor(7),
|
||||
# )
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_unbind_3(self):
|
||||
self._test_unbind_case(
|
||||
torch.tensor([1.0]), torch.tensor([]),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_unbind_4(self):
|
||||
self._test_unbind_case(
|
||||
torch.tensor([]), torch.tensor([]),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_unbind_dim(self):
|
||||
def _test_fn(unbind_fn):
|
||||
a = torch.rand(3, 2)
|
||||
b = torch.rand(2, 3)
|
||||
nt = nested_tensor([a, b])
|
||||
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
|
||||
|
||||
# Both of these tests are necessary, because we're using
|
||||
# torch_function.
|
||||
_test_fn(lambda x, dim: x.unbind(dim))
|
||||
# TODO: Re-enable this once using torch_dispatch
|
||||
# _test_fn(lambda x, dim: torch.unbind(x, dim))
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_nested_tensor(self):
|
||||
self.assertRaises(TypeError, lambda: nested_tensor([3.0]))
|
||||
self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0])))
|
||||
self.assertRaises(TypeError, lambda: nested_tensor(4.0))
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_nested_tensor_matching_dim(self):
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
|
||||
lambda: nested_tensor([torch.tensor(1.0), torch.tensor([])]),
|
||||
)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
|
||||
lambda: nested_tensor(
|
||||
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
|
||||
),
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_default_nested_tensor(self):
|
||||
self.assertRaises(TypeError, lambda: nested_tensor())
|
||||
default_nested_tensor = nested_tensor([])
|
||||
default_tensor = torch.tensor([])
|
||||
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
|
||||
# self.assertEqual(default_nested_tensor.nested_size(), ())
|
||||
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
|
||||
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
|
||||
self.assertEqual(default_nested_tensor.device, default_tensor.device)
|
||||
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
|
||||
self.assertEqual(
|
||||
default_nested_tensor.requires_grad, default_tensor.requires_grad
|
||||
)
|
||||
self.assertIsNone(default_tensor.grad)
|
||||
# TODO: Re-enable once we have a performance driven
|
||||
# use case and implementation.
|
||||
# self.assertEqual(default_nested_tensor.is_pinned(),
|
||||
# default_tensor.is_pinned())
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_dim(self):
|
||||
for constructor in _iter_constructors():
|
||||
a1 = constructor([])
|
||||
self.assertEqual(a1.dim(), 1)
|
||||
a1 = constructor([torch.tensor(3.0)])
|
||||
self.assertEqual(a1.dim(), 1)
|
||||
a1 = constructor([torch.tensor([1, 2, 3, 4])])
|
||||
self.assertEqual(a1.dim(), 2)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
|
||||
@torch.inference_mode()
|
||||
def test_numel(self):
|
||||
for constructor in _iter_constructors():
|
||||
a1 = constructor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "numel is disabled", lambda: a1.numel(),
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "size is not virtual in fbcode.")
|
||||
@torch.inference_mode()
|
||||
def test_size(self):
|
||||
for constructor in _iter_constructors():
|
||||
a1 = constructor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"NestedTensorImpl doesn't support sizes",
|
||||
lambda: a1.size(),
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
|
||||
@torch.inference_mode()
|
||||
def test_stride(self):
|
||||
for constructor in _iter_constructors():
|
||||
a1 = constructor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"NestedTensorImpl doesn't support strides",
|
||||
lambda: a1.stride(),
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
|
||||
@torch.inference_mode()
|
||||
def test_is_contiguous(self):
|
||||
for constructor in _iter_constructors():
|
||||
a1 = constructor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "is_contiguous is disabled", lambda: a1.is_contiguous()
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def test_repr_string(self):
|
||||
a = nested_tensor([])
|
||||
expected = "nested_tensor([" "\n\n])"
|
||||
self.assertEqual(str(a), expected)
|
||||
self.assertEqual(repr(a), expected)
|
||||
|
||||
a = nested_tensor([torch.tensor(1.0)])
|
||||
expected = "nested_tensor([" "\n tensor(1.)" "\n])"
|
||||
self.assertEqual(str(a), expected)
|
||||
self.assertEqual(repr(a), expected)
|
||||
|
||||
a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
|
||||
expected = (
|
||||
"nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])"
|
||||
)
|
||||
self.assertEqual(str(a), expected)
|
||||
self.assertEqual(repr(a), expected)
|
@ -974,6 +974,7 @@ aten_cpu_source_non_codegen_list = [
|
||||
"aten/src/ATen/MemoryOverlap.cpp",
|
||||
"aten/src/ATen/MapAllocator.cpp",
|
||||
"aten/src/ATen/NamedTensorUtils.cpp",
|
||||
"aten/src/ATen/NestedTensorImpl.cpp",
|
||||
"aten/src/ATen/ParallelCommon.cpp",
|
||||
"aten/src/ATen/ParallelNative.cpp",
|
||||
"aten/src/ATen/ParallelNativeTBB.cpp",
|
||||
@ -1316,6 +1317,7 @@ aten_native_source_non_codegen_list = [
|
||||
"aten/src/ATen/native/WeightNorm.cpp",
|
||||
"aten/src/ATen/native/group_norm.cpp",
|
||||
"aten/src/ATen/native/layer_norm.cpp",
|
||||
"aten/src/ATen/native/nested/NestedTensorMath.cpp",
|
||||
"aten/src/ATen/native/sparse/ParamUtils.cpp",
|
||||
"aten/src/ATen/native/sparse/SoftMax.cpp",
|
||||
"aten/src/ATen/native/sparse/SparseBlas.cpp",
|
||||
|
@ -139,6 +139,7 @@ dispatch_keys = [
|
||||
DispatchKey.QuantizedCUDA,
|
||||
DispatchKey.CompositeImplicitAutograd,
|
||||
DispatchKey.CompositeExplicitAutograd,
|
||||
DispatchKey.NestedTensor,
|
||||
# Meta is a magic key: it is automatically generated for structured
|
||||
# kernels
|
||||
DispatchKey.Meta,
|
||||
|
@ -819,6 +819,9 @@ from torch import __config__ as __config__
|
||||
from torch import __future__ as __future__
|
||||
from torch import profiler as profiler
|
||||
|
||||
from torch.nested._nestedtensor import NestedTensor
|
||||
from torch.nested._nestedtensor import nested_tensor
|
||||
|
||||
_C._init_names(list(torch._storage_classes))
|
||||
|
||||
# attach docstrings to torch and tensor functions
|
||||
|
0
torch/nested/__init__.py
Normal file
0
torch/nested/__init__.py
Normal file
109
torch/nested/_nestedtensor.py
Normal file
109
torch/nested/_nestedtensor.py
Normal file
@ -0,0 +1,109 @@
|
||||
import torch
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@wraps(torch._nested_tensor)
|
||||
def nested_tensor(*args, **kwargs):
|
||||
return NestedTensor(torch._nested_tensor(*args, **kwargs))
|
||||
|
||||
|
||||
# TODO: This entire class is not really necessary now that NestedTensor lives
|
||||
# in tree; before it lived out of tree and there was no way to conveniently
|
||||
# override the string printing behavior. Now that we are in tree, we can
|
||||
# directly override _tensor_str to capture this behavior, and the wrapper subclass
|
||||
# is not necessary. See also https://github.com/pytorch/pytorch/issues/73506
|
||||
class NestedTensor:
|
||||
# data is a torch.Tensor backed by a NestedTensorImpl
|
||||
|
||||
def __init__(self, impl):
|
||||
self._impl = impl
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
"""
|
||||
The data type of ```self``` NestedTensor.
|
||||
"""
|
||||
return self._impl.dtype
|
||||
|
||||
@property
|
||||
def layout(self):
|
||||
"""
|
||||
The layout of ```self``` NestedTensor.
|
||||
"""
|
||||
return self._impl.layout
|
||||
|
||||
@property
|
||||
def device(self):
|
||||
"""
|
||||
The device of ```self``` NestedTensor.
|
||||
"""
|
||||
return self._impl.device
|
||||
|
||||
@property
|
||||
def requires_grad(self):
|
||||
"""
|
||||
Is ```True``` if gradients need to be computed for this Tensor.
|
||||
"""
|
||||
return self._impl.requires_grad
|
||||
|
||||
def stride(self):
|
||||
"""
|
||||
NestedTensor currently does not have a stride. This will throw.
|
||||
"""
|
||||
return self._impl.stride()
|
||||
|
||||
def size(self):
|
||||
"""
|
||||
NestedTensor currently does not have a size. This will throw.
|
||||
"""
|
||||
return self._impl.size()
|
||||
|
||||
def dim(self):
|
||||
"""
|
||||
The dimension of ```self``` NestedTensor.
|
||||
"""
|
||||
tensors = self.unbind()
|
||||
if len(tensors) == 0:
|
||||
return 1
|
||||
return int(tensors[0].dim() + 1)
|
||||
|
||||
def numel(self):
|
||||
"""
|
||||
The number of elements of ```self``` NestedTensor.
|
||||
"""
|
||||
return self._impl.numel()
|
||||
|
||||
def is_contiguous(self):
|
||||
"""
|
||||
Returns true if ```self``` NestedTensor is contiguous.
|
||||
"""
|
||||
return self._impl.is_contiguous()
|
||||
|
||||
def __str__(self):
|
||||
def _str(x, indent=0, tab=" "):
|
||||
s = indent * tab + "[\n"
|
||||
strs = list(map(str, x.unbind()))
|
||||
strs = list(
|
||||
map(
|
||||
lambda xi: "\n".join(
|
||||
map(lambda xij: (indent + 1) * tab + xij, xi.split("\n"))
|
||||
),
|
||||
strs,
|
||||
)
|
||||
)
|
||||
s += ",\n".join(strs)
|
||||
s += "\n" + indent * tab + "]"
|
||||
return s
|
||||
|
||||
return "nested_tensor(" + _str(self) + ")"
|
||||
|
||||
def __repr__(self):
|
||||
return self.__str__()
|
||||
|
||||
def unbind(self, dim=None):
|
||||
if dim is None:
|
||||
unbound = torch.ops.aten.unbind.int(self._impl, 0)
|
||||
if len(unbound) == 0:
|
||||
return ()
|
||||
return unbound
|
||||
return torch.ops.aten.unbind.int(self._impl, dim)
|
Reference in New Issue
Block a user