mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-04 08:00:58 +08:00
Add NestedTensor python subclass (#108314)
Description coming soon Pull Request resolved: https://github.com/pytorch/pytorch/pull/108314 Approved by: https://github.com/jbschlosser ghstack dependencies: #108808
This commit is contained in:
committed by
PyTorch MergeBot
parent
4a4a2fc1a5
commit
2bcff92540
@ -28,6 +28,8 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
)
|
||||
|
||||
from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
|
||||
|
||||
# Tests are ported from pytorch/nestedtensor.
|
||||
# This makes porting as_nested_tensor easier in the future.
|
||||
|
||||
@ -2826,10 +2828,99 @@ class TestNestedTensorAutograd(TestCase):
|
||||
data = (a, b, c)
|
||||
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
||||
|
||||
# We can probably parametrizing existing tests instead of having a separate
|
||||
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
|
||||
class TestNestedTensorSubclass(TestCase):
|
||||
def test_tensor_attributes(self, device):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
nt, _offsets = jagged_from_list([a, b, c], None)
|
||||
|
||||
for op in (
|
||||
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.dim.default,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
):
|
||||
op(nt)
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
"directly calling torch.ops.aten.size"):
|
||||
torch.ops.aten.size.default(nt)
|
||||
|
||||
singleton_int = torch.nested._internal.nested_tensor.get_tensor_id(_offsets)
|
||||
self.assertEqual(nt.size(), (3, singleton_int, 3))
|
||||
self.assertEqual(nt.shape, (3, singleton_int, 3))
|
||||
self.assertEqual(nt.dim(), 3)
|
||||
self.assertEqual(nt.numel(), 27)
|
||||
|
||||
for op in (
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
):
|
||||
error_msg = "NestedTensors do not support directly querying strides"
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
if "memory_format" in op.__name__:
|
||||
op(nt, torch.preserve_format)
|
||||
op(nt)
|
||||
|
||||
def test_linear(self, device):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
weight = torch.randn(3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
|
||||
def grad_test_func(a, b, c, weight):
|
||||
nt, _ = jagged_from_list([a, b, c], None)
|
||||
out = torch.nn.functional.linear(nt, weight)
|
||||
return buffer_from_jagged(out)
|
||||
|
||||
gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
|
||||
|
||||
def test_unary_pointwise(self, device):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
weight = torch.randn(3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
||||
|
||||
def grad_test_func(a, b, c, weight):
|
||||
nt, _ = jagged_from_list([a, b, c], None)
|
||||
out = nt.sin().cos()
|
||||
return buffer_from_jagged(out)
|
||||
|
||||
gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
|
||||
|
||||
def test_binary_pointwise(self, device):
|
||||
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
||||
|
||||
# Incorrect usage: shape check will fail if the offsets tensor are not
|
||||
# the same exact tensor object
|
||||
nt1, _ = jagged_from_list([a, b, c], None)
|
||||
nt2, _ = jagged_from_list([a, b, c], None)
|
||||
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"expected lhs and rhs to have the same exact offsets tensor",
|
||||
lambda: nt1 * nt2)
|
||||
|
||||
# Correct usage: chain the calls using the same offsets tensor object
|
||||
def grad_test_func(a, b, c):
|
||||
nt1, offsets = jagged_from_list([a, b, c], None)
|
||||
nt2, offsets = jagged_from_list([a, b, c], offsets)
|
||||
out = nt1 * nt2
|
||||
return buffer_from_jagged(out)
|
||||
|
||||
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
|
||||
instantiate_device_type_tests(TestNestedTensorSubclass, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
||||
@ -1433,6 +1433,8 @@ def _are_functorch_transforms_active() -> _bool: ...
|
||||
# Define in torch/csrc/autograd/init.cpp
|
||||
def _set_python_dispatcher(dispatcher: object) -> None: ...
|
||||
|
||||
def _get_singleton_int(id: _int) -> SymInt: ...
|
||||
|
||||
class _TorchDispatchModeKey(Enum):
|
||||
${torch_dispatch_mode_key_hints}
|
||||
|
||||
|
||||
@ -201,11 +201,13 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
uint32_t add_input_metadata(
|
||||
const at::TensorOptions& options,
|
||||
c10::SymIntArrayRef shape,
|
||||
bool is_tensor_subclass) noexcept {
|
||||
bool is_tensor_subclass,
|
||||
bool is_nested) noexcept {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
uint32_t input_nr = input_metadata_.size();
|
||||
auto meta_shape = MetadataShape{c10::in_place_type<SymIntSmallVec>, shape};
|
||||
input_metadata_.emplace_back(options, meta_shape, is_tensor_subclass);
|
||||
input_metadata_.emplace_back(
|
||||
options, meta_shape, is_tensor_subclass, is_nested);
|
||||
return input_nr;
|
||||
}
|
||||
|
||||
|
||||
@ -43,10 +43,12 @@ struct InputMetadata {
|
||||
InputMetadata(
|
||||
const at::TensorOptions& options,
|
||||
MetadataShape input_shape,
|
||||
bool is_tensor_subclass)
|
||||
bool is_tensor_subclass,
|
||||
bool is_nested)
|
||||
: options_{options},
|
||||
shape_{std::move(input_shape)},
|
||||
is_tensor_subclass_{is_tensor_subclass},
|
||||
is_nested_{is_nested},
|
||||
was_default_constructed_{false} {
|
||||
auto device_ = options.device();
|
||||
stream_ = c10::impl::getDeviceGuardImpl(device_.type())->getStream(device_);
|
||||
@ -56,7 +58,8 @@ struct InputMetadata {
|
||||
: InputMetadata(
|
||||
t.options(),
|
||||
compute_variant_shape(t),
|
||||
t.unsafeGetTensorImpl()->is_python_dispatch()) {}
|
||||
t.unsafeGetTensorImpl()->is_python_dispatch(),
|
||||
t.is_nested()) {}
|
||||
|
||||
const at::TensorOptions options() const {
|
||||
return options_;
|
||||
@ -84,26 +87,33 @@ struct InputMetadata {
|
||||
|
||||
at::Tensor zeros_like() const {
|
||||
TORCH_CHECK(
|
||||
!is_nested_tensor(),
|
||||
"Zeros is not currently supported for nested tensors.")
|
||||
!is_nested_, "Zeros is not currently supported for nested tensors.")
|
||||
return at::zeros_symint(shape_as_dim_vector(), options_);
|
||||
}
|
||||
|
||||
bool is_same_shape(const at::Tensor& grad) const {
|
||||
void check_nestedness_same(const at::Tensor& grad) const {
|
||||
bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
|
||||
bool grad_is_nested = grad.is_nested();
|
||||
bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
|
||||
TORCH_CHECK(
|
||||
grad.is_nested() == is_nested_tensor(),
|
||||
"Both grad and InputMetadata need to be either nested or non nested tensors.")
|
||||
if (grad.is_nested()) {
|
||||
return at::native::get_nested_sizes(grad).is_same_size(shape_as_tensor());
|
||||
grad_is_cpp_nested == is_cpp_nested_tensor() &&
|
||||
grad_is_nested == is_nested_,
|
||||
"grad and the input wrt the gradient that is being computed for need to be "
|
||||
"either both nested or both non-nested tensors. Also note that nested "
|
||||
"tensors with different layouts do not compose currently.");
|
||||
}
|
||||
|
||||
bool is_same_shape(const at::Tensor& grad) const {
|
||||
check_nestedness_same(grad);
|
||||
if (is_cpp_nested_tensor()) {
|
||||
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
|
||||
}
|
||||
return grad.sym_sizes().equals(shape_as_dim_vector());
|
||||
}
|
||||
bool is_expandable_to_shape(const at::Tensor& grad) const {
|
||||
// Currently NestedTensors are not expandable. If this support is added then
|
||||
// updates to reduce_grad will be needed
|
||||
TORCH_CHECK(
|
||||
grad.is_nested() == is_nested_tensor(),
|
||||
"Both grad and InputMetadata need to be either nested or non nested tensors.")
|
||||
check_nestedness_same(grad);
|
||||
return grad.is_nested()
|
||||
? false
|
||||
: at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
||||
@ -113,7 +123,7 @@ struct InputMetadata {
|
||||
// Currently reduce_grad is only called if is_expandable_to_shape returns
|
||||
// true For nested tensors this always returns False, so this check
|
||||
// shouldn't fail
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_tensor())
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
|
||||
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
||||
}
|
||||
|
||||
@ -122,13 +132,13 @@ struct InputMetadata {
|
||||
const at::Tensor& grad) const {
|
||||
std::stringstream ss;
|
||||
ss << "invalid gradient at index " << index << " - got ";
|
||||
if (grad.is_nested()) {
|
||||
ss << at::native::get_nested_sizes(grad);
|
||||
if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
|
||||
ss << grad._nested_tensor_size();
|
||||
} else {
|
||||
ss << grad.sym_sizes();
|
||||
}
|
||||
ss << " but expected shape compatible with ";
|
||||
if (is_nested_tensor()) {
|
||||
if (is_cpp_nested_tensor()) {
|
||||
ss << shape_as_tensor();
|
||||
} else {
|
||||
ss << shape_as_dim_vector();
|
||||
@ -140,8 +150,14 @@ struct InputMetadata {
|
||||
return was_default_constructed_;
|
||||
}
|
||||
|
||||
bool is_cpp_nested_tensor() const {
|
||||
bool ret = c10::holds_alternative<at::Tensor>(shape_);
|
||||
TORCH_INTERNAL_ASSERT(ret == (is_nested_ && !is_tensor_subclass_))
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool is_nested_tensor() const {
|
||||
return (c10::holds_alternative<at::Tensor>(shape_));
|
||||
return is_nested_;
|
||||
}
|
||||
|
||||
c10::SymIntArrayRef shape_as_dim_vector() const {
|
||||
@ -156,8 +172,9 @@ struct InputMetadata {
|
||||
|
||||
private:
|
||||
MetadataShape compute_variant_shape(const at::Tensor& input) {
|
||||
if (input.is_nested()) {
|
||||
auto nested_size = at::native::get_nested_sizes(input);
|
||||
if (input.is_nested() &&
|
||||
!input.unsafeGetTensorImpl()->is_python_dispatch()) {
|
||||
auto nested_size = input._nested_tensor_size();
|
||||
return MetadataShape{c10::in_place_type<at::Tensor>, nested_size};
|
||||
}
|
||||
return MetadataShape{c10::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
||||
@ -171,6 +188,7 @@ struct InputMetadata {
|
||||
MetadataShape shape_;
|
||||
c10::Stream stream_ = c10::Stream(c10::Stream::Default::DEFAULT, device());
|
||||
bool is_tensor_subclass_ = false;
|
||||
bool is_nested_ = false;
|
||||
bool was_default_constructed_ = true;
|
||||
};
|
||||
} // namespace autograd
|
||||
|
||||
@ -685,7 +685,8 @@ const std::shared_ptr<torch::autograd::Node>& VariableHooks::grad_fn(
|
||||
view_info.base_.options(),
|
||||
self.sym_sizes(), // Note: sizes(), not base_.sizes(), is
|
||||
// intentional
|
||||
self.unsafeGetTensorImpl()->is_python_dispatch());
|
||||
self.unsafeGetTensorImpl()->is_python_dispatch(),
|
||||
self.is_nested());
|
||||
diff_view_meta->grad_fn_ = std::move(fn);
|
||||
}
|
||||
diff_view_meta->set_attr_version(current_version);
|
||||
|
||||
0
torch/nested/_internal/__init__.py
Normal file
0
torch/nested/_internal/__init__.py
Normal file
145
torch/nested/_internal/nested_tensor.py
Normal file
145
torch/nested/_internal/nested_tensor.py
Normal file
@ -0,0 +1,145 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch._C import DispatchKey, DispatchKeySet
|
||||
from torch.utils.weak import WeakTensorKeyDictionary
|
||||
from typing import * # noqa: F403
|
||||
|
||||
_tensor_id_counter = 0
|
||||
_tensor_id_registry = WeakTensorKeyDictionary()
|
||||
|
||||
|
||||
def get_tensor_id(tensor):
|
||||
global _tensor_id_counter
|
||||
if tensor not in _tensor_id_registry:
|
||||
_tensor_id_registry[tensor] = _tensor_id_counter
|
||||
_tensor_id_counter += 1
|
||||
return torch._C._get_singleton_int(_tensor_id_registry[tensor])
|
||||
|
||||
|
||||
class NestedTensor(torch.Tensor):
|
||||
_values: torch.Tensor # type: ignore[assignment]
|
||||
_offsets: torch.Tensor
|
||||
_size: Tuple[int, int, int]
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@staticmethod
|
||||
def __new__(cls, values, offsets, **kwargs):
|
||||
ks = DispatchKeySet(DispatchKey.NestedTensor)
|
||||
ks = ks.add(DispatchKey.AutogradNestedTensor)
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
cls,
|
||||
(0,),
|
||||
(0,),
|
||||
0,
|
||||
torch.contiguous_format,
|
||||
values.dtype,
|
||||
values.layout,
|
||||
values.device,
|
||||
False,
|
||||
False,
|
||||
"sizes",
|
||||
False,
|
||||
False,
|
||||
ks,
|
||||
)
|
||||
# TODO: why is values requires grad?
|
||||
# if r.requires_grad:
|
||||
# raise ValueError(
|
||||
# "buffer should not require grad when constructing NestedTensor")
|
||||
r._values = values.detach() if values.requires_grad else values
|
||||
return r
|
||||
|
||||
def __init__(self, values, offsets, **kwargs):
|
||||
super().__init__()
|
||||
# Only support jagged for now.
|
||||
assert offsets is not None
|
||||
assert offsets.ndim == 1
|
||||
assert not isinstance(values, NestedTensor)
|
||||
assert values.ndim == 2
|
||||
|
||||
# In a later PR, we'll need to accept an additional size argument
|
||||
# to handle dynamic shapes.
|
||||
ragged_dim = get_tensor_id(offsets)
|
||||
D = values.shape[1]
|
||||
B = offsets.shape[0] - 1
|
||||
self._size = (B, ragged_dim, D)
|
||||
self._offsets = offsets
|
||||
return
|
||||
|
||||
def values(self):
|
||||
return self._values
|
||||
|
||||
def offsets(self):
|
||||
return self._offsets
|
||||
|
||||
def __repr__(self):
|
||||
# We should implement this in torch/_tensor_str.py instead
|
||||
grad_fn_str = (
|
||||
f", requires_grad={self.requires_grad}" if self.requires_grad else ""
|
||||
)
|
||||
if self.grad_fn:
|
||||
grad_fn_str = f", grad_fn={self.grad_fn}"
|
||||
return f"NestedTensor(size={self._size}, offsets={self.offsets}{grad_fn_str})"
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
kwargs = {} if kwargs is None else kwargs
|
||||
|
||||
# Lazy import to avoid circular dependency
|
||||
from .ops import lookup_jagged
|
||||
|
||||
fn = lookup_jagged(func, *args, **kwargs)
|
||||
if fn is not None:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Not actually a view!
|
||||
class ViewBufferFromNested(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, x: NestedTensor): # type: ignore[override]
|
||||
ctx.save_for_backward(x.offsets())
|
||||
return x.values()
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO: torch.Tensor): # type: ignore[override]
|
||||
(offsets,) = ctx.saved_tensors
|
||||
return NestedTensor(gO, offsets=offsets)
|
||||
|
||||
|
||||
# Not actually a view!
|
||||
class ViewNestedFromBuffer(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, values: torch.Tensor, offsets: torch.Tensor): # type: ignore[override]
|
||||
return NestedTensor(values, offsets=offsets)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, gO: NestedTensor): # type: ignore[override]
|
||||
return gO.values(), None, None
|
||||
|
||||
|
||||
# Need to make it obvious that users should be passing in offsets
|
||||
def jagged_from_list(
|
||||
tensors: Sequence[torch.Tensor], offsets: Optional[torch.Tensor]
|
||||
) -> Tuple[NestedTensor, torch.Tensor]:
|
||||
"""Constructs a NestedTensor backed by jagged layout from a list of tensors"""
|
||||
assert len(set(t.dtype for t in tensors)) == 1 # noqa: C401
|
||||
assert len(set(t.device for t in tensors)) == 1 # noqa: C401
|
||||
assert all(t.ndim == 2 for t in tensors)
|
||||
assert len(set(t.shape[1] for t in tensors)) == 1 # noqa: C401
|
||||
|
||||
lengths = torch.tensor([t.shape[0] for t in tensors])
|
||||
_offsets = torch.cat([torch.tensor([0]), lengths.cumsum(0)])
|
||||
if offsets is not None:
|
||||
assert torch.all(offsets == _offsets).item()
|
||||
else:
|
||||
offsets = _offsets
|
||||
|
||||
return ViewNestedFromBuffer.apply(torch.cat(tensors, dim=0), offsets), offsets # type: ignore[call-overload]
|
||||
|
||||
|
||||
def buffer_from_jagged(jagged):
|
||||
return ViewBufferFromNested.apply(jagged)
|
||||
183
torch/nested/_internal/ops.py
Normal file
183
torch/nested/_internal/ops.py
Normal file
@ -0,0 +1,183 @@
|
||||
import functools
|
||||
|
||||
import torch
|
||||
from .nested_tensor import NestedTensor
|
||||
from typing import * # noqa: F403
|
||||
|
||||
__all__: List[Any] = []
|
||||
|
||||
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
|
||||
|
||||
|
||||
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
|
||||
named_arg_types = schema_str.split(", ")
|
||||
num_optional_args = sum([x.endswith("?") for x in named_arg_types])
|
||||
min_args = len(named_arg_types) - num_optional_args
|
||||
|
||||
if not (len(args) >= min_args and len(args) <= len(named_arg_types)):
|
||||
raise ValueError(
|
||||
f"NestedTensor {func.__name__}({schema_str}): expected at least {min_args} "
|
||||
f"arguments and at most {len(named_arg_types)} arguments, but got: "
|
||||
f"{len(args)} arguments"
|
||||
)
|
||||
|
||||
arg_type_check_fns = {
|
||||
"t": lambda x: isinstance(x, torch.Tensor) and not isinstance(x, NestedTensor),
|
||||
"jt": lambda x: isinstance(x, NestedTensor),
|
||||
"any": lambda x: True,
|
||||
}
|
||||
for i, named_arg_type in enumerate(named_arg_types):
|
||||
name, arg_type = named_arg_type.split(": ")
|
||||
is_optional = arg_type.endswith("?")
|
||||
normalized_arg_type = arg_type[:-1] if is_optional else arg_type
|
||||
if normalized_arg_type not in arg_type_check_fns.keys():
|
||||
raise AssertionError(f"Unknown arg type: {normalized_arg_type}")
|
||||
|
||||
if i >= len(args):
|
||||
if not is_optional:
|
||||
raise ValueError(
|
||||
f"NestedTensor {func.__name__}({schema_str}) "
|
||||
f"missing required argument: {name}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not arg_type_check_fns[normalized_arg_type](args[i]):
|
||||
raise ValueError(
|
||||
f"NestedTensor {func.__name__}({schema_str}): {name} should be of "
|
||||
f"type {arg_type}, but got: {type(args[i])}"
|
||||
)
|
||||
|
||||
|
||||
def check_ragged_dim_same(
|
||||
func, a: NestedTensor, a_name: str, b: NestedTensor, b_name: str
|
||||
) -> None:
|
||||
# Calling into .shape here
|
||||
assert len(a._size) == 3, "NestedTensor must be [B, *, D]"
|
||||
if a._size[1] != b._size[1]:
|
||||
raise RuntimeError(
|
||||
f"NestedTensor {func.__name__}: expected {a_name} and {b_name} to have the "
|
||||
"same exact offsets tensor."
|
||||
)
|
||||
|
||||
|
||||
def register_func(tables, aten_ops, schema_str):
|
||||
if not isinstance(aten_ops, list):
|
||||
aten_ops = [aten_ops]
|
||||
if not isinstance(tables, list):
|
||||
tables = [tables]
|
||||
|
||||
def wrapper(func):
|
||||
for aten_op in aten_ops:
|
||||
|
||||
def get_inner(aten_op):
|
||||
def inner(*args, **kwargs):
|
||||
check_schema(schema_str, func, *args, **kwargs)
|
||||
return func(aten_op, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
for table in tables:
|
||||
table[aten_op] = get_inner(aten_op)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
register_jagged_func = functools.partial(register_func, JAGGED_OPS_TABLE)
|
||||
|
||||
|
||||
def lookup_jagged(func, *args, **kwargs) -> Optional[Callable]:
|
||||
if torch.Tag.pointwise in func.tags:
|
||||
# Assume there aren't additional tensors that aren't the "unary/binary" args
|
||||
num_tensor_args = sum([isinstance(x, torch.Tensor) for x in args])
|
||||
if num_tensor_args == 1:
|
||||
return functools.partial(jagged_unary_pointwise, func)
|
||||
elif num_tensor_args == 2:
|
||||
check_schema("lhs: jt, rhs: jt", func, *args, **kwargs)
|
||||
return functools.partial(jagged_binary_pointwise, func)
|
||||
else:
|
||||
return None
|
||||
return JAGGED_OPS_TABLE.get(func, None)
|
||||
|
||||
|
||||
def extract_kwargs(arg):
|
||||
kwargs = {
|
||||
"offsets": arg.offsets(),
|
||||
}
|
||||
return kwargs
|
||||
|
||||
|
||||
def jagged_unary_pointwise(func, *args, **kwargs):
|
||||
return NestedTensor(func(args[0].values(), **kwargs), **extract_kwargs(args[0]))
|
||||
|
||||
|
||||
def jagged_binary_pointwise(func, *args, **kwargs):
|
||||
check_ragged_dim_same(func, args[0], "lhs", args[1], "rhs")
|
||||
return NestedTensor(
|
||||
func(args[0].values(), args[1].values(), **kwargs), **extract_kwargs(args[0])
|
||||
)
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
[
|
||||
torch.ops.aten.is_non_overlapping_and_dense.default,
|
||||
torch.ops.aten.sym_size.default,
|
||||
torch.ops.aten.dim.default,
|
||||
torch.ops.aten.sym_numel.default,
|
||||
],
|
||||
"self: jt",
|
||||
)
|
||||
def tensor_attr_supported_getter(func, *args, **kwargs):
|
||||
if func == torch.ops.aten.is_non_overlapping_and_dense.default:
|
||||
return False
|
||||
|
||||
if func == torch.ops.aten.sym_size.default:
|
||||
return args[0]._size
|
||||
|
||||
if func == torch.ops.aten.dim.default:
|
||||
return 3
|
||||
|
||||
if func == torch.ops.aten.sym_numel.default:
|
||||
return args[0].values().numel()
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
[
|
||||
torch.ops.aten.size.default,
|
||||
torch.ops.aten.sym_stride.default,
|
||||
torch.ops.aten.is_contiguous.default,
|
||||
torch.ops.aten.is_contiguous.memory_format,
|
||||
torch.ops.aten.sym_storage_offset.default,
|
||||
],
|
||||
"self: jt, memory_format: any?",
|
||||
)
|
||||
def tensor_attr_unsupported_getter(func, *args, **kwargs):
|
||||
if func == torch.ops.aten.size.default:
|
||||
raise RuntimeError(
|
||||
"NestedTensors does not support directly calling torch.ops.aten.size "
|
||||
"please use `nested_tensor.size()` instead."
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
"NestedTensors do not support directly querying strides, "
|
||||
"storage_offset, or contiguity."
|
||||
)
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.linear.default, "input: jt, weight: t, bias: t?")
|
||||
def linear_default(func, *args, **kwargs):
|
||||
values = torch.mm(args[0].values(), args[1])
|
||||
if len(args) == 3:
|
||||
values += args[2]
|
||||
return NestedTensor(values, **extract_kwargs(args[0]))
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.linear_backward.default,
|
||||
"self: jt, grad_output: jt, weight: t, output_mask: any",
|
||||
)
|
||||
def linear_backward_default(func, *args, **kwargs):
|
||||
check_ragged_dim_same(func, args[0], "self", args[1], "grad_output")
|
||||
ds = NestedTensor(torch.mm(args[1].values(), args[2].T), **extract_kwargs(args[1]))
|
||||
dw = torch.mm(args[0].values().T, args[1].values())
|
||||
db = None # NYI: gradient for bias, need to reduce over ragged dim
|
||||
return (ds, dw, db)
|
||||
@ -84,6 +84,7 @@ class DispatchKey(Enum):
|
||||
Sparse = auto()
|
||||
SparseCsrCPU = auto()
|
||||
SparseCsrCUDA = auto()
|
||||
NestedTensor = auto()
|
||||
|
||||
Python = auto()
|
||||
FuncTorchDynamicLayerBackMode = auto()
|
||||
|
||||
Reference in New Issue
Block a user