mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
# Summary ### Update API ```Py class AuxRequest(NamedTuple): """Request which auxiliary outputs to compute from flex_attention. Each field is a boolean indicating whether that auxiliary output should be computed. """ lse: bool = False max_scores: bool = False class AuxOutput(NamedTuple): """Auxiliary outputs from flex_attention operation. Fields will be None if not requested, or contain the tensor if requested. """ lse: Optional[Tensor] = None max_scores: Optional[Tensor] = None out_only = flex_attention(query, key, value, score_mod) out_max, aux_max = flex_attention( query, key, value, score_mod, return_aux=FlexAttentionAuxRequest(max_scores=True), ) out_both, aux_both = flex_attention( query, key, value, score_mod, return_aux=FlexAttentionAuxRequest(lse=True, max_scores=True), ) ``` Returns the max post mod scores from flex attention. Not being able to break BC is kinda of annoying here since we end up with a combinatorial problem where if we need to add any more return vals we need to new kwargs that gate if they get returned by the function and need to support the 2**N additional args possible return groups. Ideally there isn't much more we need to return, but we might want to think about how best to set this up for expansion in the future. I added kwarg only now Maybe we make a `ExtraReturns` type kwarg that can grow and we don't need to keep adding new top level args. We could also return a Struct that holds all the extra tensors and start deprecation cycle for logsumexp eventually returning just 1 `ExtraReturns` like struct with the tensors. ### Req Grad I currently dont return a max_scores that supports backproping grads. I think this might be feasible but since max is essentially 1 hot on the inputs and a reduction we would either need to save another `max_location` from the forward or find the max_score but also only apply to first occurence if there is multiple equivalent scores (need to check if thats we define for vanilla max op in torch). For now no grad, we can re-visit if needed. ## Perf I am going to disable for flex_decode. Since at least initially the motivation is for training. I also more hard than it should be to have ops return nuns or optional tensors, If return max is at the false, we should probably just create a tensor of size zero so that we don't slow down the hot path. ```Shell 🔝 Top 5 TFlops Deltas (by absolute %): shape: (5, 7) ┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡ │ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │ │ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │ │ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │ │ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │ │ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │ │ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │ │ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │ │ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │ │ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │ └────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘ 🔺 Top 5 Positive TFlops Deltas (highest +%): shape: (5, 7) ┌────────────────┬────────────────┬────────────────────────┬───────────────┬──────────────┬──────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞════════════════╪════════════════╪════════════════════════╪═══════════════╪══════════════╪══════════╪═══════════╡ │ causal ┆ torch.bfloat16 ┆ (4, 16, 2048, 16, ┆ 249.514658 ┆ 243.078974 ┆ 6.435684 ┆ 2.647569 │ │ ┆ ┆ 2048, 64) ┆ ┆ ┆ ┆ │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 57.971274 ┆ 56.633641 ┆ 1.337633 ┆ 2.361905 │ │ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │ │ noop ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 280.71254 ┆ 275.686991 ┆ 5.025549 ┆ 1.822918 │ │ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │ │ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 16384, 16, ┆ 152.970031 ┆ 150.489109 ┆ 2.480923 ┆ 1.648573 │ │ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │ │ causal ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 161.031318 ┆ 158.597808 ┆ 2.43351 ┆ 1.534391 │ │ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │ └────────────────┴────────────────┴────────────────────────┴───────────────┴──────────────┴──────────┴───────────┘ 🔻 Top 5 Negative TFlops Deltas (lowest -%): shape: (5, 7) ┌────────────────┬────────────────┬───────────────────────┬───────────────┬──────────────┬───────────┬───────────┐ │ attn_type ┆ dtype ┆ shape(B,Hq,M,Hkv,N,D) ┆ TFlops (base) ┆ TFlops (max) ┆ delta ┆ pct_delta │ │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │ │ str ┆ str ┆ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │ ╞════════════════╪════════════════╪═══════════════════════╪═══════════════╪══════════════╪═══════════╪═══════════╡ │ noop ┆ torch.bfloat16 ┆ (4, 16, 1024, 16, ┆ 244.052884 ┆ 248.65129 ┆ -4.598406 ┆ -1.849339 │ │ ┆ ┆ 1024, 64) ┆ ┆ ┆ ┆ │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 4, ┆ 175.546923 ┆ 177.81205 ┆ -2.265127 ┆ -1.273888 │ │ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │ │ sliding_window ┆ torch.bfloat16 ┆ (4, 16, 16384, 4, ┆ 156.282597 ┆ 158.209134 ┆ -1.926537 ┆ -1.217715 │ │ ┆ ┆ 16384, 64) ┆ ┆ ┆ ┆ │ │ sliding_window ┆ torch.bfloat16 ┆ (2, 16, 2048, 16, ┆ 232.542929 ┆ 235.140136 ┆ -2.597207 ┆ -1.104536 │ │ ┆ ┆ 2048, 128) ┆ ┆ ┆ ┆ │ │ alibi ┆ torch.bfloat16 ┆ (2, 16, 1024, 16, ┆ 169.652791 ┆ 171.475986 ┆ -1.823195 ┆ -1.063236 │ │ ┆ ┆ 1024, 128) ┆ ┆ ┆ ┆ │ └────────────────┴────────────────┴───────────────────────┴───────────────┴──────────────┴───────────┴───────────┘ ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/161667 Approved by: https://github.com/Chillee, https://github.com/BoyuanFeng
9000 lines
348 KiB
Python
9000 lines
348 KiB
Python
# Owner(s): ["module: nestedtensor"]
|
|
# ruff: noqa: F841
|
|
import ast
|
|
import io
|
|
import itertools
|
|
import math
|
|
import os
|
|
import random
|
|
import sys
|
|
import tempfile
|
|
import unittest
|
|
from functools import partial
|
|
from typing import Optional
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._dynamo
|
|
import torch._dynamo.testing
|
|
import torch.nn
|
|
import torch.nn.functional as F
|
|
from torch.nested._internal.nested_tensor import (
|
|
buffer_from_jagged,
|
|
jagged_from_list,
|
|
nested_view_from_values_offsets,
|
|
NestedTensor,
|
|
ViewNestedFromBuffer,
|
|
)
|
|
from torch.nn.attention.flex_attention import create_nested_block_mask, flex_attention
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
|
SM70OrLater,
|
|
SM80OrLater,
|
|
tf32_on_and_off,
|
|
)
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
dtypesIfCUDA,
|
|
flex_attention_supported_platform,
|
|
instantiate_device_type_tests,
|
|
onlyCPU,
|
|
onlyCUDA,
|
|
ops,
|
|
PYTORCH_CUDA_MEMCHECK,
|
|
skipCPUIf,
|
|
skipCUDAIf,
|
|
skipCUDAIfRocm,
|
|
skipMeta,
|
|
)
|
|
from torch.testing._internal.common_dtype import floating_types_and_half
|
|
from torch.testing._internal.common_utils import (
|
|
decorateIf,
|
|
freeze_rng_state,
|
|
gradcheck,
|
|
instantiate_parametrized_tests,
|
|
IS_FBCODE,
|
|
IS_WINDOWS,
|
|
markDynamoStrictTest,
|
|
NestedTensorTestCase,
|
|
parametrize,
|
|
run_tests,
|
|
serialTest,
|
|
skipIfRocm,
|
|
skipIfSlowGradcheckEnv,
|
|
skipIfTorchDynamo,
|
|
subtest,
|
|
TEST_WITH_ROCM,
|
|
xfailIfTorchDynamo,
|
|
)
|
|
from torch.testing._internal.opinfo.core import (
|
|
BinaryUfuncInfo,
|
|
ReductionOpInfo,
|
|
sample_skips_and_xfails,
|
|
SkipRule,
|
|
XFailRule,
|
|
)
|
|
from torch.testing._internal.opinfo.definitions.nested import _sample_njts, njt_op_db
|
|
from torch.utils._pytree import tree_flatten, tree_map_only
|
|
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
|
|
|
|
|
|
# Tests are ported from pytorch/nestedtensor.
|
|
# This makes porting as_nested_tensor easier in the future.
|
|
|
|
|
|
def _iter_constructors():
|
|
# yield as_nested_tensor
|
|
yield torch.nested.nested_tensor
|
|
|
|
|
|
# Returns True if the function recompiles between inputs1 and inputs2 with the
|
|
# specified dynamic setting.
|
|
def _recompiles_for_inputs(fn, inputs1, inputs2, dynamic=True):
|
|
compile_count = [0]
|
|
|
|
def counter(gm, example_inputs):
|
|
compile_count[0] += 1
|
|
return gm
|
|
|
|
compiled_f = torch.compile(fn, fullgraph=True, backend=counter, dynamic=dynamic)
|
|
compiled_f(*inputs1)
|
|
compiled_f(*inputs2)
|
|
return compile_count[0] > 1
|
|
|
|
|
|
# Helper function to generate a pair of random nested tensors
|
|
# one is contiguous, the other is not, but they appear to have same entries
|
|
# an output nested tensor consists of
|
|
# * `len(ragged_sizes)` matrices
|
|
# * matrices[i].shape == (20, ragged_sizes[i])
|
|
|
|
|
|
def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
|
|
xs = []
|
|
for size in ragged_sizes:
|
|
xs.append(torch.randn((size, 20), device=device, dtype=dtype))
|
|
# contiguous nested tensor
|
|
ys = []
|
|
for x in xs:
|
|
ys.append(x.transpose(-1, -2))
|
|
nt_contiguous = torch.nested.nested_tensor(ys)
|
|
# noncontiguous nested tensor
|
|
n = len(ragged_sizes)
|
|
nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2)
|
|
return nt_contiguous, nt_noncontiguous
|
|
|
|
|
|
# Helper functions to pad a noncontiguous nested tensor
|
|
# can be replaced once to_padded_tensor supports noncontiguous memory
|
|
|
|
|
|
def noncontiguous_to_padded_tensor(input, shape=None):
|
|
tensors = input.unbind()
|
|
ntensors = len(tensors)
|
|
assert ntensors > 0
|
|
if shape is None:
|
|
shape = []
|
|
for size in tensors[0].shape:
|
|
shape.append(size)
|
|
for i in range(1, ntensors):
|
|
new_shape = tensors[i].shape
|
|
for j in range(len(shape)):
|
|
shape[j] = max(shape[j], new_shape[j])
|
|
shape = [ntensors] + shape
|
|
result = tensors[0].new_zeros(shape)
|
|
for itensor in range(ntensors):
|
|
tensor = tensors[itensor]
|
|
view = result[itensor]
|
|
for idim in range(tensor.dim()):
|
|
view = view.narrow(idim, 0, tensor.size(idim))
|
|
view.copy_(tensor)
|
|
return result
|
|
|
|
|
|
# Helper function to generate a random nested tensor
|
|
|
|
|
|
def random_nt(
|
|
device,
|
|
dtype,
|
|
num_tensors,
|
|
max_dims,
|
|
min_dims=None,
|
|
layout=torch.strided,
|
|
require_non_empty=True,
|
|
):
|
|
if min_dims is None:
|
|
min_dims = tuple([0] * len(max_dims))
|
|
|
|
assert len(max_dims) == len(min_dims)
|
|
for min_dim, max_dim in zip(min_dims, max_dims):
|
|
assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
|
|
assert min_dim >= 0, "random_nt: min_dim must be non-negative"
|
|
if require_non_empty:
|
|
assert not (min_dim == 0 and max_dim == 1), (
|
|
"random_nt: zero cannot be the only possible value if require_non_empty is True"
|
|
)
|
|
|
|
if require_non_empty:
|
|
# Select a random idx that will be required to be non-empty
|
|
non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item()
|
|
|
|
ts1 = []
|
|
for i, _ in enumerate(range(num_tensors)):
|
|
tensor_dims = []
|
|
for min_dim, max_dim in zip(min_dims, max_dims):
|
|
new_min_dim = min_dim
|
|
if require_non_empty and i == non_zero_idx and min_dim == 0:
|
|
new_min_dim = 1
|
|
tensor_dims.append(
|
|
torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item()
|
|
)
|
|
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
|
|
ts1.append(t1)
|
|
|
|
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)
|
|
|
|
|
|
# Alternate approach to generating a random NT.
|
|
# dims should be something like [5, None, 10], with None indicating that a
|
|
# random ragged structure should be used
|
|
def random_nt_from_dims(
|
|
dims, device=None, dtype=None, layout=torch.strided, requires_grad=False
|
|
):
|
|
sizes = [
|
|
[
|
|
d if d is not None else torch.randint(2, 10, size=(1,)).item()
|
|
for d in dims[1:]
|
|
]
|
|
for d in range(dims[0])
|
|
]
|
|
return torch.nested.nested_tensor(
|
|
[torch.randn(*size) for size in sizes],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
|
|
# Creates an NT matching another NT's number of components and
|
|
# shape / ragged structure for all dims specified to be -1.
|
|
def random_nt_from_similar(other, dims=None):
|
|
if dims is None:
|
|
return torch.randn_like(other)
|
|
assert len(dims) == other.dim()
|
|
assert dims[0] == -1 or dims[0] == other.size(0)
|
|
|
|
ret_sizes = []
|
|
for t in other.unbind():
|
|
other_size = t.shape
|
|
ret_size = []
|
|
for i, d in enumerate(dims[1:]):
|
|
if d == -1:
|
|
ret_size.append(other_size[i])
|
|
else:
|
|
ret_size.append(d)
|
|
ret_sizes.append(ret_size)
|
|
|
|
return torch.nested.nested_tensor(
|
|
[torch.randn(*size) for size in ret_sizes], device=other.device
|
|
)
|
|
|
|
|
|
# makes naming nice for tests that parametrize over layout.
|
|
def layout_name(layout):
|
|
# e.g. "torch.jagged" -> "jagged"
|
|
return layout.__repr__().split(".")[-1]
|
|
|
|
|
|
def get_op_name(layout):
|
|
# e.g. "<OpOverload(op='aten.sum', overload='dim_IntList')>" -> "sum"
|
|
return layout.__name__.split(".")[0].split("_")[-1]
|
|
|
|
|
|
# Helper function for test_dummy_mha_with_nt
|
|
@torch.fx.wrap
|
|
def convert_dense_to_nested_tensor_legacy(values):
|
|
offsets = torch.arange(
|
|
0, values.shape[0] * values.shape[1] + 1, values.shape[1], device=values.device
|
|
)
|
|
metadata_cache = {"max_seqlen": values.shape[1], "min_seqlen": 1}
|
|
nt = ViewNestedFromBuffer.apply(
|
|
values.view(-1, values.shape[-1]), offsets, metadata_cache
|
|
)
|
|
return nt
|
|
|
|
|
|
# Helper function for test_dummy_mha_with_nt
|
|
@torch.fx.wrap
|
|
def convert_jagged_to_nested_tensor_legacy(
|
|
values: torch.Tensor, offsets: torch.Tensor, max_length: int
|
|
) -> torch.Tensor:
|
|
metadata_cache = {"max_seqlen": max_length, "min_seqlen": 1}
|
|
nt = ViewNestedFromBuffer.apply(values, offsets, metadata_cache)
|
|
return nt
|
|
|
|
|
|
# Helper function for test_dummy_mha_with_nt
|
|
@torch.fx.wrap
|
|
def convert_nt_to_jagged_legacy(nt):
|
|
return buffer_from_jagged(nt)
|
|
|
|
|
|
# Helper function for test_dummy_mha_with_nt
|
|
@torch.fx.wrap
|
|
def convert_dense_to_nested_tensor(values):
|
|
nt = torch.nested.as_nested_tensor(values, layout=torch.jagged)
|
|
return nt
|
|
|
|
|
|
# Helper function for test_dummy_mha_with_nt
|
|
@torch.fx.wrap
|
|
def convert_jagged_to_nested_tensor(
|
|
values: torch.Tensor, offsets: torch.Tensor, max_length: int
|
|
) -> torch.Tensor:
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets, lengths=None, min_seqlen=1, max_seqlen=max_length
|
|
)
|
|
return nt
|
|
|
|
|
|
# Helper function for test_dummy_mha_with_nt
|
|
def convert_nt_to_jagged(nt):
|
|
return nt.values()
|
|
|
|
|
|
@markDynamoStrictTest
|
|
class TestNestedTensor(NestedTensorTestCase):
|
|
@parametrize("batch_size", [2, 4])
|
|
@parametrize("max_seq_len", [3, 5])
|
|
@parametrize("vocab_size", [10, 20])
|
|
def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
|
|
data = []
|
|
nested_tensor_ref_list = []
|
|
for _ in range(batch_size):
|
|
if max_seq_len == 0:
|
|
length = 0
|
|
else:
|
|
length = np.random.randint(low=1, high=max_seq_len)
|
|
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
|
|
data.append(row)
|
|
nested_tensor_ref_list.append(torch.Tensor(row))
|
|
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
|
|
nested_tensor_list = nested_tensor.unbind()
|
|
for id in range(batch_size):
|
|
self.assertEqual(
|
|
nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
|
|
)
|
|
|
|
@parametrize("batch_size", [2, 4])
|
|
@parametrize("max_seq_len", [3, 5])
|
|
@parametrize("vocab_size", [10, 20])
|
|
def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
|
|
data = []
|
|
nested_tensor_ref_list = []
|
|
for _ in range(batch_size):
|
|
if max_seq_len == 0:
|
|
length = 0
|
|
else:
|
|
length = np.random.randint(low=1, high=max_seq_len)
|
|
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
|
|
row = [list(item * np.arange(max_seq_len)) for item in row]
|
|
data.append(row)
|
|
nested_tensor_ref_list.append(torch.Tensor(row))
|
|
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
|
|
nested_tensor_list = nested_tensor.unbind()
|
|
for id in range(batch_size):
|
|
self.assertEqual(
|
|
nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.int64)
|
|
)
|
|
|
|
@parametrize("batch_size", [2, 4])
|
|
@parametrize("max_seq_len", [3, 5])
|
|
@parametrize("vocab_size", [10, 20])
|
|
def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
|
|
data = []
|
|
nested_tensor_ref_list = []
|
|
for _ in range(batch_size):
|
|
if max_seq_len == 0:
|
|
length = 0
|
|
else:
|
|
length = np.random.randint(low=1, high=max_seq_len)
|
|
row = list(
|
|
np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
|
|
)
|
|
row = [list(item * np.arange(max_seq_len)) for item in row]
|
|
data.append(row)
|
|
nested_tensor_ref_list.append(torch.Tensor(row))
|
|
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
|
|
nested_tensor_list = nested_tensor.unbind()
|
|
for id in range(batch_size):
|
|
self.assertEqual(
|
|
nested_tensor_list[id], nested_tensor_ref_list[id].type(torch.float)
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def _test_unbind_case(self, a, b):
|
|
nt = torch.nested.nested_tensor([a, b])
|
|
a1, b1 = nt.unbind()
|
|
self.assertTrue(a is not a1)
|
|
self.assertTrue(b is not b1)
|
|
|
|
nt = torch.nested.nested_tensor([a, b], dtype=a.dtype)
|
|
a1, b1 = nt.unbind(0)
|
|
self.assertEqual(a, a1)
|
|
self.assertEqual(b, b1)
|
|
|
|
a = torch.randn((2, 3)).add_(1)
|
|
nt = torch.nested.nested_tensor([a])
|
|
self.assertEqual(a, nt.unbind(0)[0])
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_0(self):
|
|
self._test_unbind_case(torch.tensor([1, 2]), torch.tensor([7, 8]))
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_1(self):
|
|
self._test_unbind_case(torch.tensor([1]), torch.tensor([7]))
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_3(self):
|
|
self._test_unbind_case(torch.tensor([1.0]), torch.tensor([]))
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_4(self):
|
|
self._test_unbind_case(torch.tensor([]), torch.tensor([]))
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_dim(self):
|
|
def _test_fn(unbind_fn):
|
|
a = torch.rand(3, 2)
|
|
b = torch.rand(2, 3)
|
|
nt = torch.nested.nested_tensor([a, b])
|
|
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
|
|
|
|
# Both of these tests are necessary, because we're using
|
|
# torch_function.
|
|
_test_fn(lambda x, dim: x.unbind(dim))
|
|
# TODO: Re-enable this once using torch_dispatch
|
|
# _test_fn(lambda x, dim: torch.unbind(x, dim))
|
|
|
|
@torch.inference_mode()
|
|
def test_nested_tensor(self):
|
|
self.assertRaises(
|
|
TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0]))
|
|
)
|
|
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))
|
|
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_matching_dim(self):
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
|
|
lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
|
|
lambda: torch.nested.nested_tensor(
|
|
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
|
|
),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_default_nested_tensor(self):
|
|
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor())
|
|
default_nested_tensor = torch.nested.nested_tensor([])
|
|
default_tensor = torch.tensor([])
|
|
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
|
|
# self.assertEqual(default_nested_tensor.nested_size(), ())
|
|
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
|
|
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
|
|
self.assertEqual(default_nested_tensor.device, default_tensor.device)
|
|
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
|
|
self.assertEqual(
|
|
default_nested_tensor.requires_grad, default_tensor.requires_grad
|
|
)
|
|
self.assertIsNone(default_tensor.grad)
|
|
# TODO: Re-enable once we have a performance driven
|
|
# use case and implementation.
|
|
# self.assertEqual(default_nested_tensor.is_pinned(),
|
|
# default_tensor.is_pinned())
|
|
|
|
@torch.inference_mode()
|
|
def test_dim(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertEqual(a1.dim(), 1)
|
|
a1 = constructor([torch.tensor(3.0)])
|
|
self.assertEqual(a1.dim(), 1)
|
|
a1 = constructor([torch.tensor([1, 2, 3, 4])])
|
|
self.assertEqual(a1.dim(), 2)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_numel(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertEqual(a1.numel(), 0)
|
|
a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
|
|
self.assertEqual(a1.numel(), 2)
|
|
a1 = constructor([torch.randn(2, 2, 2)])
|
|
self.assertEqual(a1.numel(), 8)
|
|
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
|
|
self.assertEqual(a1.numel(), 12)
|
|
a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
|
|
self.assertEqual(a1.numel(), 27)
|
|
a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
|
|
self.assertEqual(a1.numel(), 341)
|
|
|
|
# Interesting edge case
|
|
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
|
|
self.assertEqual(a1.numel(), 6)
|
|
|
|
@torch.inference_mode()
|
|
def test_size(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensorImpl doesn't support sizes",
|
|
lambda: a1.size(),
|
|
)
|
|
|
|
def test_size_dim(self):
|
|
a = torch.nested.nested_tensor([])
|
|
self.assertEqual(a.size(0), 0)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor(1)])
|
|
self.assertEqual(a.size(0), 1)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)])
|
|
self.assertEqual(a.size(0), 2)
|
|
|
|
a = torch.nested.nested_tensor([torch.rand(1, 2), torch.rand(1, 8)])
|
|
self.assertEqual(a.size(0), 2)
|
|
self.assertEqual(a.size(1), 1)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Given dimension 2 is irregular and does not have a size",
|
|
lambda: a.size(2),
|
|
)
|
|
|
|
a = torch.nested.nested_tensor([torch.rand(3, 4), torch.rand(5, 4)])
|
|
self.assertEqual(a.size(0), 2)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Given dimension 1 is irregular and does not have a size",
|
|
lambda: a.size(1),
|
|
)
|
|
self.assertEqual(a.size(2), 4)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_stride(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensorImpl doesn't support strides",
|
|
lambda: a1.stride(),
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_is_contiguous(self):
|
|
# Test empty case
|
|
nt_empty = torch.nested.nested_tensor([])
|
|
assert nt_empty.is_contiguous()
|
|
self.assertEqual(nt_empty, nt_empty.contiguous())
|
|
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
|
|
# Test contiguous case
|
|
assert nt_contiguous.is_contiguous()
|
|
self.assertEqual(nt_contiguous, nt_contiguous.contiguous())
|
|
|
|
# Test non_contiguous case
|
|
assert not nt_noncontiguous.is_contiguous()
|
|
self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
|
|
|
|
# Test querying by memory_format
|
|
self.assertTrue(
|
|
nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
|
|
)
|
|
self.assertTrue(
|
|
not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_repr_string(self):
|
|
a = torch.nested.nested_tensor([])
|
|
expected = "nested_tensor([\n\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor(1.0)])
|
|
expected = "nested_tensor([\n tensor(1.)\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
|
|
expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
def test_to_padded_tensor_on_empty_tensor(self):
|
|
nt = torch.nested.nested_tensor([])
|
|
empty = torch.nested.to_padded_tensor(nt, 4)
|
|
self.assertEqual(empty, torch.tensor([]))
|
|
|
|
def test_nested_namespace(self):
|
|
nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)])
|
|
result = nt.to_padded_tensor(4)
|
|
nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
|
|
self.assertEqual(result, nested_namespace_result)
|
|
|
|
def test_to(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
|
|
|
def test_copy_behavior(t, non_blocking=False):
|
|
self.assertIs(t, t.to(t, non_blocking=non_blocking))
|
|
self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
|
|
self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
|
|
self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
|
|
self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
|
|
self.assertIsNot(
|
|
t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)
|
|
)
|
|
|
|
devices = [t.device]
|
|
if t.device.type == "cuda":
|
|
if t.device.index == -1:
|
|
devices.append(f"cuda:{torch.cuda.current_device()}")
|
|
elif t.device.index == torch.cuda.current_device():
|
|
devices.append("cuda")
|
|
for device in devices:
|
|
self.assertIs(t, t.to(device, non_blocking=non_blocking))
|
|
self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
|
|
self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
|
|
self.assertIsNot(
|
|
t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True)
|
|
)
|
|
|
|
test_copy_behavior(nt)
|
|
self.assertEqual(nt.device, nt.to("cpu").device)
|
|
self.assertEqual(nt.device, nt.to("cpu", dtype=torch.float32).device)
|
|
self.assertIs(torch.float32, nt.to("cpu", dtype=torch.float32).dtype)
|
|
self.assertEqual(nt.device, nt.to(torch.float32).device)
|
|
self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype)
|
|
|
|
def test_data_ptr(getter):
|
|
self.assertEqual(getter(nt), getter(nt.to("cpu")))
|
|
self.assertEqual(
|
|
getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False))
|
|
)
|
|
self.assertEqual(getter(nt), getter(nt.to("cpu", copy=False)))
|
|
self.assertNotEqual(getter(nt), getter(nt.to("cpu", copy=True)))
|
|
|
|
test_data_ptr(lambda nt: nt.data_ptr())
|
|
|
|
if torch.cuda.is_available():
|
|
for non_blocking in [True, False]:
|
|
for cuda in [
|
|
"cuda",
|
|
"cuda:0" if torch.cuda.device_count() == 1 else "cuda:1",
|
|
]:
|
|
nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4))
|
|
test_copy_behavior(nt2, non_blocking)
|
|
self.assertEqual(
|
|
nt2.device, nt2.to(cuda, non_blocking=non_blocking).device
|
|
)
|
|
self.assertEqual(
|
|
nt.device, nt2.to("cpu", non_blocking=non_blocking).device
|
|
)
|
|
self.assertEqual(
|
|
nt2.device, nt.to(cuda, non_blocking=non_blocking).device
|
|
)
|
|
self.assertIs(
|
|
torch.int32,
|
|
nt2.to(
|
|
"cpu", dtype=torch.int32, non_blocking=non_blocking
|
|
).dtype,
|
|
)
|
|
self.assertEqual(
|
|
nt.device,
|
|
nt2.to(
|
|
"cpu", dtype=torch.int32, non_blocking=non_blocking
|
|
).device,
|
|
)
|
|
self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
|
|
self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)
|
|
|
|
def test_copy_(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
|
nt_copy = torch.empty_like(nt)
|
|
nt_copy.copy_(nt)
|
|
|
|
for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
|
|
self.assertEqual(nt_ub, nt_copy_ub)
|
|
|
|
nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"copy_ only supports tensors that are the same size for Nested implementations",
|
|
lambda: nt_error.copy_(nt),
|
|
)
|
|
|
|
if torch.cuda.is_available():
|
|
nt = random_nt(torch.device("cuda"), torch.float32, ntensors, (4, 4))
|
|
nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
|
|
nt_copy.copy_(nt, non_blocking=True)
|
|
torch.cuda.current_stream(torch.cuda.current_device()).synchronize()
|
|
for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
|
|
self.assertEqual(nt_ub, nt_copy_ub)
|
|
|
|
nt_copy = torch.empty_like(nt, device=torch.device("cpu"))
|
|
nt_copy.copy_(nt, non_blocking=False)
|
|
for nt_ub, nt_copy_ub in zip(nt.unbind(), nt_copy):
|
|
self.assertEqual(nt_ub, nt_copy_ub)
|
|
|
|
def test_fill_(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
|
nt.fill_(10.0)
|
|
for nt_ub in nt.unbind():
|
|
t = torch.empty_like(nt_ub)
|
|
t.fill_(10.0)
|
|
self.assertEqual(nt_ub, t)
|
|
|
|
fill_tensor = torch.tensor([11.0])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"fill_ only supports 0-dimension value tensor",
|
|
lambda: nt.fill_(fill_tensor),
|
|
)
|
|
|
|
nt.fill_(fill_tensor[0])
|
|
for nt_ub in nt.unbind():
|
|
t = torch.empty_like(nt_ub)
|
|
t.fill_(11.0)
|
|
self.assertEqual(nt_ub, t)
|
|
|
|
def test_zero_(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
|
nt.zero_()
|
|
for nt_ub in nt.unbind():
|
|
t = torch.empty_like(nt_ub)
|
|
t.fill_(0.0)
|
|
self.assertEqual(nt_ub, t)
|
|
|
|
@parametrize(
|
|
"func",
|
|
[torch.ones_like, torch.zeros_like, torch.randn_like],
|
|
name_fn=lambda f: f.__name__,
|
|
)
|
|
def test_like_functions(self, func):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device("cpu"), torch.float32, ntensors, (4, 4))
|
|
torch.manual_seed(1)
|
|
nt_like = func(nt)
|
|
|
|
torch.manual_seed(1)
|
|
for nt_ub in nt_like.unbind():
|
|
t_like = func(nt_ub)
|
|
self.assertEqual(nt_ub, t_like)
|
|
|
|
def test_cat(self):
|
|
# dim=0 success case
|
|
# No constraints on ragged structures matching.
|
|
x = random_nt_from_dims([5, None, 10])
|
|
y = random_nt_from_dims([3, 4, None])
|
|
output = torch.cat([x, y], dim=0)
|
|
for out_component, xy_component in zip(
|
|
output.unbind(), itertools.chain(x.unbind(), y.unbind())
|
|
):
|
|
self.assertEqual(out_component, xy_component)
|
|
|
|
# dim=-1 success case
|
|
# shape (B, *, D)
|
|
x = random_nt_from_dims([5, None, 10])
|
|
# shape (B, *, D'); same structure as x but dim=-1 differs
|
|
y = random_nt_from_similar(x, dims=[-1, -1, 8])
|
|
# should be shape (B, *, D + D') when supported
|
|
output = torch.cat([x, y], dim=-1)
|
|
for out_component, x_component, y_component in zip(
|
|
output.unbind(), x.unbind(), y.unbind()
|
|
):
|
|
self.assertEqual(
|
|
out_component, torch.cat([x_component, y_component], dim=-1)
|
|
)
|
|
|
|
# dim between 0 and -1 success case
|
|
x = random_nt_from_dims([5, None, 2, 3])
|
|
# same structure as x but dim=2 differs
|
|
y = random_nt_from_similar(x, dims=[-1, -1, 4, -1])
|
|
output = torch.cat([x, y], dim=2)
|
|
for out_component, x_component, y_component in zip(
|
|
output.unbind(), x.unbind(), y.unbind()
|
|
):
|
|
self.assertEqual(
|
|
out_component, torch.cat([x_component, y_component], dim=1)
|
|
)
|
|
|
|
# error case: mixed NT / dense inputs
|
|
x = random_nt_from_dims([5, None, 2])
|
|
y = torch.randn(5, 3, 2)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "expected each tensor in given list to be nested"
|
|
):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: NTs with different dims
|
|
x = random_nt_from_dims([5, None, 2])
|
|
y = random_nt_from_dims([5, None, 2, 3])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"expected all nested tensors to have matching ragged structures outside of the concatenated dim",
|
|
):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: non-contiguous NT
|
|
x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32)
|
|
# transpose to put ragged dim next to batch dim
|
|
x, y = x.transpose(-2, -1), y.transpose(-2, -1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only contiguous nested tensors are supported"
|
|
):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: multiple ragged dims in inputs
|
|
x = random_nt_from_dims([5, None, None, 2])
|
|
y = random_nt_from_similar(x)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"only nested tensors with a single ragged dim next to the batch dim are supported",
|
|
):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: ragged dim not next to batch dim
|
|
x = random_nt_from_dims([5, 2, None])
|
|
y = random_nt_from_similar(x)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"only nested tensors with a single ragged dim next to the batch dim are supported",
|
|
):
|
|
torch.cat([x, y], dim=1)
|
|
|
|
# error case: NTs with different batch sizes
|
|
x = random_nt_from_dims([5, None, 2])
|
|
y = random_nt_from_dims([3, None, 2])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"expected all nested tensors to have matching ragged structures outside of the concatenated dim",
|
|
):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: NTs with different ragged structures
|
|
x = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 6),
|
|
torch.randn(4, 6),
|
|
torch.randn(5, 6),
|
|
]
|
|
)
|
|
y = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(5, 6),
|
|
torch.randn(4, 6),
|
|
torch.randn(2, 6),
|
|
]
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"expected all nested tensors to have matching ragged structures outside of the concatenated dim",
|
|
):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
def test_nested_view_from_buffer_overflow_errors(self):
|
|
buffer = torch.tensor([1])
|
|
sizes = torch.tensor([[2**63 - 1], [2**63 - 1], [3]], dtype=torch.int64)
|
|
strides = torch.tensor(
|
|
[[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64
|
|
)
|
|
offsets = torch.tensor(
|
|
[[0x41414141], [0x41414141], [0x41414141]], dtype=torch.int64
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"Storage size calculation overflowed with sizes=\[9223372036854775807\] and strides=\[1094795585\]",
|
|
):
|
|
nt = torch._nested_view_from_buffer(buffer, sizes, strides, offsets)
|
|
|
|
|
|
@markDynamoStrictTest
|
|
class TestNestedTensorDeviceType(NestedTensorTestCase):
|
|
# Helper function to generate a pair of random nested tensors
|
|
# the 2 nested tensors have same shapes
|
|
def random_nt_pair(self, device, dtype, num_tensors, max_dims):
|
|
ts1 = []
|
|
ts2 = []
|
|
for _ in range(num_tensors):
|
|
tensor_dims = tuple(
|
|
[
|
|
torch.randint(low=0, high=max_dim, size=(1,)).item()
|
|
for max_dim in max_dims
|
|
]
|
|
)
|
|
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
|
|
t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
|
|
ts1.append(t1)
|
|
ts2.append(t2)
|
|
return (
|
|
torch.nested.nested_tensor(ts1, device=device, dtype=dtype),
|
|
torch.nested.nested_tensor(ts2, device=device, dtype=dtype),
|
|
)
|
|
|
|
@dtypes(*floating_types_and_half())
|
|
def test_detach(self, device, dtype):
|
|
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
|
|
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
|
|
x = torch.nested.nested_tensor([a, b], requires_grad=True)
|
|
|
|
x_detach = x.detach()
|
|
|
|
z = x_detach * 4
|
|
self.assertFalse(x_detach.requires_grad)
|
|
self.assertFalse(z.requires_grad)
|
|
|
|
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
|
|
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
|
|
x = torch.nested.as_nested_tensor([a, b])
|
|
|
|
y = x * 2
|
|
y = y.detach()
|
|
self.assertFalse(y.requires_grad)
|
|
self.assertIsNone(y.grad_fn)
|
|
|
|
z = x + y
|
|
torch.nested.to_padded_tensor(z, 0).sum().backward()
|
|
# This is an incorrect gradient, but we assume that's what the user
|
|
# wanted. detach() is an advanced option.
|
|
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
|
|
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
|
|
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("weights_only", [False, True])
|
|
def test_serialization(self, device, dtype, requires_grad, weights_only):
|
|
def compare_metadata(nt1, nt2):
|
|
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
|
|
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
|
|
self.assertEqual(
|
|
nt1._nested_tensor_storage_offsets(),
|
|
nt2._nested_tensor_storage_offsets(),
|
|
)
|
|
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
for a in [nt_contiguous, nt_noncontiguous]:
|
|
buffer = io.BytesIO()
|
|
serialized = torch.save(a, buffer)
|
|
buffer.seek(0)
|
|
b = torch.load(buffer, weights_only=weights_only)
|
|
# should be both conceptually equal and metadata equivalent
|
|
self.assertEqual(a, b)
|
|
compare_metadata(a, b)
|
|
# should be conceptually equal but not necessarily metadata equivalent
|
|
self.assertEqual(b, nt_contiguous)
|
|
self.assertEqual(b, nt_noncontiguous)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_unbind_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device, dtype
|
|
)
|
|
ub_contiguous = nt_contiguous.unbind()
|
|
ub_noncontiguous = nt_noncontiguous.unbind()
|
|
self.assertEqual(len(ub_contiguous), len(ub_noncontiguous))
|
|
n = len(ub_contiguous)
|
|
for i in range(n):
|
|
self.assertEqual(ub_contiguous[i], ub_noncontiguous[i])
|
|
|
|
@dtypes(torch.float)
|
|
@skipMeta
|
|
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
padded = torch.nested.to_padded_tensor(nt, 0)
|
|
|
|
nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
|
|
|
|
for t1, t2 in zip(nt.unbind(), nt_to.unbind()):
|
|
self.assertEqual(t1, t2)
|
|
self.assertEqual(nt.device, nt_to.device)
|
|
|
|
@dtypes(torch.float)
|
|
@dtypesIfCUDA(torch.float, torch.half)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_layer_norm(self, device, dtype):
|
|
def _test(size):
|
|
# Simple shapes test
|
|
t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t0, t1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for nt_subresult, t in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
# More complex nt test with different lengths for each tensor
|
|
t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False)
|
|
t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t2, t0, t2]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for nt_subresult, t in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
if size <= 128:
|
|
# Test with multidimensional tensors after irregular dim
|
|
# (run only with smaller dimensions to ensure fast execution)
|
|
t0 = torch.randn(
|
|
4, size, size, 4, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
t1 = torch.randn(
|
|
10, size, size, 4, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
t2 = torch.randn(
|
|
7, size, size, 4, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
ts = [t0, t1, t2, t0, t2]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm(
|
|
(size, size, 4), device=device, dtype=dtype
|
|
)
|
|
nt_result = layer_norm(nt)
|
|
for nt_subresult, t in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
# Test where the normalizing dimensions are not all
|
|
layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for nt_subresult, t in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
|
|
_test(size)
|
|
|
|
@dtypes(torch.float)
|
|
@dtypesIfCUDA(torch.float, torch.half)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_layer_norm_breaking(self, device, dtype):
|
|
size = 128
|
|
t0 = torch.randn(
|
|
4, size, size, 4, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
t1 = torch.randn(
|
|
10, size, size, 4, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
t2 = torch.randn(
|
|
7, size, size, 4, device=device, dtype=dtype, requires_grad=False
|
|
)
|
|
ts = [t0, t1, t2, t0, t2]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"normalized_shape extends into irregular dimensions for the nested tensor",
|
|
lambda: layer_norm(nt),
|
|
)
|
|
layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"The shape at dimension 0",
|
|
lambda: layer_norm(nt),
|
|
)
|
|
|
|
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
|
|
def test_embedding(self, device, layout):
|
|
inputs = [
|
|
torch.randint(100, (L,), device=device, dtype=torch.int64)
|
|
for L in torch.randint(5, 50, (8,))
|
|
]
|
|
x = torch.nested.nested_tensor(
|
|
inputs, device=device, dtype=torch.int64, layout=layout
|
|
)
|
|
emb = torch.nn.Embedding(100, 8, device=device)
|
|
y = emb(x)
|
|
if layout == torch.jagged:
|
|
y.backward(torch.randn_like(y))
|
|
|
|
@torch._dynamo.disable
|
|
def check(inputs, y):
|
|
ys = y.unbind()
|
|
for i, inp in enumerate(inputs):
|
|
self.assertEqual(emb(inp), ys[i])
|
|
|
|
check(inputs, y)
|
|
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
@dtypes(*floating_types_and_half())
|
|
def test_masked_fill(self, device, dtype):
|
|
# nested tensor * nested tensor
|
|
(nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
|
|
ref = torch.nested.nested_tensor(
|
|
[t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())]
|
|
)
|
|
out = nt.masked_fill(mask, 0)
|
|
self.assertEqual(ref, out)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_to_padded_tensor_simple(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
for padding_value in (0, 1):
|
|
padded = torch.nested.to_padded_tensor(nt, padding_value)
|
|
|
|
correct_output = t.clone()
|
|
if padding_value == 0:
|
|
correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
|
|
else:
|
|
correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
|
|
|
|
self.assertEqual(padded, correct_output)
|
|
self.assertEqual(padded.device, torch.device(device))
|
|
self.assertEqual(padded.dtype, dtype)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_to_padded_tensor_output_size(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
output_size = (4, 6, 5)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
for padding_value in (0, 1):
|
|
padded = torch.nested.to_padded_tensor(
|
|
nt, padding_value, output_size=output_size
|
|
)
|
|
correct_output = (
|
|
torch.ones(output_size, device=device, dtype=dtype) * padding_value
|
|
)
|
|
correct_output[:4:, :4, :4] = t.clone()
|
|
if padding_value == 0:
|
|
correct_output[0][3] = torch.zeros_like(correct_output[0][3])
|
|
else:
|
|
correct_output[0][3] = torch.ones_like(correct_output[0][3])
|
|
|
|
self.assertEqual(padded, correct_output)
|
|
self.assertEqual(padded.device, torch.device(device))
|
|
self.assertEqual(padded.dtype, dtype)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_dim2(self, device, dtype):
|
|
ts = [
|
|
torch.randn(160, device=device, dtype=dtype),
|
|
torch.randn(1240, device=device, dtype=dtype),
|
|
torch.randn(2400, device=device, dtype=dtype),
|
|
]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
pad = 42
|
|
correct_output = []
|
|
for t in ts:
|
|
next_output = torch.ones_like(ts[2]) * pad
|
|
correct_output.append(next_output)
|
|
next_output[: t.size(0)].copy_(t)
|
|
correct_output = torch.stack(correct_output)
|
|
padded = torch.nested.to_padded_tensor(nt, pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_dim3(self, device, dtype):
|
|
ts = [
|
|
torch.randn(16, 21, device=device, dtype=dtype),
|
|
torch.randn(24, 32, device=device, dtype=dtype),
|
|
torch.randn(40, 53, device=device, dtype=dtype),
|
|
]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
pad = 42
|
|
correct_output = []
|
|
for t in ts:
|
|
next_output = torch.ones_like(ts[2]) * pad
|
|
correct_output.append(next_output)
|
|
next_output[: t.size(0), : t.size(1)].copy_(t)
|
|
correct_output = torch.stack(correct_output)
|
|
padded = torch.nested.to_padded_tensor(nt, pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_dim4(self, device, dtype):
|
|
ts = [
|
|
torch.randn(16, 21, 13, device=device, dtype=dtype),
|
|
torch.randn(24, 32, 14, device=device, dtype=dtype),
|
|
torch.randn(40, 53, 16, device=device, dtype=dtype),
|
|
]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
pad = 42
|
|
correct_output = []
|
|
for t in ts:
|
|
next_output = torch.ones_like(ts[2]) * pad
|
|
correct_output.append(next_output)
|
|
next_output[: t.size(0), : t.size(1), : t.size(2)].copy_(t)
|
|
correct_output = torch.stack(correct_output)
|
|
padded = torch.nested.to_padded_tensor(nt, pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
# TODO: test noncontiguous to_padded_tensor
|
|
# For now this tests the functionality of noncontiguous_to_padded_tensor
|
|
# and the error message of to_padded_tensor
|
|
# since to_padded_tensor does not support noncontiguous buffer yet
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
@torch.inference_mode()
|
|
def test_to_padded_tensor_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device, dtype
|
|
)
|
|
# test noncontiguous_to_padded_tensor functionality
|
|
self.assertEqual(
|
|
torch.nested.to_padded_tensor(nt_contiguous, 0.0),
|
|
noncontiguous_to_padded_tensor(nt_noncontiguous),
|
|
)
|
|
# test to_padded_tensor error message
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"for now to_padded_tensor only supports contiguous nested tensor",
|
|
lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0),
|
|
)
|
|
|
|
@skipMeta
|
|
def test_device_checks(self, device):
|
|
nt = torch.nested.nested_tensor([], device=device)
|
|
is_cuda = "cuda" in str(device)
|
|
self.assertEqual(nt.is_cuda, is_cuda)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_nested_tensor_indexing(self, device, dtype):
|
|
# edge case: empty nested tensor
|
|
nt0 = torch.nested.nested_tensor([])
|
|
self.assertRaises(IndexError, lambda: nt0[0])
|
|
# normal case
|
|
x0 = torch.randn((2, 5), device=device, dtype=dtype)
|
|
x1 = torch.randn((3, 4), device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([x0, x1])
|
|
# single index: only support integer in the batch dimension
|
|
self.assertEqual(nt[0], x0)
|
|
self.assertEqual(nt[-1], x1)
|
|
self.assertRaises(IndexError, lambda: nt[2])
|
|
self.assertRaises(IndexError, lambda: nt[-3])
|
|
self.assertRaises(NotImplementedError, lambda: nt[:])
|
|
self.assertEqual(nt[...], nt)
|
|
# tuple of indices: only support integer in the batch dimension
|
|
# + all possible indexing in the original tensor dimensions
|
|
self.assertEqual(nt[0, 0, 0], x0[0, 0])
|
|
self.assertEqual(nt[0, 1, :], x0[1, :])
|
|
self.assertEqual(nt[1, ...], x1)
|
|
self.assertRaises(IndexError, lambda: nt[1, 4, 2])
|
|
self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
|
|
# test select on non-batch dimensions
|
|
self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0))
|
|
self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0))
|
|
self.assertRaises(IndexError, lambda: nt.select(1, 3))
|
|
self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0))
|
|
self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0))
|
|
self.assertRaises(IndexError, lambda: nt.select(2, 5))
|
|
# make sure indexing returns a view
|
|
nt[0].fill_(100.0)
|
|
answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
|
|
self.assertEqual(nt[0], answer)
|
|
nt[1, 1, :].fill_(200.0)
|
|
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
|
|
self.assertEqual(nt[1, 1, :], answer)
|
|
|
|
# Test that indexing works when requires_grad_(True)
|
|
# previously this was failing because the backward kernel for select.int uses .sizes()
|
|
nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
|
|
self.assertEqual(nt[0], x0)
|
|
self.assertEqual(nt[-1], x1)
|
|
grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
|
|
nt[0].backward(grad_x0)
|
|
expected_grad = torch.nested.nested_tensor(
|
|
[grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)]
|
|
)
|
|
self.assertEqual(nt.grad, expected_grad)
|
|
|
|
@parametrize(
|
|
"func",
|
|
[
|
|
subtest(torch.nn.functional.relu, name="relu"),
|
|
subtest(torch.nn.functional.relu_, name="relu_"),
|
|
subtest(torch.nn.functional.gelu, name="gelu"),
|
|
subtest(torch._C._nn.gelu_, name="gelu_"),
|
|
subtest(torch.tanh, name="tanh"),
|
|
subtest(torch.tanh_, name="tanh_"),
|
|
subtest(torch.neg, name="neg"),
|
|
subtest(torch.nn.functional.silu, name="silu"),
|
|
subtest(partial(torch.nn.functional.silu, inplace=True), name="silu_"),
|
|
subtest(torch.abs, name="abs"),
|
|
subtest(torch.abs_, name="abs_"),
|
|
subtest(torch.sgn, name="sgn"),
|
|
subtest(torch.logical_not, name="logical_not"),
|
|
subtest(torch.sin, name="sin"),
|
|
subtest(torch.cos, name="cos"),
|
|
subtest(torch.isinf, name="isinf"),
|
|
subtest(torch.isposinf, name="isposinf"),
|
|
subtest(torch.isneginf, name="isneginf"),
|
|
subtest(torch.isnan, name="isnan"),
|
|
subtest(torch.sqrt, name="sqrt"),
|
|
],
|
|
)
|
|
def test_unary_funcs(self, device, func):
|
|
nt, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device=device, dtype=torch.float32
|
|
)
|
|
nested_result = func(nt)
|
|
self.assertTrue(nested_result.is_nested)
|
|
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
|
|
self.assertEqual(func(t), t_res)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensor must be contiguous to get buffer.",
|
|
lambda: func(nt_noncontiguous),
|
|
)
|
|
|
|
@parametrize("func", [subtest(torch.ge, name="ge"), subtest(torch.eq, name="eq")])
|
|
def test_binary_ops_with_scalar(self, device, func):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device=device, dtype=torch.float32
|
|
)
|
|
scalar = 0.0
|
|
|
|
# should work regardless of contiguity
|
|
for nt in (nt_contiguous, nt_noncontiguous):
|
|
nested_result = func(nt, scalar)
|
|
self.assertTrue(nested_result.is_nested)
|
|
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
|
|
self.assertEqual(func(t, scalar), t_res)
|
|
|
|
@dtypes(*floating_types_and_half())
|
|
def test_nested_tensor_chunk(self, device, dtype):
|
|
# Transformer use case
|
|
a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
|
|
b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
|
|
c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
|
|
a_chunks = a.chunk(3, dim=-1)
|
|
b_chunks = b.chunk(3, dim=-1)
|
|
c_chunks = c.chunk(3, dim=-1)
|
|
|
|
a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
|
|
b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
|
|
c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]
|
|
|
|
nt = torch.nested.nested_tensor([a, b, c])
|
|
chunked = nt.chunk(3, dim=-1)
|
|
|
|
self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
|
|
self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
|
|
self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))
|
|
|
|
for chunk in chunked:
|
|
self.assertFalse(chunk.is_contiguous())
|
|
|
|
# Failure chunking on ragged dimensions
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Chunk for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.chunk(nt, 5, dim=1),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Chunk for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.chunk(nt, 5, dim=0),
|
|
)
|
|
|
|
# Failure on non-contiguous nt
|
|
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"chunk expects `self` to be contiguous.",
|
|
lambda: torch.chunk(nt_noncontiguous, 5, dim=-1),
|
|
)
|
|
|
|
# Failure when calling non divisible n_chunks
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Chunk for nested tensors is only supported for "
|
|
"nested tensors with trailing dimension divisible by chunks.",
|
|
lambda: torch.chunk(nt, 5, dim=-1),
|
|
)
|
|
|
|
# Failure when calling backward on a chunk
|
|
a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
|
|
b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
|
|
nt_grad = torch.nested.as_nested_tensor([a, b])
|
|
chunked = torch.chunk(nt_grad, 2, dim=-1)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Nested Strided Tensor doesn't support chunk backward.",
|
|
lambda: chunked[0].backward(chunked[0].clone()),
|
|
)
|
|
|
|
@dtypes(*floating_types_and_half())
|
|
def test_nested_tensor_split_with_sizes(self, device, dtype):
|
|
a = torch.randn(3, 20, device=device, dtype=dtype)
|
|
b = torch.randn(2, 20, device=device, dtype=dtype)
|
|
c = torch.randn(1, 20, device=device, dtype=dtype)
|
|
|
|
split_sizes = [4, 6, 10]
|
|
a_splits = a.split_with_sizes(split_sizes, dim=-1)
|
|
b_splits = b.split_with_sizes(split_sizes, dim=-1)
|
|
c_splits = c.split_with_sizes(split_sizes, dim=-1)
|
|
|
|
nt = torch.nested.nested_tensor([a, b, c])
|
|
nt_splits = nt.split_with_sizes(split_sizes, dim=-1)
|
|
|
|
for i, nt_split in enumerate(nt_splits):
|
|
self.assertEqual(
|
|
nt_split,
|
|
torch.nested.nested_tensor([a_splits[i], b_splits[i], c_splits[i]]),
|
|
)
|
|
dense_strides = torch.stack(
|
|
[
|
|
torch.tensor(a_splits[i].stride()),
|
|
torch.tensor(b_splits[i].stride()),
|
|
torch.tensor(c_splits[i].stride()),
|
|
]
|
|
)
|
|
self.assertEqual(nt_split._nested_tensor_strides(), dense_strides)
|
|
self.assertFalse(nt_split.is_contiguous())
|
|
|
|
# Failure calling on ragged dimensions
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"split_with_sizes for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.split_with_sizes(nt, split_sizes, dim=1),
|
|
)
|
|
|
|
# Failure calling on non-last dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"split_with_sizes for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.split_with_sizes(nt, split_sizes, dim=0),
|
|
)
|
|
|
|
# Failure on non-contiguous nt
|
|
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"split_with_sizes expects `self` to be contiguous.",
|
|
lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1),
|
|
)
|
|
|
|
# Failure when calling with split_sizes that don't cover the full dim size
|
|
bad_split_sizes = [4, 6, 9] # don't add up to 20
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"split_with_sizes expects split_sizes to sum exactly to 20",
|
|
lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device, dtype
|
|
)
|
|
self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0))
|
|
n = nt_contiguous.size(0)
|
|
for i in range(n):
|
|
self.assertEqual(nt_contiguous[i], nt_noncontiguous[i])
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
@parametrize("transpose", [True, False])
|
|
def test_nested_tensor_add(self, device, dtype, transpose):
|
|
if transpose:
|
|
a = torch.randn(2, 2, 2, device=device, dtype=dtype)
|
|
b = torch.rand(2, 2, 2, device=device, dtype=dtype)
|
|
c = a.transpose(-1, -2).contiguous()
|
|
d = b.transpose(-1, -2).contiguous()
|
|
nt1 = torch.nested.nested_tensor([a, b, a, b])
|
|
nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
|
|
else:
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor(
|
|
[t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
|
|
)
|
|
out = nt1 + nt2
|
|
self.assertEqual(ref, out)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
@parametrize("transpose", [True, False])
|
|
def test_nested_tensor_sub(self, device, dtype, transpose):
|
|
if transpose:
|
|
a = torch.randn(2, 2, 2, device=device, dtype=dtype)
|
|
b = torch.rand(2, 2, 2, device=device, dtype=dtype)
|
|
c = a.transpose(-1, -2).contiguous()
|
|
d = b.transpose(-1, -2).contiguous()
|
|
nt1 = torch.nested.nested_tensor([a, b, a, b])
|
|
nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
|
|
else:
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor(
|
|
[t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
|
|
)
|
|
out = nt1 - nt2
|
|
self.assertEqual(ref, out)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.float16)
|
|
@torch.inference_mode()
|
|
@parametrize("embedding_dim", [8, 128, 256, 384])
|
|
def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
|
|
def _test_add_mul(nt, t):
|
|
ref_add = torch.nested.nested_tensor(
|
|
[t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
|
|
)
|
|
ref_mul = torch.nested.nested_tensor(
|
|
[t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())]
|
|
)
|
|
self.assertEqual(nt.add(t), ref_add)
|
|
self.assertEqual(nt.mul(t), ref_mul)
|
|
|
|
batch_size = 32
|
|
seq_lens = torch.randint(low=0, high=10, size=(batch_size,))
|
|
|
|
# [B, *, D], [B, 1, D] case
|
|
ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
|
|
_test_add_mul(nt, t)
|
|
|
|
# [B, *], [B, 1] case
|
|
ts = [torch.randn(seq_len) for seq_len in seq_lens]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
t = torch.randn((batch_size, 1), device=device, dtype=dtype)
|
|
_test_add_mul(nt, t)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_mul(self, device, dtype):
|
|
# nested tensor * nested tensor
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor(
|
|
[t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
|
|
)
|
|
out = nt1 * nt2
|
|
self.assertEqual(ref, out)
|
|
# nested tensor * scalar
|
|
number = 10.0
|
|
scalar = torch.tensor(number).to(dtype).to(device)
|
|
ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
|
|
out_number0 = nt1 * number
|
|
out_number1 = number * nt1
|
|
out_scalar0 = nt1 * scalar
|
|
out_scalar1 = scalar * nt1
|
|
self.assertEqual(out_number0, ref)
|
|
self.assertEqual(out_number1, ref)
|
|
self.assertEqual(out_scalar0, ref)
|
|
self.assertEqual(out_scalar1, ref)
|
|
# error case: numel == 1 but dim > 0
|
|
vector = torch.tensor([number]).to(dtype).to(device)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a nested self and non-nested other",
|
|
lambda: nt1.mul(vector),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a non-nested self and nested other",
|
|
lambda: vector.mul(nt1),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_div(self, device, dtype):
|
|
nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
scale = 4.0
|
|
ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
|
|
out = nt / 4.0
|
|
self.assertEqual(ref, out)
|
|
ref_transposed = ref.transpose(1, 2)
|
|
out = nt.transpose(1, 2) / 4.0
|
|
self.assertEqual(ref_transposed, out)
|
|
|
|
ref = torch.nested.nested_tensor(
|
|
[t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]
|
|
)
|
|
out = nt / nt2
|
|
self.assertEqual(ref, out)
|
|
|
|
out = nt.transpose(1, 2) / nt2.transpose(1, 2)
|
|
self.assertEqual(ref.transpose(1, 2), out)
|
|
|
|
nt_transpose_copy = torch.nested.nested_tensor(
|
|
[t.transpose(0, 1) for t in nt.unbind()]
|
|
)
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"div requires strides to match when given NestedTensors",
|
|
lambda: nt_transpose_copy.transpose(1, 2) / nt2,
|
|
)
|
|
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype
|
|
)
|
|
nt_chunks = nt.chunk(2, -1)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"div requires offsets to match when given NestedTensors",
|
|
lambda: nt_chunks[0] / nt_chunks[1],
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_add_in_place(self, device, dtype):
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor(
|
|
[t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
|
|
)
|
|
nt1 += nt2
|
|
self.assertEqual(ref, nt1)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_mul_in_place(self, device, dtype):
|
|
# nested tensor * nested tensor
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor(
|
|
[t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())]
|
|
)
|
|
nt1 *= nt2
|
|
self.assertEqual(ref, nt1)
|
|
# nested tensor * scalar
|
|
number = 10.0
|
|
scalar = torch.tensor(number).to(dtype).to(device)
|
|
ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
|
|
out_number = nt1.clone()
|
|
out_number *= number
|
|
out_scalar = nt1.clone()
|
|
out_scalar *= scalar
|
|
self.assertEqual(out_number, ref)
|
|
self.assertEqual(out_scalar, ref)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
|
|
lambda: scalar.mul_(nt1),
|
|
)
|
|
# error case: numel == 1 but dim > 0
|
|
vector = torch.tensor([number]).to(dtype).to(device)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a nested self and non-nested other",
|
|
lambda: nt1.mul_(vector),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a non-nested self and nested other",
|
|
lambda: vector.mul_(nt1),
|
|
)
|
|
|
|
@onlyCPU
|
|
@skipMeta
|
|
@dtypes(torch.float)
|
|
def test_nested_tensor_sum_dim(self, device, dtype):
|
|
params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))
|
|
|
|
def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True):
|
|
nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False)
|
|
nt2 = nt.clone()
|
|
ub2 = nt2.unbind()
|
|
nt.requires_grad_(True)
|
|
[t.requires_grad_(True) for t in ub2]
|
|
nt_sum = nt.sum(dim=dim, keepdim=keepdim)
|
|
ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2]
|
|
self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum))
|
|
|
|
# test backward
|
|
# generate gradient tensor that has the same size as the output
|
|
size = nt_sum._nested_tensor_size()
|
|
gt2 = []
|
|
for i in range(ntensors):
|
|
gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype))
|
|
gt = torch.nested.nested_tensor(gt2).clone()
|
|
nt_sum.backward(gt)
|
|
for t2, g2 in zip(ub2_sum, gt2):
|
|
t2.backward(g2)
|
|
self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2]))
|
|
return
|
|
|
|
for ntensors, max_sizes in params:
|
|
test_sum(device, dtype, ntensors, max_sizes, len(max_sizes))
|
|
|
|
# Test error inputs
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "NestedTensor can only be reduced across the last"
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
|
|
).sum(0, keepdim=True)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "NestedTensor only allows reduction of a single"
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]
|
|
).sum([0, 1], keepdim=True)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "NestedTensor always requires keepdim=True for now."
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[torch.tensor([3, 4, 5]), torch.tensor([1, 2])]
|
|
).sum(-1)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_contiguous(self, device, dtype):
|
|
# Since we don't have access to the buffer in python this is harder to show what
|
|
# we are testing for. When we call chunk on a consistent dim of a NT
|
|
# for chunk_size > 1 the resulting tensors are views of the original NT
|
|
# whose numels is now less than the size of the buffer. Clone was
|
|
# previously creating a new NT with a buffer that was the same size as the
|
|
# original.
|
|
nt_contiguous = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 20, device=device, dtype=dtype),
|
|
torch.randn(4, 20, device=device, dtype=dtype),
|
|
]
|
|
)
|
|
# Split up the last dimension which has a consistent size of 20 into 5 chunks
|
|
chunks = nt_contiguous.chunk(5, dim=-1)
|
|
|
|
# # Check chunks are contiguous after calling contiguous
|
|
for chunk in chunks:
|
|
self.assertFalse(chunk.is_contiguous())
|
|
self.assertTrue(chunk.contiguous().is_contiguous())
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
def test_clone(self, device, dtype):
|
|
nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1))
|
|
nt2 = nt1.clone()
|
|
# Verify the values match
|
|
self.assertEqual(nt1, nt2)
|
|
# Verify modifying nt2 doesn't affect nt1
|
|
nt2.mul_(nt1)
|
|
ub1 = nt1.unbind()
|
|
ub2 = nt2.unbind()
|
|
for i in range(len(ub1)):
|
|
self.assertNotEqual(ub1[i], ub2[i])
|
|
|
|
nt1.clone(memory_format=torch.preserve_format)
|
|
msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
nt1.clone(memory_format=torch.channels_last)
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
|
|
@decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged)
|
|
@dtypes(torch.float, torch.double)
|
|
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
|
|
def test_dropout(self, device, dtype, layout):
|
|
# edge case: empty nested tensor
|
|
# TODO: support empty NT in jagged layout
|
|
if layout == torch.strided:
|
|
nt0 = torch.nested.nested_tensor([], layout=layout)
|
|
y = torch.nn.functional.dropout(nt0, 0.5)
|
|
self.assertEqual(nt0, y)
|
|
# normal nested tensor
|
|
ntensors = 4
|
|
if layout == torch.jagged:
|
|
nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout)
|
|
else:
|
|
nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout)
|
|
# edge case: invalid dropout
|
|
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
|
|
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
|
|
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
|
|
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
|
|
# edge case: no dropout
|
|
dropouter = torch.nn.Dropout(0.0)
|
|
y0 = dropouter(nt)
|
|
y1 = torch.nn.functional.dropout(nt, 0.0)
|
|
self.assertEqual(nt, y0)
|
|
self.assertEqual(nt, y1)
|
|
# edge case: all dropout
|
|
dropouter = torch.nn.Dropout(1.0)
|
|
y0 = dropouter(nt)
|
|
y1 = torch.nn.functional.dropout(nt, 1.0)
|
|
nt0 = torch.zeros_like(nt)
|
|
self.assertEqual(nt0, y0)
|
|
self.assertEqual(nt0, y1)
|
|
# normal case: normal dropout
|
|
p = 0.2
|
|
y = torch.nn.functional.dropout(nt, p)
|
|
expect = nt.clone()
|
|
if layout == torch.jagged:
|
|
expect = torch.where(y == 0.0, y, nt)
|
|
expect /= 1.0 - p
|
|
self.assertEqual(y, expect)
|
|
else:
|
|
expect = nt.clone()
|
|
for i in range(ntensors):
|
|
actual_tensor = y[i].view(-1)
|
|
expect_tensor = expect[i].view(-1)
|
|
for j in range(actual_tensor.shape[0]):
|
|
if actual_tensor[j].item() == 0.0:
|
|
expect_tensor[j] = 0.0
|
|
else:
|
|
expect_tensor[j] /= 1.0 - p
|
|
self.assertEqual(y, expect)
|
|
with freeze_rng_state():
|
|
dropouter = torch.nn.Dropout(p)
|
|
y0 = dropouter(nt)
|
|
with freeze_rng_state():
|
|
y1 = torch.nn.functional.dropout(nt, p)
|
|
self.assertEqual(y0, y1)
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
def test_dropout_noncontiguous(self, device, dtype):
|
|
ntensors = 4
|
|
nt0 = random_nt(device, dtype, ntensors, (4, 4))
|
|
nt1 = nt0.transpose(-1, -2)
|
|
p = 0.3
|
|
with freeze_rng_state():
|
|
dropouter = torch.nn.Dropout(p)
|
|
y0 = dropouter(nt0)
|
|
with freeze_rng_state():
|
|
y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2)
|
|
self.assertEqual(y0, y1)
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_softmax(self, device, dtype):
|
|
# normal nested tensor
|
|
ntensors = 4
|
|
nt = random_nt(device, dtype, ntensors, (4, 4))
|
|
# error case: softmax across nested dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot apply softmax across nested dimension 0",
|
|
lambda: torch.nn.functional.softmax(nt, 0),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot apply softmax across nested dimension 0",
|
|
lambda: torch.nn.functional.softmax(nt, -3),
|
|
)
|
|
# error case: dimension out of range
|
|
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
|
|
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
|
|
# normal case: should equal to padding -inf
|
|
softmaxer = torch.nn.Softmax(1)
|
|
y0 = softmaxer(nt)
|
|
y1 = torch.nn.functional.softmax(nt, 1)
|
|
self.assertEqual(y0, y1)
|
|
pt = torch.nested.to_padded_tensor(nt, float("-inf"))
|
|
# if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
|
|
# however, physically speaking that should be 0.0
|
|
expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
|
|
self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect)
|
|
# edge case: empty nested tensor
|
|
nt0 = torch.nested.nested_tensor([])
|
|
y = torch.nn.functional.softmax(nt0, 1)
|
|
self.assertEqual(nt0, y)
|
|
# edge case: nesting scalars
|
|
nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
|
|
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
@torch.inference_mode()
|
|
def test_softmax_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device, dtype
|
|
)
|
|
self.assertEqual(
|
|
torch.nn.functional.softmax(nt_contiguous, -1),
|
|
torch.nn.functional.softmax(nt_noncontiguous, -1),
|
|
)
|
|
|
|
def _test_bmm(self, device, dtype):
|
|
# error case: not 3D tensors
|
|
nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn(2), torch.randn(3)], device=device, dtype=dtype
|
|
)
|
|
nt2 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1)
|
|
)
|
|
# error case: incompatible batch size
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
|
|
lambda: nt0.bmm(nt1),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
|
|
lambda: nt1.bmm(nt0),
|
|
)
|
|
# error case: underlying matrices cannot be multiplied
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
|
|
lambda: nt0.bmm(nt0),
|
|
)
|
|
# normal nested tensor
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
|
|
)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
|
|
torch.nested.to_padded_tensor(nt1, 0.0)
|
|
)
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
# nested tensor bmm normal tensor
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 7)), torch.randn((3, 7))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.rand(2, 7, 5, dtype=dtype, device=device)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
# nested tensor bmm normal tensor with non-contiguous view
|
|
nt1 = torch.rand(2, 5, 7, dtype=dtype, device=device)
|
|
nt1 = nt1.transpose(1, 2)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(nt1)
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
# normal tensor bmm nested tensor
|
|
nt0 = torch.rand(2, 5, 7, dtype=dtype, device=device)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((7, 6)), torch.randn((7, 5))], device=device, dtype=dtype
|
|
)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = nt0.bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
# test tensorcore path
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype
|
|
)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(
|
|
torch.nested.to_padded_tensor(nt1, 0.0)
|
|
)
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.double, torch.float16, torch.bfloat16)
|
|
@tf32_on_and_off(0.005)
|
|
def test_bmm_cuda(self, device, dtype):
|
|
self._test_bmm(device, dtype)
|
|
|
|
@onlyCPU
|
|
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_bmm_cpu(self, device, dtype):
|
|
self._test_bmm(device, dtype)
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_bmm_noncontiguous(self, device, dtype):
|
|
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3), device, dtype
|
|
)
|
|
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
|
|
(6, 7), device, dtype
|
|
)
|
|
self.assertEqual(
|
|
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
|
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
@tf32_on_and_off(0.005)
|
|
def test_matmul_with_bmm_path(self, device, dtype):
|
|
def unbind_rebind_matmul(nt1, nt2):
|
|
t1s = nt1.unbind()
|
|
t2s = nt2.unbind()
|
|
out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
|
|
return torch.nested.nested_tensor(out_ts)
|
|
|
|
# [N, n_head, *, head_dim], [N, n_head, head_dim, *]
|
|
Ns = [1, 2, 5]
|
|
n_heads = np.random.randint(2, 5)
|
|
head_dim = 3
|
|
t1s = []
|
|
t2s = []
|
|
for N in Ns:
|
|
for _ in range(N):
|
|
seq_len1 = np.random.randint(2, 5)
|
|
seq_len2 = np.random.randint(2, 5)
|
|
t1s.append(torch.randn(n_heads, seq_len1, head_dim))
|
|
t2s.append(torch.randn(n_heads, head_dim, seq_len2))
|
|
nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
|
|
nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
|
|
self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
|
|
|
|
# test with noncontiguous
|
|
t3s = []
|
|
t4s = []
|
|
for _ in range(N):
|
|
seq_len = np.random.randint(2, 5)
|
|
t3s.append(torch.randn(seq_len, n_heads, head_dim))
|
|
t4s.append(torch.randn(seq_len, n_heads, head_dim))
|
|
nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(
|
|
1, 2
|
|
)
|
|
nt4 = (
|
|
torch.nested.nested_tensor(t4s, device=device, dtype=dtype)
|
|
.transpose(1, 2)
|
|
.transpose(2, 3)
|
|
)
|
|
self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul(self, device, dtype):
|
|
# error case: one is nested but the other is not
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn(2), torch.randn(3)], device=device, dtype=dtype
|
|
)
|
|
t = torch.randn(4, device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both to be nested, but got a nested self and non-nested other",
|
|
lambda: torch.matmul(nt, t),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both to be nested, but got a non-nested self and nested other",
|
|
lambda: torch.matmul(t, nt),
|
|
)
|
|
# error case: not 3+D tensors
|
|
nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn(2), torch.randn(3)], device=device, dtype=dtype
|
|
)
|
|
nt2 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt0, nt0),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt0, nt1),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt0, nt2),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt1, nt0),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt1, nt1),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt1, nt2),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt2, nt0),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt2, nt1),
|
|
)
|
|
# error case: incompatible batch size
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
|
|
lambda: torch.matmul(nt0, nt1),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
|
|
lambda: torch.matmul(nt1, nt0),
|
|
)
|
|
# error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 2, 4)), torch.randn((2, 3, 4))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((3, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"matmul(): For nested tensors, batch dimensions must have the same sizes,",
|
|
lambda: torch.matmul(nt0, nt1),
|
|
)
|
|
# error case: incompatible batch sizes that should technically broadcast
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 2, 4)), torch.randn((1, 3, 4))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((1, 4, 6)), torch.randn((3, 4, 5))], device=device, dtype=dtype
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"matmul(): For nested tensors, batch dimensions must have the same sizes,",
|
|
lambda: torch.matmul(nt0, nt1),
|
|
)
|
|
# error case: underlying matrices cannot be multiplied
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"matmul(): Nested tensors cannot be matrix multiplied",
|
|
lambda: torch.matmul(nt0, nt0),
|
|
)
|
|
# normal nested tensor: 3D
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype
|
|
)
|
|
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
|
|
expect = torch.matmul(
|
|
torch.nested.to_padded_tensor(nt0, 0.0),
|
|
torch.nested.to_padded_tensor(nt1, 0.0),
|
|
)
|
|
self.assertEqual(actual, expect)
|
|
# normal nested tensor: 4D (with testing for batch_size=1)
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((1, 2, 4)), torch.randn((8, 3, 7))], device=device, dtype=dtype
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((1, 4, 6)), torch.randn((8, 7, 5))], device=device, dtype=dtype
|
|
)
|
|
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
|
|
expect = torch.matmul(
|
|
torch.nested.to_padded_tensor(nt0, 0.0),
|
|
torch.nested.to_padded_tensor(nt1, 0.0),
|
|
)
|
|
self.assertEqual(actual, expect)
|
|
# normal nested tensor: 5D
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((8, 9, 2, 4)), torch.randn((8, 9, 3, 7))],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((8, 9, 4, 6)), torch.randn((8, 9, 7, 5))],
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
|
|
expect = torch.matmul(
|
|
torch.nested.to_padded_tensor(nt0, 0.0),
|
|
torch.nested.to_padded_tensor(nt1, 0.0),
|
|
)
|
|
self.assertEqual(actual, expect)
|
|
|
|
# only supported on CUDA for now
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul_nt_with_broadcasted_t(self, device, dtype):
|
|
# NT (B, *, C, D) with T (D, E) broadcasting case
|
|
nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype)
|
|
t = torch.randn(5, 6, device=device, dtype=dtype)
|
|
output = torch.matmul(nt, t)
|
|
|
|
# should be equivalent to matmul-ing each component with the dense tensor
|
|
self.assertEqual(nt.size(0), output.size(0))
|
|
for component, out_component in zip(nt, output):
|
|
self.assertEqual(out_component, torch.matmul(component, t))
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul_noncontiguous(self, device, dtype):
|
|
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3), device, dtype
|
|
)
|
|
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair(
|
|
(6, 7), device, dtype
|
|
)
|
|
self.assertEqual(
|
|
torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous),
|
|
torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
def test_linear(self, device, dtype):
|
|
a = torch.randn(1, 2, device=device, dtype=dtype)
|
|
b = torch.randn(2, 2, device=device, dtype=dtype)
|
|
c = torch.randn(3, 2, device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([a, b, c])
|
|
|
|
weight = torch.randn(2, 2, device=device, dtype=dtype)
|
|
bias = torch.randn(2, device=device, dtype=dtype)
|
|
# success case
|
|
torch.functional.F.linear(nt, weight, bias)
|
|
|
|
# invalid nested tensor dimension
|
|
msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2"
|
|
nt1 = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(1, device=device, dtype=dtype),
|
|
torch.randn(2, device=device, dtype=dtype),
|
|
]
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt1, weight, bias)
|
|
|
|
# invalid weight shape
|
|
msg = r"Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3"
|
|
weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt, weight1, bias)
|
|
|
|
# inconsistent last dim of nested tensor
|
|
msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:"
|
|
nt2 = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(1, 2, device=device, dtype=dtype),
|
|
torch.randn(2, 3, device=device, dtype=dtype),
|
|
]
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt2, weight, bias)
|
|
|
|
# Mismatch of nested tensor last dim and weight dimension
|
|
weight2 = torch.randn(2, 4, device=device, dtype=dtype)
|
|
msg = (
|
|
r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'"
|
|
r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt, weight2, bias)
|
|
|
|
# Nested tensor input and nested weight
|
|
nt_weight = nt.clone()
|
|
msg = r"Linear does not support nested weight when input is a nested tensor."
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt, nt_weight, bias)
|
|
|
|
# TODO: test noncontiguous linear
|
|
# For now this tests the error message of linear
|
|
# since linear does not support noncontiguous buffer yet
|
|
@dtypes(torch.float, torch.double)
|
|
def test_linear_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device, dtype
|
|
)
|
|
weight = torch.randn((8, 5), device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"for now linear only supports contiguous nested tensor",
|
|
lambda: torch.nn.functional.linear(nt_noncontiguous, weight),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
|
|
ts = [torch.ones(1, 0), torch.ones(0, 0)]
|
|
nt = torch.nested.nested_tensor(
|
|
ts, device=device, dtype=dtype, layout=torch.strided
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"at least one constituent tensor should have non-zero numel",
|
|
lambda: torch.nested.to_padded_tensor(nt, 0.0),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_transpose(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# error case: transpose nested dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Nested tensor dimension 0 cannot be transposed",
|
|
lambda: nt.transpose(0, 1),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Nested tensor dimension 0 cannot be transposed",
|
|
lambda: nt.transpose(1, -3),
|
|
)
|
|
# error case: dimension out of range
|
|
self.assertRaises(IndexError, lambda: nt.transpose(1, 3))
|
|
self.assertRaises(IndexError, lambda: nt.transpose(-4, -1))
|
|
# normal case
|
|
ntT = nt.transpose(-1, -2)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.transpose(-1, -2)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_squeeze_unsqueeze(self, device, dtype):
|
|
a = torch.arange(6).reshape(2, 3)
|
|
b = torch.arange(15).reshape(5, 3)
|
|
nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
|
|
# error case: squeeze no dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"For nested tensors, squeeze without the dim argument",
|
|
lambda: nt.squeeze(),
|
|
)
|
|
# error case: squeeze nested dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"For nested tensors, squeezing dimension 0",
|
|
lambda: nt.squeeze(0),
|
|
)
|
|
# error case: dimension out of range
|
|
self.assertRaises(IndexError, lambda: nt.squeeze(3))
|
|
# error case: squeeze nested tensor of singleton tensors
|
|
c = torch.ones(1)
|
|
nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"For nested tensors, squeezing a nested tensor of singleton",
|
|
lambda: nt_singleton.squeeze(1),
|
|
)
|
|
|
|
# squeezing a dim which does not have size 1 should be a no-op
|
|
nt2 = nt.squeeze(-1)
|
|
self.assertEqual(nt, nt2)
|
|
|
|
# test cases that should work
|
|
nt_sizes = nt._nested_tensor_size()
|
|
nt_strides = nt._nested_tensor_strides()
|
|
for i in range(-2, 4):
|
|
if i == 0:
|
|
# cannot unsqueeze batch dim
|
|
continue
|
|
nt_unsqueezed = nt.unsqueeze(i)
|
|
# negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
|
|
wrapped_i = i + nt.dim() + 1 if i < 0 else i
|
|
# col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
|
|
size_idx = wrapped_i - 1
|
|
self.assertEqual(
|
|
nt_unsqueezed._nested_tensor_size()[:, size_idx],
|
|
torch.ones(2, dtype=torch.long),
|
|
)
|
|
unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
|
|
if i == nt.ndim or i == -1:
|
|
self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
|
|
else:
|
|
stride_col_after = nt_strides[:, size_idx]
|
|
size_col_after = nt_sizes[:, size_idx]
|
|
self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
|
|
nt_squeezed = nt_unsqueezed.squeeze(i)
|
|
self.assertEqual(nt_squeezed, nt)
|
|
self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
|
|
self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_transpose_inference_mode_interaction(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# Construct in default mode and transpose while in inference mode
|
|
with torch.inference_mode():
|
|
ntT = nt.transpose(-1, -2)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.transpose(-1, -2)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
# Construct and transpose while in inference mode
|
|
with torch.inference_mode():
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
ntT = nt.transpose(-1, -2)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.transpose(-1, -2)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_view(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# error case: empty shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[\]' is invalid for a nested tensor",
|
|
lambda: nt.view(()),
|
|
)
|
|
# error case: empty nested tensor
|
|
nt_empty = torch.nested.nested_tensor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"empty nested tensor cannot be reshaped",
|
|
lambda: nt_empty.view(-1),
|
|
)
|
|
# error case: -1 for batch size
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"view: For now nested view cannot change or infer the implicit batch dimension",
|
|
lambda: nt.view(-1, 2, 3),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[.*\]' is invalid for input of size [0-9]+",
|
|
lambda: nt.view(4, 2, 3),
|
|
)
|
|
# normal case
|
|
x0 = torch.randn((2, 20), device=device, dtype=dtype)
|
|
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([x0, x1])
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
# error case, trying to reshape batch dim to a legit shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"For now nested view cannot change or infer the implicit batch dimension",
|
|
lambda: nt.transpose(-1, -2).view(40, -1),
|
|
)
|
|
# inherit only the ragged dimension
|
|
# (2, 20) -> (2, 5, 4)
|
|
# (3, 20) -> (3, 5, 4)
|
|
nt1 = nt.view(2, -1, 5, 4)
|
|
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
|
pt1 = pt.view(2, -1, 5, 4)
|
|
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
|
|
|
# more than one -1 (even for "old" dims), should fail
|
|
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
|
|
# but we ban "inherit old behavior" for >1 dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"only one dimension can be inferred",
|
|
lambda: nt1.view(2, -1, -1, 2, 2),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_view_inference_mode_interaction(self, device, dtype):
|
|
# Construct in default mode and view while in inference mode
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
|
|
)
|
|
with torch.inference_mode():
|
|
ntT = nt.view(2, -1, 4, 5)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.view(2, -1, 4, 5)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
# Construct and view while in inference mode
|
|
with torch.inference_mode():
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype
|
|
)
|
|
ntT = nt.view(2, -1, 4, 5)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.view(2, -1, 4, 5)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_reshape(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# error case: empty shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[\]' is invalid for a nested tensor",
|
|
lambda: nt.reshape(()),
|
|
)
|
|
# error case: empty nested tensor
|
|
nt_empty = torch.nested.nested_tensor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"empty nested tensor cannot be reshaped",
|
|
lambda: nt_empty.reshape(-1),
|
|
)
|
|
# error case: -1 for batch size
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
|
|
lambda: nt.reshape(-1, 2, 3),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[.*\]' is invalid for input of size [0-9]+",
|
|
lambda: nt.reshape(4, 2, 3),
|
|
)
|
|
# normal case
|
|
x0 = torch.randn((2, 20), device=device, dtype=dtype)
|
|
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
# error case, trying to reshape batch dim to a legit shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
|
|
lambda: nt.transpose(-1, -2).reshape(40, -1),
|
|
)
|
|
# inherit only the ragged dimension
|
|
# (2, 20) -> (2, 5, 4)
|
|
# (3, 20) -> (3, 5, 4)
|
|
nt1 = nt.reshape(2, -1, 5, 4)
|
|
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
|
pt1 = pt.reshape(2, -1, 5, 4)
|
|
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
|
|
|
# more than one -1 (even for "old" dims), should fail
|
|
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
|
|
# but we ban "inherit old behavior" for >1 dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"only one dimension can be inferred",
|
|
lambda: nt1.reshape(2, -1, -1, 2, 2),
|
|
)
|
|
|
|
def test_nested_masked_select(self, device):
|
|
t = torch.randn([3, 3], device=device)
|
|
mask = torch.tensor([False], device=device)
|
|
|
|
njt = torch.nested.masked_select(t, mask)
|
|
self.assertEqual(njt.values(), torch.tensor([], device=device))
|
|
self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 0], device=device))
|
|
|
|
mask = torch.tensor([[False], [False], [True]], device=device)
|
|
njt = torch.nested.masked_select(t, mask)
|
|
self.assertEqual(njt.values(), t[-1], atol=0.1, rtol=0.1)
|
|
self.assertEqual(njt.offsets(), torch.tensor([0, 0, 0, 3], device=device))
|
|
|
|
mask = torch.tensor(
|
|
[[False, False, True], [True, False, True], [False, False, True]],
|
|
device=device,
|
|
)
|
|
njt = torch.nested.masked_select(t, mask)
|
|
self.assertEqual(njt.values(), t.masked_select(mask))
|
|
self.assertEqual(njt.offsets(), torch.tensor([0, 1, 3, 4], device=device))
|
|
|
|
t = torch.randn([2, 3, 3, 1], device=device)
|
|
mask = torch.tensor(
|
|
[
|
|
[
|
|
[[True], [False], [True]],
|
|
[[True], [False], [True]],
|
|
[[True], [False], [True]],
|
|
],
|
|
[
|
|
[[False], [True], [True]],
|
|
[[False], [True], [True]],
|
|
[[True], [True], [True]],
|
|
],
|
|
],
|
|
device=device,
|
|
)
|
|
njt = torch.nested.masked_select(t, mask)
|
|
self.assertEqual(njt.values(), t.masked_select(mask))
|
|
self.assertEqual(
|
|
njt.offsets(),
|
|
torch.tensor(
|
|
[0, 1, 1, 2, 3, 3, 4, 5, 5, 6, 6, 7, 8, 8, 9, 10, 11, 12, 13],
|
|
device=device,
|
|
),
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_narrow(self, device, dtype):
|
|
nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype)
|
|
|
|
# narrow on dim=0 from start to end
|
|
bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)]
|
|
for start, end in bounds:
|
|
length = end - start
|
|
narrowed = nt.narrow(dim=0, start=start, length=length)
|
|
# ensure output is a view
|
|
self.assertTrue(narrowed._base is nt)
|
|
for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]):
|
|
self.assertEqual(nc, c)
|
|
|
|
# dim != 0 is not supported
|
|
for dim in range(1, nt.dim()):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only dim=0 supported for nested tensors"
|
|
):
|
|
nt.narrow(dim=dim, start=0, length=1)
|
|
|
|
# error case: non-contiguous NT
|
|
_, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4))
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only contiguous nested tensors supported"
|
|
):
|
|
nt_noncont.narrow(dim=0, start=0, length=1)
|
|
|
|
@parametrize("input_dim", [3, 4])
|
|
@tf32_on_and_off(0.005)
|
|
def test_scaled_dot_product_attention(self, device, input_dim):
|
|
def rand_tensor(*shape):
|
|
return torch.randn(shape, device=device)
|
|
|
|
E = 8
|
|
if input_dim == 3:
|
|
# Shape: (N, L, E); ragged L
|
|
query = torch.nested.nested_tensor(
|
|
[rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)]
|
|
)
|
|
|
|
# Shape: (N, S, E); ragged S
|
|
key = torch.nested.nested_tensor(
|
|
[rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
|
|
)
|
|
value = torch.nested.nested_tensor(
|
|
[rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)]
|
|
)
|
|
elif input_dim == 4:
|
|
# In the 4D case the L and S is ragged
|
|
# Shape: (N, N', L, E); ragged N' and L
|
|
query = torch.nested.nested_tensor(
|
|
[rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)]
|
|
)
|
|
# Shape: (N, N', S, E); ragged N' and S
|
|
key = torch.nested.nested_tensor(
|
|
[rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
|
|
)
|
|
value = torch.nested.nested_tensor(
|
|
[rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)]
|
|
)
|
|
else:
|
|
self.fail(f"Invalid input_dim {input_dim} encountered in SDP test")
|
|
|
|
def rand_mask(size):
|
|
return torch.randint(0, 2, size=size, dtype=torch.bool, device=device)
|
|
|
|
# Shape: (N, L, S); ragged L and S matching above
|
|
attn_mask = torch.nested.nested_tensor(
|
|
[rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))]
|
|
)
|
|
|
|
dropout_p = 0.0 # no dropout for reproducibility
|
|
|
|
# Success case: no attn_mask set and is_causal=False.
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p
|
|
)
|
|
|
|
expected_outputs = []
|
|
for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
|
|
output = torch.nn.functional.scaled_dot_product_attention(
|
|
q.unsqueeze(0),
|
|
k.unsqueeze(0),
|
|
v.unsqueeze(0),
|
|
attn_mask=None,
|
|
dropout_p=dropout_p,
|
|
)
|
|
expected_outputs.append(output.squeeze(0))
|
|
expected_output_nested = torch.nested.nested_tensor(expected_outputs)
|
|
self.assertEqual(actual, expected_output_nested)
|
|
|
|
# Error case: explicit attn_mask set.
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "not supported when an explicit attn_mask is set"
|
|
):
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p
|
|
)
|
|
|
|
# Error case: is_causal=True.
|
|
with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
query, key, value, dropout_p=dropout_p, is_causal=True
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_empty_like(self, device, dtype):
|
|
ntensors = 4
|
|
nt = random_nt(device, dtype, ntensors, (4, 4))
|
|
|
|
# Create empty on same device as original nested tensor
|
|
nt_empty = torch.empty_like(nt)
|
|
assert nt.is_same_size(nt_empty)
|
|
self.assertEqual(nt.dtype, nt_empty.dtype)
|
|
self.assertEqual(nt.device, nt_empty.device)
|
|
self.assertEqual(nt.layout, nt_empty.layout)
|
|
|
|
if torch.cuda.is_available():
|
|
if device == "cpu":
|
|
nt_cuda = torch.empty_like(nt, device="cuda")
|
|
self.assertEqual(torch.device("cuda").type, nt_cuda.device.type)
|
|
else:
|
|
nt_cpu = torch.empty_like(nt, device="cpu")
|
|
self.assertEqual(torch.device("cpu").type, nt_cpu.device.type)
|
|
|
|
# Check changing dtype of empty_like nested tensor output
|
|
dtype_set = {torch.float, torch.float16, torch.double}
|
|
for other_dtype in dtype_set - {dtype}:
|
|
nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype)
|
|
self.assertEqual(nt.dtype, dtype)
|
|
self.assertEqual(nt_empty_other_dtype.dtype, other_dtype)
|
|
self.assertEqual(nt.device, nt_empty.device)
|
|
self.assertEqual(nt.layout, nt_empty.layout)
|
|
|
|
# Create tensor for autograd
|
|
nt_empty_req_grad = torch.empty_like(nt, requires_grad=True)
|
|
self.assertEqual(nt_empty_req_grad.requires_grad, True)
|
|
|
|
# Test noncontiguous tensor does not fail to copy
|
|
nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
nt_empty = torch.empty_like(nt_cont)
|
|
assert nt_cont.is_same_size(nt_empty)
|
|
nt_empty_non_contig = torch.empty_like(nt_noncont)
|
|
assert nt_noncont.is_same_size(nt_empty_non_contig)
|
|
|
|
# Test the contiguous memory format option
|
|
nt_empty_contig = torch.empty_like(
|
|
nt_cont, memory_format=torch.contiguous_format
|
|
)
|
|
assert nt_cont.is_same_size(nt_empty_contig)
|
|
assert nt_empty_contig.is_contiguous()
|
|
|
|
nt_empty_non_contig = torch.empty_like(
|
|
nt_noncont, memory_format=torch.contiguous_format
|
|
)
|
|
assert nt_noncont.is_same_size(nt_empty_non_contig)
|
|
assert nt_empty_non_contig.is_contiguous()
|
|
|
|
# Test other memory formats fail
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last),
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last),
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d),
|
|
)
|
|
self.assertRaises(
|
|
RuntimeError,
|
|
lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d),
|
|
)
|
|
|
|
|
|
@markDynamoStrictTest
|
|
class TestNestedTensorAutograd(NestedTensorTestCase):
|
|
# Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck
|
|
# includes the default parameters used for testing ops with gradcheck. However nested tensor
|
|
# does not support the stack op therefore we turn it off for these tests
|
|
def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False):
|
|
return torch.nested.nested_tensor(
|
|
[torch.randn(1, 2), torch.randn(7, 8)],
|
|
requires_grad=requires_grad,
|
|
device=tensor_device,
|
|
)
|
|
|
|
def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False):
|
|
return torch.nested.as_nested_tensor(
|
|
[
|
|
torch.randn(1, 2, requires_grad=requires_grad),
|
|
torch.randn(7, 8, requires_grad=requires_grad),
|
|
],
|
|
device=tensor_device,
|
|
)
|
|
|
|
def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False):
|
|
data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device)
|
|
mask = torch.ones_like(data[:, :, 0]).bool()
|
|
return torch._nested_tensor_from_mask(data, mask)
|
|
|
|
def test_as_nested_tensor_propagates_gradients(self, device):
|
|
a = torch.arange(3, dtype=torch.float, device=device)
|
|
b = torch.arange(5, dtype=torch.float, device=device)
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
# tensors with requires_grad=False are leaves
|
|
self.assertTrue(nt.is_leaf)
|
|
self.assertTrue(not nt.requires_grad)
|
|
|
|
a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
|
|
b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
|
|
nt2 = torch.nested.as_nested_tensor([a, b])
|
|
fake_grad = torch.nested.nested_tensor(
|
|
[torch.ones_like(a), torch.zeros_like(b)], device=device
|
|
)
|
|
nt2.backward(fake_grad)
|
|
self.assertEqual(a.grad, fake_grad[0])
|
|
self.assertEqual(b.grad, fake_grad[1])
|
|
|
|
def test_nested_tensor_generates_leaf(self, device):
|
|
a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
|
|
b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
|
|
|
|
nt = torch.nested.nested_tensor([a, b], requires_grad=False)
|
|
self.assertTrue(nt.is_leaf)
|
|
self.assertTrue(not nt.requires_grad)
|
|
|
|
nt2 = torch.nested.nested_tensor([a, b], requires_grad=True)
|
|
self.assertTrue(nt2.is_leaf)
|
|
self.assertTrue(nt2.requires_grad)
|
|
|
|
fake_grad = torch.nested.nested_tensor(
|
|
[torch.ones_like(a), torch.zeros_like(b)], device=device
|
|
)
|
|
nt2.backward(fake_grad)
|
|
self.assertEqual(nt2.grad, fake_grad)
|
|
self.assertEqual(a.grad, None)
|
|
self.assertEqual(b.grad, None)
|
|
|
|
def test_set_requires_grad_from_list(self, device):
|
|
nt = self._create_nested_tensor_from_list(device)
|
|
nt.requires_grad_()
|
|
assert nt.requires_grad
|
|
|
|
def test_set_requires_grad_from_mask(self, device):
|
|
nt = self._create_nested_tensor_from_mask(device)
|
|
nt.requires_grad_()
|
|
assert nt.requires_grad
|
|
|
|
def test_backward_for_add_op(self, device):
|
|
nt_1 = self._create_nested_tensor_from_mask(device)
|
|
nt_2 = self._create_nested_tensor_from_mask(device)
|
|
|
|
nt_1.requires_grad_()
|
|
c = nt_1 + nt_2
|
|
|
|
assert nt_1.requires_grad
|
|
assert c.requires_grad
|
|
grad_output = self._create_nested_tensor_from_mask(device)
|
|
c.backward(grad_output)
|
|
|
|
# Grad check doesn't work with nested yet.
|
|
# d/dnt_1 (nt + nt_1) = 1*grad_output
|
|
self.assertEqual(nt_1.grad, grad_output)
|
|
|
|
def test_backward_for_sub_op(self, device):
|
|
nt_1 = self._create_nested_tensor_from_mask(device)
|
|
nt_2 = self._create_nested_tensor_from_mask(device)
|
|
|
|
nt_1.requires_grad_()
|
|
nt_2.requires_grad_()
|
|
c = nt_1 - nt_2
|
|
|
|
assert nt_1.requires_grad
|
|
assert nt_2.requires_grad
|
|
assert c.requires_grad
|
|
grad_output = self._create_nested_tensor_from_mask(device)
|
|
c.backward(grad_output)
|
|
|
|
self.assertEqual(nt_1.grad, grad_output)
|
|
self.assertEqual(nt_2.grad, -1 * grad_output)
|
|
|
|
def test_backward_sub_strided(self, device):
|
|
a = torch.nested.nested_tensor(
|
|
[torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
b = torch.nested.nested_tensor(
|
|
[torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
c = a - b.transpose(-1, -2)
|
|
grad_output = c.clone()
|
|
c.backward(grad_output)
|
|
self.assertEqual(a.grad, grad_output)
|
|
self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2))
|
|
|
|
def test_backward_add_strided(self, device):
|
|
a = torch.nested.nested_tensor(
|
|
[torch.randn(9, 2, 4), torch.randn(12, 2, 4)],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
b = torch.nested.nested_tensor(
|
|
[torch.randn(9, 4, 2), torch.randn(12, 4, 2)],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
c = a + b.transpose(-1, -2)
|
|
grad_output = c.clone()
|
|
c.backward(grad_output)
|
|
self.assertEqual(a.grad, grad_output)
|
|
self.assertEqual(b.grad, grad_output.transpose(-1, -2))
|
|
|
|
# Test Factory Functions
|
|
def test_nested_tensor_to_padded_tensor(self, device):
|
|
for padding_val in [0, 1]:
|
|
nt = self._create_leaf_nested_tensor_from_list(
|
|
tensor_device=device, requires_grad=True
|
|
)
|
|
|
|
out = torch.nested.to_padded_tensor(nt, padding_val)
|
|
grad_output = torch.ones(out.shape, device=device)
|
|
out.backward(grad_output)
|
|
|
|
self.assertEqual(
|
|
nt.grad,
|
|
torch.nested.nested_tensor(
|
|
[torch.ones(1, 2), torch.ones(7, 8)], device=device
|
|
),
|
|
)
|
|
|
|
def test_nested_tensor_from_mask_and_to_padded(self, device):
|
|
N, L, D = 2, 4, 4
|
|
mask = torch.ones(N, L, device=device)
|
|
for i in range(1, N):
|
|
end = torch.randint(1, L - 1, (1,), device=device)
|
|
mask[i, end:] = 0
|
|
|
|
mask[0, :] = 1
|
|
mask = mask.bool()
|
|
|
|
data = torch.randn(
|
|
N, L, D, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
|
|
def grad_test_func(inpt):
|
|
nt = torch._nested_tensor_from_mask(inpt, mask)
|
|
# This implicitly tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(nt, 0)
|
|
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_from_padded(self, device):
|
|
nested_size = torch.tensor([[1, 2], [2, 2]])
|
|
padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device)
|
|
padded_tensor[0, 1, :] = 0
|
|
padded_tensor.requires_grad_()
|
|
|
|
def grad_test_func(tensor, nested_size):
|
|
nt = torch._nested_from_padded(
|
|
tensor, nested_size, fuse_transform_0213=False
|
|
)
|
|
# This implicitly tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(nt, 0)
|
|
|
|
data = (padded_tensor, nested_size)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_from_padded_fused(self, device):
|
|
nested_size = torch.tensor([[1, 8], [2, 8]])
|
|
padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device)
|
|
padded_tensor[0, 1, :] = 0
|
|
padded_tensor.requires_grad_()
|
|
|
|
def grad_test_func(tensor, nested_size):
|
|
nt = torch._nested_from_padded(
|
|
tensor, nested_size, fuse_transform_0213=True
|
|
)
|
|
# This implicitly tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(nt, 0)
|
|
|
|
data = (padded_tensor, nested_size)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_from_list(self, device):
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
c = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implictily tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(c, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
|
|
def test_dropout_backward(self, layout):
|
|
if layout == torch.jagged:
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 5)), torch.randn((3, 5))],
|
|
requires_grad=True,
|
|
layout=layout,
|
|
)
|
|
else:
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 5)), torch.randn((3, 4))],
|
|
requires_grad=True,
|
|
layout=layout,
|
|
)
|
|
p = 0.2
|
|
y = torch.nn.functional.dropout(nt, p)
|
|
y.backward(nt.detach().clone())
|
|
self.assertEqual(nt.grad, y)
|
|
|
|
def test_nested_tensor_bmm_gradcheck(self, device):
|
|
a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, d):
|
|
nt0 = torch.nested.as_nested_tensor([a, b])
|
|
nt1 = torch.nested.as_nested_tensor([c, d])
|
|
result = nt0.bmm(nt1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b, c, d)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
|
|
|
|
@tf32_on_and_off(0.008)
|
|
def test_nested_tensor_bmm_backward(self, device):
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((2, 6)), torch.randn((3, 6))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((6, 4)), torch.randn((6, 5))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
|
|
pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt0.bmm(nt1)
|
|
ypt = pt0.bmm(pt1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
|
|
|
|
def test_nested_tensor_matmul_gradcheck(self, device):
|
|
a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, d):
|
|
nt0 = torch.nested.as_nested_tensor([a, b])
|
|
nt1 = torch.nested.as_nested_tensor([c, d])
|
|
result = torch.matmul(nt0, nt1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b, c, d)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
|
|
|
|
def test_nested_tensor_matmul_backward(self, device):
|
|
nt0 = torch.nested.nested_tensor(
|
|
[torch.randn((7, 2, 6)), torch.randn((7, 3, 6))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
nt1 = torch.nested.nested_tensor(
|
|
[torch.randn((7, 6, 4)), torch.randn((7, 6, 5))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
|
|
pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
|
|
|
|
ynt = torch.matmul(nt0, nt1)
|
|
ypt = torch.matmul(pt0, pt1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
|
|
|
|
def test_nested_tensor_transpose_gradcheck(self, device):
|
|
a = torch.randn(2, 5, requires_grad=True, device=device)
|
|
b = torch.randn(3, 4, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.transpose(-2, -1).transpose(-2, -1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
|
|
|
|
def test_nested_tensor_transpose_backward(self, device):
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 5)), torch.randn((3, 4))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.transpose(-2, -1)
|
|
ypt = pt.transpose(-2, -1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_reshape_gradcheck(self, device):
|
|
a = torch.randn(2, 6, requires_grad=True, device=device)
|
|
b = torch.randn(3, 6, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.reshape(2, -1, 2, 3)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
|
|
|
|
def test_nested_tensor_reshape_backward(self):
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True
|
|
)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.reshape(2, -1, 2, 3)
|
|
ypt = pt.reshape(2, -1, 2, 3)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_squeeze_backward(self, device):
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 6, 1)), torch.randn((3, 6, 1))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.squeeze(-1)
|
|
ypt = pt.squeeze(-1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_squeeze_gradcheck(self, device):
|
|
a = torch.randn(
|
|
(2, 6, 1), dtype=torch.float64, requires_grad=True, device=device
|
|
)
|
|
b = torch.randn(
|
|
(3, 6, 1), dtype=torch.float64, requires_grad=True, device=device
|
|
)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.squeeze(-1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
|
|
|
def test_nested_tensor_unsqueeze_backward(self, device):
|
|
nt = torch.nested.nested_tensor(
|
|
[torch.randn((2, 6)), torch.randn((3, 6))],
|
|
requires_grad=True,
|
|
device=device,
|
|
)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.unsqueeze(2)
|
|
ypt = pt.unsqueeze(2)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_unsqueeze_gradcheck(self, device):
|
|
a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device)
|
|
b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.unsqueeze(-1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
|
|
|
def test_nested_tensor_linear(self, device):
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
weight = torch.randn(
|
|
2, 2, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, weight, bias=None):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implicitly tests to_padded_tensor grads
|
|
d = torch.functional.F.linear(nt, weight, bias)
|
|
return torch.nested.to_padded_tensor(d, 0)
|
|
|
|
data = (a, b, c, weight, bias)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Test linear with no bias added
|
|
data = (a, b, c, weight)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_linear_plus_transpose(self, device):
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
weight = torch.randn(
|
|
2, 2, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, weight, bias=None):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implicitly tests to_padded_tensor grads
|
|
d = torch.functional.F.linear(nt, weight, bias)
|
|
d = d.transpose(-1, -2).contiguous()
|
|
return torch.nested.to_padded_tensor(d, 0)
|
|
|
|
data = (a, b, c, weight, bias)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Test linear with no bias added
|
|
data = (a, b, c, weight)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_softmax(self, device):
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, dim):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implicitly tests to_padded_tensor grads
|
|
d = torch.functional.F.softmax(nt, dim=dim)
|
|
return torch.nested.to_padded_tensor(d, 0)
|
|
|
|
# softmax over last dim
|
|
data = (a, b, c, -1)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_linear_backward(self, device):
|
|
a = torch.randn(1, 2, requires_grad=False, device=device)
|
|
b = torch.randn(2, 2, requires_grad=False, device=device)
|
|
c = torch.randn(3, 2, requires_grad=False, device=device)
|
|
|
|
weight = torch.randn(2, 2, requires_grad=True, device=device)
|
|
bias = torch.randn(2, requires_grad=True, device=device)
|
|
nt = torch.nested.as_nested_tensor([a, b, c], device=device)
|
|
|
|
out = torch.functional.F.linear(nt, weight, bias)
|
|
|
|
out.backward(out.clone())
|
|
|
|
assert weight.grad is not None
|
|
assert bias.grad is not None
|
|
|
|
assert a.grad is None
|
|
assert b.grad is None
|
|
assert c.grad is None
|
|
|
|
def test_values_grad_with_broadcast(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
buffer = nt.values()
|
|
return buffer.sum()
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_to_buffer_series_ops_grad_with_broadcast(self, device):
|
|
a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
buffer = nt.values()
|
|
buffer = buffer * 2
|
|
return buffer.exp()
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_unbind_flow_through(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
ntT = nt.transpose(-1, -2)
|
|
unbound = ntT.unbind()
|
|
d = unbound[0]
|
|
d = torch.pow(d, 2)
|
|
return d
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_split_with_sizes_flow_through(self, device):
|
|
a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
splits = nt.split_with_sizes([2, 3], dim=-1)
|
|
unbound = splits[1].unbind()
|
|
d = unbound[0]
|
|
d = torch.pow(d, 2)
|
|
return d
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_indexing_backward(self, device):
|
|
x0 = torch.randn((2, 5))
|
|
x1 = torch.randn((3, 4))
|
|
nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True)
|
|
self.assertEqual(nt[0], x0)
|
|
self.assertEqual(nt[-1], x1)
|
|
grad_x0 = torch.randn((2, 5), device=device)
|
|
nt[0].backward(grad_x0)
|
|
expected_grad = torch.nested.nested_tensor(
|
|
[grad_x0, torch.zeros((3, 4), device=device)]
|
|
)
|
|
self.assertEqual(nt.grad, expected_grad)
|
|
|
|
def test_masked_fill_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
mask = nt.detach().clone().to(bool)
|
|
out = nt.masked_fill(mask, 0)
|
|
out = torch.nested.to_padded_tensor(out, 0)
|
|
return out
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_gelu_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_gelu = torch.nn.functional.gelu(nt)
|
|
return torch.nested.to_padded_tensor(nt_gelu, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_relu_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_relu = torch.nn.functional.relu(nt)
|
|
return torch.nested.to_padded_tensor(nt_relu, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_selu_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_relu = torch.nn.functional.silu(nt)
|
|
return torch.nested.to_padded_tensor(nt_relu, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_abs_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_abs = torch.abs(nt)
|
|
return torch.nested.to_padded_tensor(nt_abs, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Previously would error when input NT doesn't require grad
|
|
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
|
|
def test_layer_norm_backward_edge_case(self, device):
|
|
size = 4
|
|
a = torch.randn(
|
|
1, 2, size, requires_grad=False, dtype=torch.float64, device=device
|
|
)
|
|
nt = torch.nested.nested_tensor([a])
|
|
nt_layer_norm = torch.nn.LayerNorm(
|
|
nt.size(-1), device=device, dtype=torch.float64
|
|
)
|
|
out = nt_layer_norm(nt)
|
|
out.backward(out.clone())
|
|
|
|
def test_accumulate_grad_different_strides(self, device):
|
|
a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt_1 = torch.nested.as_nested_tensor([a, b])
|
|
nt_2 = nt_1.clone()
|
|
out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2)
|
|
return torch.nested.to_padded_tensor(out, 0)
|
|
|
|
data = (a, b)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/95562
|
|
@skipIfSlowGradcheckEnv
|
|
@parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
|
|
def test_layer_norm_backward(self, device, size):
|
|
a = torch.randn(
|
|
1, 2, size, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
b = torch.randn(
|
|
2, 2, size, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
c = torch.randn(
|
|
3, 2, size, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
layer_norm = torch.nn.LayerNorm(
|
|
nt.size(-1), device=device, dtype=torch.float64
|
|
)
|
|
nt_layer_norm = layer_norm(nt)
|
|
return torch.nested.to_padded_tensor(nt_layer_norm, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/95562
|
|
@skipIfSlowGradcheckEnv
|
|
# Could either mark slow or reduce size
|
|
@parametrize("size", [128, 32, 4, 2])
|
|
def test_layer_norm_backward_5d(self, device, size):
|
|
a = torch.randn(
|
|
4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
b = torch.randn(
|
|
7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
c = torch.randn(
|
|
10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
layer_norm = torch.nn.LayerNorm(
|
|
(size, size, nt.size(-1)), device=device, dtype=torch.float64
|
|
)
|
|
nt_layer_norm = layer_norm(nt)
|
|
return torch.nested.to_padded_tensor(nt_layer_norm, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
|
|
# Found in torch/testing/_comparison.py
|
|
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
|
|
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
|
|
|
|
|
|
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
|
|
deviation = true_value - computed_value
|
|
deviation = torch.abs(deviation / true_value)
|
|
# Fill in the nans with the default rtol
|
|
torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
|
|
return deviation.max().item()
|
|
|
|
|
|
def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
|
|
deviation = true_value - computed_value
|
|
atol = torch.abs(deviation).max().item()
|
|
return atol
|
|
|
|
|
|
def get_tolerances(
|
|
true_value: torch.Tensor,
|
|
computed_value: torch.Tensor,
|
|
fudge_factor: Optional[float] = None,
|
|
) -> tuple[float, float]:
|
|
"""Returns the absolute and relative tolerances for comparing two tensors."""
|
|
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
|
|
atol = get_atol(true_value, computed_value)
|
|
rtol = get_rtol(true_value, computed_value)
|
|
|
|
atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
|
|
rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
|
|
# torch.isclose() has weird behavior around see:
|
|
# https://github.com/pytorch/pytorch/issues/102400
|
|
if rtol > 1e30:
|
|
rtol = default_rtol[computed_value.dtype]
|
|
return atol, rtol
|
|
|
|
|
|
# We can probably parametrizing existing tests instead of having a separate
|
|
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
|
|
@markDynamoStrictTest
|
|
class TestNestedTensorSubclass(NestedTensorTestCase):
|
|
# TODO: consolidate with the below
|
|
def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
|
|
Ds = nested_size[1:]
|
|
out = []
|
|
for s in nested_size[0]:
|
|
out.append(
|
|
torch.randn(
|
|
s,
|
|
*Ds,
|
|
requires_grad=requires_grad,
|
|
device=device,
|
|
dtype=torch.float64,
|
|
)
|
|
)
|
|
return out
|
|
|
|
def _get_example_tensor_lists(
|
|
self,
|
|
include_list_of_lists=True,
|
|
include_requires_grad=True,
|
|
include_inner_dim_size_1=False,
|
|
include_2d_tensor=False,
|
|
):
|
|
def _make_tensor(
|
|
*shape, include_requires_grad=include_requires_grad, requires_grad=True
|
|
):
|
|
return torch.randn(
|
|
*shape,
|
|
requires_grad=(requires_grad if include_requires_grad else False),
|
|
)
|
|
|
|
# Purposefully introduce mixed requires_grad settings for the components
|
|
# when include_requires_grad=True.
|
|
example_lists = [
|
|
# (B, *, D) with B=4
|
|
[
|
|
_make_tensor(2, 5),
|
|
_make_tensor(3, 5, requires_grad=False),
|
|
_make_tensor(4, 5, requires_grad=False),
|
|
_make_tensor(6, 5),
|
|
],
|
|
# (B, *, D_0, D_1) with B=5
|
|
[
|
|
_make_tensor(2, 5, 6),
|
|
_make_tensor(3, 5, 6),
|
|
_make_tensor(4, 5, 6, requires_grad=False),
|
|
_make_tensor(5, 5, 6),
|
|
_make_tensor(6, 5, 6),
|
|
],
|
|
# (B, *, D_0, D_1, D_2) with B=6
|
|
[
|
|
_make_tensor(2, 5, 6, 7),
|
|
_make_tensor(3, 5, 6, 7),
|
|
_make_tensor(4, 5, 6, 7, requires_grad=False),
|
|
_make_tensor(5, 5, 6, 7),
|
|
_make_tensor(6, 5, 6, 7),
|
|
_make_tensor(7, 5, 6, 7),
|
|
],
|
|
]
|
|
|
|
if include_list_of_lists:
|
|
example_lists.append(
|
|
# (B, *, D) with B=3 in list form
|
|
[
|
|
_make_tensor(2, 5, requires_grad=False).tolist(),
|
|
_make_tensor(3, 5).tolist(),
|
|
_make_tensor(4, 5).tolist(),
|
|
]
|
|
)
|
|
|
|
if include_inner_dim_size_1:
|
|
example_lists.append(
|
|
[
|
|
_make_tensor(2, 1),
|
|
_make_tensor(3, 1, requires_grad=False),
|
|
_make_tensor(4, 1, requires_grad=False),
|
|
_make_tensor(6, 1),
|
|
] # (B, *, 1)
|
|
)
|
|
example_lists.append(
|
|
[
|
|
_make_tensor(2, 5, 1),
|
|
_make_tensor(3, 5, 1, requires_grad=False),
|
|
_make_tensor(4, 5, 1, requires_grad=False),
|
|
_make_tensor(6, 5, 1),
|
|
] # (B, *, 5, 1)
|
|
)
|
|
|
|
if include_2d_tensor:
|
|
example_lists.append(
|
|
[
|
|
_make_tensor(2),
|
|
_make_tensor(3, requires_grad=False),
|
|
_make_tensor(4, requires_grad=False),
|
|
_make_tensor(6),
|
|
] # (B, *)
|
|
)
|
|
|
|
return example_lists
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"contiguity",
|
|
["contig", "noncontig_transposed", "noncontig_with_holes"],
|
|
name_fn=lambda c: c,
|
|
)
|
|
@parametrize("weights_only", [True, False])
|
|
def test_serialization(self, device, dtype, contiguity, weights_only):
|
|
# Test with 3 cases:
|
|
# 1. contiguous
|
|
# 2. non-contiguous transposed
|
|
# 3. non-contiguous with holes
|
|
if contiguity == "contig":
|
|
nt = random_nt_from_dims(
|
|
[4, None, 10],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
)
|
|
elif contiguity == "noncontig_transposed":
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
).transpose(-3, -2)
|
|
elif contiguity == "noncontig_with_holes":
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values=torch.randn(10, 3, device=device, dtype=dtype),
|
|
offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64),
|
|
# these lengths specify holes
|
|
lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64),
|
|
)
|
|
else:
|
|
raise ValueError("invalid contiguity specified for test_serialization()")
|
|
|
|
# Access sizes / strides to ensure cache doesn't break serialization.
|
|
# See https://github.com/pytorch/pytorch/issues/129366
|
|
nt.size()
|
|
nt.stride()
|
|
|
|
with tempfile.TemporaryFile() as f:
|
|
torch.save(nt, f)
|
|
f.seek(0)
|
|
nt_loaded = torch.load(f, weights_only=weights_only)
|
|
|
|
self.assertIsNot(nt, nt_loaded)
|
|
# we expect a new offsets tensor -> different nested int upon load
|
|
self.assertEqualIgnoringNestedInts(nt, nt_loaded)
|
|
self.assertEqual(nt._ragged_idx, nt_loaded._ragged_idx)
|
|
# ensure shapes are equal except nested int
|
|
nt_rest_of_shape = (
|
|
*nt.shape[: nt._ragged_idx],
|
|
*nt.shape[nt._ragged_idx + 1 :],
|
|
)
|
|
nt_loaded_rest_of_shape = (
|
|
*nt_loaded.shape[: nt_loaded._ragged_idx],
|
|
*nt_loaded.shape[nt_loaded._ragged_idx + 1 :],
|
|
)
|
|
self.assertEqual(nt_rest_of_shape, nt_loaded_rest_of_shape)
|
|
# ensure metadata cache is carried through serialization
|
|
self.assertEqual(nt._metadata_cache, nt_loaded._metadata_cache)
|
|
# ensure lengths are carried through if present
|
|
self.assertEqual(nt._lengths, nt_loaded._lengths)
|
|
|
|
def test_tensor_attributes(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
_offsets = nt.offsets()
|
|
|
|
for op in (
|
|
torch.ops.aten.is_non_overlapping_and_dense.default,
|
|
torch.ops.aten.sym_size.default,
|
|
torch.ops.aten.dim.default,
|
|
torch.ops.aten.numel.default,
|
|
torch.ops.aten.sym_numel.default,
|
|
torch.ops.aten.sym_stride.default,
|
|
torch.ops.aten.sym_storage_offset.default,
|
|
):
|
|
op(nt)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "directly calling torch.ops.aten.size"
|
|
):
|
|
torch.ops.aten.size.default(nt)
|
|
|
|
nested_int = torch.nested._internal.nested_tensor.get_tensor_symint(
|
|
_offsets, coeff=1
|
|
)
|
|
self.assertEqual(nt.size(), (3, nested_int, 3))
|
|
self.assertEqual(nt.shape, (3, nested_int, 3))
|
|
self.assertEqual(nt.dim(), 3)
|
|
self.assertEqual(nt.numel(), 27)
|
|
|
|
@parametrize("nt_dim", [3, 4, 5])
|
|
def test_linear(self, device, nt_dim):
|
|
if nt_dim == 3:
|
|
fixed_shape = (3,)
|
|
elif nt_dim == 4:
|
|
fixed_shape = (4, 3)
|
|
elif nt_dim == 5:
|
|
fixed_shape = (5, 4, 3)
|
|
|
|
a = torch.randn(
|
|
2, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
b = torch.randn(
|
|
3, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
c = torch.randn(
|
|
4, *fixed_shape, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
weight = torch.randn(
|
|
4, 3, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
bias = torch.randn(4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, weight, bias):
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
out = torch.nn.functional.linear(nt, weight, bias)
|
|
return out.values()
|
|
|
|
gradcheck(
|
|
grad_test_func, inputs=(a, b, c, weight, bias), check_batched_grad=False
|
|
)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float32)
|
|
@serialTest()
|
|
def test_linear_backward_memory_usage(self, device, dtype):
|
|
# Verify that linear_backward() doesn't use more memory than it should
|
|
# for higher dim input sizes.
|
|
# See https://github.com/pytorch/pytorch/issues/141112
|
|
B, D, max_seq_len = 64, 512, 100
|
|
torch._C._cuda_clearCublasWorkspaces()
|
|
m = torch.nn.Linear(D, D, device=device)
|
|
nt = torch.nested.as_nested_tensor(
|
|
[
|
|
torch.rand(size=[seq_len, D])
|
|
for seq_len in torch.randint(max_seq_len, size=(B,))
|
|
],
|
|
layout=torch.jagged,
|
|
device=device,
|
|
)
|
|
|
|
# (B, j1, D) -> (B, j1, 1, D) for a higher dim input size
|
|
nt = nt.unsqueeze(-2)
|
|
# linear_backward() should not explode the max memory usage
|
|
torch.cuda.reset_max_memory_allocated()
|
|
m(nt).sum().backward()
|
|
# expect under a GB for max memory allocated
|
|
max_after_gb = torch.cuda.max_memory_allocated(0) // (1024**3)
|
|
self.assertEqual(max_after_gb, 0)
|
|
|
|
def test_unary_pointwise(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
out = torch.nn.functional.silu(nt.sin().cos())
|
|
return out.values()
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
|
|
|
def test_unary_pointwise_transposed_inputs(self, device):
|
|
a, b, c = (
|
|
torch.randn(
|
|
i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
for i in range(3)
|
|
)
|
|
|
|
nt = torch.nested.nested_tensor(
|
|
[a.detach(), b.detach(), c.detach()], layout=torch.jagged
|
|
)
|
|
nt_t = nt.transpose(1, 2)
|
|
self.assertFalse(nt_t.is_contiguous())
|
|
out = torch.nn.functional.silu(nt_t.sin().cos())
|
|
self.assertEqual(
|
|
out.is_contiguous(),
|
|
torch.nn.functional.silu(b.transpose(-1, -2).sin().cos()).is_contiguous(),
|
|
)
|
|
|
|
self.assertEqual(nt_t.shape, out.shape)
|
|
|
|
a, b, c = (
|
|
torch.randn(
|
|
i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
for i in range(3)
|
|
)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
nt_t = nt.transpose(1, 2)
|
|
out = torch.nn.functional.silu(nt_t.sin().cos())
|
|
return out.values()
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
|
|
|
def test_binary_pointwise(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
# Incorrect usage: shape check will fail if the offsets tensor are not
|
|
# the same exact tensor object
|
|
nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
nt2 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"cannot call binary pointwise function .* with inputs of shapes",
|
|
lambda: nt1 * nt2,
|
|
)
|
|
|
|
# Correct usage: chain the calls using the same offsets tensor object
|
|
def grad_test_func(a, b, c):
|
|
nt1 = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
# TODO: Switch to public API that takes in (values, offsets) once it exists
|
|
nt2, offsets = jagged_from_list([a, b, c], nt1.offsets())
|
|
out = nt1 * nt2
|
|
return out.values()
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
|
|
|
def test_binary_pointwise_transposed(self, device):
|
|
a, b, c = (
|
|
torch.randn(i + 2, 5, dtype=torch.float64, device=device) for i in range(3)
|
|
)
|
|
|
|
nt1, offsets = jagged_from_list([a, b, c], None)
|
|
nt2, offsets = jagged_from_list([a, b, c], offsets)
|
|
|
|
nt1_t = nt1.transpose(1, 2)
|
|
nt2_t = nt2.transpose(1, 2)
|
|
|
|
# out = nt1_t * nt2_t
|
|
# self.assertFalse(nt1_t.is_contiguous())
|
|
# self.assertEqual(out.is_contiguous(), (b.transpose(-1, -2) * b.transpose(-1, -2)).is_contiguous())
|
|
# self.assertEqual(out.shape, nt1_t.shape)
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"cannot call binary pointwise function mul.Tensor with inputs of shapes",
|
|
lambda: nt1 * nt2_t,
|
|
)
|
|
|
|
a, b, c = (
|
|
torch.randn(
|
|
i + 2, 5, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
for i in range(3)
|
|
)
|
|
|
|
# Correct usage: chain the calls using the same offsets tensor object
|
|
def grad_test_func(a, b, c):
|
|
nt1, offsets = jagged_from_list([a, b, c], None)
|
|
nt2, offsets = jagged_from_list([a, b, c], offsets)
|
|
nt1_t = nt1.transpose(1, 2)
|
|
nt2_t = nt2.transpose(1, 2)
|
|
out = nt1_t * nt2_t
|
|
return out.values()
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
|
|
|
def test_binary_pointwise_with_nested_int_second_arg(self, device):
|
|
# See https://github.com/pytorch/pytorch/issues/138496
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "invalid argument"):
|
|
nt * nt.size(1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "invalid argument"):
|
|
nt + nt.size(1)
|
|
|
|
def test_split(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
out = torch.split(nt, 2, -1)
|
|
self.assertEqual(len(out), 2)
|
|
self.assertEqualIgnoringNestedInts(
|
|
out[0],
|
|
torch.nested.as_nested_tensor(
|
|
[a[:, 0:2], b[:, 0:2], c[:, 0:2]], layout=torch.jagged
|
|
),
|
|
)
|
|
self.assertEqualIgnoringNestedInts(
|
|
out[1],
|
|
torch.nested.as_nested_tensor(
|
|
[a[:, 2:], b[:, 2:], c[:, 2:]], layout=torch.jagged
|
|
),
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"split\(\): not supported for NestedTensor on ragged dim",
|
|
):
|
|
torch.split(nt, 2, 1)
|
|
|
|
def test_split_with_sizes(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
nt = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
out = torch.split(nt, [1, 2], -1)
|
|
self.assertEqual(len(out), 2)
|
|
self.assertEqualIgnoringNestedInts(
|
|
out[0],
|
|
torch.nested.as_nested_tensor(
|
|
[a[:, 0:1], b[:, 0:1], c[:, 0:1]], layout=torch.jagged
|
|
),
|
|
)
|
|
self.assertEqualIgnoringNestedInts(
|
|
out[1],
|
|
torch.nested.as_nested_tensor(
|
|
[a[:, 1:], b[:, 1:], c[:, 1:]], layout=torch.jagged
|
|
),
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"split_with_sizes\(\): not supported for NestedTensor on ragged dim",
|
|
):
|
|
torch.split(nt, [1, 2], 1)
|
|
|
|
def test_softmax(self, device):
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
# operate on dim=2
|
|
output = nt.softmax(dim=2)
|
|
|
|
@torch._dynamo.disable
|
|
def _compare_to_ref(nt, output, dim):
|
|
for in_component, out_component in zip(nt.unbind(), output.unbind()):
|
|
self.assertEqual(in_component.softmax(dim=dim), out_component)
|
|
|
|
# dim=2 -> dim=1 after unbind
|
|
_compare_to_ref(nt, output, dim=1)
|
|
|
|
# operate on dim=-1
|
|
output2 = nt.softmax(dim=-1)
|
|
torch._dynamo.disable(self.assertEqual)(output, output2)
|
|
_compare_to_ref(nt, output2, dim=-1)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b], layout=torch.jagged)
|
|
out = nt.softmax(dim=-1)
|
|
return out.values()
|
|
|
|
a = torch.rand(4, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.rand(8, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
gradcheck(grad_test_func, inputs=(a, b), check_batched_grad=False)
|
|
|
|
def test_views_inherit_ragged_dim(self, device):
|
|
# view
|
|
nt = random_nt_from_dims(
|
|
[4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
# inherit ragged dim via -1
|
|
view = nt.view(4, -1, 80)
|
|
self.assertEqual(nt.shape[1], view.shape[1])
|
|
# inherit batch and ragged dims via -1
|
|
view2 = nt.view(-1, -1, 80)
|
|
self.assertEqual(nt.shape[:2], view2.shape[:2])
|
|
|
|
# expand
|
|
nt = random_nt_from_dims(
|
|
[3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
# inherit batch and ragged dims via -1
|
|
view = nt.expand(-1, -1, 5)
|
|
self.assertEqual(nt.shape[:2], view.shape[:2])
|
|
|
|
def test_view_ragged_idx_not_one(self, device):
|
|
nt = random_nt_from_dims(
|
|
[2, None, 20], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
|
|
view_transposed = nt.transpose(1, 2).view(2, 20, nt.size(1))
|
|
self.assertEqual((2, 20, nt.size(1)), (view_transposed.size()))
|
|
self.assertEqual(view_transposed._base, nt._base)
|
|
|
|
def test_unsafe_view(self, device):
|
|
nt = random_nt_from_dims(
|
|
[4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
# basic view
|
|
view1 = torch.ops.aten._unsafe_view(nt, (4, -1, 80))
|
|
self.assertEqual((4, nt.size(1), 80), tuple(view1.size()))
|
|
# _unsafe_view differs from view in that the view information is not tracked
|
|
self.assertTrue(view1._base is None)
|
|
|
|
# test an unsafe_view when ragged_idx != 1, currently only supports identity view
|
|
nt_t = nt.transpose(1, 2)
|
|
view2 = torch.ops.aten._unsafe_view(nt_t, (4, 8, nt.size(1), 10))
|
|
self.assertEqual((4, 8, nt.size(1), 10), tuple(view2.size()))
|
|
self.assertTrue(view2._base is None)
|
|
|
|
@xfailIfTorchDynamo
|
|
@parametrize("requires_grad", [False, True])
|
|
def test_reshape_decomp(self, device, requires_grad):
|
|
# contiguous NT should result in view.
|
|
nt = (
|
|
random_nt_from_dims(
|
|
[3, None, 10],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
)
|
|
.detach()
|
|
.requires_grad_(requires_grad)
|
|
)
|
|
view = nt.reshape(-1, -1, 5, 2)
|
|
self.assertEqual(view.shape[:2], nt.shape[:2])
|
|
self.assertTrue(view._is_view() and view._base is nt)
|
|
# make sure gradients flow back
|
|
if requires_grad:
|
|
view.backward(torch.ones_like(view))
|
|
self.assertEqual(nt.grad, torch.ones_like(nt))
|
|
|
|
# non-contiguous NT should result in contiguous copy
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
nt_noncontig = nt.transpose(-1, -2)
|
|
self.assertFalse(nt_noncontig.is_contiguous())
|
|
copy = nt_noncontig.reshape(-1, -1, 10)
|
|
self.assertTrue(copy.is_contiguous())
|
|
self.assertEqual(copy.shape[:2], nt.shape[:2])
|
|
# make sure gradients flow back
|
|
if requires_grad:
|
|
copy.backward(torch.ones_like(copy))
|
|
self.assertEqual(nt.grad, torch.ones_like(nt))
|
|
|
|
def test_flatten_decomp(self, device):
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
flattened = nt.flatten(-2, -1)
|
|
self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape)
|
|
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
flattened = nt.flatten(-3, -2)
|
|
self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape)
|
|
|
|
def test_chunk(self, device):
|
|
# none NJT case
|
|
t = torch.randn(10, 4, 5, requires_grad=True)
|
|
t_list = t.chunk(3, dim=0)
|
|
loss = t_list[0].sum() + t_list[2].sum()
|
|
loss.backward()
|
|
|
|
# normal case
|
|
D = 30
|
|
B = 8
|
|
nt = random_nt_from_dims(
|
|
[B, None, D],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
NUM_CHUNKS = 3
|
|
chunks = nt.chunk(NUM_CHUNKS, dim=-1)
|
|
self.assertEqual(len(chunks), NUM_CHUNKS)
|
|
for i in range(NUM_CHUNKS):
|
|
self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS)
|
|
|
|
# test chunk_backward
|
|
values = torch.randn(
|
|
5, 11, dtype=torch.float64, device=device, requires_grad=True
|
|
)
|
|
offsets = torch.tensor([0, 2, 3, 5], device=device)
|
|
|
|
def grad_test_func(values, offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
chunks = nt.chunk(3, dim=-1)
|
|
return chunks[0].values().sum()
|
|
|
|
assert gradcheck(
|
|
grad_test_func,
|
|
inputs=(values, offsets),
|
|
check_batched_grad=False,
|
|
)
|
|
|
|
# chunk on batch dim
|
|
chunks = nt.chunk(NUM_CHUNKS, dim=0)
|
|
self.assertEqual(len(chunks), NUM_CHUNKS)
|
|
chunk_size = math.ceil(B / NUM_CHUNKS)
|
|
for i in range(NUM_CHUNKS):
|
|
if i < NUM_CHUNKS - 1:
|
|
self.assertEqual(chunks[i].shape[0], chunk_size)
|
|
else:
|
|
self.assertEqual(chunks[i].shape[0], B - chunk_size * (NUM_CHUNKS - 1))
|
|
offsets_expected = (
|
|
nt._offsets[i * chunk_size + 1 : (i + 1) * chunk_size + 1]
|
|
- nt._offsets[i * chunk_size]
|
|
)
|
|
self.assertEqual(chunks[i]._offsets[1:], offsets_expected)
|
|
self.assertEqual(nt._values, torch.cat([x._values for x in chunks], dim=0))
|
|
|
|
# doesn't support backward for chunk (dim=0) yet
|
|
loss = (
|
|
chunks[0].values().sum()
|
|
+ chunks[1].values().sum()
|
|
+ chunks[2].values().sum()
|
|
)
|
|
loss.backward()
|
|
|
|
# chunk on ragged dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "chunk.* not supported for NestedTensor on ragged dim"
|
|
):
|
|
nt.chunk(2, dim=1)
|
|
|
|
def test_squeeze(self, device):
|
|
B = 4
|
|
D = 6
|
|
# squeeze middle dim
|
|
nt = random_nt_from_dims(
|
|
[B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
j0 = nt.shape[1]
|
|
|
|
for dim_arg in [-2, 2]:
|
|
out = nt.squeeze(dim_arg)
|
|
self.assertEqual(out.shape, (B, j0, D))
|
|
self.assertEqual(out.unsqueeze(-2), nt)
|
|
|
|
# squeeze last dim
|
|
nt = random_nt_from_dims(
|
|
[B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged
|
|
)
|
|
j1 = nt.shape[1]
|
|
|
|
for dim_arg in [-1, 2]:
|
|
out = nt.squeeze(dim_arg)
|
|
self.assertEqual(out.shape, (B, j1))
|
|
self.assertEqual(out.unsqueeze(-1), nt)
|
|
|
|
# squeeze on batch dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "squeeze.* not supported for NestedTensor on dim=0"
|
|
):
|
|
nt.squeeze(0)
|
|
|
|
# squeeze on ragged dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "squeeze.* not supported for NestedTensor on ragged dim"
|
|
):
|
|
nt.squeeze(1)
|
|
|
|
def test_binary_pointwise_broadcasting(self, device):
|
|
# (B, j0, 3, 4)
|
|
ts = self._get_list_for_jagged_tensor(
|
|
((2, 3, 4), 3, 4), device, requires_grad=True
|
|
)
|
|
# (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
|
|
# (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
|
|
# (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
|
|
# Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
|
|
t_sizes = (
|
|
(4,),
|
|
(1, 4),
|
|
(3, 1),
|
|
(1, 3, 1),
|
|
(1, 1, 1, 4),
|
|
# (1, 1, 1, 1, 4), (unsupported today)
|
|
)
|
|
|
|
def grad_test_func(t, *ts):
|
|
nt = torch.nested.as_nested_tensor(list(ts), layout=torch.jagged)
|
|
out = nt + t
|
|
return out.values()
|
|
|
|
for t_size in t_sizes:
|
|
t = torch.rand(
|
|
t_size, requires_grad=True, device=device, dtype=torch.float64
|
|
)
|
|
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
|
|
|
|
def test_threshold_backward(self, device):
|
|
ts1 = self._get_list_for_jagged_tensor(
|
|
((2, 3, 4), 16), device=device, requires_grad=False
|
|
)
|
|
ts2 = self._get_list_for_jagged_tensor(
|
|
((2, 3, 4), 16), device=device, requires_grad=False
|
|
)
|
|
|
|
nt1, offsets = jagged_from_list(ts1, None)
|
|
nt2, offsets = jagged_from_list(ts2, offsets)
|
|
buf1 = nt1.values().detach().clone()
|
|
buf2 = nt2.values().detach().clone()
|
|
|
|
res_nt = torch.ops.aten.threshold_backward(nt1, nt2, 0.0)
|
|
res_dense = torch.ops.aten.threshold_backward(buf1, buf2, 0.0)
|
|
|
|
self.assertEqual(res_dense, res_nt.values())
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float32)
|
|
def test_record_stream(self, device, dtype):
|
|
def _create_nt():
|
|
values = torch.ones(1024, 4 * 1024, device="cuda")
|
|
offsets = torch.tensor([0, 500, 1024], device="cuda", dtype=torch.int64)
|
|
lengths = offsets.diff()
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets, lengths)
|
|
data_ptrs = {
|
|
nt._values.data_ptr(),
|
|
nt._offsets.data_ptr(),
|
|
nt._lengths.data_ptr(),
|
|
}
|
|
return nt, data_ptrs
|
|
|
|
def fn(record_stream):
|
|
nt, data_ptrs = _create_nt()
|
|
s = torch.cuda.Stream()
|
|
|
|
with torch.cuda.stream(s):
|
|
# emulate doing something long via sleep
|
|
per_ms = 2e7
|
|
torch.cuda._sleep(int(per_ms * 100))
|
|
if record_stream:
|
|
nt.record_stream(s)
|
|
return data_ptrs
|
|
|
|
# expect memory reuse when record_stream() is not run
|
|
data_ptrs = fn(record_stream=False)
|
|
nt, nt_data_ptrs = _create_nt()
|
|
self.assertEqual(data_ptrs, nt_data_ptrs)
|
|
del nt
|
|
torch.cuda.synchronize()
|
|
|
|
# expect memory to be preserved (no reuse) when record_stream() is run
|
|
data_ptrs = fn(record_stream=True)
|
|
nt, nt_data_ptrs = _create_nt()
|
|
self.assertEqual(len(data_ptrs.intersection(nt_data_ptrs)), 0)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"func",
|
|
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
|
|
name_fn=get_op_name,
|
|
)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_jagged_op_different_output_shape_dim(
|
|
self, device, dtype, keepdim, requires_grad, components_require_grad, func
|
|
):
|
|
"""
|
|
Operator passes when reducing on valid reduction dimensions.
|
|
This test is for operators which return an output tensor with a shape different from the input tensor.
|
|
"""
|
|
if get_op_name(func) == "mean" and not keepdim:
|
|
return
|
|
|
|
op_name = get_op_name(func)
|
|
|
|
ts = self._get_list_for_jagged_tensor(
|
|
((2, 3, 4), 3, 4), device=device, requires_grad=True
|
|
) # (B, j0, 3, 4)
|
|
|
|
# verify correctness of shapes (assuming that ragged_idx == 1)
|
|
if op_name == "sum":
|
|
reduce_dims = (
|
|
((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged
|
|
((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch
|
|
((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch
|
|
((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch
|
|
(
|
|
(0, 1, 2, 3),
|
|
(),
|
|
(1, 1, 1, 1),
|
|
(0, 1, 2),
|
|
), # batch, ragged, non-batch, non-batch
|
|
((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch
|
|
) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
|
|
elif op_name == "mean":
|
|
reduce_dims = (
|
|
((2,), (3, None, 4), (3, None, 1, 4), (1,)),
|
|
((3,), (3, None, 3), (3, None, 3, 1), (2,)),
|
|
)
|
|
|
|
for rd, ref_shape_no_keepdim, ref_shape_keepdim, _ in reduce_dims:
|
|
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
|
|
out = func(nt, dim=rd, keepdim=keepdim)
|
|
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
|
|
if not torch.compiler.is_compiling(): # if not using torch dynamo
|
|
self.assertEqual(len(out.shape), len(ref_shape))
|
|
for o, r in zip(out.shape, ref_shape):
|
|
if r is not None:
|
|
self.assertEqual(o, r)
|
|
else:
|
|
self.assertTrue(isinstance(o, torch.SymInt))
|
|
|
|
# verify correctness of values
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True,
|
|
)
|
|
for tensor_list, reduce_dim_tuple in itertools.product(
|
|
tensor_lists, reduce_dims
|
|
):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
|
|
|
|
if nt.dim() > reduce_dim[-1]:
|
|
out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
|
|
if nt._ragged_idx in reduce_dim: # raggedness reduced away
|
|
out_expected = func(
|
|
nt.values(), dim=reduce_dim_expected, keepdim=keepdim
|
|
)
|
|
self.assertTrue(torch.allclose(out_actual, out_expected))
|
|
else: # raggedness preserved
|
|
out_expected = func(nt.values(), dim=reduce_dim_expected)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
out_actual.values().view(-1), out_expected.view(-1)
|
|
)
|
|
)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
@parametrize(
|
|
"func",
|
|
[torch.nn.functional.softmax, torch.nn.functional.log_softmax],
|
|
name_fn=lambda func: func.__name__,
|
|
)
|
|
def test_softmax_dim(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
func,
|
|
):
|
|
"""
|
|
Softmax passes when reducing on valid reduction dimensions.
|
|
"""
|
|
ts = self._get_list_for_jagged_tensor(
|
|
((2, 3, 4), 3, 4), device=device, requires_grad=True
|
|
) # (B, j0, 3, 4)
|
|
|
|
output_shape = (3, None, 3, 4)
|
|
|
|
# verify correctness of shapes (assuming that ragged_idx == 1)
|
|
reduce_dims = (
|
|
(2, 1),
|
|
(3, 2),
|
|
) # (reduction dimension, effective reduction dimension for baseline)
|
|
|
|
for reduce_dim, _ in reduce_dims:
|
|
nt = torch.nested.as_nested_tensor(ts, layout=torch.jagged)
|
|
out_actual = func(nt, dim=reduce_dim)
|
|
torch._dynamo.disable(self.assertEqual)(
|
|
len(out_actual.shape), len(output_shape)
|
|
) # disable if running on dynamo
|
|
for dim_actual, dim_expected in zip(out_actual.shape, output_shape):
|
|
if dim_expected is not None:
|
|
self.assertEqual(dim_actual, dim_expected)
|
|
else:
|
|
self.assertTrue(isinstance(dim_actual, torch.SymInt))
|
|
|
|
# verify correctness of values
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True,
|
|
)
|
|
for tensor_list, reduce_dim_tuple in itertools.product(
|
|
tensor_lists, reduce_dims
|
|
):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
reduce_dim, reduce_dim_expected = reduce_dim_tuple
|
|
|
|
if nt.dim() > reduce_dim:
|
|
# nested tensor
|
|
out_actual = func(nt, dim=reduce_dim)
|
|
# dense tensor of dimensions 1 less than out_actual
|
|
out_expected = func(nt.values(), dim=reduce_dim_expected)
|
|
self.assertTrue(
|
|
torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
|
|
)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"func",
|
|
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
|
|
name_fn=get_op_name,
|
|
)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_op_dim_reduce_ragged_idx_1_different_output_shape(
|
|
self, device, dtype, keepdim, requires_grad, components_require_grad, func
|
|
):
|
|
"""
|
|
Operator on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
|
|
This test is for operators which return an output tensor with a shape different from the input tensor.
|
|
"""
|
|
if get_op_name(func) == "mean" and not keepdim:
|
|
return
|
|
|
|
op_name = get_op_name(func)
|
|
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
)
|
|
reduce_dim = (1,) # ragged
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
|
|
out_expected = torch.cat(
|
|
[func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0) for t in nt.unbind()]
|
|
)
|
|
if keepdim:
|
|
out_expected = out_expected.unsqueeze(reduce_dim[0])
|
|
|
|
self.assertFalse(
|
|
out_actual.is_nested,
|
|
f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
|
|
) # output is a dense tensor
|
|
self.assertEqual(out_actual, out_expected)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_softmax_dim_reduce_ragged_idx_1(
|
|
self, device, dtype, requires_grad, components_require_grad
|
|
):
|
|
"""
|
|
Softmax on NestedTensor passes when trying to reduce across ragged dimension, where ragged_idx == 1.
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
include_2d_tensor=True, # (B, *)
|
|
)
|
|
reduce_dim = 1 # ragged
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
out_actual = torch.nn.functional.softmax(nt, dim=reduce_dim)
|
|
out_expected = torch.cat(
|
|
[
|
|
torch.nn.functional.softmax(t, dim=reduce_dim - 1)
|
|
for t in nt.unbind()
|
|
]
|
|
)
|
|
|
|
self.assertTrue(
|
|
out_actual.is_nested,
|
|
"softmax(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
|
|
) # output is a nested tensor
|
|
self.assertTrue(torch.allclose(out_actual.values(), out_expected))
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
@parametrize(
|
|
"func",
|
|
[torch.nn.functional.softmax, torch.nn.functional.log_softmax],
|
|
name_fn=lambda func: func.__name__,
|
|
)
|
|
def test_softmax_reduce_batch_dim(
|
|
self, device, dtype, requires_grad, components_require_grad, func
|
|
):
|
|
"""
|
|
Softmax on NestedTensor fails when trying to reduce across batch dimension.
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
)
|
|
reduce_dim = 0 # batch
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"not supported when reducing across the batch dimension for NestedTensor",
|
|
):
|
|
out = func(nt, dim=reduce_dim)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_layer_norm_reduce_ragged_idx_1(
|
|
self, device, dtype, requires_grad, components_require_grad
|
|
):
|
|
"""
|
|
Layer normalization on NestedTensor passes when trying to normalize across ragged dimension, where ragged_idx == 1.
|
|
"""
|
|
|
|
# requires_grad = False does not currently work with dynamo tests and throws this error:
|
|
# AssertionError: SymInts must use SymNodeVariable.
|
|
# If the underlying value is static, we will create a ConstantVariable and specialize.
|
|
if torch._dynamo.is_compiling() and not requires_grad:
|
|
return
|
|
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
)
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if (
|
|
nt.dim() >= 3
|
|
): # layer norm only works for tensors with 3 or more dimensions
|
|
normalized_shape = nt.shape[nt._ragged_idx :]
|
|
|
|
out_actual = torch.nn.functional.layer_norm(
|
|
nt, normalized_shape=normalized_shape
|
|
)
|
|
out_expected = torch.cat(
|
|
[
|
|
torch.nn.functional.layer_norm(t, normalized_shape=t.shape)
|
|
for t in nt.unbind()
|
|
]
|
|
) # e.g. in 3D tensor (B, *, M), performs layer normalization on B 2D tensors (*, M)
|
|
|
|
self.assertTrue(
|
|
out_actual.is_nested,
|
|
"layer_norm(): the result of reducing a nested tensor along the ragged dimension is a nested tensor",
|
|
) # output is a nested tensor
|
|
self.assertEqual(out_actual._values.shape, out_expected.shape)
|
|
self.assertTrue(torch.allclose(out_actual.values(), out_expected))
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_layer_norm_2d_input(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Layer normalization on NestedTensor fails when trying to operate on a 2-dimensional tensor
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
include_2d_tensor=True, # (B, *)
|
|
)
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if nt.dim() <= 2:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"not supported for NestedTensor objects with 2 or fewer dimensions",
|
|
):
|
|
out = torch.nn.functional.layer_norm(
|
|
nt, normalized_shape=(nt.shape[nt._ragged_idx],)
|
|
)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_layer_norm_operate_on_batch_dim(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Layer normalization on NestedTensor fails when trying to operate on the batch dimension
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
include_2d_tensor=True, # (B, *)
|
|
)
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if nt.dim() > 2: # cannot perform layer normalization on 2D tensors
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"not supported when normalizing over the batch dimension for NestedTensor",
|
|
):
|
|
out = torch.nn.functional.layer_norm(nt, normalized_shape=nt.shape)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"func",
|
|
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
|
|
name_fn=get_op_name,
|
|
)
|
|
@parametrize(
|
|
"transpose_offset", [1, 2]
|
|
) # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_op_dim_reduce_ragged_idx_greater_than_1_different_output_shape(
|
|
self,
|
|
device,
|
|
dtype,
|
|
keepdim,
|
|
requires_grad,
|
|
components_require_grad,
|
|
func,
|
|
transpose_offset,
|
|
):
|
|
"""
|
|
Operator on NestedTensor passes when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
|
|
This test is for operators which return an output tensor with a shape different from the input tensor.
|
|
"""
|
|
if get_op_name(func) == "mean" and not keepdim:
|
|
return
|
|
|
|
op_name = get_op_name(func)
|
|
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
include_2d_tensor=True, # (B, *)
|
|
)
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if nt.dim() > nt._ragged_idx + transpose_offset:
|
|
nt_transposed = nt.transpose(
|
|
nt._ragged_idx, nt._ragged_idx + transpose_offset
|
|
)
|
|
reduce_dim = (nt_transposed._ragged_idx,) # ragged
|
|
|
|
out_actual = func(nt_transposed, dim=reduce_dim, keepdim=keepdim)
|
|
out_expected = torch.cat(
|
|
[
|
|
func(t, dim=(reduce_dim[0] - 1)).unsqueeze(0)
|
|
for t in nt_transposed.unbind()
|
|
]
|
|
)
|
|
if keepdim:
|
|
out_expected = out_expected.unsqueeze(reduce_dim[0])
|
|
|
|
self.assertFalse(
|
|
out_actual.is_nested,
|
|
f"{op_name}(): the result of reducing a nested tensor along the ragged dimension is a dense tensor",
|
|
) # output is a dense tensor
|
|
self.assertEqual(out_actual, out_expected)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"transpose_offset", [1, 2]
|
|
) # [transpose consecutive dimensions, transpose nonconsecutive dimensions]
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_softmax_dim_reduce_ragged_idx_greater_than_1_same_output_shape(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
transpose_offset,
|
|
):
|
|
"""
|
|
Softmax on NestedTensor fails when trying to reduce across a transposed ragged dimension, i.e. ragged_idx > 1
|
|
This test is for operators which return an output tensor with the same shape as the input tensor.
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
)
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if nt.dim() > nt._ragged_idx + transpose_offset:
|
|
nt_transposed = nt.transpose(
|
|
nt._ragged_idx, nt._ragged_idx + transpose_offset
|
|
)
|
|
reduce_dim = nt_transposed._ragged_idx # ragged
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"not supported when reducing along the ragged dimension for ragged_idx > 1 for NestedTensor",
|
|
):
|
|
out = torch.nn.functional.softmax(nt_transposed, dim=reduce_dim)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"func",
|
|
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
|
|
name_fn=get_op_name,
|
|
)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_op_dim_transpose_non_ragged_dim_different_output_shape(
|
|
self, device, dtype, keepdim, requires_grad, components_require_grad, func
|
|
):
|
|
"""
|
|
Operator passes when reducing transposed nested tensors on valid reduction dimensions.
|
|
This test is for operators which return an output tensor with a shape different from the input tensor.
|
|
"""
|
|
if get_op_name(func) == "mean" and not keepdim:
|
|
return
|
|
|
|
# verify correctness of shapes (assuming that ragged_idx == 1)
|
|
if get_op_name(func) == "sum":
|
|
reduce_dims = (
|
|
((0, 1), (3, 4), (1, 1, 3, 4), (0,)), # batch, ragged
|
|
((2, 3), (3, None), (3, None, 1, 1), (1, 2)), # non-batch, non-batch
|
|
((0, 1, 3), (3,), (1, 1, 3, 1), (0, 2)), # batch, ragged, non-batch
|
|
((0, 1, 2), (4,), (1, 1, 1, 4), (0, 1)), # batch, ragged, non-batch
|
|
(
|
|
(0, 1, 2, 3),
|
|
(),
|
|
(1, 1, 1, 1),
|
|
(0, 1, 2),
|
|
), # batch, ragged, non-batch, non-batch
|
|
((2,), (3, None, 4), (3, None, 1, 4), (1,)), # non-batch
|
|
) # (dims, expected shape, expected keepdim shape, reduce_dim_expected), where j0 is represented as None
|
|
elif get_op_name(func) == "mean":
|
|
reduce_dims = (
|
|
((2,), (3, None, 4), (3, None, 1, 4), (1,)),
|
|
((3,), (3, None, 3), (3, None, 3, 1), (2,)),
|
|
)
|
|
|
|
# verify correctness of values
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
)
|
|
for tensor_list, reduce_dim_tuple in itertools.product(
|
|
tensor_lists, reduce_dims
|
|
):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
).transpose(-1, -2)
|
|
|
|
reduce_dim, _, _, reduce_dim_expected = reduce_dim_tuple
|
|
|
|
if nt.dim() > max(
|
|
reduce_dim[-1], nt._ragged_idx + 2
|
|
): # ensure that transposed dimensions are non-batch, non-ragged dimensions
|
|
out_actual = func(nt, dim=reduce_dim, keepdim=keepdim)
|
|
if nt._ragged_idx in reduce_dim: # raggedness reduced away
|
|
out_expected = func(
|
|
nt.values(), dim=reduce_dim_expected, keepdim=keepdim
|
|
)
|
|
self.assertTrue(torch.allclose(out_actual, out_expected))
|
|
else: # raggedness preserved
|
|
out_expected = func(nt.values(), dim=reduce_dim_expected)
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
out_actual.values().view(-1), out_expected.view(-1)
|
|
)
|
|
)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_softmax_dim_transpose_non_ragged_dim(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Softmax passes when reducing transposed nested tensors on valid reduction dimensions.
|
|
This test is for operators which return an output tensor with the same shape as the input tensor.
|
|
"""
|
|
# verify correctness of shapes (assuming that ragged_idx == 1)
|
|
reduce_dims = (
|
|
(2, 1),
|
|
(3, 2),
|
|
) # (reduction dimension, effective reduction dimension for baseline)
|
|
|
|
# verify correctness of values
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False,
|
|
include_requires_grad=components_require_grad,
|
|
include_inner_dim_size_1=True, # (B, *, 1)
|
|
)
|
|
for tensor_list, reduce_dim_tuple in itertools.product(
|
|
tensor_lists, reduce_dims
|
|
):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
).transpose(-1, -2)
|
|
|
|
reduce_dim, reduce_dim_expected = reduce_dim_tuple
|
|
|
|
if nt.dim() > max(reduce_dim, nt._ragged_idx + 2):
|
|
out_actual = torch.nn.functional.softmax(
|
|
nt, dim=reduce_dim
|
|
) # nested tensor
|
|
out_expected = torch.nn.functional.softmax(
|
|
nt.values(), dim=reduce_dim_expected
|
|
) # dense tensor of dimensions 1 less than out_actual
|
|
|
|
self.assertTrue(
|
|
torch.allclose(out_actual.values().view(-1), out_expected.view(-1))
|
|
)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_sum_dim_reduce_ragged_and_non_batch(
|
|
self,
|
|
device,
|
|
dtype,
|
|
keepdim,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Sum on NestedTensor fails when trying to reduce across ragged and non-batch dimensions
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False, include_requires_grad=components_require_grad
|
|
)
|
|
reduce_dims = (
|
|
(1, 2), # ragged, non-batch
|
|
(1, 3), # ragged, non-batch
|
|
)
|
|
|
|
for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if nt.dim() > reduce_dim[-1]:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"reducing along a ragged and non-batch dimension is not supported",
|
|
):
|
|
out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_sum_dim_reduce_batch_and_non_batch(
|
|
self,
|
|
device,
|
|
dtype,
|
|
keepdim,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Sum on NestedTensor fails when trying to reduce across batch and non-batch dimensions
|
|
"""
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False, include_requires_grad=components_require_grad
|
|
)
|
|
reduce_dims = (
|
|
(0, 2), # batch, non-batch
|
|
(0, 3), # batch, non-batch
|
|
)
|
|
|
|
for tensor_list, reduce_dim in itertools.product(tensor_lists, reduce_dims):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
if nt.dim() > reduce_dim[-1]:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"reducing along the batch dimension but not the ragged dimension "
|
|
+ "is not supported",
|
|
):
|
|
out = torch.sum(nt, dim=reduce_dim, keepdim=keepdim)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"func",
|
|
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
|
|
name_fn=get_op_name,
|
|
)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_op_dim_reduce_batch_only_different_output_shape(
|
|
self, device, dtype, keepdim, requires_grad, components_require_grad, func
|
|
):
|
|
"""
|
|
Operator on NestedTensor fails when trying to reduce across batch dimension
|
|
"""
|
|
if get_op_name(func) == "mean" and not keepdim:
|
|
return
|
|
|
|
tensor_lists = self._get_example_tensor_lists(
|
|
include_list_of_lists=False, include_requires_grad=components_require_grad
|
|
)
|
|
reduce_dim = (0,) # batch
|
|
|
|
for tensor_list in tensor_lists:
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"reducing along the batch dimension but not the ragged dimension "
|
|
+ "is not supported",
|
|
):
|
|
out = func(nt, dim=reduce_dim, keepdim=keepdim)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize(
|
|
"func",
|
|
[torch.ops.aten.sum.dim_IntList, torch.ops.aten.mean.dim],
|
|
name_fn=get_op_name,
|
|
)
|
|
@parametrize("keepdim", [False, True])
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_op_dim_with_lengths_different_output_shape(
|
|
self,
|
|
device,
|
|
dtype,
|
|
keepdim,
|
|
requires_grad,
|
|
components_require_grad,
|
|
func,
|
|
):
|
|
"""
|
|
Operator on NestedTensor fails when trying to reduce a nested tensor with lengths,
|
|
i.e. a nested tensor with holes, if reducing on the ragged dimension.
|
|
This test is for operators which return an output tensor with different shape than the input tensor.
|
|
"""
|
|
if get_op_name(func) == "mean" and not keepdim:
|
|
return
|
|
|
|
reduce_dims = ((1,), (2,), (2, 3))
|
|
|
|
lengths = torch.randint(5, 10, (20,), device=device)
|
|
offsets = torch.zeros((21,), device=device, dtype=torch.int)
|
|
torch.cumsum(lengths, dim=0, out=offsets[1:])
|
|
|
|
values = torch.randn(
|
|
(offsets[-1].item(), 20),
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
nt_with_holes = torch.nested.nested_tensor_from_jagged(
|
|
values,
|
|
offsets,
|
|
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
|
|
)
|
|
|
|
for reduce_dim in reduce_dims:
|
|
if nt_with_holes.dim() > reduce_dim[-1]:
|
|
if nt_with_holes._ragged_idx in reduce_dim:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"reducing across the ragged dimension is not supported for "
|
|
+ "non-contiguous nested tensors with holes",
|
|
):
|
|
out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
|
|
else:
|
|
out = func(nt_with_holes, dim=reduce_dim, keepdim=keepdim)
|
|
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_softmax_dim_with_lengths(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Softmax on NestedTensor fails when trying to reduce a nested tensor with lengths,
|
|
i.e. a nested tensor with holes, if reducing on the ragged dimension.
|
|
"""
|
|
reduce_dims = (1, 2, 3)
|
|
|
|
lengths = torch.randint(5, 10, (20,), device=device)
|
|
offsets = torch.zeros((21,), device=device, dtype=torch.int)
|
|
torch.cumsum(lengths, dim=0, out=offsets[1:])
|
|
|
|
values = torch.randn(
|
|
(offsets[-1].item(), 20),
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
nt_with_holes = torch.nested.nested_tensor_from_jagged(
|
|
values,
|
|
offsets,
|
|
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
|
|
)
|
|
|
|
for reduce_dim in reduce_dims:
|
|
if nt_with_holes.dim() > reduce_dim:
|
|
if nt_with_holes._ragged_idx == reduce_dim:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"not supported where lengths is not None "
|
|
+ "if reducing across the ragged dimension for NestedTensor",
|
|
):
|
|
out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
|
|
else:
|
|
out = torch.nn.functional.softmax(nt_with_holes, dim=reduce_dim)
|
|
|
|
@skipIfTorchDynamo(
|
|
"ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx] does not currently work "
|
|
+ "with dynamo tests and throws this error: `AssertionError: SymInts must use SymNodeVariable. "
|
|
+ "If the underlying value is static, we will create a ConstantVariable and specialize.`"
|
|
)
|
|
@dtypes(torch.float32)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_layer_norm_with_lengths(
|
|
self,
|
|
device,
|
|
dtype,
|
|
requires_grad,
|
|
components_require_grad,
|
|
):
|
|
"""
|
|
Layer normalization on NestedTensor fails when trying to operate on a nested tensor with lengths,
|
|
i.e. a nested tensor with holes, if operating on the ragged dimension.
|
|
"""
|
|
|
|
# create components for nested tensor
|
|
lengths = torch.randint(5, 10, (20,), device=device)
|
|
offsets = torch.zeros((21,), device=device, dtype=torch.int)
|
|
torch.cumsum(lengths, dim=0, out=offsets[1:])
|
|
values = torch.randn(
|
|
(offsets[-1].item(), 10, 30),
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
nt_with_holes = torch.nested.nested_tensor_from_jagged(
|
|
values,
|
|
offsets,
|
|
lengths=offsets.diff() - 2, # arbitrary subtraction to create holes
|
|
)
|
|
|
|
ragged_size = nt_with_holes.shape[nt_with_holes._ragged_idx]
|
|
|
|
normalized_shapes = (
|
|
(10, 30), # normalization on non-ragged dimension passes
|
|
(ragged_size, 10, 30), # normalization on ragged dimension fails
|
|
)
|
|
|
|
for normalized_shape in normalized_shapes:
|
|
if ragged_size in normalized_shape:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"not supported where lengths is not None if operating on the ragged dimension for NestedTensor",
|
|
):
|
|
out = torch.nn.functional.layer_norm(
|
|
nt_with_holes, normalized_shape=normalized_shape
|
|
)
|
|
else:
|
|
out = torch.nn.functional.layer_norm(
|
|
nt_with_holes, normalized_shape=normalized_shape
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
|
|
)
|
|
@onlyCUDA
|
|
def test_pin_memory(self, device):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
for nt in [nt_contiguous, nt_noncontiguous]:
|
|
self.assertFalse(nt.is_pinned())
|
|
pinned = nt.pin_memory()
|
|
self.assertTrue(pinned.is_pinned())
|
|
self.assertEqual(nt, pinned)
|
|
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
|
|
# test that pin_memory on already pinned tensor has no effect
|
|
self.assertIs(pinned, pinned.pin_memory())
|
|
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
|
|
|
@torch.compiler.disable
|
|
def _validate_nt(
|
|
self,
|
|
nt,
|
|
device,
|
|
dtype,
|
|
layout,
|
|
requires_grad,
|
|
dim,
|
|
batch_size,
|
|
contiguous,
|
|
cached_min_seqlen=None,
|
|
cached_max_seqlen=None,
|
|
base=None,
|
|
ref_nt=None,
|
|
):
|
|
# Validate a bunch of properties after NT construction.
|
|
device = torch.device(device)
|
|
self.assertEqual(nt.dim(), dim)
|
|
self.assertEqual(nt.device, device)
|
|
self.assertEqual(nt.dtype, dtype)
|
|
self.assertEqual(nt.layout, layout)
|
|
self.assertEqual(nt.requires_grad, requires_grad)
|
|
self.assertEqual(nt.is_contiguous(), contiguous)
|
|
|
|
if layout == torch.jagged:
|
|
self.assertEqual(nt._values.device, device)
|
|
self.assertEqual(nt._offsets.device, device)
|
|
self.assertEqual(nt.shape[0], batch_size)
|
|
self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
|
|
|
|
if base is not None:
|
|
self.assertTrue(nt._is_view() and nt._base is base)
|
|
replay_cache = nt._view_func(torch.randn_like(nt._base))._metadata_cache
|
|
self.assertEqual(
|
|
"min_seqlen" in replay_cache, cached_min_seqlen is not None
|
|
)
|
|
self.assertEqual(
|
|
"max_seqlen" in replay_cache, cached_max_seqlen is not None
|
|
)
|
|
|
|
self.assertEqual(
|
|
"min_seqlen" in nt._metadata_cache, cached_min_seqlen is not None
|
|
)
|
|
self.assertEqual(
|
|
"max_seqlen" in nt._metadata_cache, cached_max_seqlen is not None
|
|
)
|
|
|
|
if cached_min_seqlen is not None:
|
|
self.assertEqual(nt._min_seqlen, cached_min_seqlen)
|
|
|
|
if cached_max_seqlen is not None:
|
|
self.assertEqual(nt._max_seqlen, cached_max_seqlen)
|
|
|
|
if ref_nt is not None:
|
|
self.assertEqual(nt.size(0), ref_nt.size(0))
|
|
for n1, n2 in zip(nt.unbind(), ref_nt.unbind()):
|
|
self.assertEqual(n1, n2)
|
|
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_jagged_layout_construction_nested_tensor(
|
|
self, device, dtype, requires_grad, components_require_grad
|
|
):
|
|
for tensor_list in self._get_example_tensor_lists(
|
|
include_list_of_lists=True, include_requires_grad=components_require_grad
|
|
):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
|
|
expected_batch_size = len(tensor_list)
|
|
expected_contiguous = True
|
|
expected_min_seqlen = min(
|
|
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
|
|
for t in tensor_list
|
|
)
|
|
expected_max_seqlen = max(
|
|
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
|
|
for t in tensor_list
|
|
)
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
torch.jagged,
|
|
requires_grad,
|
|
expected_dim,
|
|
expected_batch_size,
|
|
expected_contiguous,
|
|
expected_min_seqlen,
|
|
expected_max_seqlen,
|
|
)
|
|
|
|
# Make sure grads -don't- flow back into original tensors for nested_tensor()
|
|
if requires_grad:
|
|
(nt * 2).backward(torch.ones_like(nt))
|
|
for t in tensor_list:
|
|
t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
|
|
self.assertTrue(t.grad is None)
|
|
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_jagged_layout_construction_as_nested_tensor(
|
|
self, device, dtype, components_require_grad
|
|
):
|
|
# NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
|
|
for tensor_list in self._get_example_tensor_lists(
|
|
include_list_of_lists=False, include_requires_grad=components_require_grad
|
|
):
|
|
nt = torch.nested.as_nested_tensor(
|
|
tensor_list, device=device, dtype=dtype, layout=torch.jagged
|
|
)
|
|
|
|
# nt.requires_grad=True should be set if at least one component requires grad
|
|
expected_dim = tensor_list[0].dim() + 1
|
|
expected_batch_size = len(tensor_list)
|
|
expected_contiguous = True
|
|
expected_min_seqlen = min(
|
|
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
|
|
for t in tensor_list
|
|
)
|
|
expected_max_seqlen = max(
|
|
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
|
|
for t in tensor_list
|
|
)
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
torch.jagged,
|
|
components_require_grad,
|
|
expected_dim,
|
|
expected_batch_size,
|
|
expected_contiguous,
|
|
expected_min_seqlen,
|
|
expected_max_seqlen,
|
|
)
|
|
|
|
# Make sure grads flow back into original tensors for as_nested_tensor()
|
|
if components_require_grad:
|
|
(nt * 2).backward(torch.ones_like(nt))
|
|
for t in tensor_list:
|
|
if t.requires_grad:
|
|
self.assertEqual(t.grad, torch.ones_like(t) * 2)
|
|
else:
|
|
self.assertTrue(t.grad is None)
|
|
|
|
@xfailIfTorchDynamo
|
|
@unittest.skipIf(
|
|
PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property"
|
|
)
|
|
@onlyCUDA
|
|
def test_jagged_layout_construction_with_pinned_memory(self, device):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list, layout=torch.jagged, device="cpu", pin_memory=True
|
|
)
|
|
|
|
expected_dim = torch.as_tensor(tensor_list[0]).dim() + 1
|
|
expected_batch_size = len(tensor_list)
|
|
expected_min_seqlen = min(
|
|
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
|
|
for t in tensor_list
|
|
)
|
|
expected_max_seqlen = max(
|
|
(torch.tensor(t) if isinstance(t, list) else t).shape[0]
|
|
for t in tensor_list
|
|
)
|
|
self._validate_nt(
|
|
nt,
|
|
device="cpu",
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
requires_grad=False,
|
|
dim=expected_dim,
|
|
batch_size=expected_batch_size,
|
|
contiguous=True,
|
|
cached_min_seqlen=expected_min_seqlen,
|
|
cached_max_seqlen=expected_max_seqlen,
|
|
)
|
|
self.assertTrue(nt.is_pinned())
|
|
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("values_is_view", [False, True])
|
|
def test_jagged_view_from_values_offsets(
|
|
self, device, dtype, requires_grad, values_is_view
|
|
):
|
|
if values_is_view:
|
|
# make values a view of base
|
|
base = torch.randn(
|
|
2, 3, 4, 5, 6, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
values = base.flatten(0, -2)
|
|
else:
|
|
values = torch.randn(
|
|
10, 5, device=device, dtype=dtype, requires_grad=requires_grad
|
|
)
|
|
offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
|
|
|
|
nt = nested_view_from_values_offsets(values, offsets)
|
|
|
|
expected_dim = values.dim() + 1
|
|
expected_batch_size = offsets.shape[0] - 1
|
|
expected_base = base if values_is_view else values
|
|
lengths = offsets.diff()
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
torch.jagged,
|
|
requires_grad,
|
|
expected_dim,
|
|
expected_batch_size,
|
|
# ensure NT is a proper view
|
|
base=expected_base,
|
|
contiguous=True,
|
|
# if no min / max are passed, expect the metadata cache to be empty
|
|
cached_min_seqlen=None,
|
|
cached_max_seqlen=None,
|
|
)
|
|
|
|
if requires_grad:
|
|
# Make sure grads flow back
|
|
(nt * 2).backward(torch.ones_like(nt))
|
|
|
|
@torch.compiler.disable
|
|
def _check_grad(t):
|
|
self.assertTrue(t.grad is not None)
|
|
self.assertEqual(t.grad, torch.ones_like(t) * 2)
|
|
|
|
_check_grad(base if values_is_view else values)
|
|
|
|
@dtypes(torch.float)
|
|
@parametrize("pass_min_max", [False, True])
|
|
def test_nested_tensor_from_jagged(self, device, dtype, pass_min_max):
|
|
# === construct from (values, offsets) ===
|
|
values = torch.randn(10, 5, device=device, dtype=dtype)
|
|
offsets = torch.tensor([0, 2, 4, 6, 10], device=device, dtype=torch.int64)
|
|
|
|
# compute min / max seqlen
|
|
lengths = offsets.diff()
|
|
min_seqlen = lengths.min().item()
|
|
max_seqlen = lengths.max().item()
|
|
|
|
if pass_min_max:
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets=offsets, min_seqlen=min_seqlen, max_seqlen=max_seqlen
|
|
)
|
|
else:
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets=offsets)
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
torch.jagged,
|
|
requires_grad=False,
|
|
dim=3,
|
|
batch_size=4,
|
|
contiguous=True,
|
|
cached_min_seqlen=(min_seqlen if pass_min_max else None),
|
|
cached_max_seqlen=(max_seqlen if pass_min_max else None),
|
|
base=values,
|
|
)
|
|
|
|
# === construct from (values, offsets, lengths) ===
|
|
lengths = torch.tensor([2, 1, 1, 2], device=device)
|
|
|
|
# compute min / max seqlen
|
|
min_seqlen = lengths.min().item()
|
|
max_seqlen = lengths.max().item()
|
|
|
|
if pass_min_max:
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values,
|
|
offsets=offsets,
|
|
lengths=lengths,
|
|
min_seqlen=min_seqlen,
|
|
max_seqlen=max_seqlen,
|
|
)
|
|
else:
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets=offsets, lengths=lengths
|
|
)
|
|
|
|
# when both offsets / lengths are specified, expect non-contiguous
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
torch.jagged,
|
|
requires_grad=False,
|
|
dim=3,
|
|
batch_size=4,
|
|
contiguous=False,
|
|
cached_min_seqlen=(min_seqlen if pass_min_max else None),
|
|
cached_max_seqlen=(max_seqlen if pass_min_max else None),
|
|
base=values,
|
|
)
|
|
self.assertIs(nt.lengths(), lengths)
|
|
|
|
# === construct from (values, lengths) ===
|
|
values = torch.randn(14, 5, device=device, dtype=dtype)
|
|
lengths = torch.tensor([2, 3, 4, 5], device=device)
|
|
|
|
# compute min / max seqlen
|
|
min_seqlen = lengths.min().item()
|
|
max_seqlen = lengths.max().item()
|
|
|
|
if pass_min_max:
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values, lengths=lengths, min_seqlen=min_seqlen, max_seqlen=max_seqlen
|
|
)
|
|
else:
|
|
nt = torch.nested.nested_tensor_from_jagged(values, lengths=lengths)
|
|
|
|
# for now, if only lengths is specified, convert to offsets to integrate best with the
|
|
# existing kernels
|
|
expected_offsets = torch.tensor([0, 2, 5, 9, 14], device=device)
|
|
expected_nt = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets=expected_offsets
|
|
)
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
torch.jagged,
|
|
requires_grad=False,
|
|
dim=3,
|
|
batch_size=4,
|
|
contiguous=True,
|
|
cached_min_seqlen=(min_seqlen if pass_min_max else None),
|
|
cached_max_seqlen=(max_seqlen if pass_min_max else None),
|
|
base=values,
|
|
ref_nt=expected_nt,
|
|
)
|
|
|
|
# error case: no offsets or lengths
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "At least one of offsets or lengths is required"
|
|
):
|
|
torch.nested.nested_tensor_from_jagged(values, offsets=None, lengths=None)
|
|
|
|
with self.assertRaisesRegex(ValueError, "Expected jagged_dim >=1, but got 0."):
|
|
torch.nested.nested_tensor_from_jagged(
|
|
values, lengths=lengths, jagged_dim=0
|
|
)
|
|
|
|
@onlyCPU
|
|
def test_nested_tensor_from_jagged_fx_trace(self, device):
|
|
def fn(x, y):
|
|
return torch.nested.nested_tensor_from_jagged(x, y)
|
|
|
|
def user_unwrapped(x, y):
|
|
return fn(x, y)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"torch.nested.nested_tensor_from_jagged does not support tracing with fx.symbolic_trace",
|
|
):
|
|
torch.fx.symbolic_trace(user_unwrapped)
|
|
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("dim", range(5))
|
|
@parametrize(
|
|
"layout",
|
|
[torch.strided, torch.jagged],
|
|
name_fn=lambda l: f"layout_{str(l).split('.')[1]}",
|
|
)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("contiguous", [False, True])
|
|
def test_as_nested_tensor_from_tensor(
|
|
self, device, dtype, dim, layout, requires_grad, contiguous
|
|
):
|
|
if dim == 0:
|
|
t = torch.tensor(3.0, requires_grad=requires_grad)
|
|
else:
|
|
t = torch.randn(*(3 for _ in range(dim)), requires_grad=requires_grad)
|
|
assert t.dim() == dim
|
|
|
|
if dim < 2:
|
|
# 0-1 dim tensors can't be converted to NTs
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Expected tensor argument to have dim"
|
|
):
|
|
nt = torch.nested.as_nested_tensor(
|
|
t, device=device, dtype=dtype, layout=layout
|
|
)
|
|
return
|
|
|
|
orig_t = t
|
|
if not contiguous:
|
|
t = t.transpose(0, 1)
|
|
|
|
nt = torch.nested.as_nested_tensor(t, device=device, dtype=dtype, layout=layout)
|
|
expected_dim = t.dim()
|
|
expected_batch_size = t.size(0)
|
|
expected_seqlen = t.size(1) if layout == torch.jagged else None
|
|
self._validate_nt(
|
|
nt,
|
|
device,
|
|
dtype,
|
|
layout,
|
|
requires_grad=requires_grad,
|
|
dim=dim,
|
|
batch_size=expected_batch_size,
|
|
contiguous=True,
|
|
cached_min_seqlen=expected_seqlen,
|
|
cached_max_seqlen=expected_seqlen,
|
|
)
|
|
|
|
if torch.device(device) == t.device and dtype == t.dtype and contiguous:
|
|
# should be the non-copying (view) case
|
|
self.assertTrue(nt._is_view() and nt._base is t)
|
|
|
|
# should have equivalent components to construction from unbound tensor list
|
|
nt_from_unbind = torch.nested.as_nested_tensor(
|
|
list(t.unbind(0)), device=device, dtype=dtype, layout=layout
|
|
)
|
|
self.assertEqualIgnoringNestedInts(nt, nt_from_unbind)
|
|
|
|
# ensure call on a NT with the same properties returns the NT directly
|
|
nt2 = torch.nested.as_nested_tensor(
|
|
nt, device=device, dtype=dtype, layout=layout
|
|
)
|
|
self.assertTrue(nt is nt2)
|
|
|
|
# ensure call with device=None uses input tensor device
|
|
nt3 = torch.nested.as_nested_tensor(
|
|
t.to(device=device, dtype=dtype),
|
|
device=None,
|
|
dtype=None,
|
|
layout=layout,
|
|
)
|
|
self._validate_nt(
|
|
nt3,
|
|
device,
|
|
dtype,
|
|
layout,
|
|
requires_grad=requires_grad,
|
|
dim=dim,
|
|
batch_size=expected_batch_size,
|
|
contiguous=True,
|
|
cached_min_seqlen=expected_seqlen,
|
|
cached_max_seqlen=expected_seqlen,
|
|
)
|
|
|
|
# we don't support conversion between layouts this way atm
|
|
other_layout = torch.strided if layout == torch.jagged else torch.jagged
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Converting between nested tensor layouts is not supported"
|
|
):
|
|
torch.nested.as_nested_tensor(
|
|
nt, device=device, dtype=dtype, layout=other_layout
|
|
)
|
|
|
|
if requires_grad:
|
|
# make sure gradients flow back into inputs
|
|
(nt * 2).backward(torch.ones_like(nt))
|
|
self.assertEqual(orig_t.grad, torch.ones_like(orig_t) * 2)
|
|
|
|
@dtypes(torch.float32)
|
|
def test_construction_from_list(self, device, dtype):
|
|
from torch.fx.experimental.symbolic_shapes import is_nested_int
|
|
|
|
# success case: single ragged dim anywhere but the batch dim
|
|
for nt_dim in [2, 3, 4]:
|
|
for ragged_dim in range(1, nt_dim):
|
|
B = 6
|
|
shapes = [list(range(3, 3 + nt_dim - 1)) for _ in range(B)]
|
|
for b in range(B):
|
|
# subtract 1 to convert to component dim space
|
|
shapes[b][ragged_dim - 1] = torch.randint(
|
|
2, 9, (1,), device=device, dtype=torch.int64
|
|
).item()
|
|
|
|
components = [
|
|
torch.randn(shape, device=device, dtype=dtype) for shape in shapes
|
|
]
|
|
nt = torch.nested.nested_tensor(components, layout=torch.jagged)
|
|
|
|
self.assertEqual(nt.dim(), nt_dim)
|
|
self.assertEqual(nt._ragged_idx, ragged_dim)
|
|
for d in range(nt_dim):
|
|
self.assertEqual(d == ragged_dim, is_nested_int(nt.shape[d]))
|
|
|
|
# error case: empty list
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "Cannot construct a nested tensor from an empty tensor list"
|
|
):
|
|
torch.nested.nested_tensor([], layout=torch.jagged)
|
|
|
|
# error case: list of zero-dim tensors
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot construct a nested tensor from a list of zero-dim tensors",
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[
|
|
torch.tensor(3.0, device=device, dtype=dtype),
|
|
torch.tensor(4.0, device=device, dtype=dtype),
|
|
torch.tensor(5.0, device=device, dtype=dtype),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# error case: multiple ragged dims
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot represent given tensor list as a nested tensor with the jagged layout",
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 3, device=device, dtype=dtype),
|
|
torch.randn(4, 5, device=device, dtype=dtype),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# error case: components on multiple devices
|
|
if "cuda" in device:
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"When constructing a nested tensor, all tensors in list must be on the same device",
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 3, device=device, dtype=dtype),
|
|
torch.randn(2, 4, device="cpu", dtype=dtype),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# error case: components with multiple dtypes
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"When constructing a nested tensor, all tensors in list must have the same dtype",
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 3, device=device, dtype=dtype),
|
|
torch.randn(2, 4, device=device, dtype=torch.float64),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# error case: components with multiple dims
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"When constructing a nested tensor, all tensors in list must have the same dim",
|
|
):
|
|
torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 3, device=device, dtype=dtype),
|
|
torch.randn(2, 3, 4, device=device, dtype=dtype),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
@dtypes(torch.double, torch.half)
|
|
@onlyCUDA
|
|
def test_device_dtype_transfer_updates_offsets(self, device, dtype):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
orig_device = torch.device("cpu")
|
|
orig_dtype = torch.float32
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list, layout=torch.jagged, device=orig_device, dtype=orig_dtype
|
|
)
|
|
|
|
self.assertEqual(torch.int64, nt.offsets().dtype)
|
|
nt = nt.to(device=device).to(dtype=dtype)
|
|
|
|
# offsets should still be int64 on the new device
|
|
self.assertEqual(nt.values().device, nt.offsets().device)
|
|
self.assertEqual(torch.int64, nt.offsets().dtype)
|
|
|
|
def test_unbind(self, device):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list, layout=torch.jagged, device=device
|
|
) # ragged_idx = 1
|
|
out = nt.unbind()
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(t, tensor_list[i])
|
|
|
|
@parametrize("ragged_idx", [2, 3])
|
|
def test_unbind_transpose(self, device, ragged_idx):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list, layout=torch.jagged, device=device
|
|
)
|
|
if ragged_idx < nt.dim():
|
|
nt = nt.transpose(1, ragged_idx) # set ragged_idx
|
|
out = nt.unbind()
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(
|
|
t.transpose(0, ragged_idx - 1), tensor_list[i]
|
|
) # transpose back each element of result
|
|
|
|
def test_unbind_transpose_ragged_idx_last_dim(self, device):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list, layout=torch.jagged, device=device
|
|
).transpose(1, -1) # set ragged_idx = last dimension
|
|
out = nt.unbind()
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(
|
|
t.transpose(0, -1), tensor_list[i]
|
|
) # transpose back each element of result
|
|
|
|
def test_unbind_lengths(self, device):
|
|
values = torch.randn(16, 128, device=device)
|
|
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
|
|
lengths = torch.tensor([6, 2, 1, 2], device=device)
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values, offsets=offsets, lengths=lengths
|
|
) # 3D nested tensor
|
|
|
|
tensor_list = []
|
|
for i in range(offsets.shape[0] - 1):
|
|
tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i])])
|
|
|
|
out = nt.unbind()
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(t, tensor_list[i])
|
|
|
|
def test_unbind_lengths_ragged_idx_1(self, device):
|
|
values = torch.randn(16, 8, 128, device=device)
|
|
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
|
|
lengths = torch.tensor([6, 2, 1, 2], device=device)
|
|
ragged_idx = 1
|
|
nt = torch.nested._internal.nested_tensor.NestedTensor(
|
|
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
|
|
) # 4D nested tensor
|
|
|
|
tensor_list = []
|
|
for i in range(offsets.shape[0] - 1):
|
|
tensor_list.append(values[offsets[i] : (offsets[i] + lengths[i]), :, :])
|
|
|
|
out = nt.unbind()
|
|
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(t, tensor_list[i])
|
|
|
|
def test_unbind_lengths_ragged_idx_equals_2_bad_dim(self, device):
|
|
values = torch.randn(16, 8, 128, device=device)
|
|
offsets = torch.tensor([0, 8, 12, 13, 16], device=device)
|
|
lengths = torch.tensor([6, 2, 1, 2], device=device)
|
|
ragged_idx = 2
|
|
nt = torch.nested._internal.nested_tensor.NestedTensor(
|
|
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
|
|
) # 4D nested tensor
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"unbind\(\): nested tensor offsets and lengths.*",
|
|
lambda: nt.unbind(),
|
|
)
|
|
|
|
def test_unbind_lengths_ragged_idx_2(self, device):
|
|
values = torch.randn(16, 8, 128, device=device)
|
|
offsets = torch.tensor([0, 2, 4, 8], device=device)
|
|
lengths = torch.tensor([2, 1, 3], device=device)
|
|
ragged_idx = 2
|
|
nt = torch.nested._internal.nested_tensor.NestedTensor(
|
|
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
|
|
) # 4D nested tensor
|
|
|
|
tensor_list = []
|
|
for i in range(offsets.shape[0] - 1):
|
|
tensor_list.append(values[:, offsets[i] : (offsets[i] + lengths[i]), :])
|
|
|
|
out = nt.unbind()
|
|
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(t, tensor_list[i])
|
|
|
|
def test_unbind_lengths_ragged_idx_3(self, device):
|
|
values = torch.randn(16, 8, 128, device=device)
|
|
offsets = torch.tensor([0, 100, 128], device=device)
|
|
lengths = torch.tensor([50, 28], device=device)
|
|
ragged_idx = 3
|
|
nt = torch.nested._internal.nested_tensor.NestedTensor(
|
|
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
|
|
) # 4D nested tensor
|
|
|
|
tensor_list = []
|
|
for i in range(offsets.shape[0] - 1):
|
|
tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
|
|
|
|
out = nt.unbind()
|
|
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(t, tensor_list[i])
|
|
|
|
@skipIfTorchDynamo(
|
|
"TorchDynamo raises an error for ragged_idx == 0 earlier than Torch"
|
|
)
|
|
def test_unbind_lengths_ragged_idx_0(self, device):
|
|
values = torch.randn(16, 8, 128, device=device)
|
|
offsets = torch.tensor([0, 100, 128], device=device)
|
|
lengths = torch.tensor([50, 28], device=device)
|
|
ragged_idx = 0
|
|
nt = torch.nested._internal.nested_tensor.NestedTensor(
|
|
values, offsets=offsets, lengths=lengths, _ragged_idx=ragged_idx
|
|
) # 4D nested tensor
|
|
|
|
tensor_list = []
|
|
for i in range(offsets.shape[0] - 1):
|
|
tensor_list.append(values[:, :, offsets[i] : (offsets[i] + lengths[i])])
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"unbind\(\): nested tensor.*out of bounds",
|
|
lambda: nt.unbind(),
|
|
)
|
|
|
|
def test_narrow(self, device):
|
|
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
|
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
|
buffer = (
|
|
torch.arange(0, 10, device=device, dtype=torch.int64)
|
|
.unsqueeze(0)
|
|
.expand(5, -1)
|
|
.clone()
|
|
.detach()
|
|
)
|
|
nt = torch.nested.narrow(buffer, 1, starts, lengths, layout=torch.jagged)
|
|
|
|
self.assertTrue(nt._is_view() and nt._base is buffer)
|
|
|
|
# TODO: Use this approach when unbind is functional
|
|
# unbinded_nt = nt.unbind()
|
|
# for i in range(starts.shape[0]):
|
|
# self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
|
|
for i in range(starts.shape[0]):
|
|
self.assertEqual(
|
|
torch.arange(
|
|
starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64
|
|
),
|
|
nt.values()[nt.offsets()[i] : (nt.offsets()[i] + nt.lengths()[i])],
|
|
)
|
|
|
|
def test_njt_cat(self, device):
|
|
offsets = torch.tensor([0, 2, 3], device=device, dtype=torch.int64)
|
|
values_1 = torch.randn(
|
|
3, 2, dtype=torch.float64, device=device, requires_grad=True
|
|
)
|
|
values_2 = torch.randn(
|
|
3, 4, dtype=torch.float64, device=device, requires_grad=True
|
|
)
|
|
|
|
def grad_test_func(values_1, values_2, offsets):
|
|
nt_1 = torch.nested.nested_tensor_from_jagged(values_1, offsets)
|
|
nt_2 = torch.nested.nested_tensor_from_jagged(values_2, offsets)
|
|
nt_3 = torch.cat([nt_1, nt_2], dim=-1)
|
|
return nt_3.values()
|
|
|
|
assert gradcheck(
|
|
grad_test_func,
|
|
inputs=(values_1, values_2, offsets),
|
|
check_batched_grad=False,
|
|
)
|
|
|
|
def test_is_contiguous(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
nt_contiguous = torch.nested.as_nested_tensor([a, b, c], layout=torch.jagged)
|
|
|
|
starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
|
lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
|
narrow_base = (
|
|
torch.arange(0, 10, device=device, dtype=torch.int64)
|
|
.unsqueeze(0)
|
|
.expand(5, -1)
|
|
.clone()
|
|
)
|
|
nt_noncontiguous = torch.nested.narrow(
|
|
narrow_base, 1, starts_nc, lengths_nc, layout=torch.jagged
|
|
)
|
|
|
|
starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
|
|
lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
|
|
nt_contiguous_narrow = torch.nested.narrow(
|
|
narrow_base, 1, starts_c, lengths_c, layout=torch.jagged
|
|
)
|
|
|
|
# Test contiguous case
|
|
assert nt_contiguous.is_contiguous()
|
|
|
|
# Test narrow case
|
|
assert not nt_noncontiguous.is_contiguous()
|
|
assert nt_contiguous_narrow.is_contiguous()
|
|
|
|
# Test querying by memory_format
|
|
self.assertTrue(
|
|
nt_contiguous.is_contiguous(memory_format=torch.contiguous_format)
|
|
)
|
|
self.assertTrue(
|
|
not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format)
|
|
)
|
|
self.assertTrue(
|
|
nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format)
|
|
)
|
|
|
|
def test_layout_under_torch_dispatch_mode(self):
|
|
from torch.testing._internal.logging_tensor import (
|
|
capture_logs_with_logging_tensor_mode,
|
|
)
|
|
|
|
nt = random_nt_from_dims(
|
|
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
|
|
)
|
|
|
|
with capture_logs_with_logging_tensor_mode():
|
|
self.assertEqual(nt.layout, torch.jagged)
|
|
|
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
|
@parametrize(
|
|
"func", [torch.empty_like, torch.randn_like], name_fn=lambda f: f.__name__
|
|
)
|
|
def test_like_shape(self, func):
|
|
nt = random_nt_from_dims(
|
|
[2, None, 3], torch.device("cpu"), torch.float32, layout=torch.jagged
|
|
)
|
|
nt_like = func(nt)
|
|
|
|
for nt_ub in nt_like.unbind():
|
|
t_like = func(nt_ub)
|
|
self.assertEqual(nt_ub.shape, t_like.shape)
|
|
|
|
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
|
|
@parametrize(
|
|
"func",
|
|
[
|
|
torch.empty_like,
|
|
torch.full_like,
|
|
torch.ones_like,
|
|
torch.rand_like,
|
|
torch.randint_like,
|
|
torch.randn_like,
|
|
torch.zeros_like,
|
|
],
|
|
name_fn=lambda f: f.__name__,
|
|
)
|
|
def test_like_value(self, func, device):
|
|
dtype = torch.float32 if func is not torch.randint_like else torch.int32
|
|
for nt in _sample_njts(device=device, dtype=dtype):
|
|
extra_kwarg_sets = [{}]
|
|
if func is torch.full_like:
|
|
extra_kwarg_sets = [{"fill_value": 4.2}]
|
|
elif func is torch.randint_like:
|
|
extra_kwarg_sets = [{"high": 5}, {"low": 4, "high": 9}]
|
|
|
|
# only test changing dtype / device from CUDA -> CPU because CUDA might not be
|
|
# available when running this test for CPU
|
|
change_dtype_device_settings = (
|
|
[False, True] if "cuda" in device else [False]
|
|
)
|
|
for change_dtype_device in change_dtype_device_settings:
|
|
if change_dtype_device:
|
|
new_dtype = (
|
|
torch.float64 if func is not torch.randint_like else torch.int64
|
|
)
|
|
new_device = "cpu" if "cuda" in device else device
|
|
new_layout = torch.strided
|
|
for extra_kwargs in extra_kwarg_sets:
|
|
extra_kwargs.update(
|
|
{
|
|
"dtype": new_dtype,
|
|
"device": new_device,
|
|
"layout": new_layout,
|
|
}
|
|
)
|
|
|
|
for extra_kwargs in extra_kwarg_sets:
|
|
nt_like = func(nt, **extra_kwargs)
|
|
self.assertEqual(nt.shape, nt_like.shape)
|
|
if change_dtype_device:
|
|
self.assertNotEqual(nt.device, nt_like.device)
|
|
self.assertNotEqual(nt.device, nt_like.dtype)
|
|
# layout should be ignored since only torch.jagged is supported
|
|
self.assertEqual(torch.jagged, nt_like.layout)
|
|
else:
|
|
self.assertEqual(nt.device, nt_like.device)
|
|
self.assertEqual(nt.dtype, nt_like.dtype)
|
|
self.assertEqual(nt.layout, nt_like.layout)
|
|
self.assertEqual(nt.layout, torch.jagged)
|
|
|
|
# don't bother trying to compare random or empty values
|
|
if func not in [
|
|
torch.empty_like,
|
|
torch.rand_like,
|
|
torch.randn_like,
|
|
torch.randint_like,
|
|
]:
|
|
for nt_ub in nt_like.unbind():
|
|
t_like = func(nt_ub, **extra_kwargs)
|
|
self.assertEqual(nt_ub, t_like)
|
|
|
|
def test_noncontiguous_pointwise(self, device):
|
|
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
nt = torch.nested.nested_tensor([a, b, c], layout=torch.jagged)
|
|
# transpose ragged dim
|
|
transposed = nt.transpose(1, 2)
|
|
self.assertFalse(transposed.is_contiguous())
|
|
clone = transposed.clone()
|
|
|
|
def check_nt_equality(x, y):
|
|
self.assertEqual(x.values(), y.values())
|
|
self.assertEqual(x.offsets(), y.offsets())
|
|
self.assertEqual(x._ragged_idx, y._ragged_idx)
|
|
self.assertEqual(x.shape, y.shape)
|
|
|
|
self.assertFalse(clone.is_contiguous())
|
|
check_nt_equality(clone, transposed)
|
|
|
|
clone_contig = transposed.clone(memory_format=torch.contiguous_format)
|
|
self.assertTrue(clone_contig.is_contiguous())
|
|
check_nt_equality(clone_contig, transposed)
|
|
|
|
detached = transposed.detach()
|
|
self.assertFalse(clone.is_contiguous())
|
|
check_nt_equality(detached, transposed)
|
|
|
|
def test_permute(self, device):
|
|
nt = random_nt_from_dims(
|
|
[2, None, 3, 5], device, torch.float32, layout=torch.jagged
|
|
)
|
|
nt_shape = nt.shape
|
|
nt_inner_shape = nt.values().shape
|
|
with self.assertRaisesRegex(
|
|
ValueError,
|
|
r"permute\(\): number of dimensions in the tensor input \(4\) "
|
|
+ r"does not match the length of the desired ordering of dimensions \(3\).",
|
|
):
|
|
nt.permute(0, 2, 1)
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"permute\(\): duplicate dims are not allowed."
|
|
):
|
|
nt.permute(0, 2, -2, 3)
|
|
with self.assertRaisesRegex(
|
|
ValueError, "Permute is not supported on the batch dimension for jagged NT"
|
|
):
|
|
nt.permute(1, 0, 2, 3)
|
|
nt_permute = nt.permute(0, 2, 1, -1)
|
|
self.assertEqual(
|
|
nt_permute.shape, (nt_shape[0], nt_shape[2], nt_shape[1], nt_shape[3])
|
|
)
|
|
self.assertEqual(
|
|
nt_permute.values().shape,
|
|
(nt_inner_shape[1], nt_inner_shape[0], nt_inner_shape[2]),
|
|
)
|
|
self.assertEqual(nt_permute._ragged_idx, 2)
|
|
self.assertEqual(nt_permute.permute(0, 2, 1, 3), nt)
|
|
|
|
def test_to_dtype(self, device):
|
|
nt = random_nt_from_dims(
|
|
[2, None, 3], device, torch.float32, layout=torch.jagged
|
|
)
|
|
nt_after = nt.to(torch.float64)
|
|
self.assertEqual(torch.float32, nt.dtype)
|
|
self.assertEqual(torch.float64, nt_after.dtype)
|
|
self.assertEqual(torch.float64, nt_after.values().dtype)
|
|
self.assertEqual(torch.int64, nt_after.offsets().dtype)
|
|
|
|
noncontiguous_nt = nt.transpose(1, 2)
|
|
noncontiguous_nt_after = noncontiguous_nt.to(torch.bfloat16)
|
|
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.dtype)
|
|
self.assertEqual(torch.bfloat16, noncontiguous_nt_after.values().dtype)
|
|
self.assertEqual(torch.int64, noncontiguous_nt_after.offsets().dtype)
|
|
|
|
def test_to_copy(self, device):
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(
|
|
i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
for i in range(3)
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
nt_copy_dtype = torch.ops.aten._to_copy(nt, dtype=torch.float16)
|
|
self.assertEqual(torch.float16, nt_copy_dtype.dtype)
|
|
|
|
nt_t = nt.transpose(1, 2)
|
|
nt_t_copy_dtype = torch.ops.aten._to_copy(nt_t, dtype=torch.float16)
|
|
self.assertEqual(torch.float16, nt_t_copy_dtype.dtype)
|
|
|
|
def test_copy_(self, device):
|
|
offsets = torch.tensor([0, 2, 4], device=device)
|
|
a = torch.nested.nested_tensor_from_jagged(
|
|
torch.zeros(4, 3, device=device), offsets
|
|
)
|
|
b = torch.nested.nested_tensor_from_jagged(
|
|
torch.ones(4, 3, device=device), offsets
|
|
)
|
|
a.copy_(b)
|
|
torch._dynamo.disable(self.assertEqual)(a, b)
|
|
|
|
offsets_2 = torch.tensor([0, 2, 4], device=device)
|
|
c = torch.nested.nested_tensor_from_jagged(
|
|
torch.ones(4, 3, device=device), offsets_2
|
|
)
|
|
# should work even though the nested ints are different due to unbound-based copy
|
|
a.copy_(c)
|
|
|
|
# fail when tensors have different sizes
|
|
a = a.transpose(1, 2)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"expected compatible input and src shapes, but got",
|
|
):
|
|
a.copy_(b)
|
|
|
|
# This can't happen in the opinfo tests due to subprocess creation
|
|
@unittest.skipIf(
|
|
TEST_WITH_ROCM,
|
|
"In ROCm, kernel asserts are disabled due to performance overhead",
|
|
)
|
|
def test_index_put_error(self, device):
|
|
import subprocess
|
|
|
|
with self.subTest():
|
|
r = subprocess.call(
|
|
[
|
|
sys.executable,
|
|
"-c",
|
|
"""\
|
|
import torch
|
|
offsets = torch.tensor([0, 2, 5, 7], device='cuda')
|
|
lengths = torch.tensor([2, 2, 2], device='cuda')
|
|
indices = [
|
|
torch.tensor([0, 1, 2], device='cuda'),
|
|
torch.tensor([0, 2, 1], device='cuda'),
|
|
torch.tensor([0, 0, 0], device='cuda'),
|
|
]
|
|
a = torch.nested.nested_tensor_from_jagged(
|
|
torch.zeros(7, 3, device='cuda'), offsets, lengths
|
|
)
|
|
a[indices] = 1.0
|
|
torch.cuda.synchronize()
|
|
""",
|
|
]
|
|
)
|
|
self.assertTrue(r != 0)
|
|
|
|
@skipIfTorchDynamo("Dynamo doesn't know how to trace prof.events()")
|
|
def test_profiler_sequence_nr(self):
|
|
with torch.profiler.profile() as prof:
|
|
values = torch.randn(4, 6, requires_grad=True)
|
|
offsets = torch.tensor([0, 2, 4])
|
|
values = values * 2
|
|
l = torch.nn.Linear(6, 8)
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
|
|
nt = l(nt)
|
|
val = nt.values()
|
|
|
|
loss = val.sum()
|
|
loss.backward()
|
|
|
|
fwd_seq_nrs = []
|
|
for evt in prof.events():
|
|
if (
|
|
"linear" in evt.name.lower()
|
|
and "backward" not in evt.name.lower()
|
|
and evt.sequence_nr != -1
|
|
):
|
|
fwd_seq_nrs.append(evt.sequence_nr)
|
|
|
|
bwd_seq_nrs = []
|
|
for evt in prof.events():
|
|
if (
|
|
"linear" in evt.name.lower()
|
|
and "backward" in evt.name.lower()
|
|
and "evaluate_function" not in evt.name.lower()
|
|
and evt.sequence_nr != -1
|
|
):
|
|
bwd_seq_nrs.append(evt.sequence_nr)
|
|
|
|
# There should only be one such event with a sequence number:
|
|
# the PythonTLSSnapshot event - but, note that it's not terrible if
|
|
# we end up with multiple events with the same sequence number - so we
|
|
# could relax this check if it becomes inconvenient to maintain this
|
|
# property.
|
|
self.assertEqual(len(fwd_seq_nrs), 1)
|
|
self.assertEqual(len(bwd_seq_nrs), 1)
|
|
self.assertEqual(fwd_seq_nrs[0], bwd_seq_nrs[0])
|
|
|
|
def test_is_same_size(self, device):
|
|
def get_3_tensors():
|
|
return [
|
|
torch.randn(
|
|
i + 2, 3, 4, requires_grad=True, dtype=torch.float64, device=device
|
|
)
|
|
for i in range(3)
|
|
]
|
|
|
|
nt1, offsets1 = jagged_from_list(get_3_tensors(), None)
|
|
nt2, offsets1 = jagged_from_list(get_3_tensors(), offsets1)
|
|
|
|
nt3, offsets2 = jagged_from_list(get_3_tensors(), None)
|
|
nt4, offsets2 = jagged_from_list(get_3_tensors(), offsets2)
|
|
|
|
def check_size(nt1, nt2, nt3, nt4):
|
|
self.assertTrue(torch.ops.aten.is_same_size(nt1, nt2))
|
|
self.assertTrue(torch.ops.aten.is_same_size(nt3, nt4))
|
|
self.assertFalse(torch.ops.aten.is_same_size(nt1, nt3))
|
|
|
|
check_size(nt1, nt2, nt3, nt4)
|
|
|
|
nt1_t, nt2_t, nt3_t, nt4_t = (x.transpose(1, 2) for x in (nt1, nt2, nt3, nt4))
|
|
check_size(nt1_t, nt2_t, nt3_t, nt4_t)
|
|
|
|
@skipIfTorchDynamo("compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
def test_specialize_dynamic_shape(self, device):
|
|
values = torch.randn((18, 16), device=device)
|
|
offsets = torch.tensor([0, 2, 3, 6, 15, 18], device=device)
|
|
like_values = torch.randn_like(values)
|
|
|
|
# this marks values as dynamic
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
|
|
def fn(values, same_size):
|
|
# here, the dynamic shape is specialized by same_size's shape
|
|
# https://github.com/pytorch/pytorch/issues/127097
|
|
# make sure this doesn't error out in torch.compile
|
|
return values + same_size
|
|
|
|
self.assertEqual(
|
|
fn(values, like_values),
|
|
torch.compile(fn)(values, like_values),
|
|
)
|
|
|
|
@skipIfTorchDynamo("compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
def test_specialize_dynamic_shape_recompile(self, device):
|
|
def generate_inp(total_len):
|
|
values = torch.randn((total_len, 16), device=device)
|
|
offsets = torch.tensor([0, 2, 3, 6, 15, total_len], device=device)
|
|
like_values = torch.randn_like(values)
|
|
return values, offsets, like_values
|
|
|
|
def check_results(ref_fn, res_fn, args):
|
|
values, offsets, like_values = args
|
|
# this may add dynamic shape markings
|
|
# goal of this test is to make sure that whatever markings are there,
|
|
# we eventually stop recompiling as shape changes.
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
|
|
self.assertEqual(ref_fn(values, like_values), res_fn(values, like_values))
|
|
|
|
def fn(values, same_size):
|
|
return values + same_size
|
|
|
|
compile_counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
compiled_fn = torch.compile(fn, backend=compile_counter, fullgraph=True)
|
|
check_results(fn, compiled_fn, generate_inp(18))
|
|
self.assertEqual(compile_counter.frame_count, 1)
|
|
|
|
check_results(fn, compiled_fn, generate_inp(19))
|
|
# we'll probably recompile here with dynamic shapes - it's okay if not though.
|
|
frame_count_2 = compile_counter.frame_count
|
|
self.assertIn(frame_count_2, [1, 2])
|
|
|
|
# make sure that by now we've already compiled with dynamic shapes, so additional
|
|
# shapes should not trigger additional recompiles.
|
|
check_results(fn, compiled_fn, generate_inp(20))
|
|
self.assertEqual(compile_counter.frame_count, frame_count_2)
|
|
|
|
# Note 1: Math fallback doesn't work with bfloat16 on CUDA
|
|
# Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
|
|
@unittest.skipIf(
|
|
TEST_WITH_ROCM,
|
|
"ROCm doesn't support flash attention or mem_efficient attention for NT",
|
|
)
|
|
@tf32_on_and_off(0.005)
|
|
@dtypes(
|
|
*(
|
|
[torch.float16, torch.bfloat16, torch.float32]
|
|
if SM80OrLater
|
|
else [torch.float16, torch.float32]
|
|
)
|
|
)
|
|
def test_sdpa(self, device, dtype):
|
|
batch_size = 1
|
|
emb_dims = 128
|
|
n_heads = 8
|
|
head_dims = emb_dims // n_heads
|
|
|
|
sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
|
|
sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
|
|
|
|
query = torch.nn.Linear(
|
|
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
|
|
)
|
|
key = torch.nn.Linear(
|
|
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
|
|
)
|
|
value = torch.nn.Linear(
|
|
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
|
|
)
|
|
|
|
# Simplest case: 1 sentence, no batching
|
|
x_d1 = sen1.unsqueeze(0)
|
|
x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged)
|
|
|
|
# See note below for why we detach here.
|
|
q_d1 = (
|
|
query(x_d1)
|
|
.view(batch_size, -1, n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
q_d1_t = q_d1.transpose(1, 2)
|
|
k_d1 = (
|
|
key(x_d1)
|
|
.view(batch_size, -1, n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
k_d1_t = k_d1.transpose(1, 2)
|
|
v_d1 = (
|
|
value(x_d1)
|
|
.view(batch_size, -1, n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
v_d1_t = v_d1.transpose(1, 2)
|
|
|
|
q_nt = (
|
|
query(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
q_nt_t = q_nt.transpose(1, 2)
|
|
k_nt = (
|
|
key(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
k_nt_t = k_nt.transpose(1, 2)
|
|
v_nt = (
|
|
value(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
v_nt_t = v_nt.transpose(1, 2)
|
|
|
|
# High Precision Math Reference
|
|
q_d1_f32 = q_d1.to(torch.float32)
|
|
k_d1_f32 = k_d1.to(torch.float32)
|
|
v_d1_f32 = v_d1.to(torch.float32)
|
|
q_d1_f32_t = q_d1_f32.transpose(1, 2)
|
|
k_d1_f32_t = k_d1_f32.transpose(1, 2)
|
|
v_d1_f32_t = v_d1_f32.transpose(1, 2)
|
|
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
|
q_d1_f32_t, k_d1_f32_t, v_d1_f32_t
|
|
)[0]
|
|
grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32))
|
|
|
|
# Low Precision Math Reference
|
|
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
|
q_d1_t, k_d1_t, v_d1_t
|
|
)[0]
|
|
grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1))
|
|
|
|
# Compute tolerances
|
|
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
|
|
# fudge factor of 1.7 for smaller GPUs e.g., A2, A16
|
|
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(
|
|
grads_ref[0], grads_lp_ref[0], 1.7
|
|
)
|
|
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1])
|
|
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2])
|
|
grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol]
|
|
grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol]
|
|
|
|
attn_d1 = torch.nn.functional.scaled_dot_product_attention(
|
|
q_d1_t, k_d1_t, v_d1_t
|
|
).transpose(1, 2)
|
|
attn_nt = torch.nn.functional.scaled_dot_product_attention(
|
|
q_nt_t, k_nt_t, v_nt_t
|
|
).transpose(1, 2)
|
|
|
|
self.assertEqual(
|
|
attn_d1,
|
|
attn_nt.unbind()[0].unsqueeze(0),
|
|
atol=output_ref_atol,
|
|
rtol=output_ref_rtol,
|
|
)
|
|
|
|
# Simple case: 2 sentences, no extra params
|
|
x_d2 = sen2.unsqueeze(0)
|
|
x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
|
|
|
|
# NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before
|
|
# it is transposed. This is because today we cannot backward through view or unbind a
|
|
# transposed tensor.
|
|
q_d2 = (
|
|
query(x_d2)
|
|
.view(batch_size, -1, n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
q_d2_t = q_d2.transpose(1, 2)
|
|
k_d2 = (
|
|
key(x_d2)
|
|
.view(batch_size, -1, n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
k_d2_t = k_d2.transpose(1, 2)
|
|
v_d2 = (
|
|
value(x_d2)
|
|
.view(batch_size, -1, n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
v_d2_t = v_d2.transpose(1, 2)
|
|
|
|
q_nt = (
|
|
query(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
q_nt_t = q_nt.transpose(1, 2)
|
|
k_nt = (
|
|
key(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
k_nt_t = k_nt.transpose(1, 2)
|
|
v_nt = (
|
|
value(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.requires_grad_(True)
|
|
)
|
|
v_nt_t = v_nt.transpose(1, 2)
|
|
|
|
attn_d2 = torch.nn.functional.scaled_dot_product_attention(
|
|
q_d2_t, k_d2_t, v_d2_t
|
|
).transpose(1, 2)
|
|
d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1))
|
|
d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2))
|
|
|
|
# Simple case 3: batch_size = 1, seq_len = 1
|
|
q_3 = torch.randn(1, 8, 16, dtype=dtype, device=device)
|
|
q_nt_3 = torch.nested.as_nested_tensor([q_3], layout=torch.jagged)
|
|
q_nt_3 = q_nt_3.transpose(1, 2)
|
|
attn_out = torch.nn.functional.scaled_dot_product_attention(
|
|
q_nt_3, q_nt_3, q_nt_3
|
|
)
|
|
self.assertEqual(attn_out.shape, q_nt_3.shape)
|
|
|
|
@parametrize("skip_backward", [True, False])
|
|
def check_forward_backward(skip_backward=False):
|
|
if not skip_backward:
|
|
attn_nt = torch.nn.functional.scaled_dot_product_attention(
|
|
q_nt_t, k_nt_t, v_nt_t
|
|
).transpose(1, 2)
|
|
else:
|
|
x_nt.requires_grad = False
|
|
q_nt.requires_grad = False
|
|
k_nt.requires_grad = False
|
|
v_nt.requires_grad = False
|
|
tq = q_nt_t.detach()
|
|
tk = k_nt_t.detach()
|
|
tv = v_nt_t.detach()
|
|
with torch.no_grad():
|
|
attn_nt = torch.nn.functional.scaled_dot_product_attention(
|
|
tq, tk, tv
|
|
).transpose(1, 2)
|
|
|
|
attn_nts = attn_nt.unbind()
|
|
self.assertEqual(
|
|
attn_d1,
|
|
attn_nts[0].unsqueeze(0),
|
|
atol=output_ref_atol,
|
|
rtol=output_ref_rtol,
|
|
)
|
|
self.assertEqual(
|
|
attn_d2,
|
|
attn_nts[1].unsqueeze(0),
|
|
atol=output_ref_atol,
|
|
rtol=output_ref_rtol,
|
|
)
|
|
|
|
if not skip_backward:
|
|
nt_grads = torch.autograd.grad(
|
|
attn_nt.values().sum(), (q_nt, k_nt, v_nt)
|
|
)
|
|
for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(
|
|
nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols
|
|
):
|
|
unbound_nt_grads = nt_grad.unbind()
|
|
self.assertEqual(
|
|
d1_grad,
|
|
unbound_nt_grads[0].unsqueeze(0),
|
|
atol=grad_atol,
|
|
rtol=grad_rtol,
|
|
)
|
|
self.assertEqual(
|
|
d2_grad,
|
|
unbound_nt_grads[1].unsqueeze(0),
|
|
atol=grad_atol,
|
|
rtol=grad_rtol,
|
|
)
|
|
|
|
# Default
|
|
check_forward_backward()
|
|
|
|
# Test dispatcher works by calling only mem-effn and math (as they are safe for all devices)
|
|
with torch.backends.cuda.sdp_kernel(
|
|
enable_flash=False, enable_mem_efficient=True, enable_math=True
|
|
):
|
|
check_forward_backward()
|
|
|
|
# Test math fallback
|
|
with torch.backends.cuda.sdp_kernel(
|
|
enable_flash=False, enable_mem_efficient=False, enable_math=True
|
|
):
|
|
# Math fallback doesn't work with bfloat16 on CUDA because
|
|
# "group_gemm_dispatch" not implemented for 'BFloat16'
|
|
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
|
|
check_forward_backward()
|
|
check_cudnn = os.getenv("TORCH_CUDNN_SDPA_NESTED_TENSOR_ENABLED", "0") == "1"
|
|
if (
|
|
"cuda" in str(device)
|
|
and check_cudnn
|
|
and (dtype == torch.float16 or dtype == torch.bfloat16)
|
|
):
|
|
with torch.nn.attention.sdpa_kernel(
|
|
torch.nn.attention.SDPBackend.CUDNN_ATTENTION
|
|
):
|
|
check_forward_backward()
|
|
|
|
@skipIfTorchDynamo("SDPA test compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
# Guarding with sqrt() doesn't work on ROCm?
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@dtypes(
|
|
*(
|
|
[torch.float16, torch.bfloat16, torch.float32]
|
|
if SM80OrLater
|
|
else [torch.float16, torch.float32]
|
|
)
|
|
)
|
|
def test_sdpa_compile(self, device, dtype):
|
|
batch_size = 1
|
|
emb_dims = 1024
|
|
n_heads = 8
|
|
head_dims = emb_dims // n_heads
|
|
|
|
sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
|
|
sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
|
|
|
|
query = torch.nn.Linear(
|
|
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
|
|
)
|
|
key = torch.nn.Linear(
|
|
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
|
|
)
|
|
value = torch.nn.Linear(
|
|
emb_dims, emb_dims, bias=False, device=device, dtype=dtype
|
|
)
|
|
|
|
# Simplest case: 1 sentence, no batching
|
|
x_d1 = sen1.unsqueeze(0)
|
|
x_d2 = sen2.unsqueeze(0)
|
|
x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
|
|
|
|
q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
|
|
q_nt = (
|
|
query(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.transpose(1, 2)
|
|
)
|
|
k_nt = (
|
|
key(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.transpose(1, 2)
|
|
)
|
|
v_nt = (
|
|
value(x_nt)
|
|
.view(*x_nt.size()[0:2], n_heads, head_dims)
|
|
.detach()
|
|
.transpose(1, 2)
|
|
)
|
|
|
|
# High Precision Math Reference
|
|
q_d1_f32 = q_d1.to(torch.float32)
|
|
k_d1_f32 = k_d1.to(torch.float32)
|
|
v_d1_f32 = v_d1.to(torch.float32)
|
|
out_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
|
q_d1_f32, k_d1_f32, v_d1_f32
|
|
)[0]
|
|
# Low Precision Math Reference
|
|
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(
|
|
q_d1, k_d1, v_d1
|
|
)[0]
|
|
output_ref_atol, output_ref_rtol = get_tolerances(
|
|
out_ref, out_lp_ref, fudge_factor=2
|
|
)
|
|
|
|
attn_d1 = torch.nn.functional.scaled_dot_product_attention(
|
|
q_d1, k_d1, v_d1
|
|
).transpose(1, 2)
|
|
attn_d2 = torch.nn.functional.scaled_dot_product_attention(
|
|
q_d2, k_d2, v_d2
|
|
).transpose(1, 2)
|
|
|
|
compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention)
|
|
attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2)
|
|
|
|
attn_nts = attn_nt.unbind()
|
|
self.assertEqual(
|
|
attn_d1,
|
|
attn_nts[0].unsqueeze(0),
|
|
atol=output_ref_atol,
|
|
rtol=output_ref_rtol,
|
|
)
|
|
self.assertEqual(
|
|
attn_d2,
|
|
attn_nts[1].unsqueeze(0),
|
|
atol=output_ref_atol,
|
|
rtol=output_ref_rtol,
|
|
)
|
|
|
|
@dtypes(torch.float32, torch.double, torch.half)
|
|
def test_sdpa_with_constant_sequence_length(self, device, dtype):
|
|
# shape (B, P*, S, D)
|
|
# B: batch size
|
|
# P*: ragged number of prompts
|
|
# S: (constant) sequence length
|
|
# D: embedding size
|
|
query = random_nt_from_dims(
|
|
[4, None, 8, 10],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
key = random_nt_from_similar(query)
|
|
value = random_nt_from_similar(query)
|
|
output = F.scaled_dot_product_attention(query, key, value)
|
|
self.assertTrue(isinstance(output, NestedTensor))
|
|
output.values().sum().backward()
|
|
|
|
query_dense = query.detach().clone().requires_grad_(True)
|
|
# should be equivalent to just running the buffers through
|
|
output_dense = F.scaled_dot_product_attention(
|
|
query_dense.values(), key.values(), value.values()
|
|
)
|
|
torch._dynamo.disable(self.assertEqual)(output._values, output_dense)
|
|
output_dense.sum().backward()
|
|
torch._dynamo.disable(self.assertEqual)(query.grad, query_dense.grad)
|
|
|
|
@onlyCUDA
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
|
"Platform doesn't support flash or mem-efficient attention",
|
|
)
|
|
@dtypes(
|
|
*(
|
|
[torch.float16, torch.bfloat16, torch.float32]
|
|
if SM80OrLater
|
|
else [torch.float16, torch.float32]
|
|
)
|
|
)
|
|
def test_sdpa_with_packed_in_proj(self, device, dtype):
|
|
# shape (B, *, D)
|
|
input_packed = random_nt_from_dims(
|
|
[5, None, 10], device=device, dtype=dtype, layout=torch.jagged
|
|
)
|
|
|
|
# Do input projection.
|
|
num_heads = 2
|
|
# should be multiple of 4 for efficient kernels (e.g. flash / mem-efficient)
|
|
head_dim = 8
|
|
qkv_linear = torch.nn.Linear(10, num_heads * head_dim * 3).to(
|
|
device=device, dtype=dtype
|
|
)
|
|
|
|
def in_proj(input_packed, qkv_linear=qkv_linear):
|
|
qkv_post_proj = qkv_linear(input_packed)
|
|
# these are non-contiguous to trigger _is_safe_to_get_storage_as_tensor()
|
|
q, k, v = qkv_post_proj.chunk(3, dim=-1)
|
|
q = q.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
|
|
k = k.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
|
|
v = v.unflatten(-1, [num_heads, head_dim]).transpose(-2, -3)
|
|
return q, k, v
|
|
|
|
q, k, v = in_proj(input_packed)
|
|
output = F.scaled_dot_product_attention(q, k, v, attn_mask=None)
|
|
|
|
# compare to individually running unbound components through
|
|
for in_component, out_component in zip(
|
|
input_packed.unbind(), output.transpose(-2, -3).unbind()
|
|
):
|
|
q, k, v = in_proj(in_component)
|
|
out = F.scaled_dot_product_attention(q, k, v).transpose(-2, -3)
|
|
|
|
# Low Precision Math Reference
|
|
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q, k, v)[
|
|
0
|
|
].transpose(-2, -3)
|
|
output_ref_atol, output_ref_rtol = get_tolerances(
|
|
out, out_lp_ref, fudge_factor=2
|
|
)
|
|
|
|
self.assertEqual(
|
|
out, out_component, atol=output_ref_atol, rtol=output_ref_rtol
|
|
)
|
|
|
|
@skipIfTorchDynamo("SDPA test compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
# mha_varlen_fwd not supported on ROCm
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@dtypes(
|
|
*(
|
|
[torch.float16, torch.bfloat16, torch.float32]
|
|
if SM80OrLater
|
|
else [torch.float16, torch.float32]
|
|
)
|
|
)
|
|
def test_sdpa_backwards(self, device, dtype):
|
|
values = torch.randn(9, 3, 256, requires_grad=True, device=device, dtype=dtype)
|
|
offsets = torch.tensor([0, 1, 3, 5, 9], device=device, dtype=torch.int64)
|
|
|
|
@torch.compile
|
|
def f(values, offsets):
|
|
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
|
|
nt = nt.transpose(-2, -3)
|
|
# purposefully graph break to trigger view replay for subclass view input
|
|
torch.tensor(1).item()
|
|
output = F.scaled_dot_product_attention(nt, nt, nt).transpose(-2, -3)
|
|
return convert_nt_to_jagged(output)
|
|
|
|
output = f(values, offsets)
|
|
output.sum().backward()
|
|
self.assertEqual(values.grad, torch.ones_like(values))
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
|
"Platform doesn't support flash or mem-efficient attention",
|
|
)
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@skipIfTorchDynamo()
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
def test_sdpa_autocast(self, device):
|
|
def fn_nt(values32, values16, offsets):
|
|
nt32 = convert_jagged_to_nested_tensor(values32, offsets, max_length=16)
|
|
nt16 = convert_jagged_to_nested_tensor(values16, offsets, max_length=16)
|
|
nt32 = nt32.transpose(1, 2)
|
|
nt16 = nt16.transpose(1, 2)
|
|
return F.scaled_dot_product_attention(nt32, nt16, nt32)
|
|
|
|
def fn_dense(x32, x16):
|
|
x32 = x32.view(8, 16, 4, 16).transpose(1, 2)
|
|
x16 = x16.view(8, 16, 4, 16).transpose(1, 2)
|
|
return F.scaled_dot_product_attention(x32, x16, x32)
|
|
|
|
values32 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float32)
|
|
values16 = torch.randn((8 * 16, 4, 16), device=device, dtype=torch.float16)
|
|
offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
|
|
|
|
x32 = values32.clone()
|
|
x16 = values16.clone()
|
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
out_dense_eager = fn_dense(x32, x16)
|
|
out_dense_compiled = torch.compile(fn_dense)(x32, x16)
|
|
out_nt_eager = fn_nt(values32, values16, offsets)
|
|
out_nt_compiled = torch.compile(fn_nt)(values32, values16, offsets)
|
|
|
|
self.assertEqual(out_dense_eager, out_dense_compiled)
|
|
self.assertEqual(
|
|
out_dense_eager.transpose(1, 2),
|
|
out_nt_eager.values().transpose(0, 1).view(8, 16, 4, 16),
|
|
)
|
|
self.assertEqual(
|
|
out_dense_eager.transpose(1, 2),
|
|
out_nt_compiled.values().transpose(0, 1).view(8, 16, 4, 16),
|
|
)
|
|
|
|
def get_values():
|
|
return tuple(
|
|
x.detach().clone().requires_grad_(True) for x in (values32, values16)
|
|
)
|
|
|
|
v32_dense_eager, v16_dense_eager = get_values()
|
|
v32_dense_compile, v16_dense_compile = get_values()
|
|
v32_nt_eager, v16_nt_eager = get_values()
|
|
v32_nt_compile, v16_nt_compile = get_values()
|
|
|
|
with torch.autocast(device_type="cuda", dtype=torch.float16):
|
|
loss_dense_eager = fn_dense(v32_dense_eager, v16_dense_eager).sum()
|
|
loss_dense_compile = torch.compile(fn_dense)(
|
|
v32_dense_compile, v16_dense_compile
|
|
).sum()
|
|
loss_nt_eager = fn_nt(v32_nt_eager, v16_nt_eager, offsets).values().sum()
|
|
loss_nt_compile = (
|
|
torch.compile(fn_nt)(v32_nt_compile, v16_nt_compile, offsets)
|
|
.values()
|
|
.sum()
|
|
)
|
|
|
|
loss_dense_eager.backward()
|
|
loss_dense_compile.backward()
|
|
loss_nt_eager.backward()
|
|
loss_nt_compile.backward()
|
|
|
|
self.assertEqual(v32_dense_eager.grad, v32_dense_compile.grad)
|
|
self.assertEqual(v32_dense_eager.grad, v32_nt_eager.grad, atol=1e-4, rtol=1e-4)
|
|
self.assertEqual(
|
|
v32_dense_eager.grad, v32_nt_compile.grad, atol=1e-4, rtol=1e-4
|
|
)
|
|
|
|
self.assertEqual(v16_dense_eager.grad, v16_dense_compile.grad)
|
|
self.assertEqual(v16_dense_eager.grad, v16_nt_eager.grad, atol=1e-5, rtol=5e-3)
|
|
self.assertEqual(
|
|
v16_dense_eager.grad, v16_nt_compile.grad, atol=1e-5, rtol=5e-3
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FUSED_ATTENTION,
|
|
"Platform doesn't support flash or mem-efficient attention",
|
|
)
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
@onlyCUDA
|
|
@skipIfTorchDynamo()
|
|
def test_sdpa_flop_counter(self, device):
|
|
from torch.utils.flop_counter import FlopCounterMode
|
|
|
|
def get_flops(nt):
|
|
flop_counter = FlopCounterMode(display=False)
|
|
with flop_counter:
|
|
ret = torch.nn.functional.scaled_dot_product_attention(nt, nt, nt)
|
|
ret.values().sum().backward()
|
|
return flop_counter.get_total_flops()
|
|
|
|
values = torch.randn(
|
|
(8 * 16, 4, 16), requires_grad=True, device=device, dtype=torch.float16
|
|
)
|
|
offsets = torch.arange(0, 8 * 16 + 1, 16, device=device, dtype=torch.int32)
|
|
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=16).transpose(
|
|
1, 2
|
|
)
|
|
|
|
values_meta = torch.randn(
|
|
(8 * 16, 4, 16), requires_grad=True, device="meta", dtype=torch.float16
|
|
)
|
|
offsets_meta = torch.arange(0, 8 * 16 + 1, 16, device="meta", dtype=torch.int32)
|
|
nt_meta = convert_jagged_to_nested_tensor(
|
|
values_meta, offsets_meta, max_length=16
|
|
).transpose(1, 2)
|
|
|
|
self.assertEqual(get_flops(nt), get_flops(nt_meta))
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_nested_tensor_activation_checkpoint(self, device):
|
|
values = torch.randn(
|
|
9, 3, 256, requires_grad=True, device=device, dtype=torch.float32
|
|
)
|
|
lengths = torch.tensor([1, 2, 3, 3], device=device, dtype=torch.int64)
|
|
offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
|
|
|
|
def fn(values, offsets):
|
|
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
|
|
return convert_nt_to_jagged(nt).sum()
|
|
|
|
checkpoint(fn, values, offsets, use_reentrant=False).backward()
|
|
self.assertIsNotNone(values.grad)
|
|
|
|
context_fn = partial(
|
|
create_selective_checkpoint_contexts, [torch.ops.aten.cumsum.default]
|
|
)
|
|
|
|
values.grad = None
|
|
|
|
def fn(values, lengths):
|
|
offsets = F.pad(lengths, pad=(1, 0)).cumsum(dim=0)
|
|
nt = convert_jagged_to_nested_tensor(values, offsets, max_length=4)
|
|
return convert_nt_to_jagged(nt).sum()
|
|
|
|
checkpoint(
|
|
fn, values, lengths, use_reentrant=False, context_fn=context_fn
|
|
).backward()
|
|
self.assertIsNotNone(values.grad)
|
|
|
|
# Internally-defined NT use cases are lifted to here for maximum test realism.
|
|
# TODO: Remove these when ViewNestedFromBuffer, etc. are deprecated.
|
|
@skipCUDAIfRocm # not needed
|
|
@skipIfTorchDynamo("compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@parametrize("use_legacy_api", [True, False])
|
|
@skipCPUIf(True, "SPDA Math NT fallback causes failure: see issue #133644")
|
|
def test_dummy_mha_with_nt(self, device, use_legacy_api):
|
|
bs = 3
|
|
d1 = 2
|
|
d2 = 4
|
|
d3 = 16
|
|
n_heads = 2
|
|
d_head = d3 // n_heads
|
|
max_length_1 = 10
|
|
max_length_2 = 20
|
|
torch.manual_seed(0)
|
|
|
|
class mha(torch.nn.Module):
|
|
def __init__(self, use_legacy_api) -> None:
|
|
super().__init__()
|
|
torch.manual_seed(0)
|
|
self.linear = torch.nn.Linear(d2, d3, device=device)
|
|
self.use_legacy_api = use_legacy_api
|
|
|
|
def forward(self, query, value, offsets):
|
|
value = self.linear(value)
|
|
if self.use_legacy_api:
|
|
key = convert_jagged_to_nested_tensor_legacy(
|
|
value, offsets, max_length_1
|
|
)
|
|
value = convert_jagged_to_nested_tensor_legacy(
|
|
value, offsets, max_length_2
|
|
)
|
|
query = convert_dense_to_nested_tensor_legacy(query)
|
|
else:
|
|
key = convert_jagged_to_nested_tensor(value, offsets, max_length_1)
|
|
value = convert_jagged_to_nested_tensor(
|
|
value, offsets, max_length_2
|
|
)
|
|
query = convert_dense_to_nested_tensor(query)
|
|
q = query.view(bs, -1, n_heads, d_head).transpose(1, 2)
|
|
k = key.view(bs, -1, n_heads, d_head).transpose(1, 2)
|
|
v = value.view(bs, -1, n_heads, d_head).transpose(1, 2)
|
|
|
|
with torch.nn.attention.sdpa_kernel(
|
|
[
|
|
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
|
|
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
|
|
]
|
|
):
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
attn_mask=None,
|
|
dropout_p=0.0,
|
|
is_causal=False,
|
|
)
|
|
attn_output = attn_output.transpose(1, 2)
|
|
if self.use_legacy_api:
|
|
attn_output = convert_nt_to_jagged_legacy(attn_output)
|
|
else:
|
|
attn_output = convert_nt_to_jagged(attn_output)
|
|
return attn_output, key._max_seqlen, value._max_seqlen
|
|
|
|
query = torch.rand(bs, d1, d3, device=device)
|
|
value = torch.rand(30, d2, requires_grad=True, device=device)
|
|
# total_length must > than max_length otherwise flash_attn backward will fail
|
|
offsets = torch.tensor([0, 2, 3, 30], device=device)
|
|
|
|
m = mha(use_legacy_api)
|
|
symbolic_traced: torch.fx.GraphModule = torch.fx.symbolic_trace(m)
|
|
m = torch.compile(symbolic_traced)
|
|
attn_output, cached_key_max_seqlen, cached_value_max_seqlen = m(
|
|
query, value, offsets
|
|
)
|
|
loss = attn_output.sum()
|
|
# Check that NT can be fx traced and torch.compile, and backward works
|
|
loss.backward()
|
|
|
|
# Check that value.requires_grad is not lost after tracing and compiling
|
|
value_grad = value.grad # save for comparison later
|
|
self.assertIsNotNone(value_grad)
|
|
# check that max_seqlen is cached properly
|
|
self.assertEqual(cached_key_max_seqlen, max_length_1)
|
|
self.assertEqual(cached_value_max_seqlen, max_length_2)
|
|
|
|
# check if the output is numerically equivalent with the eager mode
|
|
m_eager = mha(use_legacy_api)
|
|
|
|
value.grad = None
|
|
attn_output_eager, _, _ = m_eager(query, value, offsets)
|
|
attn_output_eager.sum().backward()
|
|
self.assertTrue(torch.allclose(attn_output_eager, attn_output))
|
|
self.assertTrue(torch.allclose(value_grad, value.grad))
|
|
|
|
# Helper function to generate random query, key, value NJTs in (B, n_heads, *, D) format.
|
|
# If noncontig_with_holes is True, the results will be non-contiguous with holes (i.e. have
|
|
# both offsets and lengths specified).
|
|
def _rand_qkv(self, device, dtype, noncontig_with_holes=False, q_and_kv_match=True):
|
|
batch_size = 8
|
|
n_heads = 8
|
|
D = 16
|
|
|
|
def _rand_nt(noncontig_with_holes=noncontig_with_holes):
|
|
sentence_lengths = [random.randint(2, 1023) for _ in range(batch_size - 1)]
|
|
total = sum(sentence_lengths)
|
|
|
|
# shape (B, *, D_total) where D_total = n_heads * D
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(l, n_heads * D, device=device, dtype=dtype)
|
|
for l in sentence_lengths
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
if noncontig_with_holes:
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
nt._values,
|
|
nt._offsets,
|
|
# -1 to introduce holes
|
|
lengths=nt._offsets.diff() - 1,
|
|
jagged_dim=nt._ragged_idx,
|
|
min_seqlen=nt._min_seqlen,
|
|
max_seqlen=nt._max_seqlen,
|
|
)
|
|
|
|
return nt
|
|
|
|
query = _rand_nt()
|
|
if q_and_kv_match:
|
|
key = torch.randn_like(query)
|
|
value = torch.randn_like(query)
|
|
else:
|
|
key = _rand_nt()
|
|
value = torch.randn_like(key)
|
|
|
|
# shape (B, *, D_total) -> (B, n_heads, *, D)
|
|
query = (
|
|
query.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
|
|
)
|
|
key = key.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
|
|
value = (
|
|
value.unflatten(-1, [n_heads, D]).transpose(1, 2).detach().requires_grad_()
|
|
)
|
|
|
|
return query, key, value
|
|
|
|
@unittest.skip(
|
|
"Temporarily skip - nested tensor backward pass broken after return-max-scores commit"
|
|
)
|
|
@onlyCUDA
|
|
@flex_attention_supported_platform
|
|
@dtypes(torch.float32)
|
|
# non-contiguous with holes not supported yet
|
|
@decorateIf(unittest.skip, lambda params: params["noncontig_with_holes"])
|
|
@parametrize("noncontig_with_holes", [False, True])
|
|
@parametrize("cross_attention", [False, True])
|
|
@skipIfRocm
|
|
def test_flex_attention(self, device, dtype, noncontig_with_holes, cross_attention):
|
|
query, key, value = self._rand_qkv(
|
|
device, dtype, noncontig_with_holes, q_and_kv_match=(not cross_attention)
|
|
)
|
|
|
|
# Run FlexAttention with a causal mask
|
|
def causal_mask(b, h, q_idx, kv_idx):
|
|
return q_idx >= kv_idx
|
|
|
|
if cross_attention:
|
|
block_mask = create_nested_block_mask(
|
|
causal_mask, 1, 1, query, key, _compile=True
|
|
)
|
|
else:
|
|
block_mask = create_nested_block_mask(
|
|
causal_mask, 1, 1, query, _compile=True
|
|
)
|
|
|
|
out_flex = flex_attention(query, key, value, block_mask=block_mask)
|
|
grad_out = torch.randn_like(out_flex)
|
|
grads_flex = torch.autograd.grad(
|
|
out_flex, inputs=(query, key, value), grad_outputs=(grad_out,)
|
|
)
|
|
flex_outs = [out_flex, *grads_flex]
|
|
|
|
# Run FlexAttention with a score_mod that represents causal attention
|
|
def causal_score_mod(score, b, h, q_idx, kv_idx):
|
|
return torch.where(q_idx >= kv_idx, score, float("-inf"))
|
|
|
|
out_flex2 = flex_attention(query, key, value, score_mod=causal_score_mod)
|
|
grads_flex2 = torch.autograd.grad(
|
|
out_flex2, inputs=(query, key, value), grad_outputs=(grad_out,)
|
|
)
|
|
flex_outs2 = [out_flex2, *grads_flex2]
|
|
|
|
# Run causal SDPA for comparison
|
|
out_sdpa = F.scaled_dot_product_attention(query, key, value, is_causal=True)
|
|
grads_sdpa = torch.autograd.grad(
|
|
out_sdpa, inputs=(query, key, value), grad_outputs=(grad_out,)
|
|
)
|
|
sdpa_outs = [out_sdpa, *grads_sdpa]
|
|
|
|
# Compare flex vs. SDPA output and grads
|
|
for flex, flex2, sdpa in zip(flex_outs, flex_outs2, sdpa_outs):
|
|
self.assertTrue(flex.is_nested and flex2.is_nested and sdpa.is_nested)
|
|
self.assertEqual(flex, sdpa, atol=1e-2, rtol=1e-2)
|
|
self.assertEqual(flex2, sdpa, atol=1e-2, rtol=1e-2)
|
|
|
|
@onlyCUDA
|
|
@flex_attention_supported_platform
|
|
@dtypes(torch.float32)
|
|
def test_flex_attention_converts_stacked_seq_indices(self, device, dtype):
|
|
# This test verifies that a score_mod function written to operate within
|
|
# NJT sequence index space, such as a lookup table, works correctly. This
|
|
# validates that FlexAttention properly converts indices within the
|
|
# "stacked sequence" space used for NJT -> sequence-relative indices.
|
|
query, key, value = self._rand_qkv(device, dtype)
|
|
|
|
# Test with score_mod
|
|
score_mod_table = torch.randn(query._max_seqlen, device=device, dtype=dtype)
|
|
|
|
def my_score_mod(score, b, h, q_idx, kv_idx):
|
|
return score_mod_table[q_idx]
|
|
|
|
flex_attention(query, key, value, score_mod=my_score_mod)
|
|
|
|
# Test with batch-specific score_mod
|
|
batch_size = query.size(0)
|
|
batch_table = torch.randn(batch_size, device=device, dtype=dtype)
|
|
# Keep score the same for batch index == 0
|
|
batch_table[0].zero_()
|
|
|
|
def batch_specific_score_mod(score, b, h, q_idx, kv_idx):
|
|
return score + batch_table[b]
|
|
|
|
def identity_score_mod(score, b, h, q_idx, kv_idx):
|
|
return score
|
|
|
|
output = flex_attention(query, key, value, score_mod=batch_specific_score_mod)
|
|
output_identity = flex_attention(
|
|
query, key, value, score_mod=identity_score_mod
|
|
)
|
|
|
|
# Guard against a bug where the batch index passed to score_mod is always b == 0.
|
|
# Output would be equivalent to applying an identity score_mod.
|
|
# See https://github.com/pytorch/pytorch/issues/143788
|
|
self.assertFalse(torch.allclose(output._values, output_identity._values))
|
|
|
|
# Test with mask_mod
|
|
mask_mod_table = score_mod_table > 0.0
|
|
|
|
def my_mask_mod(b, h, q_idx, kv_idx):
|
|
return mask_mod_table[q_idx]
|
|
|
|
def my_mask_mod2(b, h, q_idx, kv_idx):
|
|
return mask_mod_table[q_idx] & (b == 0)
|
|
|
|
block_mask = create_nested_block_mask(my_mask_mod, 1, 1, query, _compile=True)
|
|
output = flex_attention(query, key, value, block_mask=block_mask)
|
|
|
|
block_mask2 = create_nested_block_mask(my_mask_mod2, 1, 1, query, _compile=True)
|
|
output2 = flex_attention(query, key, value, block_mask=block_mask2)
|
|
|
|
# Guard against a bug where the batch index passed to mask_mod is always b == 0.
|
|
# See https://github.com/pytorch/pytorch/issues/143788
|
|
self.assertFalse(torch.allclose(output._values, output2._values))
|
|
|
|
@dtypes(torch.float32)
|
|
def test_apply_(self, device, dtype):
|
|
nt = random_nt_from_dims(
|
|
[5, None, 10],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
def f(x):
|
|
return x * 2
|
|
|
|
if device != "cpu":
|
|
with self.assertRaisesRegex(
|
|
TypeError, "apply_ is only implemented on CPU tensors"
|
|
):
|
|
nt.apply_(f)
|
|
return
|
|
|
|
before = nt._values.detach().clone()
|
|
|
|
nt.apply_(f)
|
|
expected = f(before)
|
|
self.assertEqual(expected, nt._values)
|
|
# apply_ should swap values in-place without appending to autograd graph
|
|
self.assertIsNone(nt.grad)
|
|
self.assertIsNone(nt._values.grad_fn)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float64, torch.float32, torch.half)
|
|
@parametrize(
|
|
"contiguity",
|
|
["noncontig_transposed", "noncontig_with_holes"],
|
|
name_fn=lambda c: c,
|
|
)
|
|
def test_noncontiguous_to(self, device, dtype, contiguity):
|
|
# Dense tensors preserve non-contiguity through to() calls (i.e. strides are
|
|
# preserved). Test for the analogous behavior for NJTs:
|
|
# 1. non-contiguous transposed
|
|
# 2. non-contiguous with holes
|
|
if contiguity == "noncontig_transposed":
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
).transpose(-3, -2)
|
|
elif contiguity == "noncontig_with_holes":
|
|
nt = torch.nested.nested_tensor_from_jagged(
|
|
values=torch.randn(10, 3, device=device, dtype=dtype),
|
|
offsets=torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int64),
|
|
# these lengths specify holes
|
|
lengths=torch.tensor([1, 2, 3], device=device, dtype=torch.int64),
|
|
)
|
|
else:
|
|
raise ValueError("invalid contiguity specified for test_noncontiguous_to()")
|
|
|
|
# test dtype conversion
|
|
dtype_conversions = {
|
|
torch.float32: torch.half,
|
|
torch.float64: torch.float32,
|
|
torch.half: torch.float32,
|
|
}
|
|
other_dtype = dtype_conversions[dtype]
|
|
nt2 = nt.to(dtype=other_dtype)
|
|
self.assertEqual(nt2.dtype, other_dtype)
|
|
self.assertEqual(nt.is_contiguous(), nt2.is_contiguous())
|
|
self.assertEqual(nt._values.is_contiguous(), nt2._values.is_contiguous())
|
|
self.assertEqual(nt.shape, nt2.shape)
|
|
# expect no change for offsets / lengths
|
|
self.assertEqual(nt._offsets, nt2._offsets)
|
|
self.assertEqual(nt._lengths, nt2._lengths)
|
|
|
|
# test device conversion
|
|
other_device = torch.device("cpu")
|
|
nt3 = nt.to(device=other_device)
|
|
self.assertEqual(nt3.device, other_device)
|
|
self.assertEqual(nt.is_contiguous(), nt3.is_contiguous())
|
|
self.assertEqual(nt._values.is_contiguous(), nt3._values.is_contiguous())
|
|
self.assertEqual(nt.shape, nt3.shape)
|
|
# expect device change for offsets / lengths
|
|
self.assertEqual(nt3._offsets.device, other_device)
|
|
if nt._lengths is not None:
|
|
self.assertEqual(nt3._lengths.device, other_device)
|
|
|
|
@dtypes(torch.float32)
|
|
def test_autograd_function_with_None_grad(self, device, dtype):
|
|
class MyFunction(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, inp):
|
|
ctx.save_for_backward(inp)
|
|
out1 = inp + 1
|
|
out2 = inp * 2
|
|
return out1, out2
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out1, grad_out2):
|
|
(inp,) = ctx.saved_tensors
|
|
return grad_out1 + grad_out2
|
|
|
|
f = MyFunction.apply
|
|
nt = random_nt_from_dims(
|
|
[5, None, 10],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
# Only use one of the autograd.Function outputs downstream so that the grad
|
|
# for the other output is None. We're testing that the engine can allocate
|
|
# correctly-shaped (NJT) zeros for the grad of the other output in this case.
|
|
(out1, _) = f(nt)
|
|
out1.backward(torch.ones_like(out1))
|
|
|
|
@dtypes(torch.float64, torch.float32, torch.half)
|
|
def test_jagged_padded_dense_conversion_kernels(self, device, dtype):
|
|
values = torch.randn(10, 5, device=device, dtype=dtype)
|
|
offsets = torch.tensor([0, 1, 3, 8, 10], device=device, dtype=torch.int64)
|
|
max_length = offsets.diff().max().item()
|
|
padding_value = 1.3
|
|
|
|
# convert jagged -> padded dense
|
|
padded = torch.ops.aten._jagged_to_padded_dense_forward(
|
|
values, [offsets], [max_length], padding_value
|
|
)
|
|
|
|
batch_size = offsets.shape[0] - 1
|
|
expected_padded_shape = (batch_size, max_length, values.shape[-1])
|
|
self.assertEqual(padded.shape, expected_padded_shape)
|
|
|
|
# convert padded dense -> jagged
|
|
total_L = values.shape[0]
|
|
output_jagged = torch.ops.aten._padded_dense_to_jagged_forward(
|
|
padded, [offsets], total_L
|
|
)
|
|
|
|
# should be equivalent to the original values
|
|
self.assertEqual(values, output_jagged)
|
|
|
|
# success case: truncate to max length as needed
|
|
trunc_max_length = max_length - 1
|
|
trunc_padded = torch.ops.aten._jagged_to_padded_dense_forward(
|
|
values, [offsets], [trunc_max_length], padding_value
|
|
)
|
|
self.assertEqual(padded[:, :trunc_max_length, :], trunc_padded)
|
|
|
|
# specific to CPU impls
|
|
if device == "cpu":
|
|
# error case: multiple offsets on cpu since CPU kernels don't support more now
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only a single jagged dim is supported"
|
|
):
|
|
torch.ops.aten._jagged_to_padded_dense_forward(
|
|
values, [offsets, offsets], [max_length, max_length], padding_value
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only a single jagged dim is supported"
|
|
):
|
|
torch.ops.aten._padded_dense_to_jagged_forward(
|
|
padded, [offsets, offsets], total_L
|
|
)
|
|
|
|
# error case: > 1D offsets
|
|
offsets2d = offsets.unsqueeze(-1)
|
|
with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
|
|
torch.ops.aten._jagged_to_padded_dense_forward(
|
|
values, [offsets2d], [max_length], padding_value
|
|
)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected 1D offsets"):
|
|
torch.ops.aten._padded_dense_to_jagged_forward(
|
|
padded, [offsets2d], total_L
|
|
)
|
|
|
|
# error case: final offset != total_L
|
|
offsets_wrong = offsets.detach().clone()
|
|
offsets_wrong[-1] = total_L + 1
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "final offset should match total_L value"
|
|
):
|
|
torch.ops.aten._padded_dense_to_jagged_forward(
|
|
padded, [offsets_wrong], total_L
|
|
)
|
|
|
|
# error case: 1D padded input
|
|
padded_wrong = padded.flatten().detach().clone()
|
|
with self.assertRaisesRegex(RuntimeError, "expected padded dim >= 2"):
|
|
torch.ops.aten._padded_dense_to_jagged_forward(
|
|
padded_wrong, [offsets], total_L
|
|
)
|
|
|
|
# error case: batch item has length > max length
|
|
# max_length is 5 above; 7 here
|
|
offsets_wrong = torch.tensor(
|
|
[0, 1, 8, 9, 10], device=device, dtype=torch.int64
|
|
)
|
|
with self.assertRaisesRegex(RuntimeError, "found batch item of length"):
|
|
torch.ops.aten._padded_dense_to_jagged_forward(
|
|
padded, [offsets_wrong], total_L
|
|
)
|
|
|
|
@dtypes(torch.float32)
|
|
@skipIfTorchDynamo("Test compiles internally")
|
|
@unittest.skipIf(
|
|
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
|
)
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
def test_compile_preserves_metadata_cache(self, device, dtype):
|
|
# shape (B, *, D)
|
|
nt = random_nt_from_dims(
|
|
[4, None, 3, 16],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
# expect min / max seqlen to be stored here
|
|
cache = dict(nt._metadata_cache)
|
|
|
|
@torch.compile
|
|
def f(nt):
|
|
q = nt.transpose(-3, -2)
|
|
output = F.scaled_dot_product_attention(q, q, q).transpose(-3, -2)
|
|
return output
|
|
|
|
output = f(nt)
|
|
output.backward(torch.ones_like(output))
|
|
self.assertEqual(output._metadata_cache, cache)
|
|
|
|
@dtypes(torch.float32)
|
|
@skipIfTorchDynamo("Test compiles internally")
|
|
@unittest.skipIf(
|
|
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
|
)
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
def test_compile_with_dynamic_max_seq_len(self, device, dtype):
|
|
# shape (B, *, D)
|
|
# max seq len: 18
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 5),
|
|
torch.randn(3, 5),
|
|
torch.randn(18, 5),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# max seq len: 19
|
|
nt2 = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 5),
|
|
torch.randn(3, 5),
|
|
torch.randn(19, 5),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
def f(nt):
|
|
# TODO: Replace with public API when we can use @properties
|
|
return torch.ones_like(nt) * nt._get_max_seqlen()
|
|
|
|
for dynamic in [False, True, None]:
|
|
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
|
|
|
|
@dtypes(torch.float32)
|
|
@skipIfTorchDynamo("Test compiles internally")
|
|
@unittest.skipIf(
|
|
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
|
)
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
def test_compile_with_dynamic_min_seq_len(self, device, dtype):
|
|
# shape (B, *, D)
|
|
# min seq len: 7
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(7, 5),
|
|
torch.randn(8, 5),
|
|
torch.randn(9, 5),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# min seq len: 8
|
|
nt2 = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(8, 5),
|
|
torch.randn(9, 5),
|
|
torch.randn(10, 5),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
def f(nt):
|
|
# TODO: Replace with public API when we can use @properties
|
|
return torch.ones_like(nt) * nt._get_min_seqlen()
|
|
|
|
for dynamic in [False, True, None]:
|
|
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
|
|
|
|
@dtypes(torch.float32)
|
|
@skipIfTorchDynamo("Test compiles internally")
|
|
@unittest.skipIf(
|
|
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
|
)
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
def test_compile_with_propagated_dynamic_max_seq_len(self, device, dtype):
|
|
# shape (B, *, D)
|
|
# max seq len: 18
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 5),
|
|
torch.randn(3, 5),
|
|
torch.randn(18, 5),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
# max seq len: 19
|
|
nt2 = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 5),
|
|
torch.randn(3, 5),
|
|
torch.randn(19, 5),
|
|
],
|
|
layout=torch.jagged,
|
|
)
|
|
|
|
def f(nt):
|
|
nt2 = nt.sin() + 1
|
|
# TODO: Replace with public API when we can use @properties
|
|
return torch.ones_like(nt2) * nt2._get_max_seqlen()
|
|
|
|
ref = f(nt)
|
|
output = torch.compile(f, fullgraph=True, dynamic=False)(nt)
|
|
self.assertEqual(ref, output)
|
|
|
|
for dynamic in [False, True, None]:
|
|
self.assertFalse(_recompiles_for_inputs(f, (nt,), (nt2,), dynamic=dynamic))
|
|
|
|
def test_dropout_inference_mode(self, device):
|
|
seq_len = 32
|
|
embed_dim = 128
|
|
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(11, seq_len, embed_dim, device=device),
|
|
torch.randn(11, seq_len, embed_dim, device=device),
|
|
],
|
|
layout=torch.jagged,
|
|
device=device,
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
torch.nn.functional.dropout(nt, p=0.05)
|
|
|
|
@dtypes(torch.float32, torch.double, torch.half)
|
|
def test_unbind_backward(self, device, dtype):
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2, 4, device=device),
|
|
torch.randn(5, 4, device=device),
|
|
torch.randn(3, 4, device=device),
|
|
],
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
a, b, c = nt.unbind()
|
|
b.sum().backward()
|
|
|
|
@torch._dynamo.disable
|
|
def check(nt):
|
|
expected_grad = torch.zeros_like(nt)
|
|
expected_grad.unbind()[1].add_(1.0)
|
|
self.assertEqual(nt.grad, expected_grad)
|
|
|
|
check(nt)
|
|
|
|
@dtypes(torch.float32, torch.double, torch.half, torch.bool)
|
|
@parametrize("nt_dim", [2, 3, 4])
|
|
@parametrize("requires_grad", [False, True])
|
|
def test_to_padded_tensor(self, device, dtype, nt_dim, requires_grad):
|
|
if dtype is torch.bool and requires_grad:
|
|
# grads not supported for bool
|
|
return
|
|
|
|
if nt_dim == 2:
|
|
post_seq_len_shape = ()
|
|
elif nt_dim == 3:
|
|
post_seq_len_shape = (10,)
|
|
elif nt_dim == 4:
|
|
post_seq_len_shape = (9, 10)
|
|
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
|
|
if dtype is torch.bool
|
|
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
|
for n in range(2, 9)
|
|
],
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
PADDING_VAL = 4.2
|
|
expected_padded = nt._values.new_full((7, 8, *post_seq_len_shape), PADDING_VAL)
|
|
for i, component in enumerate(nt.unbind()):
|
|
expected_padded[i, : component.shape[0]].copy_(component)
|
|
|
|
padded = nt.to_padded_tensor(PADDING_VAL)
|
|
self.assertEqual(expected_padded, padded)
|
|
|
|
# convert padded dense -> NJT
|
|
from torch.nested._internal.nested_tensor import nested_from_padded
|
|
|
|
nt2 = nested_from_padded(padded, nt.offsets())
|
|
self.assertEqual(nt, nt2)
|
|
|
|
if requires_grad and dtype is not torch.bool:
|
|
# ensure gradients flow through conversions
|
|
nt2.backward(torch.ones_like(nt2))
|
|
self.assertEqual(nt.grad, torch.ones_like(nt))
|
|
|
|
# blows up due to test parametrization otherwise
|
|
@torch._dynamo.utils.disable_cache_limit()
|
|
@skipIfTorchDynamo("SDPA test compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
@dtypes(torch.float32, torch.double, torch.half)
|
|
@parametrize("nt_dim", [2, 3, 4])
|
|
@parametrize("requires_grad", [False, True])
|
|
def test_to_padded_tensor_compile(self, device, dtype, nt_dim, requires_grad):
|
|
if dtype is torch.bool and requires_grad:
|
|
# grads not supported for bool
|
|
return
|
|
|
|
if nt_dim == 2:
|
|
post_seq_len_shape = ()
|
|
elif nt_dim == 3:
|
|
post_seq_len_shape = (10,)
|
|
elif nt_dim == 4:
|
|
post_seq_len_shape = (9, 10)
|
|
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randint(2, (n, *post_seq_len_shape), device=device, dtype=dtype)
|
|
if dtype is torch.bool
|
|
else torch.randn(n, *post_seq_len_shape, device=device, dtype=dtype)
|
|
for n in range(2, 9)
|
|
],
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
def f(x):
|
|
return x.sin() + 1
|
|
|
|
from torch.nested._internal.nested_tensor import nested_from_padded
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def g(nt):
|
|
def _g(nt):
|
|
PADDING_VAL = 4.2
|
|
padded = nt.to_padded_tensor(PADDING_VAL)
|
|
padded = f(padded)
|
|
# NB: sum_S must be specified to use the lowering for dense -> jagged
|
|
# and get full fusion
|
|
return nested_from_padded(
|
|
padded, nt.offsets(), sum_S=nt.values().shape[0]
|
|
)
|
|
|
|
# NB: use checkpointing to force fusion
|
|
return torch.utils.checkpoint.checkpoint(_g, nt, use_reentrant=False)
|
|
|
|
expected_output = f(nt)
|
|
if requires_grad:
|
|
expected_output.backward(torch.ones_like(expected_output))
|
|
expected_grad = nt.grad.detach().clone()
|
|
nt.grad = None
|
|
|
|
from torch._inductor.utils import run_and_get_code
|
|
|
|
compiled_output, generated_code = run_and_get_code(g, nt)
|
|
if requires_grad:
|
|
compiled_output.backward(torch.ones_like(compiled_output))
|
|
compiled_grad = nt.grad.detach().clone()
|
|
self.assertEqual(compiled_grad, expected_grad, rtol=1e-3, atol=1e-3)
|
|
|
|
self.assertEqual(compiled_output, expected_output, rtol=1e-3, atol=1e-3)
|
|
|
|
# === Verify that computation fusion happens. ===
|
|
# Fallback op call -> fusion didn't happen.
|
|
fallback_op_calls_present = any(
|
|
"torch.ops.aten._padded_dense_to_jagged_forward.default("
|
|
in generated_code[i]
|
|
or "torch.ops.aten._jagged_to_padded_dense_forward.default("
|
|
in generated_code[i]
|
|
for i in range(len(generated_code))
|
|
)
|
|
|
|
# NB: Fusion isn't supported on CPU.
|
|
self.assertEqual("cuda" in device, not fallback_op_calls_present)
|
|
|
|
for i in range(len(generated_code)):
|
|
# Examine buffer construction lines in the generated code to determine
|
|
# whether fusion occurred. If fusion happens, a 3D buffer with shape
|
|
# (B, max_seqlen, D) should never be materialized.
|
|
buffer_constructions = [
|
|
line.strip()
|
|
for line in generated_code[i].split("\n")
|
|
if "empty_strided_cuda(" in line
|
|
]
|
|
|
|
buffer_dims = [
|
|
# buffer dim == number of elements in the tensor size tuple arg
|
|
len(ast.parse(t).body[0].value.args[0].elts)
|
|
for t in buffer_constructions
|
|
]
|
|
|
|
if "cuda" in device:
|
|
self.assertFalse(any(d == 3 for d in buffer_dims))
|
|
|
|
@dtypes(torch.float32)
|
|
@skipIfTorchDynamo("Test compiles internally")
|
|
@unittest.skipIf(
|
|
sys.version_info >= (3, 12), "torch.compile is not supported on python 3.12+"
|
|
)
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@skipCUDAIfRocm
|
|
def test_compile_padded_dense_conversion_preserves_metadata_cache(
|
|
self, device, dtype
|
|
):
|
|
# shape (B, *, D)
|
|
nt = random_nt_from_dims(
|
|
[4, None, 3, 16],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
# expect min / max seqlen to be stored here
|
|
cache = dict(nt._metadata_cache)
|
|
|
|
@torch.compile
|
|
def g(nt):
|
|
padded = nt.to_padded_tensor(0.3)
|
|
intermediate = padded.sin() + 1
|
|
|
|
from torch.nested._internal.nested_tensor import nested_from_padded
|
|
|
|
return nested_from_padded(
|
|
intermediate,
|
|
nt.offsets(),
|
|
min_seqlen=nt._min_seqlen,
|
|
max_seqlen=nt._max_seqlen,
|
|
sum_S=nt.values().shape[0],
|
|
)
|
|
|
|
output = g(nt)
|
|
output.backward(torch.ones_like(output))
|
|
self.assertEqual(output._metadata_cache, cache)
|
|
|
|
# See https://github.com/pytorch/pytorch/issues/128649
|
|
@dtypes(torch.float32)
|
|
def test_composite_op_in_inference_mode(self, device, dtype):
|
|
# expect view
|
|
nt = random_nt_from_dims(
|
|
[4, None, 48],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
|
|
with torch.inference_mode():
|
|
output = nt.reshape([4, -1, 3, 16])
|
|
self.assertEqual(output.shape, (4, nt.shape[1], 3, 16))
|
|
self.assertTrue(output._is_view())
|
|
|
|
# expect copy
|
|
nt = random_nt_from_dims(
|
|
[4, None, 3, 16],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
).transpose(-1, -2)
|
|
|
|
with torch.inference_mode():
|
|
output = nt.reshape([4, -1, 48])
|
|
self.assertEqual(output.shape, (4, nt.shape[1], 48))
|
|
self.assertFalse(output._is_view())
|
|
|
|
@dtypes(torch.float32)
|
|
def test_composite_op_with_custom_mode(self, device, dtype):
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
# simple passthrough TorchDispatchMode
|
|
class CustomDispatchMode(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=..., kwargs=None):
|
|
return func(*args, **kwargs)
|
|
|
|
nt = random_nt_from_dims(
|
|
[4, None, 2, 3],
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=True,
|
|
)
|
|
with CustomDispatchMode():
|
|
res = nt.reshape(4, -1, 6)
|
|
|
|
self.assertEqual(res.shape, (4, nt.shape[1], 6))
|
|
|
|
@skipIfTorchDynamo("compiles internally")
|
|
@unittest.skipIf(IS_WINDOWS, reason="Windows not yet supported for torch.compile")
|
|
@skipCUDAIf(not SM70OrLater, "GPU capability is < SM70")
|
|
@dtypes(torch.float32)
|
|
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_broadcast_shapes_on_in_graph_constructed_njt(self, device, dtype):
|
|
# Tests that a guard isn't wrongly installed on a freshly-created nested int when
|
|
# broadcast_shapes() is used on NJT shapes.
|
|
# See https://github.com/pytorch/pytorch/issues/145874 for more context.
|
|
nt = torch.nested.nested_tensor(
|
|
[
|
|
torch.randn(2),
|
|
torch.randn(3),
|
|
torch.randn(4),
|
|
],
|
|
layout=torch.jagged,
|
|
device=device,
|
|
dtype=dtype,
|
|
)
|
|
|
|
values = nt._values.detach().clone()
|
|
offsets = nt._offsets.detach().clone()
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def f(values, offsets):
|
|
nt = torch.nested.nested_tensor_from_jagged(values, offsets)
|
|
# NB: torch.where() utilizes broadcast_shapes() underneath
|
|
return torch.where(nt > 0.0, torch.ones_like(nt), torch.zeros_like(nt))
|
|
|
|
output = f(values, offsets)
|
|
self.assertTrue(output.is_nested)
|
|
self.assertEqual(nt.shape[:-1], output.shape[:-1])
|
|
for nt_component, output_component in zip(nt.unbind(), output.unbind()):
|
|
self.assertEqual(nt_component.shape, output_component.shape)
|
|
|
|
|
|
# The following lists specify skips and xfails for particular SampleInputs. Note that
|
|
# these are attempted to be matched from top to bottom and only one at most will
|
|
# be matched, so order matters! The guiding general principle here should be one
|
|
# xfail / skip per bug if at all possible :)
|
|
FORWARD_SKIPS_AND_XFAILS = [
|
|
# not implemented
|
|
XFailRule(
|
|
error_type=NotImplementedError,
|
|
op_match_fn=lambda device, op: op.full_name
|
|
in {
|
|
# unary
|
|
# needs log_sigmoid_forward, which returns a tuple
|
|
"nn.functional.logsigmoid",
|
|
"nn.functional.prelu",
|
|
# needs rrelu_with_noise
|
|
"nn.functional.rrelu",
|
|
# binary
|
|
"__rsub__",
|
|
"complex",
|
|
"floor_divide",
|
|
"polar",
|
|
"rsub",
|
|
# reduction
|
|
"count_nonzero",
|
|
"linalg.vector_norm",
|
|
"nansum",
|
|
"std",
|
|
"std.unbiased",
|
|
"var",
|
|
"var.unbiased",
|
|
"hash_tensor",
|
|
},
|
|
name="not_implemented",
|
|
),
|
|
# expected: torch.where() support has some limitations
|
|
# 1. condition must be an NJT
|
|
# 2. no dense tensors of higher dim than the NJT
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected condition to be a jagged layout NestedTensor",
|
|
op_match_fn=lambda device, op: op.full_name == "where",
|
|
sample_match_fn=lambda device, sample: not sample.kwargs["condition"].is_nested,
|
|
),
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="broadcasting nested tensors with dense tensors of equal or higher dim",
|
|
op_match_fn=lambda device, op: op.full_name == "where",
|
|
sample_match_fn=lambda device, sample: (
|
|
(
|
|
not sample.input.is_nested
|
|
and sample.input.dim() >= sample.kwargs["condition"].dim()
|
|
)
|
|
or (
|
|
not sample.kwargs["other"].is_nested
|
|
and sample.kwargs["other"].dim() >= sample.kwargs["condition"].dim()
|
|
)
|
|
),
|
|
),
|
|
# expected: masked ops don't support jagged layout
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expects strided",
|
|
op_match_fn=lambda device, op: op.full_name
|
|
in {
|
|
"masked.amax",
|
|
"masked.amin",
|
|
"masked.argmax",
|
|
"masked.argmin",
|
|
"masked.logsumexp",
|
|
"masked.mean",
|
|
"masked.norm",
|
|
"masked.prod",
|
|
"masked.std",
|
|
"masked.sum",
|
|
"masked.var",
|
|
},
|
|
name="no_masked_jagged_support",
|
|
),
|
|
# Op doesn't support lengths being present
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected input to be a contiguous jagged layout NestedTensor",
|
|
op_match_fn=lambda device, op: (op.full_name == "nn.functional.linear"),
|
|
sample_match_fn=lambda device, sample: (sample.input._lengths is not None),
|
|
name="no_linear_noncontig_holes_support",
|
|
),
|
|
# nanmean sometimes hits an unimplemented nansum() path and other times hits an
|
|
# unimplemented sum() path
|
|
XFailRule(
|
|
error_type=NotImplementedError,
|
|
op_match_fn=lambda device, op: (op.full_name == "nanmean"),
|
|
sample_match_fn=lambda device, sample: (
|
|
not (
|
|
"noncontig_holes" in sample.name
|
|
and "dim" in sample.kwargs
|
|
and (
|
|
(
|
|
isinstance(sample.kwargs["dim"], int)
|
|
and sample.kwargs["dim"] == sample.input._ragged_idx
|
|
)
|
|
or (
|
|
isinstance(sample.kwargs["dim"], (tuple, list))
|
|
and sample.input._ragged_idx in sample.kwargs["dim"]
|
|
)
|
|
)
|
|
)
|
|
),
|
|
name="nansum_unimplemented",
|
|
),
|
|
# expected: reducing across the ragged dimension is not supported for non-contiguous
|
|
# nested tensors with holes
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg=(
|
|
"reducing across the ragged dimension is not supported for non-contiguous "
|
|
"nested tensors with holes"
|
|
),
|
|
op_match_fn=lambda device, op: (
|
|
# min.reduction_with_dim and max.reduction_with_dim aren't associated with
|
|
# ReductionOpInfo entries sadly even though they're reductions
|
|
isinstance(op, ReductionOpInfo) or "reduction_with_dim" in op.full_name
|
|
),
|
|
sample_match_fn=lambda device, sample: (
|
|
"noncontig_holes" in sample.name
|
|
and "dim" in sample.kwargs
|
|
and (
|
|
(
|
|
isinstance(sample.kwargs["dim"], int)
|
|
and sample.kwargs["dim"] == sample.input._ragged_idx
|
|
)
|
|
or (
|
|
isinstance(sample.kwargs["dim"], (tuple, list))
|
|
and sample.input._ragged_idx in sample.kwargs["dim"]
|
|
)
|
|
)
|
|
),
|
|
name="ragged_dim_reduction_noncontig_holes",
|
|
),
|
|
# expected: index_put() doesn't work on non-contiguous NJTs without ragged dimension indices
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="If ragged dimension is not part of indices, this only works on contiguous NJTs",
|
|
op_match_fn=lambda device, op: (op.full_name == "index_put"),
|
|
sample_match_fn=lambda device, sample: (
|
|
not sample.input.is_contiguous()
|
|
and len(sample.kwargs["indices"]) - 1 < sample.input._ragged_idx
|
|
),
|
|
name="index_put_noncontig_holes_no_ragged_dim_indices",
|
|
),
|
|
# select() only supports dim=0 for non-contiguous with holes NJTs for now
|
|
XFailRule(
|
|
op_match_fn=lambda device, op: (op.full_name == "select"),
|
|
sample_match_fn=lambda device, sample: (
|
|
sample.kwargs["dim"] != 0 and "noncontig_holes" in sample.name
|
|
),
|
|
name="unsupported_select_on_non_batch_dim_with_noncontig_holes",
|
|
),
|
|
# these don't work on non-contiguous NJTs yet
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected self to be a contiguous jagged layout NestedTensor",
|
|
op_match_fn=lambda device, op: (
|
|
op.full_name
|
|
in {
|
|
"chunk",
|
|
"masked_select",
|
|
"narrow",
|
|
"split",
|
|
"split_with_sizes",
|
|
"squeeze",
|
|
}
|
|
),
|
|
sample_match_fn=lambda device, sample: (
|
|
sample.input._lengths is not None or sample.input._ragged_idx != 1
|
|
),
|
|
name="missing_noncontig_support",
|
|
),
|
|
# these don't work on the ragged dim yet
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="not supported for NestedTensor on ragged dim",
|
|
op_match_fn=lambda device, op: (
|
|
op.full_name
|
|
in {
|
|
"chunk",
|
|
"narrow",
|
|
"select",
|
|
"split",
|
|
}
|
|
),
|
|
sample_match_fn=lambda device, sample: "ragged_dim" in sample.name,
|
|
name="ragged_dim_unsupported",
|
|
),
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
# error comes from usage of view() in the decomp
|
|
error_msg="does not support ragged_idx != 1 except when",
|
|
op_match_fn=lambda device, op: (op.full_name == "unflatten"),
|
|
sample_match_fn=lambda device, sample: "noncontig_transposed" in sample.name,
|
|
name="unflatten_ragged_dim_unsupported",
|
|
),
|
|
# these don't work on the batch dim yet
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="not supported for NestedTensor on dim=0",
|
|
op_match_fn=lambda device, op: (
|
|
op.full_name
|
|
in {
|
|
"narrow",
|
|
"split",
|
|
"split_with_sizes",
|
|
"unsqueeze",
|
|
}
|
|
),
|
|
sample_match_fn=lambda device, sample: "batch_dim" in sample.name,
|
|
name="batch_dim_unsupported",
|
|
),
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
# error comes from usage of view() in the decomp
|
|
error_msg="cannot view shape",
|
|
op_match_fn=lambda device, op: (op.full_name == "unflatten"),
|
|
sample_match_fn=lambda device, sample: "batch_dim" in sample.name,
|
|
name="unflatten_batch_dim_unsupported",
|
|
),
|
|
# expected: bmm / matmul sometimes use a to_padded_tensor() fallback which isn't
|
|
# supported for non-contig NJTs with holes
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="not supported for nested tensors with holes",
|
|
op_match_fn=lambda device, op: (op.full_name in {"bmm", "matmul"}),
|
|
sample_match_fn=lambda device, sample: (
|
|
"noncontig_holes" in sample.name
|
|
# "other" is the name for the matmul arg and "mat2" is the name for the bmm arg
|
|
and sample.input.dim()
|
|
== sample.kwargs.get("other", sample.kwargs.get("mat2")).dim()
|
|
),
|
|
name="mm_noncontig_holes",
|
|
),
|
|
# some jiterator op failures due to unsupported jagged layout
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="unsupported tensor layout",
|
|
op_match_fn=lambda device, op: op.full_name
|
|
in {
|
|
"jiterator_binary",
|
|
"jiterator_binary_return_by_ref",
|
|
"jiterator_unary",
|
|
},
|
|
name="no_jiterator_jagged_support",
|
|
),
|
|
# Bug when broadcasting a binary op with non-contiguous with holes NJT + dense
|
|
# tensor with 1 in ragged dim.
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="cannot call binary pointwise function .* with inputs of shapes",
|
|
op_match_fn=lambda device, op: (isinstance(op, BinaryUfuncInfo)),
|
|
sample_match_fn=lambda device, sample: (
|
|
"noncontig_holes" in sample.name
|
|
and "broadcasting 1 over ragged" in sample.name
|
|
),
|
|
name="binary_noncontig_holes_broadcasting_1_over_ragged",
|
|
),
|
|
]
|
|
|
|
BACKWARD_SKIPS_AND_XFAILS = [
|
|
# segfaults, so skip. It's trying to use the NST logic for NJT
|
|
SkipRule(
|
|
op_match_fn=lambda device, op: op.full_name == "split_with_sizes",
|
|
name="split_with_sizes_backward_segfault",
|
|
),
|
|
*FORWARD_SKIPS_AND_XFAILS,
|
|
# Backwards is generally broken for non-contiguous NJTs with holes. Rather than
|
|
# determine the exceptions in detail, just skip for now. Fix is to ensure
|
|
# that summing over gradients during backwards after broadcasting takes into
|
|
# account holes / lengths.
|
|
SkipRule(
|
|
op_match_fn=lambda device, op: (
|
|
isinstance(op, BinaryUfuncInfo)
|
|
or op.full_name in {"mean", "where", "unsqueeze"}
|
|
),
|
|
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
|
|
name="broken_noncontig_holes_backward",
|
|
),
|
|
# mean(): need to examine backwards formula
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="SymIntArrayRef expected to contain only concrete integers",
|
|
op_match_fn=lambda device, op: (op.full_name in {"mean"}),
|
|
sample_match_fn=lambda device, sample: (
|
|
"full reduction" not in sample.name
|
|
and "normal dim reduction" not in sample.name
|
|
),
|
|
name="broken_mean_backward",
|
|
),
|
|
# RuntimeError: expand(): cannot expand shape (3, 3, 1, j44) -> [3, 3, 7, j44]
|
|
# with noncontig transposed inputs to mean()
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="cannot expand shape",
|
|
op_match_fn=lambda device, op: (op.full_name == "mean"),
|
|
sample_match_fn=lambda device, sample: (
|
|
"normal dim reduction" in sample.name
|
|
and "noncontig_transposed" in sample.name
|
|
),
|
|
name="broken_mean_backward2",
|
|
),
|
|
# unsqueeze() backward tries to call squeeze with noncontig transposed,
|
|
# but that's not supported
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected self to be a contiguous jagged layout NestedTensor",
|
|
op_match_fn=lambda device, op: (op.full_name == "unsqueeze"),
|
|
sample_match_fn=lambda device, sample: (
|
|
"noncontig_transposed" in sample.name or "ragged_dim" in sample.name
|
|
),
|
|
name="broken_unsqueeze_backward",
|
|
),
|
|
# RuntimeError: view(): cannot view shape (3, j62, 1, 7, 3) as [3, j58, 7, 3]
|
|
# with unflatten()
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="cannot view shape",
|
|
op_match_fn=lambda device, op: (op.full_name in {"unflatten"}),
|
|
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
|
|
name="broken_unflatten_backward",
|
|
),
|
|
# sum() backward is not implemented for non-full reductions
|
|
XFailRule(
|
|
error_type=NotImplementedError,
|
|
error_msg="aten._nested_sum_backward.default",
|
|
op_match_fn=lambda device, op: (op.full_name == "sum"),
|
|
sample_match_fn=lambda device, sample: ("full reduction" not in sample.name),
|
|
name="broken_sum_backward",
|
|
),
|
|
# squeeze(): invalid gradient shape; need to check formula
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="returned an invalid gradient at index 0",
|
|
op_match_fn=lambda device, op: (op.full_name == "squeeze"),
|
|
sample_match_fn=lambda device, sample: (
|
|
sample.name == "5D_contig_with_seqlen_cache: normal_dim"
|
|
and sample.kwargs["dim"] == 3
|
|
),
|
|
name="broken_squeeze_backward",
|
|
),
|
|
# sgn() / masked_select(): backwards formulas don't work at all
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="NestedTensor does not support directly calling torch.ops.aten.size",
|
|
op_match_fn=lambda device, op: (op.full_name in {"sgn", "masked_select"}),
|
|
name="broken_sgn_masked_select_backward",
|
|
),
|
|
# select(): grad_output is an NJT for non-batch-dim operation
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected grad_output to be a tensor",
|
|
op_match_fn=lambda device, op: (op.full_name == "select"),
|
|
sample_match_fn=lambda device, sample: ("batch_dim" not in sample.name),
|
|
name="broken_select_backward",
|
|
),
|
|
# prod(): completely broken in every way
|
|
XFailRule(
|
|
op_match_fn=lambda device, op: (op.full_name == "prod"),
|
|
name="broken_prod_backward",
|
|
),
|
|
# pow() / float_power(): use where() underneath; broken for (NT, T) broadcasting cases
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected condition to be a jagged layout NestedTensor",
|
|
op_match_fn=lambda device, op: (op.full_name in {"pow", "float_power"}),
|
|
sample_match_fn=lambda device, sample: ("(NT, T)" in sample.name),
|
|
name="broken_pow_backward",
|
|
),
|
|
# __rpow__() backward is also broken, but for the reverse (T, NT) broadcasting cases
|
|
XFailRule(
|
|
error_type=ValueError,
|
|
error_msg="expected condition to be a jagged layout NestedTensor",
|
|
op_match_fn=lambda device, op: (op.full_name == "__rpow__"),
|
|
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
|
|
name="broken_rpow_backward",
|
|
),
|
|
# linear(): some formula problem when bias is used; seems to be platform-specific
|
|
# (fails locally but not in CI)
|
|
SkipRule(
|
|
# result2.use_count() <= 1 INTERNAL ASSERT FAILED
|
|
op_match_fn=lambda device, op: (op.full_name == "nn.functional.linear"),
|
|
sample_match_fn=lambda device, sample: ("with bias" in sample.name),
|
|
name="broken_linear_backward",
|
|
),
|
|
# narrow(): unimplemented backward
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="derivative for aten::narrow is not implemented",
|
|
op_match_fn=lambda device, op: (op.full_name == "narrow"),
|
|
name="broken_narrow_backward",
|
|
),
|
|
# min / max: need factory function support for ragged dim reductions
|
|
# where the output is dense but sizes still contain a nested int
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="SymIntArrayRef expected to contain only concrete integers",
|
|
op_match_fn=lambda device, op: (
|
|
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
|
|
),
|
|
sample_match_fn=lambda device, sample: ("ragged dim" in sample.name),
|
|
name="broken_min_max_reduction_with_dim_backward_on_ragged_dim",
|
|
),
|
|
# copysign(): formula is broken for (T, NT) broadcasting
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="SymIntArrayRef expected to contain only concrete integers",
|
|
op_match_fn=lambda device, op: (op.full_name == "copysign"),
|
|
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
|
|
name="broken_copysign_backward",
|
|
),
|
|
# amin() / amax(): broken in a host of ways I don't think it's a good use of time
|
|
# to try to sift through
|
|
SkipRule(
|
|
op_match_fn=lambda device, op: (op.full_name in {"amin", "amax"}),
|
|
name="broken_amin_amax_backward",
|
|
),
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="reducing across the ragged dimension is not supported for non-contiguous",
|
|
op_match_fn=lambda device, op: (
|
|
isinstance(op, BinaryUfuncInfo)
|
|
# doesn't happen for these ops for some reason
|
|
and op.full_name
|
|
not in {"copysign", "max.binary", "maximum", "min.binary", "minimum"}
|
|
),
|
|
sample_match_fn=lambda device, sample: (
|
|
"(NT, T) broadcasting all 1s" in sample.name
|
|
and "noncontig_holes" in sample.name
|
|
),
|
|
name="binary_noncontig_holes_ragged_dim_reduction",
|
|
),
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="reducing across the ragged dimension is not supported for non-contiguous",
|
|
op_match_fn=lambda device, op: (op.full_name == "nn.functional.rms_norm"),
|
|
sample_match_fn=lambda device, sample: (sample.input._lengths is not None),
|
|
name="rms_norm_noncontig_holes_ragged_dim_reduction",
|
|
),
|
|
# expected: autodiff on complex dtype is not supported
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg=(
|
|
"_nested_view_from_jagged does not support automatic differentiation "
|
|
"for outputs with complex dtype"
|
|
),
|
|
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}),
|
|
name="no_complex_autodiff",
|
|
),
|
|
# Bug: need to use the correct nested int in the return shape
|
|
XFailRule(
|
|
error_type=RuntimeError,
|
|
error_msg="Function CloneBackward0 returned an invalid gradient",
|
|
op_match_fn=lambda device, op: (op.full_name == "clone"),
|
|
sample_match_fn=lambda device, sample: (
|
|
sample.kwargs.get("memory_format", None) == torch.contiguous_format
|
|
),
|
|
name="clone_wrong_nested_int_for_gradient",
|
|
),
|
|
# some min / max ops use masked_fill_ underneath sometimes, which isn't implemented
|
|
XFailRule(
|
|
error_type=NotImplementedError,
|
|
error_msg="aten.masked_fill_.Scalar",
|
|
op_match_fn=lambda device, op: (
|
|
op.full_name
|
|
in {"max.binary", "min.binary", "minimum", "maximum", "copysign"}
|
|
),
|
|
name="unimplemented_masked_fill",
|
|
),
|
|
]
|
|
|
|
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
|
|
*FORWARD_SKIPS_AND_XFAILS,
|
|
# Bug: cross-device conversions with to() result in new nested ints within compile only
|
|
XFailRule(
|
|
error_type=AssertionError,
|
|
error_msg="The values for attribute 'shape' do not match",
|
|
op_match_fn=lambda device, op: (op.full_name == "to"),
|
|
sample_match_fn=lambda device, sample: ("-> cpu" in sample.name),
|
|
name="cross_device_transfer_wrong_nested_int_in_compile",
|
|
),
|
|
# clone() -> preserve format on an non-contiguous NJT with holes currently uses
|
|
# unbind(), leading to data-dependent expression. Should be fixed via torch._check()
|
|
XFailRule(
|
|
error_type=torch._dynamo.exc.Unsupported,
|
|
# Ne(u1, u0) (unhinted: Ne(u1, u0)). (Size-like symbols: u1, u0)
|
|
error_msg="Could not guard on data-dependent expression",
|
|
op_match_fn=lambda device, op: (op.full_name == "clone"),
|
|
sample_match_fn=lambda device, sample: (
|
|
"noncontig_holes" in sample.name
|
|
and sample.kwargs.get("memory_format", None) == torch.contiguous_format
|
|
),
|
|
name="clone_unbind_data_dependency",
|
|
),
|
|
# chunk(): broken in several ways on the batch dim; revisit after similar
|
|
# data-dependency issues are handled for narrow()
|
|
SkipRule(
|
|
op_match_fn=lambda device, op: (op.full_name == "chunk"),
|
|
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
|
|
name="broken_chunk_compile_backward_on_batch_dim",
|
|
),
|
|
# select on batch dim currently uses unbind(), leading to data-dependent error in
|
|
# torch.compile that needs to be addressed via torch._check()
|
|
XFailRule(
|
|
error_type=torch._dynamo.exc.InternalTorchDynamoError,
|
|
error_msg="Pending unbacked symbols",
|
|
op_match_fn=lambda device, op: (op.full_name == "select"),
|
|
sample_match_fn=lambda device, sample: ("batch_dim" in sample.name),
|
|
name="broken_select_backward_unbacked",
|
|
),
|
|
]
|
|
|
|
COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
|
|
# non-contiguous with holes inputs + torch.compile doesn't work great today; need
|
|
# torch._check() statements. Skip these and handle them later.
|
|
SkipRule(
|
|
op_match_fn=lambda device, op: True,
|
|
sample_match_fn=lambda device, sample: ("noncontig_holes" in sample.name),
|
|
name="noncontig_holes_data_dependency",
|
|
),
|
|
# mean(): weird bug
|
|
XFailRule(
|
|
error_type=torch._dynamo.exc.BackendCompilerFailed,
|
|
error_msg="'NestedIntNode' object has no attribute 'sub'",
|
|
op_match_fn=lambda device, op: (op.full_name == "mean"),
|
|
sample_match_fn=lambda device, sample: (
|
|
"full reduction" not in sample.name
|
|
and "normal dim reduction" not in sample.name
|
|
),
|
|
name="broken_mean_compile_backward",
|
|
),
|
|
# min() / max(): weird bug
|
|
XFailRule(
|
|
error_type=AttributeError,
|
|
error_msg="'ConstantIntNode' object has no attribute 'add'",
|
|
op_match_fn=lambda device, op: (
|
|
op.full_name in {"max.reduction_with_dim", "min.reduction_with_dim"}
|
|
),
|
|
sample_match_fn=lambda device, sample: ("ragged dim" in sample.name),
|
|
name="broken_min_max_compile_backward",
|
|
),
|
|
# to() fails with data-dependent guards OR Unknown layout in record_stream_any_impl;
|
|
# need to fix with torch._check(), etc.
|
|
XFailRule(
|
|
op_match_fn=lambda device, op: (op.full_name == "to"),
|
|
sample_match_fn=lambda device, sample: ("-> cpu" in sample.name),
|
|
name="to_data_dependency",
|
|
),
|
|
# copysign(): formula is broken for (T, NT) broadcasting
|
|
XFailRule(
|
|
error_type=AttributeError,
|
|
error_msg="'ConstantIntNode' object has no attribute 'add'",
|
|
op_match_fn=lambda device, op: (op.full_name == "copysign"),
|
|
sample_match_fn=lambda device, sample: ("(T, NT)" in sample.name),
|
|
name="broken_copysign_compile_backward",
|
|
),
|
|
# in compile, these complex ops use view_as_real(), which isn't implemented
|
|
XFailRule(
|
|
error_type=NotImplementedError,
|
|
error_msg="aten.view_as_real.default",
|
|
op_match_fn=lambda device, op: (op.full_name in {"cdouble", "cfloat", "chalf"}),
|
|
name="unimplemented_view_as_real",
|
|
),
|
|
*COMPILE_FORWARD_SKIPS_AND_XFAILS,
|
|
*BACKWARD_SKIPS_AND_XFAILS,
|
|
]
|
|
|
|
COMPARE_TENSOR_COMPONENT_EQUALITY = {
|
|
# masked_select is expected to output a different shape
|
|
"masked_select",
|
|
}
|
|
|
|
|
|
# OpInfo-based NJT tests. These tests utilize an NJT-specific op_db generated from the standard
|
|
# op_db. Note that certain tradeoffs were made wrt coverage vs. time spent running tests:
|
|
# * All tests run with dtype=torch.float32 only
|
|
class TestNestedTensorOpInfo(NestedTensorTestCase):
|
|
# TODO: move this
|
|
def _gen_grad_outputs(self, out_val):
|
|
if isinstance(out_val, (list, tuple)):
|
|
need_grad_outs = tuple(o for o in out_val if o.grad_fn is not None)
|
|
grad_outputs = tuple(
|
|
torch.ones_like(o) for o in out_val if o.grad_fn is not None
|
|
)
|
|
return need_grad_outs, grad_outputs
|
|
else:
|
|
return out_val, (torch.ones_like(out_val),)
|
|
|
|
@ops(
|
|
[op for op in njt_op_db if op.supports_njt],
|
|
allowed_dtypes=(torch.float32,),
|
|
)
|
|
@tf32_on_and_off(0.005)
|
|
@sample_skips_and_xfails(FORWARD_SKIPS_AND_XFAILS)
|
|
def test_forward(self, device, dtype, op):
|
|
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
|
|
device=device,
|
|
dtype=dtype,
|
|
requires_grad=False,
|
|
use_subtests=True,
|
|
):
|
|
with subtest_ctx(self), skip_xfail_ctx(self):
|
|
# compare to reference, but expect different nested int
|
|
out = op.op(sample.input, *sample.args, **sample.kwargs)
|
|
out_ref = op.ref(op, sample)
|
|
self.assertEqualIgnoringNestedInts(out, out_ref)
|
|
if op._extra_op_data.is_view:
|
|
tree_map_only(
|
|
NestedTensor, lambda x: self.assertTrue(x._is_view()), out
|
|
)
|
|
|
|
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
|
|
# TODO: Add xfails for other inplace ops instead of hardcoding
|
|
if op.inplace_variant and "index_put" in op.full_name:
|
|
op.inplace_variant(sample.input, *sample.args, **sample.kwargs)
|
|
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
|
|
|
|
@ops(
|
|
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
|
|
allowed_dtypes=(torch.float32,),
|
|
)
|
|
@tf32_on_and_off(0.005)
|
|
@sample_skips_and_xfails(BACKWARD_SKIPS_AND_XFAILS)
|
|
def test_backward(self, device, dtype, op):
|
|
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
|
|
device=device, dtype=dtype, requires_grad=True, use_subtests=True
|
|
):
|
|
with subtest_ctx(self), skip_xfail_ctx(self):
|
|
# compare to reference, but expect different nested int
|
|
out = op.op(sample.input, *sample.args, **sample.kwargs)
|
|
out_ref = op.ref(op, sample)
|
|
self.assertEqualIgnoringNestedInts(out, out_ref)
|
|
if op._extra_op_data.is_view:
|
|
tree_map_only(
|
|
NestedTensor, lambda x: self.assertTrue(x._is_view()), out
|
|
)
|
|
|
|
inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
|
|
g_inps = [
|
|
inp
|
|
for inp in inps
|
|
if isinstance(inp, torch.Tensor) and inp.requires_grad
|
|
]
|
|
if len(g_inps) > 0:
|
|
need_grad_outs, grad_outputs = self._gen_grad_outputs(out)
|
|
grads = torch.autograd.grad(
|
|
need_grad_outs, inputs=g_inps, grad_outputs=grad_outputs
|
|
)
|
|
|
|
need_grad_outs, grad_outputs = self._gen_grad_outputs(out_ref)
|
|
grads_ref = torch.autograd.grad(
|
|
need_grad_outs, inputs=g_inps, grad_outputs=grad_outputs
|
|
)
|
|
|
|
self.assertEqualNoncontigAware(grads, grads_ref)
|
|
|
|
@ops(
|
|
[op for op in njt_op_db if op.supports_njt],
|
|
allowed_dtypes=(torch.float32,),
|
|
)
|
|
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
|
# needed to avoid "data dependent operator: aten._local_scalar_dense.default"
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
@sample_skips_and_xfails(COMPILE_FORWARD_SKIPS_AND_XFAILS)
|
|
def test_compile_forward(self, device, dtype, op):
|
|
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
|
|
device=device, dtype=dtype, requires_grad=False, use_subtests=True
|
|
):
|
|
with subtest_ctx(self), skip_xfail_ctx(self):
|
|
torch.compiler.reset()
|
|
|
|
op_fn = op.op
|
|
|
|
def f(*args, **kwargs):
|
|
return op_fn(*args, **kwargs)
|
|
|
|
compiled_f = torch.compile(
|
|
f, fullgraph=True, backend="aot_eager_decomp_partition"
|
|
)
|
|
|
|
out_ref = f(sample.input, *sample.args, **sample.kwargs)
|
|
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
|
|
if op._extra_op_data.is_view:
|
|
tree_map_only(
|
|
NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref
|
|
)
|
|
|
|
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
|
|
self.assertEqualIgnoringNestedInts(out_compile, out_ref)
|
|
else:
|
|
self.assertEqual(out_compile, out_ref)
|
|
|
|
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
|
|
# TODO: Add xfails for other inplace ops instead of hardcoding
|
|
if op.inplace_variant and "index_put" in op.full_name:
|
|
op_fn = op.inplace_variant
|
|
|
|
def in_f(*args, **kwargs):
|
|
return op_fn(*args, **kwargs)
|
|
|
|
compiled_in_f = torch.compile(
|
|
in_f, fullgraph=True, backend="aot_eager_decomp_partition"
|
|
)
|
|
|
|
compiled_in_f(sample.input, *sample.args, **sample.kwargs)
|
|
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
|
|
self.assertEqualIgnoringNestedInts(sample.input, out_ref)
|
|
else:
|
|
self.assertEqual(sample.input, out_ref)
|
|
|
|
@ops(
|
|
[op for op in njt_op_db if op.supports_njt and op.supports_autograd],
|
|
allowed_dtypes=(torch.float32,),
|
|
)
|
|
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
|
# needed to avoid "data dependent operator: aten._local_scalar_dense.default"
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
@sample_skips_and_xfails(COMPILE_BACKWARD_SKIPS_AND_XFAILS)
|
|
def test_compile_backward(self, device, dtype, op):
|
|
for sample, subtest_ctx, skip_xfail_ctx in op.sample_inputs(
|
|
device=device, dtype=dtype, requires_grad=True, use_subtests=True
|
|
):
|
|
with subtest_ctx(self), skip_xfail_ctx(self):
|
|
torch.compiler.reset()
|
|
|
|
op_fn = op.op
|
|
|
|
def f(*args, **kwargs):
|
|
return op_fn(*args, **kwargs)
|
|
|
|
compiled_f = torch.compile(
|
|
f, fullgraph=True, backend="aot_eager_decomp_partition"
|
|
)
|
|
|
|
out_ref = f(sample.input, *sample.args, **sample.kwargs)
|
|
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
|
|
if op._extra_op_data.is_view:
|
|
tree_map_only(
|
|
NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref
|
|
)
|
|
|
|
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
|
|
self.assertEqualIgnoringNestedInts(out_compile, out_ref)
|
|
else:
|
|
self.assertEqual(out_compile, out_ref)
|
|
|
|
inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
|
|
g_inps = [
|
|
inp
|
|
for inp in inps
|
|
if isinstance(inp, torch.Tensor) and inp.requires_grad
|
|
]
|
|
if len(g_inps) > 0:
|
|
need_grad_outs, grad_outputs = self._gen_grad_outputs(out_compile)
|
|
grads_compile = torch.autograd.grad(
|
|
need_grad_outs,
|
|
inputs=g_inps,
|
|
grad_outputs=grad_outputs,
|
|
)
|
|
|
|
need_grad_outs, grad_outputs = self._gen_grad_outputs(out_ref)
|
|
grads_ref = torch.autograd.grad(
|
|
need_grad_outs,
|
|
inputs=g_inps,
|
|
grad_outputs=grad_outputs,
|
|
)
|
|
|
|
self.assertEqualNoncontigAware(grads_compile, grads_ref)
|
|
|
|
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
|
# needed to avoid "data dependent operator: aten._local_scalar_dense.default"
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
@skipIfTorchDynamo(
|
|
"Dynamo fails on pending unbacked symints at assertEqual(ref_y[0][0][0].item(), 2)"
|
|
)
|
|
def test_nested_tensor_non_contiguous_mutation(self):
|
|
def fn(x, x0):
|
|
x[0, 0, 0] = 2
|
|
return x
|
|
|
|
def _inp():
|
|
base = torch.zeros(32, 3)
|
|
v = base.t()
|
|
return torch.nested.nested_tensor_from_jagged(
|
|
v,
|
|
offsets=torch.tensor([0, 2, 3]),
|
|
), torch.ones(2, 32)
|
|
|
|
ref_x, ref_x0 = _inp()
|
|
ref_y = fn(ref_x, ref_x0)
|
|
|
|
self.assertEqual(ref_y[0][0][0].item(), 2)
|
|
|
|
y = torch.compile(fn, fullgraph=True, backend="aot_eager")(*_inp())
|
|
self.assertEqual(y[0][0][0], 2)
|
|
|
|
def test_nested_tensor_input_mutation_backward(self):
|
|
# See Note [AOTAutograd Tangent Subclassness for mutated inputs]
|
|
# NJT tangent is always subclass, See torch/csrc/autograd/python_function.cpp, use_zeros_like.
|
|
# This test checks that AOTD correctly guess NJT tangent as NJT.
|
|
def fn(x):
|
|
x.mul_(2)
|
|
return x + 1
|
|
|
|
def _inp():
|
|
v = torch.zeros(32, 3, requires_grad=True)
|
|
return torch.nested.nested_tensor_from_jagged(
|
|
v,
|
|
offsets=torch.tensor([0, 2, 3]),
|
|
).clone()
|
|
|
|
ref_x = _inp()
|
|
ref_y = fn(ref_x)
|
|
ref_y.sum().backward()
|
|
|
|
x = _inp()
|
|
y = torch.compile(fn, fullgraph=True, backend="aot_eager")(x)
|
|
y.sum().backward()
|
|
|
|
|
|
from torch.nested._internal.nested_int import NestedIntNode
|
|
|
|
|
|
class TestNestedInt(torch.testing._internal.common_utils.TestCase):
|
|
def test_comparisons(self):
|
|
a = torch.SymInt(NestedIntNode(1, 1))
|
|
b = torch.SymInt(NestedIntNode(1, 1))
|
|
c = torch.SymInt(NestedIntNode(2, 1))
|
|
d = 3
|
|
|
|
self.assertTrue(a == a)
|
|
self.assertTrue(a == b)
|
|
self.assertFalse(a != a)
|
|
self.assertFalse(a != b)
|
|
self.assertFalse(a == c)
|
|
self.assertTrue(a != c)
|
|
|
|
self.assertFalse(a == d)
|
|
self.assertTrue(a != d)
|
|
self.assertFalse(d == a)
|
|
self.assertTrue(d != a)
|
|
|
|
# ge
|
|
self.assertTrue(a >= a)
|
|
self.assertTrue(a >= b)
|
|
self.assertTrue(b >= a)
|
|
with self.assertRaises(ValueError):
|
|
_ = a >= c
|
|
with self.assertRaises(ValueError):
|
|
_ = c >= a
|
|
with self.assertRaises(ValueError):
|
|
_ = c >= 3
|
|
self.assertTrue(c >= 2)
|
|
self.assertTrue(c >= 1)
|
|
self.assertFalse(c <= 1)
|
|
|
|
# lt
|
|
self.assertFalse(a < a)
|
|
self.assertFalse(a < b)
|
|
self.assertFalse(b < a)
|
|
with self.assertRaises(ValueError):
|
|
_ = a < c
|
|
with self.assertRaises(ValueError):
|
|
_ = c < a
|
|
with self.assertRaises(ValueError):
|
|
_ = 3 < a
|
|
with self.assertRaises(ValueError):
|
|
_ = 2 < a
|
|
self.assertTrue(a > 1)
|
|
|
|
# le
|
|
self.assertTrue(a <= a)
|
|
self.assertTrue(b <= a)
|
|
self.assertTrue(a <= b)
|
|
with self.assertRaises(ValueError):
|
|
_ = a <= c
|
|
with self.assertRaises(ValueError):
|
|
_ = c <= a
|
|
with self.assertRaises(ValueError):
|
|
_ = 3 <= c
|
|
self.assertTrue(c >= 2)
|
|
self.assertTrue(c >= 1)
|
|
self.assertFalse(c <= 1)
|
|
|
|
# gt
|
|
self.assertFalse(a > a)
|
|
self.assertFalse(b > a)
|
|
self.assertFalse(a > b)
|
|
with self.assertRaises(ValueError):
|
|
_ = a > c
|
|
with self.assertRaises(ValueError):
|
|
_ = c > a
|
|
with self.assertRaises(ValueError):
|
|
_ = a > 3
|
|
with self.assertRaises(ValueError):
|
|
_ = a > 2
|
|
self.assertTrue(a > 1)
|
|
|
|
def test_with_factor(self):
|
|
a = torch.SymInt(NestedIntNode(1, 5))
|
|
b = torch.SymInt(NestedIntNode(1, 10))
|
|
# eq
|
|
self.assertFalse(a == b)
|
|
self.assertFalse(a >= b)
|
|
self.assertTrue(b >= a)
|
|
self.assertTrue(a <= b)
|
|
self.assertFalse(b <= a)
|
|
# ne
|
|
self.assertTrue(a != b)
|
|
# mul
|
|
self.assertTrue(a * 2 == b)
|
|
self.assertTrue(a * 3 >= b)
|
|
self.assertTrue(a * 2 == 2 * a)
|
|
|
|
|
|
instantiate_parametrized_tests(TestNestedTensor)
|
|
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
|
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
|
|
instantiate_device_type_tests(TestNestedTensorSubclass, globals())
|
|
instantiate_device_type_tests(TestNestedTensorOpInfo, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|