Files
pytorch/test/test_nestedtensor.py
2025-06-13 19:11:43 +00:00

9003 lines
348 KiB
Python

# Owner(s): ["module: nestedtensor"]
# ruff: noqa: F841
import ast
import io
import itertools
import math
import os
import random
import sys
import tempfile
import unittest
from functools import partial
from typing import Optional
import numpy as np
import torch
import torch._dynamo
import torch._dynamo.testing
import torch.nn
import torch.nn.functional as F
from torch.nested._internal.nested_tensor import (
buffer_from_jagged,
jagged_from_list,
nested_view_from_values_offsets,
NestedTensor,
ViewNestedFromBuffer,
)
from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention
from torch.testing._internal.common_cuda import (
PLATFORM_SUPPORTS_FUSED_ATTENTION,
SM70OrLater,
SM80OrLater,
tf32_on_and_off,
)
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
flex_attention_supported_platform,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
ops,
PYTORCH_CUDA_MEMCHECK,
skipCPUIf,
skipCUDAIf,
skipCUDAIfRocm,
skipMeta,
)
from torch.testing._internal.common_dtype import floating_types_and_half
from torch.testing._internal.common_utils import (
decorateIf,
freeze_rng_state,
gradcheck,
instantiate_parametrized_tests,
IS_FBCODE,
IS_WINDOWS,
markDynamoStrictTest,
NestedTensorTestCase,
parametrize,
run_tests,
serialTest,
skipIfRocm,
skipIfSlowGradcheckEnv,
skipIfTorchDynamo,
subtest,
TEST_WITH_ROCM,
xfailIfTorchDynamo,
)
from torch.testing._internal.opinfo.core import (
BinaryUfuncInfo,
ReductionOpInfo,
sample_skips_and_xfails,
SkipRule,
XFailRule,
)
from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
# Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future.
def _iter_constructors():
# yield as_nested_tensor
yield torch.nested.nested_tensor
# Returns True if the function recompiles between inputs1 and inputs2 with the
# specified dynamic setting.
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
compile_count = [0]
def counter(gm, example_inputs):
compile_count[0] += 1
return gm
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
compiled_f(*inputs1)
compiled_f(*inputs2)
return compile_count[0] > 1
# Helper function to generate a pair of random nested tensors
# one is contiguous, the other is not, but they appear to have same entries
# an output nested tensor consists of
# * `len(ragged_sizes)` matrices
# * matrices[i].shape == (20, ragged_sizes[i])
def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
xs = []
for size in ragged_sizes:
xs.append(torch.randn((size, 20), device=device, dtype=dtype))
# contiguous nested tensor
ys = []
for x in xs:
ys.append(x.transpose(-1, -2))
nt_contiguous = torch.nested.nested_tensor(ys)
# noncontiguous nested tensor
n = len(ragged_sizes)
nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2)
return nt_contiguous, nt_noncontiguous
# Helper functions to pad a noncontiguous nested tensor
# can be replaced once to_padded_tensor supports noncontiguous memory
def noncontiguous_to_padded_tensor(input, shape=None):
tensors = input.unbind()
ntensors = len(tensors)
assert ntensors > 0
if shape is None:
shape = []
for size in tensors[0].shape:
shape.append(size)
for i in range(1, ntensors):
new_shape = tensors[i].shape
for j in range(len(shape)):
shape[j] = max(shape[j], new_shape[j])
shape = [ntensors] + shape
result = tensors[0].new_zeros(shape)
for itensor in range(ntensors):
tensor = tensors[itensor]
view = result[itensor]
for idim in range(tensor.dim()):
view = view.narrow(idim, 0, tensor.size(idim))
view.copy_(tensor)
return result
# Helper function to generate a random nested tensor
def random_nt(
device,
dtype,
num_tensors,
max_dims,
min_dims=None,
layout=torch.strided,
require_non_empty=True,
):
if min_dims is None:
min_dims = tuple([0] * len(max_dims))
assert len(max_dims) == len(min_dims)
for min_dim, max_dim in zip(min_dims, max_dims):
assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
assert min_dim >= 0, "random_nt: min_dim must be non-negative"
if require_non_empty:
assert not (min_dim == 0 and max_dim == 1), (
"random_nt: zero cannot be the only possible value if require_non_empty is True"
)
if require_non_empty:
# Select a random idx that will be required to be non-empty
non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item()
ts1 = []
for i, _ in enumerate(range(num_tensors)):
tensor_dims = []
for min_dim, max_dim in zip(min_dims, max_dims):
new_min_dim = min_dim
if require_non_empty and i == non_zero_idx and min_dim == 0:
new_min_dim = 1
tensor_dims.append(
torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()
)
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)
# Alternate approach to generating a random NT.
# dims should be something like [5, None, 10], with None indicating that a
# random ragged structure should be used
def random_nt_from_dims(
dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
):
sizes = [
[
d if d is not None else torch.randint(2, 10, size=(1,)).item()
for d in dims[1:]
]
for d in range(dims[0])
]
return torch.nested.nested_tensor(
[torch.randn(*size) for size in sizes],
device=device,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
)
# Creates an NT matching another NT's number of components and
# shape / ragged structure for all dims specified to be -1.
def random_nt_from_similar(other, dims=None):
if dims is None:
return torch.randn_like(other)
assert len(dims) == other.dim()
assert dims[0] == -1 or dims[0] == other.size(0)
ret_sizes = []
for t in other.unbind():
other_size = t.shape
ret_size = []
for i, d in enumerate(dims[1:]):
if d == -1:
ret_size.append(other_size[i])
else:
ret_size.append(d)
ret_sizes.append(ret_size)
return torch.nested.nested_tensor(
[torch.randn(*size) for size in ret_sizes], device=other.device
)
# makes naming nice for tests that parametrize over layout.
def layout_name(layout):
# e.g. "torch.jagged" -> "jagged"
return layout.__repr__().split(".")[-1]
def get_op_name(layout):
# e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum"
return layout.__name__.split(".")[0].split("_")[-1]
# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_dense_to_nested_tensor_legacy(values):
offsets = torch.arange(
0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device
)
metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1}
nt = ViewNestedFromBuffer.apply(
values.view(-1, values.shape[-1]), offsets, metadata_cache
)
return nt
# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_jagged_to_nested_tensor_legacy(
values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
return nt
# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_nt_to_jagged_legacy(nt):
return buffer_from_jagged(nt)
# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_dense_to_nested_tensor(values):
nt = torch.nested.as_nested_tensor(values, layout=torch.jagged)
return nt
# Helper function for test_dummy_mha_with_nt
@torch.fx.wrap
def convert_jagged_to_nested_tensor(
values: torch.Tensor, offsets: torch.Tensor, max_length: int
) -> torch.Tensor:
nt = torch.nested.nested_tensor_from_jagged(
values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length
)
return nt
# Helper function for test_dummy_mha_with_nt
def convert_nt_to_jagged(nt):
return nt.values()
@markDynamoStrictTest
class TestNestedTensor(NestedTensorTestCase):
@parametrize("batch_size", [2, 4])
@parametrize("max_seq_len", [3, 5])
@parametrize("vocab_size", [10, 20])
def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
data = []
nested_tensor_ref_list = []
for _ in range(batch_size):
if max_seq_len == 0:
length = 0
else:
length = np.random.randint(low=1, high=max_seq_len)
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
data.append(row)
nested_tensor_ref_list.append(torch.Tensor(row))
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
nested_tensor_list = nested_tensor.unbind()
for id in range(batch_size):
self.assertEqual(
nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
)
@parametrize("batch_size", [2, 4])
@parametrize("max_seq_len", [3, 5])
@parametrize("vocab_size", [10, 20])
def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
data = []
nested_tensor_ref_list = []
for _ in range(batch_size):
if max_seq_len == 0:
length = 0
else:
length = np.random.randint(low=1, high=max_seq_len)
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
row = [list(item * np.arange(max_seq_len)) for item in row]
data.append(row)
nested_tensor_ref_list.append(torch.Tensor(row))
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
nested_tensor_list = nested_tensor.unbind()
for id in range(batch_size):
self.assertEqual(
nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
)
@parametrize("batch_size", [2, 4])
@parametrize("max_seq_len", [3, 5])
@parametrize("vocab_size", [10, 20])
def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
data = []
nested_tensor_ref_list = []
for _ in range(batch_size):
if max_seq_len == 0:
length = 0
else:
length = np.random.randint(low=1, high=max_seq_len)
row = list(
np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
)
row = [list(item * np.arange(max_seq_len)) for item in row]
data.append(row)
nested_tensor_ref_list.append(torch.Tensor(row))
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
nested_tensor_list = nested_tensor.unbind()
for id in range(batch_size):
self.assertEqual(
nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float)
)
@torch.inference_mode()
def _test_unbind_case(self, a, b):
nt = torch.nested.nested_tensor([a, b])
a1, b1 = nt.unbind()
self.assertTrue(a is not a1)
self.assertTrue(b is not b1)
nt = torch.nested.nested_tensor([a, b], dtype=a.dtype)
a1, b1 = nt.unbind(0)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
a = torch.randn((2, 3)).add_(1)
nt = torch.nested.nested_tensor([a])
self.assertEqual(a, nt.unbind(0)[0])
@torch.inference_mode()
def test_unbind_0(self):
self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8]))
@torch.inference_mode()
def test_unbind_1(self):
self._test_unbind_case(torch.tensor([1]), torch.tensor([7]))
@torch.inference_mode()
def test_unbind_3(self):
self._test_unbind_case(torch.tensor([1.0]), torch.tensor([]))
@torch.inference_mode()
def test_unbind_4(self):
self._test_unbind_case(torch.tensor([]), torch.tensor([]))
@torch.inference_mode()
def test_unbind_dim(self):
def _test_fn(unbind_fn):
a = torch.rand(3, 2)
b = torch.rand(2, 3)
nt = torch.nested.nested_tensor([a, b])
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
# Both of these tests are necessary, because we're using
# torch_function.
_test_fn(lambda x, dim: x.unbind(dim))
# TODO: Re-enable this once using torch_dispatch
# _test_fn(lambda x, dim: torch.unbind(x, dim))
@torch.inference_mode()
def test_nested_tensor(self):
self.assertRaises(
TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))
)
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))
@torch.inference_mode()
def test_nested_tensor_matching_dim(self):
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]),
)
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
lambda: torch.nested.nested_tensor(
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
),
)
@torch.inference_mode()
def test_default_nested_tensor(self):
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor())
default_nested_tensor = torch.nested.nested_tensor([])
default_tensor = torch.tensor([])
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
# self.assertEqual(default_nested_tensor.nested_size(), ())
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
self.assertEqual(default_nested_tensor.device, default_tensor.device)
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
self.assertEqual(
default_nested_tensor.requires_grad, default_tensor.requires_grad
)
self.assertIsNone(default_tensor.grad)
# TODO: Re-enable once we have a performance driven
# use case and implementation.
# self.assertEqual(default_nested_tensor.is_pinned(),
# default_tensor.is_pinned())
@torch.inference_mode()
def test_dim(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor(3.0)])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor([1, 2, 3, 4])])
self.assertEqual(a1.dim(), 2)
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
@torch.inference_mode()
def test_numel(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.numel(), 0)
a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
self.assertEqual(a1.numel(), 2)
a1 = constructor([torch.randn(2, 2, 2)])
self.assertEqual(a1.numel(), 8)
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
self.assertEqual(a1.numel(), 12)
a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
self.assertEqual(a1.numel(), 27)
a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
self.assertEqual(a1.numel(), 341)
# Interesting edge case
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
self.assertEqual(a1.numel(), 6)
@torch.inference_mode()
def test_size(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"NestedTensorImpl doesn't support sizes",
lambda: a1.size(),
)
def test_size_dim(self):
a = torch.nested.nested_tensor([])
self.assertEqual(a.size(0), 0)
a = torch.nested.nested_tensor([torch.tensor(1)])
self.assertEqual(a.size(0), 1)
a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)])
self.assertEqual(a.size(0), 2)
a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)])
self.assertEqual(a.size(0), 2)
self.assertEqual(a.size(1), 1)
self.assertRaisesRegex(
RuntimeError,
"Given dimension 2 is irregular and does not have a size",
lambda: a.size(2),
)
a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)])
self.assertEqual(a.size(0), 2)
self.assertRaisesRegex(
RuntimeError,
"Given dimension 1 is irregular and does not have a size",
lambda: a.size(1),
)
self.assertEqual(a.size(2), 4)
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
@torch.inference_mode()
def test_stride(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"NestedTensorImpl doesn't support strides",
lambda: a1.stride(),
)
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
@torch.inference_mode()
def test_is_contiguous(self):
# Test empty case
nt_empty = torch.nested.nested_tensor([])
assert nt_empty.is_contiguous()
self.assertEqual(nt_empty, nt_empty.contiguous())
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
# Test contiguous case
assert nt_contiguous.is_contiguous()
self.assertEqual(nt_contiguous, nt_contiguous.contiguous())
# Test non_contiguous case
assert not nt_noncontiguous.is_contiguous()
self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
# Test querying by memory_format
self.assertTrue(
nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
)
self.assertTrue(
not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
)
@torch.inference_mode()
def test_repr_string(self):
a = torch.nested.nested_tensor([])
expected = "nested_tensor([\n\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = torch.nested.nested_tensor([torch.tensor(1.0)])
expected = "nested_tensor([\n tensor(1.)\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
def test_to_padded_tensor_on_empty_tensor(self):
nt = torch.nested.nested_tensor([])
empty = torch.nested.to_padded_tensor(nt, 4)
self.assertEqual(empty, torch.tensor([]))
def test_nested_namespace(self):
nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)])
result = nt.to_padded_tensor(4)
nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
self.assertEqual(result, nested_namespace_result)
def test_to(self):
ntensors = 4
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
def test_copy_behavior(t, non_blocking=False):
self.assertIs(t, t.to(t, non_blocking=non_blocking))
self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
self.assertIsNot(
t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)
)
devices = [t.device]
if t.device.type == "cuda":
if t.device.index == -1:
devices.append(f"cuda:{torch.cuda.current_device()}")
elif t.device.index == torch.cuda.current_device():
devices.append("cuda")
for device in devices:
self.assertIs(t, t.to(device, non_blocking=non_blocking))
self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
self.assertIsNot(
t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)
)
test_copy_behavior(nt)
self.assertEqual(nt.device, nt.to("cpu").device)
self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device)
self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype)
self.assertEqual(nt.device, nt.to(torch.float32).device)
self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype)
def test_data_ptr(getter):
self.assertEqual(getter(nt), getter(nt.to("cpu")))
self.assertEqual(
getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))
)
self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False)))
self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True)))
test_data_ptr(lambda nt: nt.data_ptr())
if torch.cuda.is_available():
for non_blocking in [True, False]:
for cuda in [
"cuda",
"cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
]:
nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4))
test_copy_behavior(nt2, non_blocking)
self.assertEqual(
nt2.device, nt2.to(cuda, non_blocking=non_blocking).device
)
self.assertEqual(
nt.device, nt2.to("cpu", non_blocking=non_blocking).device
)
self.assertEqual(
nt2.device, nt.to(cuda, non_blocking=non_blocking).device
)
self.assertIs(
torch.int32,
nt2.to(
"cpu", dtype=torch.int32, non_blocking=non_blocking
).dtype,
)
self.assertEqual(
nt.device,
nt2.to(
"cpu", dtype=torch.int32, non_blocking=non_blocking
).device,
)
self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)
def test_copy_(self):
ntensors = 4
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
nt_copy = torch.empty_like(nt)
nt_copy.copy_(nt)
for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
self.assertEqual(nt_ub, nt_copy_ub)
nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])])
self.assertRaisesRegex(
RuntimeError,
"copy_ only supports tensors that are the same size for Nested implementations",
lambda: nt_error.copy_(nt),
)
if torch.cuda.is_available():
nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4))
nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
nt_copy.copy_(nt, non_blocking=True)
torch.cuda.current_stream(torch.cuda.current_device()).synchronize()
for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
self.assertEqual(nt_ub, nt_copy_ub)
nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
nt_copy.copy_(nt, non_blocking=False)
for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
self.assertEqual(nt_ub, nt_copy_ub)
def test_fill_(self):
ntensors = 4
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
nt.fill_(10.0)
for nt_ub in nt.unbind():
t = torch.empty_like(nt_ub)
t.fill_(10.0)
self.assertEqual(nt_ub, t)
fill_tensor = torch.tensor([11.0])
self.assertRaisesRegex(
RuntimeError,
"fill_ only supports 0-dimension value tensor",
lambda: nt.fill_(fill_tensor),
)
nt.fill_(fill_tensor[0])
for nt_ub in nt.unbind():
t = torch.empty_like(nt_ub)
t.fill_(11.0)
self.assertEqual(nt_ub, t)
def test_zero_(self):
ntensors = 4
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
nt.zero_()
for nt_ub in nt.unbind():
t = torch.empty_like(nt_ub)
t.fill_(0.0)
self.assertEqual(nt_ub, t)
@parametrize(
"func",
[torch.ones_like, torch.zeros_like, torch.randn_like],
name_fn=lambda f: f.__name__,
)
def test_like_functions(self, func):
ntensors = 4
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
torch.manual_seed(1)
nt_like = func(nt)
torch.manual_seed(1)
for nt_ub in nt_like.unbind():
t_like = func(nt_ub)
self.assertEqual(nt_ub, t_like)
def test_cat(self):
# dim=0 success case
# No constraints on ragged structures matching.
x = random_nt_from_dims([5, None, 10])
y = random_nt_from_dims([3, 4, None])
output = torch.cat([x, y], dim=0)
for out_component, xy_component in zip(
output.unbind(), itertools.chain(x.unbind(), y.unbind())
):
self.assertEqual(out_component, xy_component)
# dim=-1 success case
# shape (B, *, D)
x = random_nt_from_dims([5, None, 10])
# shape (B, *, D'); same structure as x but dim=-1 differs
y = random_nt_from_similar(x, dims=[-1, -1, 8])
# should be shape (B, *, D + D') when supported
output = torch.cat([x, y], dim=-1)
for out_component, x_component, y_component in zip(
output.unbind(), x.unbind(), y.unbind()
):
self.assertEqual(
out_component, torch.cat([x_component, y_component], dim=-1)
)
# dim between 0 and -1 success case
x = random_nt_from_dims([5, None, 2, 3])
# same structure as x but dim=2 differs
y = random_nt_from_similar(x, dims=[-1, -1, 4, -1])
output = torch.cat([x, y], dim=2)
for out_component, x_component, y_component in zip(
output.unbind(), x.unbind(), y.unbind()
):
self.assertEqual(
out_component, torch.cat([x_component, y_component], dim=1)
)
# error case: mixed NT / dense inputs
x = random_nt_from_dims([5, None, 2])
y = torch.randn(5, 3, 2)
with self.assertRaisesRegex(
RuntimeError, "expected each tensor in given list to be nested"
):
torch.cat([x, y], dim=-1)
# error case: NTs with different dims
x = random_nt_from_dims([5, None, 2])
y = random_nt_from_dims([5, None, 2, 3])
with self.assertRaisesRegex(
RuntimeError,
"expected all nested tensors to have matching ragged structures outside of the concatenated dim",
):
torch.cat([x, y], dim=-1)
# error case: non-contiguous NT
x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32)
# transpose to put ragged dim next to batch dim
x, y = x.transpose(-2, -1), y.transpose(-2, -1)
with self.assertRaisesRegex(
RuntimeError, "only contiguous nested tensors are supported"
):
torch.cat([x, y], dim=-1)
# error case: multiple ragged dims in inputs
x = random_nt_from_dims([5, None, None, 2])
y = random_nt_from_similar(x)
with self.assertRaisesRegex(
RuntimeError,
"only nested tensors with a single ragged dim next to the batch dim are supported",
):
torch.cat([x, y], dim=-1)
# error case: ragged dim not next to batch dim
x = random_nt_from_dims([5, 2, None])
y = random_nt_from_similar(x)
with self.assertRaisesRegex(
RuntimeError,
"only nested tensors with a single ragged dim next to the batch dim are supported",
):
torch.cat([x, y], dim=1)
# error case: NTs with different batch sizes
x = random_nt_from_dims([5, None, 2])
y = random_nt_from_dims([3, None, 2])
with self.assertRaisesRegex(
RuntimeError,
"expected all nested tensors to have matching ragged structures outside of the concatenated dim",
):
torch.cat([x, y], dim=-1)
# error case: NTs with different ragged structures
x = torch.nested.nested_tensor(
[
torch.randn(2, 6),
torch.randn(4, 6),
torch.randn(5, 6),
]
)
y = torch.nested.nested_tensor(
[
torch.randn(5, 6),
torch.randn(4, 6),
torch.randn(2, 6),
]
)
with self.assertRaisesRegex(
RuntimeError,
"expected all nested tensors to have matching ragged structures outside of the concatenated dim",
):
torch.cat([x, y], dim=-1)
def test_nested_view_from_buffer_overflow_errors(self):
buffer = torch.tensor([1])
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)
strides = torch.tensor(
[[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64
)
offsets = torch.tensor(
[[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64
)
with self.assertRaisesRegex(
RuntimeError,
r"Storage size calculation overflowed with sizes=\[9223372036854775807\] and strides=\[1094795585\]",
):
nt = torch._nested_view_from_buffer(buffer, sizes, strides, offsets)
@markDynamoStrictTest
class TestNestedTensorDeviceType(NestedTensorTestCase):
# Helper function to generate a pair of random nested tensors
# the 2 nested tensors have same shapes
def random_nt_pair(self, device, dtype, num_tensors, max_dims):
ts1 = []
ts2 = []
for _ in range(num_tensors):
tensor_dims = tuple(
[
torch.randint(low=0, high=max_dim, size=(1,)).item()
for max_dim in max_dims
]
)
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
ts2.append(t2)
return (
torch.nested.nested_tensor(ts1, device=device, dtype=dtype),
torch.nested.nested_tensor(ts2, device=device, dtype=dtype),
)
@dtypes(*floating_types_and_half())
def test_detach(self, device, dtype):
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
x = torch.nested.nested_tensor([a, b], requires_grad=True)
x_detach = x.detach()
z = x_detach * 4
self.assertFalse(x_detach.requires_grad)
self.assertFalse(z.requires_grad)
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
x = torch.nested.as_nested_tensor([a, b])
y = x * 2
y = y.detach()
self.assertFalse(y.requires_grad)
self.assertIsNone(y.grad_fn)
z = x + y
torch.nested.to_padded_tensor(z, 0).sum().backward()
# This is an incorrect gradient, but we assume that's what the user
# wanted. detach() is an advanced option.
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("weights_only", [False, True])
def test_serialization(self, device, dtype, requires_grad, weights_only):
def compare_metadata(nt1, nt2):
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
self.assertEqual(
nt1._nested_tensor_storage_offsets(),
nt2._nested_tensor_storage_offsets(),
)
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for a in [nt_contiguous, nt_noncontiguous]:
buffer = io.BytesIO()
serialized = torch.save(a, buffer)
buffer.seek(0)
b = torch.load(buffer, weights_only=weights_only)
# should be both conceptually equal and metadata equivalent
self.assertEqual(a, b)
compare_metadata(a, b)
# should be conceptually equal but not necessarily metadata equivalent
self.assertEqual(b, nt_contiguous)
self.assertEqual(b, nt_noncontiguous)
@dtypes(torch.float, torch.float16, torch.double)
def test_unbind_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device, dtype
)
ub_contiguous = nt_contiguous.unbind()
ub_noncontiguous = nt_noncontiguous.unbind()
self.assertEqual(len(ub_contiguous), len(ub_noncontiguous))
n = len(ub_contiguous)
for i in range(n):
self.assertEqual(ub_contiguous[i], ub_noncontiguous[i])
@dtypes(torch.float)
@skipMeta
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
padded = torch.nested.to_padded_tensor(nt, 0)
nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
for t1, t2 in zip(nt.unbind(), nt_to.unbind()):
self.assertEqual(t1, t2)
self.assertEqual(nt.device, nt_to.device)
@dtypes(torch.float)
@dtypesIfCUDA(torch.float, torch.half)
@skipMeta
@torch.inference_mode()
def test_layer_norm(self, device, dtype):
def _test(size):
# Simple shapes test
t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
ts = [t0, t1, t0, t1]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
nt_result = layer_norm(nt)
for nt_subresult, t in zip(nt_result.unbind(), ts):
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
self.assertEqual(nt_subresult, t_result)
# More complex nt test with different lengths for each tensor
t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False)
t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False)
t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False)
ts = [t0, t1, t2, t0, t2]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
nt_result = layer_norm(nt)
for nt_subresult, t in zip(nt_result.unbind(), ts):
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
self.assertEqual(nt_subresult, t_result)
if size <= 128:
# Test with multidimensional tensors after irregular dim
# (run only with smaller dimensions to ensure fast execution)
t0 = torch.randn(
4, size, size, 4, device=device, dtype=dtype, requires_grad=False
)
t1 = torch.randn(
10, size, size, 4, device=device, dtype=dtype, requires_grad=False
)
t2 = torch.randn(
7, size, size, 4, device=device, dtype=dtype, requires_grad=False
)
ts = [t0, t1, t2, t0, t2]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
layer_norm = torch.nn.LayerNorm(
(size, size, 4), device=device, dtype=dtype
)
nt_result = layer_norm(nt)
for nt_subresult, t in zip(nt_result.unbind(), ts):
t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
self.assertEqual(nt_subresult, t_result)
# Test where the normalizing dimensions are not all
layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype)
nt_result = layer_norm(nt)
for nt_subresult, t in zip(nt_result.unbind(), ts):
t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
self.assertEqual(nt_subresult, t_result)
for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
_test(size)
@dtypes(torch.float)
@dtypesIfCUDA(torch.float, torch.half)
@skipMeta
@torch.inference_mode()
def test_layer_norm_breaking(self, device, dtype):
size = 128
t0 = torch.randn(
4, size, size, 4, device=device, dtype=dtype, requires_grad=False
)
t1 = torch.randn(
10, size, size, 4, device=device, dtype=dtype, requires_grad=False
)
t2 = torch.randn(
7, size, size, 4, device=device, dtype=dtype, requires_grad=False
)
ts = [t0, t1, t2, t0, t2]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
"normalized_shape extends into irregular dimensions for the nested tensor",
lambda: layer_norm(nt),
)
layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
"The shape at dimension 0",
lambda: layer_norm(nt),
)
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
def test_embedding(self, device, layout):
inputs = [
torch.randint(100, (L,), device=device, dtype=torch.int64)
for L in torch.randint(5, 50, (8,))
]
x = torch.nested.nested_tensor(
inputs, device=device, dtype=torch.int64, layout=layout
)
emb = torch.nn.Embedding(100, 8, device=device)
y = emb(x)
if layout == torch.jagged:
y.backward(torch.randn_like(y))
@torch._dynamo.disable
def check(inputs, y):
ys = y.unbind()
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
check(inputs, y)
@skipMeta
@torch.inference_mode()
@dtypes(*floating_types_and_half())
def test_masked_fill(self, device, dtype):
# nested tensor * nested tensor
(nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
ref = torch.nested.nested_tensor(
[t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]
)
out = nt.masked_fill(mask, 0)
self.assertEqual(ref, out)
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_simple(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = torch.nested.to_padded_tensor(nt, padding_value)
correct_output = t.clone()
if padding_value == 0:
correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
else:
correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
self.assertEqual(padded.dtype, dtype)
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_output_size(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
output_size = (4, 6, 5)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = torch.nested.to_padded_tensor(
nt, padding_value, output_size=output_size
)
correct_output = (
torch.ones(output_size, device=device, dtype=dtype) * padding_value
)
correct_output[:4:, :4, :4] = t.clone()
if padding_value == 0:
correct_output[0][3] = torch.zeros_like(correct_output[0][3])
else:
correct_output[0][3] = torch.ones_like(correct_output[0][3])
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
self.assertEqual(padded.dtype, dtype)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim2(self, device, dtype):
ts = [
torch.randn(160, device=device, dtype=dtype),
torch.randn(1240, device=device, dtype=dtype),
torch.randn(2400, device=device, dtype=dtype),
]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[: t.size(0)].copy_(t)
correct_output = torch.stack(correct_output)
padded = torch.nested.to_padded_tensor(nt, pad)
self.assertEqual(padded, correct_output)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim3(self, device, dtype):
ts = [
torch.randn(16, 21, device=device, dtype=dtype),
torch.randn(24, 32, device=device, dtype=dtype),
torch.randn(40, 53, device=device, dtype=dtype),
]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[: t.size(0), : t.size(1)].copy_(t)
correct_output = torch.stack(correct_output)
padded = torch.nested.to_padded_tensor(nt, pad)
self.assertEqual(padded, correct_output)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim4(self, device, dtype):
ts = [
torch.randn(16, 21, 13, device=device, dtype=dtype),
torch.randn(24, 32, 14, device=device, dtype=dtype),
torch.randn(40, 53, 16, device=device, dtype=dtype),
]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t)
correct_output = torch.stack(correct_output)
padded = torch.nested.to_padded_tensor(nt, pad)
self.assertEqual(padded, correct_output)
# TODO: test noncontiguous to_padded_tensor
# For now this tests the functionality of noncontiguous_to_padded_tensor
# and the error message of to_padded_tensor
# since to_padded_tensor does not support noncontiguous buffer yet
@dtypes(torch.float, torch.float16, torch.double)
@torch.inference_mode()
def test_to_padded_tensor_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device, dtype
)
# test noncontiguous_to_padded_tensor functionality
self.assertEqual(
torch.nested.to_padded_tensor(nt_contiguous, 0.0),
noncontiguous_to_padded_tensor(nt_noncontiguous),
)
# test to_padded_tensor error message
self.assertRaisesRegex(
RuntimeError,
r"for now to_padded_tensor only supports contiguous nested tensor",
lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0),
)
@skipMeta
def test_device_checks(self, device):
nt = torch.nested.nested_tensor([], device=device)
is_cuda = "cuda" in str(device)
self.assertEqual(nt.is_cuda, is_cuda)
@dtypes(torch.float, torch.float16, torch.double)
def test_nested_tensor_indexing(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested.nested_tensor([])
self.assertRaises(IndexError, lambda: nt0[0])
# normal case
x0 = torch.randn((2, 5), device=device, dtype=dtype)
x1 = torch.randn((3, 4), device=device, dtype=dtype)
nt = torch.nested.nested_tensor([x0, x1])
# single index: only support integer in the batch dimension
self.assertEqual(nt[0], x0)
self.assertEqual(nt[-1], x1)
self.assertRaises(IndexError, lambda: nt[2])
self.assertRaises(IndexError, lambda: nt[-3])
self.assertRaises(NotImplementedError, lambda: nt[:])
self.assertEqual(nt[...], nt)
# tuple of indices: only support integer in the batch dimension
# + all possible indexing in the original tensor dimensions
self.assertEqual(nt[0, 0, 0], x0[0, 0])
self.assertEqual(nt[0, 1, :], x0[1, :])
self.assertEqual(nt[1, ...], x1)
self.assertRaises(IndexError, lambda: nt[1, 4, 2])
self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
# test select on non-batch dimensions
self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0))
self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0))
self.assertRaises(IndexError, lambda: nt.select(1, 3))
self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0))
self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0))
self.assertRaises(IndexError, lambda: nt.select(2, 5))
# make sure indexing returns a view
nt[0].fill_(100.0)
answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
self.assertEqual(nt[0], answer)
nt[1, 1, :].fill_(200.0)
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
self.assertEqual(nt[1, 1, :], answer)
# Test that indexing works when requires_grad_(True)
# previously this was failing because the backward kernel for select.int uses .sizes()
nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
self.assertEqual(nt[0], x0)
self.assertEqual(nt[-1], x1)
grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
nt[0].backward(grad_x0)
expected_grad = torch.nested.nested_tensor(
[grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]
)
self.assertEqual(nt.grad, expected_grad)
@parametrize(
"func",
[
subtest(torch.nn.functional.relu, name="relu"),
subtest(torch.nn.functional.relu_, name="relu_"),
subtest(torch.nn.functional.gelu, name="gelu"),
subtest(torch._C._nn.gelu_, name="gelu_"),
subtest(torch.tanh, name="tanh"),
subtest(torch.tanh_, name="tanh_"),
subtest(torch.neg, name="neg"),
subtest(torch.nn.functional.silu, name="silu"),
subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"),
subtest(torch.abs, name="abs"),
subtest(torch.abs_, name="abs_"),
subtest(torch.sgn, name="sgn"),
subtest(torch.logical_not, name="logical_not"),
subtest(torch.sin, name="sin"),
subtest(torch.cos, name="cos"),
subtest(torch.isinf, name="isinf"),
subtest(torch.isposinf, name="isposinf"),
subtest(torch.isneginf, name="isneginf"),
subtest(torch.isnan, name="isnan"),
subtest(torch.sqrt, name="sqrt"),
],
)
def test_unary_funcs(self, device, func):
nt, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device=device, dtype=torch.float32
)
nested_result = func(nt)
self.assertTrue(nested_result.is_nested)
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
self.assertEqual(func(t), t_res)
self.assertRaisesRegex(
RuntimeError,
"NestedTensor must be contiguous to get buffer.",
lambda: func(nt_noncontiguous),
)
@parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
def test_binary_ops_with_scalar(self, device, func):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device=device, dtype=torch.float32
)
scalar = 0.0
# should work regardless of contiguity
for nt in (nt_contiguous, nt_noncontiguous):
nested_result = func(nt, scalar)
self.assertTrue(nested_result.is_nested)
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
self.assertEqual(func(t, scalar), t_res)
@dtypes(*floating_types_and_half())
def test_nested_tensor_chunk(self, device, dtype):
# Transformer use case
a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
a_chunks = a.chunk(3, dim=-1)
b_chunks = b.chunk(3, dim=-1)
c_chunks = c.chunk(3, dim=-1)
a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]
nt = torch.nested.nested_tensor([a, b, c])
chunked = nt.chunk(3, dim=-1)
self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))
for chunk in chunked:
self.assertFalse(chunk.is_contiguous())
# Failure chunking on ragged dimensions
self.assertRaisesRegex(
RuntimeError,
"Chunk for nested tensors is currently only supported for the last dimension.",
lambda: torch.chunk(nt, 5, dim=1),
)
self.assertRaisesRegex(
RuntimeError,
"Chunk for nested tensors is currently only supported for the last dimension.",
lambda: torch.chunk(nt, 5, dim=0),
)
# Failure on non-contiguous nt
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
self.assertRaisesRegex(
RuntimeError,
"chunk expects `self` to be contiguous.",
lambda: torch.chunk(nt_noncontiguous, 5, dim=-1),
)
# Failure when calling non divisible n_chunks
self.assertRaisesRegex(
RuntimeError,
"Chunk for nested tensors is only supported for "
"nested tensors with trailing dimension divisible by chunks.",
lambda: torch.chunk(nt, 5, dim=-1),
)
# Failure when calling backward on a chunk
a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
nt_grad = torch.nested.as_nested_tensor([a, b])
chunked = torch.chunk(nt_grad, 2, dim=-1)
self.assertRaisesRegex(
RuntimeError,
"Nested Strided Tensor doesn't support chunk backward.",
lambda: chunked[0].backward(chunked[0].clone()),
)
@dtypes(*floating_types_and_half())
def test_nested_tensor_split_with_sizes(self, device, dtype):
a = torch.randn(3, 20, device=device, dtype=dtype)
b = torch.randn(2, 20, device=device, dtype=dtype)
c = torch.randn(1, 20, device=device, dtype=dtype)
split_sizes = [4, 6, 10]
a_splits = a.split_with_sizes(split_sizes, dim=-1)
b_splits = b.split_with_sizes(split_sizes, dim=-1)
c_splits = c.split_with_sizes(split_sizes, dim=-1)
nt = torch.nested.nested_tensor([a, b, c])
nt_splits = nt.split_with_sizes(split_sizes, dim=-1)
for i, nt_split in enumerate(nt_splits):
self.assertEqual(
nt_split,
torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]),
)
dense_strides = torch.stack(
[
torch.tensor(a_splits[i].stride()),
torch.tensor(b_splits[i].stride()),
torch.tensor(c_splits[i].stride()),
]
)
self.assertEqual(nt_split._nested_tensor_strides(), dense_strides)
self.assertFalse(nt_split.is_contiguous())
# Failure calling on ragged dimensions
self.assertRaisesRegex(
RuntimeError,
"split_with_sizes for nested tensors is currently only supported for the last dimension.",
lambda: torch.split_with_sizes(nt, split_sizes, dim=1),
)
# Failure calling on non-last dimension
self.assertRaisesRegex(
RuntimeError,
"split_with_sizes for nested tensors is currently only supported for the last dimension.",
lambda: torch.split_with_sizes(nt, split_sizes, dim=0),
)
# Failure on non-contiguous nt
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
self.assertRaisesRegex(
RuntimeError,
"split_with_sizes expects `self` to be contiguous.",
lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1),
)
# Failure when calling with split_sizes that don't cover the full dim size
bad_split_sizes = [4, 6, 9] # don't add up to 20
self.assertRaisesRegex(
RuntimeError,
"split_with_sizes expects split_sizes to sum exactly to 20",
lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1),
)
@dtypes(torch.float, torch.float16, torch.double)
@torch.inference_mode()
def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device, dtype
)
self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0))
n = nt_contiguous.size(0)
for i in range(n):
self.assertEqual(nt_contiguous[i], nt_noncontiguous[i])
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
@parametrize("transpose", [True, False])
def test_nested_tensor_add(self, device, dtype, transpose):
if transpose:
a = torch.randn(2, 2, 2, device=device, dtype=dtype)
b = torch.rand(2, 2, 2, device=device, dtype=dtype)
c = a.transpose(-1, -2).contiguous()
d = b.transpose(-1, -2).contiguous()
nt1 = torch.nested.nested_tensor([a, b, a, b])
nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
else:
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested.nested_tensor(
[t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
)
out = nt1 + nt2
self.assertEqual(ref, out)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
@parametrize("transpose", [True, False])
def test_nested_tensor_sub(self, device, dtype, transpose):
if transpose:
a = torch.randn(2, 2, 2, device=device, dtype=dtype)
b = torch.rand(2, 2, 2, device=device, dtype=dtype)
c = a.transpose(-1, -2).contiguous()
d = b.transpose(-1, -2).contiguous()
nt1 = torch.nested.nested_tensor([a, b, a, b])
nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
else:
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested.nested_tensor(
[t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
)
out = nt1 - nt2
self.assertEqual(ref, out)
@onlyCUDA
@dtypes(torch.float, torch.float16)
@torch.inference_mode()
@parametrize("embedding_dim", [8, 128, 256, 384])
def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
def _test_add_mul(nt, t):
ref_add = torch.nested.nested_tensor(
[t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
)
ref_mul = torch.nested.nested_tensor(
[t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
)
self.assertEqual(nt.add(t), ref_add)
self.assertEqual(nt.mul(t), ref_mul)
batch_size = 32
seq_lens = torch.randint(low=0, high=10, size=(batch_size,))
# [B, *, D], [B, 1, D] case
ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
_test_add_mul(nt, t)
# [B, *], [B, 1] case
ts = [torch.randn(seq_len) for seq_len in seq_lens]
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
t = torch.randn((batch_size, 1), device=device, dtype=dtype)
_test_add_mul(nt, t)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_mul(self, device, dtype):
# nested tensor * nested tensor
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested.nested_tensor(
[t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
)
out = nt1 * nt2
self.assertEqual(ref, out)
# nested tensor * scalar
number = 10.0
scalar = torch.tensor(number).to(dtype).to(device)
ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
out_number0 = nt1 * number
out_number1 = number * nt1
out_scalar0 = nt1 * scalar
out_scalar1 = scalar * nt1
self.assertEqual(out_number0, ref)
self.assertEqual(out_number1, ref)
self.assertEqual(out_scalar0, ref)
self.assertEqual(out_scalar1, ref)
# error case: numel == 1 but dim > 0
vector = torch.tensor([number]).to(dtype).to(device)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a nested self and non-nested other",
lambda: nt1.mul(vector),
)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a non-nested self and nested other",
lambda: vector.mul(nt1),
)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_div(self, device, dtype):
nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
scale = 4.0
ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
out = nt / 4.0
self.assertEqual(ref, out)
ref_transposed = ref.transpose(1, 2)
out = nt.transpose(1, 2) / 4.0
self.assertEqual(ref_transposed, out)
ref = torch.nested.nested_tensor(
[t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]
)
out = nt / nt2
self.assertEqual(ref, out)
out = nt.transpose(1, 2) / nt2.transpose(1, 2)
self.assertEqual(ref.transpose(1, 2), out)
nt_transpose_copy = torch.nested.nested_tensor(
[t.transpose(0, 1) for t in nt.unbind()]
)
self.assertRaisesRegex(
RuntimeError,
"div requires strides to match when given NestedTensors",
lambda: nt_transpose_copy.transpose(1, 2) / nt2,
)
nt = torch.nested.nested_tensor(
[torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype
)
nt_chunks = nt.chunk(2, -1)
self.assertRaisesRegex(
RuntimeError,
"div requires offsets to match when given NestedTensors",
lambda: nt_chunks[0] / nt_chunks[1],
)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_add_in_place(self, device, dtype):
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested.nested_tensor(
[t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
)
nt1 += nt2
self.assertEqual(ref, nt1)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_mul_in_place(self, device, dtype):
# nested tensor * nested tensor
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested.nested_tensor(
[t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
)
nt1 *= nt2
self.assertEqual(ref, nt1)
# nested tensor * scalar
number = 10.0
scalar = torch.tensor(number).to(dtype).to(device)
ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
out_number = nt1.clone()
out_number *= number
out_scalar = nt1.clone()
out_scalar *= scalar
self.assertEqual(out_number, ref)
self.assertEqual(out_scalar, ref)
self.assertRaisesRegex(
RuntimeError,
r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
lambda: scalar.mul_(nt1),
)
# error case: numel == 1 but dim > 0
vector = torch.tensor([number]).to(dtype).to(device)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a nested self and non-nested other",
lambda: nt1.mul_(vector),
)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a non-nested self and nested other",
lambda: vector.mul_(nt1),
)
@onlyCPU
@skipMeta
@dtypes(torch.float)
def test_nested_tensor_sum_dim(self, device, dtype):
params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))
def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True):
nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False)
nt2 = nt.clone()
ub2 = nt2.unbind()
nt.requires_grad_(True)
[t.requires_grad_(True) for t in ub2]
nt_sum = nt.sum(dim=dim, keepdim=keepdim)
ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2]
self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum))
# test backward
# generate gradient tensor that has the same size as the output
size = nt_sum._nested_tensor_size()
gt2 = []
for i in range(ntensors):
gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype))
gt = torch.nested.nested_tensor(gt2).clone()
nt_sum.backward(gt)
for t2, g2 in zip(ub2_sum, gt2):
t2.backward(g2)
self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2]))
return
for ntensors, max_sizes in params:
test_sum(device, dtype, ntensors, max_sizes, len(max_sizes))
# Test error inputs
with self.assertRaisesRegex(
RuntimeError, "NestedTensor can only be reduced across the last"
):
torch.nested.nested_tensor(
[torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
).sum(0, keepdim=True)
with self.assertRaisesRegex(
RuntimeError, "NestedTensor only allows reduction of a single"
):
torch.nested.nested_tensor(
[torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]
).sum([0, 1], keepdim=True)
with self.assertRaisesRegex(
RuntimeError, "NestedTensor always requires keepdim=True for now."
):
torch.nested.nested_tensor(
[torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
).sum(-1)
@dtypes(torch.float, torch.float16)
def test_contiguous(self, device, dtype):
# Since we don't have access to the buffer in python this is harder to show what
# we are testing for. When we call chunk on a consistent dim of a NT
# for chunk_size > 1 the resulting tensors are views of the original NT
# whose numels is now less than the size of the buffer. Clone was
# previously creating a new NT with a buffer that was the same size as the
# original.
nt_contiguous = torch.nested.nested_tensor(
[
torch.randn(2, 20, device=device, dtype=dtype),
torch.randn(4, 20, device=device, dtype=dtype),
]
)
# Split up the last dimension which has a consistent size of 20 into 5 chunks
chunks = nt_contiguous.chunk(5, dim=-1)
# # Check chunks are contiguous after calling contiguous
for chunk in chunks:
self.assertFalse(chunk.is_contiguous())
self.assertTrue(chunk.contiguous().is_contiguous())
@dtypes(torch.float, torch.float16)
@skipMeta
def test_clone(self, device, dtype):
nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1))
nt2 = nt1.clone()
# Verify the values match
self.assertEqual(nt1, nt2)
# Verify modifying nt2 doesn't affect nt1
nt2.mul_(nt1)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
for i in range(len(ub1)):
self.assertNotEqual(ub1[i], ub2[i])
nt1.clone(memory_format=torch.preserve_format)
msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
@decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged)
@dtypes(torch.float, torch.double)
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
def test_dropout(self, device, dtype, layout):
# edge case: empty nested tensor
# TODO: support empty NT in jagged layout
if layout == torch.strided:
nt0 = torch.nested.nested_tensor([], layout=layout)
y = torch.nn.functional.dropout(nt0, 0.5)
self.assertEqual(nt0, y)
# normal nested tensor
ntensors = 4
if layout == torch.jagged:
nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout)
else:
nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout)
# edge case: invalid dropout
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
# edge case: no dropout
dropouter = torch.nn.Dropout(0.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 0.0)
self.assertEqual(nt, y0)
self.assertEqual(nt, y1)
# edge case: all dropout
dropouter = torch.nn.Dropout(1.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 1.0)
nt0 = torch.zeros_like(nt)
self.assertEqual(nt0, y0)
self.assertEqual(nt0, y1)
# normal case: normal dropout
p = 0.2
y = torch.nn.functional.dropout(nt, p)
expect = nt.clone()
if layout == torch.jagged:
expect = torch.where(y == 0.0, y, nt)
expect /= 1.0 - p
self.assertEqual(y, expect)
else:
expect = nt.clone()
for i in range(ntensors):
actual_tensor = y[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.assertEqual(y, expect)
with freeze_rng_state():
dropouter = torch.nn.Dropout(p)
y0 = dropouter(nt)
with freeze_rng_state():
y1 = torch.nn.functional.dropout(nt, p)
self.assertEqual(y0, y1)
@dtypes(torch.float, torch.double)
def test_dropout_noncontiguous(self, device, dtype):
ntensors = 4
nt0 = random_nt(device, dtype, ntensors, (4, 4))
nt1 = nt0.transpose(-1, -2)
p = 0.3
with freeze_rng_state():
dropouter = torch.nn.Dropout(p)
y0 = dropouter(nt0)
with freeze_rng_state():
y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2)
self.assertEqual(y0, y1)
# cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_softmax(self, device, dtype):
# normal nested tensor
ntensors = 4
nt = random_nt(device, dtype, ntensors, (4, 4))
# error case: softmax across nested dimension
self.assertRaisesRegex(
RuntimeError,
"Cannot apply softmax across nested dimension 0",
lambda: torch.nn.functional.softmax(nt, 0),
)
self.assertRaisesRegex(
RuntimeError,
"Cannot apply softmax across nested dimension 0",
lambda: torch.nn.functional.softmax(nt, -3),
)
# error case: dimension out of range
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
# normal case: should equal to padding -inf
softmaxer = torch.nn.Softmax(1)
y0 = softmaxer(nt)
y1 = torch.nn.functional.softmax(nt, 1)
self.assertEqual(y0, y1)
pt = torch.nested.to_padded_tensor(nt, float("-inf"))
# if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
# however, physically speaking that should be 0.0
expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect)
# edge case: empty nested tensor
nt0 = torch.nested.nested_tensor([])
y = torch.nn.functional.softmax(nt0, 1)
self.assertEqual(nt0, y)
# edge case: nesting scalars
nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
@dtypes(torch.float, torch.double)
@torch.inference_mode()
def test_softmax_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device, dtype
)
self.assertEqual(
torch.nn.functional.softmax(nt_contiguous, -1),
torch.nn.functional.softmax(nt_noncontiguous, -1),
)
def _test_bmm(self, device, dtype):
# error case: not 3D tensors
nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
nt1 = torch.nested.nested_tensor(
[torch.randn(2), torch.randn(3)], device=device, dtype=dtype
)
nt2 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
)
self.assertRaisesRegex(
RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0)
)
self.assertRaisesRegex(
RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1)
)
self.assertRaisesRegex(
RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2)
)
self.assertRaisesRegex(
RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0)
)
self.assertRaisesRegex(
RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1)
)
self.assertRaisesRegex(
RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2)
)
self.assertRaisesRegex(
RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0)
)
self.assertRaisesRegex(
RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1)
)
# error case: incompatible batch size
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
device=device,
dtype=dtype,
)
self.assertRaisesRegex(
RuntimeError,
"Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
lambda: nt0.bmm(nt1),
)
self.assertRaisesRegex(
RuntimeError,
"Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
lambda: nt1.bmm(nt0),
)
# error case: underlying matrices cannot be multiplied
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
)
self.assertRaisesRegex(
RuntimeError,
r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
lambda: nt0.bmm(nt0),
)
# normal nested tensor
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
torch.nested.to_padded_tensor(nt1, 0.0)
)
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
# nested tensor bmm normal tensor
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype
)
nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
# nested tensor bmm normal tensor with non-contiguous view
nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device)
nt1 = nt1.transpose(1, 2)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
# normal tensor bmm nested tensor
nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device)
nt1 = torch.nested.nested_tensor(
[torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype
)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0))
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
# test tensorcore path
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype
)
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
torch.nested.to_padded_tensor(nt1, 0.0)
)
if dtype == torch.float16:
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
else:
self.assertEqual(actual, expect)
@onlyCUDA
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
@tf32_on_and_off(0.005)
def test_bmm_cuda(self, device, dtype):
self._test_bmm(device, dtype)
@onlyCPU
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_bmm_cpu(self, device, dtype):
self._test_bmm(device, dtype)
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_bmm_noncontiguous(self, device, dtype):
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
(2, 3), device, dtype
)
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
(6, 7), device, dtype
)
self.assertEqual(
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous),
)
@dtypes(torch.float, torch.double)
@tf32_on_and_off(0.005)
def test_matmul_with_bmm_path(self, device, dtype):
def unbind_rebind_matmul(nt1, nt2):
t1s = nt1.unbind()
t2s = nt2.unbind()
out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
return torch.nested.nested_tensor(out_ts)
# [N, n_head, *, head_dim], [N, n_head, head_dim, *]
Ns = [1, 2, 5]
n_heads = np.random.randint(2, 5)
head_dim = 3
t1s = []
t2s = []
for N in Ns:
for _ in range(N):
seq_len1 = np.random.randint(2, 5)
seq_len2 = np.random.randint(2, 5)
t1s.append(torch.randn(n_heads, seq_len1, head_dim))
t2s.append(torch.randn(n_heads, head_dim, seq_len2))
nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
# test with noncontiguous
t3s = []
t4s = []
for _ in range(N):
seq_len = np.random.randint(2, 5)
t3s.append(torch.randn(seq_len, n_heads, head_dim))
t4s.append(torch.randn(seq_len, n_heads, head_dim))
nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(
1, 2
)
nt4 = (
torch.nested.nested_tensor(t4s, device=device, dtype=dtype)
.transpose(1, 2)
.transpose(2, 3)
)
self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_matmul(self, device, dtype):
# error case: one is nested but the other is not
nt = torch.nested.nested_tensor(
[torch.randn(2), torch.randn(3)], device=device, dtype=dtype
)
t = torch.randn(4, device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
"Expected both to be nested, but got a nested self and non-nested other",
lambda: torch.matmul(nt, t),
)
self.assertRaisesRegex(
RuntimeError,
"Expected both to be nested, but got a non-nested self and nested other",
lambda: torch.matmul(t, nt),
)
# error case: not 3+D tensors
nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
nt1 = torch.nested.nested_tensor(
[torch.randn(2), torch.randn(3)], device=device, dtype=dtype
)
nt2 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
lambda: torch.matmul(nt0, nt0),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
lambda: torch.matmul(nt0, nt1),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
lambda: torch.matmul(nt0, nt2),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
lambda: torch.matmul(nt1, nt0),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
lambda: torch.matmul(nt1, nt1),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
lambda: torch.matmul(nt1, nt2),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
lambda: torch.matmul(nt2, nt0),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
lambda: torch.matmul(nt2, nt1),
)
# error case: incompatible batch size
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
device=device,
dtype=dtype,
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
lambda: torch.matmul(nt0, nt1),
)
self.assertRaisesRegex(
RuntimeError,
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
lambda: torch.matmul(nt1, nt0),
)
# error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
)
self.assertRaisesRegex(
RuntimeError,
"matmul(): For nested tensors, batch dimensions must have the same sizes,",
lambda: torch.matmul(nt0, nt1),
)
# error case: incompatible batch sizes that should technically broadcast
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
)
self.assertRaisesRegex(
RuntimeError,
"matmul(): For nested tensors, batch dimensions must have the same sizes,",
lambda: torch.matmul(nt0, nt1),
)
# error case: underlying matrices cannot be multiplied
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
)
self.assertRaisesRegex(
RuntimeError,
"matmul(): Nested tensors cannot be matrix multiplied",
lambda: torch.matmul(nt0, nt0),
)
# normal nested tensor: 3D
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
)
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
expect = torch.matmul(
torch.nested.to_padded_tensor(nt0, 0.0),
torch.nested.to_padded_tensor(nt1, 0.0),
)
self.assertEqual(actual, expect)
# normal nested tensor: 4D (with testing for batch_size=1)
nt0 = torch.nested.nested_tensor(
[torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype
)
nt1 = torch.nested.nested_tensor(
[torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype
)
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
expect = torch.matmul(
torch.nested.to_padded_tensor(nt0, 0.0),
torch.nested.to_padded_tensor(nt1, 0.0),
)
self.assertEqual(actual, expect)
# normal nested tensor: 5D
nt0 = torch.nested.nested_tensor(
[torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))],
device=device,
dtype=dtype,
)
nt1 = torch.nested.nested_tensor(
[torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))],
device=device,
dtype=dtype,
)
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
expect = torch.matmul(
torch.nested.to_padded_tensor(nt0, 0.0),
torch.nested.to_padded_tensor(nt1, 0.0),
)
self.assertEqual(actual, expect)
# only supported on CUDA for now
@dtypes(torch.float, torch.double)
def test_matmul_nt_with_broadcasted_t(self, device, dtype):
# NT (B, *, C, D) with T (D, E) broadcasting case
nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype)
t = torch.randn(5, 6, device=device, dtype=dtype)
output = torch.matmul(nt, t)
# should be equivalent to matmul-ing each component with the dense tensor
self.assertEqual(nt.size(0), output.size(0))
for component, out_component in zip(nt, output):
self.assertEqual(out_component, torch.matmul(component, t))
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_matmul_noncontiguous(self, device, dtype):
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
(2, 3), device, dtype
)
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
(6, 7), device, dtype
)
self.assertEqual(
torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous),
torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous),
)
@dtypes(torch.float, torch.double)
def test_linear(self, device, dtype):
a = torch.randn(1, 2, device=device, dtype=dtype)
b = torch.randn(2, 2, device=device, dtype=dtype)
c = torch.randn(3, 2, device=device, dtype=dtype)
nt = torch.nested.nested_tensor([a, b, c])
weight = torch.randn(2, 2, device=device, dtype=dtype)
bias = torch.randn(2, device=device, dtype=dtype)
# success case
torch.functional.F.linear(nt, weight, bias)
# invalid nested tensor dimension
msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2"
nt1 = torch.nested.nested_tensor(
[
torch.randn(1, device=device, dtype=dtype),
torch.randn(2, device=device, dtype=dtype),
]
)
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt1, weight, bias)
# invalid weight shape
msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3"
weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, weight1, bias)
# inconsistent last dim of nested tensor
msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:"
nt2 = torch.nested.nested_tensor(
[
torch.randn(1, 2, device=device, dtype=dtype),
torch.randn(2, 3, device=device, dtype=dtype),
]
)
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt2, weight, bias)
# Mismatch of nested tensor last dim and weight dimension
weight2 = torch.randn(2, 4, device=device, dtype=dtype)
msg = (
r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'"
r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
)
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, weight2, bias)
# Nested tensor input and nested weight
nt_weight = nt.clone()
msg = r"Linear does not support nested weight when input is a nested tensor."
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, nt_weight, bias)
# TODO: test noncontiguous linear
# For now this tests the error message of linear
# since linear does not support noncontiguous buffer yet
@dtypes(torch.float, torch.double)
def test_linear_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
(2, 3, 6, 7), device, dtype
)
weight = torch.randn((8, 5), device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
r"for now linear only supports contiguous nested tensor",
lambda: torch.nn.functional.linear(nt_noncontiguous, weight),
)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
ts = [torch.ones(1, 0), torch.ones(0, 0)]
nt = torch.nested.nested_tensor(
ts, device=device, dtype=dtype, layout=torch.strided
)
self.assertRaisesRegex(
RuntimeError,
r"at least one constituent tensor should have non-zero numel",
lambda: torch.nested.to_padded_tensor(nt, 0.0),
)
@dtypes(torch.float, torch.float16, torch.double)
def test_transpose(self, device, dtype):
nt = random_nt(device, dtype, 4, (4, 4))
# error case: transpose nested dimension
self.assertRaisesRegex(
RuntimeError,
"Nested tensor dimension 0 cannot be transposed",
lambda: nt.transpose(0, 1),
)
self.assertRaisesRegex(
RuntimeError,
"Nested tensor dimension 0 cannot be transposed",
lambda: nt.transpose(1, -3),
)
# error case: dimension out of range
self.assertRaises(IndexError, lambda: nt.transpose(1, 3))
self.assertRaises(IndexError, lambda: nt.transpose(-4, -1))
# normal case
ntT = nt.transpose(-1, -2)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = torch.nested.to_padded_tensor(nt, 0.0)
ptT = pt.transpose(-1, -2)
self.assertEqual(ptT, ptT_from_ntT)
@dtypes(torch.float, torch.float16, torch.double)
def test_squeeze_unsqueeze(self, device, dtype):
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(5, 3)
nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
# error case: squeeze no dimension
self.assertRaisesRegex(
RuntimeError,
"For nested tensors, squeeze without the dim argument",
lambda: nt.squeeze(),
)
# error case: squeeze nested dimension
self.assertRaisesRegex(
RuntimeError,
"For nested tensors, squeezing dimension 0",
lambda: nt.squeeze(0),
)
# error case: dimension out of range
self.assertRaises(IndexError, lambda: nt.squeeze(3))
# error case: squeeze nested tensor of singleton tensors
c = torch.ones(1)
nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
self.assertRaisesRegex(
RuntimeError,
"For nested tensors, squeezing a nested tensor of singleton",
lambda: nt_singleton.squeeze(1),
)
# squeezing a dim which does not have size 1 should be a no-op
nt2 = nt.squeeze(-1)
self.assertEqual(nt, nt2)
# test cases that should work
nt_sizes = nt._nested_tensor_size()
nt_strides = nt._nested_tensor_strides()
for i in range(-2, 4):
if i == 0:
# cannot unsqueeze batch dim
continue
nt_unsqueezed = nt.unsqueeze(i)
# negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
wrapped_i = i + nt.dim() + 1 if i < 0 else i
# col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
size_idx = wrapped_i - 1
self.assertEqual(
nt_unsqueezed._nested_tensor_size()[:, size_idx],
torch.ones(2, dtype=torch.long),
)
unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
if i == nt.ndim or i == -1:
self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
else:
stride_col_after = nt_strides[:, size_idx]
size_col_after = nt_sizes[:, size_idx]
self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
nt_squeezed = nt_unsqueezed.squeeze(i)
self.assertEqual(nt_squeezed, nt)
self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)
@dtypes(torch.float, torch.float16, torch.double)
def test_transpose_inference_mode_interaction(self, device, dtype):
nt = random_nt(device, dtype, 4, (4, 4))
# Construct in default mode and transpose while in inference mode
with torch.inference_mode():
ntT = nt.transpose(-1, -2)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = torch.nested.to_padded_tensor(nt, 0.0)
ptT = pt.transpose(-1, -2)
self.assertEqual(ptT, ptT_from_ntT)
# Construct and transpose while in inference mode
with torch.inference_mode():
nt = random_nt(device, dtype, 4, (4, 4))
ntT = nt.transpose(-1, -2)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = torch.nested.to_padded_tensor(nt, 0.0)
ptT = pt.transpose(-1, -2)
self.assertEqual(ptT, ptT_from_ntT)
@dtypes(torch.float, torch.float16, torch.double)
def test_view(self, device, dtype):
nt = random_nt(device, dtype, 4, (4, 4))
# error case: empty shape
self.assertRaisesRegex(
RuntimeError,
r"shape '\[\]' is invalid for a nested tensor",
lambda: nt.view(()),
)
# error case: empty nested tensor
nt_empty = torch.nested.nested_tensor([])
self.assertRaisesRegex(
RuntimeError,
"empty nested tensor cannot be reshaped",
lambda: nt_empty.view(-1),
)
# error case: -1 for batch size
self.assertRaisesRegex(
RuntimeError,
r"view: For now nested view cannot change or infer the implicit batch dimension",
lambda: nt.view(-1, 2, 3),
)
self.assertRaisesRegex(
RuntimeError,
r"shape '\[.*\]' is invalid for input of size [0-9]+",
lambda: nt.view(4, 2, 3),
)
# normal case
x0 = torch.randn((2, 20), device=device, dtype=dtype)
x1 = torch.randn((3, 20), device=device, dtype=dtype)
nt = torch.nested.nested_tensor([x0, x1])
pt = torch.nested.to_padded_tensor(nt, 0.0)
# error case, trying to reshape batch dim to a legit shape
self.assertRaisesRegex(
RuntimeError,
r"For now nested view cannot change or infer the implicit batch dimension",
lambda: nt.transpose(-1, -2).view(40, -1),
)
# inherit only the ragged dimension
# (2, 20) -> (2, 5, 4)
# (3, 20) -> (3, 5, 4)
nt1 = nt.view(2, -1, 5, 4)
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.view(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
# more than one -1 (even for "old" dims), should fail
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
# but we ban "inherit old behavior" for >1 dimension
self.assertRaisesRegex(
RuntimeError,
r"only one dimension can be inferred",
lambda: nt1.view(2, -1, -1, 2, 2),
)
@dtypes(torch.float, torch.float16, torch.double)
def test_view_inference_mode_interaction(self, device, dtype):
# Construct in default mode and view while in inference mode
nt = torch.nested.nested_tensor(
[torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
)
with torch.inference_mode():
ntT = nt.view(2, -1, 4, 5)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = torch.nested.to_padded_tensor(nt, 0.0)
ptT = pt.view(2, -1, 4, 5)
self.assertEqual(ptT, ptT_from_ntT)
# Construct and view while in inference mode
with torch.inference_mode():
nt = torch.nested.nested_tensor(
[torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
)
ntT = nt.view(2, -1, 4, 5)
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
pt = torch.nested.to_padded_tensor(nt, 0.0)
ptT = pt.view(2, -1, 4, 5)
self.assertEqual(ptT, ptT_from_ntT)
@dtypes(torch.float, torch.float16, torch.double)
def test_reshape(self, device, dtype):
nt = random_nt(device, dtype, 4, (4, 4))
# error case: empty shape
self.assertRaisesRegex(
RuntimeError,
r"shape '\[\]' is invalid for a nested tensor",
lambda: nt.reshape(()),
)
# error case: empty nested tensor
nt_empty = torch.nested.nested_tensor([])
self.assertRaisesRegex(
RuntimeError,
"empty nested tensor cannot be reshaped",
lambda: nt_empty.reshape(-1),
)
# error case: -1 for batch size
self.assertRaisesRegex(
RuntimeError,
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
lambda: nt.reshape(-1, 2, 3),
)
self.assertRaisesRegex(
RuntimeError,
r"shape '\[.*\]' is invalid for input of size [0-9]+",
lambda: nt.reshape(4, 2, 3),
)
# normal case
x0 = torch.randn((2, 20), device=device, dtype=dtype)
x1 = torch.randn((3, 20), device=device, dtype=dtype)
nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20)
pt = torch.nested.to_padded_tensor(nt, 0.0)
# error case, trying to reshape batch dim to a legit shape
self.assertRaisesRegex(
RuntimeError,
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
lambda: nt.transpose(-1, -2).reshape(40, -1),
)
# inherit only the ragged dimension
# (2, 20) -> (2, 5, 4)
# (3, 20) -> (3, 5, 4)
nt1 = nt.reshape(2, -1, 5, 4)
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
pt1 = pt.reshape(2, -1, 5, 4)
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
# more than one -1 (even for "old" dims), should fail
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
# but we ban "inherit old behavior" for >1 dimension
self.assertRaisesRegex(
RuntimeError,
r"only one dimension can be inferred",
lambda: nt1.reshape(2, -1, -1, 2, 2),
)
def test_nested_masked_select(self, device):
t = torch.randn([3, 3], device=device)
mask = torch.tensor([False], device=device)
njt = torch.nested.masked_select(t, mask)
self.assertEqual(njt.values(), torch.tensor([], device=device))
self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device))
mask = torch.tensor([[False], [False], [True]], device=device)
njt = torch.nested.masked_select(t, mask)
self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1)
self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device))
mask = torch.tensor(
[[False, False, True], [True, False, True], [False, False, True]],
device=device,
)
njt = torch.nested.masked_select(t, mask)
self.assertEqual(njt.values(), t.masked_select(mask))
self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device))
t = torch.randn([2, 3, 3, 1], device=device)
mask = torch.tensor(
[
[
[[True], [False], [True]],
[[True], [False], [True]],
[[True], [False], [True]],
],
[
[[False], [True], [True]],
[[False], [True], [True]],
[[True], [True], [True]],
],
],
device=device,
)
njt = torch.nested.masked_select(t, mask)
self.assertEqual(njt.values(), t.masked_select(mask))
self.assertEqual(
njt.offsets(),
torch.tensor(
[0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13],
device=device,
),
)
@dtypes(torch.float, torch.float16, torch.double)
def test_narrow(self, device, dtype):
nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype)
# narrow on dim=0 from start to end
bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)]
for start, end in bounds:
length = end - start
narrowed = nt.narrow(dim=0, start=start, length=length)
# ensure output is a view
self.assertTrue(narrowed._base is nt)
for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]):
self.assertEqual(nc, c)
# dim != 0 is not supported
for dim in range(1, nt.dim()):
with self.assertRaisesRegex(
RuntimeError, "only dim=0 supported for nested tensors"
):
nt.narrow(dim=dim, start=0, length=1)
# error case: non-contiguous NT
_, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4))
with self.assertRaisesRegex(
RuntimeError, "only contiguous nested tensors supported"
):
nt_noncont.narrow(dim=0, start=0, length=1)
@parametrize("input_dim", [3, 4])
@tf32_on_and_off(0.005)
def test_scaled_dot_product_attention(self, device, input_dim):
def rand_tensor(*shape):
return torch.randn(shape, device=device)
E = 8
if input_dim == 3:
# Shape: (N, L, E); ragged L
query = torch.nested.nested_tensor(
[rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]
)
# Shape: (N, S, E); ragged S
key = torch.nested.nested_tensor(
[rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
)
value = torch.nested.nested_tensor(
[rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
)
elif input_dim == 4:
# In the 4D case the L and S is ragged
# Shape: (N, N', L, E); ragged N' and L
query = torch.nested.nested_tensor(
[rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]
)
# Shape: (N, N', S, E); ragged N' and S
key = torch.nested.nested_tensor(
[rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
)
value = torch.nested.nested_tensor(
[rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
)
else:
self.fail(f"Invalid input_dim {input_dim} encountered in SDP test")
def rand_mask(size):
return torch.randint(0, 2, size=size, dtype=torch.bool, device=device)
# Shape: (N, L, S); ragged L and S matching above
attn_mask = torch.nested.nested_tensor(
[rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]
)
dropout_p = 0.0 # no dropout for reproducibility
# Success case: no attn_mask set and is_causal=False.
actual = torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p
)
expected_outputs = []
for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
output = torch.nn.functional.scaled_dot_product_attention(
q.unsqueeze(0),
k.unsqueeze(0),
v.unsqueeze(0),
attn_mask=None,
dropout_p=dropout_p,
)
expected_outputs.append(output.squeeze(0))
expected_output_nested = torch.nested.nested_tensor(expected_outputs)
self.assertEqual(actual, expected_output_nested)
# Error case: explicit attn_mask set.
with self.assertRaisesRegex(
RuntimeError, "not supported when an explicit attn_mask is set"
):
torch.nn.functional.scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p
)
# Error case: is_causal=True.
with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
torch.nn.functional.scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, is_causal=True
)
@dtypes(torch.float, torch.float16, torch.double)
def test_empty_like(self, device, dtype):
ntensors = 4
nt = random_nt(device, dtype, ntensors, (4, 4))
# Create empty on same device as original nested tensor
nt_empty = torch.empty_like(nt)
assert nt.is_same_size(nt_empty)
self.assertEqual(nt.dtype, nt_empty.dtype)
self.assertEqual(nt.device, nt_empty.device)
self.assertEqual(nt.layout, nt_empty.layout)
if torch.cuda.is_available():
if device == "cpu":
nt_cuda = torch.empty_like(nt, device="cuda")
self.assertEqual(torch.device("cuda").type, nt_cuda.device.type)
else:
nt_cpu = torch.empty_like(nt, device="cpu")
self.assertEqual(torch.device("cpu").type, nt_cpu.device.type)
# Check changing dtype of empty_like nested tensor output
dtype_set = {torch.float, torch.float16, torch.double}
for other_dtype in dtype_set - {dtype}:
nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype)
self.assertEqual(nt.dtype, dtype)
self.assertEqual(nt_empty_other_dtype.dtype, other_dtype)
self.assertEqual(nt.device, nt_empty.device)
self.assertEqual(nt.layout, nt_empty.layout)
# Create tensor for autograd
nt_empty_req_grad = torch.empty_like(nt, requires_grad=True)
self.assertEqual(nt_empty_req_grad.requires_grad, True)
# Test noncontiguous tensor does not fail to copy
nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7))
nt_empty = torch.empty_like(nt_cont)
assert nt_cont.is_same_size(nt_empty)
nt_empty_non_contig = torch.empty_like(nt_noncont)
assert nt_noncont.is_same_size(nt_empty_non_contig)
# Test the contiguous memory format option
nt_empty_contig = torch.empty_like(
nt_cont, memory_format=torch.contiguous_format
)
assert nt_cont.is_same_size(nt_empty_contig)
assert nt_empty_contig.is_contiguous()
nt_empty_non_contig = torch.empty_like(
nt_noncont, memory_format=torch.contiguous_format
)
assert nt_noncont.is_same_size(nt_empty_non_contig)
assert nt_empty_non_contig.is_contiguous()
# Test other memory formats fail
self.assertRaises(
RuntimeError,
lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last),
)
self.assertRaises(
RuntimeError,
lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last),
)
self.assertRaises(
RuntimeError,
lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d),
)
self.assertRaises(
RuntimeError,
lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d),
)
@markDynamoStrictTest
class TestNestedTensorAutograd(NestedTensorTestCase):
# Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck
# includes the default parameters used for testing ops with gradcheck. However nested tensor
# does not support the stack op therefore we turn it off for these tests
def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False):
return torch.nested.nested_tensor(
[torch.randn(1, 2), torch.randn(7, 8)],
requires_grad=requires_grad,
device=tensor_device,
)
def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False):
return torch.nested.as_nested_tensor(
[
torch.randn(1, 2, requires_grad=requires_grad),
torch.randn(7, 8, requires_grad=requires_grad),
],
device=tensor_device,
)
def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False):
data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device)
mask = torch.ones_like(data[:, :, 0]).bool()
return torch._nested_tensor_from_mask(data, mask)
def test_as_nested_tensor_propagates_gradients(self, device):
a = torch.arange(3, dtype=torch.float, device=device)
b = torch.arange(5, dtype=torch.float, device=device)
nt = torch.nested.as_nested_tensor([a, b])
# tensors with requires_grad=False are leaves
self.assertTrue(nt.is_leaf)
self.assertTrue(not nt.requires_grad)
a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
nt2 = torch.nested.as_nested_tensor([a, b])
fake_grad = torch.nested.nested_tensor(
[torch.ones_like(a), torch.zeros_like(b)], device=device
)
nt2.backward(fake_grad)
self.assertEqual(a.grad, fake_grad[0])
self.assertEqual(b.grad, fake_grad[1])
def test_nested_tensor_generates_leaf(self, device):
a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
nt = torch.nested.nested_tensor([a, b], requires_grad=False)
self.assertTrue(nt.is_leaf)
self.assertTrue(not nt.requires_grad)
nt2 = torch.nested.nested_tensor([a, b], requires_grad=True)
self.assertTrue(nt2.is_leaf)
self.assertTrue(nt2.requires_grad)
fake_grad = torch.nested.nested_tensor(
[torch.ones_like(a), torch.zeros_like(b)], device=device
)
nt2.backward(fake_grad)
self.assertEqual(nt2.grad, fake_grad)
self.assertEqual(a.grad, None)
self.assertEqual(b.grad, None)
def test_set_requires_grad_from_list(self, device):
nt = self._create_nested_tensor_from_list(device)
nt.requires_grad_()
assert nt.requires_grad
def test_set_requires_grad_from_mask(self, device):
nt = self._create_nested_tensor_from_mask(device)
nt.requires_grad_()
assert nt.requires_grad
def test_backward_for_add_op(self, device):
nt_1 = self._create_nested_tensor_from_mask(device)
nt_2 = self._create_nested_tensor_from_mask(device)
nt_1.requires_grad_()
c = nt_1 + nt_2
assert nt_1.requires_grad
assert c.requires_grad
grad_output = self._create_nested_tensor_from_mask(device)
c.backward(grad_output)
# Grad check doesn't work with nested yet.
# d/dnt_1 (nt + nt_1) = 1*grad_output
self.assertEqual(nt_1.grad, grad_output)
def test_backward_for_sub_op(self, device):
nt_1 = self._create_nested_tensor_from_mask(device)
nt_2 = self._create_nested_tensor_from_mask(device)
nt_1.requires_grad_()
nt_2.requires_grad_()
c = nt_1 - nt_2
assert nt_1.requires_grad
assert nt_2.requires_grad
assert c.requires_grad
grad_output = self._create_nested_tensor_from_mask(device)
c.backward(grad_output)
self.assertEqual(nt_1.grad, grad_output)
self.assertEqual(nt_2.grad, -1 * grad_output)
def test_backward_sub_strided(self, device):
a = torch.nested.nested_tensor(
[torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
requires_grad=True,
device=device,
)
b = torch.nested.nested_tensor(
[torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
requires_grad=True,
device=device,
)
c = a - b.transpose(-1, -2)
grad_output = c.clone()
c.backward(grad_output)
self.assertEqual(a.grad, grad_output)
self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2))
def test_backward_add_strided(self, device):
a = torch.nested.nested_tensor(
[torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
requires_grad=True,
device=device,
)
b = torch.nested.nested_tensor(
[torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
requires_grad=True,
device=device,
)
c = a + b.transpose(-1, -2)
grad_output = c.clone()
c.backward(grad_output)
self.assertEqual(a.grad, grad_output)
self.assertEqual(b.grad, grad_output.transpose(-1, -2))
# Test Factory Functions
def test_nested_tensor_to_padded_tensor(self, device):
for padding_val in [0, 1]:
nt = self._create_leaf_nested_tensor_from_list(
tensor_device=device, requires_grad=True
)
out = torch.nested.to_padded_tensor(nt, padding_val)
grad_output = torch.ones(out.shape, device=device)
out.backward(grad_output)
self.assertEqual(
nt.grad,
torch.nested.nested_tensor(
[torch.ones(1, 2), torch.ones(7, 8)], device=device
),
)
def test_nested_tensor_from_mask_and_to_padded(self, device):
N, L, D = 2, 4, 4
mask = torch.ones(N, L, device=device)
for i in range(1, N):
end = torch.randint(1, L - 1, (1,), device=device)
mask[i, end:] = 0
mask[0, :] = 1
mask = mask.bool()
data = torch.randn(
N, L, D, requires_grad=True, dtype=torch.float64, device=device
)
def grad_test_func(inpt):
nt = torch._nested_tensor_from_mask(inpt, mask)
# This implicitly tests to_padded_tensor grads
return torch.nested.to_padded_tensor(nt, 0)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_nested_tensor_from_padded(self, device):
nested_size = torch.tensor([[1, 2], [2, 2]])
padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device)
padded_tensor[0, 1, :] = 0
padded_tensor.requires_grad_()
def grad_test_func(tensor, nested_size):
nt = torch._nested_from_padded(
tensor, nested_size, fuse_transform_0213=False
)
# This implicitly tests to_padded_tensor grads
return torch.nested.to_padded_tensor(nt, 0)
data = (padded_tensor, nested_size)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_nested_tensor_from_padded_fused(self, device):
nested_size = torch.tensor([[1, 8], [2, 8]])
padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device)
padded_tensor[0, 1, :] = 0
padded_tensor.requires_grad_()
def grad_test_func(tensor, nested_size):
nt = torch._nested_from_padded(
tensor, nested_size, fuse_transform_0213=True
)
# This implicitly tests to_padded_tensor grads
return torch.nested.to_padded_tensor(nt, 0)
data = (padded_tensor, nested_size)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_nested_tensor_from_list(self, device):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
c = torch.nested.as_nested_tensor([a, b, c])
# This implictily tests to_padded_tensor grads
return torch.nested.to_padded_tensor(c, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
def test_dropout_backward(self, layout):
if layout == torch.jagged:
nt = torch.nested.nested_tensor(
[torch.randn((2, 5)), torch.randn((3, 5))],
requires_grad=True,
layout=layout,
)
else:
nt = torch.nested.nested_tensor(
[torch.randn((2, 5)), torch.randn((3, 4))],
requires_grad=True,
layout=layout,
)
p = 0.2
y = torch.nn.functional.dropout(nt, p)
y.backward(nt.detach().clone())
self.assertEqual(nt.grad, y)
def test_nested_tensor_bmm_gradcheck(self, device):
a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, d):
nt0 = torch.nested.as_nested_tensor([a, b])
nt1 = torch.nested.as_nested_tensor([c, d])
result = nt0.bmm(nt1)
return torch.nested.to_padded_tensor(result, 0.0)
data = (a, b, c, d)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
@tf32_on_and_off(0.008)
def test_nested_tensor_bmm_backward(self, device):
nt0 = torch.nested.nested_tensor(
[torch.randn((2, 6)), torch.randn((3, 6))],
requires_grad=True,
device=device,
)
nt1 = torch.nested.nested_tensor(
[torch.randn((6, 4)), torch.randn((6, 5))],
requires_grad=True,
device=device,
)
with torch.no_grad():
pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
ynt = nt0.bmm(nt1)
ypt = pt0.bmm(pt1)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
def test_nested_tensor_matmul_gradcheck(self, device):
a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, d):
nt0 = torch.nested.as_nested_tensor([a, b])
nt1 = torch.nested.as_nested_tensor([c, d])
result = torch.matmul(nt0, nt1)
return torch.nested.to_padded_tensor(result, 0.0)
data = (a, b, c, d)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_matmul_backward(self, device):
nt0 = torch.nested.nested_tensor(
[torch.randn((7, 2, 6)), torch.randn((7, 3, 6))],
requires_grad=True,
device=device,
)
nt1 = torch.nested.nested_tensor(
[torch.randn((7, 6, 4)), torch.randn((7, 6, 5))],
requires_grad=True,
device=device,
)
with torch.no_grad():
pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
ynt = torch.matmul(nt0, nt1)
ypt = torch.matmul(pt0, pt1)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
def test_nested_tensor_transpose_gradcheck(self, device):
a = torch.randn(2, 5, requires_grad=True, device=device)
b = torch.randn(3, 4, requires_grad=True, device=device)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b])
result = nt.transpose(-2, -1).transpose(-2, -1)
return torch.nested.to_padded_tensor(result, 0.0)
data = (a, b)
assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
def test_nested_tensor_transpose_backward(self, device):
nt = torch.nested.nested_tensor(
[torch.randn((2, 5)), torch.randn((3, 4))],
requires_grad=True,
device=device,
)
with torch.no_grad():
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
ynt = nt.transpose(-2, -1)
ypt = pt.transpose(-2, -1)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_reshape_gradcheck(self, device):
a = torch.randn(2, 6, requires_grad=True, device=device)
b = torch.randn(3, 6, requires_grad=True, device=device)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b])
result = nt.reshape(2, -1, 2, 3)
return torch.nested.to_padded_tensor(result, 0.0)
data = (a, b)
assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
def test_nested_tensor_reshape_backward(self):
nt = torch.nested.nested_tensor(
[torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True
)
with torch.no_grad():
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
ynt = nt.reshape(2, -1, 2, 3)
ypt = pt.reshape(2, -1, 2, 3)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_squeeze_backward(self, device):
nt = torch.nested.nested_tensor(
[torch.randn((2, 6, 1)), torch.randn((3, 6, 1))],
requires_grad=True,
device=device,
)
with torch.no_grad():
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
ynt = nt.squeeze(-1)
ypt = pt.squeeze(-1)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_squeeze_gradcheck(self, device):
a = torch.randn(
(2, 6, 1), dtype=torch.float64, requires_grad=True, device=device
)
b = torch.randn(
(3, 6, 1), dtype=torch.float64, requires_grad=True, device=device
)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b])
result = nt.squeeze(-1)
return torch.nested.to_padded_tensor(result, 0.0)
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
def test_nested_tensor_unsqueeze_backward(self, device):
nt = torch.nested.nested_tensor(
[torch.randn((2, 6)), torch.randn((3, 6))],
requires_grad=True,
device=device,
)
with torch.no_grad():
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
ynt = nt.unsqueeze(2)
ypt = pt.unsqueeze(2)
ynt.backward(ynt.clone())
ypt.backward(ypt.clone())
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
def test_nested_tensor_unsqueeze_gradcheck(self, device):
a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device)
b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b])
result = nt.unsqueeze(-1)
return torch.nested.to_padded_tensor(result, 0.0)
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
def test_nested_tensor_linear(self, device):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
weight = torch.randn(
2, 2, requires_grad=True, dtype=torch.float64, device=device
)
bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, weight, bias=None):
nt = torch.nested.as_nested_tensor([a, b, c])
# This implicitly tests to_padded_tensor grads
d = torch.functional.F.linear(nt, weight, bias)
return torch.nested.to_padded_tensor(d, 0)
data = (a, b, c, weight, bias)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# Test linear with no bias added
data = (a, b, c, weight)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_nested_tensor_linear_plus_transpose(self, device):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
weight = torch.randn(
2, 2, requires_grad=True, dtype=torch.float64, device=device
)
bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, weight, bias=None):
nt = torch.nested.as_nested_tensor([a, b, c])
# This implicitly tests to_padded_tensor grads
d = torch.functional.F.linear(nt, weight, bias)
d = d.transpose(-1, -2).contiguous()
return torch.nested.to_padded_tensor(d, 0)
data = (a, b, c, weight, bias)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# Test linear with no bias added
data = (a, b, c, weight)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_nested_tensor_softmax(self, device):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, dim):
nt = torch.nested.as_nested_tensor([a, b, c])
# This implicitly tests to_padded_tensor grads
d = torch.functional.F.softmax(nt, dim=dim)
return torch.nested.to_padded_tensor(d, 0)
# softmax over last dim
data = (a, b, c, -1)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_nested_tensor_linear_backward(self, device):
a = torch.randn(1, 2, requires_grad=False, device=device)
b = torch.randn(2, 2, requires_grad=False, device=device)
c = torch.randn(3, 2, requires_grad=False, device=device)
weight = torch.randn(2, 2, requires_grad=True, device=device)
bias = torch.randn(2, requires_grad=True, device=device)
nt = torch.nested.as_nested_tensor([a, b, c], device=device)
out = torch.functional.F.linear(nt, weight, bias)
out.backward(out.clone())
assert weight.grad is not None
assert bias.grad is not None
assert a.grad is None
assert b.grad is None
assert c.grad is None
def test_values_grad_with_broadcast(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
buffer = nt.values()
return buffer.sum()
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_to_buffer_series_ops_grad_with_broadcast(self, device):
a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
buffer = nt.values()
buffer = buffer * 2
return buffer.exp()
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_unbind_flow_through(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
ntT = nt.transpose(-1, -2)
unbound = ntT.unbind()
d = unbound[0]
d = torch.pow(d, 2)
return d
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_split_with_sizes_flow_through(self, device):
a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
splits = nt.split_with_sizes([2, 3], dim=-1)
unbound = splits[1].unbind()
d = unbound[0]
d = torch.pow(d, 2)
return d
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_indexing_backward(self, device):
x0 = torch.randn((2, 5))
x1 = torch.randn((3, 4))
nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True)
self.assertEqual(nt[0], x0)
self.assertEqual(nt[-1], x1)
grad_x0 = torch.randn((2, 5), device=device)
nt[0].backward(grad_x0)
expected_grad = torch.nested.nested_tensor(
[grad_x0, torch.zeros((3, 4), device=device)]
)
self.assertEqual(nt.grad, expected_grad)
def test_masked_fill_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
mask = nt.detach().clone().to(bool)
out = nt.masked_fill(mask, 0)
out = torch.nested.to_padded_tensor(out, 0)
return out
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_gelu_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
nt_gelu = torch.nn.functional.gelu(nt)
return torch.nested.to_padded_tensor(nt_gelu, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_relu_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
nt_relu = torch.nn.functional.relu(nt)
return torch.nested.to_padded_tensor(nt_relu, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_selu_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
nt_relu = torch.nn.functional.silu(nt)
return torch.nested.to_padded_tensor(nt_relu, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
def test_abs_backward(self, device):
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
nt_abs = torch.abs(nt)
return torch.nested.to_padded_tensor(nt_abs, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# Previously would error when input NT doesn't require grad
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
def test_layer_norm_backward_edge_case(self, device):
size = 4
a = torch.randn(
1, 2, size, requires_grad=False, dtype=torch.float64, device=device
)
nt = torch.nested.nested_tensor([a])
nt_layer_norm = torch.nn.LayerNorm(
nt.size(-1), device=device, dtype=torch.float64
)
out = nt_layer_norm(nt)
out.backward(out.clone())
def test_accumulate_grad_different_strides(self, device):
a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device)
b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b):
nt_1 = torch.nested.as_nested_tensor([a, b])
nt_2 = nt_1.clone()
out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2)
return torch.nested.to_padded_tensor(out, 0)
data = (a, b)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# https://github.com/pytorch/pytorch/issues/95562
@skipIfSlowGradcheckEnv
@parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
def test_layer_norm_backward(self, device, size):
a = torch.randn(
1, 2, size, requires_grad=True, dtype=torch.float64, device=device
)
b = torch.randn(
2, 2, size, requires_grad=True, dtype=torch.float64, device=device
)
c = torch.randn(
3, 2, size, requires_grad=True, dtype=torch.float64, device=device
)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
layer_norm = torch.nn.LayerNorm(
nt.size(-1), device=device, dtype=torch.float64
)
nt_layer_norm = layer_norm(nt)
return torch.nested.to_padded_tensor(nt_layer_norm, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# https://github.com/pytorch/pytorch/issues/95562
@skipIfSlowGradcheckEnv
# Could either mark slow or reduce size
@parametrize("size", [128, 32, 4, 2])
def test_layer_norm_backward_5d(self, device, size):
a = torch.randn(
4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
)
b = torch.randn(
7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
)
c = torch.randn(
10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c])
layer_norm = torch.nn.LayerNorm(
(size, size, nt.size(-1)), device=device, dtype=torch.float64
)
nt_layer_norm = layer_norm(nt)
return torch.nested.to_padded_tensor(nt_layer_norm, 0)
data = (a, b, c)
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
# Found in torch/testing/_comparison.py
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
deviation = torch.abs(deviation / true_value)
# Fill in the nans with the default rtol
torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
return deviation.max().item()
def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
deviation = true_value - computed_value
atol = torch.abs(deviation).max().item()
return atol
def get_tolerances(
true_value: torch.Tensor,
computed_value: torch.Tensor,
fudge_factor: Optional[float] = None,
) -> tuple[float, float]:
"""Returns the absolute and relative tolerances for comparing two tensors."""
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
atol = get_atol(true_value, computed_value)
rtol = get_rtol(true_value, computed_value)
atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
# torch.isclose() has weird behavior around see:
# https://github.com/pytorch/pytorch/issues/102400
if rtol > 1e30:
rtol = default_rtol[computed_value.dtype]
return atol, rtol
# We can probably parametrizing existing tests instead of having a separate
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
@markDynamoStrictTest
class TestNestedTensorSubclass(NestedTensorTestCase):
# TODO: consolidate with the below
def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
Ds = nested_size[1:]
out = []
for s in nested_size[0]:
out.append(
torch.randn(
s,
*Ds,
requires_grad=requires_grad,
device=device,
dtype=torch.float64,
)
)
return out
def _get_example_tensor_lists(
self,
include_list_of_lists=True,
include_requires_grad=True,
include_inner_dim_size_1=False,
include_2d_tensor=False,
):
def _make_tensor(
*shape, include_requires_grad=include_requires_grad, requires_grad=True
):
return torch.randn(
*shape,
requires_grad=(requires_grad if include_requires_grad else False),
)
# Purposefully introduce mixed requires_grad settings for the components
# when include_requires_grad=True.
example_lists = [
# (B, *, D) with B=4
[
_make_tensor(2, 5),
_make_tensor(3, 5, requires_grad=False),
_make_tensor(4, 5, requires_grad=False),
_make_tensor(6, 5),
],
# (B, *, D_0, D_1) with B=5
[
_make_tensor(2, 5, 6),
_make_tensor(3, 5, 6),
_make_tensor(4, 5, 6, requires_grad=False),
_make_tensor(5, 5, 6),
_make_tensor(6, 5, 6),
],
# (B, *, D_0, D_1, D_2) with B=6
[
_make_tensor(2, 5, 6, 7),
_make_tensor(3, 5, 6, 7),
_make_tensor(4, 5, 6, 7, requires_grad=False),
_make_tensor(5, 5, 6, 7),
_make_tensor(6, 5, 6, 7),
_make_tensor(7, 5, 6, 7),
],
]
if include_list_of_lists:
example_lists.append(
# (B, *, D) with B=3 in list form
[
_make_tensor(2, 5, requires_grad=False).tolist(),
_make_tensor(3, 5).tolist(),
_make_tensor(4, 5).tolist(),
]
)
if include_inner_dim_size_1:
example_lists.append(
[
_make_tensor(2, 1),
_make_tensor(3, 1, requires_grad=False),
_make_tensor(4, 1, requires_grad=False),
_make_tensor(6, 1),
] # (B, *, 1)
)
example_lists.append(
[
_make_tensor(2, 5, 1),
_make_tensor(3, 5, 1, requires_grad=False),
_make_tensor(4, 5, 1, requires_grad=False),
_make_tensor(6, 5, 1),
] # (B, *, 5, 1)
)
if include_2d_tensor:
example_lists.append(
[
_make_tensor(2),
_make_tensor(3, requires_grad=False),
_make_tensor(4, requires_grad=False),
_make_tensor(6),
] # (B, *)
)
return example_lists
@dtypes(torch.float32)
@parametrize(
"contiguity",
["contig", "noncontig_transposed", "noncontig_with_holes"],
name_fn=lambda c: c,
)
@parametrize("weights_only", [True, False])
def test_serialization(self, device, dtype, contiguity, weights_only):
# Test with 3 cases:
# 1. contiguous
# 2. non-contiguous transposed
# 3. non-contiguous with holes
if contiguity == "contig":
nt = random_nt_from_dims(
[4, None, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
)
elif contiguity == "noncontig_transposed":
nt = random_nt_from_dims(
[3, None, 5, 2],
device=device,
dtype=dtype,
layout=torch.jagged,
).transpose(-3, -2)
elif contiguity == "noncontig_with_holes":
nt = torch.nested.nested_tensor_from_jagged(
values=torch.randn(10, 3, device=device, dtype=dtype),
offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64),
# these lengths specify holes
lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64),
)
else:
raise ValueError("invalid contiguity specified for test_serialization()")
# Access sizes / strides to ensure cache doesn't break serialization.
# See https://github.com/pytorch/pytorch/issues/129366
nt.size()
nt.stride()
with tempfile.TemporaryFile() as f:
torch.save(nt, f)
f.seek(0)
nt_loaded = torch.load(f, weights_only=weights_only)
self.assertIsNot(nt, nt_loaded)
# we expect a new offsets tensor -> different nested int upon load
self.assertEqualIgnoringNestedInts(nt, nt_loaded)
self.assertEqual(nt._ragged_idx, nt_loaded._ragged_idx)
# ensure shapes are equal except nested int
nt_rest_of_shape = (
*nt.shape[: nt._ragged_idx],
*nt.shape[nt._ragged_idx + 1 :],
)
nt_loaded_rest_of_shape = (
*nt_loaded.shape[: nt_loaded._ragged_idx],
*nt_loaded.shape[nt_loaded._ragged_idx + 1 :],
)
self.assertEqual(nt_rest_of_shape, nt_loaded_rest_of_shape)
# ensure metadata cache is carried through serialization
self.assertEqual(nt._metadata_cache, nt_loaded._metadata_cache)
# ensure lengths are carried through if present
self.assertEqual(nt._lengths, nt_loaded._lengths)
def test_tensor_attributes(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
_offsets = nt.offsets()
for op in (
torch.ops.aten.is_non_overlapping_and_dense.default,
torch.ops.aten.sym_size.default,
torch.ops.aten.dim.default,
torch.ops.aten.numel.default,
torch.ops.aten.sym_numel.default,
torch.ops.aten.sym_stride.default,
torch.ops.aten.sym_storage_offset.default,
):
op(nt)
with self.assertRaisesRegex(
RuntimeError, "directly calling torch.ops.aten.size"
):
torch.ops.aten.size.default(nt)
nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(
_offsets, coeff=1
)
self.assertEqual(nt.size(), (3, nested_int, 3))
self.assertEqual(nt.shape, (3, nested_int, 3))
self.assertEqual(nt.dim(), 3)
self.assertEqual(nt.numel(), 27)
@parametrize("nt_dim", [3, 4, 5])
def test_linear(self, device, nt_dim):
if nt_dim == 3:
fixed_shape = (3,)
elif nt_dim == 4:
fixed_shape = (4, 3)
elif nt_dim == 5:
fixed_shape = (5, 4, 3)
a = torch.randn(
2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
)
b = torch.randn(
3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
)
c = torch.randn(
4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
)
weight = torch.randn(
4, 3, requires_grad=True, dtype=torch.float64, device=device
)
bias = torch.randn(4, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c, weight, bias):
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
out = torch.nn.functional.linear(nt, weight, bias)
return out.values()
gradcheck(
grad_test_func, inputs=(a, b, c, weight, bias), check_batched_grad=False
)
@onlyCUDA
@dtypes(torch.float32)
@serialTest()
def test_linear_backward_memory_usage(self, device, dtype):
# Verify that linear_backward() doesn't use more memory than it should
# for higher dim input sizes.
# See https://github.com/pytorch/pytorch/issues/141112
B, D, max_seq_len = 64, 512, 100
torch._C._cuda_clearCublasWorkspaces()
m = torch.nn.Linear(D, D, device=device)
nt = torch.nested.as_nested_tensor(
[
torch.rand(size=[seq_len, D])
for seq_len in torch.randint(max_seq_len, size=(B,))
],
layout=torch.jagged,
device=device,
)
# (B, j1, D) -> (B, j1, 1, D) for a higher dim input size
nt = nt.unsqueeze(-2)
# linear_backward() should not explode the max memory usage
torch.cuda.reset_max_memory_allocated()
m(nt).sum().backward()
# expect under a GB for max memory allocated
max_after_gb = torch.cuda.max_memory_allocated(0) // (1024**3)
self.assertEqual(max_after_gb, 0)
def test_unary_pointwise(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
out = torch.nn.functional.silu(nt.sin().cos())
return out.values()
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
def test_unary_pointwise_transposed_inputs(self, device):
a, b, c = (
torch.randn(
i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
)
for i in range(3)
)
nt = torch.nested.nested_tensor(
[a.detach(), b.detach(), c.detach()], layout=torch.jagged
)
nt_t = nt.transpose(1, 2)
self.assertFalse(nt_t.is_contiguous())
out = torch.nn.functional.silu(nt_t.sin().cos())
self.assertEqual(
out.is_contiguous(),
torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(),
)
self.assertEqual(nt_t.shape, out.shape)
a, b, c = (
torch.randn(
i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
)
for i in range(3)
)
def grad_test_func(a, b, c):
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
nt_t = nt.transpose(1, 2)
out = torch.nn.functional.silu(nt_t.sin().cos())
return out.values()
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
def test_binary_pointwise(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
# Incorrect usage: shape check will fail if the offsets tensor are not
# the same exact tensor object
nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
self.assertRaisesRegex(
RuntimeError,
"cannot call binary pointwise function .* with inputs of shapes",
lambda: nt1 * nt2,
)
# Correct usage: chain the calls using the same offsets tensor object
def grad_test_func(a, b, c):
nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
# TODO: Switch to public API that takes in (values, offsets) once it exists
nt2, offsets = jagged_from_list([a, b, c], nt1.offsets())
out = nt1 * nt2
return out.values()
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
def test_binary_pointwise_transposed(self, device):
a, b, c = (
torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
)
nt1, offsets = jagged_from_list([a, b, c], None)
nt2, offsets = jagged_from_list([a, b, c], offsets)
nt1_t = nt1.transpose(1, 2)
nt2_t = nt2.transpose(1, 2)
# out = nt1_t * nt2_t
# self.assertFalse(nt1_t.is_contiguous())
# self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
# self.assertEqual(out.shape, nt1_t.shape)
self.assertRaisesRegex(
RuntimeError,
"cannot call binary pointwise function mul.Tensor with inputs of shapes",
lambda: nt1 * nt2_t,
)
a, b, c = (
torch.randn(
i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
)
for i in range(3)
)
# Correct usage: chain the calls using the same offsets tensor object
def grad_test_func(a, b, c):
nt1, offsets = jagged_from_list([a, b, c], None)
nt2, offsets = jagged_from_list([a, b, c], offsets)
nt1_t = nt1.transpose(1, 2)
nt2_t = nt2.transpose(1, 2)
out = nt1_t * nt2_t
return out.values()
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
def test_binary_pointwise_with_nested_int_second_arg(self, device):
# See https://github.com/pytorch/pytorch/issues/138496
nt = random_nt_from_dims(
[3, None, 5],
device=device,
dtype=torch.float32,
layout=torch.jagged,
)
with self.assertRaisesRegex(RuntimeError, "invalid argument"):
nt * nt.size(1)
with self.assertRaisesRegex(RuntimeError, "invalid argument"):
nt + nt.size(1)
def test_split(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
out = torch.split(nt, 2, -1)
self.assertEqual(len(out), 2)
self.assertEqualIgnoringNestedInts(
out[0],
torch.nested.as_nested_tensor(
[a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged
),
)
self.assertEqualIgnoringNestedInts(
out[1],
torch.nested.as_nested_tensor(
[a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged
),
)
with self.assertRaisesRegex(
RuntimeError,
r"split\(\): not supported for NestedTensor on ragged dim",
):
torch.split(nt, 2, 1)
def test_split_with_sizes(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
out = torch.split(nt, [1, 2], -1)
self.assertEqual(len(out), 2)
self.assertEqualIgnoringNestedInts(
out[0],
torch.nested.as_nested_tensor(
[a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged
),
)
self.assertEqualIgnoringNestedInts(
out[1],
torch.nested.as_nested_tensor(
[a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged
),
)
with self.assertRaisesRegex(
RuntimeError,
r"split_with_sizes\(\): not supported for NestedTensor on ragged dim",
):
torch.split(nt, [1, 2], 1)
def test_softmax(self, device):
nt = random_nt_from_dims(
[3, None, 5],
device=device,
dtype=torch.float32,
layout=torch.jagged,
requires_grad=True,
)
# operate on dim=2
output = nt.softmax(dim=2)
@torch._dynamo.disable
def _compare_to_ref(nt, output, dim):
for in_component, out_component in zip(nt.unbind(), output.unbind()):
self.assertEqual(in_component.softmax(dim=dim), out_component)
# dim=2 -> dim=1 after unbind
_compare_to_ref(nt, output, dim=1)
# operate on dim=-1
output2 = nt.softmax(dim=-1)
torch._dynamo.disable(self.assertEqual)(output, output2)
_compare_to_ref(nt, output2, dim=-1)
def grad_test_func(a, b):
nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged)
out = nt.softmax(dim=-1)
return out.values()
a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device)
b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device)
gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False)
def test_views_inherit_ragged_dim(self, device):
# view
nt = random_nt_from_dims(
[4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
)
# inherit ragged dim via -1
view = nt.view(4, -1, 80)
self.assertEqual(nt.shape[1], view.shape[1])
# inherit batch and ragged dims via -1
view2 = nt.view(-1, -1, 80)
self.assertEqual(nt.shape[:2], view2.shape[:2])
# expand
nt = random_nt_from_dims(
[3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
)
# inherit batch and ragged dims via -1
view = nt.expand(-1, -1, 5)
self.assertEqual(nt.shape[:2], view.shape[:2])
def test_view_ragged_idx_not_one(self, device):
nt = random_nt_from_dims(
[2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged
)
view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1))
self.assertEqual((2, 20, nt.size(1)), (view_transposed.size()))
self.assertEqual(view_transposed._base, nt._base)
def test_unsafe_view(self, device):
nt = random_nt_from_dims(
[4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
)
# basic view
view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80))
self.assertEqual((4, nt.size(1), 80), tuple(view1.size()))
# _unsafe_view differs from view in that the view information is not tracked
self.assertTrue(view1._base is None)
# test an unsafe_view when ragged_idx != 1, currently only supports identity view
nt_t = nt.transpose(1, 2)
view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10))
self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size()))
self.assertTrue(view2._base is None)
@xfailIfTorchDynamo
@parametrize("requires_grad", [False, True])
def test_reshape_decomp(self, device, requires_grad):
# contiguous NT should result in view.
nt = (
random_nt_from_dims(
[3, None, 10],
device=device,
dtype=torch.float32,
layout=torch.jagged,
)
.detach()
.requires_grad_(requires_grad)
)
view = nt.reshape(-1, -1, 5, 2)
self.assertEqual(view.shape[:2], nt.shape[:2])
self.assertTrue(view._is_view() and view._base is nt)
# make sure gradients flow back
if requires_grad:
view.backward(torch.ones_like(view))
self.assertEqual(nt.grad, torch.ones_like(nt))
# non-contiguous NT should result in contiguous copy
nt = random_nt_from_dims(
[3, None, 5, 2],
device=device,
dtype=torch.float32,
layout=torch.jagged,
requires_grad=requires_grad,
)
nt_noncontig = nt.transpose(-1, -2)
self.assertFalse(nt_noncontig.is_contiguous())
copy = nt_noncontig.reshape(-1, -1, 10)
self.assertTrue(copy.is_contiguous())
self.assertEqual(copy.shape[:2], nt.shape[:2])
# make sure gradients flow back
if requires_grad:
copy.backward(torch.ones_like(copy))
self.assertEqual(nt.grad, torch.ones_like(nt))
def test_flatten_decomp(self, device):
nt = random_nt_from_dims(
[3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged
)
flattened = nt.flatten(-2, -1)
self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape)
nt = random_nt_from_dims(
[3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged
)
flattened = nt.flatten(-3, -2)
self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape)
def test_chunk(self, device):
# none NJT case
t = torch.randn(10, 4, 5, requires_grad=True)
t_list = t.chunk(3, dim=0)
loss = t_list[0].sum() + t_list[2].sum()
loss.backward()
# normal case
D = 30
B = 8
nt = random_nt_from_dims(
[B, None, D],
device=device,
dtype=torch.float32,
layout=torch.jagged,
requires_grad=True,
)
NUM_CHUNKS = 3
chunks = nt.chunk(NUM_CHUNKS, dim=-1)
self.assertEqual(len(chunks), NUM_CHUNKS)
for i in range(NUM_CHUNKS):
self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS)
# test chunk_backward
values = torch.randn(
5, 11, dtype=torch.float64, device=device, requires_grad=True
)
offsets = torch.tensor([0, 2, 3, 5], device=device)
def grad_test_func(values, offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
chunks = nt.chunk(3, dim=-1)
return chunks[0].values().sum()
assert gradcheck(
grad_test_func,
inputs=(values, offsets),
check_batched_grad=False,
)
# chunk on batch dim
chunks = nt.chunk(NUM_CHUNKS, dim=0)
self.assertEqual(len(chunks), NUM_CHUNKS)
chunk_size = math.ceil(B / NUM_CHUNKS)
for i in range(NUM_CHUNKS):
if i < NUM_CHUNKS - 1:
self.assertEqual(chunks[i].shape[0], chunk_size)
else:
self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1))
offsets_expected = (
nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1]
- nt._offsets[i * chunk_size]
)
self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))
# doesn't support backward for chunk (dim=0) yet
loss = (
chunks[0].values().sum()
+ chunks[1].values().sum()
+ chunks[2].values().sum()
)
loss.backward()
# chunk on ragged dim not supported
with self.assertRaisesRegex(
RuntimeError, "chunk.* not supported for NestedTensor on ragged dim"
):
nt.chunk(2, dim=1)
def test_squeeze(self, device):
B = 4
D = 6
# squeeze middle dim
nt = random_nt_from_dims(
[B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged
)
j0 = nt.shape[1]
for dim_arg in [-2, 2]:
out = nt.squeeze(dim_arg)
self.assertEqual(out.shape, (B, j0, D))
self.assertEqual(out.unsqueeze(-2), nt)
# squeeze last dim
nt = random_nt_from_dims(
[B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
)
j1 = nt.shape[1]
for dim_arg in [-1, 2]:
out = nt.squeeze(dim_arg)
self.assertEqual(out.shape, (B, j1))
self.assertEqual(out.unsqueeze(-1), nt)
# squeeze on batch dim not supported
with self.assertRaisesRegex(
RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"
):
nt.squeeze(0)
# squeeze on ragged dim not supported
with self.assertRaisesRegex(
RuntimeError, "squeeze.* not supported for NestedTensor on ragged dim"
):
nt.squeeze(1)
def test_binary_pointwise_broadcasting(self, device):
# (B, j0, 3, 4)
ts = self._get_list_for_jagged_tensor(
((2, 3, 4), 3, 4), device, requires_grad=True
)
# (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
# (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
# (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
# Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
t_sizes = (
(4,),
(1, 4),
(3, 1),
(1, 3, 1),
(1, 1, 1, 4),
# (1, 1, 1, 1, 4), (unsupported today)
)
def grad_test_func(t, *ts):
nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged)
out = nt + t
return out.values()
for t_size in t_sizes:
t = torch.rand(
t_size, requires_grad=True, device=device, dtype=torch.float64
)
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
def test_threshold_backward(self, device):
ts1 = self._get_list_for_jagged_tensor(
((2, 3, 4), 16), device=device, requires_grad=False
)
ts2 = self._get_list_for_jagged_tensor(
((2, 3, 4), 16), device=device, requires_grad=False
)
nt1, offsets = jagged_from_list(ts1, None)
nt2, offsets = jagged_from_list(ts2, offsets)
buf1 = nt1.values().detach().clone()
buf2 = nt2.values().detach().clone()
res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
self.assertEqual(res_dense, res_nt.values())
@onlyCUDA
@dtypes(torch.float32)
def test_record_stream(self, device, dtype):
def _create_nt():
values = torch.ones(1024, 4 * 1024, device="cuda")
offsets = torch.tensor([0, 500, 1024], device="cuda", dtype=torch.int64)
lengths = offsets.diff()
nt = torch.nested.nested_tensor_from_jagged(values, offsets, lengths)
data_ptrs = {
nt._values.data_ptr(),
nt._offsets.data_ptr(),
nt._lengths.data_ptr(),
}
return nt, data_ptrs
def fn(record_stream):
nt, data_ptrs = _create_nt()
s = torch.cuda.Stream()
with torch.cuda.stream(s):
# emulate doing something long via sleep
per_ms = 2e7
torch.cuda._sleep(int(per_ms * 100))
if record_stream:
nt.record_stream(s)
return data_ptrs
# expect memory reuse when record_stream() is not run
data_ptrs = fn(record_stream=False)
nt, nt_data_ptrs = _create_nt()
self.assertEqual(data_ptrs, nt_data_ptrs)
del nt
torch.cuda.synchronize()
# expect memory to be preserved (no reuse) when record_stream() is run
data_ptrs = fn(record_stream=True)
nt, nt_data_ptrs = _create_nt()
self.assertEqual(len(data_ptrs.intersection(nt_data_ptrs)), 0)
@dtypes(torch.float32)
@parametrize(
"func",
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
name_fn=get_op_name,
)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_jagged_op_different_output_shape_dim(
self, device, dtype, keepdim, requires_grad, components_require_grad, func
):
"""
Operator passes when reducing on valid reduction dimensions.
This test is for operators which return an output tensor with a shape different from the input tensor.
"""
if get_op_name(func) == "mean" and not keepdim:
return
op_name = get_op_name(func)
ts = self._get_list_for_jagged_tensor(
((2, 3, 4), 3, 4), device=device, requires_grad=True
) # (B, j0, 3, 4)
# verify correctness of shapes (assuming that ragged_idx == 1)
if op_name == "sum":
reduce_dims = (
((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged
((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch
((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch
((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch
(
(0, 1, 2, 3),
(),
(1, 1, 1, 1),
(0, 1, 2),
), # batch, ragged, non-batch, non-batch
((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch
) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
elif op_name == "mean":
reduce_dims = (
((2,), (3, None, 4), (3, None, 1, 4), (1,)),
((3,), (3, None, 3), (3, None, 3, 1), (2,)),
)
for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
out = func(nt, dim=rd, keepdim=keepdim)
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
if not torch.compiler.is_compiling(): # if not using torch dynamo
self.assertEqual(len(out.shape), len(ref_shape))
for o, r in zip(out.shape, ref_shape):
if r is not None:
self.assertEqual(o, r)
else:
self.assertTrue(isinstance(o, torch.SymInt))
# verify correctness of values
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True,
)
for tensor_list, reduce_dim_tuple in itertools.product(
tensor_lists, reduce_dims
):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
if nt.dim() > reduce_dim[-1]:
out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
if nt._ragged_idx in reduce_dim: # raggedness reduced away
out_expected = func(
nt.values(), dim=reduce_dim_expected, keepdim=keepdim
)
self.assertTrue(torch.allclose(out_actual, out_expected))
else: # raggedness preserved
out_expected = func(nt.values(), dim=reduce_dim_expected)
self.assertTrue(
torch.allclose(
out_actual.values().view(-1), out_expected.view(-1)
)
)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_softmax_dim(
self,
device,
dtype,
requires_grad,
components_require_grad,
):
"""
Softmax passes when reducing on valid reduction dimensions.
"""
ts = self._get_list_for_jagged_tensor(
((2, 3, 4), 3, 4), device=device, requires_grad=True
) # (B, j0, 3, 4)
output_shape = (3, None, 3, 4)
# verify correctness of shapes (assuming that ragged_idx == 1)
reduce_dims = (
(2, 1),
(3, 2),
) # (reduction dimension, effective reduction dimension for baseline)
for reduce_dim, _ in reduce_dims:
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
torch._dynamo.disable(self.assertEqual)(
len(out_actual.shape), len(output_shape)
) # disable if running on dynamo
for dim_actual, dim_expected in zip(out_actual.shape, output_shape):
if dim_expected is not None:
self.assertEqual(dim_actual, dim_expected)
else:
self.assertTrue(isinstance(dim_actual, torch.SymInt))
# verify correctness of values
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True,
)
for tensor_list, reduce_dim_tuple in itertools.product(
tensor_lists, reduce_dims
):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
reduce_dim, reduce_dim_expected = reduce_dim_tuple
if nt.dim() > reduce_dim:
out_actual = torch.nn.functional.softmax(
nt, dim=reduce_dim
) # nested tensor
out_expected = torch.nn.functional.softmax(
nt.values(), dim=reduce_dim_expected
) # dense tensor of dimensions 1 less than out_actual
self.assertTrue(
torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
)
@dtypes(torch.float32)
@parametrize(
"func",
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
name_fn=get_op_name,
)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_op_dim_reduce_ragged_idx_1_different_output_shape(
self, device, dtype, keepdim, requires_grad, components_require_grad, func
):
"""
Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
This test is for operators which return an output tensor with a shape different from the input tensor.
"""
if get_op_name(func) == "mean" and not keepdim:
return
op_name = get_op_name(func)
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
)
reduce_dim = (1,) # ragged
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
out_expected = torch.cat(
[func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()]
)
if keepdim:
out_expected = out_expected.unsqueeze(reduce_dim[0])
self.assertFalse(
out_actual.is_nested,
f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
) # output is a dense tensor
self.assertEqual(out_actual, out_expected)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_softmax_dim_reduce_ragged_idx_1(
self, device, dtype, requires_grad, components_require_grad
):
"""
Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
include_2d_tensor=True, # (B, *)
)
reduce_dim = 1 # ragged
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
out_expected = torch.cat(
[
torch.nn.functional.softmax(t, dim=reduce_dim - 1)
for t in nt.unbind()
]
)
self.assertTrue(
out_actual.is_nested,
"softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
) # output is a nested tensor
self.assertTrue(torch.allclose(out_actual.values(), out_expected))
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_softmax_reduce_batch_dim(
self, device, dtype, requires_grad, components_require_grad
):
"""
Softmax on NestedTensor fails when trying to reduce across batch dimension.
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
)
reduce_dim = 0 # batch
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
with self.assertRaisesRegex(
RuntimeError,
"not supported when reducing across the batch dimension for NestedTensor",
):
out = torch.nn.functional.softmax(nt, dim=reduce_dim)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_layer_norm_reduce_ragged_idx_1(
self, device, dtype, requires_grad, components_require_grad
):
"""
Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
"""
# requires_grad = False does not currently work with dynamo tests and throws this error:
# AssertionError: SymInts must use SymNodeVariable.
# If the underlying value is static, we will create a ConstantVariable and specialize.
if torch._dynamo.is_compiling() and not requires_grad:
return
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
)
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if (
nt.dim() >= 3
): # layer norm only works for tensors with 3 or more dimensions
normalized_shape = nt.shape[nt._ragged_idx :]
out_actual = torch.nn.functional.layer_norm(
nt, normalized_shape=normalized_shape
)
out_expected = torch.cat(
[
torch.nn.functional.layer_norm(t, normalized_shape=t.shape)
for t in nt.unbind()
]
) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)
self.assertTrue(
out_actual.is_nested,
"layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
) # output is a nested tensor
self.assertEqual(out_actual._values.shape, out_expected.shape)
self.assertTrue(torch.allclose(out_actual.values(), out_expected))
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_layer_norm_2d_input(
self,
device,
dtype,
requires_grad,
components_require_grad,
):
"""
Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
include_2d_tensor=True, # (B, *)
)
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if nt.dim() <= 2:
with self.assertRaisesRegex(
RuntimeError,
"not supported for NestedTensor objects with 2 or fewer dimensions",
):
out = torch.nn.functional.layer_norm(
nt, normalized_shape=(nt.shape[nt._ragged_idx],)
)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_layer_norm_operate_on_batch_dim(
self,
device,
dtype,
requires_grad,
components_require_grad,
):
"""
Layer normalization on NestedTensor fails when trying to operate on the batch dimension
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
include_2d_tensor=True, # (B, *)
)
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if nt.dim() > 2: # cannot perform layer normalization on 2D tensors
with self.assertRaisesRegex(
RuntimeError,
"not supported when normalizing over the batch dimension for NestedTensor",
):
out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape)
@dtypes(torch.float32)
@parametrize(
"func",
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
name_fn=get_op_name,
)
@parametrize(
"transpose_offset", [1, 2]
) # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
self,
device,
dtype,
keepdim,
requires_grad,
components_require_grad,
func,
transpose_offset,
):
"""
Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
This test is for operators which return an output tensor with a shape different from the input tensor.
"""
if get_op_name(func) == "mean" and not keepdim:
return
op_name = get_op_name(func)
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
include_2d_tensor=True, # (B, *)
)
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if nt.dim() > nt._ragged_idx + transpose_offset:
nt_transposed = nt.transpose(
nt._ragged_idx, nt._ragged_idx + transpose_offset
)
reduce_dim = (nt_transposed._ragged_idx,) # ragged
out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim)
out_expected = torch.cat(
[
func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
for t in nt_transposed.unbind()
]
)
if keepdim:
out_expected = out_expected.unsqueeze(reduce_dim[0])
self.assertFalse(
out_actual.is_nested,
f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
) # output is a dense tensor
self.assertEqual(out_actual, out_expected)
@dtypes(torch.float32)
@parametrize(
"transpose_offset", [1, 2]
) # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
self,
device,
dtype,
requires_grad,
components_require_grad,
transpose_offset,
):
"""
Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
This test is for operators which return an output tensor with the same shape as the input tensor.
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
)
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if nt.dim() > nt._ragged_idx + transpose_offset:
nt_transposed = nt.transpose(
nt._ragged_idx, nt._ragged_idx + transpose_offset
)
reduce_dim = nt_transposed._ragged_idx # ragged
with self.assertRaisesRegex(
RuntimeError,
"not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor",
):
out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim)
@dtypes(torch.float32)
@parametrize(
"func",
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
name_fn=get_op_name,
)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_op_dim_transpose_non_ragged_dim_different_output_shape(
self, device, dtype, keepdim, requires_grad, components_require_grad, func
):
"""
Operator passes when reducing transposed nested tensors on valid reduction dimensions.
This test is for operators which return an output tensor with a shape different from the input tensor.
"""
if get_op_name(func) == "mean" and not keepdim:
return
# verify correctness of shapes (assuming that ragged_idx == 1)
if get_op_name(func) == "sum":
reduce_dims = (
((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged
((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch
((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch
((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch
(
(0, 1, 2, 3),
(),
(1, 1, 1, 1),
(0, 1, 2),
), # batch, ragged, non-batch, non-batch
((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch
) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
elif get_op_name(func) == "mean":
reduce_dims = (
((2,), (3, None, 4), (3, None, 1, 4), (1,)),
((3,), (3, None, 3), (3, None, 3, 1), (2,)),
)
# verify correctness of values
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
)
for tensor_list, reduce_dim_tuple in itertools.product(
tensor_lists, reduce_dims
):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
).transpose(-1, -2)
reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
if nt.dim() > max(
reduce_dim[-1], nt._ragged_idx + 2
): # ensure that transposed dimensions are non-batch, non-ragged dimensions
out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
if nt._ragged_idx in reduce_dim: # raggedness reduced away
out_expected = func(
nt.values(), dim=reduce_dim_expected, keepdim=keepdim
)
self.assertTrue(torch.allclose(out_actual, out_expected))
else: # raggedness preserved
out_expected = func(nt.values(), dim=reduce_dim_expected)
self.assertTrue(
torch.allclose(
out_actual.values().view(-1), out_expected.view(-1)
)
)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_softmax_dim_transpose_non_ragged_dim(
self,
device,
dtype,
requires_grad,
components_require_grad,
):
"""
Softmax passes when reducing transposed nested tensors on valid reduction dimensions.
This test is for operators which return an output tensor with the same shape as the input tensor.
"""
# verify correctness of shapes (assuming that ragged_idx == 1)
reduce_dims = (
(2, 1),
(3, 2),
) # (reduction dimension, effective reduction dimension for baseline)
# verify correctness of values
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False,
include_requires_grad=components_require_grad,
include_inner_dim_size_1=True, # (B, *, 1)
)
for tensor_list, reduce_dim_tuple in itertools.product(
tensor_lists, reduce_dims
):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
).transpose(-1, -2)
reduce_dim, reduce_dim_expected = reduce_dim_tuple
if nt.dim() > max(reduce_dim, nt._ragged_idx + 2):
out_actual = torch.nn.functional.softmax(
nt, dim=reduce_dim
) # nested tensor
out_expected = torch.nn.functional.softmax(
nt.values(), dim=reduce_dim_expected
) # dense tensor of dimensions 1 less than out_actual
self.assertTrue(
torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
)
@dtypes(torch.float32)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_sum_dim_reduce_ragged_and_non_batch(
self,
device,
dtype,
keepdim,
requires_grad,
components_require_grad,
):
"""
Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False, include_requires_grad=components_require_grad
)
reduce_dims = (
(1, 2), # ragged, non-batch
(1, 3), # ragged, non-batch
)
for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if nt.dim() > reduce_dim[-1]:
with self.assertRaisesRegex(
RuntimeError,
"reducing along a ragged and non-batch dimension is not supported",
):
out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
@dtypes(torch.float32)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_sum_dim_reduce_batch_and_non_batch(
self,
device,
dtype,
keepdim,
requires_grad,
components_require_grad,
):
"""
Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
"""
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False, include_requires_grad=components_require_grad
)
reduce_dims = (
(0, 2), # batch, non-batch
(0, 3), # batch, non-batch
)
for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
if nt.dim() > reduce_dim[-1]:
with self.assertRaisesRegex(
RuntimeError,
"reducing along the batch dimension but not the ragged dimension "
+ "is not supported",
):
out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
@dtypes(torch.float32)
@parametrize(
"func",
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
name_fn=get_op_name,
)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_op_dim_reduce_batch_only_different_output_shape(
self, device, dtype, keepdim, requires_grad, components_require_grad, func
):
"""
Operator on NestedTensor fails when trying to reduce across batch dimension
"""
if get_op_name(func) == "mean" and not keepdim:
return
tensor_lists = self._get_example_tensor_lists(
include_list_of_lists=False, include_requires_grad=components_require_grad
)
reduce_dim = (0,) # batch
for tensor_list in tensor_lists:
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
with self.assertRaisesRegex(
RuntimeError,
"reducing along the batch dimension but not the ragged dimension "
+ "is not supported",
):
out = func(nt, dim=reduce_dim, keepdim=keepdim)
@dtypes(torch.float32)
@parametrize(
"func",
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
name_fn=get_op_name,
)
@parametrize("keepdim", [False, True])
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_op_dim_with_lengths_different_output_shape(
self,
device,
dtype,
keepdim,
requires_grad,
components_require_grad,
func,
):
"""
Operator on NestedTensor fails when trying to reduce a nested tensor with lengths,
i.e. a nested tensor with holes, if reducing on the ragged dimension.
This test is for operators which return an output tensor with different shape than the input tensor.
"""
if get_op_name(func) == "mean" and not keepdim:
return
reduce_dims = ((1,), (2,), (2, 3))
lengths = torch.randint(5, 10, (20,), device=device)
offsets = torch.zeros((21,), device=device, dtype=torch.int)
torch.cumsum(lengths, dim=0, out=offsets[1:])
values = torch.randn(
(offsets[-1].item(), 20),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
nt_with_holes = torch.nested.nested_tensor_from_jagged(
values,
offsets,
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
)
for reduce_dim in reduce_dims:
if nt_with_holes.dim() > reduce_dim[-1]:
if nt_with_holes._ragged_idx in reduce_dim:
with self.assertRaisesRegex(
RuntimeError,
"reducing across the ragged dimension is not supported for "
+ "non-contiguous nested tensors with holes",
):
out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
else:
out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_softmax_dim_with_lengths(
self,
device,
dtype,
requires_grad,
components_require_grad,
):
"""
Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths,
i.e. a nested tensor with holes, if reducing on the ragged dimension.
"""
reduce_dims = (1, 2, 3)
lengths = torch.randint(5, 10, (20,), device=device)
offsets = torch.zeros((21,), device=device, dtype=torch.int)
torch.cumsum(lengths, dim=0, out=offsets[1:])
values = torch.randn(
(offsets[-1].item(), 20),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
nt_with_holes = torch.nested.nested_tensor_from_jagged(
values,
offsets,
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
)
for reduce_dim in reduce_dims:
if nt_with_holes.dim() > reduce_dim:
if nt_with_holes._ragged_idx == reduce_dim:
with self.assertRaisesRegex(
RuntimeError,
"not supported where lengths is not None "
+ "if reducing across the ragged dimension for NestedTensor",
):
out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
else:
out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
@skipIfTorchDynamo(
"ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
+ "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
+ "If the underlying value is static, we will create a ConstantVariable and specialize.`"
)
@dtypes(torch.float32)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_layer_norm_with_lengths(
self,
device,
dtype,
requires_grad,
components_require_grad,
):
"""
Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
i.e. a nested tensor with holes, if operating on the ragged dimension.
"""
# create components for nested tensor
lengths = torch.randint(5, 10, (20,), device=device)
offsets = torch.zeros((21,), device=device, dtype=torch.int)
torch.cumsum(lengths, dim=0, out=offsets[1:])
values = torch.randn(
(offsets[-1].item(), 10, 30),
device=device,
dtype=dtype,
requires_grad=requires_grad,
)
nt_with_holes = torch.nested.nested_tensor_from_jagged(
values,
offsets,
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
)
ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx]
normalized_shapes = (
(10, 30), # normalization on non-ragged dimension passes
(ragged_size, 10, 30), # normalization on ragged dimension fails
)
for normalized_shape in normalized_shapes:
if ragged_size in normalized_shape:
with self.assertRaisesRegex(
RuntimeError,
"not supported where lengths is not None if operating on the ragged dimension for NestedTensor",
):
out = torch.nn.functional.layer_norm(
nt_with_holes, normalized_shape=normalized_shape
)
else:
out = torch.nn.functional.layer_norm(
nt_with_holes, normalized_shape=normalized_shape
)
@unittest.skipIf(
PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
)
@onlyCUDA
def test_pin_memory(self, device):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
for nt in [nt_contiguous, nt_noncontiguous]:
self.assertFalse(nt.is_pinned())
pinned = nt.pin_memory()
self.assertTrue(pinned.is_pinned())
self.assertEqual(nt, pinned)
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
# test that pin_memory on already pinned tensor has no effect
self.assertIs(pinned, pinned.pin_memory())
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
@torch.compiler.disable
def _validate_nt(
self,
nt,
device,
dtype,
layout,
requires_grad,
dim,
batch_size,
contiguous,
cached_min_seqlen=None,
cached_max_seqlen=None,
base=None,
ref_nt=None,
):
# Validate a bunch of properties after NT construction.
device = torch.device(device)
self.assertEqual(nt.dim(), dim)
self.assertEqual(nt.device, device)
self.assertEqual(nt.dtype, dtype)
self.assertEqual(nt.layout, layout)
self.assertEqual(nt.requires_grad, requires_grad)
self.assertEqual(nt.is_contiguous(), contiguous)
if layout == torch.jagged:
self.assertEqual(nt._values.device, device)
self.assertEqual(nt._offsets.device, device)
self.assertEqual(nt.shape[0], batch_size)
self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
if base is not None:
self.assertTrue(nt._is_view() and nt._base is base)
replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache
self.assertEqual(
"min_seqlen" in replay_cache, cached_min_seqlen is not None
)
self.assertEqual(
"max_seqlen" in replay_cache, cached_max_seqlen is not None
)
self.assertEqual(
"min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None
)
self.assertEqual(
"max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None
)
if cached_min_seqlen is not None:
self.assertEqual(nt._min_seqlen, cached_min_seqlen)
if cached_max_seqlen is not None:
self.assertEqual(nt._max_seqlen, cached_max_seqlen)
if ref_nt is not None:
self.assertEqual(nt.size(0), ref_nt.size(0))
for n1, n2 in zip(nt.unbind(), ref_nt.unbind()):
self.assertEqual(n1, n2)
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("components_require_grad", [False, True])
def test_jagged_layout_construction_nested_tensor(
self, device, dtype, requires_grad, components_require_grad
):
for tensor_list in self._get_example_tensor_lists(
include_list_of_lists=True, include_requires_grad=components_require_grad
):
nt = torch.nested.nested_tensor(
tensor_list,
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=requires_grad,
)
expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
expected_batch_size = len(tensor_list)
expected_contiguous = True
expected_min_seqlen = min(
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
expected_max_seqlen = max(
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
self._validate_nt(
nt,
device,
dtype,
torch.jagged,
requires_grad,
expected_dim,
expected_batch_size,
expected_contiguous,
expected_min_seqlen,
expected_max_seqlen,
)
# Make sure grads -don't- flow back into original tensors for nested_tensor()
if requires_grad:
(nt * 2).backward(torch.ones_like(nt))
for t in tensor_list:
t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
self.assertTrue(t.grad is None)
@dtypes(torch.float, torch.double, torch.half)
@parametrize("components_require_grad", [False, True])
def test_jagged_layout_construction_as_nested_tensor(
self, device, dtype, components_require_grad
):
# NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
for tensor_list in self._get_example_tensor_lists(
include_list_of_lists=False, include_requires_grad=components_require_grad
):
nt = torch.nested.as_nested_tensor(
tensor_list, device=device, dtype=dtype, layout=torch.jagged
)
# nt.requires_grad=True should be set if at least one component requires grad
expected_dim = tensor_list[0].dim() + 1
expected_batch_size = len(tensor_list)
expected_contiguous = True
expected_min_seqlen = min(
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
expected_max_seqlen = max(
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
self._validate_nt(
nt,
device,
dtype,
torch.jagged,
components_require_grad,
expected_dim,
expected_batch_size,
expected_contiguous,
expected_min_seqlen,
expected_max_seqlen,
)
# Make sure grads flow back into original tensors for as_nested_tensor()
if components_require_grad:
(nt * 2).backward(torch.ones_like(nt))
for t in tensor_list:
if t.requires_grad:
self.assertEqual(t.grad, torch.ones_like(t) * 2)
else:
self.assertTrue(t.grad is None)
@xfailIfTorchDynamo
@unittest.skipIf(
PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
)
@onlyCUDA
def test_jagged_layout_construction_with_pinned_memory(self, device):
for tensor_list in self._get_example_tensor_lists():
nt = torch.nested.nested_tensor(
tensor_list, layout=torch.jagged, device="cpu", pin_memory=True
)
expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
expected_batch_size = len(tensor_list)
expected_min_seqlen = min(
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
expected_max_seqlen = max(
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
for t in tensor_list
)
self._validate_nt(
nt,
device="cpu",
dtype=torch.float32,
layout=torch.jagged,
requires_grad=False,
dim=expected_dim,
batch_size=expected_batch_size,
contiguous=True,
cached_min_seqlen=expected_min_seqlen,
cached_max_seqlen=expected_max_seqlen,
)
self.assertTrue(nt.is_pinned())
@dtypes(torch.float, torch.double, torch.half)
@parametrize("requires_grad", [False, True])
@parametrize("values_is_view", [False, True])
def test_jagged_view_from_values_offsets(
self, device, dtype, requires_grad, values_is_view
):
if values_is_view:
# make values a view of base
base = torch.randn(
2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad
)
values = base.flatten(0, -2)
else:
values = torch.randn(
10, 5, device=device, dtype=dtype, requires_grad=requires_grad
)
offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
nt = nested_view_from_values_offsets(values, offsets)
expected_dim = values.dim() + 1
expected_batch_size = offsets.shape[0] - 1
expected_base = base if values_is_view else values
lengths = offsets.diff()
self._validate_nt(
nt,
device,
dtype,
torch.jagged,
requires_grad,
expected_dim,
expected_batch_size,
# ensure NT is a proper view
base=expected_base,
contiguous=True,
# if no min / max are passed, expect the metadata cache to be empty
cached_min_seqlen=None,
cached_max_seqlen=None,
)
if requires_grad:
# Make sure grads flow back
(nt * 2).backward(torch.ones_like(nt))
@torch.compiler.disable
def _check_grad(t):
self.assertTrue(t.grad is not None)
self.assertEqual(t.grad, torch.ones_like(t) * 2)
_check_grad(base if values_is_view else values)
@dtypes(torch.float)
@parametrize("pass_min_max", [False, True])
def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max):
# === construct from (values, offsets) ===
values = torch.randn(10, 5, device=device, dtype=dtype)
offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
# compute min / max seqlen
lengths = offsets.diff()
min_seqlen = lengths.min().item()
max_seqlen = lengths.max().item()
if pass_min_max:
nt = torch.nested.nested_tensor_from_jagged(
values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
)
else:
nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets)
self._validate_nt(
nt,
device,
dtype,
torch.jagged,
requires_grad=False,
dim=3,
batch_size=4,
contiguous=True,
cached_min_seqlen=(min_seqlen if pass_min_max else None),
cached_max_seqlen=(max_seqlen if pass_min_max else None),
base=values,
)
# === construct from (values, offsets, lengths) ===
lengths = torch.tensor([2, 1, 1, 2], device=device)
# compute min / max seqlen
min_seqlen = lengths.min().item()
max_seqlen = lengths.max().item()
if pass_min_max:
nt = torch.nested.nested_tensor_from_jagged(
values,
offsets=offsets,
lengths=lengths,
min_seqlen=min_seqlen,
max_seqlen=max_seqlen,
)
else:
nt = torch.nested.nested_tensor_from_jagged(
values, offsets=offsets, lengths=lengths
)
# when both offsets / lengths are specified, expect non-contiguous
self._validate_nt(
nt,
device,
dtype,
torch.jagged,
requires_grad=False,
dim=3,
batch_size=4,
contiguous=False,
cached_min_seqlen=(min_seqlen if pass_min_max else None),
cached_max_seqlen=(max_seqlen if pass_min_max else None),
base=values,
)
self.assertIs(nt.lengths(), lengths)
# === construct from (values, lengths) ===
values = torch.randn(14, 5, device=device, dtype=dtype)
lengths = torch.tensor([2, 3, 4, 5], device=device)
# compute min / max seqlen
min_seqlen = lengths.min().item()
max_seqlen = lengths.max().item()
if pass_min_max:
nt = torch.nested.nested_tensor_from_jagged(
values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen
)
else:
nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
# for now, if only lengths is specified, convert to offsets to integrate best with the
# existing kernels
expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device)
expected_nt = torch.nested.nested_tensor_from_jagged(
values, offsets=expected_offsets
)
self._validate_nt(
nt,
device,
dtype,
torch.jagged,
requires_grad=False,
dim=3,
batch_size=4,
contiguous=True,
cached_min_seqlen=(min_seqlen if pass_min_max else None),
cached_max_seqlen=(max_seqlen if pass_min_max else None),
base=values,
ref_nt=expected_nt,
)
# error case: no offsets or lengths
with self.assertRaisesRegex(
RuntimeError, "At least one of offsets or lengths is required"
):
torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)
@onlyCPU
def test_nested_tensor_from_jagged_fx_trace(self, device):
def fn(x, y):
return torch.nested.nested_tensor_from_jagged(x, y)
def user_unwrapped(x, y):
return fn(x, y)
with self.assertRaisesRegex(
RuntimeError,
"torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace",
):
torch.fx.symbolic_trace(user_unwrapped)
@dtypes(torch.float, torch.double, torch.half)
@parametrize("dim", range(5))
@parametrize(
"layout",
[torch.strided, torch.jagged],
name_fn=lambda l: f"layout_{str(l).split('.')[1]}",
)
@parametrize("requires_grad", [False, True])
@parametrize("contiguous", [False, True])
def test_as_nested_tensor_from_tensor(
self, device, dtype, dim, layout, requires_grad, contiguous
):
if dim == 0:
t = torch.tensor(3.0, requires_grad=requires_grad)
else:
t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad)
assert t.dim() == dim
if dim < 2:
# 0-1 dim tensors can't be converted to NTs
with self.assertRaisesRegex(
RuntimeError, "Expected tensor argument to have dim"
):
nt = torch.nested.as_nested_tensor(
t, device=device, dtype=dtype, layout=layout
)
return
orig_t = t
if not contiguous:
t = t.transpose(0, 1)
nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout)
expected_dim = t.dim()
expected_batch_size = t.size(0)
expected_seqlen = t.size(1) if layout == torch.jagged else None
self._validate_nt(
nt,
device,
dtype,
layout,
requires_grad=requires_grad,
dim=dim,
batch_size=expected_batch_size,
contiguous=True,
cached_min_seqlen=expected_seqlen,
cached_max_seqlen=expected_seqlen,
)
if torch.device(device) == t.device and dtype == t.dtype and contiguous:
# should be the non-copying (view) case
self.assertTrue(nt._is_view() and nt._base is t)
# should have equivalent components to construction from unbound tensor list
nt_from_unbind = torch.nested.as_nested_tensor(
list(t.unbind(0)), device=device, dtype=dtype, layout=layout
)
self.assertEqualIgnoringNestedInts(nt, nt_from_unbind)
# ensure call on a NT with the same properties returns the NT directly
nt2 = torch.nested.as_nested_tensor(
nt, device=device, dtype=dtype, layout=layout
)
self.assertTrue(nt is nt2)
# ensure call with device=None uses input tensor device
nt3 = torch.nested.as_nested_tensor(
t.to(device=device, dtype=dtype),
device=None,
dtype=None,
layout=layout,
)
self._validate_nt(
nt3,
device,
dtype,
layout,
requires_grad=requires_grad,
dim=dim,
batch_size=expected_batch_size,
contiguous=True,
cached_min_seqlen=expected_seqlen,
cached_max_seqlen=expected_seqlen,
)
# we don't support conversion between layouts this way atm
other_layout = torch.strided if layout == torch.jagged else torch.jagged
with self.assertRaisesRegex(
RuntimeError, "Converting between nested tensor layouts is not supported"
):
torch.nested.as_nested_tensor(
nt, device=device, dtype=dtype, layout=other_layout
)
if requires_grad:
# make sure gradients flow back into inputs
(nt * 2).backward(torch.ones_like(nt))
self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
@dtypes(torch.float32)
def test_construction_from_list(self, device, dtype):
from torch.fx.experimental.symbolic_shapes import is_nested_int
# success case: single ragged dim anywhere but the batch dim
for nt_dim in [2, 3, 4]:
for ragged_dim in range(1, nt_dim):
B = 6
shapes = [list(range(3, 3 + nt_dim - 1)) for _ in range(B)]
for b in range(B):
# subtract 1 to convert to component dim space
shapes[b][ragged_dim - 1] = torch.randint(
2, 9, (1,), device=device, dtype=torch.int64
).item()
components = [
torch.randn(shape, device=device, dtype=dtype) for shape in shapes
]
nt = torch.nested.nested_tensor(components, layout=torch.jagged)
self.assertEqual(nt.dim(), nt_dim)
self.assertEqual(nt._ragged_idx, ragged_dim)
for d in range(nt_dim):
self.assertEqual(d == ragged_dim, is_nested_int(nt.shape[d]))
# error case: empty list
with self.assertRaisesRegex(
RuntimeError, "Cannot construct a nested tensor from an empty tensor list"
):
torch.nested.nested_tensor([], layout=torch.jagged)
# error case: list of zero-dim tensors
with self.assertRaisesRegex(
RuntimeError,
"Cannot construct a nested tensor from a list of zero-dim tensors",
):
torch.nested.nested_tensor(
[
torch.tensor(3.0, device=device, dtype=dtype),
torch.tensor(4.0, device=device, dtype=dtype),
torch.tensor(5.0, device=device, dtype=dtype),
],
layout=torch.jagged,
)
# error case: multiple ragged dims
with self.assertRaisesRegex(
RuntimeError,
"Cannot represent given tensor list as a nested tensor with the jagged layout",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(4, 5, device=device, dtype=dtype),
],
layout=torch.jagged,
)
# error case: components on multiple devices
if "cuda" in device:
with self.assertRaisesRegex(
RuntimeError,
"When constructing a nested tensor, all tensors in list must be on the same device",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(2, 4, device="cpu", dtype=dtype),
],
layout=torch.jagged,
)
# error case: components with multiple dtypes
with self.assertRaisesRegex(
RuntimeError,
"When constructing a nested tensor, all tensors in list must have the same dtype",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(2, 4, device=device, dtype=torch.float64),
],
layout=torch.jagged,
)
# error case: components with multiple dims
with self.assertRaisesRegex(
RuntimeError,
"When constructing a nested tensor, all tensors in list must have the same dim",
):
torch.nested.nested_tensor(
[
torch.randn(2, 3, device=device, dtype=dtype),
torch.randn(2, 3, 4, device=device, dtype=dtype),
],
layout=torch.jagged,
)
@dtypes(torch.double, torch.half)
@onlyCUDA
def test_device_dtype_transfer_updates_offsets(self, device, dtype):
for tensor_list in self._get_example_tensor_lists():
orig_device = torch.device("cpu")
orig_dtype = torch.float32
nt = torch.nested.nested_tensor(
tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype
)
self.assertEqual(torch.int64, nt.offsets().dtype)
nt = nt.to(device=device).to(dtype=dtype)
# offsets should still be int64 on the new device
self.assertEqual(nt.values().device, nt.offsets().device)
self.assertEqual(torch.int64, nt.offsets().dtype)
def test_unbind(self, device):
for tensor_list in self._get_example_tensor_lists():
nt = torch.nested.nested_tensor(
tensor_list, layout=torch.jagged, device=device
) # ragged_idx = 1
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])
@parametrize("ragged_idx", [2, 3])
def test_unbind_transpose(self, device, ragged_idx):
for tensor_list in self._get_example_tensor_lists():
nt = torch.nested.nested_tensor(
tensor_list, layout=torch.jagged, device=device
)
if ragged_idx < nt.dim():
nt = nt.transpose(1, ragged_idx) # set ragged_idx
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(
t.transpose(0, ragged_idx - 1), tensor_list[i]
) # transpose back each element of result
def test_unbind_transpose_ragged_idx_last_dim(self, device):
for tensor_list in self._get_example_tensor_lists():
nt = torch.nested.nested_tensor(
tensor_list, layout=torch.jagged, device=device
).transpose(1, -1) # set ragged_idx = last dimension
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(
t.transpose(0, -1), tensor_list[i]
) # transpose back each element of result
def test_unbind_lengths(self, device):
values = torch.randn(16, 128, device=device)
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
lengths = torch.tensor([6, 2, 1, 2], device=device)
nt = torch.nested.nested_tensor_from_jagged(
values, offsets=offsets, lengths=lengths
) # 3D nested tensor
tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])])
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])
def test_unbind_lengths_ragged_idx_1(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
lengths = torch.tensor([6, 2, 1, 2], device=device)
ragged_idx = 1
nt = torch.nested._internal.nested_tensor.NestedTensor(
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
) # 4D nested tensor
tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :])
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])
def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
lengths = torch.tensor([6, 2, 1, 2], device=device)
ragged_idx = 2
nt = torch.nested._internal.nested_tensor.NestedTensor(
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
) # 4D nested tensor
self.assertRaisesRegex(
RuntimeError,
r"unbind\(\): nested tensor offsets and lengths.*",
lambda: nt.unbind(),
)
def test_unbind_lengths_ragged_idx_2(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 2, 4, 8], device=device)
lengths = torch.tensor([2, 1, 3], device=device)
ragged_idx = 2
nt = torch.nested._internal.nested_tensor.NestedTensor(
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
) # 4D nested tensor
tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :])
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])
def test_unbind_lengths_ragged_idx_3(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 100, 128], device=device)
lengths = torch.tensor([50, 28], device=device)
ragged_idx = 3
nt = torch.nested._internal.nested_tensor.NestedTensor(
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
) # 4D nested tensor
tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
out = nt.unbind()
self.assertEqual(len(out), len(tensor_list))
for i, t in enumerate(out):
self.assertEqual(t, tensor_list[i])
@skipIfTorchDynamo(
"TorchDynamo raises an error for ragged_idx == 0 earlier than Torch"
)
def test_unbind_lengths_ragged_idx_0(self, device):
values = torch.randn(16, 8, 128, device=device)
offsets = torch.tensor([0, 100, 128], device=device)
lengths = torch.tensor([50, 28], device=device)
ragged_idx = 0
nt = torch.nested._internal.nested_tensor.NestedTensor(
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
) # 4D nested tensor
tensor_list = []
for i in range(offsets.shape[0] - 1):
tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
self.assertRaisesRegex(
RuntimeError,
r"unbind\(\): nested tensor.*out of bounds",
lambda: nt.unbind(),
)
def test_narrow(self, device):
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
buffer = (
torch.arange(0, 10, device=device, dtype=torch.int64)
.unsqueeze(0)
.expand(5, -1)
.clone()
.detach()
)
nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged)
self.assertTrue(nt._is_view() and nt._base is buffer)
# TODO: Use this approach when unbind is functional
# unbinded_nt = nt.unbind()
# for i in range(starts.shape[0]):
# self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
for i in range(starts.shape[0]):
self.assertEqual(
torch.arange(
starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64
),
nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
)
def test_njt_cat(self, device):
offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
values_1 = torch.randn(
3, 2, dtype=torch.float64, device=device, requires_grad=True
)
values_2 = torch.randn(
3, 4, dtype=torch.float64, device=device, requires_grad=True
)
def grad_test_func(values_1, values_2, offsets):
nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets)
nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets)
nt_3 = torch.cat([nt_1, nt_2], dim=-1)
return nt_3.values()
assert gradcheck(
grad_test_func,
inputs=(values_1, values_2, offsets),
check_batched_grad=False,
)
def test_is_contiguous(self, device):
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
narrow_base = (
torch.arange(0, 10, device=device, dtype=torch.int64)
.unsqueeze(0)
.expand(5, -1)
.clone()
)
nt_noncontiguous = torch.nested.narrow(
narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged
)
starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
nt_contiguous_narrow = torch.nested.narrow(
narrow_base, 1, starts_c, lengths_c, layout=torch.jagged
)
# Test contiguous case
assert nt_contiguous.is_contiguous()
# Test narrow case
assert not nt_noncontiguous.is_contiguous()
assert nt_contiguous_narrow.is_contiguous()
# Test querying by memory_format
self.assertTrue(
nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
)
self.assertTrue(
not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
)
self.assertTrue(
nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)
)
def test_layout_under_torch_dispatch_mode(self):
from torch.testing._internal.logging_tensor import (
capture_logs_with_logging_tensor_mode,
)
nt = random_nt_from_dims(
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
)
with capture_logs_with_logging_tensor_mode():
self.assertEqual(nt.layout, torch.jagged)
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@parametrize(
"func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__
)
def test_like_shape(self, func):
nt = random_nt_from_dims(
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
)
nt_like = func(nt)
for nt_ub in nt_like.unbind():
t_like = func(nt_ub)
self.assertEqual(nt_ub.shape, t_like.shape)
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@parametrize(
"func",
[
torch.empty_like,
torch.full_like,
torch.ones_like,
torch.rand_like,
torch.randint_like,
torch.randn_like,
torch.zeros_like,
],
name_fn=lambda f: f.__name__,
)
def test_like_value(self, func, device):
dtype = torch.float32 if func is not torch.randint_like else torch.int32
for nt in _sample_njts(device=device, dtype=dtype):
extra_kwarg_sets = [{}]
if func is torch.full_like:
extra_kwarg_sets = [{"fill_value": 4.2}]
elif func is torch.randint_like:
extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}]
# only test changing dtype / device from CUDA -> CPU because CUDA might not be
# available when running this test for CPU
change_dtype_device_settings = (
[False, True] if "cuda" in device else [False]
)
for change_dtype_device in change_dtype_device_settings:
if change_dtype_device:
new_dtype = (
torch.float64 if func is not torch.randint_like else torch.int64
)
new_device = "cpu" if "cuda" in device else device
new_layout = torch.strided
for extra_kwargs in extra_kwarg_sets:
extra_kwargs.update(
{
"dtype": new_dtype,
"device": new_device,
"layout": new_layout,
}
)
for extra_kwargs in extra_kwarg_sets:
nt_like = func(nt, **extra_kwargs)
self.assertEqual(nt.shape, nt_like.shape)
if change_dtype_device:
self.assertNotEqual(nt.device, nt_like.device)
self.assertNotEqual(nt.device, nt_like.dtype)
# layout should be ignored since only torch.jagged is supported
self.assertEqual(torch.jagged, nt_like.layout)
else:
self.assertEqual(nt.device, nt_like.device)
self.assertEqual(nt.dtype, nt_like.dtype)
self.assertEqual(nt.layout, nt_like.layout)
self.assertEqual(nt.layout, torch.jagged)
# don't bother trying to compare random or empty values
if func not in [
torch.empty_like,
torch.rand_like,
torch.randn_like,
torch.randint_like,
]:
for nt_ub in nt_like.unbind():
t_like = func(nt_ub, **extra_kwargs)
self.assertEqual(nt_ub, t_like)
def test_noncontiguous_pointwise(self, device):
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged)
# transpose ragged dim
transposed = nt.transpose(1, 2)
self.assertFalse(transposed.is_contiguous())
clone = transposed.clone()
def check_nt_equality(x, y):
self.assertEqual(x.values(), y.values())
self.assertEqual(x.offsets(), y.offsets())
self.assertEqual(x._ragged_idx, y._ragged_idx)
self.assertEqual(x.shape, y.shape)
self.assertFalse(clone.is_contiguous())
check_nt_equality(clone, transposed)
clone_contig = transposed.clone(memory_format=torch.contiguous_format)
self.assertTrue(clone_contig.is_contiguous())
check_nt_equality(clone_contig, transposed)
detached = transposed.detach()
self.assertFalse(clone.is_contiguous())
check_nt_equality(detached, transposed)
def test_permute(self, device):
nt = random_nt_from_dims(
[2, None, 3, 5], device, torch.float32, layout=torch.jagged
)
nt_shape = nt.shape
nt_inner_shape = nt.values().shape
with self.assertRaisesRegex(
ValueError,
r"permute\(\): number of dimensions in the tensor input \(4\) "
+ r"does not match the length of the desired ordering of dimensions \(3\).",
):
nt.permute(0, 2, 1)
with self.assertRaisesRegex(
ValueError, r"permute\(\): duplicate dims are not allowed."
):
nt.permute(0, 2, -2, 3)
with self.assertRaisesRegex(
ValueError, "Permute is not supported on the batch dimension for jagged NT"
):
nt.permute(1, 0, 2, 3)
nt_permute = nt.permute(0, 2, 1, -1)
self.assertEqual(
nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
)
self.assertEqual(
nt_permute.values().shape,
(nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
)
self.assertEqual(nt_permute._ragged_idx, 2)
self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)
def test_to_dtype(self, device):
nt = random_nt_from_dims(
[2, None, 3], device, torch.float32, layout=torch.jagged
)
nt_after = nt.to(torch.float64)
self.assertEqual(torch.float32, nt.dtype)
self.assertEqual(torch.float64, nt_after.dtype)
self.assertEqual(torch.float64, nt_after.values().dtype)
self.assertEqual(torch.int64, nt_after.offsets().dtype)
noncontiguous_nt = nt.transpose(1, 2)
noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)
def test_to_copy(self, device):
nt = torch.nested.nested_tensor(
[
torch.randn(
i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
)
for i in range(3)
],
layout=torch.jagged,
)
nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16)
self.assertEqual(torch.float16, nt_copy_dtype.dtype)
nt_t = nt.transpose(1, 2)
nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16)
self.assertEqual(torch.float16, nt_t_copy_dtype.dtype)
def test_copy_(self, device):
offsets = torch.tensor([0, 2, 4], device=device)
a = torch.nested.nested_tensor_from_jagged(
torch.zeros(4, 3, device=device), offsets
)
b = torch.nested.nested_tensor_from_jagged(
torch.ones(4, 3, device=device), offsets
)
a.copy_(b)
torch._dynamo.disable(self.assertEqual)(a, b)
offsets_2 = torch.tensor([0, 2, 4], device=device)
c = torch.nested.nested_tensor_from_jagged(
torch.ones(4, 3, device=device), offsets_2
)
# should work even though the nested ints are different due to unbound-based copy
a.copy_(c)
# fail when tensors have different sizes
a = a.transpose(1, 2)
with self.assertRaisesRegex(
RuntimeError,
"expected compatible input and src shapes, but got",
):
a.copy_(b)
# This can't happen in the opinfo tests due to subprocess creation
@unittest.skipIf(
TEST_WITH_ROCM,
"In ROCm, kernel asserts are disabled due to performance overhead",
)
def test_index_put_error(self, device):
import subprocess
with self.subTest():
r = subprocess.call(
[
sys.executable,
"-c",
"""\
import torch
offsets = torch.tensor([0, 2, 5, 7], device='cuda')
lengths = torch.tensor([2, 2, 2], device='cuda')
indices = [
torch.tensor([0, 1, 2], device='cuda'),
torch.tensor([0, 2, 1], device='cuda'),
torch.tensor([0, 0, 0], device='cuda'),
]
a = torch.nested.nested_tensor_from_jagged(
torch.zeros(7, 3, device='cuda'), offsets, lengths
)
a[indices] = 1.0
torch.cuda.synchronize()
""",
]
)
self.assertTrue(r != 0)
@skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
def test_profiler_sequence_nr(self):
with torch.profiler.profile() as prof:
values = torch.randn(4, 6, requires_grad=True)
offsets = torch.tensor([0, 2, 4])
values = values * 2
l = torch.nn.Linear(6, 8)
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
nt = l(nt)
val = nt.values()
loss = val.sum()
loss.backward()
fwd_seq_nrs = []
for evt in prof.events():
if (
"linear" in evt.name.lower()
and "backward" not in evt.name.lower()
and evt.sequence_nr != -1
):
fwd_seq_nrs.append(evt.sequence_nr)
bwd_seq_nrs = []
for evt in prof.events():
if (
"linear" in evt.name.lower()
and "backward" in evt.name.lower()
and "evaluate_function" not in evt.name.lower()
and evt.sequence_nr != -1
):
bwd_seq_nrs.append(evt.sequence_nr)
# There should only be one such event with a sequence number:
# the PythonTLSSnapshot event - but, note that it's not terrible if
# we end up with multiple events with the same sequence number - so we
# could relax this check if it becomes inconvenient to maintain this
# property.
self.assertEqual(len(fwd_seq_nrs), 1)
self.assertEqual(len(bwd_seq_nrs), 1)
self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0])
def test_is_same_size(self, device):
def get_3_tensors():
return [
torch.randn(
i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
)
for i in range(3)
]
nt1, offsets1 = jagged_from_list(get_3_tensors(), None)
nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1)
nt3, offsets2 = jagged_from_list(get_3_tensors(), None)
nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2)
def check_size(nt1, nt2, nt3, nt4):
self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2))
self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4))
self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3))
check_size(nt1, nt2, nt3, nt4)
nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
check_size(nt1_t, nt2_t, nt3_t, nt4_t)
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
def test_specialize_dynamic_shape(self, device):
values = torch.randn((18, 16), device=device)
offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
like_values = torch.randn_like(values)
# this marks values as dynamic
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
def fn(values, same_size):
# here, the dynamic shape is specialized by same_size's shape
# https://github.com/pytorch/pytorch/issues/127097
# make sure this doesn't error out in torch.compile
return values + same_size
self.assertEqual(
fn(values, like_values),
torch.compile(fn)(values, like_values),
)
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
def test_specialize_dynamic_shape_recompile(self, device):
def generate_inp(total_len):
values = torch.randn((total_len, 16), device=device)
offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
like_values = torch.randn_like(values)
return values, offsets, like_values
def check_results(ref_fn, res_fn, args):
values, offsets, like_values = args
# this may add dynamic shape markings
# goal of this test is to make sure that whatever markings are there,
# we eventually stop recompiling as shape changes.
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values))
def fn(values, same_size):
return values + same_size
compile_counter = torch._dynamo.testing.CompileCounter()
compiled_fn = torch.compile(fn, backend=compile_counter, fullgraph=True)
check_results(fn, compiled_fn, generate_inp(18))
self.assertEqual(compile_counter.frame_count, 1)
check_results(fn, compiled_fn, generate_inp(19))
# we'll probably recompile here with dynamic shapes - it's okay if not though.
frame_count_2 = compile_counter.frame_count
self.assertIn(frame_count_2, [1, 2])
# make sure that by now we've already compiled with dynamic shapes, so additional
# shapes should not trigger additional recompiles.
check_results(fn, compiled_fn, generate_inp(20))
self.assertEqual(compile_counter.frame_count, frame_count_2)
# Note 1: Math fallback doesn't work with bfloat16 on CUDA
# Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
@unittest.skipIf(
TEST_WITH_ROCM,
"ROCm doesn't support flash attention or mem_efficient attention for NT",
)
@tf32_on_and_off(0.005)
@dtypes(
*(
[torch.float16, torch.bfloat16, torch.float32]
if SM80OrLater
else [torch.float16, torch.float32]
)
)
def test_sdpa(self, device, dtype):
batch_size = 1
emb_dims = 128
n_heads = 8
head_dims = emb_dims // n_heads
sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
query = torch.nn.Linear(
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
)
key = torch.nn.Linear(
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
)
value = torch.nn.Linear(
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
)
# Simplest case: 1 sentence, no batching
x_d1 = sen1.unsqueeze(0)
x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged)
# See note below for why we detach here.
q_d1 = (
query(x_d1)
.view(batch_size, -1, n_heads, head_dims)
.detach()
.requires_grad_(True)
)
q_d1_t = q_d1.transpose(1, 2)
k_d1 = (
key(x_d1)
.view(batch_size, -1, n_heads, head_dims)
.detach()
.requires_grad_(True)
)
k_d1_t = k_d1.transpose(1, 2)
v_d1 = (
value(x_d1)
.view(batch_size, -1, n_heads, head_dims)
.detach()
.requires_grad_(True)
)
v_d1_t = v_d1.transpose(1, 2)
q_nt = (
query(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.requires_grad_(True)
)
q_nt_t = q_nt.transpose(1, 2)
k_nt = (
key(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.requires_grad_(True)
)
k_nt_t = k_nt.transpose(1, 2)
v_nt = (
value(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.requires_grad_(True)
)
v_nt_t = v_nt.transpose(1, 2)
# High Precision Math Reference
q_d1_f32 = q_d1.to(torch.float32)
k_d1_f32 = k_d1.to(torch.float32)
v_d1_f32 = v_d1.to(torch.float32)
q_d1_f32_t = q_d1_f32.transpose(1, 2)
k_d1_f32_t = k_d1_f32.transpose(1, 2)
v_d1_f32_t = v_d1_f32.transpose(1, 2)
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
q_d1_f32_t, k_d1_f32_t, v_d1_f32_t
)[0]
grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32))
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
q_d1_t, k_d1_t, v_d1_t
)[0]
grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1))
# Compute tolerances
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
# fudge factor of 1.7 for smaller GPUs e.g., A2, A16
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(
grads_ref[0], grads_lp_ref[0], 1.7
)
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1])
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2])
grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol]
grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol]
attn_d1 = torch.nn.functional.scaled_dot_product_attention(
q_d1_t, k_d1_t, v_d1_t
).transpose(1, 2)
attn_nt = torch.nn.functional.scaled_dot_product_attention(
q_nt_t, k_nt_t, v_nt_t
).transpose(1, 2)
self.assertEqual(
attn_d1,
attn_nt.unbind()[0].unsqueeze(0),
atol=output_ref_atol,
rtol=output_ref_rtol,
)
# Simple case: 2 sentences, no extra params
x_d2 = sen2.unsqueeze(0)
x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
# NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before
# it is transposed. This is because today we cannot backward through view or unbind a
# transposed tensor.
q_d2 = (
query(x_d2)
.view(batch_size, -1, n_heads, head_dims)
.detach()
.requires_grad_(True)
)
q_d2_t = q_d2.transpose(1, 2)
k_d2 = (
key(x_d2)
.view(batch_size, -1, n_heads, head_dims)
.detach()
.requires_grad_(True)
)
k_d2_t = k_d2.transpose(1, 2)
v_d2 = (
value(x_d2)
.view(batch_size, -1, n_heads, head_dims)
.detach()
.requires_grad_(True)
)
v_d2_t = v_d2.transpose(1, 2)
q_nt = (
query(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.requires_grad_(True)
)
q_nt_t = q_nt.transpose(1, 2)
k_nt = (
key(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.requires_grad_(True)
)
k_nt_t = k_nt.transpose(1, 2)
v_nt = (
value(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.requires_grad_(True)
)
v_nt_t = v_nt.transpose(1, 2)
attn_d2 = torch.nn.functional.scaled_dot_product_attention(
q_d2_t, k_d2_t, v_d2_t
).transpose(1, 2)
d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1))
d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2))
# Simple case 3: batch_size = 1, seq_len = 1
q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device)
q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged)
q_nt_3 = q_nt_3.transpose(1, 2)
attn_out = torch.nn.functional.scaled_dot_product_attention(
q_nt_3, q_nt_3, q_nt_3
)
self.assertEqual(attn_out.shape, q_nt_3.shape)
@parametrize("skip_backward", [True, False])
def check_forward_backward(skip_backward=False):
if not skip_backward:
attn_nt = torch.nn.functional.scaled_dot_product_attention(
q_nt_t, k_nt_t, v_nt_t
).transpose(1, 2)
else:
x_nt.requires_grad = False
q_nt.requires_grad = False
k_nt.requires_grad = False
v_nt.requires_grad = False
tq = q_nt_t.detach()
tk = k_nt_t.detach()
tv = v_nt_t.detach()
with torch.no_grad():
attn_nt = torch.nn.functional.scaled_dot_product_attention(
tq, tk, tv
).transpose(1, 2)
attn_nts = attn_nt.unbind()
self.assertEqual(
attn_d1,
attn_nts[0].unsqueeze(0),
atol=output_ref_atol,
rtol=output_ref_rtol,
)
self.assertEqual(
attn_d2,
attn_nts[1].unsqueeze(0),
atol=output_ref_atol,
rtol=output_ref_rtol,
)
if not skip_backward:
nt_grads = torch.autograd.grad(
attn_nt.values().sum(), (q_nt, k_nt, v_nt)
)
for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
):
unbound_nt_grads = nt_grad.unbind()
self.assertEqual(
d1_grad,
unbound_nt_grads[0].unsqueeze(0),
atol=grad_atol,
rtol=grad_rtol,
)
self.assertEqual(
d2_grad,
unbound_nt_grads[1].unsqueeze(0),
atol=grad_atol,
rtol=grad_rtol,
)
# Default
check_forward_backward()
# Test dispatcher works by calling only mem-effn and math (as they are safe for all devices)
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=True, enable_math=True
):
check_forward_backward()
# Test math fallback
with torch.backends.cuda.sdp_kernel(
enable_flash=False, enable_mem_efficient=False, enable_math=True
):
# Math fallback doesn't work with bfloat16 on CUDA because
# "group_gemm_dispatch" not implemented for 'BFloat16'
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
check_forward_backward()
check_cudnn = os.getenv("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED", "0") == "1"
if (
"cuda" in str(device)
and check_cudnn
and (dtype == torch.float16 or dtype == torch.bfloat16)
):
with self.assertRaisesRegex(RuntimeError, "cuDNN SDPA Nested Tensor"):
with torch.nn.attention.sdpa_kernel(
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
):
check_forward_backward()
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
# Guarding with sqrt() doesn't work on ROCm?
@skipCUDAIfRocm
@onlyCUDA
@dtypes(
*(
[torch.float16, torch.bfloat16, torch.float32]
if SM80OrLater
else [torch.float16, torch.float32]
)
)
def test_sdpa_compile(self, device, dtype):
batch_size = 1
emb_dims = 1024
n_heads = 8
head_dims = emb_dims // n_heads
sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
query = torch.nn.Linear(
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
)
key = torch.nn.Linear(
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
)
value = torch.nn.Linear(
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
)
# Simplest case: 1 sentence, no batching
x_d1 = sen1.unsqueeze(0)
x_d2 = sen2.unsqueeze(0)
x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
q_nt = (
query(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.transpose(1, 2)
)
k_nt = (
key(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.transpose(1, 2)
)
v_nt = (
value(x_nt)
.view(*x_nt.size()[0:2], n_heads, head_dims)
.detach()
.transpose(1, 2)
)
# High Precision Math Reference
q_d1_f32 = q_d1.to(torch.float32)
k_d1_f32 = k_d1.to(torch.float32)
v_d1_f32 = v_d1.to(torch.float32)
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
q_d1_f32, k_d1_f32, v_d1_f32
)[0]
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
q_d1, k_d1, v_d1
)[0]
output_ref_atol, output_ref_rtol = get_tolerances(
out_ref, out_lp_ref, fudge_factor=2
)
attn_d1 = torch.nn.functional.scaled_dot_product_attention(
q_d1, k_d1, v_d1
).transpose(1, 2)
attn_d2 = torch.nn.functional.scaled_dot_product_attention(
q_d2, k_d2, v_d2
).transpose(1, 2)
compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention)
attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2)
attn_nts = attn_nt.unbind()
self.assertEqual(
attn_d1,
attn_nts[0].unsqueeze(0),
atol=output_ref_atol,
rtol=output_ref_rtol,
)
self.assertEqual(
attn_d2,
attn_nts[1].unsqueeze(0),
atol=output_ref_atol,
rtol=output_ref_rtol,
)
@dtypes(torch.float32, torch.double, torch.half)
def test_sdpa_with_constant_sequence_length(self, device, dtype):
# shape (B, P*, S, D)
# B: batch size
# P*: ragged number of prompts
# S: (constant) sequence length
# D: embedding size
query = random_nt_from_dims(
[4, None, 8, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
key = random_nt_from_similar(query)
value = random_nt_from_similar(query)
output = F.scaled_dot_product_attention(query, key, value)
self.assertTrue(isinstance(output, NestedTensor))
output.values().sum().backward()
query_dense = query.detach().clone().requires_grad_(True)
# should be equivalent to just running the buffers through
output_dense = F.scaled_dot_product_attention(
query_dense.values(), key.values(), value.values()
)
torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
output_dense.sum().backward()
torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
@onlyCUDA
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Platform doesn't support flash or mem-efficient attention",
)
@dtypes(
*(
[torch.float16, torch.bfloat16, torch.float32]
if SM80OrLater
else [torch.float16, torch.float32]
)
)
def test_sdpa_with_packed_in_proj(self, device, dtype):
# shape (B, *, D)
input_packed = random_nt_from_dims(
[5, None, 10], device=device, dtype=dtype, layout=torch.jagged
)
# Do input projection.
num_heads = 2
# should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient)
head_dim = 8
qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(
device=device, dtype=dtype
)
def in_proj(input_packed, qkv_linear=qkv_linear):
qkv_post_proj = qkv_linear(input_packed)
# these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor()
q, k, v = qkv_post_proj.chunk(3, dim=-1)
q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
return q, k, v
q, k, v = in_proj(input_packed)
output = F.scaled_dot_product_attention(q, k, v, attn_mask=None)
# compare to individually running unbound components through
for in_component, out_component in zip(
input_packed.unbind(), output.transpose(-2, -3).unbind()
):
q, k, v = in_proj(in_component)
out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3)
# Low Precision Math Reference
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[
0
].transpose(-2, -3)
output_ref_atol, output_ref_rtol = get_tolerances(
out, out_lp_ref, fudge_factor=2
)
self.assertEqual(
out, out_component, atol=output_ref_atol, rtol=output_ref_rtol
)
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
# mha_varlen_fwd not supported on ROCm
@skipCUDAIfRocm
@onlyCUDA
@dtypes(
*(
[torch.float16, torch.bfloat16, torch.float32]
if SM80OrLater
else [torch.float16, torch.float32]
)
)
def test_sdpa_backwards(self, device, dtype):
values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype)
offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64)
@torch.compile
def f(values, offsets):
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
nt = nt.transpose(-2, -3)
# purposefully graph break to trigger view replay for subclass view input
torch.tensor(1).item()
output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3)
return convert_nt_to_jagged(output)
output = f(values, offsets)
output.sum().backward()
self.assertEqual(values.grad, torch.ones_like(values))
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Platform doesn't support flash or mem-efficient attention",
)
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
@onlyCUDA
@skipIfTorchDynamo()
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
def test_sdpa_autocast(self, device):
def fn_nt(values32, values16, offsets):
nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16)
nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16)
nt32 = nt32.transpose(1, 2)
nt16 = nt16.transpose(1, 2)
return F.scaled_dot_product_attention(nt32, nt16, nt32)
def fn_dense(x32, x16):
x32 = x32.view(8, 16, 4, 16).transpose(1, 2)
x16 = x16.view(8, 16, 4, 16).transpose(1, 2)
return F.scaled_dot_product_attention(x32, x16, x32)
values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32)
values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16)
offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
x32 = values32.clone()
x16 = values16.clone()
with torch.autocast(device_type="cuda", dtype=torch.float16):
out_dense_eager = fn_dense(x32, x16)
out_dense_compiled = torch.compile(fn_dense)(x32, x16)
out_nt_eager = fn_nt(values32, values16, offsets)
out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets)
self.assertEqual(out_dense_eager, out_dense_compiled)
self.assertEqual(
out_dense_eager.transpose(1, 2),
out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16),
)
self.assertEqual(
out_dense_eager.transpose(1, 2),
out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16),
)
def get_values():
return tuple(
x.detach().clone().requires_grad_(True) for x in (values32, values16)
)
v32_dense_eager, v16_dense_eager = get_values()
v32_dense_compile, v16_dense_compile = get_values()
v32_nt_eager, v16_nt_eager = get_values()
v32_nt_compile, v16_nt_compile = get_values()
with torch.autocast(device_type="cuda", dtype=torch.float16):
loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum()
loss_dense_compile = torch.compile(fn_dense)(
v32_dense_compile, v16_dense_compile
).sum()
loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum()
loss_nt_compile = (
torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets)
.values()
.sum()
)
loss_dense_eager.backward()
loss_dense_compile.backward()
loss_nt_eager.backward()
loss_nt_compile.backward()
self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad)
self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad, atol=1e-4, rtol=1e-4)
self.assertEqual(
v32_dense_eager.grad, v32_nt_compile.grad, atol=1e-4, rtol=1e-4
)
self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad)
self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad, atol=1e-5, rtol=5e-3)
self.assertEqual(
v16_dense_eager.grad, v16_nt_compile.grad, atol=1e-5, rtol=5e-3
)
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
"Platform doesn't support flash or mem-efficient attention",
)
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
@onlyCUDA
@skipIfTorchDynamo()
def test_sdpa_flop_counter(self, device):
from torch.utils.flop_counter import FlopCounterMode
def get_flops(nt):
flop_counter = FlopCounterMode(display=False)
with flop_counter:
ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt)
ret.values().sum().backward()
return flop_counter.get_total_flops()
values = torch.randn(
(8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
)
offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16).transpose(
1, 2
)
values_meta = torch.randn(
(8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
)
offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
nt_meta = convert_jagged_to_nested_tensor(
values_meta, offsets_meta, max_length=16
).transpose(1, 2)
self.assertEqual(get_flops(nt), get_flops(nt_meta))
@skipIfTorchDynamo()
def test_nested_tensor_activation_checkpoint(self, device):
values = torch.randn(
9, 3, 256, requires_grad=True, device=device, dtype=torch.float32
)
lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64)
offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
def fn(values, offsets):
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
return convert_nt_to_jagged(nt).sum()
checkpoint(fn, values, offsets, use_reentrant=False).backward()
self.assertIsNotNone(values.grad)
context_fn = partial(
create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default]
)
values.grad = None
def fn(values, lengths):
offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
return convert_nt_to_jagged(nt).sum()
checkpoint(
fn, values, lengths, use_reentrant=False, context_fn=context_fn
).backward()
self.assertIsNotNone(values.grad)
# Internally-defined NT use cases are lifted to here for maximum test realism.
# TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated.
@skipCUDAIfRocm # not needed
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@parametrize("use_legacy_api", [True, False])
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
def test_dummy_mha_with_nt(self, device, use_legacy_api):
bs = 3
d1 = 2
d2 = 4
d3 = 16
n_heads = 2
d_head = d3 // n_heads
max_length_1 = 10
max_length_2 = 20
torch.manual_seed(0)
class mha(torch.nn.Module):
def __init__(self, use_legacy_api) -> None:
super().__init__()
torch.manual_seed(0)
self.linear = torch.nn.Linear(d2, d3, device=device)
self.use_legacy_api = use_legacy_api
def forward(self, query, value, offsets):
value = self.linear(value)
if self.use_legacy_api:
key = convert_jagged_to_nested_tensor_legacy(
value, offsets, max_length_1
)
value = convert_jagged_to_nested_tensor_legacy(
value, offsets, max_length_2
)
query = convert_dense_to_nested_tensor_legacy(query)
else:
key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
value = convert_jagged_to_nested_tensor(
value, offsets, max_length_2
)
query = convert_dense_to_nested_tensor(query)
q = query.view(bs, -1, n_heads, d_head).transpose(1, 2)
k = key.view(bs, -1, n_heads, d_head).transpose(1, 2)
v = value.view(bs, -1, n_heads, d_head).transpose(1, 2)
with torch.nn.attention.sdpa_kernel(
[
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
]
):
attn_output = torch.nn.functional.scaled_dot_product_attention(
q,
k,
v,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2)
if self.use_legacy_api:
attn_output = convert_nt_to_jagged_legacy(attn_output)
else:
attn_output = convert_nt_to_jagged(attn_output)
return attn_output, key._max_seqlen, value._max_seqlen
query = torch.rand(bs, d1, d3, device=device)
value = torch.rand(30, d2, requires_grad=True, device=device)
# total_length must > than max_length otherwise flash_attn backwark will fail
offsets = torch.tensor([0, 2, 3, 30], device=device)
m = mha(use_legacy_api)
symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
m = torch.compile(symbolic_traced)
attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m(
query, value, offsets
)
loss = attn_output.sum()
# Check that NT can be fx traced and torch.compile, and backward works
loss.backward()
# Check that value.requires_grad is not lost after tracing and compiling
value_grad = value.grad # save for comparison later
self.assertIsNotNone(value_grad)
# check that max_seqlen is cached properly
self.assertEqual(cached_key_max_seqlen, max_length_1)
self.assertEqual(cached_value_max_seqlen, max_length_2)
# check if the output is numerically equivalent with the eager mode
m_eager = mha(use_legacy_api)
value.grad = None
attn_output_eager, _, _ = m_eager(query, value, offsets)
attn_output_eager.sum().backward()
self.assertTrue(torch.allclose(attn_output_eager, attn_output))
self.assertTrue(torch.allclose(value_grad, value.grad))
# Helper function to generate random query, key, value NJTs in (B, n_heads, *, D) format.
# If noncontig_with_holes is True, the results will be non-contiguous with holes (i.e. have
# both offsets and lengths specified).
def _rand_qkv(self, device, dtype, noncontig_with_holes=False, q_and_kv_match=True):
batch_size = 8
n_heads = 8
D = 16
def _rand_nt(noncontig_with_holes=noncontig_with_holes):
sentence_lengths = [random.randint(2, 1023) for _ in range(batch_size - 1)]
total = sum(sentence_lengths)
# shape (B, *, D_total) where D_total = n_heads * D
nt = torch.nested.nested_tensor(
[
torch.randn(l, n_heads * D, device=device, dtype=dtype)
for l in sentence_lengths
],
layout=torch.jagged,
)
if noncontig_with_holes:
nt = torch.nested.nested_tensor_from_jagged(
nt._values,
nt._offsets,
# -1 to introduce holes
lengths=nt._offsets.diff() - 1,
jagged_dim=nt._ragged_idx,
min_seqlen=nt._min_seqlen,
max_seqlen=nt._max_seqlen,
)
return nt
query = _rand_nt()
if q_and_kv_match:
key = torch.randn_like(query)
value = torch.randn_like(query)
else:
key = _rand_nt()
value = torch.randn_like(key)
# shape (B, *, D_total) -> (B, n_heads, *, D)
query = (
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
value = (
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
)
return query, key, value
@onlyCUDA
@flex_attention_supported_platform
@dtypes(torch.float32)
# non-contiguous with holes not supported yet
@decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"])
@parametrize("noncontig_with_holes", [False, True])
@parametrize("cross_attention", [False, True])
@skipIfRocm
def test_flex_attention(self, device, dtype, noncontig_with_holes, cross_attention):
query, key, value = self._rand_qkv(
device, dtype, noncontig_with_holes, q_and_kv_match=(not cross_attention)
)
# Run FlexAttention with a causal mask
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
if cross_attention:
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, key, _compile=True
)
else:
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, _compile=True
)
out_flex = flex_attention(query, key, value, block_mask=block_mask)
grad_out = torch.randn_like(out_flex)
grads_flex = torch.autograd.grad(
out_flex, inputs=(query, key, value), grad_outputs=(grad_out,)
)
flex_outs = [out_flex, *grads_flex]
# Run FlexAttention with a score_mod that represents causal attention
def causal_score_mod(score, b, h, q_idx, kv_idx):
return torch.where(q_idx >= kv_idx, score, float("-inf"))
out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod)
grads_flex2 = torch.autograd.grad(
out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,)
)
flex_outs2 = [out_flex2, *grads_flex2]
# Run causal SDPA for comparison
out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True)
grads_sdpa = torch.autograd.grad(
out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,)
)
sdpa_outs = [out_sdpa, *grads_sdpa]
# Compare flex vs. SDPA output and grads
for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs):
self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested)
self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2)
self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2)
@onlyCUDA
@flex_attention_supported_platform
@dtypes(torch.float32)
def test_flex_attention_converts_stacked_seq_indices(self, device, dtype):
# This test verifies that a score_mod function written to operate within
# NJT sequence index space, such as a lookup table, works correctly. This
# validates that FlexAttention properly converts indices within the
# "stacked sequence" space used for NJT -> sequence-relative indices.
query, key, value = self._rand_qkv(device, dtype)
# Test with score_mod
score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype)
def my_score_mod(score, b, h, q_idx, kv_idx):
return score_mod_table[q_idx]
flex_attention(query, key, value, score_mod=my_score_mod)
# Test with batch-specific score_mod
batch_size = query.size(0)
batch_table = torch.randn(batch_size, device=device, dtype=dtype)
# Keep score the same for batch index == 0
batch_table[0].zero_()
def batch_specific_score_mod(score, b, h, q_idx, kv_idx):
return score + batch_table[b]
def identity_score_mod(score, b, h, q_idx, kv_idx):
return score
output = flex_attention(query, key, value, score_mod=batch_specific_score_mod)
output_identity = flex_attention(
query, key, value, score_mod=identity_score_mod
)
# Guard against a bug where the batch index passed to score_mod is always b == 0.
# Output would be equivalent to applying an identity score_mod.
# See https://github.com/pytorch/pytorch/issues/143788
self.assertFalse(torch.allclose(output._values, output_identity._values))
# Test with mask_mod
mask_mod_table = score_mod_table > 0.0
def my_mask_mod(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx]
def my_mask_mod2(b, h, q_idx, kv_idx):
return mask_mod_table[q_idx] & (b == 0)
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
output = flex_attention(query, key, value, block_mask=block_mask)
block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True)
output2 = flex_attention(query, key, value, block_mask=block_mask2)
# Guard against a bug where the batch index passed to mask_mod is always b == 0.
# See https://github.com/pytorch/pytorch/issues/143788
self.assertFalse(torch.allclose(output._values, output2._values))
@dtypes(torch.float32)
def test_apply_(self, device, dtype):
nt = random_nt_from_dims(
[5, None, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
def f(x):
return x * 2
if device != "cpu":
with self.assertRaisesRegex(
TypeError, "apply_ is only implemented on CPU tensors"
):
nt.apply_(f)
return
before = nt._values.detach().clone()
nt.apply_(f)
expected = f(before)
self.assertEqual(expected, nt._values)
# apply_ should swap values in-place without appending to autograd graph
self.assertIsNone(nt.grad)
self.assertIsNone(nt._values.grad_fn)
@onlyCUDA
@dtypes(torch.float64, torch.float32, torch.half)
@parametrize(
"contiguity",
["noncontig_transposed", "noncontig_with_holes"],
name_fn=lambda c: c,
)
def test_noncontiguous_to(self, device, dtype, contiguity):
# Dense tensors preserve non-contiguity through to() calls (i.e. strides are
# preserved). Test for the analogous behavior for NJTs:
# 1. non-contiguous transposed
# 2. non-contiguous with holes
if contiguity == "noncontig_transposed":
nt = random_nt_from_dims(
[3, None, 5, 2],
device=device,
dtype=dtype,
layout=torch.jagged,
).transpose(-3, -2)
elif contiguity == "noncontig_with_holes":
nt = torch.nested.nested_tensor_from_jagged(
values=torch.randn(10, 3, device=device, dtype=dtype),
offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64),
# these lengths specify holes
lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64),
)
else:
raise ValueError("invalid contiguity specified for test_noncontiguous_to()")
# test dtype conversion
dtype_conversions = {
torch.float32: torch.half,
torch.float64: torch.float32,
torch.half: torch.float32,
}
other_dtype = dtype_conversions[dtype]
nt2 = nt.to(dtype=other_dtype)
self.assertEqual(nt2.dtype, other_dtype)
self.assertEqual(nt.is_contiguous(), nt2.is_contiguous())
self.assertEqual(nt._values.is_contiguous(), nt2._values.is_contiguous())
self.assertEqual(nt.shape, nt2.shape)
# expect no change for offsets / lengths
self.assertEqual(nt._offsets, nt2._offsets)
self.assertEqual(nt._lengths, nt2._lengths)
# test device conversion
other_device = torch.device("cpu")
nt3 = nt.to(device=other_device)
self.assertEqual(nt3.device, other_device)
self.assertEqual(nt.is_contiguous(), nt3.is_contiguous())
self.assertEqual(nt._values.is_contiguous(), nt3._values.is_contiguous())
self.assertEqual(nt.shape, nt3.shape)
# expect device change for offsets / lengths
self.assertEqual(nt3._offsets.device, other_device)
if nt._lengths is not None:
self.assertEqual(nt3._lengths.device, other_device)
@dtypes(torch.float32)
def test_autograd_function_with_None_grad(self, device, dtype):
class MyFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, inp):
ctx.save_for_backward(inp)
out1 = inp + 1
out2 = inp * 2
return out1, out2
@staticmethod
def backward(ctx, grad_out1, grad_out2):
(inp,) = ctx.saved_tensors
return grad_out1 + grad_out2
f = MyFunction.apply
nt = random_nt_from_dims(
[5, None, 10],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
# Only use one of the autograd.Function outputs downstream so that the grad
# for the other output is None. We're testing that the engine can allocate
# correctly-shaped (NJT) zeros for the grad of the other output in this case.
(out1, _) = f(nt)
out1.backward(torch.ones_like(out1))
@dtypes(torch.float64, torch.float32, torch.half)
def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
values = torch.randn(10, 5, device=device, dtype=dtype)
offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64)
max_length = offsets.diff().max().item()
padding_value = 1.3
# convert jagged -> padded dense
padded = torch.ops.aten._jagged_to_padded_dense_forward(
values, [offsets], [max_length], padding_value
)
batch_size = offsets.shape[0] - 1
expected_padded_shape = (batch_size, max_length, values.shape[-1])
self.assertEqual(padded.shape, expected_padded_shape)
# convert padded dense -> jagged
total_L = values.shape[0]
output_jagged = torch.ops.aten._padded_dense_to_jagged_forward(
padded, [offsets], total_L
)
# should be equivalent to the original values
self.assertEqual(values, output_jagged)
# success case: truncate to max length as needed
trunc_max_length = max_length - 1
trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward(
values, [offsets], [trunc_max_length], padding_value
)
self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded)
# specific to CPU impls
if device == "cpu":
# error case: multiple offsets on cpu since CPU kernels don't support more now
with self.assertRaisesRegex(
RuntimeError, "only a single jagged dim is supported"
):
torch.ops.aten._jagged_to_padded_dense_forward(
values, [offsets, offsets], [max_length, max_length], padding_value
)
with self.assertRaisesRegex(
RuntimeError, "only a single jagged dim is supported"
):
torch.ops.aten._padded_dense_to_jagged_forward(
padded, [offsets, offsets], total_L
)
# error case: > 1D offsets
offsets2d = offsets.unsqueeze(-1)
with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
torch.ops.aten._jagged_to_padded_dense_forward(
values, [offsets2d], [max_length], padding_value
)
with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
torch.ops.aten._padded_dense_to_jagged_forward(
padded, [offsets2d], total_L
)
# error case: final offset != total_L
offsets_wrong = offsets.detach().clone()
offsets_wrong[-1] = total_L + 1
with self.assertRaisesRegex(
RuntimeError, "final offset should match total_L value"
):
torch.ops.aten._padded_dense_to_jagged_forward(
padded, [offsets_wrong], total_L
)
# error case: 1D padded input
padded_wrong = padded.flatten().detach().clone()
with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"):
torch.ops.aten._padded_dense_to_jagged_forward(
padded_wrong, [offsets], total_L
)
# error case: batch item has length > max length
# max_length is 5 above; 7 here
offsets_wrong = torch.tensor(
[0, 1, 8, 9, 10], device=device, dtype=torch.int64
)
with self.assertRaisesRegex(RuntimeError, "found batch item of length"):
torch.ops.aten._padded_dense_to_jagged_forward(
padded, [offsets_wrong], total_L
)
@dtypes(torch.float32)
@skipIfTorchDynamo("Test compiles internally")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
def test_compile_preserves_metadata_cache(self, device, dtype):
# shape (B, *, D)
nt = random_nt_from_dims(
[4, None, 3, 16],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
# expect min / max seqlen to be stored here
cache = dict(nt._metadata_cache)
@torch.compile
def f(nt):
q = nt.transpose(-3, -2)
output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2)
return output
output = f(nt)
output.backward(torch.ones_like(output))
self.assertEqual(output._metadata_cache, cache)
@dtypes(torch.float32)
@skipIfTorchDynamo("Test compiles internally")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
def test_compile_with_dynamic_max_seq_len(self, device, dtype):
# shape (B, *, D)
# max seq len: 18
nt = torch.nested.nested_tensor(
[
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(18, 5),
],
layout=torch.jagged,
)
# max seq len: 19
nt2 = torch.nested.nested_tensor(
[
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(19, 5),
],
layout=torch.jagged,
)
def f(nt):
# TODO: Replace with public API when we can use @properties
return torch.ones_like(nt) * nt._get_max_seqlen()
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
@dtypes(torch.float32)
@skipIfTorchDynamo("Test compiles internally")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
def test_compile_with_dynamic_min_seq_len(self, device, dtype):
# shape (B, *, D)
# min seq len: 7
nt = torch.nested.nested_tensor(
[
torch.randn(7, 5),
torch.randn(8, 5),
torch.randn(9, 5),
],
layout=torch.jagged,
)
# min seq len: 8
nt2 = torch.nested.nested_tensor(
[
torch.randn(8, 5),
torch.randn(9, 5),
torch.randn(10, 5),
],
layout=torch.jagged,
)
def f(nt):
# TODO: Replace with public API when we can use @properties
return torch.ones_like(nt) * nt._get_min_seqlen()
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
@dtypes(torch.float32)
@skipIfTorchDynamo("Test compiles internally")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype):
# shape (B, *, D)
# max seq len: 18
nt = torch.nested.nested_tensor(
[
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(18, 5),
],
layout=torch.jagged,
)
# max seq len: 19
nt2 = torch.nested.nested_tensor(
[
torch.randn(2, 5),
torch.randn(3, 5),
torch.randn(19, 5),
],
layout=torch.jagged,
)
def f(nt):
nt2 = nt.sin() + 1
# TODO: Replace with public API when we can use @properties
return torch.ones_like(nt2) * nt2._get_max_seqlen()
ref = f(nt)
output = torch.compile(f, fullgraph=True, dynamic=False)(nt)
self.assertEqual(ref, output)
for dynamic in [False, True, None]:
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
def test_dropout_inference_mode(self, device):
seq_len = 32
embed_dim = 128
nt = torch.nested.nested_tensor(
[
torch.randn(11, seq_len, embed_dim, device=device),
torch.randn(11, seq_len, embed_dim, device=device),
],
layout=torch.jagged,
device=device,
)
with torch.inference_mode():
torch.nn.functional.dropout(nt, p=0.05)
@dtypes(torch.float32, torch.double, torch.half)
def test_unbind_backward(self, device, dtype):
nt = torch.nested.nested_tensor(
[
torch.randn(2, 4, device=device),
torch.randn(5, 4, device=device),
torch.randn(3, 4, device=device),
],
layout=torch.jagged,
requires_grad=True,
)
a, b, c = nt.unbind()
b.sum().backward()
@torch._dynamo.disable
def check(nt):
expected_grad = torch.zeros_like(nt)
expected_grad.unbind()[1].add_(1.0)
self.assertEqual(nt.grad, expected_grad)
check(nt)
@dtypes(torch.float32, torch.double, torch.half, torch.bool)
@parametrize("nt_dim", [2, 3, 4])
@parametrize("requires_grad", [False, True])
def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad):
if dtype is torch.bool and requires_grad:
# grads not supported for bool
return
if nt_dim == 2:
post_seq_len_shape = ()
elif nt_dim == 3:
post_seq_len_shape = (10,)
elif nt_dim == 4:
post_seq_len_shape = (9, 10)
nt = torch.nested.nested_tensor(
[
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
if dtype is torch.bool
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
for n in range(2, 9)
],
layout=torch.jagged,
requires_grad=requires_grad,
)
PADDING_VAL = 4.2
expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL)
for i, component in enumerate(nt.unbind()):
expected_padded[i, : component.shape[0]].copy_(component)
padded = nt.to_padded_tensor(PADDING_VAL)
self.assertEqual(expected_padded, padded)
# convert padded dense -> NJT
from torch.nested._internal.nested_tensor import nested_from_padded
nt2 = nested_from_padded(padded, nt.offsets())
self.assertEqual(nt, nt2)
if requires_grad and dtype is not torch.bool:
# ensure gradients flow through conversions
nt2.backward(torch.ones_like(nt2))
self.assertEqual(nt.grad, torch.ones_like(nt))
# blows up due to test parametrization otherwise
@torch._dynamo.utils.disable_cache_limit()
@skipIfTorchDynamo("SDPA test compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
@dtypes(torch.float32, torch.double, torch.half)
@parametrize("nt_dim", [2, 3, 4])
@parametrize("requires_grad", [False, True])
def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad):
if dtype is torch.bool and requires_grad:
# grads not supported for bool
return
if nt_dim == 2:
post_seq_len_shape = ()
elif nt_dim == 3:
post_seq_len_shape = (10,)
elif nt_dim == 4:
post_seq_len_shape = (9, 10)
nt = torch.nested.nested_tensor(
[
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
if dtype is torch.bool
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
for n in range(2, 9)
],
layout=torch.jagged,
requires_grad=requires_grad,
)
def f(x):
return x.sin() + 1
from torch.nested._internal.nested_tensor import nested_from_padded
@torch.compile(fullgraph=True)
def g(nt):
def _g(nt):
PADDING_VAL = 4.2
padded = nt.to_padded_tensor(PADDING_VAL)
padded = f(padded)
# NB: sum_S must be specified to use the lowering for dense -> jagged
# and get full fusion
return nested_from_padded(
padded, nt.offsets(), sum_S=nt.values().shape[0]
)
# NB: use checkpointing to force fusion
return torch.utils.checkpoint.checkpoint(_g, nt, use_reentrant=False)
expected_output = f(nt)
if requires_grad:
expected_output.backward(torch.ones_like(expected_output))
expected_grad = nt.grad.detach().clone()
nt.grad = None
from torch._inductor.utils import run_and_get_code
compiled_output, generated_code = run_and_get_code(g, nt)
if requires_grad:
compiled_output.backward(torch.ones_like(compiled_output))
compiled_grad = nt.grad.detach().clone()
self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3)
self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3)
# === Verify that computation fusion happens. ===
# Fallback op call -> fusion didn't happen.
fallback_op_calls_present = any(
"torch.ops.aten._padded_dense_to_jagged_forward.default("
in generated_code[i]
or "torch.ops.aten._jagged_to_padded_dense_forward.default("
in generated_code[i]
for i in range(len(generated_code))
)
# NB: Fusion isn't supported on CPU.
self.assertEqual("cuda" in device, not fallback_op_calls_present)
for i in range(len(generated_code)):
# Examine buffer construction lines in the generated code to determine
# whether fusion occurred. If fusion happens, a 3D buffer with shape
# (B, max_seqlen, D) should never be materialized.
buffer_constructions = [
line.strip()
for line in generated_code[i].split("\n")
if "empty_strided_cuda(" in line
]
buffer_dims = [
# buffer dim == number of elements in the tensor size tuple arg
len(ast.parse(t).body[0].value.args[0].elts)
for t in buffer_constructions
]
if "cuda" in device:
self.assertFalse(any(d == 3 for d in buffer_dims))
@dtypes(torch.float32)
@skipIfTorchDynamo("Test compiles internally")
@unittest.skipIf(
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
)
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@skipCUDAIfRocm
def test_compile_padded_dense_conversion_preserves_metadata_cache(
self, device, dtype
):
# shape (B, *, D)
nt = random_nt_from_dims(
[4, None, 3, 16],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
# expect min / max seqlen to be stored here
cache = dict(nt._metadata_cache)
@torch.compile
def g(nt):
padded = nt.to_padded_tensor(0.3)
intermediate = padded.sin() + 1
from torch.nested._internal.nested_tensor import nested_from_padded
return nested_from_padded(
intermediate,
nt.offsets(),
min_seqlen=nt._min_seqlen,
max_seqlen=nt._max_seqlen,
sum_S=nt.values().shape[0],
)
output = g(nt)
output.backward(torch.ones_like(output))
self.assertEqual(output._metadata_cache, cache)
# See https://github.com/pytorch/pytorch/issues/128649
@dtypes(torch.float32)
def test_composite_op_in_inference_mode(self, device, dtype):
# expect view
nt = random_nt_from_dims(
[4, None, 48],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
with torch.inference_mode():
output = nt.reshape([4, -1, 3, 16])
self.assertEqual(output.shape, (4, nt.shape[1], 3, 16))
self.assertTrue(output._is_view())
# expect copy
nt = random_nt_from_dims(
[4, None, 3, 16],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
).transpose(-1, -2)
with torch.inference_mode():
output = nt.reshape([4, -1, 48])
self.assertEqual(output.shape, (4, nt.shape[1], 48))
self.assertFalse(output._is_view())
@dtypes(torch.float32)
def test_composite_op_with_custom_mode(self, device, dtype):
from torch.utils._python_dispatch import TorchDispatchMode
# simple passthrough TorchDispatchMode
class CustomDispatchMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
return func(*args, **kwargs)
nt = random_nt_from_dims(
[4, None, 2, 3],
device=device,
dtype=dtype,
layout=torch.jagged,
requires_grad=True,
)
with CustomDispatchMode():
res = nt.reshape(4, -1, 6)
self.assertEqual(res.shape, (4, nt.shape[1], 6))
@skipIfTorchDynamo("compiles internally")
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
@dtypes(torch.float32)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_broadcast_shapes_on_in_graph_constructed_njt(self, device, dtype):
# Tests that a guard isn't wrongly installed on a freshly-created nested int when
# broadcast_shapes() is used on NJT shapes.
# See https://github.com/pytorch/pytorch/issues/145874 for more context.
nt = torch.nested.nested_tensor(
[
torch.randn(2),
torch.randn(3),
torch.randn(4),
],
layout=torch.jagged,
device=device,
dtype=dtype,
)
values = nt._values.detach().clone()
offsets = nt._offsets.detach().clone()
@torch.compile(fullgraph=True)
def f(values, offsets):
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
# NB: torch.where() utilizes broadcast_shapes() underneath
return torch.where(nt > 0.0, torch.ones_like(nt), torch.zeros_like(nt))
output = f(values, offsets)
self.assertTrue(output.is_nested)
self.assertEqual(nt.shape[:-1], output.shape[:-1])
for nt_component, output_component in zip(nt.unbind(), output.unbind()):
self.assertEqual(nt_component.shape, output_component.shape)
# The following lists specify skips and xfails for particular SampleInputs. Note that
# these are attempted to be matched from top to bottom and only one at most will
# be matched, so order matters! The guiding general principle here should be one
# xfail / skip per bug if at all possible :)
FORWARD_SKIPS_AND_XFAILS = [
# not implemented
XFailRule(
error_type=NotImplementedError,
op_match_fn=lambda device, op: op.full_name
in {
# unary
# needs log_sigmoid_forward, which returns a tuple
"nn.functional.logsigmoid",
"nn.functional.prelu",
# needs rrelu_with_noise
"nn.functional.rrelu",
# binary
"__rsub__",
"complex",
"floor_divide",
"polar",
"rsub",
# reduction
"count_nonzero",
"linalg.vector_norm",
"nansum",
"std",
"std.unbiased",
"var",
"var.unbiased",
},
name="not_implemented",
),
# expected: torch.where() support has some limitations
# 1. condition must be an NJT
# 2. no dense tensors of higher dim than the NJT
XFailRule(
error_type=ValueError,
error_msg="expected condition to be a jagged layout NestedTensor",
op_match_fn=lambda device, op: op.full_name == "where",
sample_match_fn=lambda device, sample: not sample.kwargs["condition"].is_nested,
),
XFailRule(
error_type=ValueError,
error_msg="broadcasting nested tensors with dense tensors of equal or higher dim",
op_match_fn=lambda device, op: op.full_name == "where",
sample_match_fn=lambda device, sample: (
(
not sample.input.is_nested
and sample.input.dim() >= sample.kwargs["condition"].dim()
)
or (
not sample.kwargs["other"].is_nested
and sample.kwargs["other"].dim() >= sample.kwargs["condition"].dim()
)
),
),
# expected: masked ops don't support jagged layout
XFailRule(
error_type=ValueError,
error_msg="expects strided",
op_match_fn=lambda device, op: op.full_name
in {
"masked.amax",
"masked.amin",
"masked.argmax",
"masked.argmin",
"masked.logsumexp",
"masked.mean",
"masked.norm",
"masked.prod",
"masked.std",
"masked.sum",
"masked.var",
},
name="no_masked_jagged_support",
),
# Op doesn't support lengths being present
XFailRule(
error_type=ValueError,
error_msg="expected input to be a contiguous jagged layout NestedTensor",
op_match_fn=lambda device, op: (op.full_name == "nn.functional.linear"),
sample_match_fn=lambda device, sample: (sample.input._lengths is not None),
name="no_linear_noncontig_holes_support",
),
# nanmean sometimes hits an unimplemented nansum() path and other times hits an
# unimplemented sum() path
XFailRule(
error_type=NotImplementedError,
op_match_fn=lambda device, op: (op.full_name == "nanmean"),
sample_match_fn=lambda device, sample: (
not (
"noncontig_holes" in sample.name
and "dim" in sample.kwargs
and (
(
isinstance(sample.kwargs["dim"], int)
and sample.kwargs["dim"] == sample.input._ragged_idx
)
or (
isinstance(sample.kwargs["dim"], (tuple, list))
and sample.input._ragged_idx in sample.kwargs["dim"]
)
)
)
),
name="nansum_unimplemented",
),
# expected: reducing across the ragged dimension is not supported for non-contiguous
# nested tensors with holes
XFailRule(
error_type=RuntimeError,
error_msg=(
"reducing across the ragged dimension is not supported for non-contiguous "
"nested tensors with holes"
),
op_match_fn=lambda device, op: (
# min.reduction_with_dim and max.reduction_with_dim aren't associated with
# ReductionOpInfo entries sadly even though they're reductions
isinstance(op, ReductionOpInfo) or "reduction_with_dim" in op.full_name
),
sample_match_fn=lambda device, sample: (
"noncontig_holes" in sample.name
and "dim" in sample.kwargs
and (
(
isinstance(sample.kwargs["dim"], int)
and sample.kwargs["dim"] == sample.input._ragged_idx
)
or (
isinstance(sample.kwargs["dim"], (tuple, list))
and sample.input._ragged_idx in sample.kwargs["dim"]
)
)
),
name="ragged_dim_reduction_noncontig_holes",
),
# expected: index_put() doesn't work on non-contiguous NJTs without ragged dimension indices
XFailRule(
error_type=RuntimeError,
error_msg="If ragged dimension is not part of indices, this only works on contiguous NJTs",
op_match_fn=lambda device, op: (op.full_name == "index_put"),
sample_match_fn=lambda device, sample: (
not sample.input.is_contiguous()
and len(sample.kwargs["indices"]) - 1 < sample.input._ragged_idx
),
name="index_put_noncontig_holes_no_ragged_dim_indices",
),
# select() only supports dim=0 for non-contiguous with holes NJTs for now
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "select"),
sample_match_fn=lambda device, sample: (
sample.kwargs["dim"] != 0 and "noncontig_holes" in sample.name
),
name="unsupported_select_on_non_batch_dim_with_noncontig_holes",
),
# these don't work on non-contiguous NJTs yet
XFailRule(
error_type=ValueError,
error_msg="expected self to be a contiguous jagged layout NestedTensor",
op_match_fn=lambda device, op: (
op.full_name
in {
"chunk",
"masked_select",
"narrow",
"split",
"split_with_sizes",
"squeeze",
}
),
sample_match_fn=lambda device, sample: (
sample.input._lengths is not None or sample.input._ragged_idx != 1
),
name="missing_noncontig_support",
),
# these don't work on the ragged dim yet
XFailRule(
error_type=RuntimeError,
error_msg="not supported for NestedTensor on ragged dim",
op_match_fn=lambda device, op: (
op.full_name
in {
"chunk",
"narrow",
"select",
"split",
}
),
sample_match_fn=lambda device, sample: "ragged_dim" in sample.name,
name="ragged_dim_unsupported",
),
XFailRule(
error_type=RuntimeError,
# error comes from usage of view() in the decomp
error_msg="does not support ragged_idx != 1 except when",
op_match_fn=lambda device, op: (op.full_name == "unflatten"),
sample_match_fn=lambda device, sample: "noncontig_transposed" in sample.name,
name="unflatten_ragged_dim_unsupported",
),
# these don't work on the batch dim yet
XFailRule(
error_type=RuntimeError,
error_msg="not supported for NestedTensor on dim=0",
op_match_fn=lambda device, op: (
op.full_name
in {
"narrow",
"split",
"split_with_sizes",
"unsqueeze",
}
),
sample_match_fn=lambda device, sample: "batch_dim" in sample.name,
name="batch_dim_unsupported",
),
XFailRule(
error_type=RuntimeError,
# error comes from usage of view() in the decomp
error_msg="cannot view shape",
op_match_fn=lambda device, op: (op.full_name == "unflatten"),
sample_match_fn=lambda device, sample: "batch_dim" in sample.name,
name="unflatten_batch_dim_unsupported",
),
# expected: bmm / matmul sometimes use a to_padded_tensor() fallback which isn't
# supported for non-contig NJTs with holes
XFailRule(
error_type=RuntimeError,
error_msg="not supported for nested tensors with holes",
op_match_fn=lambda device, op: (op.full_name in {"bmm", "matmul"}),
sample_match_fn=lambda device, sample: (
"noncontig_holes" in sample.name
# "other" is the name for the matmul arg and "mat2" is the name for the bmm arg
and sample.input.dim()
== sample.kwargs.get("other", sample.kwargs.get("mat2")).dim()
),
name="mm_noncontig_holes",
),
# some jiterator op failures due to unsupported jagged layout
XFailRule(
error_type=RuntimeError,
error_msg="unsupported tensor layout",
op_match_fn=lambda device, op: op.full_name
in {
"jiterator_binary",
"jiterator_binary_return_by_ref",
"jiterator_unary",
},
name="no_jiterator_jagged_support",
),
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
# tensor with 1 in ragged dim.
XFailRule(
error_type=RuntimeError,
error_msg="cannot call binary pointwise function .* with inputs of shapes",
op_match_fn=lambda device, op: (isinstance(op, BinaryUfuncInfo)),
sample_match_fn=lambda device, sample: (
"noncontig_holes" in sample.name
and "broadcasting 1 over ragged" in sample.name
),
name="binary_noncontig_holes_broadcasting_1_over_ragged",
),
]
BACKWARD_SKIPS_AND_XFAILS = [
# segfaults, so skip. It's trying to use the NST logic for NJT
SkipRule(
op_match_fn=lambda device, op: op.full_name == "split_with_sizes",
name="split_with_sizes_backward_segfault",
),
*FORWARD_SKIPS_AND_XFAILS,
# Backwards is generally broken for non-contiguous NJTs with holes. Rather than
# determine the exceptions in detail, just skip for now. Fix is to ensure
# that summing over gradients during backwards after broadcasting takes into
# account holes / lengths.
SkipRule(
op_match_fn=lambda device, op: (
isinstance(op, BinaryUfuncInfo)
or op.full_name in {"mean", "where", "unsqueeze"}
),
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
name="broken_noncontig_holes_backward",
),
# mean(): need to examine backwards formula
XFailRule(
error_type=RuntimeError,
error_msg="SymIntArrayRef expected to contain only concrete integers",
op_match_fn=lambda device, op: (op.full_name in {"mean"}),
sample_match_fn=lambda device, sample: (
"full reduction" not in sample.name
and "normal dim reduction" not in sample.name
),
name="broken_mean_backward",
),
# RuntimeError: expand(): cannot expand shape (3, 3, 1, j44) -> [3, 3, 7, j44]
# with noncontig transposed inputs to mean()
XFailRule(
error_type=RuntimeError,
error_msg="cannot expand shape",
op_match_fn=lambda device, op: (op.full_name == "mean"),
sample_match_fn=lambda device, sample: (
"normal dim reduction" in sample.name
and "noncontig_transposed" in sample.name
),
name="broken_mean_backward2",
),
# unsqueeze() backward tries to call squeeze with noncontig transposed,
# but that's not supported
XFailRule(
error_type=ValueError,
error_msg="expected self to be a contiguous jagged layout NestedTensor",
op_match_fn=lambda device, op: (op.full_name == "unsqueeze"),
sample_match_fn=lambda device, sample: (
"noncontig_transposed" in sample.name or "ragged_dim" in sample.name
),
name="broken_unsqueeze_backward",
),
# RuntimeError: view(): cannot view shape (3, j62, 1, 7, 3) as [3, j58, 7, 3]
# with unflatten()
XFailRule(
error_type=RuntimeError,
error_msg="cannot view shape",
op_match_fn=lambda device, op: (op.full_name in {"unflatten"}),
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
name="broken_unflatten_backward",
),
# sum() backward is not implemented for non-full reductions
XFailRule(
error_type=NotImplementedError,
error_msg="aten._nested_sum_backward.default",
op_match_fn=lambda device, op: (op.full_name == "sum"),
sample_match_fn=lambda device, sample: ("full reduction" not in sample.name),
name="broken_sum_backward",
),
# squeeze(): invalid gradient shape; need to check formula
XFailRule(
error_type=RuntimeError,
error_msg="returned an invalid gradient at index 0",
op_match_fn=lambda device, op: (op.full_name == "squeeze"),
sample_match_fn=lambda device, sample: (
sample.name == "5D_contig_with_seqlen_cache: normal_dim"
and sample.kwargs["dim"] == 3
),
name="broken_squeeze_backward",
),
# sgn() / masked_select(): backwards formulas don't work at all
XFailRule(
error_type=RuntimeError,
error_msg="NestedTensor does not support directly calling torch.ops.aten.size",
op_match_fn=lambda device, op: (op.full_name in {"sgn", "masked_select"}),
name="broken_sgn_masked_select_backward",
),
# select(): grad_output is an NJT for non-batch-dim operation
XFailRule(
error_type=ValueError,
error_msg="expected grad_output to be a tensor",
op_match_fn=lambda device, op: (op.full_name == "select"),
sample_match_fn=lambda device, sample: ("batch_dim" not in sample.name),
name="broken_select_backward",
),
# prod(): completely broken in every way
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "prod"),
name="broken_prod_backward",
),
# pow() / float_power(): use where() underneath; broken for (NT, T) broadcasting cases
XFailRule(
error_type=ValueError,
error_msg="expected condition to be a jagged layout NestedTensor",
op_match_fn=lambda device, op: (op.full_name in {"pow", "float_power"}),
sample_match_fn=lambda device, sample: ("(NT, T)" in sample.name),
name="broken_pow_backward",
),
# __rpow__() backward is also broken, but for the reverse (T, NT) broadcasting cases
XFailRule(
error_type=ValueError,
error_msg="expected condition to be a jagged layout NestedTensor",
op_match_fn=lambda device, op: (op.full_name == "__rpow__"),
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
name="broken_rpow_backward",
),
# linear(): some formula problem when bias is used; seems to be platform-specific
# (fails locally but not in CI)
SkipRule(
# result2.use_count() <= 1 INTERNAL ASSERT FAILED
op_match_fn=lambda device, op: (op.full_name == "nn.functional.linear"),
sample_match_fn=lambda device, sample: ("with bias" in sample.name),
name="broken_linear_backward",
),
# narrow(): unimplemented backward
XFailRule(
error_type=RuntimeError,
error_msg="derivative for aten::narrow is not implemented",
op_match_fn=lambda device, op: (op.full_name == "narrow"),
name="broken_narrow_backward",
),
# min / max: need factory function support for ragged dim reductions
# where the output is dense but sizes still contain a nested int
XFailRule(
error_type=RuntimeError,
error_msg="SymIntArrayRef expected to contain only concrete integers",
op_match_fn=lambda device, op: (
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
),
sample_match_fn=lambda device, sample: ("ragged dim" in sample.name),
name="broken_min_max_reduction_with_dim_backward_on_ragged_dim",
),
# copysign(): formula is broken for (T, NT) broadcasting
XFailRule(
error_type=RuntimeError,
error_msg="SymIntArrayRef expected to contain only concrete integers",
op_match_fn=lambda device, op: (op.full_name == "copysign"),
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
name="broken_copysign_backward",
),
# amin() / amax(): broken in a host of ways I don't think it's a good use of time
# to try to sift through
SkipRule(
op_match_fn=lambda device, op: (op.full_name in {"amin", "amax"}),
name="broken_amin_amax_backward",
),
XFailRule(
error_type=RuntimeError,
error_msg="reducing across the ragged dimension is not supported for non-contiguous",
op_match_fn=lambda device, op: (
isinstance(op, BinaryUfuncInfo)
# doesn't happen for these ops for some reason
and op.full_name
not in {"copysign", "max.binary", "maximum", "min.binary", "minimum"}
),
sample_match_fn=lambda device, sample: (
"(NT, T) broadcasting all 1s" in sample.name
and "noncontig_holes" in sample.name
),
name="binary_noncontig_holes_ragged_dim_reduction",
),
XFailRule(
error_type=RuntimeError,
error_msg="reducing across the ragged dimension is not supported for non-contiguous",
op_match_fn=lambda device, op: (op.full_name == "nn.functional.rms_norm"),
sample_match_fn=lambda device, sample: (sample.input._lengths is not None),
name="rms_norm_noncontig_holes_ragged_dim_reduction",
),
# expected: autodiff on complex dtype is not supported
XFailRule(
error_type=RuntimeError,
error_msg=(
"_nested_view_from_jagged does not support automatic differentiation "
"for outputs with complex dtype"
),
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}),
name="no_complex_autodiff",
),
# Bug: need to use the correct nested int in the return shape
XFailRule(
error_type=RuntimeError,
error_msg="Function CloneBackward0 returned an invalid gradient",
op_match_fn=lambda device, op: (op.full_name == "clone"),
sample_match_fn=lambda device, sample: (
sample.kwargs.get("memory_format", None) == torch.contiguous_format
),
name="clone_wrong_nested_int_for_gradient",
),
# some min / max ops use masked_fill_ underneath sometimes, which isn't implemented
XFailRule(
error_type=NotImplementedError,
error_msg="aten.masked_fill_.Scalar",
op_match_fn=lambda device, op: (
op.full_name
in {"max.binary", "min.binary", "minimum", "maximum", "copysign"}
),
name="unimplemented_masked_fill",
),
]
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
*FORWARD_SKIPS_AND_XFAILS,
# Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails
# e.g. Expected 5 == 4
XFailRule(
error_type=AssertionError,
op_match_fn=lambda device, op: (op.full_name == "fill"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="fill_aot_autograd_bug_with_transposed_input",
),
# Bug: cross-device conversions with to() result in new nested ints within compile only
XFailRule(
error_type=AssertionError,
error_msg="The values for attribute 'shape' do not match",
op_match_fn=lambda device, op: (op.full_name == "to"),
sample_match_fn=lambda device, sample: ("-> cpu" in sample.name),
name="cross_device_transfer_wrong_nested_int_in_compile",
),
# clone() -> preserve format on an non-contiguous NJT with holes currently uses
# unbind(), leading to data-dependent expression. Should be fixed via torch._check()
XFailRule(
error_type=torch._dynamo.exc.Unsupported,
# Ne(u1, u0) (unhinted: Ne(u1, u0)). (Size-like symbols: u1, u0)
error_msg="Could not guard on data-dependent expression",
op_match_fn=lambda device, op: (op.full_name == "clone"),
sample_match_fn=lambda device, sample: (
"noncontig_holes" in sample.name
and sample.kwargs.get("memory_format", None) == torch.contiguous_format
),
name="clone_unbind_data_dependency",
),
# chunk(): broken in several ways on the batch dim; revisit after similar
# data-dependency issues are handled for narrow()
SkipRule(
op_match_fn=lambda device, op: (op.full_name == "chunk"),
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
name="broken_chunk_compile_backward_on_batch_dim",
),
# select on batch dim currently uses unbind(), leading to data-dependent error in
# torch.compile that needs to be addressed via torch._check()
XFailRule(
error_type=torch._dynamo.exc.InternalTorchDynamoError,
error_msg="Pending unbacked symbols",
op_match_fn=lambda device, op: (op.full_name == "select"),
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
name="broken_select_backward_unbacked",
),
# Bug: no idea what's going on here; needs investigation within AOTAutograd
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "nan_to_num"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug1",
),
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "isreal"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug2",
),
]
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
# non-contiguous with holes inputs + torch.compile doesn't work great today; need
# torch._check() statements. Skip these and handle them later.
SkipRule(
op_match_fn=lambda device, op: True,
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
name="noncontig_holes_data_dependency",
),
# mean(): weird bug
XFailRule(
error_type=torch._dynamo.exc.BackendCompilerFailed,
error_msg="'NestedIntNode' object has no attribute 'sub'",
op_match_fn=lambda device, op: (op.full_name == "mean"),
sample_match_fn=lambda device, sample: (
"full reduction" not in sample.name
and "normal dim reduction" not in sample.name
),
name="broken_mean_compile_backward",
),
# min() / max(): weird bug
XFailRule(
error_type=AttributeError,
error_msg="'ConstantIntNode' object has no attribute 'add'",
op_match_fn=lambda device, op: (
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
),
sample_match_fn=lambda device, sample: ("ragged dim" in sample.name),
name="broken_min_max_compile_backward",
),
# to() fails with data-dependent guards OR Unknown layout in record_stream_any_impl;
# need to fix with torch._check(), etc.
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "to"),
sample_match_fn=lambda device, sample: ("-> cpu" in sample.name),
name="to_data_dependency",
),
# copysign(): formula is broken for (T, NT) broadcasting
XFailRule(
error_type=AttributeError,
error_msg="'ConstantIntNode' object has no attribute 'add'",
op_match_fn=lambda device, op: (op.full_name == "copysign"),
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
name="broken_copysign_compile_backward",
),
# in compile, these complex ops use view_as_real(), which isn't implemented
XFailRule(
error_type=NotImplementedError,
error_msg="aten.view_as_real.default",
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}),
name="unimplemented_view_as_real",
),
*COMPILE_FORWARD_SKIPS_AND_XFAILS,
*BACKWARD_SKIPS_AND_XFAILS,
]
COMPARE_TENSOR_COMPONENT_EQUALITY = {
# masked_select is expected to output a different shape
"masked_select",
}
# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard
# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests:
# * All tests run with dtype=torch.float32 only
class TestNestedTensorOpInfo(NestedTensorTestCase):
# TODO: move this
def _gen_grad_outputs(self, out_val):
if isinstance(out_val, (list, tuple)):
need_grad_outs = tuple(o for o in out_val if o.grad_fn is not None)
grad_outputs = tuple(
torch.ones_like(o) for o in out_val if o.grad_fn is not None
)
return need_grad_outs, grad_outputs
else:
return out_val, (torch.ones_like(out_val),)
@ops(
[op for op in njt_op_db if op.supports_njt],
allowed_dtypes=(torch.float32,),
)
@tf32_on_and_off(0.005)
@sample_skips_and_xfails(FORWARD_SKIPS_AND_XFAILS)
def test_forward(self, device, dtype, op):
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
device=device,
dtype=dtype,
requires_grad=False,
use_subtests=True,
):
with subtest_ctx(self), skip_xfail_ctx(self):
# compare to reference, but expect different nested int
out = op.op(sample.input, *sample.args, **sample.kwargs)
out_ref = op.ref(op, sample)
self.assertEqualIgnoringNestedInts(out, out_ref)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out
)
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
# TODO: Add xfails for other inplace ops instead of hardcoding
if op.inplace_variant and "index_put" in op.full_name:
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
@ops(
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
allowed_dtypes=(torch.float32,),
)
@tf32_on_and_off(0.005)
@sample_skips_and_xfails(BACKWARD_SKIPS_AND_XFAILS)
def test_backward(self, device, dtype, op):
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
device=device, dtype=dtype, requires_grad=True, use_subtests=True
):
with subtest_ctx(self), skip_xfail_ctx(self):
# compare to reference, but expect different nested int
out = op.op(sample.input, *sample.args, **sample.kwargs)
out_ref = op.ref(op, sample)
self.assertEqualIgnoringNestedInts(out, out_ref)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out
)
inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
g_inps = [
inp
for inp in inps
if isinstance(inp, torch.Tensor) and inp.requires_grad
]
if len(g_inps) > 0:
need_grad_outs, grad_outputs = self._gen_grad_outputs(out)
grads = torch.autograd.grad(
need_grad_outs, inputs=g_inps, grad_outputs=grad_outputs
)
need_grad_outs, grad_outputs = self._gen_grad_outputs(out_ref)
grads_ref = torch.autograd.grad(
need_grad_outs, inputs=g_inps, grad_outputs=grad_outputs
)
self.assertEqualNoncontigAware(grads, grads_ref)
@ops(
[op for op in njt_op_db if op.supports_njt],
allowed_dtypes=(torch.float32,),
)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
# needed to avoid "data dependent operator: aten._local_scalar_dense.default"
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@sample_skips_and_xfails(COMPILE_FORWARD_SKIPS_AND_XFAILS)
def test_compile_forward(self, device, dtype, op):
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
device=device, dtype=dtype, requires_grad=False, use_subtests=True
):
with subtest_ctx(self), skip_xfail_ctx(self):
torch.compiler.reset()
op_fn = op.op
def f(*args, **kwargs):
return op_fn(*args, **kwargs)
compiled_f = torch.compile(
f, fullgraph=True, backend="aot_eager_decomp_partition"
)
out_ref = f(sample.input, *sample.args, **sample.kwargs)
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref
)
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
self.assertEqualIgnoringNestedInts(out_compile, out_ref)
else:
self.assertEqual(out_compile, out_ref)
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
# TODO: Add xfails for other inplace ops instead of hardcoding
if op.inplace_variant and "index_put" in op.full_name:
op_fn = op.inplace_variant
def in_f(*args, **kwargs):
return op_fn(*args, **kwargs)
compiled_in_f = torch.compile(
in_f, fullgraph=True, backend="aot_eager_decomp_partition"
)
compiled_in_f(sample.input, *sample.args, **sample.kwargs)
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
else:
self.assertEqual(sample.input, out_ref)
@ops(
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
allowed_dtypes=(torch.float32,),
)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
# needed to avoid "data dependent operator: aten._local_scalar_dense.default"
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@sample_skips_and_xfails(COMPILE_BACKWARD_SKIPS_AND_XFAILS)
def test_compile_backward(self, device, dtype, op):
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
device=device, dtype=dtype, requires_grad=True, use_subtests=True
):
with subtest_ctx(self), skip_xfail_ctx(self):
torch.compiler.reset()
op_fn = op.op
def f(*args, **kwargs):
return op_fn(*args, **kwargs)
compiled_f = torch.compile(
f, fullgraph=True, backend="aot_eager_decomp_partition"
)
out_ref = f(sample.input, *sample.args, **sample.kwargs)
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref
)
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
self.assertEqualIgnoringNestedInts(out_compile, out_ref)
else:
self.assertEqual(out_compile, out_ref)
inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
g_inps = [
inp
for inp in inps
if isinstance(inp, torch.Tensor) and inp.requires_grad
]
if len(g_inps) > 0:
need_grad_outs, grad_outputs = self._gen_grad_outputs(out_compile)
grads_compile = torch.autograd.grad(
need_grad_outs,
inputs=g_inps,
grad_outputs=grad_outputs,
)
need_grad_outs, grad_outputs = self._gen_grad_outputs(out_ref)
grads_ref = torch.autograd.grad(
need_grad_outs,
inputs=g_inps,
grad_outputs=grad_outputs,
)
self.assertEqualNoncontigAware(grads_compile, grads_ref)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
# needed to avoid "data dependent operator: aten._local_scalar_dense.default"
@torch._dynamo.config.patch(capture_scalar_outputs=True)
@skipIfTorchDynamo(
"Dynamo fails on pending unbacked symints at assertEqual(ref_y[0][0][0].item(), 2)"
)
def test_nested_tensor_non_contiguous_mutation(self):
def fn(x, x0):
x[0, 0, 0] = 2
return x
def _inp():
base = torch.zeros(32, 3)
v = base.t()
return torch.nested.nested_tensor_from_jagged(
v,
offsets=torch.tensor([0, 2, 3]),
), torch.ones(2, 32)
ref_x, ref_x0 = _inp()
ref_y = fn(ref_x, ref_x0)
self.assertEqual(ref_y[0][0][0].item(), 2)
y = torch.compile(fn, fullgraph=True, backend="aot_eager")(*_inp())
self.assertEqual(y[0][0][0], 2)
def test_nested_tensor_input_mutation_backward(self):
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
# NJT tangent is always subclass, See torch/csrc/autograd/python_function.cpp, use_zeros_like.
# This test checks that AOTD correctly guess NJT tangent as NJT.
def fn(x):
x.mul_(2)
return x + 1
def _inp():
v = torch.zeros(32, 3, requires_grad=True)
return torch.nested.nested_tensor_from_jagged(
v,
offsets=torch.tensor([0, 2, 3]),
).clone()
ref_x = _inp()
ref_y = fn(ref_x)
ref_y.sum().backward()
x = _inp()
y = torch.compile(fn, fullgraph=True, backend="aot_eager")(x)
y.sum().backward()
from torch.nested._internal.nested_int import NestedIntNode
class TestNestedInt(torch.testing._internal.common_utils.TestCase):
def test_comparisons(self):
a = torch.SymInt(NestedIntNode(1, 1))
b = torch.SymInt(NestedIntNode(1, 1))
c = torch.SymInt(NestedIntNode(2, 1))
d = 3
self.assertTrue(a == a)
self.assertTrue(a == b)
self.assertFalse(a != a)
self.assertFalse(a != b)
self.assertFalse(a == c)
self.assertTrue(a != c)
self.assertFalse(a == d)
self.assertTrue(a != d)
self.assertFalse(d == a)
self.assertTrue(d != a)
# ge
self.assertTrue(a >= a)
self.assertTrue(a >= b)
self.assertTrue(b >= a)
with self.assertRaises(ValueError):
_ = a >= c
with self.assertRaises(ValueError):
_ = c >= a
with self.assertRaises(ValueError):
_ = c >= 3
self.assertTrue(c >= 2)
self.assertTrue(c >= 1)
self.assertFalse(c <= 1)
# lt
self.assertFalse(a < a)
self.assertFalse(a < b)
self.assertFalse(b < a)
with self.assertRaises(ValueError):
_ = a < c
with self.assertRaises(ValueError):
_ = c < a
with self.assertRaises(ValueError):
_ = 3 < a
with self.assertRaises(ValueError):
_ = 2 < a
self.assertTrue(a > 1)
# le
self.assertTrue(a <= a)
self.assertTrue(b <= a)
self.assertTrue(a <= b)
with self.assertRaises(ValueError):
_ = a <= c
with self.assertRaises(ValueError):
_ = c <= a
with self.assertRaises(ValueError):
_ = 3 <= c
self.assertTrue(c >= 2)
self.assertTrue(c >= 1)
self.assertFalse(c <= 1)
# gt
self.assertFalse(a > a)
self.assertFalse(b > a)
self.assertFalse(a > b)
with self.assertRaises(ValueError):
_ = a > c
with self.assertRaises(ValueError):
_ = c > a
with self.assertRaises(ValueError):
_ = a > 3
with self.assertRaises(ValueError):
_ = a > 2
self.assertTrue(a > 1)
def test_with_factor(self):
a = torch.SymInt(NestedIntNode(1, 5))
b = torch.SymInt(NestedIntNode(1, 10))
# eq
self.assertFalse(a == b)
self.assertFalse(a >= b)
self.assertTrue(b >= a)
self.assertTrue(a <= b)
self.assertFalse(b <= a)
# ne
self.assertTrue(a != b)
# mul
self.assertTrue(a * 2 == b)
self.assertTrue(a * 3 >= b)
self.assertTrue(a * 2 == 2 * a)
instantiate_parametrized_tests(TestNestedTensor)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
instantiate_device_type_tests(TestNestedTensorSubclass, globals())
instantiate_device_type_tests(TestNestedTensorOpInfo, globals())
if __name__ == "__main__":
run_tests()