Files
pytorch/test/test_namedtensor.py
Zsolt Dollenstein b004307252 [codemod][lint][fbcode/c*] Enable BLACK by default
Test Plan: manual inspection & sandcastle

Reviewed By: zertosh

Differential Revision: D30279364

fbshipit-source-id: c1ed77dfe43a3bde358f92737cd5535ae5d13c9a
2021-08-12 10:58:35 -07:00

2344 lines
84 KiB
Python

import functools
import io
import itertools
import pickle
import sys
import unittest
import warnings
from collections import namedtuple, OrderedDict
from multiprocessing.reduction import ForkingPickler
import torch
import torch.nn.functional as F
from torch import Tensor
from torch.testing._internal.common_cuda import TEST_CUDA
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
def pass_name_to_python_arg_parser(name):
x = torch.empty(2, names=(name,))
def flatten(lst):
return [item for sublist in lst for item in sublist]
Function = namedtuple("TestCase", ["name", "lambd"])
def parse_compressed_namedshape(string):
# This is a metalanguage for describing a shape of a tensor compactly.
# 'N:3,C:2' -> size = [3, 2], names: ['N', 'C']
# 'None:3,None:2' -> size = [3, 2], names: ['None', 'None']
# '3,2' -> size = [3, 2], names=None passed to ctor.
def parse_name(maybe_name):
maybe_name = maybe_name.strip()
if maybe_name == "None":
return None
return maybe_name
string = string.strip()
# '' -> size: [], names:None
if len(string) == 0:
return None, []
# '3, 2' -> size = [3, 2], None names.
if ":" not in string:
return None, [int(size) for size in string.split(",")]
dims = string.split(",")
tuples = [dim.split(":") for dim in dims]
return zip(*[(parse_name(name), int(size)) for name, size in tuples])
def create(namedshape, factory=torch.randn):
# namedshape: str
names, shape = parse_compressed_namedshape(namedshape)
return factory(shape, names=names)
def out_fn(operator):
@functools.wraps(operator)
def fn(*inputs):
return operator(*inputs[1:], out=inputs[0])
return fn
class TestNamedTensor(TestCase):
def test_aaa_must_run_first_check_experimental_warning(self):
# TODO(rzou): It would be nice for this to be a "real" python warning.
# Right now this error message only prints once and doesn't respect
# warnings.simplefilter behavior (where python users can control whether
# or not to display warnings once, all the time, or never).
with warnings.catch_warnings(record=True) as warns:
x = torch.randn(3, 3, names=("N", "C"))
self.assertEqual(len(warns), 1)
self.assertTrue(
str(warns[0].message).startswith(
"Named tensors and all their associated APIs are an experimental feature"
)
)
def test_trivial(self):
pass
def _test_name_inference(
self, op, args=(), expected_names=(), device="cpu", maybe_raises_regex=None
):
casted_args = [
arg.to(device) if isinstance(arg, torch.Tensor) else arg for arg in args
]
if maybe_raises_regex is not None:
with self.assertRaisesRegex(RuntimeError, maybe_raises_regex):
result = op(*args)
return
result = op(*args)
self.assertEqual(
result.names,
expected_names,
msg="Name inference for {} on device {} failed".format(op.__name__, device),
)
# TODO(rzou): Some form of this check should be added to self.assertEqual.
# Right now I don't know what it should look like.
def assertTensorDataAndNamesEqual(self, x, y):
self.assertEqual(x.names, y.names)
unnamed_x = x.rename(None)
unnamed_y = y.rename(None)
self.assertEqual(unnamed_x, unnamed_y)
def _test_factory(self, factory, device):
x = factory([], device=device)
self.assertEqual(x.names, ())
x = factory(1, 2, 3, device=device)
self.assertEqual(x.names, (None, None, None))
x = factory(1, 2, 3, names=None, device=device)
self.assertEqual(x.names, (None, None, None))
x = factory(1, 2, 3, names=("N", "T", "D"), device=device)
self.assertEqual(x.names, ("N", "T", "D"))
x = factory(1, 2, 3, names=("N", None, "D"), device=device)
self.assertEqual(x.names, ("N", None, "D"))
x = factory(1, 2, 3, names=("_1", "batch9", "BATCH_5"), device=device)
self.assertEqual(x.names, ("_1", "batch9", "BATCH_5"))
with self.assertRaisesRegex(RuntimeError, "a valid identifier contains only"):
x = factory(2, names=("1",), device=device)
with self.assertRaisesRegex(RuntimeError, "a valid identifier contains only"):
x = factory(2, names=("?",), device=device)
with self.assertRaisesRegex(RuntimeError, "Number of names"):
x = factory(2, 1, names=("N",), device=device)
with self.assertRaisesRegex(TypeError, "invalid combination of arguments"):
x = factory(2, 1, names="N", device=device)
with self.assertRaisesRegex(
RuntimeError, "construct a tensor with duplicate names"
):
x = factory(2, 1, 1, names=("N", "C", "N"), device=device)
names64 = ["A" * i for i in range(1, 65)]
x = factory([1] * 64, names=names64, device=device)
self.assertEqual(x.names, names64)
with self.assertRaisesRegex(RuntimeError, "only support up to 64 dims"):
names65 = ["A" * i for i in range(1, 66)]
x = factory([1] * 65, names=names64, device=device)
def test_none_names_refcount(self, N=10):
def scope():
unnamed = torch.empty(2, 3)
unnamed.names # materialize [None, None]
prev_none_refcnt = sys.getrefcount(None)
# Ran it N times to reduce flakiness
[scope() for i in range(N)]
after_none_refcnt = sys.getrefcount(None)
self.assertTrue(
after_none_refcnt - prev_none_refcnt < N / 2,
msg="Using tensor.names should not change " "the refcount of Py_None",
)
def test_has_names(self):
unnamed = torch.empty(2, 3)
none_named = torch.empty(2, 3, names=(None, None))
partially_named = torch.empty(2, 3, names=("N", None))
fully_named = torch.empty(2, 3, names=("N", "C"))
self.assertFalse(unnamed.has_names())
self.assertFalse(none_named.has_names())
self.assertTrue(partially_named.has_names())
self.assertTrue(fully_named.has_names())
def test_py3_ellipsis(self):
tensor = torch.randn(2, 3, 5, 7)
output = tensor.refine_names("N", ..., "C")
self.assertEqual(output.names, ["N", None, None, "C"])
def test_refine_names(self):
# Unnamed tensor -> Unnamed tensor
self._test_name_inference(
Tensor.refine_names,
[create("None:1,None:2,None:3"), "N", "C", "H"],
["N", "C", "H"],
)
# Named tensor -> Named tensor
self._test_name_inference(
Tensor.refine_names, [create("N:1,C:2,H:3"), "N", "C", "H"], ["N", "C", "H"]
)
# Partially named tensor -> named tensor
self._test_name_inference(
Tensor.refine_names,
[create("None:1,C:2,None:3"), None, "C", "H"],
[None, "C", "H"],
)
# Too few names
self._test_name_inference(
Tensor.refine_names,
[create("None:2,None:3"), "N", "C", "H"],
maybe_raises_regex="different number of dims",
)
# Cannot change Tensor[D] to Tensor[N]
self._test_name_inference(
Tensor.refine_names,
[create("D:3"), "N"],
maybe_raises_regex="is different from",
)
# Cannot change Tensor[D] to Tensor[None]
self._test_name_inference(
Tensor.refine_names,
[create("D:3"), None],
maybe_raises_regex="'D' is more specific than None",
)
# globbing behavior exists
self._test_name_inference(
Tensor.refine_names,
[create("None:1,None:1,None:2,None:3"), "...", "C", "H"],
[None, None, "C", "H"],
)
def test_detach(self):
names = ["N"]
self._test_name_inference(
Tensor.detach_, [torch.randn(3, requires_grad=True, names=names)], names
)
self._test_name_inference(
Tensor.detach, [torch.randn(3, requires_grad=True, names=names)], names
)
def test_index_fill(self):
for device in torch.testing.get_all_device_types():
expected_names = ("N", "C")
x = torch.randn(3, 5, device=device, names=expected_names)
output = x.index_fill_("C", torch.tensor([0, 1], device=device), 5)
self.assertEqual(output.names, expected_names)
output = x.index_fill_(
"C", torch.tensor([0, 1], device=device), torch.tensor(4.0)
)
self.assertEqual(output.names, expected_names)
output = x.index_fill("C", torch.tensor([0, 1], device=device), 5)
self.assertEqual(output.names, expected_names)
output = x.index_fill(
"C", torch.tensor([0, 1], device=device), torch.tensor(4.0)
)
self.assertEqual(output.names, expected_names)
def test_equal(self):
for device in torch.testing.get_all_device_types():
tensor = torch.randn(2, 3, device=device)
other = tensor.clone()
self.assertTrue(
torch.equal(tensor.rename("N", "C"), other.rename("N", "C"))
)
self.assertFalse(
torch.equal(tensor.rename("M", "C"), other.rename("N", "C"))
)
self.assertFalse(
torch.equal(tensor.rename(None, "C"), other.rename("N", "C"))
)
def test_squeeze(self):
x = create("N:3,C:1,H:1,W:1")
output = x.squeeze("C")
self.assertEqual(output.names, ["N", "H", "W"])
output = x.squeeze()
self.assertEqual(output.names, ["N"])
def test_repr(self):
named_tensor = torch.zeros(2, 3).rename_("N", "C")
expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]], names=('N', 'C'))"
self.assertEqual(repr(named_tensor), expected)
unnamed_tensor = torch.zeros(2, 3)
expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]])"
self.assertEqual(repr(unnamed_tensor), expected)
none_named_tensor = torch.zeros(2, 3).rename_(None, None)
self.assertEqual(repr(none_named_tensor), expected)
def test_diagonal(self):
named_tensor = torch.zeros(2, 3, 5, 7, names=list("ABCD"))
self.assertEqual(named_tensor.diagonal().names, ["C", "D", None])
self.assertEqual(named_tensor.diagonal(1, 3).names, ["A", "C", None])
self.assertEqual(
named_tensor.diagonal(outdim="E", dim1="B", dim2="D").names, ["A", "C", "E"]
)
def test_max_pooling(self):
def check_tuple_return(op, inputs, expected_names):
values, indices = op(*inputs)
self.assertEqual(values.names, expected_names)
self.assertEqual(indices.names, expected_names)
for device in torch.testing.get_all_device_types():
named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list("ABC"))
named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list("ABCD"))
named_tensor_3d = torch.zeros(
2, 3, 5, 7, 9, device=device, names=list("ABCDE")
)
self.assertEqual(
F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names
)
self.assertEqual(
F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names
)
self.assertEqual(
F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names
)
check_tuple_return(
F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names
)
check_tuple_return(
F.max_pool2d_with_indices,
[named_tensor_2d, [2, 2]],
named_tensor_2d.names,
)
check_tuple_return(
F.max_pool3d_with_indices,
[named_tensor_3d, [2, 2, 2]],
named_tensor_3d.names,
)
def test_max_pooling_without_names_does_not_warn(self):
for device in torch.testing.get_all_device_types():
tensor_2d = torch.zeros(2, 3, 5, 7, device=device, requires_grad=True)
with warnings.catch_warnings(record=True) as warns:
warnings.simplefilter("always")
result = F.max_pool2d(tensor_2d, [2, 2])
result.sum().backward()
self.assertEqual(len(warns), 0)
def test_no_save_support(self):
named_tensor = torch.zeros(2, 3, names=("N", "C"))
buf = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "NYI"):
torch.save(named_tensor, buf)
def test_no_pickle_support(self):
named_tensor = torch.zeros(2, 3, names=("N", "C"))
with self.assertRaisesRegex(RuntimeError, "NYI"):
serialized = pickle.dumps(named_tensor)
def test_no_multiprocessing_support(self):
named_tensor = torch.zeros(2, 3, names=("N", "C"))
buf = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "NYI"):
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor)
def test_big_tensor_repr_has_names(self):
def check_repr(named_tensor):
unnamed_tensor = named_tensor.rename(None)
names_tag = "names={}".format(named_tensor.names)
self.assertIn(names_tag, repr(named_tensor))
check_repr(torch.randn(128, 3, 64, 64, names=("N", "C", "H", "W")))
def test_noncontig_contiguous(self):
# This type of contiguous is special-cased and therefore needs its own test
for device in torch.testing.get_all_device_types():
x = torch.randn(2, 3, device=device).t().rename_("N", "C")
self.assertEqual(x.contiguous().names, ("N", "C"))
def test_copy_transpose(self):
# This type of copy is special-cased and therefore needs its own test
def _test(self_names, other_names, expected_names):
x = torch.empty(2, 5, names=self_names)
y = torch.empty(5, 2).t().rename_(*other_names)
x.copy_(y)
self.assertEqual(x.names, expected_names)
_test(("N", "C"), ("N", "C"), ("N", "C"))
_test(None, ("N", "C"), ("N", "C"))
def test_rename_(self):
tensor = torch.empty(1, 1, names=("N", "C"))
self.assertEqual(tensor.rename_(None).names, (None, None))
self.assertEqual(tensor.rename_("H", "W").names, ("H", "W"))
with self.assertRaisesRegex(RuntimeError, "Number of names"):
tensor.rename_("N", "C", "W")
with self.assertRaisesRegex(RuntimeError, "duplicate names"):
tensor.rename_("N", "N")
def test_rename(self):
tensor = torch.empty(1, 1, names=("N", "C"))
self.assertEqual(tensor.rename(None).names, (None, None))
self.assertEqual(tensor.rename("H", "W").names, ("H", "W"))
# Check that we didn't modify tensor.names
self.assertEqual(tensor.names, ("N", "C"))
with self.assertRaisesRegex(RuntimeError, "Number of names"):
tensor.rename("N", "C", "W")
with self.assertRaisesRegex(RuntimeError, "duplicate names"):
tensor.rename("N", "N")
with self.assertRaisesRegex(
RuntimeError, "either positional args or keyword args"
):
tensor.rename(None, N="batch")
# rename returns a view on the tensor
self.assertEqual(tensor.rename("H", "W").data_ptr(), tensor.data_ptr())
self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr())
def test_rename_globber(self):
scalar = torch.randn([])
unnamed_tensor = torch.empty(1, 1, 1, 1)
named_tensor = torch.empty(1, 1, 1, 1, names=("N", "C", "H", "W"))
self.assertEqual(scalar.rename(None).names, [])
self.assertEqual(scalar.rename("...").names, [])
# Check that it works with unnamed tensors
self.assertEqual(unnamed_tensor.rename("...").names, unnamed_tensor.names)
self.assertEqual(
unnamed_tensor.rename("...", "H", "W").names, [None, None, "H", "W"]
)
self.assertEqual(
unnamed_tensor.rename("N", "...", "W").names, ["N", None, None, "W"]
)
self.assertEqual(
unnamed_tensor.rename("N", "C", "...").names, ["N", "C", None, None]
)
# Check that it works with named tensors
self.assertEqual(named_tensor.rename("...").names, named_tensor.names)
self.assertEqual(
named_tensor.rename("...", "width").names, ["N", "C", "H", "width"]
)
self.assertEqual(
named_tensor.rename("batch", "channels", "...", "width").names,
["batch", "channels", "H", "width"],
)
self.assertEqual(
named_tensor.rename("batch", "...").names, ["batch", "C", "H", "W"]
)
# Test empty glob
self.assertEqual(
unnamed_tensor.rename("...", None, None, None, None).names,
[None, None, None, None],
)
self.assertEqual(
named_tensor.rename("N", "C", "H", "...", "W").names, ["N", "C", "H", "W"]
)
# Multiple globs throw
with self.assertRaisesRegex(RuntimeError, "More than one "):
named_tensor.rename("...", "channels", "...")
def test_rename_rename_map(self):
scalar = torch.randn([])
unnamed_tensor = torch.empty(1, 1, 1, 1)
named_tensor = torch.empty(1, 1, 1, 1, names=("N", "C", "H", "W"))
with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
scalar.rename(N="batch")
with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
unnamed_tensor.rename(N="batch")
with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
named_tensor.rename(B="batch")
with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
named_tensor.rename(H="height", B="batch")
self.assertEqual(
named_tensor.rename(N="batch").data_ptr(), named_tensor.data_ptr()
)
self.assertEqual(named_tensor.rename(N="batch").names, ["batch", "C", "H", "W"])
self.assertEqual(
named_tensor.rename(N="batch", H="height").names,
["batch", "C", "height", "W"],
)
def test_set_names_property(self):
tensor = torch.empty(1, 1, names=("N", "C"))
tensor.names = None
self.assertEqual(tensor.names, (None, None))
tensor.names = ("N", "W")
self.assertEqual(tensor.names, ("N", "W"))
with self.assertRaisesRegex(RuntimeError, "Number of names"):
tensor.names = ["N", "C", "W"]
with self.assertRaisesRegex(RuntimeError, "duplicate names"):
tensor.names = ["N", "N"]
def test_factory_edge_cases(self):
for device in torch.testing.get_all_device_types():
self._test_factory(torch.empty, device)
def test_factory_coverage(self):
def _test(factory, device):
names = ("N", "T", "D")
torch.manual_seed(0)
result = factory(1, 2, 3, names=names, device=device)
torch.manual_seed(0)
expected = factory(1, 2, 3, device=device).rename_(*names)
self.assertTensorDataAndNamesEqual(result, expected)
supported = [
torch.ones,
torch.rand,
torch.randn,
torch.zeros,
]
for op, device in itertools.product(
supported, torch.testing.get_all_device_types()
):
_test(op, device)
# Test torch.full
for device in torch.testing.get_all_device_types():
names = ("N", "T", "D")
result = torch.full([1, 2, 3], 2.0, names=names, device=device)
expected = torch.full([1, 2, 3], 2.0, device=device).rename_(*names)
self.assertTensorDataAndNamesEqual(result, expected)
def test_tensor_from_lists(self):
names = ("N", "C")
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)
names = ("N",)
tensor = torch.tensor([1], names=names)
self.assertEqual(tensor.names, names)
with self.assertRaisesRegex(RuntimeError, "Number of names"):
names = ("N", "C")
tensor = torch.tensor([1], names=names)
@unittest.skipIf(not TEST_NUMPY, "no numpy")
def test_tensor_from_numpy(self):
import numpy as np
arr = np.array([[1]])
names = ("N", "C")
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)
def test_tensor_from_tensor(self):
x = torch.randn(1, 1)
names = ("N", "C")
tensor = torch.tensor(x, names=names)
self.assertEqual(tensor.names, names)
def test_tensor_from_named_tensor(self):
x = torch.randn(1, 1, names=("N", "D"))
tensor = torch.tensor(x)
self.assertEqual(tensor.names, ("N", "D"))
# there's no way to distinguish between names=None and not passing in names.
# If the user passes in names=None they are asking for trouble.
x = torch.randn(1, 1, names=("N", "D"))
tensor = torch.tensor(x, names=None)
self.assertEqual(tensor.names, ("N", "D"))
x = torch.randn(1, 1, names=("N", "D"))
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
tensor = torch.tensor(x, names=("N", "C"))
def test_size(self):
t = torch.empty(2, 3, 5, names=("N", None, "C"))
self.assertEqual(t.size("N"), 2)
self.assertEqual(t.size("C"), 5)
with self.assertRaisesRegex(RuntimeError, "Please look up dimensions by name*"):
t.size(None)
with self.assertRaisesRegex(RuntimeError, "Name 'channels' not found in "):
t.size("channels")
with self.assertRaisesRegex(RuntimeError, "Name 'N' not found in "):
torch.empty(2, 3, 4).size("N")
def test_stride(self):
t = torch.empty(2, 3, 5, names=("N", None, "C"))
self.assertEqual(t.stride("N"), 3 * 5)
self.assertEqual(t.stride("C"), 1)
with self.assertRaisesRegex(RuntimeError, "Please look up dimensions by name"):
t.stride(None)
with self.assertRaisesRegex(RuntimeError, "Name 'channels' not found in "):
t.stride("channels")
with self.assertRaisesRegex(RuntimeError, "Name 'N' not found in "):
torch.empty(2, 3, 4).stride("N")
def test_transpose_variants(self):
t = torch.randn(2, 3, 5, 7, names=("N", "C", "H", "W"))
self.assertEqual(t.transpose("N", "C").names, ["C", "N", "H", "W"])
self.assertEqual(t.transpose(1, 3).names, ["N", "W", "H", "C"])
t = torch.randn(2, 3, names=("N", "C"))
self.assertEqual(t.t().names, ["C", "N"])
def test_resize(self):
for device in torch.testing.get_all_device_types():
named = torch.randn(2, names=("N",), device=device)
named.resize_([2])
self.assertEqual(named.names, ["N"])
with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"):
named.resize_([3])
other_named = torch.randn(2, names=("N",), device=device)
named.resize_as_(other_named)
self.assertEqual(other_named.names, ["N"])
unnamed = torch.randn(2, device=device)
with self.assertRaisesRegex(
RuntimeError, r"names .* are not the same as the computed output names"
):
named.resize_as_(unnamed)
unnamed = torch.randn(1, device=device)
unnamed.resize_as_(named)
self.assertEqual(unnamed.names, ["N"])
def test_cdist(self):
for device in torch.testing.get_all_device_types():
tensor = torch.randn(
3, 1, 2, 7, names=("M", "N", "first_group", "features"), device=device
)
other = torch.randn(
5, 11, 7, names=("N", "second_group", "features"), device=device
)
result = torch.cdist(tensor, other)
self.assertEqual(result.names, ["M", "N", "first_group", "second_group"])
def test_info_smoke(self):
# Smoke test for info functions / methods / attributes on named tensors.
tensor = torch.empty(1, 1, names=("N", "D"))
tensor.device
tensor.dtype
tensor.get_device()
tensor.is_complex()
tensor.is_floating_point()
tensor.is_nonzero()
torch.is_same_size(tensor, tensor)
torch.is_signed(tensor)
tensor.layout
tensor.numel()
tensor.dim()
tensor.element_size()
tensor.is_contiguous()
tensor.is_cuda
tensor.is_leaf
tensor.is_pinned()
tensor.is_shared()
tensor.is_sparse
tensor.ndimension()
tensor.nelement()
tensor.shape
tensor.size()
tensor.size(1)
tensor.storage()
tensor.storage_offset()
tensor.storage_type()
tensor.stride()
tensor.stride(1)
tensor.data
tensor.data_ptr()
tensor.ndim
tensor.item()
tensor.type()
tensor.is_shared()
tensor.is_signed()
def test_autograd_smoke(self):
x = torch.randn(3, 3, names=("N", "D"), requires_grad=True)
y = x.clone()
y.retain_grad()
y.register_hook(lambda x: x)
y.sum().backward()
# autograd related attributes
tensor = torch.empty(1, 1, names=("N", "D"), requires_grad=True)
tensor = tensor.relu()
tensor.output_nr
tensor.grad_fn
tensor.requires_grad
def test_split_fns_propagates_names(self):
fns = [
lambda x: x.split(1, 0),
lambda x: x.split([1, 1], 1),
lambda x: x.chunk(2, 0),
]
for device in torch.testing.get_all_device_types():
orig_tensor = torch.empty(2, 2, names=("N", "D"), device=device)
for fn in fns:
splits = fn(orig_tensor)
for split in splits:
self.assertEqual(split.names, orig_tensor.names)
def test_any_all(self):
for device in torch.testing.get_all_device_types():
x = torch.zeros(3, dtype=torch.bool, device=device, names=("C",))
self.assertEqual(x.any().names, [])
self.assertEqual(x.all().names, [])
def test_addcmul_addcdiv(self):
for device in torch.testing.get_all_device_types():
names = ["N"]
a = torch.rand(3, device=device, names=names)
b = torch.rand(3, device=device, names=names)
# avoid division by 0
c = torch.rand(3, device=device, names=names).clamp_min_(0.1)
out = torch.randn(3, device=device, names=names)
self.assertEqual(torch.addcmul(a, b, c).names, names)
self.assertEqual(torch.addcmul(a, b, c, out=out).names, names)
self.assertEqual(a.addcmul_(b, c).names, names)
self.assertEqual(torch.addcdiv(a, b, c).names, names)
self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names)
self.assertEqual(a.addcdiv_(b, c).names, names)
def test_binary_ops(self):
def test_basic(op):
a = torch.empty(2, 3, names=("N", "C"))
b = torch.empty(3, 2, names=("C", "N"))
c = torch.empty(3, names=("C",))
d = torch.empty(5, names=("W",))
self.assertEqual(op(a, a).names, ("N", "C"))
self.assertEqual(op(a, c).names, ("N", "C"))
with self.assertRaisesRegex(RuntimeError, "do not match"):
op(a, d)
with self.assertRaisesRegex(RuntimeError, "do not match"):
op(a, b)
def test_wildcard(op):
a = torch.empty(2, 3, names=("N", "C"))
c = torch.empty(2, 3, names=(None, "C"))
self.assertEqual(op(a, c).names, ("N", "C"))
b = torch.empty(2, 3)
self.assertEqual(op(a, b).names, ("N", "C"))
d = torch.empty(2, 3, names=("C", None))
with self.assertRaisesRegex(RuntimeError, "Misaligned"):
op(d, c)
def test_mixed_unnamed_named(op, is_inplace):
named2 = torch.randn(1, 1, names=("N", "C"))
unnamed1 = torch.randn(1)
unnamed2 = torch.randn(1, 1)
unnamed3 = torch.randn(1, 1, 1)
def compute_expected_names(tensor, other):
assert tensor.has_names() ^ other.has_names()
named = tensor if tensor.has_names() else other
unnamed = other if tensor.has_names() else tensor
unnamed_dim = unnamed.dim()
if unnamed_dim > named.dim():
return [None] * (unnamed_dim - named.dim()) + list(named.names)
else:
return named.names
inputs = itertools.chain(
itertools.product([named2], [unnamed1, unnamed2, unnamed3]),
itertools.product([unnamed1, unnamed2, unnamed3], [named2]),
)
if is_inplace:
# In-place ops have the constraint that they must not change shape.
inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()]
for tensor, other in inputs:
expected_names = compute_expected_names(tensor, other)
self.assertEqual(op(tensor, other).names, expected_names)
def method(name, *args, **kwargs):
return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))]
def function(name, *args, **kwargs):
return [
Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))
]
def out_function(name, *args, **kwargs):
out_fn = getattr(torch, name)
def fn(a, b):
result = torch.empty([0], dtype=a.dtype, device=a.device)
out_fn(a, b, *args, out=result, **kwargs)
return result
return [Function(name, fn)]
def fn_method_and_inplace(name, *args, **kwargs):
return (
method(name, *args, **kwargs)
+ method(name + "_", *args, **kwargs)
+ out_function(name, *args, **kwargs)
)
tests = [
fn_method_and_inplace("add"),
fn_method_and_inplace("div"),
fn_method_and_inplace("mul"),
fn_method_and_inplace("sub"),
fn_method_and_inplace("pow"),
fn_method_and_inplace("atan2"),
method("copy_"),
function("floor_divide"),
function("true_divide"),
]
tests = flatten(tests)
for name, op in tests:
test_basic(op)
test_wildcard(op)
test_mixed_unnamed_named(op, is_inplace=name.endswith("_"))
def test_logical_ops(self):
# Implemented via TensorIterator, so just check that each version
# (out-of-place, inplace, out=) propagates names.
def zeros(*args, **kwargs):
return torch.zeros(*args, dtype=torch.bool, **kwargs)
for op in ("logical_xor", "logical_and", "logical_or"):
self._test_name_inference(
getattr(torch, op),
(create("N:2,C:3", zeros), create("N:2,C:3", zeros)),
expected_names=["N", "C"],
)
self._test_name_inference(
getattr(Tensor, op + "_"),
(create("N:2,C:3", zeros), create("N:2,C:3", zeros)),
expected_names=["N", "C"],
)
self._test_name_inference(
lambda out, x, y: getattr(torch, op)(x, y, out=out),
(
create("0", zeros),
create("N:2,C:3", zeros),
create("N:2,C:3", zeros),
),
expected_names=["N", "C"],
)
def test_pow_special(self):
# There are a few pow cases that don't go through TensorIterator.
# Test them here.
for device in torch.testing.get_all_device_types():
named = torch.randn(2, 3, names=("N", "C"), device=device)
unnamed = torch.randn([0], device=device)
result = torch.pow(named, 0, out=unnamed.clone())
self.assertEqual(result.names, named.names)
result = torch.pow(named, 1, out=unnamed.clone())
self.assertEqual(result.names, named.names)
result = torch.pow(1, named, out=unnamed.clone())
self.assertEqual(result.names, named.names)
def test_out_fn_semantics(self):
out_fn = torch.abs
unnamed_tensor = torch.randn(3, 2)
none_named_tensor = torch.randn(3, 2, names=(None, None))
named_tensor = torch.randn(3, 2, names=("N", "C"))
partially_named_tensor = torch.randn(3, 2, names=("N", None))
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(partially_named_tensor, out=named_tensor)
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(named_tensor, out=partially_named_tensor)
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(none_named_tensor, out=named_tensor)
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(unnamed_tensor, out=named_tensor)
output = torch.randn(3, 2)
out_fn(unnamed_tensor, out=output)
self.assertFalse(output.has_names())
output = torch.randn(3, 2, names=(None, None))
out_fn(named_tensor, out=output)
self.assertEqual(output.names, named_tensor.names)
output = torch.randn(3, 2)
out_fn(named_tensor, out=output)
self.assertEqual(output.names, named_tensor.names)
output = torch.randn(3, 2, names=(None, None))
out_fn(unnamed_tensor, out=output)
self.assertFalse(output.has_names())
def test_unary_propagate_names_fns(self):
def _test(testcase, names=("N", "D"), device="cpu"):
sizes = [2] * len(names)
tensor = torch.empty(sizes, names=names, device=device)
try:
out = testcase.lambd(tensor)
except RuntimeError as err:
# Get a better error message by catching the error and asserting.
raise RuntimeError("{}: {}".format(testcase.name, err)) from err
self.assertEqual(out.names, tensor.names, msg=testcase.name)
def fn(name, *args, **kwargs):
return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]
def method(name, *args, **kwargs):
return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))]
def out_function(name, *args, **kwargs):
out_fn = getattr(torch, name)
def fn(tensor):
result = torch.empty([0], dtype=tensor.dtype, device=tensor.device)
out_fn(tensor, *args, out=result, **kwargs)
return result
return [Function(name + "_out", fn)]
def fn_method_and_inplace(name, *args, **kwargs):
return (
method(name, *args, **kwargs)
+ method(name + "_", *args, **kwargs)
+ out_function(name, *args, **kwargs)
)
# All of these operate on 2x2 tensors.
tests = [
# unary pointwise
fn_method_and_inplace("abs"),
fn_method_and_inplace("acos"),
fn_method_and_inplace("asin"),
fn_method_and_inplace("atan"),
fn_method_and_inplace("ceil"),
fn_method_and_inplace("clamp", -1, 1),
fn_method_and_inplace("clamp_min", -2),
fn_method_and_inplace("clamp_max", 2),
method("cauchy_"),
method("clone"),
method("contiguous"),
fn_method_and_inplace("cos"),
fn_method_and_inplace("cosh"),
fn_method_and_inplace("digamma"),
fn_method_and_inplace("erf"),
fn_method_and_inplace("erfc"),
fn_method_and_inplace("erfinv"),
fn_method_and_inplace("exp"),
fn_method_and_inplace("expm1"),
method("exponential_"),
fn_method_and_inplace("floor"),
fn_method_and_inplace("frac"),
method("geometric_", p=0.5),
fn_method_and_inplace("lgamma"),
fn_method_and_inplace("log"),
fn_method_and_inplace("log10"),
fn_method_and_inplace("log1p"),
fn_method_and_inplace("log2"),
method("log_normal_"),
fn_method_and_inplace("neg"),
method("normal_"),
[Function("polygamma", lambda t: torch.polygamma(1, t))],
method("polygamma_", 1),
fn_method_and_inplace("reciprocal"),
method("random_", 0, 1),
method("random_", 1),
method("random_"),
method("relu_"),
method("requires_grad_"),
method("relu"),
fn_method_and_inplace("round"),
fn_method_and_inplace("rsqrt"),
fn_method_and_inplace("sigmoid"),
fn_method_and_inplace("sign"),
fn_method_and_inplace("sin"),
fn_method_and_inplace("sinh"),
fn_method_and_inplace("sqrt"),
fn_method_and_inplace("tan"),
fn_method_and_inplace("tanh"),
fn("threshold", 0, 1),
fn("threshold_", 0, 1),
out_function("threshold", 0, 1),
fn_method_and_inplace("trunc"),
method("uniform_"),
method("zero_"),
method("fill_", 1),
method("fill_", torch.tensor(3.14)),
# conversions
method("to", dtype=torch.long),
method("to", device="cpu"),
method("to", torch.empty([])),
method("bool"),
method("byte"),
method("char"),
method("cpu"),
method("double"),
method("float"),
method("long"),
method("half"),
method("int"),
method("short"),
method("type", dtype=torch.long),
# cumsum and cumprod
fn("cumsum", 0),
fn("cumsum", "D"),
out_function("cumsum", "D"),
fn("cumprod", 0),
fn("cumprod", "D"),
out_function("cumprod", "D"),
# views
method("narrow", 0, 0, 1),
# creation functions
fn("empty_like"),
fn("zeros_like"),
fn("ones_like"),
fn("full_like", 3.14),
fn("rand_like"),
fn("randn_like"),
# bernoulli variants
method("bernoulli_", 0.5),
method("bernoulli_", torch.tensor(0.5)),
method("softmax", dim=1),
method("softmax", dim="D"),
method("log_softmax", dim=1),
method("log_softmax", dim="D"),
[
Function(
"F.dropout(inplace)", lambda t: F.dropout(t, p=0.5, inplace=True)
)
],
[
Function(
"F.dropout(outplace)", lambda t: F.dropout(t, p=0.5, inplace=False)
)
],
]
tests = flatten(tests)
for testcase, device in itertools.product(
tests, torch.testing.get_all_device_types()
):
_test(testcase, device=device)
def test_cummax_cummin(self):
def test_ops(op):
for device in torch.testing.get_all_device_types():
names = ("N", "D")
tensor = torch.rand(2, 3, names=names)
result = op(tensor, 0)
self.assertEqual(result[0].names, names)
self.assertEqual(result[1].names, names)
test_ops(torch.cummax)
test_ops(torch.cummin)
def test_logcumsumexp(self):
for device in torch.testing.get_all_device_types():
names = ("N", "D")
tensor = torch.rand(2, 3, names=names)
result = torch.logcumsumexp(tensor, "D")
self.assertEqual(result.names, names)
def test_bitwise_not(self):
for device in torch.testing.get_all_device_types():
names = ("N", "D")
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
result = torch.empty(0, dtype=torch.bool)
self.assertEqual(tensor.bitwise_not().names, names)
self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
self.assertEqual(tensor.bitwise_not_().names, names)
def test_logical_not(self):
for device in torch.testing.get_all_device_types():
names = ("N", "D")
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
result = torch.empty(0, dtype=torch.bool)
self.assertEqual(tensor.logical_not().names, names)
self.assertEqual(torch.logical_not(tensor, out=result).names, names)
self.assertEqual(tensor.logical_not_().names, names)
def test_bernoulli(self):
for device in torch.testing.get_all_device_types():
names = ("N", "D")
tensor = torch.rand(2, 3, names=names)
result = torch.empty(0)
self.assertEqual(tensor.bernoulli().names, names)
torch.bernoulli(tensor, out=result)
self.assertEqual(result.names, names)
def test_flatten(self):
tensor = torch.randn(2, 3, 5, 7, 11, names=("N", "C", "D", "H", "W"))
# basic
out = tensor.flatten("D", "W", "features")
self.assertEqual(out.names, ["N", "C", "features"])
self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
# int overload
out = tensor.flatten(2, 4, "features")
self.assertEqual(out.names, ["N", "C", "features"])
self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
# list overload
out = tensor.flatten(["D", "H", "W"], "features")
self.assertEqual(out.names, ["N", "C", "features"])
self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
# Non-contiguous flatten: N and H are not "adjacent" in memory.
sentences = torch.randn(2, 3, 5, 7, names=("N", "T", "H", "D"))
sentences = sentences.transpose("T", "H")
out = sentences.flatten("N", "H", "N_H")
self.assertEqual(out.names, ["N_H", "T", "D"])
with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"):
tensor.flatten(["D", "L"], "features")
with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
tensor.flatten(["D", "W"], "features")
with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
tensor.flatten(["H", "D", "W"], "features")
def test_unflatten(self):
# test args: tensor, int, namedshape
self.assertTrue(
torch.equal(
torch.ones(4).unflatten(0, (("A", 2), ("B", 2))),
torch.ones(2, 2, names=("A", "B")),
)
)
self.assertTrue(
torch.equal(
torch.ones(4).unflatten(0, [("A", 2), ("B", 2)]),
torch.ones(2, 2, names=("A", "B")),
)
)
self.assertTrue(
torch.equal(
torch.ones(4).unflatten(0, (["A", 2], ["B", 2])),
torch.ones(2, 2, names=("A", "B")),
)
)
self.assertTrue(
torch.equal(
torch.ones(4).unflatten(-1, (["A", 2], ["B", 2])),
torch.ones(2, 2, names=("A", "B")),
)
)
self.assertTrue(
torch.equal(
torch.ones(4).unflatten(-1, (["A", -1], ["B", 2])),
torch.ones(2, 2, names=("A", "B")),
)
)
self.assertTrue(
torch.equal(
torch.ones(4).unflatten(-1, (["A", 2], ["B", -1])),
torch.ones(2, 2, names=("A", "B")),
)
)
self.assertTrue(
torch.equal(
torch.ones(2, 10, names=("A", "B")).unflatten("B", (["B1", -1],)),
torch.ones(2, 10, names=("A", "B1")),
)
)
self.assertTrue(
torch.equal(
torch.ones(2, 3 * 4 * 5 * 6, names=("A", "B")).unflatten(
"B", (["B1", 3], ["B2", 4], ["B3", -1], ["B4", 6])
),
torch.ones(2, 3, 4, 5, 6, names=("A", "B1", "B2", "B3", "B4")),
)
)
self.assertTrue(
torch.equal(
torch.ones(2, 0, names=("A", "B")).unflatten(
"B", (["B1", 3], ["B2", -1], ["B3", 4])
),
torch.ones(2, 3, 0, 4, names=("A", "B1", "B2", "B3")),
)
)
# test args: namedtensor, int, namedshape
self.assertTrue(
torch.equal(
torch.ones(2, 4, names=("A", "B")).unflatten(1, (("B1", 2), ("B2", 2))),
torch.ones(2, 2, 2, names=("A", "B1", "B2")),
)
)
# test args: namedtensor, str, namedshape
self.assertTrue(
torch.equal(
torch.ones(2, 4, names=("A", "B")).unflatten(
"B", (("B1", 2), ("B2", 2))
),
torch.ones(2, 2, 2, names=("A", "B1", "B2")),
)
)
# test invalid args: namedtensor, str, sizes
with self.assertRaisesRegex(
TypeError, r"received an invalid combination of arguments"
):
torch.tensor([1], names=("A",)).unflatten("A", (1, 1))
# test invalid args: namedtensor, int, sizes
with self.assertRaisesRegex(
RuntimeError,
r"input is a named tensor but no names were given for unflattened sizes",
):
torch.tensor([1], names=("A",)).unflatten(0, (1, 1))
with self.assertRaisesRegex(
RuntimeError,
r"Provided sizes \[3, -1\] don't multiply up to the "
r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]",
):
torch.ones(2, 4, names=("A", "B")).unflatten("B", (("B1", 3), ("B2", -1)))
with self.assertRaisesRegex(
RuntimeError,
r"the unspecified dimension size -1 can be any value and is ambiguous",
):
torch.ones(2, 0, names=("A", "B")).unflatten("B", (("B1", 0), ("B2", -1)))
tensor = torch.randn(7, 2 * 3 * 5, 11, names=("N", "D", "K"))
# accepts OrderedDict
out = tensor.unflatten("D", OrderedDict((("C", 2), ("H", 3), ("W", 5))))
self.assertEqual(out.names, ("N", "C", "H", "W", "K"))
self.assertEqual(out.shape, (7, 2, 3, 5, 11))
# Unflatten left-most
out = tensor.unflatten("N", (("N", 7), ("H", 1)))
self.assertEqual(out.names, ("N", "H", "D", "K"))
self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11))
# Unflatten right-most
out = tensor.unflatten("K", (("K", 11), ("H", 1)))
self.assertEqual(out.names, ("N", "D", "K", "H"))
self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1))
with self.assertRaisesRegex(RuntimeError, "don't multiply up to"):
tensor.unflatten("D", (("H", 3), ("W", 5)))
with self.assertRaisesRegex(RuntimeError, "sizes must be non-empty"):
tensor.unflatten("D", None)
with self.assertRaisesRegex(RuntimeError, "non-empty"):
tensor.unflatten("D", OrderedDict())
def test_unsupported_op_error_msg(self):
named = torch.randn(3, 3, names=("N", "C"))
with self.assertRaisesRegex(
RuntimeError, r"pdist.+is not yet supported with named tensors"
):
torch.pdist(named)
with self.assertRaisesRegex(
RuntimeError, r"as_strided_.+is not yet supported with named tensors"
):
named.as_strided_((3, 3), (3, 1))
def test_reduction_fns(self):
def check_output(output, expected_names):
if isinstance(output, torch.Tensor):
self.assertEqual(output.names, expected_names)
return
for out in output:
self.assertEqual(out.names, expected_names)
def sum_all_outputs(output):
if isinstance(output, torch.Tensor):
return output.sum()
result = 0
for out in output:
result = out + result
return result.sum()
def test_simple_reduce(op, device):
t = torch.empty(2, 3, 5, names=("N", "C", "L"), device=device)
check_output(op(t, 1), ["N", "L"])
check_output(op(t, -1), ["N", "C"])
check_output(op(t, "C"), ["N", "L"])
with self.assertRaisesRegex(
RuntimeError, "Please look up dimensions by name"
):
op(t, None)
with self.assertRaisesRegex(RuntimeError, "Name 'H' not found"):
op(t, "H")
def test_autograd_supports_dimname_overload(op, device):
t = torch.empty(
2, 3, 5, names=("N", "C", "L"), device=device, requires_grad=True
)
sum_all_outputs(op(t, "C")).backward()
self.assertIsNotNone(t.grad)
def test_complete_reduce(op, device):
t = torch.empty(2, 3, 5, names=("N", "C", "L"), device=device)
check_output(op(t), [])
def test_multidim_reduce(op, device):
t = torch.empty(2, 3, 5, names=("N", "C", "L"), device=device)
check_output(op(t, [1, 2]), ["N"])
check_output(op(t, [0, -1]), ["C"])
check_output(op(t, ["C", "L"]), ["N"])
with self.assertRaisesRegex(
RuntimeError, "Please look up dimensions by name"
):
op(t, [None, "C"])
def test_out_variant(op, output_lambda, device):
t = torch.empty(2, 3, 5, names=("N", "C", "L"), device=device)
if output_lambda:
out = output_lambda(t)
else:
out = torch.empty([0], device=device)
op(t, "C", out=out)
check_output(out, ["N", "L"])
def test_keepdim(op, device):
t = torch.empty(2, 3, 5, names=("N", "C", "L"), device=device)
check_output(op(t, "C", keepdim=True), ["N", "C", "L"])
def values_and_indices(t):
return (
torch.empty([0], device=t.device),
torch.empty([0], device=t.device, dtype=torch.long),
)
def kthvalue_wrapper(tensor, *args, **kwargs):
# Return the 0-th value
return torch.kthvalue(tensor, 1, *args, **kwargs)
Case = namedtuple(
"Case",
[
"op",
"supports_complete_reduce",
"supports_multidim_reduce",
"supports_out_variant",
"supports_keepdim",
"output_lambda",
],
)
tests = [
Case(torch.sum, True, True, True, True, None),
Case(torch.prod, True, False, True, True, None),
Case(torch.mean, True, True, True, True, None),
Case(torch.var, True, True, True, True, None),
Case(torch.std, True, True, True, True, None),
Case(torch.std_mean, True, True, False, True, None),
Case(torch.var_mean, True, True, False, True, None),
Case(torch.min, True, False, True, True, values_and_indices),
Case(torch.max, True, False, True, True, values_and_indices),
Case(torch.unbind, False, False, False, False, None),
Case(torch.logsumexp, False, True, True, True, None),
Case(torch.mode, False, False, True, True, values_and_indices),
Case(kthvalue_wrapper, False, False, True, True, values_and_indices),
Case(torch.median, True, False, True, True, values_and_indices),
Case(torch.nanmedian, True, False, True, True, values_and_indices),
]
for testcase, device in itertools.product(
tests, torch.testing.get_all_device_types()
):
op = testcase.op
test_simple_reduce(op, device)
test_autograd_supports_dimname_overload(op, device)
if testcase.supports_keepdim:
test_keepdim(op, device)
if testcase.supports_out_variant:
test_out_variant(op, testcase.output_lambda, device)
if testcase.supports_complete_reduce:
test_complete_reduce(op, device)
if testcase.supports_multidim_reduce:
test_multidim_reduce(op, device)
def test_masked_select(self):
# simple
self._test_name_inference(
torch.masked_select,
(create("N:2,C:3"), (create("2,3") > 0).rename("N", "C")),
expected_names=[None],
)
# left broadcast
self._test_name_inference(
torch.masked_select,
(create("C:3"), (create("2,3") > 0).rename("N", "C")),
expected_names=[None],
)
# right broadcast
self._test_name_inference(
torch.masked_select,
(create("N:2,C:3"), (create("3") > 0).rename("C")),
expected_names=[None],
)
# error
self._test_name_inference(
torch.masked_select,
(create("N:2,C:3"), (create("3") > 0).rename("D")),
maybe_raises_regex="do not match",
)
# out=
self._test_name_inference(
out_fn(torch.masked_select),
(create("0"), create("N:2,C:3"), (create("2,3") > 0).rename("N", "C")),
expected_names=[None],
)
def test_cat(self):
# simple
self._test_name_inference(
torch.cat,
[[create("N:2,C:3"), create("N:2,C:3")]],
expected_names=["N", "C"],
)
# error: zero dim
self._test_name_inference(
torch.cat, [[create(""), create("")]], maybe_raises_regex="zero-dim"
)
# error: names don't match
self._test_name_inference(
torch.cat,
[[create("N:2,C:3"), create("C:3,N:2")]],
maybe_raises_regex="do not match",
)
# error: different number of dims
self._test_name_inference(
torch.cat,
[[create("N:2,C:3"), create("C:3")]],
maybe_raises_regex="must have same number of dimensions",
)
# out=
self._test_name_inference(
out_fn(torch.cat),
[create("0"), [create("N:2,C:3"), create("N:2,C:3")]],
expected_names=["N", "C"],
)
def test_masked_fill(self):
# simple
self._test_name_inference(
Tensor.masked_fill,
(create("N:2,C:3"), (create("2,3") > 0).rename("N", "C"), 3.14),
expected_names=["N", "C"],
)
# left broadcast
self._test_name_inference(
Tensor.masked_fill,
(create("C:3"), (create("2,3") > 0).rename("N", "C"), 3.14),
maybe_raises_regex="must be less than or equal to",
)
# right broadcast
self._test_name_inference(
Tensor.masked_fill,
(create("N:2,C:3"), (create("3") > 0).rename("C"), 3.14),
expected_names=["N", "C"],
)
# error
self._test_name_inference(
Tensor.masked_fill,
(create("N:2,C:3"), (create("3") > 0).rename("D"), 3.14),
maybe_raises_regex="do not match",
)
# inplace
self._test_name_inference(
Tensor.masked_fill_,
(create("N:2,C:3"), (create("2,3") > 0).rename("N", "C"), 3.14),
expected_names=["N", "C"],
)
# inplace, computed names don't match output tensor names
self._test_name_inference(
Tensor.masked_fill_,
(create("N:2,None:3"), (create("2,3") > 0).rename("N", "C"), 3.14),
maybe_raises_regex="not the same as the computed output names",
)
def test_using_seen_interned_string_doesnt_bump_refcount(self):
def see_name():
seen_name = "N"
pass_name_to_python_arg_parser(seen_name)
see_name()
seen_name = "N"
old_refcnt = sys.getrefcount(seen_name)
pass_name_to_python_arg_parser(seen_name)
new_refcnt = sys.getrefcount(seen_name)
self.assertEqual(new_refcnt, old_refcnt)
def test_using_unseen_interned_string_bumps_refcount_permanently(self):
# Please don't use this as a name in a different test.
unseen_name = "abcdefghi"
old_refcnt = sys.getrefcount(unseen_name)
pass_name_to_python_arg_parser(unseen_name)
new_refcnt = sys.getrefcount(unseen_name)
self.assertEqual(new_refcnt, old_refcnt + 1)
def test_using_unseen_uninterned_string_refcounts(self):
# Please don't use this as a name in a different test.
# non-compile-time constants are not interned
unseen_name = "".join(["abc", "def", "ghi", "jkl"])
interned_unseen_name = "abcdefghijkl"
self.assertFalse(unseen_name is interned_unseen_name)
old_uninterned_refcnt = sys.getrefcount(unseen_name)
old_interned_refcnt = sys.getrefcount(interned_unseen_name)
pass_name_to_python_arg_parser(unseen_name)
new_uninterned_refcnt = sys.getrefcount(unseen_name)
new_interned_refcnt = sys.getrefcount(interned_unseen_name)
# Internally, PyTorch should not hold a reference to the uninterned string
self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)
# Instead, we should hold a new reference to the interned version.
self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)
def _test_select(self, device):
x = torch.empty(2, 3, 4, 5, names=("N", "C", "H", "W"), device=device)
y = x.select(1, 1)
self.assertEqual(y.names, ("N", "H", "W"))
y = x.select("C", 1)
self.assertEqual(y.names, ("N", "H", "W"))
with self.assertRaisesRegex(RuntimeError, "Please look up dimensions by name"):
y = x.select(None, 1)
def test_select(self):
self._test_select("cpu")
@unittest.skipIf(not TEST_CUDA, "no CUDA")
def test_select_cuda(self):
self._test_select("cuda")
def _test_as_strided(self, device):
x = torch.empty(2, 3, 4, 5, names=("N", "C", "H", "W"), device=device)
y = x.as_strided([2 * 3 * 4 * 5], [1])
self.assertEqual(y.names, (None,))
def test_as_strided(self):
self._test_as_strided("cpu")
@unittest.skipIf(not TEST_CUDA, "no CUDA")
def test_as_strided_cuda(self):
self._test_as_strided("cuda")
def test_no_jit_tracer_support(self):
def foo(x):
return torch.full(x.shape, 2.0, names=("N",))
with self.assertRaisesRegex(RuntimeError, "not supported with the tracer"):
x = torch.randn(3)
torch.jit.trace(foo, example_inputs=x)
def bar(x):
return x.select("N", 1)
with self.assertRaisesRegex(RuntimeError, "not supported with the tracer"):
x = torch.randn(3)
torch.jit.trace(bar, example_inputs=x)
def test_no_jit_script_support(self):
@torch.jit.script
def foo(x):
return x + 1
with self.assertRaisesRegex(RuntimeError, "NYI"):
foo(torch.randn(2, 3, names=("N", "C")))
@torch.jit.ignore
def add_names(x):
x.names = ("N", "C")
@torch.jit.script
def return_named_tensor(input):
add_names(input)
return input
with self.assertRaisesRegex(RuntimeError, "NYI"):
return_named_tensor(torch.randn(1, 1))
def test_align_to(self):
# trivial
tensor = create("N:3")
output = tensor.align_to("N")
self.assertEqual(output.names, ["N"])
self.assertEqual(output.shape, [3])
# unsqueeze behavior
tensor = create("N:3")
output = tensor.align_to("N", "D")
self.assertEqual(output.names, ["N", "D"])
self.assertEqual(output.shape, [3, 1])
# transpose behavior
tensor = create("N:3,C:2")
output = tensor.align_to("C", "N")
self.assertEqual(output.names, ["C", "N"])
self.assertEqual(output.shape, [2, 3])
# unsqueeze / transpose
tensor = create("C:2,N:3,H:5")
output = tensor.align_to("N", "H", "W", "C")
self.assertEqual(output.names, ["N", "H", "W", "C"])
self.assertEqual(output.shape, [3, 5, 1, 2])
# All input dimensions must be named
with self.assertRaisesRegex(
RuntimeError, "All input dims must be named. Found unnamed dim at index 0"
):
create("None:2,C:3").align_to("N", "C")
# not enough names
with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"):
create("N:2,C:3").align_to("C")
# names not found
with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"):
create("N:2,C:3").align_to("D", "N")
def test_align_to_ellipsis(self):
tensor = create("N:7,H:3,W:5,C:2")
# ... = ['N', 'H', 'W', 'C']
output = tensor.align_to("...")
self.assertEqual(output.names, ["N", "H", "W", "C"])
self.assertEqual(output.shape, [7, 3, 5, 2])
# ... = ['H', 'C']
output = tensor.align_to("...", "W", "N")
self.assertEqual(output.names, ["H", "C", "W", "N"])
self.assertEqual(output.shape, [3, 2, 5, 7])
# ... = ['N', 'W']
output = tensor.align_to("H", "C", "...")
self.assertEqual(output.names, ["H", "C", "N", "W"])
self.assertEqual(output.shape, [3, 2, 7, 5])
# ... = ['H', 'C']
output = tensor.align_to("W", "...", "N")
self.assertEqual(output.names, ["W", "H", "C", "N"])
self.assertEqual(output.shape, [5, 3, 2, 7])
# ... = []
output = tensor.align_to("N", "...", "C", "D", "H", "W")
self.assertEqual(output.names, ["N", "C", "D", "H", "W"])
self.assertEqual(output.shape, [7, 2, 1, 3, 5])
# Input tensor partially named
partially_named = create("None:2,None:3,None:5,C:7")
output = partially_named.align_to("C", "...")
self.assertEqual(output.names, ["C", None, None, None])
self.assertEqual(output.shape, [7, 2, 3, 5])
with self.assertRaisesRegex(
RuntimeError, "order of dimensions cannot contain a None"
):
partially_named.align_to("C", None, "...")
# Input order partially named
with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"):
tensor.align_to("...", "N", None)
# Input order duplicate names
with self.assertRaisesRegex(RuntimeError, "duplicate names"):
tensor.align_to("...", "N", "N")
def test_align_as(self):
# align_as calls align_to internally. align_to has pretty substantial tests,
# so just test some basic things here.
tensor = create("C:2,N:3,H:5")
other = create("N:1,H:1,W:1,C:1")
output = tensor.align_as(other)
self.assertEqual(output.names, ["N", "H", "W", "C"])
self.assertEqual(output.shape, [3, 5, 1, 2])
@unittest.skip("Not implemented yet")
def test_align_tensors_two_inputs(self):
def _test(tensor_namedshape, align_names, expected_sizes, expected_error):
tensor_names, tensor_sizes = tensor_namedshape
tensor = torch.empty(*tensor_sizes, names=tensor_names)
other = torch.empty([1] * len(align_names), names=align_names)
if expected_error is not None:
with self.assertRaisesRegex(RuntimeError, expected_error):
torch.align_tensors(tensor, other)
return
output, _ = torch.align_tensors(tensor, other)
self.assertEqual(output.shape, expected_sizes)
self.assertEqual(output.names, align_names)
Case = namedtuple(
"Case",
[
"tensor_namedshape",
"align_names",
"expected_sizes",
"expected_error",
],
)
tests = [
# basic tests
Case(
tensor_namedshape=(["C"], [2]),
align_names=["C"],
expected_sizes=[2],
expected_error=None,
),
Case(
tensor_namedshape=(["C"], [2]),
align_names=["D"],
expected_sizes=None,
expected_error="not a subsequence",
),
# single-dim alignment test
Case(
tensor_namedshape=(["C"], [2]),
align_names=["N", "C"],
expected_sizes=[1, 2],
expected_error=None,
),
Case(
tensor_namedshape=[["N"], [2]],
align_names=["N", "C"],
expected_sizes=[2, 1],
expected_error=None,
),
# multiple dim alignment test
Case(
tensor_namedshape=[["N", "C"], [2, 3]],
align_names=["N", "H", "C", "W"],
expected_sizes=[2, 1, 3, 1],
expected_error=None,
),
Case(
tensor_namedshape=[["N", "C"], [2, 3]],
align_names=["C", "H", "N", "W"],
expected_sizes=None,
expected_error="not a subsequence",
),
# scalar tensor tests
Case(
tensor_namedshape=[None, [[]]],
align_names=["N", "C"],
expected_sizes=[1, 1],
expected_error=None,
),
Case(
tensor_namedshape=[[], [[]]],
align_names=[None, None],
expected_sizes=[1, 1],
expected_error=None,
),
# unnamed tensor tests
Case(
tensor_namedshape=[None, [2, 3]],
align_names=[None, None],
expected_sizes=[2, 3],
expected_error=None,
),
Case(
tensor_namedshape=[None, [2, 3]],
align_names=[None, None, None],
expected_sizes=[1, 2, 3],
expected_error=None,
),
Case(
tensor_namedshape=[None, [2]],
align_names=["N"],
expected_sizes=None,
expected_error="not a subsequence",
),
# unnamed dim alignment tests
Case(
tensor_namedshape=[[None], [2]],
align_names=["N", None],
expected_sizes=[1, 2],
expected_error=None,
),
Case(
tensor_namedshape=[[None], [2]],
align_names=["N", None, None, None],
expected_sizes=[1, 1, 1, 2],
expected_error=None,
),
Case(
tensor_namedshape=[["N"], [2]],
align_names=["N", None, None, None],
expected_sizes=[2, 1, 1, 1],
expected_error=None,
),
Case(
tensor_namedshape=[[None, "N", None], [2, 3, 5]],
align_names=[None, None, "N", None],
expected_sizes=[1, 2, 3, 5],
expected_error=None,
),
Case(
tensor_namedshape=[[None], [2]],
align_names=[None, "N"],
expected_sizes=None,
expected_error="absolute position from the right",
),
Case(
tensor_namedshape=[None, [2]],
align_names=[None, "N"],
expected_sizes=None,
expected_error="absolute position from the right",
),
Case(
tensor_namedshape=[[None, "N"], [2, 3]],
align_names=[None, "C", "N"],
expected_sizes=None,
expected_error="absolute position from the right",
),
]
for test in tests:
_test(*test)
@unittest.skip("Not implemented yet")
def test_align_tensors(self):
def reference_fn(*tensors):
longest_names = tensors[0].names
for tensor in tensors:
if len(tensor.names) > len(longest_names):
longest_names = tensor.names
return [tensor.align_to(*longest_names) for tensor in tensors]
x = torch.empty(1, 1, names=("N", "H"))
y = torch.empty(2, 3, 5, names=("N", "C", "H"))
z = torch.empty(2, names=("N",))
output = torch.align_tensors(x, y, z)
expected_tensors = reference_fn(x, y, z)
for tensor, expected in zip(output, expected_tensors):
self.assertTensorDataAndNamesEqual(tensor, expected)
def test_mm(self):
for device in torch.testing.get_all_device_types():
self._test_name_inference(
torch.mm,
device=device,
args=(create("N:3,C:2"), create("W:2,H:5")),
expected_names=("N", "H"),
)
# left arg is unnamed
self._test_name_inference(
torch.mm,
device=device,
args=(create("3,2"), create("W:2,H:5")),
expected_names=(None, "H"),
)
# right arg is unnamed
self._test_name_inference(
torch.mm,
device=device,
args=(create("N:3,C:2"), create("2,5")),
expected_names=("N", None),
)
# out=
self._test_name_inference(
out_fn(torch.mm),
device=device,
args=(create("0"), create("N:3,C:2"), create("W:2,H:5")),
expected_names=("N", "H"),
)
self._test_name_inference(
torch.mm,
device=device,
args=(create("N:3,C:2"), create("W:2,N:5")),
maybe_raises_regex="with duplicate names",
)
def test_expand(self):
for device in torch.testing.get_all_device_types():
self._test_name_inference(
Tensor.expand,
device=device,
args=(create("D:1"), [3]),
expected_names=("D"),
)
self._test_name_inference(
Tensor.expand,
device=device,
args=(create("H:3,W:2"), [10, 3, 3, 2]),
expected_names=(None, None, "H", "W"),
)
self._test_name_inference(
Tensor.expand,
device=device,
args=(create("3, 2"), [10, 3, 3, 2]),
expected_names=(None, None, None, None),
)
def test_addmm(self):
for device in torch.testing.get_all_device_types():
# full names
self._test_name_inference(
torch.addmm,
device=device,
args=(create("N:3,H:5"), create("N:3,C:2"), create("W:2,H:5")),
expected_names=("N", "H"),
)
# no name on bias
self._test_name_inference(
torch.addmm,
device=device,
args=(create("3,5"), create("N:3,C:2"), create("W:2,H:5")),
expected_names=("N", "H"),
)
# partially named bias
self._test_name_inference(
torch.addmm,
device=device,
args=(create("N:3,None:5"), create("N:3,C:2"), create("W:2,H:5")),
expected_names=("N", "H"),
)
# out=
self._test_name_inference(
out_fn(torch.addmm),
device=device,
args=(
create("0"),
create("N:3,None:5"),
create("N:3,C:2"),
create("W:2,H:5"),
),
expected_names=("N", "H"),
)
# inplace
self._test_name_inference(
torch.Tensor.addmm_,
device=device,
args=(create("N:3,H:5"), create("N:3,C:2"), create("W:2,H:5")),
expected_names=("N", "H"),
)
self._test_name_inference(
torch.addmm,
device=device,
args=(create("N:3,H:5"), create("N:3,C:2"), create("W:2,N:5")),
maybe_raises_regex="with duplicate names",
)
def test_bmm(self):
for device in torch.testing.get_all_device_types():
# full names
self._test_name_inference(
torch.bmm,
device=device,
args=(create("N:7,A:3,B:2"), create("N:7,A:2,B:5")),
expected_names=("N", "A", "B"),
)
# no name on left tensor
self._test_name_inference(
torch.bmm,
device=device,
args=(create("7,3,2"), create("N:7,A:2,B:5")),
expected_names=("N", None, "B"),
)
# no name on right tensor
self._test_name_inference(
torch.bmm,
device=device,
args=(create("N:7,A:3,B:2"), create("7,2,5")),
expected_names=("N", "A", None),
)
# out=
self._test_name_inference(
out_fn(torch.bmm),
device=device,
args=(create("0"), create("N:7,A:3,B:2"), create("N:7,A:2,B:5")),
expected_names=("N", "A", "B"),
)
# duplicate names after mm
self._test_name_inference(
torch.bmm,
device=device,
args=(create("N:7,A:3,B:2"), create("N:7,B:2,A:5")),
maybe_raises_regex="with duplicate names",
)
# matching error (batch dimensions must be alignable)
self._test_name_inference(
torch.bmm,
device=device,
args=(create("N:3,A:3,B:3"), create("M:3,A:3,B:3")),
maybe_raises_regex="do not match",
)
# misalignment (batch dimension is getting contracted)
self._test_name_inference(
torch.bmm,
device=device,
args=(create("N:3,A:3,B:3"), create("None:3,N:3,B:3")),
maybe_raises_regex="misaligned",
)
def test_matmul(self):
for device in torch.testing.get_all_device_types():
# input tensors are less than 1D
self._test_name_inference(
torch.matmul,
device=device,
args=(create(""), create("A:2")),
maybe_raises_regex="at least 1D",
)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("A:2"), create("")),
maybe_raises_regex="at least 1D",
)
# 1D @ 1D
self._test_name_inference(
torch.matmul,
device=device,
args=(create("A:2"), create("B:2")),
expected_names=[],
)
# ND @ 1D
self._test_name_inference(
torch.matmul,
device=device,
args=(create("A:3,C:2"), create("B:2")),
expected_names=["A"],
)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("A:5,C:3,D:2"), create("B:2")),
expected_names=["A", "C"],
)
# 1D @ ND
self._test_name_inference(
torch.matmul,
device=device,
args=(create("C:2"), create("A:2,B:3")),
expected_names=["B"],
)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("C:2"), create("A:3,B:2,D:5")),
expected_names=["A", "D"],
)
# 2D @ 2D
self._test_name_inference(
torch.matmul,
device=device,
args=(create("A:3,B:2"), create("A:2,B:3")),
expected_names=["A", "B"],
)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("A:3,B:2"), create("B:2,A:5")),
maybe_raises_regex="with duplicate names",
)
# ND @ ND where N >= 2
self._test_name_inference(
torch.matmul,
device=device,
args=(create("C:5,A:3,B:2"), create("A:2,B:3")),
expected_names=["C", "A", "B"],
)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("C:5,A:3,B:2"), create("None:1,A:2,B:3")),
expected_names=["C", "A", "B"],
)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("C:5,A:3,B:2"), create("None:2,None:1,A:2,B:3")),
expected_names=[None, "C", "A", "B"],
)
# out=
self._test_name_inference(
out_fn(torch.matmul),
device=device,
args=(create("0"), create("N:7,A:3,B:2"), create("N:7,A:2,B:5")),
expected_names=("N", "A", "B"),
)
# duplicate names after mm
self._test_name_inference(
torch.bmm,
device=device,
args=(create("N:7,A:3,B:2"), create("N:7,B:2,A:5")),
maybe_raises_regex="with duplicate names",
)
# misalignment (batch dimension is getting contracted)
self._test_name_inference(
torch.matmul,
device=device,
args=(create("N:3,A:3,B:3"), create("A:3,N:3,B:3")),
maybe_raises_regex="do not match",
)
def test_mv(self):
for device in torch.testing.get_all_device_types():
self._test_name_inference(
torch.mv,
device=device,
args=(create("N:3,C:2"), create("W:2")),
expected_names=("N",),
)
# left arg is unnamed
self._test_name_inference(
torch.mv,
device=device,
args=(create("3,2"), create("W:2")),
expected_names=(None,),
)
# right arg is unnamed
self._test_name_inference(
torch.mv,
device=device,
args=(create("N:3,C:2"), create("2")),
expected_names=("N",),
)
# out=
self._test_name_inference(
out_fn(torch.mv),
device=device,
args=(create("0"), create("N:3,C:2"), create("W:2")),
expected_names=("N",),
)
def test_addmv(self):
for device in torch.testing.get_all_device_types():
# full names
self._test_name_inference(
torch.addmv,
device=device,
args=(create("N:3"), create("N:3,C:2"), create("H:2")),
expected_names=["N"],
)
# no name on bias
self._test_name_inference(
torch.addmv,
device=device,
args=(create("3"), create("N:3,C:2"), create("H:2")),
expected_names=("N",),
)
# out=
self._test_name_inference(
out_fn(torch.addmv),
device=device,
args=(create("0"), create("N:3"), create("N:3,C:2"), create("H:2")),
expected_names=("N",),
)
# inplace
self._test_name_inference(
torch.Tensor.addmv_,
device=device,
args=(create("N:3"), create("N:3,C:2"), create("H:2")),
expected_names=("N",),
)
def test_autograd_ignores_names(self):
# sigmoid forward is supported by named tensors, but sigmoid_backward
# is not (see native_functions.yaml). Test that autograd ignores names
# and that the sigmoid_backward succeeds.
x = torch.randn(3, 3, names=("N", "C"), requires_grad=True)
x.sigmoid().sum().backward()
def test_tensor_grad_is_unnamed(self):
x = torch.randn(3, 3, names=(None, None), requires_grad=True)
y = torch.randn(3, 3, names=("N", "C"), requires_grad=True)
(x * y).sum().backward()
# Check that names weren't propagated
self.assertEqual(y.grad.names, [None, None])
self.assertEqual(x.grad.names, [None, None])
def test_autograd_warns_named_grad(self):
base = torch.randn(3, 3, names=("N", "C"))
named_grad = base.clone()
base.requires_grad_()
with warnings.catch_warnings(record=True) as warns:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
base.clone().backward(named_grad)
self.assertEqual(len(warns), 1)
self.assertTrue(
str(warns[0].message).startswith(
"Autograd was passed a named grad tensor"
)
)
def test_nyi_dimname_overload_msg(self):
x = torch.randn(3, 3)
with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"):
x.squeeze_("N")
def test_dot(self):
for device in torch.testing.get_all_device_types():
# torch.dot ignores the names of both tensors
self._test_name_inference(
torch.dot,
device=device,
args=(create("C:2"), create("W:2")),
expected_names=[],
)
def test_comparison_ops(self):
for device in torch.testing.get_all_device_types():
a = torch.randn(3, 3, names=("N", "C"), device=device)
b = torch.randn(3, 3, names=("N", "C"), device=device)
scalar = torch.randn([], device=device)
self.assertEqual((a == b).names, ["N", "C"])
self.assertEqual((a != b).names, ["N", "C"])
self.assertEqual((a > b).names, ["N", "C"])
self.assertEqual((a < b).names, ["N", "C"])
self.assertEqual((a >= b).names, ["N", "C"])
self.assertEqual((a <= b).names, ["N", "C"])
self.assertEqual((a == 1).names, ["N", "C"])
self.assertEqual((a != 1).names, ["N", "C"])
self.assertEqual((a > 1).names, ["N", "C"])
self.assertEqual((a < 1).names, ["N", "C"])
self.assertEqual((a >= 1).names, ["N", "C"])
self.assertEqual((a <= 1).names, ["N", "C"])
self.assertEqual((a == scalar).names, ["N", "C"])
self.assertEqual((a != scalar).names, ["N", "C"])
self.assertEqual((a > scalar).names, ["N", "C"])
self.assertEqual((a < scalar).names, ["N", "C"])
self.assertEqual((a >= scalar).names, ["N", "C"])
self.assertEqual((a <= scalar).names, ["N", "C"])
res = torch.empty(3, 3, dtype=torch.bool, device=device)
torch.eq(a, b, out=res)
self.assertEqual(res.names, ["N", "C"])
torch.ne(a, b, out=res)
self.assertEqual(res.names, ["N", "C"])
torch.lt(a, b, out=res)
self.assertEqual(res.names, ["N", "C"])
torch.gt(a, b, out=res)
self.assertEqual(res.names, ["N", "C"])
torch.le(a, b, out=res)
self.assertEqual(res.names, ["N", "C"])
torch.ge(a, b, out=res)
self.assertEqual(res.names, ["N", "C"])
res = torch.isnan(a)
self.assertEqual(res.names, ["N", "C"])
res = torch.isinf(a)
self.assertEqual(res.names, ["N", "C"])
if __name__ == "__main__":
run_tests()