[NT] Backward support for broadcasting binary ops (#112519)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112519
Approved by: https://github.com/jbschlosser
ghstack dependencies: #113031
This commit is contained in:
soulitzer
2023-11-06 15:22:51 -05:00
committed by PyTorch MergeBot
parent d5007d8d8e
commit c2084da14a
6 changed files with 229 additions and 37 deletions

View File

@ -31,7 +31,11 @@ from torch.testing._internal.common_utils import (
TestCase,
)
from torch.nested._internal.nested_tensor import jagged_from_list, buffer_from_jagged
from torch.nested._internal.nested_tensor import (
buffer_from_jagged,
jagged_from_list,
NestedTensor,
)
import contextlib
# Tests are ported from pytorch/nestedtensor.
@ -2897,6 +2901,16 @@ class TestNestedTensorSubclass(TestCase):
unittest.mock.patch.object(torch._dynamo.config, "suppress_errors", False)
)
# TODO: consolidate with the below
def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
Ds = nested_size[1:]
out = []
for s in nested_size[0]:
out.append(
torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64)
)
return out
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
@ -3032,6 +3046,83 @@ class TestNestedTensorSubclass(TestCase):
):
torch.split(nt, [1, 2], 1)
def test_binary_pointwise_broadcasting(self, device):
# (B, j0, 3, 4)
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True)
# (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
# (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
# (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
# Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
t_sizes = (
(4,),
(1, 4),
(3, 1),
(1, 3, 1),
(1, 1, 1, 4),
# (1, 1, 1, 1, 4), (unsupported today)
)
def grad_test_func(t, *ts):
nt, _ = jagged_from_list(ts, None)
out = nt + t
return buffer_from_jagged(out)
for t_size in t_sizes:
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
@parametrize("keepdim", [False, True])
def test_sum_int_DimList(self, device, keepdim):
# (B, j0, 3, 4)
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True)
# Check shape correctness
reduce_dims = (
# dims, expected shape, expected keepdim shape
# j0 is represented as None
((0, 1), (3, 4), (1, 1, 3, 4)),
((1, 2), None, None),
((2, 3), (3, None), (3, None, 1, 1)),
((0, 1, 3), (3,), (1, 1, 3, 1)),
((0, 1, 2), (4,), (1, 1, 1, 4)),
((0, 1, 2, 3), tuple(), (1, 1, 1, 1)),
)
for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
if (0 in rd) ^ (1 in rd):
with self.assertRaisesRegex(
RuntimeError,
"applying over the ragged dimension, but not the batch dimension"):
nt, _ = jagged_from_list(ts, None)
out = torch.sum(nt, dim=rd, keepdim=keepdim)
continue
nt, _ = jagged_from_list(ts, None)
out = torch.sum(nt, dim=rd, keepdim=keepdim)
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
self.assertEqual(len(out.shape), len(ref_shape))
for o, r in zip(out.shape, ref_shape):
if r is not None:
self.assertEqual(o, r)
else:
self.assertTrue(isinstance(o, torch.SymInt))
# Check values correctness
# raggedness not reduced
nt, _ = jagged_from_list(ts, None)
out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
out_ref = torch.sum(nt.values(), dim=(1, 2))
self.assertIsInstance(out, NestedTensor)
# flatten to avoid having to replicate unsqueeze logic depending on keepdim
self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
# raggedness reduced away
nt, _ = jagged_from_list(ts, None)
out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
out_ref = torch.sum(nt.values(), dim=(0,))
self.assertNotIsInstance(out, NestedTensor)
self.assertTrue(torch.allclose(out, out_ref))
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("weights_only", [False, True])

View File

