mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[NT] Backward support for broadcasting binary ops (#112519)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112519 Approved by: https://github.com/jbschlosser ghstack dependencies: #113031
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5007d8d8e
commit
c2084da14a
@ -31,7 +31,11 @@ from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
)
|
||||
|
||||
from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
|
||||
from torch.nested._internal.nested_tensor import (
|
||||
buffer_from_jagged,
|
||||
jagged_from_list,
|
||||
NestedTensor,
|
||||
)
|
||||
import contextlib
|
||||
|
||||
# Tests are ported from pytorch/nestedtensor.
|
||||
@ -2897,6 +2901,16 @@ class TestNestedTensorSubclass(TestCase):
|
||||
unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
|
||||
)
|
||||
|
||||
# TODO: consolidate with the below
|
||||
def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
|
||||
Ds = nested_size[1:]
|
||||
out = []
|
||||
for s in nested_size[0]:
|
||||
out.append(
|
||||
torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64)
|
||||
)
|
||||
return out
|
||||
|
||||
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
|
||||
|
||||
def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
|
||||
@ -3032,6 +3046,83 @@ class TestNestedTensorSubclass(TestCase):
|
||||
):
|
||||
torch.split(nt, [1, 2], 1)
|
||||
|
||||
def test_binary_pointwise_broadcasting(self, device):
|
||||
# (B, j0, 3, 4)
|
||||
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True)
|
||||
# (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
|
||||
# (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
|
||||
# (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
|
||||
# Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
|
||||
t_sizes = (
|
||||
(4,),
|
||||
(1, 4),
|
||||
(3, 1),
|
||||
(1, 3, 1),
|
||||
(1, 1, 1, 4),
|
||||
# (1, 1, 1, 1, 4), (unsupported today)
|
||||
)
|
||||
|
||||
def grad_test_func(t, *ts):
|
||||
nt, _ = jagged_from_list(ts, None)
|
||||
out = nt + t
|
||||
return buffer_from_jagged(out)
|
||||
|
||||
for t_size in t_sizes:
|
||||
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
|
||||
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
|
||||
|
||||
@parametrize("keepdim", [False, True])
|
||||
def test_sum_int_DimList(self, device, keepdim):
|
||||
# (B, j0, 3, 4)
|
||||
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True)
|
||||
|
||||
# Check shape correctness
|
||||
reduce_dims = (
|
||||
# dims, expected shape, expected keepdim shape
|
||||
# j0 is represented as None
|
||||
((0, 1), (3, 4), (1, 1, 3, 4)),
|
||||
((1, 2), None, None),
|
||||
((2, 3), (3, None), (3, None, 1, 1)),
|
||||
((0, 1, 3), (3,), (1, 1, 3, 1)),
|
||||
((0, 1, 2), (4,), (1, 1, 1, 4)),
|
||||
((0, 1, 2, 3), tuple(), (1, 1, 1, 1)),
|
||||
)
|
||||
for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
|
||||
if (0 in rd) ^ (1 in rd):
|
||||
with self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"applying over the ragged dimension, but not the batch dimension"):
|
||||
nt, _ = jagged_from_list(ts, None)
|
||||
out = torch.sum(nt, dim=rd, keepdim=keepdim)
|
||||
continue
|
||||
|
||||
nt, _ = jagged_from_list(ts, None)
|
||||
out = torch.sum(nt, dim=rd, keepdim=keepdim)
|
||||
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
|
||||
self.assertEqual(len(out.shape), len(ref_shape))
|
||||
for o, r in zip(out.shape, ref_shape):
|
||||
if r is not None:
|
||||
self.assertEqual(o, r)
|
||||
else:
|
||||
self.assertTrue(isinstance(o, torch.SymInt))
|
||||
|
||||
# Check values correctness
|
||||
# raggedness not reduced
|
||||
nt, _ = jagged_from_list(ts, None)
|
||||
out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
|
||||
out_ref = torch.sum(nt.values(), dim=(1, 2))
|
||||
self.assertIsInstance(out, NestedTensor)
|
||||
# flatten to avoid having to replicate unsqueeze logic depending on keepdim
|
||||
self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
|
||||
|
||||
# raggedness reduced away
|
||||
nt, _ = jagged_from_list(ts, None)
|
||||
out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
|
||||
out_ref = torch.sum(nt.values(), dim=(0,))
|
||||
self.assertNotIsInstance(out, NestedTensor)
|
||||
self.assertTrue(torch.allclose(out, out_ref))
|
||||
|
||||
|
||||
@dtypes(torch.float, torch.double, torch.half)
|
||||
@parametrize("requires_grad", [False, True])
|
||||
@parametrize("weights_only", [False, True])
|
||||
|
@ -16,6 +16,14 @@ MetadataShape compute_variant_shape(const at::Tensor& input) {
|
||||
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
|
||||
}
|
||||
|
||||
bool is_python_dispatch(const at::Tensor& tensor) {
|
||||
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
|
||||
}
|
||||
|
||||
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
|
||||
return tensor.is_nested() && !is_python_dispatch(tensor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
InputMetadata::InputMetadata(
|
||||
@ -36,7 +44,7 @@ InputMetadata::InputMetadata(const at::Tensor& t)
|
||||
: InputMetadata(
|
||||
t.options(),
|
||||
compute_variant_shape(t),
|
||||
t.unsafeGetTensorImpl()->is_python_dispatch(),
|
||||
is_python_dispatch(t),
|
||||
t.is_nested()) {}
|
||||
|
||||
at::Tensor InputMetadata::zeros_like() const {
|
||||
@ -46,7 +54,9 @@ at::Tensor InputMetadata::zeros_like() const {
|
||||
}
|
||||
|
||||
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
||||
check_nestedness_same(grad);
|
||||
if (!is_nestedness_same(grad)) {
|
||||
return false;
|
||||
}
|
||||
if (is_cpp_nested_tensor()) {
|
||||
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
|
||||
}
|
||||
@ -54,19 +64,15 @@ bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
|
||||
}
|
||||
|
||||
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
|
||||
// Currently NestedTensors are not expandable. If this support is added then
|
||||
// updates to reduce_grad will be needed
|
||||
check_nestedness_same(grad);
|
||||
return grad.is_nested()
|
||||
? false
|
||||
: at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
||||
if (!maybe_expandable_to(grad)) {
|
||||
return false;
|
||||
}
|
||||
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
|
||||
}
|
||||
|
||||
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
|
||||
// Currently reduce_grad is only called if is_expandable_to_shape returns
|
||||
// true For nested tensors this always returns False, so this check
|
||||
// shouldn't fail
|
||||
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
|
||||
// reduce_grad should only be called if is_expandable_to_shape returns true.
|
||||
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
|
||||
return at::sum_to(std::move(grad), shape_as_dim_vector());
|
||||
}
|
||||
|
||||
@ -75,7 +81,7 @@ std::stringstream InputMetadata::incompatible_shape_error_message(
|
||||
const at::Tensor& grad) const {
|
||||
std::stringstream ss{};
|
||||
ss << "invalid gradient at index " << index << " - got ";
|
||||
if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
|
||||
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
|
||||
ss << grad._nested_tensor_size();
|
||||
} else {
|
||||
ss << grad.sym_sizes();
|
||||
@ -106,21 +112,34 @@ SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
|
||||
return std::get<SymIntSmallVec>(shape_);
|
||||
}
|
||||
|
||||
void InputMetadata::check_nestedness_same(const at::Tensor& grad) const {
|
||||
bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
|
||||
bool grad_is_nested = grad.is_nested();
|
||||
bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
|
||||
TORCH_CHECK(
|
||||
grad_is_cpp_nested == is_cpp_nested_tensor() &&
|
||||
grad_is_nested == is_nested_,
|
||||
"grad and the input wrt the gradient that is being computed for need to be "
|
||||
"either both nested or both non-nested tensors. Also note that nested "
|
||||
"tensors with different layouts do not compose currently.");
|
||||
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
|
||||
return (
|
||||
grad.is_nested() == is_nested_ &&
|
||||
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
|
||||
}
|
||||
|
||||
at::Tensor InputMetadata::shape_as_tensor() const {
|
||||
return std::get<at::Tensor>(shape_);
|
||||
}
|
||||
|
||||
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
|
||||
// This is the initial step to determine whether or not the tensor represented
|
||||
// by input_metadata is expandable to grad based on is-nestedness information
|
||||
// alone. If this function returns true, then is_expandable_to_shape will be
|
||||
// called. We support the following 3 types of expansion:
|
||||
bool grad_is_nested = grad.is_nested();
|
||||
if (!is_nested_ && !grad_is_nested) {
|
||||
// Normal case (no NestedTensors are involved)
|
||||
// (1) plain Tensor -> plain Tensor
|
||||
return true;
|
||||
} else {
|
||||
// (2) python NT -> python NT
|
||||
// (3) plain Tensor -> python NT
|
||||
return (
|
||||
grad_is_nested && is_python_dispatch(grad) &&
|
||||
(!is_nested_ || is_tensor_subclass_));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
@ -98,7 +98,8 @@ struct TORCH_API InputMetadata {
|
||||
|
||||
private:
|
||||
at::Tensor shape_as_tensor() const;
|
||||
void check_nestedness_same(const at::Tensor& grad) const;
|
||||
bool is_nestedness_same(const at::Tensor& grad) const;
|
||||
bool maybe_expandable_to(const at::Tensor& grad) const;
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
|
||||
const at::TensorOptions options_;
|
||||
|
@ -177,7 +177,8 @@ Example::
|
||||
|
||||
from torch.nested._internal.nested_tensor import jagged_from_list
|
||||
|
||||
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
|
||||
with torch.no_grad():
|
||||
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
|
||||
|
||||
nt.requires_grad_(requires_grad)
|
||||
if pin_memory:
|
||||
|
@ -61,7 +61,7 @@ class NestedTensor(torch.Tensor):
|
||||
torch.jagged,
|
||||
values.device,
|
||||
False,
|
||||
False,
|
||||
kwargs.get("requires_grad", False),
|
||||
"sizes",
|
||||
False,
|
||||
True, # dispatch_layout
|
||||
|
@ -11,6 +11,13 @@ __all__: List[Any] = []
|
||||
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
|
||||
|
||||
|
||||
# Simplifying assumption: we assume that the batch dim is always the left-most
|
||||
# dim, and the ragged dim is always the second dim.
|
||||
def _outer_to_inner_dim(ndim, dim):
|
||||
assert dim >= 0 and dim < ndim
|
||||
return 0 if dim < 2 else dim - 1
|
||||
|
||||
|
||||
def _wrap_jagged_dim(ndim, dim, op_name):
|
||||
from torch._prims_common import canonicalize_dims
|
||||
|
||||
@ -19,7 +26,29 @@ def _wrap_jagged_dim(ndim, dim, op_name):
|
||||
raise RuntimeError(
|
||||
f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
|
||||
)
|
||||
return wrapped - 1
|
||||
return _outer_to_inner_dim(ndim, wrapped)
|
||||
|
||||
|
||||
def _wrap_jagged_dims(ndim, dims, op_name):
|
||||
# ex: (2, 3, 4) -> (1, 2, 3)
|
||||
# ex: (0, 1, 4) -> (0, 3)
|
||||
from torch._prims_common import canonicalize_dims
|
||||
|
||||
wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
|
||||
# This logic needs to be done after we canonicalize dims but before we
|
||||
# map to inner dims so we can print a nicer error message.
|
||||
zero_in_dims = 0 in wrapped_dims
|
||||
one_in_dims = 1 in wrapped_dims
|
||||
if zero_in_dims ^ one_in_dims:
|
||||
apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
|
||||
raise RuntimeError(
|
||||
f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
|
||||
" dimension is not supported for NestedTensor"
|
||||
)
|
||||
return (
|
||||
tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
|
||||
zero_in_dims,
|
||||
)
|
||||
|
||||
|
||||
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
|
||||
@ -79,6 +108,30 @@ def raggedness_matches(nt, size):
|
||||
return list(nt._size[:end]) == list(size[:end])
|
||||
|
||||
|
||||
def squeeze_leading_ones(t):
|
||||
# Note: [ Squeezing leading ones ]
|
||||
#
|
||||
# Squeeze leading ones from t.
|
||||
#
|
||||
# We want:
|
||||
# (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
|
||||
# (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
|
||||
#
|
||||
# 1) Squeeze extra ones and grab values from NT
|
||||
# (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
|
||||
# 2) Do dense broadcasting:
|
||||
# (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
|
||||
# 3) Construct nested tensor
|
||||
# (sum(*), ?, ?) -> (B, j0, ?, ?)
|
||||
#
|
||||
# If unsqueezing on the 0th dim becomes supported, we would unsqueeze
|
||||
# at step (4) and we would need to update this function to record how
|
||||
# many ones we unsqueezed.
|
||||
while t.shape[0] == 1:
|
||||
t = t.squeeze(0)
|
||||
return t
|
||||
|
||||
|
||||
def register_func(tables, aten_ops, schema_str):
|
||||
if not isinstance(aten_ops, list):
|
||||
aten_ops = [aten_ops]
|
||||
@ -163,15 +216,17 @@ def jagged_binary_pointwise(func, *args, **kwargs):
|
||||
# === Handle broadcasting across the batch / ragged dims ===
|
||||
|
||||
# Easy case: take advantage of pre-existing broadcasting logic
|
||||
# when NT dim > non-NT dim
|
||||
# ex: (B, j0, D_0, D_1) + (D_0, D_1) -> (B, j0, D_0, D_1)
|
||||
# ex: (B, j0, D_0, D_1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
|
||||
# ex: (B, j0, 1, 1) + (D_0, D_1) -> (B, j0, D_0, D_1)
|
||||
# ex: (B, j0, 1, 1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
|
||||
if (a_is_nt and a.dim() > b.dim()) or (not a_is_nt and b.dim() > a.dim()):
|
||||
arg1 = a._values if a_is_nt else a
|
||||
arg2 = b._values if not a_is_nt else b
|
||||
return NestedTensor(func(arg1, arg2, *args[2:], **kwargs), **extracted_kwargs)
|
||||
# ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
|
||||
# ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
|
||||
# ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
|
||||
nt, t = (a, b) if a_is_nt else (b, a)
|
||||
# See Note: [ Squeezing leading ones ]
|
||||
if t.dim() > nt.dim():
|
||||
raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
|
||||
t_squeezed = squeeze_leading_ones(t)
|
||||
if nt.dim() >= t_squeezed.dim() + 2:
|
||||
lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
|
||||
return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
|
||||
|
||||
# Harder case: do manual broadcasting over unbound components
|
||||
# when NT dim == non-NT dim
|
||||
@ -598,6 +653,31 @@ def is_same_size_default(func, *args, **kwargs):
|
||||
return args[0]._size == args[1]._size
|
||||
|
||||
|
||||
@register_jagged_func(
|
||||
torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
|
||||
)
|
||||
def sum_dim_IntList(func, *args, **kwargs):
|
||||
# sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
|
||||
# are reduced away.
|
||||
_, new_kwargs = normalize_function(
|
||||
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
|
||||
)
|
||||
inp = new_kwargs.pop("input")
|
||||
assert inp._ragged_idx == 1
|
||||
new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
|
||||
inp.dim(), new_kwargs["dim"], "sum"
|
||||
)
|
||||
|
||||
if not ragged_reduced_away:
|
||||
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
|
||||
else:
|
||||
# Don't wrap because we reduced away the raggedness
|
||||
out = func(inp._values, **new_kwargs)
|
||||
if new_kwargs["keepdim"]:
|
||||
out = out.unsqueeze(0)
|
||||
return out
|
||||
|
||||
|
||||
@register_jagged_func(torch.ops.aten.transpose.int, "self: jt, dim0: any, dim1: any")
|
||||
def transpose_int(func, *args, **kwargs):
|
||||
_, new_kwargs = normalize_function(
|
||||
|
Reference in New Issue
Block a user