mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
[NT] Backward support for broadcasting binary ops (#112519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112519 Approved by: https://github.com/jbschlosser ghstack dependencies: #113031
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5007d8d8e
commit
c2084da14a
@ -31,7 +31,11 @@ from torch.testing._internal.common_utils import (
|
|||||||
TestCase,
|
TestCase,
|
||||||
)
|
)
|
||||||
|
|
||||||
from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
|
from torch.nested._internal.nested_tensor import (
|
||||||
|
buffer_from_jagged,
|
||||||
|
jagged_from_list,
|
||||||
|
NestedTensor,
|
||||||
|
)
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
||||||
# Tests are ported from pytorch/nestedtensor.
|
# Tests are ported from pytorch/nestedtensor.
|
||||||
@ -2897,6 +2901,16 @@ class TestNestedTensorSubclass(TestCase):
|
|||||||
unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
|
unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# TODO: consolidate with the below
|
||||||
|
def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
|
||||||
|
Ds = nested_size[1:]
|
||||||
|
out = []
|
||||||
|
for s in nested_size[0]:
|
||||||
|
out.append(
|
||||||
|
torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64)
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
|
||||||
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
|
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
|
||||||
|
|
||||||
def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
|
def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
|
||||||
@ -3032,6 +3046,83 @@ class TestNestedTensorSubclass(TestCase):
|
|||||||
):
|
):
|
||||||
torch.split(nt, [1, 2], 1)
|
torch.split(nt, [1, 2], 1)
|
||||||
|
|
||||||
|
def test_binary_pointwise_broadcasting(self, device):
|
||||||
|
# (B, j0, 3, 4)
|
||||||
|
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True)
|
||||||
|
# (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
|
||||||
|
# (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
|
||||||
|
# (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
|
||||||
|
# Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
|
||||||
|
t_sizes = (
|
||||||
|
(4,),
|
||||||
|
(1, 4),
|
||||||
|
(3, 1),
|
||||||
|
(1, 3, 1),
|
||||||
|
(1, 1, 1, 4),
|
||||||
|
# (1, 1, 1, 1, 4), (unsupported today)
|
||||||
|
)
|
||||||
|
|
||||||
|
def grad_test_func(t, *ts):
|
||||||
|
nt, _ = jagged_from_list(ts, None)
|
||||||
|
out = nt + t
|
||||||
|
return buffer_from_jagged(out)
|
||||||
|
|
||||||
|
for t_size in t_sizes:
|
||||||
|
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
|
||||||
|
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
|
||||||
|
|
||||||
|
@parametrize("keepdim", [False, True])
|
||||||
|
def test_sum_int_DimList(self, device, keepdim):
|
||||||
|
# (B, j0, 3, 4)
|
||||||
|
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True)
|
||||||
|
|
||||||
|
# Check shape correctness
|
||||||
|
reduce_dims = (
|
||||||
|
# dims, expected shape, expected keepdim shape
|
||||||
|
# j0 is represented as None
|
||||||
|
((0, 1), (3, 4), (1, 1, 3, 4)),
|
||||||
|
((1, 2), None, None),
|
||||||
|
((2, 3), (3, None), (3, None, 1, 1)),
|
||||||
|
((0, 1, 3), (3,), (1, 1, 3, 1)),
|
||||||
|
((0, 1, 2), (4,), (1, 1, 1, 4)),
|
||||||
|
((0, 1, 2, 3), tuple(), (1, 1, 1, 1)),
|
||||||
|
)
|
||||||
|
for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
|
||||||
|
if (0 in rd) ^ (1 in rd):
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
RuntimeError,
|
||||||
|
"applying over the ragged dimension, but not the batch dimension"):
|
||||||
|
nt, _ = jagged_from_list(ts, None)
|
||||||
|
out = torch.sum(nt, dim=rd, keepdim=keepdim)
|
||||||
|
continue
|
||||||
|
|
||||||
|
nt, _ = jagged_from_list(ts, None)
|
||||||
|
out = torch.sum(nt, dim=rd, keepdim=keepdim)
|
||||||
|
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
|
||||||
|
self.assertEqual(len(out.shape), len(ref_shape))
|
||||||
|
for o, r in zip(out.shape, ref_shape):
|
||||||
|
if r is not None:
|
||||||
|
self.assertEqual(o, r)
|
||||||
|
else:
|
||||||
|
self.assertTrue(isinstance(o, torch.SymInt))
|
||||||
|
|
||||||
|
# Check values correctness
|
||||||
|
# raggedness not reduced
|
||||||
|
nt, _ = jagged_from_list(ts, None)
|
||||||
|
out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
|
||||||
|
out_ref = torch.sum(nt.values(), dim=(1, 2))
|
||||||
|
self.assertIsInstance(out, NestedTensor)
|
||||||
|
# flatten to avoid having to replicate unsqueeze logic depending on keepdim
|
||||||
|
self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
|
||||||
|
|
||||||
|
# raggedness reduced away
|
||||||
|
nt, _ = jagged_from_list(ts, None)
|
||||||
|
out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
|
||||||
|
out_ref = torch.sum(nt.values(), dim=(0,))
|
||||||
|
self.assertNotIsInstance(out, NestedTensor)
|
||||||
|
self.assertTrue(torch.allclose(out, out_ref))
|
||||||
|
|
||||||
|
|
||||||
@dtypes(torch.float, torch.double, torch.half)
|
@dtypes(torch.float, torch.double, torch.half)
|
||||||
@parametrize("requires_grad", [False, True])
|
@parametrize("requires_grad", [False, True])
|
||||||
@parametrize("weights_only", [False, True])
|
@parametrize("weights_only", [False, True])
|
||||||
|
@ -16,6 +16,14 @@ MetadataShape compute_variant_shape(const at::Tensor& input) {
|
|||||||
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool is_python_dispatch(const at::Tensor& tensor) {
|
||||||
|
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
|
||||||
|
return tensor.is_nested() && !is_python_dispatch(tensor);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
InputMetadata::InputMetadata(
|
InputMetadata::InputMetadata(
|
||||||
@ -36,7 +44,7 @@ InputMetadata::InputMetadata(const at::Tensor& t)
|
|||||||
: InputMetadata(
|
: InputMetadata(
|
||||||
t.options(),
|
t.options(),
|
||||||
compute_variant_shape(t),
|
compute_variant_shape(t),
|
||||||
t.unsafeGetTensorImpl()->is_python_dispatch(),
|
is_python_dispatch(t),
|
||||||
t.is_nested()) {}
|
t.is_nested()) {}
|
||||||
|
|
||||||
at::Tensor InputMetadata::zeros_like() const {
|
at::Tensor InputMetadata::zeros_like() const {
|
||||||
@ -46,7 +54,9 @@ at::Tensor InputMetadata::zeros_like() const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
||||||
check_nestedness_same(grad);
|
if (!is_nestedness_same(grad)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
if (is_cpp_nested_tensor()) {
|
if (is_cpp_nested_tensor()) {
|
||||||
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
|
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
|
||||||
}
|
}
|
||||||
@ -54,19 +64,15 @@ bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
|||||||
}
|
}
|
||||||
|
|
||||||
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
|
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
|
||||||
// Currently NestedTensors are not expandable. If this support is added then
|
if (!maybe_expandable_to(grad)) {
|
||||||
// updates to reduce_grad will be needed
|
return false;
|
||||||
check_nestedness_same(grad);
|
}
|
||||||
return grad.is_nested()
|
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
||||||
? false
|
|
||||||
: at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
|
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
|
||||||
// Currently reduce_grad is only called if is_expandable_to_shape returns
|
// reduce_grad should only be called if is_expandable_to_shape returns true.
|
||||||
// true For nested tensors this always returns False, so this check
|
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
|
||||||
// shouldn't fail
|
|
||||||
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
|
|
||||||
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -75,7 +81,7 @@ std::stringstream InputMetadata::incompatible_shape_error_message(
|
|||||||
const at::Tensor& grad) const {
|
const at::Tensor& grad) const {
|
||||||
std::stringstream ss{};
|
std::stringstream ss{};
|
||||||
ss << "invalid gradient at index " << index << " - got ";
|
ss << "invalid gradient at index " << index << " - got ";
|
||||||
if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
|
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
|
||||||
ss << grad._nested_tensor_size();
|
ss << grad._nested_tensor_size();
|
||||||
} else {
|
} else {
|
||||||
ss << grad.sym_sizes();
|
ss << grad.sym_sizes();
|
||||||
@ -106,21 +112,34 @@ SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
|
|||||||
return std::get<SymIntSmallVec>(shape_);
|
return std::get<SymIntSmallVec>(shape_);
|
||||||
}
|
}
|
||||||
|
|
||||||
void InputMetadata::check_nestedness_same(const at::Tensor& grad) const {
|
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
|
||||||
bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
|
return (
|
||||||
bool grad_is_nested = grad.is_nested();
|
grad.is_nested() == is_nested_ &&
|
||||||
bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
|
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
|
||||||
TORCH_CHECK(
|
|
||||||
grad_is_cpp_nested == is_cpp_nested_tensor() &&
|
|
||||||
grad_is_nested == is_nested_,
|
|
||||||
"grad and the input wrt the gradient that is being computed for need to be "
|
|
||||||
"either both nested or both non-nested tensors. Also note that nested "
|
|
||||||
"tensors with different layouts do not compose currently.");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
at::Tensor InputMetadata::shape_as_tensor() const {
|
at::Tensor InputMetadata::shape_as_tensor() const {
|
||||||
return std::get<at::Tensor>(shape_);
|
return std::get<at::Tensor>(shape_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
|
||||||
|
// This is the initial step to determine whether or not the tensor represented
|
||||||
|
// by input_metadata is expandable to grad based on is-nestedness information
|
||||||
|
// alone. If this function returns true, then is_expandable_to_shape will be
|
||||||
|
// called. We support the following 3 types of expansion:
|
||||||
|
bool grad_is_nested = grad.is_nested();
|
||||||
|
if (!is_nested_ && !grad_is_nested) {
|
||||||
|
// Normal case (no NestedTensors are involved)
|
||||||
|
// (1) plain Tensor -> plain Tensor
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
// (2) python NT -> python NT
|
||||||
|
// (3) plain Tensor -> python NT
|
||||||
|
return (
|
||||||
|
grad_is_nested && is_python_dispatch(grad) &&
|
||||||
|
(!is_nested_ || is_tensor_subclass_));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace autograd
|
} // namespace autograd
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
@ -98,7 +98,8 @@ struct TORCH_API InputMetadata {
|
|||||||
|
|
||||||
private:
|
private:
|
||||||
at::Tensor shape_as_tensor() const;
|
at::Tensor shape_as_tensor() const;
|
||||||
void check_nestedness_same(const at::Tensor& grad) const;
|
bool is_nestedness_same(const at::Tensor& grad) const;
|
||||||
|
bool maybe_expandable_to(const at::Tensor& grad) const;
|
||||||
|
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||||
const at::TensorOptions options_;
|
const at::TensorOptions options_;
|
||||||
|
@ -177,7 +177,8 @@ Example::
|
|||||||
|
|
||||||
from torch.nested._internal.nested_tensor import jagged_from_list
|
from torch.nested._internal.nested_tensor import jagged_from_list
|
||||||
|
|
||||||
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
|
with torch.no_grad():
|
||||||
|
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
|
||||||
|
|
||||||
nt.requires_grad_(requires_grad)
|
nt.requires_grad_(requires_grad)
|
||||||
if pin_memory:
|
if pin_memory:
|
||||||
|
@ -61,7 +61,7 @@ class NestedTensor(torch.Tensor):
|
|||||||
torch.jagged,
|
torch.jagged,
|
||||||
values.device,
|
values.device,
|
||||||
False,
|
False,
|
||||||
False,
|
kwargs.get("requires_grad", False),
|
||||||
"sizes",
|
"sizes",
|
||||||
False,
|
False,
|
||||||
True, # dispatch_layout
|
True, # dispatch_layout
|
||||||
|
@ -11,6 +11,13 @@ __all__: List[Any] = []
|
|||||||
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
|
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# Simplifying assumption: we assume that the batch dim is always the left-most
|
||||||
|
# dim, and the ragged dim is always the second dim.
|
||||||
|
def _outer_to_inner_dim(ndim, dim):
|
||||||
|
assert dim >= 0 and dim < ndim
|
||||||
|
return 0 if dim < 2 else dim - 1
|
||||||
|
|
||||||
|
|
||||||
def _wrap_jagged_dim(ndim, dim, op_name):
|
def _wrap_jagged_dim(ndim, dim, op_name):
|
||||||
from torch._prims_common import canonicalize_dims
|
from torch._prims_common import canonicalize_dims
|
||||||
|
|
||||||
@ -19,7 +26,29 @@ def _wrap_jagged_dim(ndim, dim, op_name):
|
|||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
|
f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
|
||||||
)
|
)
|
||||||
return wrapped - 1
|
return _outer_to_inner_dim(ndim, wrapped)
|
||||||
|
|
||||||
|
|
||||||
|
def _wrap_jagged_dims(ndim, dims, op_name):
|
||||||
|
# ex: (2, 3, 4) -> (1, 2, 3)
|
||||||
|
# ex: (0, 1, 4) -> (0, 3)
|
||||||
|
from torch._prims_common import canonicalize_dims
|
||||||
|
|
||||||
|
wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
|
||||||
|
# This logic needs to be done after we canonicalize dims but before we
|
||||||
|
# map to inner dims so we can print a nicer error message.
|
||||||
|
zero_in_dims = 0 in wrapped_dims
|
||||||
|
one_in_dims = 1 in wrapped_dims
|
||||||
|
if zero_in_dims ^ one_in_dims:
|
||||||
|
apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
|
||||||
|
" dimension is not supported for NestedTensor"
|
||||||
|
)
|
||||||
|
return (
|
||||||
|
tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
|
||||||
|
zero_in_dims,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
|
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
|
||||||
@ -79,6 +108,30 @@ def raggedness_matches(nt, size):
|
|||||||
return list(nt._size[:end]) == list(size[:end])
|
return list(nt._size[:end]) == list(size[:end])
|
||||||
|
|
||||||
|
|
||||||
|
def squeeze_leading_ones(t):
|
||||||
|
# Note: [ Squeezing leading ones ]
|
||||||
|
#
|
||||||
|
# Squeeze leading ones from t.
|
||||||
|
#
|
||||||
|
# We want:
|
||||||
|
# (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
|
||||||
|
# (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
|
||||||
|
#
|
||||||
|
# 1) Squeeze extra ones and grab values from NT
|
||||||
|
# (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
|
||||||
|
# 2) Do dense broadcasting:
|
||||||
|
# (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
|
||||||
|
# 3) Construct nested tensor
|
||||||
|
# (sum(*), ?, ?) -> (B, j0, ?, ?)
|
||||||
|
#
|
||||||
|
# If unsqueezing on the 0th dim becomes supported, we would unsqueeze
|
||||||
|
# at step (4) and we would need to update this function to record how
|
||||||
|
# many ones we unsqueezed.
|
||||||
|
while t.shape[0] == 1:
|
||||||
|
t = t.squeeze(0)
|
||||||
|
return t
|
||||||
|
|
||||||
|
|
||||||
def register_func(tables, aten_ops, schema_str):
|
def register_func(tables, aten_ops, schema_str):
|
||||||
if not isinstance(aten_ops, list):
|
if not isinstance(aten_ops, list):
|
||||||
aten_ops = [aten_ops]
|
aten_ops = [aten_ops]
|
||||||
@ -163,15 +216,17 @@ def jagged_binary_pointwise(func, *args, **kwargs):
|
|||||||
# === Handle broadcasting across the batch / ragged dims ===
|
# === Handle broadcasting across the batch / ragged dims ===
|
||||||
|
|
||||||
# Easy case: take advantage of pre-existing broadcasting logic
|
# Easy case: take advantage of pre-existing broadcasting logic
|
||||||
# when NT dim > non-NT dim
|
# ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
|
||||||
# ex: (B, j0, D_0, D_1) + (D_0, D_1) -> (B, j0, D_0, D_1)
|
# ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
|
||||||
# ex: (B, j0, D_0, D_1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
|
# ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
|
||||||
# ex: (B, j0, 1, 1) + (D_0, D_1) -> (B, j0, D_0, D_1)
|
nt, t = (a, b) if a_is_nt else (b, a)
|
||||||
# ex: (B, j0, 1, 1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
|
# See Note: [ Squeezing leading ones ]
|
||||||
if (a_is_nt and a.dim() > b.dim()) or (not a_is_nt and b.dim() > a.dim()):
|
if t.dim() > nt.dim():
|
||||||
arg1 = a._values if a_is_nt else a
|
raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
|
||||||
arg2 = b._values if not a_is_nt else b
|
t_squeezed = squeeze_leading_ones(t)
|
||||||
return NestedTensor(func(arg1, arg2, *args[2:], **kwargs), **extracted_kwargs)
|
if nt.dim() >= t_squeezed.dim() + 2:
|
||||||
|
lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
|
||||||
|
return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
|
||||||
|
|
||||||
# Harder case: do manual broadcasting over unbound components
|
# Harder case: do manual broadcasting over unbound components
|
||||||
# when NT dim == non-NT dim
|
# when NT dim == non-NT dim
|
||||||
@ -598,6 +653,31 @@ def is_same_size_default(func, *args, **kwargs):
|
|||||||
return args[0]._size == args[1]._size
|
return args[0]._size == args[1]._size
|
||||||
|
|
||||||
|
|
||||||
|
@register_jagged_func(
|
||||||
|
torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
|
||||||
|
)
|
||||||
|
def sum_dim_IntList(func, *args, **kwargs):
|
||||||
|
# sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
|
||||||
|
# are reduced away.
|
||||||
|
_, new_kwargs = normalize_function(
|
||||||
|
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||||
|
)
|
||||||
|
inp = new_kwargs.pop("input")
|
||||||
|
assert inp._ragged_idx == 1
|
||||||
|
new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
|
||||||
|
inp.dim(), new_kwargs["dim"], "sum"
|
||||||
|
)
|
||||||
|
|
||||||
|
if not ragged_reduced_away:
|
||||||
|
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
|
||||||
|
else:
|
||||||
|
# Don't wrap because we reduced away the raggedness
|
||||||
|
out = func(inp._values, **new_kwargs)
|
||||||
|
if new_kwargs["keepdim"]:
|
||||||
|
out = out.unsqueeze(0)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
@register_jagged_func(torch.ops.aten.transpose.int, "self: jt, dim0: any, dim1: any")
|
@register_jagged_func(torch.ops.aten.transpose.int, "self: jt, dim0: any, dim1: any")
|
||||||
def transpose_int(func, *args, **kwargs):
|
def transpose_int(func, *args, **kwargs):
|
||||||
_, new_kwargs = normalize_function(
|
_, new_kwargs = normalize_function(
|
||||||
|
Reference in New Issue
Block a user