@ -16,6 +16,14 @@ MetadataShape compute_variant_shape(const at::Tensor& input) {
return MetadataShape{std::in_place_type<SymIntSmallVec>, input.sym_sizes()};
}
bool is_python_dispatch(const at::Tensor& tensor) {
return tensor.unsafeGetTensorImpl()->is_python_dispatch();
}
bool is_cpp_nested_tensor(const at::Tensor& tensor) {
return tensor.is_nested() && !is_python_dispatch(tensor);
}
} // namespace
InputMetadata::InputMetadata(
@ -36,7 +44,7 @@ InputMetadata::InputMetadata(const at::Tensor& t)
: InputMetadata(
t.options(),
compute_variant_shape(t),
t.unsafeGetTensorImpl()->is_python_dispatch(),
is_python_dispatch(t),
t.is_nested()) {}
at::Tensor InputMetadata::zeros_like() const {
@ -46,7 +54,9 @@ at::Tensor InputMetadata::zeros_like() const {
}
bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
check_nestedness_same(grad);
if (!is_nestedness_same(grad)) {
return false;
}
if (is_cpp_nested_tensor()) {
return grad._nested_tensor_size().is_same_size(shape_as_tensor());
}
@ -54,19 +64,15 @@ bool InputMetadata::is_same_shape(const at::Tensor& grad) const {
}
bool InputMetadata::is_expandable_to_shape(const at::Tensor& grad) const {
// Currently NestedTensors are not expandable. If this support is added then
// updates to reduce_grad will be needed
check_nestedness_same(grad);
return grad.is_nested()
? false
: at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
if (!maybe_expandable_to(grad)) {
return false;
}
return at::is_expandable_to(shape_as_dim_vector(), grad.sym_sizes());
}
at::Tensor InputMetadata::reduce_grad(at::Tensor& grad) const {
// Currently reduce_grad is only called if is_expandable_to_shape returns
// true For nested tensors this always returns False, so this check
// shouldn't fail
TORCH_INTERNAL_ASSERT(!grad.is_nested() && !is_nested_)
// reduce_grad should only be called if is_expandable_to_shape returns true.
TORCH_INTERNAL_ASSERT(maybe_expandable_to(grad));
return at::sum_to(std::move(grad), shape_as_dim_vector());
}
@ -75,7 +81,7 @@ std::stringstream InputMetadata::incompatible_shape_error_message(
const at::Tensor& grad) const {
std::stringstream ss{};
ss << "invalid gradient at index " << index << " - got ";
if (grad.is_nested() && !grad.unsafeGetTensorImpl()->is_python_dispatch()) {
if (::torch::autograd::is_cpp_nested_tensor(grad)) {
ss << grad._nested_tensor_size();
} else {
ss << grad.sym_sizes();
@ -106,21 +112,34 @@ SymIntSmallVec& InputMetadata::mutable_shape_as_dim_vector() {
return std::get<SymIntSmallVec>(shape_);
}
void InputMetadata::check_nestedness_same(const at::Tensor& grad) const {
bool grad_is_subclass = grad.unsafeGetTensorImpl()->is_python_dispatch();
bool grad_is_nested = grad.is_nested();
bool grad_is_cpp_nested = grad_is_nested && !grad_is_subclass;
TORCH_CHECK(
grad_is_cpp_nested == is_cpp_nested_tensor() &&
grad_is_nested == is_nested_,
"grad and the input wrt the gradient that is being computed for need to be "
"either both nested or both non-nested tensors. Also note that nested "
"tensors with different layouts do not compose currently.");
bool InputMetadata::is_nestedness_same(const at::Tensor& grad) const {
return (
grad.is_nested() == is_nested_ &&
::torch::autograd::is_cpp_nested_tensor(grad) == is_cpp_nested_tensor());
}
at::Tensor InputMetadata::shape_as_tensor() const {
return std::get<at::Tensor>(shape_);
}
bool InputMetadata::maybe_expandable_to(const at::Tensor& grad) const {
// This is the initial step to determine whether or not the tensor represented
// by input_metadata is expandable to grad based on is-nestedness information
// alone. If this function returns true, then is_expandable_to_shape will be
// called. We support the following 3 types of expansion:
bool grad_is_nested = grad.is_nested();
if (!is_nested_ && !grad_is_nested) {
// Normal case (no NestedTensors are involved)
// (1) plain Tensor -> plain Tensor
return true;
} else {
// (2) python NT -> python NT
// (3) plain Tensor -> python NT
return (
grad_is_nested && is_python_dispatch(grad) &&
(!is_nested_ || is_tensor_subclass_));
}
}
} // namespace autograd
} // namespace torch

View File

@ -98,7 +98,8 @@ struct TORCH_API InputMetadata {
private:
at::Tensor shape_as_tensor() const;
void check_nestedness_same(const at::Tensor& grad) const;
bool is_nestedness_same(const at::Tensor& grad) const;
bool maybe_expandable_to(const at::Tensor& grad) const;
// NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
const at::TensorOptions options_;

View File

@ -177,7 +177,8 @@ Example::
from torch.nested._internal.nested_tensor import jagged_from_list
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
with torch.no_grad():
nt, _ = jagged_from_list(list_of_tensors, offsets=None, device=device, dtype=dtype)
nt.requires_grad_(requires_grad)
if pin_memory:

View File

@ -61,7 +61,7 @@ class NestedTensor(torch.Tensor):
torch.jagged,
values.device,
False,
False,
kwargs.get("requires_grad", False),
"sizes",
False,
True, # dispatch_layout

View File

@ -11,6 +11,13 @@ __all__: List[Any] = []
JAGGED_OPS_TABLE: Dict[Any, Any] = {}
# Simplifying assumption: we assume that the batch dim is always the left-most
# dim, and the ragged dim is always the second dim.
def _outer_to_inner_dim(ndim, dim):
assert dim >= 0 and dim < ndim
return 0 if dim < 2 else dim - 1
def _wrap_jagged_dim(ndim, dim, op_name):
from torch._prims_common import canonicalize_dims
@ -19,7 +26,29 @@ def _wrap_jagged_dim(ndim, dim, op_name):
raise RuntimeError(
f"{op_name}(): not supported for NestedTensor on dim=0 or dim=1"
)
return wrapped - 1
return _outer_to_inner_dim(ndim, wrapped)
def _wrap_jagged_dims(ndim, dims, op_name):
# ex: (2, 3, 4) -> (1, 2, 3)
# ex: (0, 1, 4) -> (0, 3)
from torch._prims_common import canonicalize_dims
wrapped_dims = [canonicalize_dims(ndim, d) for d in dims]
# This logic needs to be done after we canonicalize dims but before we
# map to inner dims so we can print a nicer error message.
zero_in_dims = 0 in wrapped_dims
one_in_dims = 1 in wrapped_dims
if zero_in_dims ^ one_in_dims:
apply, not_apply = ("batch", "ragged") if zero_in_dims else ("ragged", "batch")
raise RuntimeError(
f"{op_name}(): applying over the {apply} dimension, but not the {not_apply}"
" dimension is not supported for NestedTensor"
)
return (
tuple(_outer_to_inner_dim(ndim, d) for d in dims if d != 0),
zero_in_dims,
)
def check_schema(schema_str: str, func, *args, **kwargs) -> None:
@ -79,6 +108,30 @@ def raggedness_matches(nt, size):
return list(nt._size[:end]) == list(size[:end])
def squeeze_leading_ones(t):
# Note: [ Squeezing leading ones ]
#
# Squeeze leading ones from t.
#
# We want:
# (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
# (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?) (not yet supported)
#
# 1) Squeeze extra ones and grab values from NT
# (1, 1, ?, ?) -> (?, ?) and (sum(*), ?, ?) -> (B, j0, ?, ?)
# 2) Do dense broadcasting:
# (sum(*), ?, ?) + (?, ?) -> (sum(*), ?, ?)
# 3) Construct nested tensor
# (sum(*), ?, ?) -> (B, j0, ?, ?)
#
# If unsqueezing on the 0th dim becomes supported, we would unsqueeze
# at step (4) and we would need to update this function to record how
# many ones we unsqueezed.
while t.shape[0] == 1:
t = t.squeeze(0)
return t
def register_func(tables, aten_ops, schema_str):
if not isinstance(aten_ops, list):
aten_ops = [aten_ops]
@ -163,15 +216,17 @@ def jagged_binary_pointwise(func, *args, **kwargs):
# === Handle broadcasting across the batch / ragged dims ===
# Easy case: take advantage of pre-existing broadcasting logic
# when NT dim > non-NT dim
# ex: (B, j0, D_0, D_1) + (D_0, D_1) -> (B, j0, D_0, D_1)
# ex: (B, j0, D_0, D_1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
# ex: (B, j0, 1, 1) + (D_0, D_1) -> (B, j0, D_0, D_1)
# ex: (B, j0, 1, 1) + (1, D_0, D_1) -> (B, j0, D_0, D_1)
if (a_is_nt and a.dim() > b.dim()) or (not a_is_nt and b.dim() > a.dim()):
arg1 = a._values if a_is_nt else a
arg2 = b._values if not a_is_nt else b
return NestedTensor(func(arg1, arg2, *args[2:], **kwargs), **extracted_kwargs)
# ex: (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
# ex: (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
# ex: (B, j0, ?, ?) + (1, 1, ?, ?) -> (B, j0, ?, ?)
nt, t = (a, b) if a_is_nt else (b, a)
# See Note: [ Squeezing leading ones ]
if t.dim() > nt.dim():
raise NotImplementedError("NYI: broadcasting NT with T with larger dim")
t_squeezed = squeeze_leading_ones(t)
if nt.dim() >= t_squeezed.dim() + 2:
lhs, rhs = (nt._values, t_squeezed) if a_is_nt else (t_squeezed, nt._values)
return NestedTensor(func(lhs, rhs, *args[2:], **kwargs), **extracted_kwargs)
# Harder case: do manual broadcasting over unbound components
# when NT dim == non-NT dim
@ -598,6 +653,31 @@ def is_same_size_default(func, *args, **kwargs):
return args[0]._size == args[1]._size
@register_jagged_func(
torch.ops.aten.sum.dim_IntList, "self: jt, dim: any?, keepdim: any?, dtype: any?"
)
def sum_dim_IntList(func, *args, **kwargs):
# sum_dim_IntList can produce a NT or a T depending on whether the ragged dims
# are reduced away.
_, new_kwargs = normalize_function(
func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
)
inp = new_kwargs.pop("input")
assert inp._ragged_idx == 1
new_kwargs["dim"], ragged_reduced_away = _wrap_jagged_dims(
inp.dim(), new_kwargs["dim"], "sum"
)
if not ragged_reduced_away:
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
else:
# Don't wrap because we reduced away the raggedness
out = func(inp._values, **new_kwargs)
if new_kwargs["keepdim"]:
out = out.unsqueeze(0)
return out
@register_jagged_func(torch.ops.aten.transpose.int, "self: jt, dim0: any, dim1: any")
def transpose_int(func, *args, **kwargs):
_, new_kwargs = normalize_function(