Files
pytorch/test/test_namedtensor.py
Mike Ruberry 3671036ef3 Adds true_divide function, analogous to Python 's, JAX's, NumPy's (true) division (#34236)
Summary:
See NumPy's division documentation here: https://numpy.org/doc/1.18/reference/generated/numpy.divide.html#numpy.divide.

True division is the same as PyTorch's default division except when both inputs are integer or bool tensors. In the latter case the inputs are (conceptually) cast to the default floating type before the division is performed.

The function is implemented for dense and sparse tensors and supports exporting to ONNX from PyTorch's eager mode or JIT traces. The function is inherently incompatible with exporting to ONNX via JIT script, and is another datapoint suggesting we should deprecate exporting scripted graphs to ONNX.

Tests are added for the type promotion, named tensor, and ONNX export behavior.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34236

Reviewed By: houseroad

Differential Revision: D20334087

Pulled By: mruberry

fbshipit-source-id: 83d00d886f46f713215d7d9e02ffd043164c57f1
2020-03-09 21:06:33 -07:00

1988 lines
77 KiB
Python

import unittest
from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY
from torch.testing._internal.common_cuda import TEST_CUDA
from collections import namedtuple, OrderedDict
import itertools
import functools
import torch
from torch import Tensor
from torch._six import PY2
import torch.nn.functional as F
from multiprocessing.reduction import ForkingPickler
import pickle
import io
import sys
import warnings
def pass_name_to_python_arg_parser(name):
x = torch.empty(2, names=(name,))
def flatten(lst):
return [item for sublist in lst for item in sublist]
Function = namedtuple('TestCase', ['name', 'lambd'])
def parse_compressed_namedshape(string):
# This is a metalanguage for describing a shape of a tensor compactly.
# 'N:3,C:2' -> size = [3, 2], names: ['N', 'C']
# 'None:3,None:2' -> size = [3, 2], names: ['None', 'None']
# '3,2' -> size = [3, 2], names=None passed to ctor.
def parse_name(maybe_name):
maybe_name = maybe_name.strip()
if maybe_name == 'None':
return None
return maybe_name
string = string.strip()
# '' -> size: [], names:None
if len(string) == 0:
return None, []
# '3, 2' -> size = [3, 2], None names.
if ':' not in string:
return None, [int(size) for size in string.split(',')]
dims = string.split(',')
tuples = [dim.split(':') for dim in dims]
return zip(*[(parse_name(name), int(size)) for name, size in tuples])
def create(namedshape, factory=torch.randn):
# namedshape: str
names, shape = parse_compressed_namedshape(namedshape)
return factory(shape, names=names)
def out_fn(operator):
@functools.wraps(operator)
def fn(*inputs):
return operator(*inputs[1:], out=inputs[0])
return fn
class TestNamedTensor(TestCase):
def test_aaa_must_run_first_check_experimental_warning(self):
# TODO(rzou): It would be nice for this to be a "real" python warning.
# Right now this error message only prints once and doesn't respect
# warnings.simplefilter behavior (where python users can control whether
# or not to display warnings once, all the time, or never).
with warnings.catch_warnings(record=True) as warns:
x = torch.randn(3, 3, names=('N', 'C'))
self.assertEqual(len(warns), 1)
self.assertTrue(str(warns[0].message).startswith(
'Named tensors and all their associated APIs are an experimental feature'))
def test_trivial(self):
pass
def _test_name_inference(self, op, args=(), expected_names=(), device='cpu',
maybe_raises_regex=None):
casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg
for arg in args]
if maybe_raises_regex is not None:
with self.assertRaisesRegex(RuntimeError, maybe_raises_regex):
result = op(*args)
return
result = op(*args)
self.assertEqual(result.names, expected_names,
message='Name inference for {} on device {} failed'.format(
op.__name__, device))
# TODO(rzou): Some form of this check should be added to self.assertEqual.
# Right now I don't know what it should look like.
def assertTensorDataAndNamesEqual(self, x, y):
self.assertEqual(x.names, y.names)
unnamed_x = x.rename(None)
unnamed_y = y.rename(None)
self.assertEqual(unnamed_x, unnamed_y)
def _test_factory(self, factory, device):
x = factory([], device=device)
self.assertEqual(x.names, ())
x = factory(1, 2, 3, device=device)
self.assertEqual(x.names, (None, None, None))
x = factory(1, 2, 3, names=None, device=device)
self.assertEqual(x.names, (None, None, None))
x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device)
self.assertEqual(x.names, ('N', 'T', 'D'))
x = factory(1, 2, 3, names=('N', None, 'D'), device=device)
self.assertEqual(x.names, ('N', None, 'D'))
with self.assertRaisesRegex(RuntimeError,
'must contain alphabetical characters and/or underscore'):
x = factory(2, names=('?',), device=device)
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
x = factory(2, 1, names=('N',), device=device)
with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'):
x = factory(2, 1, names='N', device=device)
with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'):
x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device)
names64 = ['A' * i for i in range(1, 65)]
x = factory([1] * 64, names=names64, device=device)
self.assertEqual(x.names, names64)
with self.assertRaisesRegex(
RuntimeError,
'only support up to 64 dims'):
names65 = ['A' * i for i in range(1, 66)]
x = factory([1] * 65, names=names64, device=device)
def test_none_names_refcount(self):
def scope():
unnamed = torch.empty(2, 3)
unnamed.names # materialize [None, None]
prev_none_refcnt = sys.getrefcount(None)
scope()
self.assertEqual(sys.getrefcount(None), prev_none_refcnt,
message='Using tensor.names should not change '
'the refcount of Py_None')
def test_has_names(self):
unnamed = torch.empty(2, 3)
none_named = torch.empty(2, 3, names=(None, None))
partially_named = torch.empty(2, 3, names=('N', None))
fully_named = torch.empty(2, 3, names=('N', 'C'))
self.assertFalse(unnamed.has_names())
self.assertFalse(none_named.has_names())
self.assertTrue(partially_named.has_names())
self.assertTrue(fully_named.has_names())
@unittest.skipIf(PY2, "Ellipsis object not supported in python 2")
def test_py3_ellipsis(self):
# Need to exec or else flake8 will complain about invalid python 2.
tensor = torch.randn(2, 3, 5, 7)
scope = {'tensor': tensor}
code_str = "output = tensor.refine_names('N', ..., 'C')"
exec(code_str, globals(), scope)
self.assertEqual(scope['output'].names, ['N', None, None, 'C'])
def test_refine_names(self):
# Unnamed tensor -> Unnamed tensor
self._test_name_inference(Tensor.refine_names,
[create('None:1,None:2,None:3'), 'N', 'C', 'H'],
['N', 'C', 'H'])
# Named tensor -> Named tensor
self._test_name_inference(Tensor.refine_names,
[create('N:1,C:2,H:3'), 'N', 'C', 'H'],
['N', 'C', 'H'])
# Partially named tensor -> named tensor
self._test_name_inference(Tensor.refine_names,
[create('None:1,C:2,None:3'), None, 'C', 'H'],
[None, 'C', 'H'])
# Too few names
self._test_name_inference(Tensor.refine_names,
[create('None:2,None:3'), 'N', 'C', 'H'],
maybe_raises_regex="different number of dims")
# Cannot change Tensor[D] to Tensor[N]
self._test_name_inference(Tensor.refine_names,
[create('D:3'), 'N'],
maybe_raises_regex="is different from")
# Cannot change Tensor[D] to Tensor[None]
self._test_name_inference(Tensor.refine_names,
[create('D:3'), None],
maybe_raises_regex="'D' is more specific than None")
# globbing behavior exists
self._test_name_inference(Tensor.refine_names,
[create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'],
[None, None, 'C', 'H'])
def test_detach(self):
names = ['N']
self._test_name_inference(
Tensor.detach_,
[torch.randn(3, requires_grad=True, names=names)],
names)
self._test_name_inference(
Tensor.detach,
[torch.randn(3, requires_grad=True, names=names)],
names)
def test_index_fill(self):
for device in torch.testing.get_all_device_types():
expected_names = ('N', 'C')
x = torch.randn(3, 5, device=device, names=expected_names)
output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5)
self.assertEqual(output.names, expected_names)
output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
self.assertEqual(output.names, expected_names)
output = x.index_fill('C', torch.tensor([0, 1], device=device), 5)
self.assertEqual(output.names, expected_names)
output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.))
self.assertEqual(output.names, expected_names)
def test_equal(self):
for device in torch.testing.get_all_device_types():
tensor = torch.randn(2, 3, device=device)
other = tensor.clone()
self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C')))
self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C')))
self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C')))
def test_squeeze(self):
x = create('N:3,C:1,H:1,W:1')
output = x.squeeze('C')
self.assertEqual(output.names, ['N', 'H', 'W'])
output = x.squeeze()
self.assertEqual(output.names, ['N'])
def test_repr(self):
named_tensor = torch.zeros(2, 3).rename_('N', 'C')
expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]], names=('N', 'C'))"
self.assertEqual(repr(named_tensor), expected)
unnamed_tensor = torch.zeros(2, 3)
expected = "tensor([[0., 0., 0.],\n [0., 0., 0.]])"
self.assertEqual(repr(unnamed_tensor), expected)
none_named_tensor = torch.zeros(2, 3).rename_(None, None)
self.assertEqual(repr(none_named_tensor), expected)
def test_diagonal(self):
named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD'))
self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None])
self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None])
self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names,
['A', 'C', 'E'])
def test_max_pooling(self):
def check_tuple_return(op, inputs, expected_names):
values, indices = op(*inputs)
self.assertEqual(values.names, expected_names)
self.assertEqual(indices.names, expected_names)
for device in torch.testing.get_all_device_types():
named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC'))
named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD'))
named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE'))
self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names)
self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names)
self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names)
check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names)
check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names)
check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names)
def test_no_save_support(self):
named_tensor = torch.zeros(2, 3, names=('N', 'C'))
buf = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "NYI"):
torch.save(named_tensor, buf)
def test_no_pickle_support(self):
named_tensor = torch.zeros(2, 3, names=('N', 'C'))
with self.assertRaisesRegex(RuntimeError, "NYI"):
serialized = pickle.dumps(named_tensor)
def test_no_multiprocessing_support(self):
named_tensor = torch.zeros(2, 3, names=('N', 'C'))
buf = io.BytesIO()
with self.assertRaisesRegex(RuntimeError, "NYI"):
ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor)
def test_big_tensor_repr_has_names(self):
def check_repr(named_tensor):
unnamed_tensor = named_tensor.rename(None)
names_tag = 'names={}'.format(named_tensor.names)
self.assertIn(names_tag, repr(named_tensor))
check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W')))
def test_noncontig_contiguous(self):
# This type of contiguous is special-cased and therefore needs its own test
for device in torch.testing.get_all_device_types():
x = torch.randn(2, 3, device=device).t().rename_('N', 'C')
self.assertEqual(x.contiguous().names, ('N', 'C'))
def test_copy_transpose(self):
# This type of copy is special-cased and therefore needs its own test
def _test(self_names, other_names, expected_names):
x = torch.empty(2, 5, names=self_names)
y = torch.empty(5, 2).t().rename_(*other_names)
x.copy_(y)
self.assertEqual(x.names, expected_names)
_test(('N', 'C'), ('N', 'C'), ('N', 'C'))
_test(None, ('N', 'C'), ('N', 'C'))
def test_rename_(self):
tensor = torch.empty(1, 1, names=('N', 'C'))
self.assertEqual(tensor.rename_(None).names, (None, None))
self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W'))
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
tensor.rename_('N', 'C', 'W')
with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
tensor.rename_('N', 'N')
def test_rename(self):
tensor = torch.empty(1, 1, names=('N', 'C'))
self.assertEqual(tensor.rename(None).names, (None, None))
self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W'))
# Check that we didn't modify tensor.names
self.assertEqual(tensor.names, ('N', 'C'))
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
tensor.rename('N', 'C', 'W')
with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
tensor.rename('N', 'N')
with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'):
tensor.rename(None, N='batch')
# rename returns a view on the tensor
self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr())
self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr())
def test_rename_globber(self):
scalar = torch.randn([])
unnamed_tensor = torch.empty(1, 1, 1, 1)
named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))
self.assertEqual(scalar.rename(None).names, [])
self.assertEqual(scalar.rename('...').names, [])
# Check that it works with unnamed tensors
self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names)
self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names,
[None, None, 'H', 'W'])
self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names,
['N', None, None, 'W'])
self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names,
['N', 'C', None, None])
# Check that it works with named tensors
self.assertEqual(named_tensor.rename('...').names, named_tensor.names)
self.assertEqual(named_tensor.rename('...', 'width').names,
['N', 'C', 'H', 'width'])
self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names,
['batch', 'channels', 'H', 'width'])
self.assertEqual(named_tensor.rename('batch', '...').names,
['batch', 'C', 'H', 'W'])
# Test empty glob
self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names,
[None, None, None, None])
self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names,
['N', 'C', 'H', 'W'])
# Multiple globs throw
with self.assertRaisesRegex(RuntimeError, 'More than one '):
named_tensor.rename('...', 'channels', '...')
def test_rename_rename_map(self):
scalar = torch.randn([])
unnamed_tensor = torch.empty(1, 1, 1, 1)
named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W'))
with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
scalar.rename(N='batch')
with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"):
unnamed_tensor.rename(N='batch')
with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
named_tensor.rename(B='batch')
with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"):
named_tensor.rename(H='height', B='batch')
self.assertEqual(named_tensor.rename(N='batch').data_ptr(),
named_tensor.data_ptr())
self.assertEqual(named_tensor.rename(N='batch').names,
['batch', 'C', 'H', 'W'])
self.assertEqual(named_tensor.rename(N='batch', H='height').names,
['batch', 'C', 'height', 'W'])
def test_set_names_property(self):
tensor = torch.empty(1, 1, names=('N', 'C'))
tensor.names = None
self.assertEqual(tensor.names, (None, None))
tensor.names = ('N', 'W')
self.assertEqual(tensor.names, ('N', 'W'))
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
tensor.names = ['N', 'C', 'W']
with self.assertRaisesRegex(RuntimeError, 'duplicate names'):
tensor.names = ['N', 'N']
def test_factory_edge_cases(self):
for device in torch.testing.get_all_device_types():
self._test_factory(torch.empty, device)
def test_factory_coverage(self):
def _test(factory, device):
names = ('N', 'T', 'D')
torch.manual_seed(0)
result = factory(1, 2, 3, names=names, device=device)
torch.manual_seed(0)
expected = factory(1, 2, 3, device=device).rename_(*names)
self.assertTensorDataAndNamesEqual(result, expected)
supported = [
torch.ones,
torch.rand,
torch.randn,
torch.zeros,
]
for op, device in itertools.product(supported, torch.testing.get_all_device_types()):
_test(op, device)
# Test torch.full
for device in torch.testing.get_all_device_types():
names = ('N', 'T', 'D')
result = torch.full([1, 2, 3], 2, names=names, device=device)
expected = torch.full([1, 2, 3], 2, device=device).rename_(*names)
self.assertTensorDataAndNamesEqual(result, expected)
def test_tensor_from_lists(self):
names = ('N', 'C')
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)
names = ('N',)
tensor = torch.tensor([1], names=names)
self.assertEqual(tensor.names, names)
with self.assertRaisesRegex(RuntimeError, 'Number of names'):
names = ('N', 'C')
tensor = torch.tensor([1], names=names)
@unittest.skipIf(not TEST_NUMPY, "no numpy")
def test_tensor_from_numpy(self):
import numpy as np
arr = np.array([[1]])
names = ('N', 'C')
tensor = torch.tensor([[1]], names=names)
self.assertEqual(tensor.names, names)
def test_tensor_from_tensor(self):
x = torch.randn(1, 1)
names = ('N', 'C')
tensor = torch.tensor(x, names=names)
self.assertEqual(tensor.names, names)
def test_tensor_from_named_tensor(self):
x = torch.randn(1, 1, names=('N', 'D'))
tensor = torch.tensor(x)
self.assertEqual(tensor.names, ('N', 'D'))
# there's no way to distinguish between names=None and not passing in names.
# If the user passes in names=None they are asking for trouble.
x = torch.randn(1, 1, names=('N', 'D'))
tensor = torch.tensor(x, names=None)
self.assertEqual(tensor.names, ('N', 'D'))
x = torch.randn(1, 1, names=('N', 'D'))
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
tensor = torch.tensor(x, names=('N', 'C'))
def test_size(self):
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
self.assertEqual(t.size('N'), 2)
self.assertEqual(t.size('C'), 5)
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name*'):
t.size(None)
with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
t.size('channels')
with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
torch.empty(2, 3, 4).size('N')
def test_stride(self):
t = torch.empty(2, 3, 5, names=('N', None, 'C'))
self.assertEqual(t.stride('N'), 3 * 5)
self.assertEqual(t.stride('C'), 1)
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
t.stride(None)
with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '):
t.stride('channels')
with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '):
torch.empty(2, 3, 4).stride('N')
def test_transpose_variants(self):
t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W'))
self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W'])
self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C'])
t = torch.randn(2, 3, names=('N', 'C'))
self.assertEqual(t.t().names, ['C', 'N'])
def test_resize(self):
for device in torch.testing.get_all_device_types():
named = torch.randn(2, names=('N',), device=device)
named.resize_([2])
self.assertEqual(named.names, ['N'])
with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"):
named.resize_([3])
other_named = torch.randn(2, names=('N',), device=device)
named.resize_as_(other_named)
self.assertEqual(other_named.names, ['N'])
unnamed = torch.randn(2, device=device)
with self.assertRaisesRegex(
RuntimeError, r'names .* are not the same as the computed output names'):
named.resize_as_(unnamed)
unnamed = torch.randn(1, device=device)
unnamed.resize_as_(named)
self.assertEqual(unnamed.names, ['N'])
def test_cdist(self):
for device in torch.testing.get_all_device_types():
tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'),
device=device)
other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'),
device=device)
result = torch.cdist(tensor, other)
self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group'])
def test_info_smoke(self):
# Smoke test for info functions / methods / attributes on named tensors.
tensor = torch.empty(1, 1, names=('N', 'D'))
tensor.device
tensor.dtype
tensor.get_device()
tensor.is_complex()
tensor.is_floating_point()
tensor.is_nonzero()
torch.is_same_size(tensor, tensor)
torch.is_signed(tensor)
tensor.layout
tensor.numel()
tensor.dim()
tensor.element_size()
tensor.is_contiguous()
tensor.is_cuda
tensor.is_leaf
tensor.is_pinned()
tensor.is_shared()
tensor.is_sparse
tensor.ndimension()
tensor.nelement()
tensor.shape
tensor.size()
tensor.size(1)
tensor.storage()
tensor.storage_offset()
tensor.storage_type()
tensor.stride()
tensor.stride(1)
tensor.data
tensor.data_ptr()
tensor.ndim
tensor.item()
tensor.type()
tensor.is_shared()
tensor.is_signed()
def test_autograd_smoke(self):
x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True)
y = x.clone()
y.retain_grad()
y.register_hook(lambda x: x)
y.sum().backward()
# autograd related attributes
tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True)
tensor = tensor.relu()
tensor.output_nr
tensor.grad_fn
tensor.requires_grad
def test_split_fns_propagates_names(self):
fns = [
lambda x: x.split(1, 0),
lambda x: x.split([1, 1], 1),
lambda x: x.chunk(2, 0),
]
for device in torch.testing.get_all_device_types():
orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device)
for fn in fns:
splits = fn(orig_tensor)
for split in splits:
self.assertEqual(split.names, orig_tensor.names)
def test_any_all(self):
for device in torch.testing.get_all_device_types():
x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',))
self.assertEqual(x.any().names, [])
self.assertEqual(x.all().names, [])
def test_addcmul_addcdiv(self):
for device in torch.testing.get_all_device_types():
names = ['N']
a = torch.rand(3, device=device, names=names)
b = torch.rand(3, device=device, names=names)
# avoid division by 0
c = torch.rand(3, device=device, names=names).clamp_min_(0.1)
out = torch.randn(3, device=device, names=names)
self.assertEqual(torch.addcmul(a, b, c).names, names)
self.assertEqual(torch.addcmul(a, b, c, out=out).names, names)
self.assertEqual(a.addcmul_(b, c).names, names)
self.assertEqual(torch.addcdiv(a, b, c).names, names)
self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names)
self.assertEqual(a.addcdiv_(b, c).names, names)
def test_binary_ops(self):
def test_basic(op):
a = torch.empty(2, 3, names=('N', 'C'))
b = torch.empty(3, 2, names=('C', 'N'))
c = torch.empty(3, names=('C',))
d = torch.empty(5, names=('W',))
self.assertEqual(op(a, a).names, ('N', 'C'))
self.assertEqual(op(a, c).names, ('N', 'C'))
with self.assertRaisesRegex(RuntimeError, "do not match"):
op(a, d)
with self.assertRaisesRegex(RuntimeError, "do not match"):
op(a, b)
def test_wildcard(op):
a = torch.empty(2, 3, names=('N', 'C'))
c = torch.empty(2, 3, names=(None, 'C'))
self.assertEqual(op(a, c).names, ('N', 'C'))
b = torch.empty(2, 3)
self.assertEqual(op(a, b).names, ('N', 'C'))
d = torch.empty(2, 3, names=('C', None))
with self.assertRaisesRegex(RuntimeError, "Misaligned"):
op(d, c)
def test_mixed_unnamed_named(op, is_inplace):
named2 = torch.randn(1, 1, names=('N', 'C'))
unnamed1 = torch.randn(1)
unnamed2 = torch.randn(1, 1)
unnamed3 = torch.randn(1, 1, 1)
def compute_expected_names(tensor, other):
assert tensor.has_names() ^ other.has_names()
named = tensor if tensor.has_names() else other
unnamed = other if tensor.has_names() else tensor
unnamed_dim = unnamed.dim()
if unnamed_dim > named.dim():
return [None] * (unnamed_dim - named.dim()) + list(named.names)
else:
return named.names
inputs = itertools.chain(
itertools.product([named2], [unnamed1, unnamed2, unnamed3]),
itertools.product([unnamed1, unnamed2, unnamed3], [named2]),
)
if is_inplace:
# In-place ops have the constraint that they must not change shape.
inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()]
for tensor, other in inputs:
expected_names = compute_expected_names(tensor, other)
self.assertEqual(op(tensor, other).names, expected_names)
def method(name, *args, **kwargs):
return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))]
def function(name, *args, **kwargs):
return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))]
def out_function(name, *args, **kwargs):
out_fn = getattr(torch, name)
def fn(a, b):
result = torch.empty([0], dtype=a.dtype, device=a.device)
out_fn(a, b, *args, out=result, **kwargs)
return result
return [Function(name, fn)]
def fn_method_and_inplace(name, *args, **kwargs):
return (
method(name, *args, **kwargs) +
method(name + '_', *args, **kwargs) +
out_function(name, *args, **kwargs)
)
tests = [
fn_method_and_inplace('add'),
fn_method_and_inplace('div'),
fn_method_and_inplace('mul'),
fn_method_and_inplace('sub'),
fn_method_and_inplace('pow'),
fn_method_and_inplace('atan2'),
method('copy_'),
function('floor_divide'),
function('true_divide'),
]
tests = flatten(tests)
for name, op in tests:
test_basic(op)
test_wildcard(op)
test_mixed_unnamed_named(op, is_inplace=name.endswith('_'))
def test_logical_ops(self):
# Implemented via TensorIterator, so just check that each version
# (out-of-place, inplace, out=) propagates names.
def zeros(*args, **kwargs):
return torch.zeros(*args, dtype=torch.bool, **kwargs)
for op in ('logical_xor', 'logical_and', 'logical_or'):
self._test_name_inference(
getattr(torch, op),
(create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
expected_names=['N', 'C'])
self._test_name_inference(
getattr(Tensor, op + '_'),
(create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
expected_names=['N', 'C'])
self._test_name_inference(
lambda out, x, y: getattr(torch, op)(x, y, out=out),
(create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)),
expected_names=['N', 'C'])
def test_pow_special(self):
# There are a few pow cases that don't go through TensorIterator.
# Test them here.
for device in torch.testing.get_all_device_types():
named = torch.randn(2, 3, names=('N', 'C'), device=device)
unnamed = torch.randn([0], device=device)
result = torch.pow(named, 0, out=unnamed.clone())
self.assertEqual(result.names, named.names)
result = torch.pow(named, 1, out=unnamed.clone())
self.assertEqual(result.names, named.names)
result = torch.pow(1, named, out=unnamed.clone())
self.assertEqual(result.names, named.names)
def test_out_fn_semantics(self):
out_fn = torch.abs
unnamed_tensor = torch.randn(3, 2)
none_named_tensor = torch.randn(3, 2, names=(None, None))
named_tensor = torch.randn(3, 2, names=('N', 'C'))
partially_named_tensor = torch.randn(3, 2, names=('N', None))
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(partially_named_tensor, out=named_tensor)
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(named_tensor, out=partially_named_tensor)
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(none_named_tensor, out=named_tensor)
with self.assertRaisesRegex(RuntimeError, "Name mismatch"):
out_fn(unnamed_tensor, out=named_tensor)
output = torch.randn(3, 2)
out_fn(unnamed_tensor, out=output)
self.assertFalse(output.has_names())
output = torch.randn(3, 2, names=(None, None))
out_fn(named_tensor, out=output)
self.assertEqual(output.names, named_tensor.names)
output = torch.randn(3, 2)
out_fn(named_tensor, out=output)
self.assertEqual(output.names, named_tensor.names)
output = torch.randn(3, 2, names=(None, None))
out_fn(unnamed_tensor, out=output)
self.assertFalse(output.has_names())
def test_unary_propagate_names_fns(self):
def _test(testcase, names=('N', 'D'), device='cpu'):
sizes = [2] * len(names)
tensor = torch.empty(sizes, names=names, device=device)
try:
out = testcase.lambd(tensor)
except RuntimeError as err:
# Get a better error message by catching the error and asserting.
raise RuntimeError('{}: {}'.format(testcase.name, err))
self.assertEqual(out.names, tensor.names,
message=testcase.name)
def fn(name, *args, **kwargs):
return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))]
def method(name, *args, **kwargs):
return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))]
def out_function(name, *args, **kwargs):
out_fn = getattr(torch, name)
def fn(tensor):
result = torch.empty([0], dtype=tensor.dtype, device=tensor.device)
out_fn(tensor, *args, out=result, **kwargs)
return result
return [Function(name + '_out', fn)]
def fn_method_and_inplace(name, *args, **kwargs):
return (
method(name, *args, **kwargs) +
method(name + '_', *args, **kwargs) +
out_function(name, *args, **kwargs)
)
# All of these operate on 2x2 tensors.
tests = [
# unary pointwise
fn_method_and_inplace('abs'),
fn_method_and_inplace('acos'),
fn_method_and_inplace('asin'),
fn_method_and_inplace('atan'),
fn_method_and_inplace('ceil'),
fn_method_and_inplace('clamp', -1, 1),
fn_method_and_inplace('clamp_min', -2),
fn_method_and_inplace('clamp_max', 2),
method('cauchy_'),
method('clone'),
method('contiguous'),
fn_method_and_inplace('cos'),
fn_method_and_inplace('cosh'),
fn_method_and_inplace('digamma'),
fn_method_and_inplace('erf'),
fn_method_and_inplace('erfc'),
fn_method_and_inplace('erfinv'),
fn_method_and_inplace('exp'),
fn_method_and_inplace('expm1'),
method('exponential_'),
fn_method_and_inplace('floor'),
fn_method_and_inplace('frac'),
method('geometric_', p=0.5),
fn_method_and_inplace('lgamma'),
fn_method_and_inplace('log'),
fn_method_and_inplace('log10'),
fn_method_and_inplace('log1p'),
fn_method_and_inplace('log2'),
method('log_normal_'),
fn_method_and_inplace('neg'),
method('normal_'),
[Function('polygamma', lambda t: torch.polygamma(1, t))],
method('polygamma_', 1),
fn_method_and_inplace('reciprocal'),
method('random_', 0, 1),
method('random_', 1),
method('random_'),
method('relu_'),
method('requires_grad_'),
method('relu'),
fn_method_and_inplace('round'),
fn_method_and_inplace('rsqrt'),
fn_method_and_inplace('sigmoid'),
fn_method_and_inplace('sign'),
fn_method_and_inplace('sin'),
fn_method_and_inplace('sinh'),
fn_method_and_inplace('sqrt'),
fn_method_and_inplace('tan'),
fn_method_and_inplace('tanh'),
fn('threshold', 0, 1),
fn('threshold_', 0, 1),
out_function('threshold', 0, 1),
fn_method_and_inplace('trunc'),
method('uniform_'),
method('zero_'),
method('fill_', 1),
method('fill_', torch.tensor(3.14)),
# conversions
method('to', dtype=torch.long),
method('to', device='cpu'),
method('to', torch.empty([])),
method('bool'),
method('byte'),
method('char'),
method('cpu'),
method('double'),
method('float'),
method('long'),
method('half'),
method('int'),
method('short'),
method('type', dtype=torch.long),
# cumsum and cumprod
fn('cumsum', 0),
fn('cumsum', 'D'),
out_function('cumsum', 'D'),
fn('cumprod', 0),
fn('cumprod', 'D'),
out_function('cumprod', 'D'),
# views
method('narrow', 0, 0, 1),
# creation functions
fn('empty_like'),
fn('zeros_like'),
fn('ones_like'),
fn('full_like', 3.14),
fn('rand_like'),
fn('randn_like'),
# bernoulli variants
method('bernoulli_', 0.5),
method('bernoulli_', torch.tensor(0.5)),
method('softmax', dim=1),
method('softmax', dim='D'),
method('log_softmax', dim=1),
method('log_softmax', dim='D'),
[Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))],
[Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))],
]
tests = flatten(tests)
for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()):
_test(testcase, device=device)
def test_cummax_cummin(self):
def test_ops(op):
for device in torch.testing.get_all_device_types():
names = ('N', 'D')
tensor = torch.rand(2, 3, names=names)
result = op(tensor, 0)
self.assertEqual(result[0].names, names)
self.assertEqual(result[1].names, names)
test_ops(torch.cummax)
test_ops(torch.cummin)
def test_bitwise_not(self):
for device in torch.testing.get_all_device_types():
names = ('N', 'D')
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
result = torch.empty(0, dtype=torch.bool)
self.assertEqual(tensor.bitwise_not().names, names)
self.assertEqual(torch.bitwise_not(tensor, out=result).names, names)
self.assertEqual(tensor.bitwise_not_().names, names)
def test_logical_not(self):
for device in torch.testing.get_all_device_types():
names = ('N', 'D')
tensor = torch.zeros(2, 3, names=names, dtype=torch.bool)
result = torch.empty(0, dtype=torch.bool)
self.assertEqual(tensor.logical_not().names, names)
self.assertEqual(torch.logical_not(tensor, out=result).names, names)
self.assertEqual(tensor.logical_not_().names, names)
def test_bernoulli(self):
for device in torch.testing.get_all_device_types():
names = ('N', 'D')
tensor = torch.rand(2, 3, names=names)
result = torch.empty(0)
self.assertEqual(tensor.bernoulli().names, names)
torch.bernoulli(tensor, out=result)
self.assertEqual(result.names, names)
def test_flatten(self):
tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W'))
# basic
out = tensor.flatten('D', 'W', 'features')
self.assertEqual(out.names, ['N', 'C', 'features'])
self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
# int overload
out = tensor.flatten(2, 4, 'features')
self.assertEqual(out.names, ['N', 'C', 'features'])
self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
# list overload
out = tensor.flatten(['D', 'H', 'W'], 'features')
self.assertEqual(out.names, ['N', 'C', 'features'])
self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1))
# Non-contiguous flatten: N and H are not "adjacent" in memory.
sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D'))
sentences = sentences.transpose('T', 'H')
out = sentences.flatten('N', 'H', 'N_H')
self.assertEqual(out.names, ['N_H', 'T', 'D'])
with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"):
tensor.flatten(['D', 'L'], 'features')
with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
tensor.flatten(['D', 'W'], 'features')
with self.assertRaisesRegex(RuntimeError, "must be consecutive in"):
tensor.flatten(['H', 'D', 'W'], 'features')
def test_unflatten(self):
tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K'))
# accepts iterable of tuples
out = tensor.unflatten('D', (('C', 2), ('H', 3), ('W', 5)))
self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
self.assertEqual(out.shape, (7, 2, 3, 5, 11))
# accepts OrderedDict
out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5))))
self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
self.assertEqual(out.shape, (7, 2, 3, 5, 11))
# Unflatten left-most
out = tensor.unflatten('N', (('N', 7), ('H', 1)))
self.assertEqual(out.names, ('N', 'H', 'D', 'K'))
self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11))
# Unflatten right-most
out = tensor.unflatten('K', (('K', 11), ('H', 1)))
self.assertEqual(out.names, ('N', 'D', 'K', 'H'))
self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1))
# takes positional dim
out = tensor.unflatten(1, (('C', 2), ('H', 3), ('W', 5)))
self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
self.assertEqual(out.shape, (7, 2, 3, 5, 11))
# takes negative positional dim
out = tensor.unflatten(-2, (('C', 2), ('H', 3), ('W', 5)))
self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K'))
self.assertEqual(out.shape, (7, 2, 3, 5, 11))
with self.assertRaisesRegex(RuntimeError, "don't multiply up to"):
tensor.unflatten('D', (('H', 3), ('W', 5)))
with self.assertRaisesRegex(RuntimeError, 'OrderedDict or iterable of tuples'):
tensor.unflatten('D', None)
with self.assertRaisesRegex(RuntimeError, 'non-empty'):
tensor.unflatten('D', OrderedDict())
def test_unsupported_op_error_msg(self):
named = torch.randn(3, 3, names=('N', 'C'))
with self.assertRaisesRegex(
RuntimeError, "pdist is not yet supported with named tensors"):
torch.pdist(named)
def test_reduction_fns(self):
def check_output(output, expected_names):
if isinstance(output, torch.Tensor):
self.assertEqual(output.names, expected_names)
return
for out in output:
self.assertEqual(out.names, expected_names)
def sum_all_outputs(output):
if isinstance(output, torch.Tensor):
return output.sum()
result = 0
for out in output:
result = out + result
return result.sum()
def test_simple_reduce(op, device):
t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
check_output(op(t, 1), ['N', 'L'])
check_output(op(t, -1), ['N', 'C'])
check_output(op(t, 'C'), ['N', 'L'])
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
op(t, None)
with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'):
op(t, 'H')
def test_autograd_supports_dimname_overload(op, device):
t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True)
sum_all_outputs(op(t, 'C')).backward()
self.assertIsNotNone(t.grad)
def test_complete_reduce(op, device):
t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
check_output(op(t), [])
def test_multidim_reduce(op, device):
t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
check_output(op(t, [1, 2]), ['N'])
check_output(op(t, [0, -1]), ['C'])
check_output(op(t, ['C', 'L']), ['N'])
with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'):
op(t, [None, 'C'])
def test_out_variant(op, output_lambda, device):
t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
if output_lambda:
out = output_lambda(t)
else:
out = torch.empty([0], device=device)
op(t, 'C', out=out)
check_output(out, ['N', 'L'])
def test_keepdim(op, device):
t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device)
check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L'])
def values_and_indices(t):
return (torch.empty([0], device=t.device),
torch.empty([0], device=t.device, dtype=torch.long))
def kthvalue_wrapper(tensor, *args, **kwargs):
# Return the 0-th value
return torch.kthvalue(tensor, 1, *args, **kwargs)
Case = namedtuple('Case', [
'op',
'supports_complete_reduce',
'supports_multidim_reduce',
'supports_out_variant',
'supports_keepdim',
'output_lambda',
])
tests = [
Case(torch.sum, True, True, True, True, None),
Case(torch.prod, True, False, True, True, None),
Case(torch.mean, True, True, True, True, None),
Case(torch.var, True, True, True, True, None),
Case(torch.std, True, True, True, True, None),
Case(torch.std_mean, True, True, False, True, None),
Case(torch.var_mean, True, True, False, True, None),
Case(torch.min, True, False, True, True, values_and_indices),
Case(torch.max, True, False, True, True, values_and_indices),
Case(torch.unbind, False, False, False, False, None),
Case(torch.logsumexp, False, True, True, True, None),
Case(torch.mode, False, False, True, True, values_and_indices),
Case(kthvalue_wrapper, False, False, True, True, values_and_indices),
Case(torch.median, True, False, True, True, values_and_indices),
]
for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()):
op = testcase.op
test_simple_reduce(op, device)
test_autograd_supports_dimname_overload(op, device)
if testcase.supports_keepdim:
test_keepdim(op, device)
if testcase.supports_out_variant:
test_out_variant(op, testcase.output_lambda, device)
if testcase.supports_complete_reduce:
test_complete_reduce(op, device)
if testcase.supports_multidim_reduce:
test_multidim_reduce(op, device)
def test_masked_select(self):
# simple
self._test_name_inference(
torch.masked_select,
(create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
expected_names=[None])
# left broadcast
self._test_name_inference(
torch.masked_select,
(create('C:3'), (create('2,3') > 0).rename('N', 'C')),
expected_names=[None])
# right broadcast
self._test_name_inference(
torch.masked_select,
(create('N:2,C:3'), (create('3') > 0).rename('C')),
expected_names=[None])
# error
self._test_name_inference(
torch.masked_select,
(create('N:2,C:3'), (create('3') > 0).rename('D')),
maybe_raises_regex='do not match')
# out=
self._test_name_inference(
out_fn(torch.masked_select),
(create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')),
expected_names=[None])
def test_cat(self):
# simple
self._test_name_inference(
torch.cat,
[[create('N:2,C:3'), create('N:2,C:3')]],
expected_names=['N', 'C'])
# error: zero dim
self._test_name_inference(
torch.cat,
[[create(''), create('')]],
maybe_raises_regex='zero-dim')
# error: names don't match
self._test_name_inference(
torch.cat,
[[create('N:2,C:3'), create('C:3,N:2')]],
maybe_raises_regex='do not match')
# error: different number of dims
self._test_name_inference(
torch.cat,
[[create('N:2,C:3'), create('C:3')]],
maybe_raises_regex='must have same number of dimensions')
# out=
self._test_name_inference(
out_fn(torch.cat),
[create('0'), [create('N:2,C:3'), create('N:2,C:3')]],
expected_names=['N', 'C'])
def test_masked_fill(self):
# simple
self._test_name_inference(
Tensor.masked_fill,
(create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
expected_names=['N', 'C'])
# left broadcast
self._test_name_inference(
Tensor.masked_fill,
(create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
maybe_raises_regex="must be less than or equal to")
# right broadcast
self._test_name_inference(
Tensor.masked_fill,
(create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14),
expected_names=['N', 'C'])
# error
self._test_name_inference(
Tensor.masked_fill,
(create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14),
maybe_raises_regex='do not match')
# inplace
self._test_name_inference(
Tensor.masked_fill_,
(create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
expected_names=['N', 'C'])
# inplace, computed names don't match output tensor names
self._test_name_inference(
Tensor.masked_fill_,
(create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14),
maybe_raises_regex="not the same as the computed output names")
def test_using_seen_interned_string_doesnt_bump_refcount(self):
def see_name():
seen_name = 'N'
pass_name_to_python_arg_parser(seen_name)
see_name()
seen_name = 'N'
old_refcnt = sys.getrefcount(seen_name)
pass_name_to_python_arg_parser(seen_name)
new_refcnt = sys.getrefcount(seen_name)
self.assertEqual(new_refcnt, old_refcnt)
def test_using_unseen_interned_string_bumps_refcount_permanently(self):
# Please don't use this as a name in a different test.
unseen_name = 'abcdefghi'
old_refcnt = sys.getrefcount(unseen_name)
pass_name_to_python_arg_parser(unseen_name)
new_refcnt = sys.getrefcount(unseen_name)
self.assertEqual(new_refcnt, old_refcnt + 1)
def test_using_unseen_uninterned_string_refcounts(self):
# Please don't use this as a name in a different test.
# non-compile-time constants are not interned
unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl'])
interned_unseen_name = 'abcdefghijkl'
self.assertFalse(unseen_name is interned_unseen_name)
old_uninterned_refcnt = sys.getrefcount(unseen_name)
old_interned_refcnt = sys.getrefcount(interned_unseen_name)
pass_name_to_python_arg_parser(unseen_name)
new_uninterned_refcnt = sys.getrefcount(unseen_name)
new_interned_refcnt = sys.getrefcount(interned_unseen_name)
# Internally, PyTorch should not hold a reference to the uninterned string
self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt)
# Instead, we should hold a new reference to the interned version.
self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1)
def _test_select(self, device):
x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
y = x.select(1, 1)
self.assertEqual(y.names, ('N', 'H', 'W'))
y = x.select('C', 1)
self.assertEqual(y.names, ('N', 'H', 'W'))
with self.assertRaisesRegex(
RuntimeError, 'Please look up dimensions by name'):
y = x.select(None, 1)
def test_select(self):
self._test_select('cpu')
@unittest.skipIf(not TEST_CUDA, 'no CUDA')
def test_select_cuda(self):
self._test_select('cuda')
def _test_as_strided(self, device):
x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device)
y = x.as_strided([2 * 3 * 4 * 5], [1])
self.assertEqual(y.names, (None,))
def test_as_strided(self):
self._test_as_strided('cpu')
@unittest.skipIf(not TEST_CUDA, 'no CUDA')
def test_as_strided_cuda(self):
self._test_as_strided('cuda')
def test_no_jit_tracer_support(self):
def foo(x):
return torch.full(x.shape, 2, names=('N',))
with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
x = torch.randn(3)
torch.jit.trace(foo, example_inputs=x)
def bar(x):
return x.select('N', 1)
with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'):
x = torch.randn(3)
torch.jit.trace(bar, example_inputs=x)
def test_no_jit_script_support(self):
@torch.jit.script
def foo(x):
return x + 1
with self.assertRaisesRegex(RuntimeError, 'NYI'):
foo(torch.randn(2, 3, names=('N', 'C')))
@torch.jit.ignore
def add_names(x):
x.names = ('N', 'C')
@torch.jit.script
def return_named_tensor(input):
add_names(input)
return input
with self.assertRaisesRegex(RuntimeError, "NYI"):
return_named_tensor(torch.randn(1, 1))
def test_align_to(self):
# trivial
tensor = create('N:3')
output = tensor.align_to('N')
self.assertEqual(output.names, ['N'])
self.assertEqual(output.shape, [3])
# unsqueeze behavior
tensor = create('N:3')
output = tensor.align_to('N', 'D')
self.assertEqual(output.names, ['N', 'D'])
self.assertEqual(output.shape, [3, 1])
# transpose behavior
tensor = create('N:3,C:2')
output = tensor.align_to('C', 'N')
self.assertEqual(output.names, ['C', 'N'])
self.assertEqual(output.shape, [2, 3])
# unsqueeze / transpose
tensor = create('C:2,N:3,H:5')
output = tensor.align_to('N', 'H', 'W', 'C')
self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
self.assertEqual(output.shape, [3, 5, 1, 2])
# All input dimensions must be named
with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"):
create('None:2,C:3').align_to('N', 'C')
# not enough names
with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"):
create('N:2,C:3').align_to('C')
# names not found
with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"):
create('N:2,C:3').align_to('D', 'N')
def test_align_to_ellipsis(self):
tensor = create('N:7,H:3,W:5,C:2')
# ... = ['N', 'H', 'W', 'C']
output = tensor.align_to('...')
self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
self.assertEqual(output.shape, [7, 3, 5, 2])
# ... = ['H', 'C']
output = tensor.align_to('...', 'W', 'N')
self.assertEqual(output.names, ['H', 'C', 'W', 'N'])
self.assertEqual(output.shape, [3, 2, 5, 7])
# ... = ['N', 'W']
output = tensor.align_to('H', 'C', '...')
self.assertEqual(output.names, ['H', 'C', 'N', 'W'])
self.assertEqual(output.shape, [3, 2, 7, 5])
# ... = ['H', 'C']
output = tensor.align_to('W', '...', 'N')
self.assertEqual(output.names, ['W', 'H', 'C', 'N'])
self.assertEqual(output.shape, [5, 3, 2, 7])
# ... = []
output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W')
self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W'])
self.assertEqual(output.shape, [7, 2, 1, 3, 5])
# Input tensor partially named
partially_named = create('None:2,None:3,None:5,C:7')
output = partially_named.align_to('C', '...')
self.assertEqual(output.names, ['C', None, None, None])
self.assertEqual(output.shape, [7, 2, 3, 5])
with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"):
partially_named.align_to('C', None, '...')
# Input order partially named
with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"):
tensor.align_to('...', 'N', None)
# Input order duplicate names
with self.assertRaisesRegex(RuntimeError, "duplicate names"):
tensor.align_to('...', 'N', 'N')
def test_align_as(self):
# align_as calls align_to internally. align_to has pretty substantial tests,
# so just test some basic things here.
tensor = create('C:2,N:3,H:5')
other = create('N:1,H:1,W:1,C:1')
output = tensor.align_as(other)
self.assertEqual(output.names, ['N', 'H', 'W', 'C'])
self.assertEqual(output.shape, [3, 5, 1, 2])
@unittest.skip("Not implemented yet")
def test_align_tensors_two_inputs(self):
def _test(tensor_namedshape, align_names, expected_sizes, expected_error):
tensor_names, tensor_sizes = tensor_namedshape
tensor = torch.empty(*tensor_sizes, names=tensor_names)
other = torch.empty([1] * len(align_names), names=align_names)
if expected_error is not None:
with self.assertRaisesRegex(RuntimeError, expected_error):
torch.align_tensors(tensor, other)
return
output, _ = torch.align_tensors(tensor, other)
self.assertEqual(output.shape, expected_sizes)
self.assertEqual(output.names, align_names)
Case = namedtuple('Case', [
'tensor_namedshape',
'align_names',
'expected_sizes',
'expected_error',
])
tests = [
# basic tests
Case(tensor_namedshape=(['C'], [2]),
align_names=['C'],
expected_sizes=[2],
expected_error=None),
Case(tensor_namedshape=(['C'], [2]),
align_names=['D'],
expected_sizes=None,
expected_error='not a subsequence'),
# single-dim alignment test
Case(tensor_namedshape=(['C'], [2]),
align_names=['N', 'C'],
expected_sizes=[1, 2],
expected_error=None),
Case(tensor_namedshape=[['N'], [2]],
align_names=['N', 'C'],
expected_sizes=[2, 1],
expected_error=None),
# multiple dim alignment test
Case(tensor_namedshape=[['N', 'C'], [2, 3]],
align_names=['N', 'H', 'C', 'W'],
expected_sizes=[2, 1, 3, 1],
expected_error=None),
Case(tensor_namedshape=[['N', 'C'], [2, 3]],
align_names=['C', 'H', 'N', 'W'],
expected_sizes=None,
expected_error='not a subsequence'),
# scalar tensor tests
Case(tensor_namedshape=[None, [[]]],
align_names=['N', 'C'],
expected_sizes=[1, 1],
expected_error=None),
Case(tensor_namedshape=[[], [[]]],
align_names=[None, None],
expected_sizes=[1, 1],
expected_error=None),
# unnamed tensor tests
Case(tensor_namedshape=[None, [2, 3]],
align_names=[None, None],
expected_sizes=[2, 3],
expected_error=None),
Case(tensor_namedshape=[None, [2, 3]],
align_names=[None, None, None],
expected_sizes=[1, 2, 3],
expected_error=None),
Case(tensor_namedshape=[None, [2]],
align_names=['N'],
expected_sizes=None,
expected_error='not a subsequence'),
# unnamed dim alignment tests
Case(tensor_namedshape=[[None], [2]],
align_names=['N', None],
expected_sizes=[1, 2],
expected_error=None),
Case(tensor_namedshape=[[None], [2]],
align_names=['N', None, None, None],
expected_sizes=[1, 1, 1, 2],
expected_error=None),
Case(tensor_namedshape=[['N'], [2]],
align_names=['N', None, None, None],
expected_sizes=[2, 1, 1, 1],
expected_error=None),
Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]],
align_names=[None, None, 'N', None],
expected_sizes=[1, 2, 3, 5],
expected_error=None),
Case(tensor_namedshape=[[None], [2]],
align_names=[None, 'N'],
expected_sizes=None,
expected_error='absolute position from the right'),
Case(tensor_namedshape=[None, [2]],
align_names=[None, 'N'],
expected_sizes=None,
expected_error='absolute position from the right'),
Case(tensor_namedshape=[[None, 'N'], [2, 3]],
align_names=[None, 'C', 'N'],
expected_sizes=None,
expected_error='absolute position from the right'),
]
for test in tests:
_test(*test)
@unittest.skip("Not implemented yet")
def test_align_tensors(self):
def reference_fn(*tensors):
longest_names = tensors[0].names
for tensor in tensors:
if len(tensor.names) > len(longest_names):
longest_names = tensor.names
return [tensor.align_to(*longest_names) for tensor in tensors]
x = torch.empty(1, 1, names=('N', 'H'))
y = torch.empty(2, 3, 5, names=('N', 'C', 'H'))
z = torch.empty(2, names=('N',))
output = torch.align_tensors(x, y, z)
expected_tensors = reference_fn(x, y, z)
for tensor, expected in zip(output, expected_tensors):
self.assertTensorDataAndNamesEqual(tensor, expected)
def test_mm(self):
for device in torch.testing.get_all_device_types():
self._test_name_inference(
torch.mm, device=device,
args=(create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
# left arg is unnamed
self._test_name_inference(
torch.mm, device=device,
args=(create('3,2'), create('W:2,H:5')),
expected_names=(None, 'H'))
# right arg is unnamed
self._test_name_inference(
torch.mm, device=device,
args=(create('N:3,C:2'), create('2,5')),
expected_names=('N', None))
# out=
self._test_name_inference(
out_fn(torch.mm), device=device,
args=(create('0'), create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
self._test_name_inference(
torch.mm, device=device,
args=(create('N:3,C:2'), create('W:2,N:5')),
maybe_raises_regex='with duplicate names')
def test_expand(self):
for device in torch.testing.get_all_device_types():
self._test_name_inference(
Tensor.expand, device=device,
args=(create('D:1'), [3]), expected_names=('D'))
self._test_name_inference(
Tensor.expand, device=device,
args=(create('H:3,W:2'), [10, 3, 3, 2]),
expected_names=(None, None, 'H', 'W'))
self._test_name_inference(
Tensor.expand, device=device,
args=(create('3, 2'), [10, 3, 3, 2]),
expected_names=(None, None, None, None))
def test_addmm(self):
for device in torch.testing.get_all_device_types():
# full names
self._test_name_inference(
torch.addmm, device=device,
args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
# no name on bias
self._test_name_inference(
torch.addmm, device=device,
args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
# partially named bias
self._test_name_inference(
torch.addmm, device=device,
args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
# out=
self._test_name_inference(
out_fn(torch.addmm), device=device,
args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
# inplace
self._test_name_inference(
torch.Tensor.addmm_, device=device,
args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')),
expected_names=('N', 'H'))
self._test_name_inference(
torch.addmm, device=device,
args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')),
maybe_raises_regex='with duplicate names')
def test_bmm(self):
for device in torch.testing.get_all_device_types():
# full names
self._test_name_inference(
torch.bmm, device=device,
args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
expected_names=('N', 'A', 'B'))
# no name on left tensor
self._test_name_inference(
torch.bmm, device=device,
args=(create('7,3,2'), create('N:7,A:2,B:5')),
expected_names=('N', None, 'B'))
# no name on right tensor
self._test_name_inference(
torch.bmm, device=device,
args=(create('N:7,A:3,B:2'), create('7,2,5')),
expected_names=('N', 'A', None))
# out=
self._test_name_inference(
out_fn(torch.bmm), device=device,
args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
expected_names=('N', 'A', 'B'))
# duplicate names after mm
self._test_name_inference(
torch.bmm, device=device,
args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
maybe_raises_regex='with duplicate names')
# matching error (batch dimensions must be alignable)
self._test_name_inference(
torch.bmm, device=device,
args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')),
maybe_raises_regex='do not match')
# misalignment (batch dimension is getting contracted)
self._test_name_inference(
torch.bmm, device=device,
args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')),
maybe_raises_regex='misaligned')
def test_matmul(self):
for device in torch.testing.get_all_device_types():
# input tensors are less than 1D
self._test_name_inference(
torch.matmul, device=device,
args=(create(''), create('A:2')),
maybe_raises_regex='at least 1D')
self._test_name_inference(
torch.matmul, device=device,
args=(create('A:2'), create('')),
maybe_raises_regex='at least 1D')
# 1D @ 1D
self._test_name_inference(
torch.matmul, device=device,
args=(create('A:2'), create('B:2')),
expected_names=[])
# ND @ 1D
self._test_name_inference(
torch.matmul, device=device,
args=(create('A:3,C:2'), create('B:2')),
expected_names=['A'])
self._test_name_inference(
torch.matmul, device=device,
args=(create('A:5,C:3,D:2'), create('B:2')),
expected_names=['A', 'C'])
# 1D @ ND
self._test_name_inference(
torch.matmul, device=device,
args=(create('C:2'), create('A:2,B:3')),
expected_names=['B'])
self._test_name_inference(
torch.matmul, device=device,
args=(create('C:2'), create('A:3,B:2,D:5')),
expected_names=['A', 'D'])
# 2D @ 2D
self._test_name_inference(
torch.matmul, device=device,
args=(create('A:3,B:2'), create('A:2,B:3')),
expected_names=['A', 'B'])
self._test_name_inference(
torch.matmul, device=device,
args=(create('A:3,B:2'), create('B:2,A:5')),
maybe_raises_regex='with duplicate names')
# ND @ ND where N >= 2
self._test_name_inference(
torch.matmul, device=device,
args=(create('C:5,A:3,B:2'), create('A:2,B:3')),
expected_names=['C', 'A', 'B'])
self._test_name_inference(
torch.matmul, device=device,
args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')),
expected_names=['C', 'A', 'B'])
self._test_name_inference(
torch.matmul, device=device,
args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')),
expected_names=[None, 'C', 'A', 'B'])
# out=
self._test_name_inference(
out_fn(torch.matmul), device=device,
args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')),
expected_names=('N', 'A', 'B'))
# duplicate names after mm
self._test_name_inference(
torch.bmm, device=device,
args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')),
maybe_raises_regex='with duplicate names')
# misalignment (batch dimension is getting contracted)
self._test_name_inference(
torch.matmul, device=device,
args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')),
maybe_raises_regex='do not match')
def test_mv(self):
for device in torch.testing.get_all_device_types():
self._test_name_inference(
torch.mv, device=device,
args=(create('N:3,C:2'), create('W:2')),
expected_names=('N',))
# left arg is unnamed
self._test_name_inference(
torch.mv, device=device,
args=(create('3,2'), create('W:2')),
expected_names=(None,))
# right arg is unnamed
self._test_name_inference(
torch.mv, device=device,
args=(create('N:3,C:2'), create('2')),
expected_names=('N',))
# out=
self._test_name_inference(
out_fn(torch.mv), device=device,
args=(create('0'), create('N:3,C:2'), create('W:2')),
expected_names=('N',))
def test_addmv(self):
for device in torch.testing.get_all_device_types():
# full names
self._test_name_inference(
torch.addmv, device=device,
args=(create('N:3'), create('N:3,C:2'), create('H:2')),
expected_names=['N'])
# no name on bias
self._test_name_inference(
torch.addmv, device=device,
args=(create('3'), create('N:3,C:2'), create('H:2')),
expected_names=('N',))
# out=
self._test_name_inference(
out_fn(torch.addmv), device=device,
args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')),
expected_names=('N',))
# inplace
self._test_name_inference(
torch.Tensor.addmv_, device=device,
args=(create('N:3'), create('N:3,C:2'), create('H:2')),
expected_names=('N',))
def test_autograd_ignores_names(self):
# sigmoid forward is supported by named tensors, but sigmoid_backward
# is not (see native_functions.yaml). Test that autograd ignores names
# and that the sigmoid_backward succeeds.
x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
x.sigmoid().sum().backward()
def test_tensor_grad_is_unnamed(self):
x = torch.randn(3, 3, names=(None, None), requires_grad=True)
y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True)
(x * y).sum().backward()
# Check that names weren't propagated
self.assertEqual(y.grad.names, [None, None])
self.assertEqual(x.grad.names, [None, None])
def test_autograd_warns_named_grad(self):
base = torch.randn(3, 3, names=('N', 'C'))
named_grad = base.clone()
base.requires_grad_()
with warnings.catch_warnings(record=True) as warns:
# Cause all warnings to always be triggered.
warnings.simplefilter("always")
base.clone().backward(named_grad)
self.assertEqual(len(warns), 1)
self.assertTrue(
str(warns[0].message).startswith('Autograd was passed a named grad tensor'))
def test_nyi_dimname_overload_msg(self):
x = torch.randn(3, 3)
with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"):
x.squeeze_("N")
def test_dot(self):
for device in torch.testing.get_all_device_types():
# torch.dot ignores the names of both tensors
self._test_name_inference(
torch.dot, device=device,
args=(create('C:2'), create('W:2')),
expected_names=[])
def test_comparison_ops(self):
for device in torch.testing.get_all_device_types():
a = torch.randn(3, 3, names=('N', 'C'), device=device)
b = torch.randn(3, 3, names=('N', 'C'), device=device)
scalar = torch.randn([], device=device)
self.assertEqual((a == b).names, ['N', 'C'])
self.assertEqual((a != b).names, ['N', 'C'])
self.assertEqual((a > b).names, ['N', 'C'])
self.assertEqual((a < b).names, ['N', 'C'])
self.assertEqual((a >= b).names, ['N', 'C'])
self.assertEqual((a <= b).names, ['N', 'C'])
self.assertEqual((a == 1).names, ['N', 'C'])
self.assertEqual((a != 1).names, ['N', 'C'])
self.assertEqual((a > 1).names, ['N', 'C'])
self.assertEqual((a < 1).names, ['N', 'C'])
self.assertEqual((a >= 1).names, ['N', 'C'])
self.assertEqual((a <= 1).names, ['N', 'C'])
self.assertEqual((a == scalar).names, ['N', 'C'])
self.assertEqual((a != scalar).names, ['N', 'C'])
self.assertEqual((a > scalar).names, ['N', 'C'])
self.assertEqual((a < scalar).names, ['N', 'C'])
self.assertEqual((a >= scalar).names, ['N', 'C'])
self.assertEqual((a <= scalar).names, ['N', 'C'])
res = torch.empty(3, 3, dtype=torch.bool, device=device)
torch.eq(a, b, out=res)
self.assertEqual(res.names, ['N', 'C'])
torch.ne(a, b, out=res)
self.assertEqual(res.names, ['N', 'C'])
torch.lt(a, b, out=res)
self.assertEqual(res.names, ['N', 'C'])
torch.gt(a, b, out=res)
self.assertEqual(res.names, ['N', 'C'])
torch.le(a, b, out=res)
self.assertEqual(res.names, ['N', 'C'])
torch.ge(a, b, out=res)
self.assertEqual(res.names, ['N', 'C'])
res = torch.isnan(a)
self.assertEqual(res.names, ['N', 'C'])
res = torch.isinf(a)
self.assertEqual(res.names, ['N', 'C'])
if __name__ == '__main__':
run_tests()