Add NestedTensor python subclass (#108314)

Description coming soon

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108314
Approved by: https://github.com/jbschlosser
ghstack dependencies: #108808
This commit is contained in:
soulitzer
2023-09-08 13:24:21 -04:00
committed by PyTorch MergeBot
parent 4a4a2fc1a5
commit 2bcff92540
9 changed files with 465 additions and 22 deletions

View File

@ -28,6 +28,8 @@ from torch.testing._internal.common_utils import (
TestCase,
)
from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
# Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future.
@ -2826,10 +2828,99 @@ class TestNestedTensorAutograd(TestCase):
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# We can probably parametrizing existing tests instead of having a separate
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
class TestNestedTensorSubclass(TestCase):
def test_tensor_attributes(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
nt, _offsets = jagged_from_list([a, b, c], None)
for op in (
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.sym_numel.default,
):
op(nt)
with self.assertRaisesRegex(RuntimeError,
"directly calling torch.ops.aten.size"):
torch.ops.aten.size.default(nt)
singleton_int = torch.nested._internal.nested_tensor.get_tensor_id(_offsets)
self.assertEqual(nt.size(), (3, singleton_int, 3))
self.assertEqual(nt.shape, (3, singleton_int, 3))
self.assertEqual(nt.dim(), 3)
self.assertEqual(nt.numel(), 27)
for op in (
torch.ops.aten.sym_stride.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.sym_storage_offset.default,
):
error_msg = "NestedTensors do not support directly querying strides"
with self.assertRaisesRegex(RuntimeError, error_msg):
if "memory_format" in op.__name__:
op(nt, torch.preserve_format)
op(nt)
def test_linear(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
weight = torch.randn(3, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, weight):
nt, _ = jagged_from_list([a, b, c], None)
out = torch.nn.functional.linear(nt, weight)
return buffer_from_jagged(out)
gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
def test_unary_pointwise(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
weight = torch.randn(3, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, weight):
nt, _ = jagged_from_list([a, b, c], None)
out = nt.sin().cos()
return buffer_from_jagged(out)
gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
def test_binary_pointwise(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
# Incorrect usage: shape check will fail if the offsets tensor are not
# the same exact tensor object
nt1, _ = jagged_from_list([a, b, c], None)
nt2, _ = jagged_from_list([a, b, c], None)
self.assertRaisesRegex(
RuntimeError,
"expected lhs and rhs to have the same exact offsets tensor",
lambda: nt1 * nt2)
# Correct usage: chain the calls using the same offsets tensor object
def grad_test_func(a, b, c):
nt1, offsets = jagged_from_list([a, b, c], None)
nt2, offsets = jagged_from_list([a, b, c], offsets)
out = nt1 * nt2
return buffer_from_jagged(out)
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
instantiate_device_type_tests(TestNestedTensorSubclass, globals())
if __name__ == '__main__':
run_tests()

View File

@ -1433,6 +1433,8 @@ def _are_functorch_transforms_active() -> _bool: ...
# Define in torch/csrc/autograd/init.cpp
def _set_python_dispatcher(dispatcher: object) -> None: ...
def _get_singleton_int(id: _int) -> SymInt: ...
class _TorchDispatchModeKey(Enum):
${torch_dispatch_mode_key_hints}

View File

@ -201,11 +201,13 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
uint32_t add_input_metadata(
const at::TensorOptions& options,
c10::SymIntArrayRef shape,
bool is_tensor_subclass) noexcept {
bool is_tensor_subclass,
bool is_nested) noexcept {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
uint32_t input_nr = input_metadata_.size();
auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape};
input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
input_metadata_.emplace_back(
options, meta_shape, is_tensor_subclass, is_nested);
return input_nr;
}

View File

@ -43,10 +43,12 @@ struct InputMetadata {
InputMetadata(
const at::TensorOptions& options,
MetadataShape input_shape,
bool is_tensor_subclass)
bool is_tensor_subclass,
bool is_nested)
: options_{options},
shape_{std::move(input_shape)},
is_tensor_subclass_{is_tensor_subclass},
is_nested_{is_nested},
was_default_constructed_{false} {
auto device_ = options.device();
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
@ -56,7 +58,8 @@ struct InputMetadata {
: InputMetadata(
t.options(),
compute_variant_shape(t),
t.unsafeGetTensorImpl()->is_python_dispatch()) {}
t.unsafeGetTensorImpl()->is_python_dispatch(),
t.is_nested()) {}
const at::TensorOptions options() const {
return options_;
@ -84,26 +87,33 @@ struct InputMetadata {
at::Tensor zeros_like() const {
TORCH_CHECK(
!is_nested_tensor(),
"Zeros is not currently supported for nested tensors.")
!is_nested_, "Zeros is not currently supported for nested tensors.")
return at::zeros_symint(shape_as_dim_vector(), options_);
}
bool is_same_shape(const at::Tensor& grad) const {
void check_nestedness_same(const at::Tensor& grad) const {
bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
bool grad_is_nested = grad.is_nested();
bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
if (grad.is_nested()) {
return at::native::get_nested_sizes(grad).is_same_size(shape_as_tensor());
grad_is_cpp_nested == is_cpp_nested_tensor() &&
grad_is_nested == is_nested_,
"grad and the input wrt the gradient that is being computed for need to be "
"either both nested or both non-nested tensors. Also note that nested "
"tensors with different layouts do not compose currently.");
}
bool is_same_shape(const at::Tensor& grad) const {
check_nestedness_same(grad);
if (is_cpp_nested_tensor()) {
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
}
return grad.sym_sizes().equals(shape_as_dim_vector());
}
bool is_expandable_to_shape(const at::Tensor& grad) const {
// Currently NestedTensors are not expandable. If this support is added then
// updates to reduce_grad will be needed
TORCH_CHECK(
grad.is_nested() == is_nested_tensor(),
"Both grad and InputMetadata need to be either nested or non nested tensors.")
check_nestedness_same(grad);
return grad.is_nested()
? false
: at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
@ -113,7 +123,7 @@ struct InputMetadata {
// Currently reduce_grad is only called if is_expandable_to_shape returns
// true For nested tensors this always returns False, so this check
// shouldn't fail
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
@ -122,13 +132,13 @@ struct InputMetadata {
const at::Tensor& grad) const {
std::stringstream ss;
ss << "invalid gradient at index " << index << " - got ";
if (grad.is_nested()) {
ss << at::native::get_nested_sizes(grad);
if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
ss << grad._nested_tensor_size();
} else {
ss << grad.sym_sizes();
}
ss << " but expected shape compatible with ";
if (is_nested_tensor()) {
if (is_cpp_nested_tensor()) {
ss << shape_as_tensor();
} else {
ss << shape_as_dim_vector();
@ -140,8 +150,14 @@ struct InputMetadata {
return was_default_constructed_;
}
bool is_cpp_nested_tensor() const {
bool ret = c10::holds_alternative<at::Tensor>(shape_);
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
return ret;
}
bool is_nested_tensor() const {
return (c10::holds_alternative<at::Tensor>(shape_));
return is_nested_;
}
c10::SymIntArrayRef shape_as_dim_vector() const {
@ -156,8 +172,9 @@ struct InputMetadata {
private:
MetadataShape compute_variant_shape(const at::Tensor& input) {
if (input.is_nested()) {
auto nested_size = at::native::get_nested_sizes(input);
if (input.is_nested() &&
!input.unsafeGetTensorImpl()->is_python_dispatch()) {
auto nested_size = input._nested_tensor_size();
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
}
return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
@ -171,6 +188,7 @@ struct InputMetadata {
MetadataShape shape_;
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
bool is_tensor_subclass_ = false;
bool is_nested_ = false;
bool was_default_constructed_ = true;
};
} // namespace autograd

View File

@ -685,7 +685,8 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
view_info.base_.options(),
self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
// intentional
self.unsafeGetTensorImpl()->is_python_dispatch());
self.unsafeGetTensorImpl()->is_python_dispatch(),
self.is_nested());
diff_view_meta->grad_fn_ = std::move(fn);
}
diff_view_meta->set_attr_version(current_version);

View File

View File

@ -0,0 +1,145 @@
from typing import Tuple
import torch
from torch._C import DispatchKey, DispatchKeySet
from torch.utils.weak import WeakTensorKeyDictionary
from typing import * # noqa: F403
_tensor_id_counter = 0
_tensor_id_registry = WeakTensorKeyDictionary()
def get_tensor_id(tensor):
global _tensor_id_counter
if tensor not in _tensor_id_registry:
_tensor_id_registry[tensor] = _tensor_id_counter
_tensor_id_counter += 1
return torch._C._get_singleton_int(_tensor_id_registry[tensor])
class NestedTensor(torch.Tensor):
_values: torch.Tensor # type: ignore[assignment]
_offsets: torch.Tensor
_size: Tuple[int, int, int]
__torch_function__ = torch._C._disabled_torch_function_impl
@staticmethod
def __new__(cls, values, offsets, **kwargs):
ks = DispatchKeySet(DispatchKey.NestedTensor)
ks = ks.add(DispatchKey.AutogradNestedTensor)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
cls,
(0,),
(0,),
0,
torch.contiguous_format,
values.dtype,
values.layout,
values.device,
False,
False,
"sizes",
False,
False,
ks,
)
# TODO: why is values requires grad?
# if r.requires_grad:
# raise ValueError(
# "buffer should not require grad when constructing NestedTensor")
r._values = values.detach() if values.requires_grad else values
return r
def __init__(self, values, offsets, **kwargs):
super().__init__()
# Only support jagged for now.
assert offsets is not None
assert offsets.ndim == 1
assert not isinstance(values, NestedTensor)
assert values.ndim == 2
# In a later PR, we'll need to accept an additional size argument
# to handle dynamic shapes.
ragged_dim = get_tensor_id(offsets)
D = values.shape[1]
B = offsets.shape[0] - 1
self._size = (B, ragged_dim, D)
self._offsets = offsets
return
def values(self):
return self._values
def offsets(self):
return self._offsets
def __repr__(self):
# We should implement this in torch/_tensor_str.py instead
grad_fn_str = (
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
)
if self.grad_fn:
grad_fn_str = f", grad_fn={self.grad_fn}"
return f"NestedTensor(size={self._size}, offsets={self.offsets}{grad_fn_str})"
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs
# Lazy import to avoid circular dependency
from .ops import lookup_jagged
fn = lookup_jagged(func, *args, **kwargs)
if fn is not None:
return fn(*args, **kwargs)
raise NotImplementedError
# Not actually a view!
class ViewBufferFromNested(torch.autograd.Function):
@staticmethod
def forward(ctx, x: NestedTensor): # type: ignore[override]
ctx.save_for_backward(x.offsets())
return x.values()
@staticmethod
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
(offsets,) = ctx.saved_tensors
return NestedTensor(gO, offsets=offsets)
# Not actually a view!
class ViewNestedFromBuffer(torch.autograd.Function):
@staticmethod
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor): # type: ignore[override]
return NestedTensor(values, offsets=offsets)
@staticmethod
def backward(ctx, gO: NestedTensor): # type: ignore[override]
return gO.values(), None, None
# Need to make it obvious that users should be passing in offsets
def jagged_from_list(
tensors: Sequence[torch.Tensor], offsets: Optional[torch.Tensor]
) -> Tuple[NestedTensor, torch.Tensor]:
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
assert len(set(t.dtype for t in tensors)) == 1 # noqa: C401
assert len(set(t.device for t in tensors)) == 1 # noqa: C401
assert all(t.ndim == 2 for t in tensors)
assert len(set(t.shape[1] for t in tensors)) == 1 # noqa: C401
lengths = torch.tensor([t.shape[0] for t in tensors])
_offsets = torch.cat([torch.tensor([0]), lengths.cumsum(0)])
if offsets is not None:
assert torch.all(offsets == _offsets).item()
else:
offsets = _offsets
return ViewNestedFromBuffer.apply(torch.cat(tensors, dim=0), offsets), offsets # type: ignore[call-overload]
def buffer_from_jagged(jagged):
return ViewBufferFromNested.apply(jagged)

View File

@ -0,0 +1,183 @@
import functools
import torch
from .nested_tensor import NestedTensor
from typing import * # noqa: F403
__all__: List[Any] = []
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
named_arg_types = schema_str.split(", ")
num_optional_args = sum([x.endswith("?") for x in named_arg_types])
min_args = len(named_arg_types) - num_optional_args
if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
f"arguments and at most {len(named_arg_types)} arguments, but got: "
f"{len(args)} arguments"
)
arg_type_check_fns = {
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
"jt": lambda x: isinstance(x, NestedTensor),
"any": lambda x: True,
}
for i, named_arg_type in enumerate(named_arg_types):
name, arg_type = named_arg_type.split(": ")
is_optional = arg_type.endswith("?")
normalized_arg_type = arg_type[:-1] if is_optional else arg_type
if normalized_arg_type not in arg_type_check_fns.keys():
raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
if i >= len(args):
if not is_optional:
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}) "
f"missing required argument: {name}"
)
continue
if not arg_type_check_fns[normalized_arg_type](args[i]):
raise ValueError(
f"NestedTensor {func.__name__}({schema_str}): {name} should be of "
f"type {arg_type}, but got: {type(args[i])}"
)
def check_ragged_dim_same(
func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
) -> None:
# Calling into .shape here
assert len(a._size) == 3, "NestedTensor must be [B, *, D]"
if a._size[1] != b._size[1]:
raise RuntimeError(
f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
"same exact offsets tensor."
)
def register_func(tables, aten_ops, schema_str):
if not isinstance(aten_ops, list):
aten_ops = [aten_ops]
if not isinstance(tables, list):
tables = [tables]
def wrapper(func):
for aten_op in aten_ops:
def get_inner(aten_op):
def inner(*args, **kwargs):
check_schema(schema_str, func, *args, **kwargs)
return func(aten_op, *args, **kwargs)
return inner
for table in tables:
table[aten_op] = get_inner(aten_op)
return wrapper
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
if torch.Tag.pointwise in func.tags:
# Assume there aren't additional tensors that aren't the "unary/binary" args
num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args])
if num_tensor_args == 1:
return functools.partial(jagged_unary_pointwise, func)
elif num_tensor_args == 2:
check_schema("lhs: jt, rhs: jt", func, *args, **kwargs)
return functools.partial(jagged_binary_pointwise, func)
else:
return None
return JAGGED_OPS_TABLE.get(func, None)
def extract_kwargs(arg):
kwargs = {
"offsets": arg.offsets(),
}
return kwargs
def jagged_unary_pointwise(func, *args, **kwargs):
return NestedTensor(func(args[0].values(), **kwargs), **extract_kwargs(args[0]))
def jagged_binary_pointwise(func, *args, **kwargs):
check_ragged_dim_same(func, args[0], "lhs", args[1], "rhs")
return NestedTensor(
func(args[0].values(), args[1].values(), **kwargs), **extract_kwargs(args[0])
)
@register_jagged_func(
[
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.sym_numel.default,
],
"self: jt",
)
def tensor_attr_supported_getter(func, *args, **kwargs):
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
return False
if func == torch.ops.aten.sym_size.default:
return args[0]._size
if func == torch.ops.aten.dim.default:
return 3
if func == torch.ops.aten.sym_numel.default:
return args[0].values().numel()
@register_jagged_func(
[
torch.ops.aten.size.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.is_contiguous.default,
torch.ops.aten.is_contiguous.memory_format,
torch.ops.aten.sym_storage_offset.default,
],
"self: jt, memory_format: any?",
)
def tensor_attr_unsupported_getter(func, *args, **kwargs):
if func == torch.ops.aten.size.default:
raise RuntimeError(
"NestedTensors does not support directly calling torch.ops.aten.size "
"please use `nested_tensor.size()` instead."
)
raise RuntimeError(
"NestedTensors do not support directly querying strides, "
"storage_offset, or contiguity."
)
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
def linear_default(func, *args, **kwargs):
values = torch.mm(args[0].values(), args[1])
if len(args) == 3:
values += args[2]
return NestedTensor(values, **extract_kwargs(args[0]))
@register_jagged_func(
torch.ops.aten.linear_backward.default,
"self: jt, grad_output: jt, weight: t, output_mask: any",
)
def linear_backward_default(func, *args, **kwargs):
check_ragged_dim_same(func, args[0], "self", args[1], "grad_output")
ds = NestedTensor(torch.mm(args[1].values(), args[2].T), **extract_kwargs(args[1]))
dw = torch.mm(args[0].values().T, args[1].values())
db = None # NYI: gradient for bias, need to reduce over ragged dim
return (ds, dw, db)

View File

@ -84,6 +84,7 @@ class DispatchKey(Enum):
Sparse = auto()
SparseCsrCPU = auto()
SparseCsrCUDA = auto()
NestedTensor = auto()
Python = auto()
FuncTorchDynamicLayerBackMode = auto()