mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Refactors test_torch.py to be fewer than 10k lines (#47356)
Summary: Creates multiple new test suites to have fewer tests in test_torch.py, consistent with previous test suite creation like test_unary_ufuncs.py and test_linalg.py. Pull Request resolved: https://github.com/pytorch/pytorch/pull/47356 Reviewed By: ngimel Differential Revision: D25202268 Pulled By: mruberry fbshipit-source-id: 75fde3ca76545d1b32b86d432a5cb7a5ba8f5bb6
This commit is contained in:
committed by
Facebook GitHub Bot
parent
272f4db043
commit
36c87f1243
@ -22,6 +22,7 @@ from typing import Dict, Optional
|
||||
TESTS = [
|
||||
'test_autograd',
|
||||
'benchmark_utils/test_benchmark_utils',
|
||||
'test_binary_ufuncs',
|
||||
'test_bundled_inputs',
|
||||
'test_complex',
|
||||
'test_cpp_api_parity',
|
||||
@ -63,20 +64,26 @@ TESTS = [
|
||||
'test_quantization',
|
||||
'test_spectral_ops',
|
||||
'test_serialization',
|
||||
'test_shape_ops',
|
||||
'test_show_pickle',
|
||||
'test_sort_and_select',
|
||||
'test_tensor_creation_ops',
|
||||
'test_testing',
|
||||
'test_torch',
|
||||
'test_type_info',
|
||||
'test_type_hints',
|
||||
'test_unary_ufuncs',
|
||||
'test_utils',
|
||||
'test_view_ops',
|
||||
'test_vmap',
|
||||
'test_namedtuple_return_api',
|
||||
'test_numpy_interop',
|
||||
'test_jit_profiling',
|
||||
'test_jit_legacy',
|
||||
'test_jit_fuser_legacy',
|
||||
'test_tensorboard',
|
||||
'test_namedtensor',
|
||||
'test_reductions',
|
||||
'test_type_promotion',
|
||||
'test_jit_disabled',
|
||||
'test_function_schema',
|
||||
|
@ -100,6 +100,25 @@ def graph_desc(fn):
|
||||
|
||||
class TestAutograd(TestCase):
|
||||
|
||||
def test_tensor_grad_warnings(self):
|
||||
dummy = torch.empty(1)
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
# Accessing .grad on leaf
|
||||
dummy.requires_grad_()
|
||||
foo = dummy.grad
|
||||
self.assertEqual(len(w), 0)
|
||||
|
||||
# Accessing .grad on non-leaf
|
||||
dummy = dummy.clone()
|
||||
foo = dummy.grad
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
# Accessing .grad on non-leaf that retains gradients
|
||||
dummy.retain_grad()
|
||||
foo = dummy.grad
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
def _function_test(self, cls):
|
||||
x = torch.randn(5, 5, requires_grad=True)
|
||||
y = torch.randn(5, 5, requires_grad=True)
|
||||
|
2398
test/test_binary_ufuncs.py
Normal file
2398
test/test_binary_ufuncs.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -1,11 +1,649 @@
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA
|
||||
import torch
|
||||
from torch import tensor
|
||||
|
||||
import unittest
|
||||
import warnings
|
||||
import random
|
||||
from functools import reduce
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCUDA, dtypes, dtypesIfCPU, dtypesIfCUDA,
|
||||
onlyOnCPUAndCUDA)
|
||||
|
||||
|
||||
class TestIndexing(TestCase):
|
||||
def test_index(self, device):
|
||||
|
||||
def consec(size, start=1):
|
||||
sequence = torch.ones(int(torch.Tensor(size).prod(0))).cumsum(0)
|
||||
sequence.add_(start - 1)
|
||||
return sequence.view(*size)
|
||||
|
||||
reference = consec((3, 3, 3)).to(device)
|
||||
|
||||
# empty tensor indexing
|
||||
self.assertEqual(reference[torch.LongTensor().to(device)], reference.new(0, 3, 3))
|
||||
|
||||
self.assertEqual(reference[0], consec((3, 3)), atol=0, rtol=0)
|
||||
self.assertEqual(reference[1], consec((3, 3), 10), atol=0, rtol=0)
|
||||
self.assertEqual(reference[2], consec((3, 3), 19), atol=0, rtol=0)
|
||||
self.assertEqual(reference[0, 1], consec((3,), 4), atol=0, rtol=0)
|
||||
self.assertEqual(reference[0:2], consec((2, 3, 3)), atol=0, rtol=0)
|
||||
self.assertEqual(reference[2, 2, 2], 27, atol=0, rtol=0)
|
||||
self.assertEqual(reference[:], consec((3, 3, 3)), atol=0, rtol=0)
|
||||
|
||||
# indexing with Ellipsis
|
||||
self.assertEqual(reference[..., 2], torch.Tensor([[3, 6, 9],
|
||||
[12, 15, 18],
|
||||
[21, 24, 27]]), atol=0, rtol=0)
|
||||
self.assertEqual(reference[0, ..., 2], torch.Tensor([3, 6, 9]), atol=0, rtol=0)
|
||||
self.assertEqual(reference[..., 2], reference[:, :, 2], atol=0, rtol=0)
|
||||
self.assertEqual(reference[0, ..., 2], reference[0, :, 2], atol=0, rtol=0)
|
||||
self.assertEqual(reference[0, 2, ...], reference[0, 2], atol=0, rtol=0)
|
||||
self.assertEqual(reference[..., 2, 2, 2], 27, atol=0, rtol=0)
|
||||
self.assertEqual(reference[2, ..., 2, 2], 27, atol=0, rtol=0)
|
||||
self.assertEqual(reference[2, 2, ..., 2], 27, atol=0, rtol=0)
|
||||
self.assertEqual(reference[2, 2, 2, ...], 27, atol=0, rtol=0)
|
||||
self.assertEqual(reference[...], reference, atol=0, rtol=0)
|
||||
|
||||
reference_5d = consec((3, 3, 3, 3, 3)).to(device)
|
||||
self.assertEqual(reference_5d[..., 1, 0], reference_5d[:, :, :, 1, 0], atol=0, rtol=0)
|
||||
self.assertEqual(reference_5d[2, ..., 1, 0], reference_5d[2, :, :, 1, 0], atol=0, rtol=0)
|
||||
self.assertEqual(reference_5d[2, 1, 0, ..., 1], reference_5d[2, 1, 0, :, 1], atol=0, rtol=0)
|
||||
self.assertEqual(reference_5d[...], reference_5d, atol=0, rtol=0)
|
||||
|
||||
# LongTensor indexing
|
||||
reference = consec((5, 5, 5)).to(device)
|
||||
idx = torch.LongTensor([2, 4]).to(device)
|
||||
self.assertEqual(reference[idx], torch.stack([reference[2], reference[4]]))
|
||||
# TODO: enable one indexing is implemented like in numpy
|
||||
# self.assertEqual(reference[2, idx], torch.stack([reference[2, 2], reference[2, 4]]))
|
||||
# self.assertEqual(reference[3, idx, 1], torch.stack([reference[3, 2], reference[3, 4]])[:, 1])
|
||||
|
||||
# None indexing
|
||||
self.assertEqual(reference[2, None], reference[2].unsqueeze(0))
|
||||
self.assertEqual(reference[2, None, None], reference[2].unsqueeze(0).unsqueeze(0))
|
||||
self.assertEqual(reference[2:4, None], reference[2:4].unsqueeze(1))
|
||||
self.assertEqual(reference[None, 2, None, None], reference.unsqueeze(0)[:, 2].unsqueeze(0).unsqueeze(0))
|
||||
self.assertEqual(reference[None, 2:5, None, None], reference.unsqueeze(0)[:, 2:5].unsqueeze(2).unsqueeze(2))
|
||||
|
||||
# indexing 0-length slice
|
||||
self.assertEqual(torch.empty(0, 5, 5), reference[slice(0)])
|
||||
self.assertEqual(torch.empty(0, 5), reference[slice(0), 2])
|
||||
self.assertEqual(torch.empty(0, 5), reference[2, slice(0)])
|
||||
self.assertEqual(torch.tensor([]), reference[2, 1:1, 2])
|
||||
|
||||
# indexing with step
|
||||
reference = consec((10, 10, 10)).to(device)
|
||||
self.assertEqual(reference[1:5:2], torch.stack([reference[1], reference[3]], 0))
|
||||
self.assertEqual(reference[1:6:2], torch.stack([reference[1], reference[3], reference[5]], 0))
|
||||
self.assertEqual(reference[1:9:4], torch.stack([reference[1], reference[5]], 0))
|
||||
self.assertEqual(reference[2:4, 1:5:2], torch.stack([reference[2:4, 1], reference[2:4, 3]], 1))
|
||||
self.assertEqual(reference[3, 1:6:2], torch.stack([reference[3, 1], reference[3, 3], reference[3, 5]], 0))
|
||||
self.assertEqual(reference[None, 2, 1:9:4], torch.stack([reference[2, 1], reference[2, 5]], 0).unsqueeze(0))
|
||||
self.assertEqual(reference[:, 2, 1:6:2],
|
||||
torch.stack([reference[:, 2, 1], reference[:, 2, 3], reference[:, 2, 5]], 1))
|
||||
|
||||
lst = [list(range(i, i + 10)) for i in range(0, 100, 10)]
|
||||
tensor = torch.DoubleTensor(lst).to(device)
|
||||
for _i in range(100):
|
||||
idx1_start = random.randrange(10)
|
||||
idx1_end = idx1_start + random.randrange(1, 10 - idx1_start + 1)
|
||||
idx1_step = random.randrange(1, 8)
|
||||
idx1 = slice(idx1_start, idx1_end, idx1_step)
|
||||
if random.randrange(2) == 0:
|
||||
idx2_start = random.randrange(10)
|
||||
idx2_end = idx2_start + random.randrange(1, 10 - idx2_start + 1)
|
||||
idx2_step = random.randrange(1, 8)
|
||||
idx2 = slice(idx2_start, idx2_end, idx2_step)
|
||||
lst_indexed = [l[idx2] for l in lst[idx1]]
|
||||
tensor_indexed = tensor[idx1, idx2]
|
||||
else:
|
||||
lst_indexed = lst[idx1]
|
||||
tensor_indexed = tensor[idx1]
|
||||
self.assertEqual(torch.DoubleTensor(lst_indexed), tensor_indexed)
|
||||
|
||||
self.assertRaises(ValueError, lambda: reference[1:9:0])
|
||||
self.assertRaises(ValueError, lambda: reference[1:9:-1])
|
||||
|
||||
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1])
|
||||
self.assertRaises(IndexError, lambda: reference[1, 1, 1, 1:1])
|
||||
self.assertRaises(IndexError, lambda: reference[3, 3, 3, 3, 3, 3, 3, 3])
|
||||
|
||||
self.assertRaises(IndexError, lambda: reference[0.0])
|
||||
self.assertRaises(TypeError, lambda: reference[0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, 0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, ..., 0.0:2.0])
|
||||
self.assertRaises(IndexError, lambda: reference[0.0, :, 0.0])
|
||||
|
||||
def delitem():
|
||||
del reference[0]
|
||||
|
||||
self.assertRaises(TypeError, delitem)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.half, torch.double)
|
||||
def test_advancedindex(self, device, dtype):
|
||||
# Tests for Integer Array Indexing, Part I - Purely integer array
|
||||
# indexing
|
||||
|
||||
def consec(size, start=1):
|
||||
# Creates the sequence in float since CPU half doesn't support the
|
||||
# needed operations. Converts to dtype before returning.
|
||||
numel = reduce(lambda x, y: x * y, size, 1)
|
||||
sequence = torch.ones(numel, dtype=torch.float, device=device).cumsum(0)
|
||||
sequence.add_(start - 1)
|
||||
return sequence.view(*size).to(dtype=dtype)
|
||||
|
||||
# pick a random valid indexer type
|
||||
def ri(indices):
|
||||
choice = random.randint(0, 2)
|
||||
if choice == 0:
|
||||
return torch.LongTensor(indices).to(device)
|
||||
elif choice == 1:
|
||||
return list(indices)
|
||||
else:
|
||||
return tuple(indices)
|
||||
|
||||
def validate_indexing(x):
|
||||
self.assertEqual(x[[0]], consec((1,)))
|
||||
self.assertEqual(x[ri([0]), ], consec((1,)))
|
||||
self.assertEqual(x[ri([3]), ], consec((1,), 4))
|
||||
self.assertEqual(x[[2, 3, 4]], consec((3,), 3))
|
||||
self.assertEqual(x[ri([2, 3, 4]), ], consec((3,), 3))
|
||||
self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([1, 3, 5], dtype=dtype, device=device))
|
||||
|
||||
def validate_setting(x):
|
||||
x[[0]] = -2
|
||||
self.assertEqual(x[[0]], torch.tensor([-2], dtype=dtype, device=device))
|
||||
x[[0]] = -1
|
||||
self.assertEqual(x[ri([0]), ], torch.tensor([-1], dtype=dtype, device=device))
|
||||
x[[2, 3, 4]] = 4
|
||||
self.assertEqual(x[[2, 3, 4]], torch.tensor([4, 4, 4], dtype=dtype, device=device))
|
||||
x[ri([2, 3, 4]), ] = 3
|
||||
self.assertEqual(x[ri([2, 3, 4]), ], torch.tensor([3, 3, 3], dtype=dtype, device=device))
|
||||
x[ri([0, 2, 4]), ] = torch.tensor([5, 4, 3], dtype=dtype, device=device)
|
||||
self.assertEqual(x[ri([0, 2, 4]), ], torch.tensor([5, 4, 3], dtype=dtype, device=device))
|
||||
|
||||
# Only validates indexing and setting for halfs
|
||||
if dtype == torch.half:
|
||||
reference = consec((10,))
|
||||
validate_indexing(reference)
|
||||
validate_setting(reference)
|
||||
return
|
||||
|
||||
# Case 1: Purely Integer Array Indexing
|
||||
reference = consec((10,))
|
||||
validate_indexing(reference)
|
||||
|
||||
# setting values
|
||||
validate_setting(reference)
|
||||
|
||||
# Tensor with stride != 1
|
||||
# strided is [1, 3, 5, 7]
|
||||
reference = consec((10,))
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), storage_offset=0,
|
||||
size=torch.Size([4]), stride=[2])
|
||||
|
||||
self.assertEqual(strided[[0]], torch.tensor([1], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([0]), ], torch.tensor([1], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([3]), ], torch.tensor([7], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[[1, 2]], torch.tensor([3, 5], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([1, 2]), ], torch.tensor([3, 5], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([[2, 1], [0, 3]]), ],
|
||||
torch.tensor([[5, 3], [1, 7]], dtype=dtype, device=device))
|
||||
|
||||
# stride is [4, 8]
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), storage_offset=4,
|
||||
size=torch.Size([2]), stride=[4])
|
||||
self.assertEqual(strided[[0]], torch.tensor([5], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([0]), ], torch.tensor([5], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([1]), ], torch.tensor([9], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[[0, 1]], torch.tensor([5, 9], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([0, 1]), ], torch.tensor([5, 9], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([[0, 1], [1, 0]]), ],
|
||||
torch.tensor([[5, 9], [9, 5]], dtype=dtype, device=device))
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
# 5 6
|
||||
reference = consec((3, 2))
|
||||
self.assertEqual(reference[ri([0, 1, 2]), ri([0])], torch.tensor([1, 3, 5], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[ri([0, 1, 2]), ri([1])], torch.tensor([2, 4, 6], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[ri([0]), ri([0])], consec((1,)))
|
||||
self.assertEqual(reference[ri([2]), ri([1])], consec((1,), 6))
|
||||
self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]], torch.tensor([1, 2], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[[ri([0, 1, 1, 0, 2]), ri([1])]],
|
||||
torch.tensor([2, 4, 4, 2, 6], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
torch.tensor([1, 2, 3, 3], dtype=dtype, device=device))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
self.assertEqual(reference[rows, columns], torch.tensor([[1, 1],
|
||||
[3, 5]], dtype=dtype, device=device))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([1, 0])
|
||||
self.assertEqual(reference[rows, columns], torch.tensor([[2, 1],
|
||||
[4, 5]], dtype=dtype, device=device))
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 0]])
|
||||
self.assertEqual(reference[rows, columns], torch.tensor([[1, 2],
|
||||
[4, 5]], dtype=dtype, device=device))
|
||||
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
self.assertEqual(reference[ri([0]), ri([1])], torch.tensor([-1], dtype=dtype, device=device))
|
||||
reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device)
|
||||
self.assertEqual(reference[ri([0, 1, 2]), ri([0])],
|
||||
torch.tensor([-1, 2, -4], dtype=dtype, device=device))
|
||||
reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)
|
||||
self.assertEqual(reference[rows, columns],
|
||||
torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device))
|
||||
|
||||
# Verify still works with Transposed (i.e. non-contiguous) Tensors
|
||||
|
||||
reference = torch.tensor([[0, 1, 2, 3],
|
||||
[4, 5, 6, 7],
|
||||
[8, 9, 10, 11]], dtype=dtype, device=device).t_()
|
||||
|
||||
# Transposed: [[0, 4, 8],
|
||||
# [1, 5, 9],
|
||||
# [2, 6, 10],
|
||||
# [3, 7, 11]]
|
||||
|
||||
self.assertEqual(reference[ri([0, 1, 2]), ri([0])],
|
||||
torch.tensor([0, 1, 2], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[ri([0, 1, 2]), ri([1])],
|
||||
torch.tensor([4, 5, 6], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[ri([0]), ri([0])],
|
||||
torch.tensor([0], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[ri([2]), ri([1])],
|
||||
torch.tensor([6], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[[ri([0, 0]), ri([0, 1])]],
|
||||
torch.tensor([0, 4], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[[ri([0, 1, 1, 0, 3]), ri([1])]],
|
||||
torch.tensor([4, 5, 5, 4, 7], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
torch.tensor([0, 4, 1, 1], dtype=dtype, device=device))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = [0],
|
||||
self.assertEqual(reference[rows, columns],
|
||||
torch.tensor([[0, 0], [1, 2]], dtype=dtype, device=device))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 2]])
|
||||
columns = ri([1, 0])
|
||||
self.assertEqual(reference[rows, columns],
|
||||
torch.tensor([[4, 0], [5, 2]], dtype=dtype, device=device))
|
||||
rows = ri([[0, 0],
|
||||
[1, 3]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 2]])
|
||||
self.assertEqual(reference[rows, columns],
|
||||
torch.tensor([[0, 4], [5, 11]], dtype=dtype, device=device))
|
||||
|
||||
# setting values
|
||||
reference[ri([0]), ri([1])] = -1
|
||||
self.assertEqual(reference[ri([0]), ri([1])],
|
||||
torch.tensor([-1], dtype=dtype, device=device))
|
||||
reference[ri([0, 1, 2]), ri([0])] = torch.tensor([-1, 2, -4], dtype=dtype, device=device)
|
||||
self.assertEqual(reference[ri([0, 1, 2]), ri([0])],
|
||||
torch.tensor([-1, 2, -4], dtype=dtype, device=device))
|
||||
reference[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)
|
||||
self.assertEqual(reference[rows, columns],
|
||||
torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device))
|
||||
|
||||
# stride != 1
|
||||
|
||||
# strided is [[1 3 5 7],
|
||||
# [9 11 13 15]]
|
||||
|
||||
reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 1, size=torch.Size([2, 4]),
|
||||
stride=[8, 2])
|
||||
|
||||
self.assertEqual(strided[ri([0, 1]), ri([0])],
|
||||
torch.tensor([1, 9], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([0, 1]), ri([1])],
|
||||
torch.tensor([3, 11], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([0]), ri([0])],
|
||||
torch.tensor([1], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[ri([1]), ri([3])],
|
||||
torch.tensor([15], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[[ri([0, 0]), ri([0, 3])]],
|
||||
torch.tensor([1, 7], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[[ri([1]), ri([0, 1, 1, 0, 3])]],
|
||||
torch.tensor([9, 11, 11, 9, 15], dtype=dtype, device=device))
|
||||
self.assertEqual(strided[[ri([0, 0, 1, 1]), ri([0, 1, 0, 0])]],
|
||||
torch.tensor([1, 3, 9, 9], dtype=dtype, device=device))
|
||||
|
||||
rows = ri([[0, 0],
|
||||
[1, 1]])
|
||||
columns = [0],
|
||||
self.assertEqual(strided[rows, columns],
|
||||
torch.tensor([[1, 1], [9, 9]], dtype=dtype, device=device))
|
||||
|
||||
rows = ri([[0, 1],
|
||||
[1, 0]])
|
||||
columns = ri([1, 2])
|
||||
self.assertEqual(strided[rows, columns],
|
||||
torch.tensor([[3, 13], [11, 5]], dtype=dtype, device=device))
|
||||
rows = ri([[0, 0],
|
||||
[1, 1]])
|
||||
columns = ri([[0, 1],
|
||||
[1, 2]])
|
||||
self.assertEqual(strided[rows, columns],
|
||||
torch.tensor([[1, 3], [11, 13]], dtype=dtype, device=device))
|
||||
|
||||
# setting values
|
||||
|
||||
# strided is [[10, 11],
|
||||
# [17, 18]]
|
||||
|
||||
reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
|
||||
stride=[7, 1])
|
||||
self.assertEqual(strided[ri([0]), ri([1])],
|
||||
torch.tensor([11], dtype=dtype, device=device))
|
||||
strided[ri([0]), ri([1])] = -1
|
||||
self.assertEqual(strided[ri([0]), ri([1])],
|
||||
torch.tensor([-1], dtype=dtype, device=device))
|
||||
|
||||
reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
|
||||
stride=[7, 1])
|
||||
self.assertEqual(strided[ri([0, 1]), ri([1, 0])],
|
||||
torch.tensor([11, 17], dtype=dtype, device=device))
|
||||
strided[ri([0, 1]), ri([1, 0])] = torch.tensor([-1, 2], dtype=dtype, device=device)
|
||||
self.assertEqual(strided[ri([0, 1]), ri([1, 0])],
|
||||
torch.tensor([-1, 2], dtype=dtype, device=device))
|
||||
|
||||
reference = torch.arange(0., 24, dtype=dtype, device=device).view(3, 8)
|
||||
strided = torch.tensor((), dtype=dtype, device=device)
|
||||
strided.set_(reference.storage(), 10, size=torch.Size([2, 2]),
|
||||
stride=[7, 1])
|
||||
|
||||
rows = ri([[0],
|
||||
[1]])
|
||||
columns = ri([[0, 1],
|
||||
[0, 1]])
|
||||
self.assertEqual(strided[rows, columns],
|
||||
torch.tensor([[10, 11], [17, 18]], dtype=dtype, device=device))
|
||||
strided[rows, columns] = torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device)
|
||||
self.assertEqual(strided[rows, columns],
|
||||
torch.tensor([[4, 6], [2, 3]], dtype=dtype, device=device))
|
||||
|
||||
# Tests using less than the number of dims, and ellipsis
|
||||
|
||||
# reference is 1 2
|
||||
# 3 4
|
||||
# 5 6
|
||||
reference = consec((3, 2))
|
||||
self.assertEqual(reference[ri([0, 2]), ],
|
||||
torch.tensor([[1, 2], [5, 6]], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[ri([1]), ...],
|
||||
torch.tensor([[3, 4]], dtype=dtype, device=device))
|
||||
self.assertEqual(reference[..., ri([1])],
|
||||
torch.tensor([[2], [4], [6]], dtype=dtype, device=device))
|
||||
|
||||
# verify too many indices fails
|
||||
with self.assertRaises(IndexError):
|
||||
reference[ri([1]), ri([0, 2]), ri([3])]
|
||||
|
||||
# test invalid index fails
|
||||
reference = torch.empty(10, dtype=dtype, device=device)
|
||||
# can't test cuda because it is a device assert
|
||||
if not reference.is_cuda:
|
||||
for err_idx in (10, -11):
|
||||
with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
reference[err_idx]
|
||||
with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
reference[torch.LongTensor([err_idx]).to(device)]
|
||||
with self.assertRaisesRegex(IndexError, r'out of'):
|
||||
reference[[err_idx]]
|
||||
|
||||
def tensor_indices_to_np(tensor, indices):
|
||||
# convert the Torch Tensor to a numpy array
|
||||
tensor = tensor.to(device='cpu')
|
||||
npt = tensor.numpy()
|
||||
|
||||
# convert indices
|
||||
idxs = tuple(i.tolist() if isinstance(i, torch.LongTensor) else
|
||||
i for i in indices)
|
||||
|
||||
return npt, idxs
|
||||
|
||||
def get_numpy(tensor, indices):
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
|
||||
# index and return as a Torch Tensor
|
||||
return torch.tensor(npt[idxs], dtype=dtype, device=device)
|
||||
|
||||
def set_numpy(tensor, indices, value):
|
||||
if not isinstance(value, int):
|
||||
if self.device_type != 'cpu':
|
||||
value = value.cpu()
|
||||
value = value.numpy()
|
||||
|
||||
npt, idxs = tensor_indices_to_np(tensor, indices)
|
||||
npt[idxs] = value
|
||||
return npt
|
||||
|
||||
def assert_get_eq(tensor, indexer):
|
||||
self.assertEqual(tensor[indexer], get_numpy(tensor, indexer))
|
||||
|
||||
def assert_set_eq(tensor, indexer, val):
|
||||
pyt = tensor.clone()
|
||||
numt = tensor.clone()
|
||||
pyt[indexer] = val
|
||||
numt = torch.tensor(set_numpy(numt, indexer, val), dtype=dtype, device=device)
|
||||
self.assertEqual(pyt, numt)
|
||||
|
||||
def assert_backward_eq(tensor, indexer):
|
||||
cpu = tensor.float().clone().detach().requires_grad_(True)
|
||||
outcpu = cpu[indexer]
|
||||
gOcpu = torch.rand_like(outcpu)
|
||||
outcpu.backward(gOcpu)
|
||||
dev = cpu.to(device).detach().requires_grad_(True)
|
||||
outdev = dev[indexer]
|
||||
outdev.backward(gOcpu.to(device))
|
||||
self.assertEqual(cpu.grad, dev.grad)
|
||||
|
||||
def get_set_tensor(indexed, indexer):
|
||||
set_size = indexed[indexer].size()
|
||||
set_count = indexed[indexer].numel()
|
||||
set_tensor = torch.randperm(set_count).view(set_size).double().to(device)
|
||||
return set_tensor
|
||||
|
||||
# Tensor is 0 1 2 3 4
|
||||
# 5 6 7 8 9
|
||||
# 10 11 12 13 14
|
||||
# 15 16 17 18 19
|
||||
reference = torch.arange(0., 20, dtype=dtype, device=device).view(4, 5)
|
||||
|
||||
indices_to_test = [
|
||||
# grab the second, fourth columns
|
||||
[slice(None), [1, 3]],
|
||||
|
||||
# first, third rows,
|
||||
[[0, 2], slice(None)],
|
||||
|
||||
# weird shape
|
||||
[slice(None), [[0, 1],
|
||||
[2, 3]]],
|
||||
# negatives
|
||||
[[-1], [0]],
|
||||
[[0, 2], [-1]],
|
||||
[slice(None), [-1]],
|
||||
]
|
||||
|
||||
# only test dupes on gets
|
||||
get_indices_to_test = indices_to_test + [[slice(None), [0, 1, 1, 2, 2]]]
|
||||
|
||||
for indexer in get_indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
if self.device_type != 'cpu':
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_set_eq(reference, indexer, 44)
|
||||
assert_set_eq(reference,
|
||||
indexer,
|
||||
get_set_tensor(reference, indexer))
|
||||
|
||||
reference = torch.arange(0., 160, dtype=dtype, device=device).view(4, 8, 5)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[[2, 3], slice(None), slice(None)],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[[0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0, 1, 3], [4], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 3]], slice(None)],
|
||||
[[[0, 1], [2, 3]], [[0]], slice(None)],
|
||||
[[[2, 1]], [[0, 3], [4, 4]], slice(None)],
|
||||
[[[2]], [[0, 3], [4, 1]], slice(None)],
|
||||
# non-contiguous indexing subspace
|
||||
[[0, 2, 3], slice(None), [1, 3, 4]],
|
||||
|
||||
# less dim, ellipsis
|
||||
[[0, 2], ],
|
||||
[[0, 2], slice(None)],
|
||||
[[0, 2], Ellipsis],
|
||||
[[0, 2], slice(None), Ellipsis],
|
||||
[[0, 2], Ellipsis, slice(None)],
|
||||
[[0, 2], [1, 3]],
|
||||
[[0, 2], [1, 3], Ellipsis],
|
||||
[Ellipsis, [1, 3], [2, 3]],
|
||||
[Ellipsis, [2, 3, 4]],
|
||||
[Ellipsis, slice(None), [2, 3, 4]],
|
||||
[slice(None), Ellipsis, [2, 3, 4]],
|
||||
|
||||
# ellipsis counts for nothing
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), [0, 3, 4], Ellipsis],
|
||||
[Ellipsis, [[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], Ellipsis, slice(None)],
|
||||
[[[0, 1], [1, 0]], [[2, 1], [3, 5]], slice(None), Ellipsis],
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 212)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
if torch.cuda.is_available():
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
reference = torch.arange(0., 1296, dtype=dtype, device=device).view(3, 9, 8, 6)
|
||||
|
||||
indices_to_test = [
|
||||
[slice(None), slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), slice(None), [2, 4, 5, 7], slice(None)],
|
||||
[slice(None), [2, 3], slice(None), slice(None)],
|
||||
[[1, 2], slice(None), slice(None), slice(None)],
|
||||
[slice(None), slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), slice(None), [0], [1, 2, 4]],
|
||||
[slice(None), slice(None), [0, 1, 3], [4]],
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3]]],
|
||||
[slice(None), slice(None), [[0, 1], [2, 3]], [[0]]],
|
||||
[slice(None), slice(None), [[5, 6]], [[0, 3], [4, 4]]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[slice(None), [0], [1, 2, 4], slice(None)],
|
||||
[slice(None), [0, 1, 3], [4], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3], [0, 1]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 4]], [[2, 3]], slice(None)],
|
||||
[slice(None), [[0, 1], [3, 2]], [[0]], slice(None)],
|
||||
[slice(None), [[2, 1]], [[0, 3], [6, 4]], slice(None)],
|
||||
[slice(None), [[2]], [[0, 3], [4, 2]], slice(None)],
|
||||
[[0, 1, 2], [1, 3, 4], slice(None), slice(None)],
|
||||
[[0], [1, 2, 4], slice(None), slice(None)],
|
||||
[[0, 1, 2], [4], slice(None), slice(None)],
|
||||
[[[0, 1], [0, 2]], [[2, 4], [1, 5]], slice(None), slice(None)],
|
||||
[[[0, 1], [1, 2]], [[2, 0]], slice(None), slice(None)],
|
||||
[[[2, 2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[[[2]], [[0, 3], [4, 5]], slice(None), slice(None)],
|
||||
[slice(None), [3, 4, 6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 4], [1, 3, 4], [4]],
|
||||
[slice(None), [0, 1, 3], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [2, 3, 5], [3], [4]],
|
||||
[slice(None), [0], [4], [1, 3, 4]],
|
||||
[slice(None), [6], [0, 2, 3], [1]],
|
||||
[slice(None), [[0, 3], [3, 6]], [[0, 1], [1, 3]], [[5, 3], [1, 2]]],
|
||||
[[2, 2, 1], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[2, 0, 1], [1, 2, 3], [4], slice(None)],
|
||||
[[0, 1, 2], [4], [1, 3, 4], slice(None)],
|
||||
[[0], [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0], [4], [1, 3, 4], slice(None)],
|
||||
[[1], [0, 2, 3], [1], slice(None)],
|
||||
[[[1, 2], [1, 2]], [[0, 1], [2, 3]], [[2, 3], [3, 5]], slice(None)],
|
||||
|
||||
# less dim, ellipsis
|
||||
[Ellipsis, [0, 3, 4]],
|
||||
[Ellipsis, slice(None), [0, 3, 4]],
|
||||
[Ellipsis, slice(None), slice(None), [0, 3, 4]],
|
||||
[slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), slice(None), Ellipsis, [0, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4]],
|
||||
[slice(None), [0, 2, 3], [1, 3, 4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 3], [1, 3, 4], slice(None)],
|
||||
[[0], [1, 2, 4]],
|
||||
[[0], [1, 2, 4], slice(None)],
|
||||
[[0], [1, 2, 4], Ellipsis],
|
||||
[[0], [1, 2, 4], Ellipsis, slice(None)],
|
||||
[[1], ],
|
||||
[[0, 2, 1], [3], [4]],
|
||||
[[0, 2, 1], [3], [4], slice(None)],
|
||||
[[0, 2, 1], [3], [4], Ellipsis],
|
||||
[Ellipsis, [0, 2, 1], [3], [4]],
|
||||
]
|
||||
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
assert_set_eq(reference, indexer, get_set_tensor(reference, indexer))
|
||||
indices_to_test += [
|
||||
[slice(None), slice(None), [[0, 1], [1, 0]], [[2, 3], [3, 0]]],
|
||||
[slice(None), slice(None), [[2]], [[0, 3], [4, 4]]],
|
||||
]
|
||||
for indexer in indices_to_test:
|
||||
assert_get_eq(reference, indexer)
|
||||
assert_set_eq(reference, indexer, 1333)
|
||||
if self.device_type != 'cpu':
|
||||
assert_backward_eq(reference, indexer)
|
||||
|
||||
def test_advancedindex_big(self, device):
|
||||
reference = torch.arange(0, 123344, dtype=torch.int, device=device)
|
||||
|
||||
self.assertEqual(reference[[0, 123, 44488, 68807, 123343], ],
|
||||
torch.tensor([0, 123, 44488, 68807, 123343], dtype=torch.int))
|
||||
|
||||
def test_single_int(self, device):
|
||||
v = torch.randn(5, 7, 3, device=device)
|
||||
self.assertEqual(v[4].shape, (7, 3))
|
||||
|
3090
test/test_linalg.py
3090
test/test_linalg.py
File diff suppressed because it is too large
Load Diff
417
test/test_numpy_interop.py
Normal file
417
test/test_numpy_interop.py
Normal file
@ -0,0 +1,417 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from itertools import product
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes)
|
||||
|
||||
# For testing handling NumPy objects and sending tensors to / accepting
|
||||
# arrays from NumPy.
|
||||
class TestNumPyInterop(TestCase):
|
||||
# Note: the warning this tests for only appears once per program, so
|
||||
# other instances of this warning should be addressed to avoid
|
||||
# the tests depending on the order in which they're run.
|
||||
@onlyCPU
|
||||
def test_numpy_non_writeable(self, device):
|
||||
arr = np.zeros(5)
|
||||
arr.flags['WRITEABLE'] = False
|
||||
self.assertWarns(UserWarning, lambda: torch.from_numpy(arr))
|
||||
|
||||
@onlyCPU
|
||||
def test_numpy_unresizable(self, device) -> None:
|
||||
x = np.zeros((2, 2))
|
||||
y = torch.from_numpy(x)
|
||||
with self.assertRaises(ValueError):
|
||||
x.resize((5, 5))
|
||||
|
||||
z = torch.randn(5, 5)
|
||||
w = z.numpy()
|
||||
with self.assertRaises(RuntimeError):
|
||||
z.resize_(10, 10)
|
||||
with self.assertRaises(ValueError):
|
||||
w.resize((10, 10))
|
||||
|
||||
@onlyCPU
|
||||
def test_to_numpy(self, device) -> None:
|
||||
def get_castable_tensor(shape, dtype):
|
||||
if dtype.is_floating_point:
|
||||
dtype_info = torch.finfo(dtype)
|
||||
# can't directly use min and max, because for double, max - min
|
||||
# is greater than double range and sampling always gives inf.
|
||||
low = max(dtype_info.min, -1e10)
|
||||
high = min(dtype_info.max, 1e10)
|
||||
t = torch.empty(shape, dtype=torch.float64).uniform_(low, high)
|
||||
else:
|
||||
# can't directly use min and max, because for int64_t, max - min
|
||||
# is greater than int64_t range and triggers UB.
|
||||
dtype_info = torch.iinfo(dtype)
|
||||
low = max(dtype_info.min, int(-1e10))
|
||||
high = min(dtype_info.max, int(1e10))
|
||||
dtype_info = torch.iinfo(dtype)
|
||||
t = torch.empty(shape, dtype=torch.int64).random_(low, high)
|
||||
return t.to(dtype)
|
||||
|
||||
dtypes = [
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.short,
|
||||
torch.int,
|
||||
torch.half,
|
||||
torch.float,
|
||||
torch.double,
|
||||
torch.long,
|
||||
]
|
||||
|
||||
for dtp in dtypes:
|
||||
# 1D
|
||||
sz = 10
|
||||
x = get_castable_tensor(sz, dtp)
|
||||
y = x.numpy()
|
||||
for i in range(sz):
|
||||
self.assertEqual(x[i], y[i])
|
||||
|
||||
# 1D > 0 storage offset
|
||||
xm = get_castable_tensor(sz * 2, dtp)
|
||||
x = xm.narrow(0, sz - 1, sz)
|
||||
self.assertTrue(x.storage_offset() > 0)
|
||||
y = x.numpy()
|
||||
for i in range(sz):
|
||||
self.assertEqual(x[i], y[i])
|
||||
|
||||
def check2d(x, y):
|
||||
for i in range(sz1):
|
||||
for j in range(sz2):
|
||||
self.assertEqual(x[i][j], y[i][j])
|
||||
|
||||
# empty
|
||||
x = torch.Tensor().to(dtp)
|
||||
y = x.numpy()
|
||||
self.assertEqual(y.size, 0)
|
||||
|
||||
# contiguous 2D
|
||||
sz1 = 3
|
||||
sz2 = 5
|
||||
x = get_castable_tensor((sz1, sz2), dtp)
|
||||
y = x.numpy()
|
||||
check2d(x, y)
|
||||
self.assertTrue(y.flags['C_CONTIGUOUS'])
|
||||
|
||||
# with storage offset
|
||||
xm = get_castable_tensor((sz1 * 2, sz2), dtp)
|
||||
x = xm.narrow(0, sz1 - 1, sz1)
|
||||
y = x.numpy()
|
||||
self.assertTrue(x.storage_offset() > 0)
|
||||
check2d(x, y)
|
||||
self.assertTrue(y.flags['C_CONTIGUOUS'])
|
||||
|
||||
# non-contiguous 2D
|
||||
x = get_castable_tensor((sz2, sz1), dtp).t()
|
||||
y = x.numpy()
|
||||
check2d(x, y)
|
||||
self.assertFalse(y.flags['C_CONTIGUOUS'])
|
||||
|
||||
# with storage offset
|
||||
xm = get_castable_tensor((sz2 * 2, sz1), dtp)
|
||||
x = xm.narrow(0, sz2 - 1, sz2).t()
|
||||
y = x.numpy()
|
||||
self.assertTrue(x.storage_offset() > 0)
|
||||
check2d(x, y)
|
||||
|
||||
# non-contiguous 2D with holes
|
||||
xm = get_castable_tensor((sz2 * 2, sz1 * 2), dtp)
|
||||
x = xm.narrow(0, sz2 - 1, sz2).narrow(1, sz1 - 1, sz1).t()
|
||||
y = x.numpy()
|
||||
self.assertTrue(x.storage_offset() > 0)
|
||||
check2d(x, y)
|
||||
|
||||
if dtp != torch.half:
|
||||
# check writeable
|
||||
x = get_castable_tensor((3, 4), dtp)
|
||||
y = x.numpy()
|
||||
self.assertTrue(y.flags.writeable)
|
||||
y[0][1] = 3
|
||||
self.assertTrue(x[0][1] == 3)
|
||||
y = x.t().numpy()
|
||||
self.assertTrue(y.flags.writeable)
|
||||
y[0][1] = 3
|
||||
self.assertTrue(x[0][1] == 3)
|
||||
|
||||
def test_to_numpy_bool(self, device) -> None:
|
||||
x = torch.tensor([True, False], dtype=torch.bool)
|
||||
self.assertEqual(x.dtype, torch.bool)
|
||||
|
||||
y = x.numpy()
|
||||
self.assertEqual(y.dtype, np.bool)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(x[i], y[i])
|
||||
|
||||
x = torch.tensor([True], dtype=torch.bool)
|
||||
self.assertEqual(x.dtype, torch.bool)
|
||||
|
||||
y = x.numpy()
|
||||
self.assertEqual(y.dtype, np.bool)
|
||||
self.assertEqual(x[0], y[0])
|
||||
|
||||
def test_from_numpy(self, device) -> None:
|
||||
dtypes = [
|
||||
np.double,
|
||||
np.float,
|
||||
np.float16,
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
np.int64,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.int8,
|
||||
np.uint8,
|
||||
np.longlong,
|
||||
np.bool,
|
||||
]
|
||||
complex_dtypes = [
|
||||
np.complex64,
|
||||
np.complex128,
|
||||
]
|
||||
|
||||
for dtype in dtypes:
|
||||
array = np.array([1, 2, 3, 4], dtype=dtype)
|
||||
tensor_from_array = torch.from_numpy(array)
|
||||
# TODO: change to tensor equality check once HalfTensor
|
||||
# implements `==`
|
||||
for i in range(len(array)):
|
||||
self.assertEqual(tensor_from_array[i], array[i])
|
||||
# ufunc 'remainder' not supported for complex dtypes
|
||||
if dtype not in complex_dtypes:
|
||||
# This is a special test case for Windows
|
||||
# https://github.com/pytorch/pytorch/issues/22615
|
||||
array2 = array % 2
|
||||
tensor_from_array2 = torch.from_numpy(array2)
|
||||
for i in range(len(array2)):
|
||||
self.assertEqual(tensor_from_array2[i], array2[i])
|
||||
|
||||
# Test unsupported type
|
||||
array = np.array([1, 2, 3, 4], dtype=np.uint16)
|
||||
with self.assertRaises(TypeError):
|
||||
tensor_from_array = torch.from_numpy(array)
|
||||
|
||||
# check storage offset
|
||||
x = np.linspace(1, 125, 125)
|
||||
x.shape = (5, 5, 5)
|
||||
x = x[1]
|
||||
expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[1]
|
||||
self.assertEqual(torch.from_numpy(x), expected)
|
||||
|
||||
# check noncontiguous
|
||||
x = np.linspace(1, 25, 25)
|
||||
x.shape = (5, 5)
|
||||
expected = torch.arange(1, 26, dtype=torch.float64).view(5, 5).t()
|
||||
self.assertEqual(torch.from_numpy(x.T), expected)
|
||||
|
||||
# check noncontiguous with holes
|
||||
x = np.linspace(1, 125, 125)
|
||||
x.shape = (5, 5, 5)
|
||||
x = x[:, 1]
|
||||
expected = torch.arange(1, 126, dtype=torch.float64).view(5, 5, 5)[:, 1]
|
||||
self.assertEqual(torch.from_numpy(x), expected)
|
||||
|
||||
# check zero dimensional
|
||||
x = np.zeros((0, 2))
|
||||
self.assertEqual(torch.from_numpy(x).shape, (0, 2))
|
||||
x = np.zeros((2, 0))
|
||||
self.assertEqual(torch.from_numpy(x).shape, (2, 0))
|
||||
|
||||
# check ill-sized strides raise exception
|
||||
x = np.array([3., 5., 8.])
|
||||
x.strides = (3,)
|
||||
self.assertRaises(ValueError, lambda: torch.from_numpy(x))
|
||||
|
||||
@onlyCPU
|
||||
def test_ctor_with_numpy_scalar_ctor(self, device) -> None:
|
||||
dtypes = [
|
||||
np.double,
|
||||
np.float,
|
||||
np.float16,
|
||||
np.int64,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.uint8,
|
||||
np.bool,
|
||||
]
|
||||
for dtype in dtypes:
|
||||
self.assertEqual(dtype(42), torch.tensor(dtype(42)).item())
|
||||
|
||||
@onlyCPU
|
||||
def test_numpy_index(self, device):
|
||||
i = np.int32([0, 1, 2])
|
||||
x = torch.randn(5, 5)
|
||||
for idx in i:
|
||||
self.assertFalse(isinstance(idx, int))
|
||||
self.assertEqual(x[idx], x[int(idx)])
|
||||
|
||||
@onlyCPU
|
||||
def test_numpy_array_interface(self, device):
|
||||
types = [
|
||||
torch.DoubleTensor,
|
||||
torch.FloatTensor,
|
||||
torch.HalfTensor,
|
||||
torch.LongTensor,
|
||||
torch.IntTensor,
|
||||
torch.ShortTensor,
|
||||
torch.ByteTensor,
|
||||
]
|
||||
dtypes = [
|
||||
np.float64,
|
||||
np.float32,
|
||||
np.float16,
|
||||
np.int64,
|
||||
np.int32,
|
||||
np.int16,
|
||||
np.uint8,
|
||||
]
|
||||
for tp, dtype in zip(types, dtypes):
|
||||
if np.dtype(dtype).kind == 'u':
|
||||
x = torch.Tensor([1, 2, 3, 4]).type(tp)
|
||||
array = np.array([1, 2, 3, 4], dtype=dtype)
|
||||
else:
|
||||
x = torch.Tensor([1, -2, 3, -4]).type(tp)
|
||||
array = np.array([1, -2, 3, -4], dtype=dtype)
|
||||
|
||||
# Test __array__ w/o dtype argument
|
||||
asarray = np.asarray(x)
|
||||
self.assertIsInstance(asarray, np.ndarray)
|
||||
self.assertEqual(asarray.dtype, dtype)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(asarray[i], x[i])
|
||||
|
||||
# Test __array_wrap__, same dtype
|
||||
abs_x = np.abs(x)
|
||||
abs_array = np.abs(array)
|
||||
self.assertIsInstance(abs_x, tp)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(abs_x[i], abs_array[i])
|
||||
|
||||
# Test __array__ with dtype argument
|
||||
for dtype in dtypes:
|
||||
x = torch.IntTensor([1, -2, 3, -4])
|
||||
asarray = np.asarray(x, dtype=dtype)
|
||||
self.assertEqual(asarray.dtype, dtype)
|
||||
if np.dtype(dtype).kind == 'u':
|
||||
wrapped_x = np.array([1, -2, 3, -4], dtype=dtype)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(asarray[i], wrapped_x[i])
|
||||
else:
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(asarray[i], x[i])
|
||||
|
||||
# Test some math functions with float types
|
||||
float_types = [torch.DoubleTensor, torch.FloatTensor]
|
||||
float_dtypes = [np.float64, np.float32]
|
||||
for tp, dtype in zip(float_types, float_dtypes):
|
||||
x = torch.Tensor([1, 2, 3, 4]).type(tp)
|
||||
array = np.array([1, 2, 3, 4], dtype=dtype)
|
||||
for func in ['sin', 'sqrt', 'ceil']:
|
||||
ufunc = getattr(np, func)
|
||||
res_x = ufunc(x)
|
||||
res_array = ufunc(array)
|
||||
self.assertIsInstance(res_x, tp)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(res_x[i], res_array[i])
|
||||
|
||||
# Test functions with boolean return value
|
||||
for tp, dtype in zip(types, dtypes):
|
||||
x = torch.Tensor([1, 2, 3, 4]).type(tp)
|
||||
array = np.array([1, 2, 3, 4], dtype=dtype)
|
||||
geq2_x = np.greater_equal(x, 2)
|
||||
geq2_array = np.greater_equal(array, 2).astype('uint8')
|
||||
self.assertIsInstance(geq2_x, torch.ByteTensor)
|
||||
for i in range(len(x)):
|
||||
self.assertEqual(geq2_x[i], geq2_array[i])
|
||||
|
||||
@onlyCPU
|
||||
def test_multiplication_numpy_scalar(self, device) -> None:
|
||||
for np_dtype in [np.float32, np.float64, np.int32, np.int64, np.int16, np.uint8]:
|
||||
for t_dtype in [torch.float, torch.double]:
|
||||
np_sc = np_dtype(2.0)
|
||||
t = torch.ones(2, requires_grad=True, dtype=t_dtype)
|
||||
r1 = t * np_sc
|
||||
self.assertIsInstance(r1, torch.Tensor)
|
||||
self.assertTrue(r1.dtype == t_dtype)
|
||||
self.assertTrue(r1.requires_grad)
|
||||
r2 = np_sc * t
|
||||
self.assertIsInstance(r2, torch.Tensor)
|
||||
self.assertTrue(r2.dtype == t_dtype)
|
||||
self.assertTrue(r2.requires_grad)
|
||||
|
||||
@onlyCPU
|
||||
def test_parse_numpy_int(self, device):
|
||||
self.assertRaisesRegex(RuntimeError, "Overflow",
|
||||
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)))
|
||||
# https://github.com/pytorch/pytorch/issues/29252
|
||||
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
|
||||
scalar = 3
|
||||
np_arr = np.array([scalar], dtype=nptype)
|
||||
np_val = np_arr[0]
|
||||
|
||||
# np integral type can be treated as a python int in native functions with
|
||||
# int parameters:
|
||||
self.assertEqual(torch.ones(5).diag(scalar), torch.ones(5).diag(np_val))
|
||||
self.assertEqual(torch.ones([2, 2, 2, 2]).mean(scalar), torch.ones([2, 2, 2, 2]).mean(np_val))
|
||||
|
||||
# numpy integral type parses like a python int in custom python bindings:
|
||||
self.assertEqual(torch.Storage(np_val).size(), scalar)
|
||||
|
||||
tensor = torch.tensor([2], dtype=torch.int)
|
||||
tensor[0] = np_val
|
||||
self.assertEqual(tensor[0], np_val)
|
||||
|
||||
# Original reported issue, np integral type parses to the correct
|
||||
# PyTorch integral type when passed for a `Scalar` parameter in
|
||||
# arithmetic operations:
|
||||
t = torch.from_numpy(np_arr)
|
||||
self.assertEqual((t + np_val).dtype, t.dtype)
|
||||
self.assertEqual((np_val + t).dtype, t.dtype)
|
||||
|
||||
def test_has_storage_numpy(self, device):
|
||||
for dtype in [np.float32, np.float64, np.int64,
|
||||
np.int32, np.int16, np.uint8]:
|
||||
arr = np.array([1], dtype=dtype)
|
||||
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.float32).storage())
|
||||
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.double).storage())
|
||||
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.int).storage())
|
||||
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.long).storage())
|
||||
self.assertIsNotNone(torch.tensor(arr, device=device, dtype=torch.uint8).storage())
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes())
|
||||
def test_numpy_scalar_cmp(self, device, dtype):
|
||||
if dtype.is_complex:
|
||||
tensors = (torch.tensor(complex(1, 3), dtype=dtype, device=device),
|
||||
torch.tensor([complex(1, 3), 0, 2j], dtype=dtype, device=device),
|
||||
torch.tensor([[complex(3, 1), 0], [-1j, 5]], dtype=dtype, device=device))
|
||||
else:
|
||||
tensors = (torch.tensor(3, dtype=dtype, device=device),
|
||||
torch.tensor([1, 0, -3], dtype=dtype, device=device),
|
||||
torch.tensor([[3, 0, -1], [3, 5, 4]], dtype=dtype, device=device))
|
||||
|
||||
for tensor in tensors:
|
||||
if dtype == torch.bfloat16:
|
||||
with self.assertRaises(TypeError):
|
||||
np_array = tensor.cpu().numpy()
|
||||
continue
|
||||
|
||||
np_array = tensor.cpu().numpy()
|
||||
for t, a in product((tensor.flatten()[0], tensor.flatten()[0].item()),
|
||||
(np_array.flatten()[0], np_array.flatten()[0].item())):
|
||||
self.assertEqual(t, a)
|
||||
if dtype == torch.complex64 and torch.is_tensor(t) and type(a) == np.complex64:
|
||||
# TODO: Imaginary part is dropped in this case. Need fix.
|
||||
# https://github.com/pytorch/pytorch/issues/43579
|
||||
self.assertFalse(t == a)
|
||||
else:
|
||||
self.assertTrue(t == a)
|
||||
|
||||
instantiate_device_type_tests(TestNumPyInterop, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
2263
test/test_reductions.py
Normal file
2263
test/test_reductions.py
Normal file
File diff suppressed because it is too large
Load Diff
599
test/test_shape_ops.py
Normal file
599
test/test_shape_ops.py
Normal file
@ -0,0 +1,599 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
|
||||
from itertools import product, combinations, permutations
|
||||
from functools import partial
|
||||
import random
|
||||
|
||||
from torch._six import nan
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, make_tensor, torch_to_numpy_dtype_dict)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests, onlyCPU, dtypes, onlyOnCPUAndCUDA,
|
||||
dtypesIfCPU, dtypesIfCUDA)
|
||||
|
||||
# TODO: replace with make_tensor
|
||||
def _generate_input(shape, dtype, device, with_extremal):
|
||||
if shape == ():
|
||||
x = torch.tensor((), dtype=dtype, device=device)
|
||||
else:
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
# work around torch.randn not being implemented for bfloat16
|
||||
if dtype == torch.bfloat16:
|
||||
x = torch.randn(*shape, device=device) * random.randint(30, 100)
|
||||
x = x.to(torch.bfloat16)
|
||||
else:
|
||||
x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100)
|
||||
x[torch.randn(*shape) > 0.5] = 0
|
||||
if with_extremal and dtype.is_floating_point:
|
||||
# Use extremal values
|
||||
x[torch.randn(*shape) > 0.5] = float('nan')
|
||||
x[torch.randn(*shape) > 0.5] = float('inf')
|
||||
x[torch.randn(*shape) > 0.5] = float('-inf')
|
||||
elif with_extremal and dtype.is_complex:
|
||||
x[torch.randn(*shape) > 0.5] = complex('nan')
|
||||
x[torch.randn(*shape) > 0.5] = complex('inf')
|
||||
x[torch.randn(*shape) > 0.5] = complex('-inf')
|
||||
elif dtype == torch.bool:
|
||||
x = torch.zeros(shape, dtype=dtype, device=device)
|
||||
x[torch.randn(*shape) > 0.5] = True
|
||||
else:
|
||||
x = torch.randint(15, 100, shape, dtype=dtype, device=device)
|
||||
|
||||
return x
|
||||
|
||||
class TestShapeOps(TestCase):
|
||||
|
||||
# TODO: update to work on CUDA, too
|
||||
@onlyCPU
|
||||
def test_unbind(self, device):
|
||||
x = torch.rand(2, 3, 4, 5)
|
||||
for dim in range(4):
|
||||
res = torch.unbind(x, dim)
|
||||
res2 = x.unbind(dim)
|
||||
self.assertEqual(x.size(dim), len(res))
|
||||
self.assertEqual(x.size(dim), len(res2))
|
||||
for i in range(dim):
|
||||
self.assertEqual(x.select(dim, i), res[i])
|
||||
self.assertEqual(x.select(dim, i), res2[i])
|
||||
|
||||
# TODO: update to work on CUDA, too?
|
||||
@onlyCPU
|
||||
def test_tolist(self, device):
|
||||
list0D = []
|
||||
tensor0D = torch.Tensor(list0D)
|
||||
self.assertEqual(tensor0D.tolist(), list0D)
|
||||
|
||||
table1D = [1, 2, 3]
|
||||
tensor1D = torch.Tensor(table1D)
|
||||
storage = torch.Storage(table1D)
|
||||
self.assertEqual(tensor1D.tolist(), table1D)
|
||||
self.assertEqual(storage.tolist(), table1D)
|
||||
self.assertEqual(tensor1D.tolist(), table1D)
|
||||
self.assertEqual(storage.tolist(), table1D)
|
||||
|
||||
table2D = [[1, 2], [3, 4]]
|
||||
tensor2D = torch.Tensor(table2D)
|
||||
self.assertEqual(tensor2D.tolist(), table2D)
|
||||
|
||||
tensor3D = torch.Tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
|
||||
tensorNonContig = tensor3D.select(1, 1)
|
||||
self.assertFalse(tensorNonContig.is_contiguous())
|
||||
self.assertEqual(tensorNonContig.tolist(), [[3, 4], [7, 8]])
|
||||
|
||||
@dtypes(torch.int64, torch.float, torch.complex128)
|
||||
def test_movedim_invalid(self, device, dtype):
|
||||
shape = self._rand_shape(4, min_size=5, max_size=10)
|
||||
x = _generate_input(shape, dtype, device, False)
|
||||
|
||||
# Invalid `source` and `destination` dimension
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
torch.movedim(x, 5, 0)
|
||||
|
||||
with self.assertRaisesRegex(IndexError, "Dimension out of range"):
|
||||
torch.movedim(x, 0, 5)
|
||||
|
||||
# Mismatch in size of `source` and `destination`
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: Invalid source or destination dims:"):
|
||||
torch.movedim(x, (1, 0), (0, ))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
|
||||
torch.movedim(x, (0, 0), (0, 1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `source`"):
|
||||
torch.movedim(x, (0, 1, 0), (0, 1, 2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
|
||||
torch.movedim(x, (0, 1), (1, 1))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, "movedim: repeated dim in `destination`"):
|
||||
torch.movedim(x, (0, 1, 2), (1, 0, 1))
|
||||
|
||||
@dtypes(torch.int64, torch.float, torch.complex128)
|
||||
def test_movedim(self, device, dtype):
|
||||
for nd in range(5):
|
||||
shape = self._rand_shape(nd, min_size=5, max_size=10)
|
||||
x = _generate_input(shape, dtype, device, with_extremal=False)
|
||||
for random_negative in [True, False]:
|
||||
for src_dim, dst_dim in permutations(range(nd), r=2):
|
||||
random_prob = random.random()
|
||||
|
||||
if random_negative and random_prob > 0.66:
|
||||
src_dim = src_dim - nd
|
||||
elif random_negative and random_prob > 0.33:
|
||||
dst_dim = dst_dim - nd
|
||||
elif random_negative:
|
||||
src_dim = src_dim - nd
|
||||
dst_dim = dst_dim - nd
|
||||
|
||||
# Integer `source` and `destination`
|
||||
torch_fn = partial(torch.movedim, source=src_dim, destination=dst_dim)
|
||||
np_fn = partial(np.moveaxis, source=src_dim, destination=dst_dim)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
if nd == 0:
|
||||
continue
|
||||
|
||||
def make_index_negative(sequence, idx):
|
||||
sequence = list(sequence)
|
||||
sequence[random_idx] = sequence[random_idx] - nd
|
||||
return tuple(src_sequence)
|
||||
|
||||
for src_sequence in permutations(range(nd), r=random.randint(1, nd)):
|
||||
# Sequence `source` and `destination`
|
||||
dst_sequence = tuple(random.sample(range(nd), len(src_sequence)))
|
||||
|
||||
# Randomly change a dim to a negative dim representation of itself.
|
||||
random_prob = random.random()
|
||||
if random_negative and random_prob > 0.66:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
src_sequence = make_index_negative(src_sequence, random_idx)
|
||||
elif random_negative and random_prob > 0.33:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
dst_sequence = make_index_negative(dst_sequence, random_idx)
|
||||
elif random_negative:
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
dst_sequence = make_index_negative(dst_sequence, random_idx)
|
||||
random_idx = random.randint(0, len(src_sequence) - 1)
|
||||
src_sequence = make_index_negative(src_sequence, random_idx)
|
||||
|
||||
torch_fn = partial(torch.movedim, source=src_sequence, destination=dst_sequence)
|
||||
np_fn = partial(np.moveaxis, source=src_sequence, destination=dst_sequence)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
# Move dim to same position
|
||||
x = torch.randn(2, 3, 5, 7, 11)
|
||||
torch_fn = partial(torch.movedim, source=(0, 1), destination=(0, 1))
|
||||
np_fn = partial(np.moveaxis, source=(0, 1), destination=(0, 1))
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
torch_fn = partial(torch.movedim, source=1, destination=1)
|
||||
np_fn = partial(np.moveaxis, source=1, destination=1)
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
# Empty Sequence
|
||||
torch_fn = partial(torch.movedim, source=(), destination=())
|
||||
np_fn = partial(np.moveaxis, source=(), destination=())
|
||||
self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None)
|
||||
|
||||
@dtypes(torch.float, torch.bool)
|
||||
def test_diag(self, device, dtype):
|
||||
if dtype is torch.bool:
|
||||
x = torch.rand(100, 100, device=device) >= 0.5
|
||||
else:
|
||||
x = torch.rand(100, 100, dtype=dtype, device=device)
|
||||
|
||||
res1 = torch.diag(x)
|
||||
res2 = torch.tensor((), dtype=dtype, device=device)
|
||||
torch.diag(x, out=res2)
|
||||
self.assertEqual(res1, res2)
|
||||
|
||||
def test_diagonal(self, device):
|
||||
x = torch.randn((100, 100), device=device)
|
||||
result = torch.diagonal(x)
|
||||
expected = torch.diag(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
x = torch.randn((100, 100), device=device)
|
||||
result = torch.diagonal(x, 17)
|
||||
expected = torch.diag(x, 17)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.float)
|
||||
def test_diagonal_multidim(self, device, dtype):
|
||||
x = torch.randn(10, 11, 12, 13, dtype=dtype, device=device)
|
||||
xn = x.numpy()
|
||||
for args in [(2, 2, 3),
|
||||
(2,),
|
||||
(-2, 1, 2),
|
||||
(0, -2, -1)]:
|
||||
result = torch.diagonal(x, *args)
|
||||
expected = xn.diagonal(*args)
|
||||
self.assertEqual(expected.shape, result.shape)
|
||||
self.assertEqual(expected, result)
|
||||
# test non-continguous
|
||||
xp = x.permute(1, 2, 3, 0)
|
||||
result = torch.diagonal(xp, 0, -2, -1)
|
||||
expected = xp.numpy().diagonal(0, -2, -1)
|
||||
self.assertEqual(expected.shape, result.shape)
|
||||
self.assertEqual(expected, result)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypesIfCPU(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_half=False,
|
||||
include_bfloat16=False))
|
||||
@dtypesIfCUDA(*torch.testing.get_all_dtypes(include_complex=False, include_bool=False, include_bfloat16=False))
|
||||
def test_trace(self, device, dtype):
|
||||
def test(shape):
|
||||
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
|
||||
expected_dtype = tensor.sum().dtype
|
||||
expected_dtype = torch_to_numpy_dtype_dict[expected_dtype]
|
||||
|
||||
result = np.trace(tensor.cpu().numpy(), dtype=expected_dtype)
|
||||
expected = torch.tensor(result, device=device)
|
||||
self.assertEqual(tensor.trace(), expected)
|
||||
|
||||
shapes = (
|
||||
[10, 1],
|
||||
[1, 10],
|
||||
[100, 100],
|
||||
[20, 100],
|
||||
[100, 20],
|
||||
)
|
||||
for shape in shapes:
|
||||
test(shape)
|
||||
|
||||
def generate_clamp_baseline(self, device, dtype, *, min_vals, max_vals, with_nans):
|
||||
"""
|
||||
Creates a random tensor for a given device and dtype, and computes the expected clamped
|
||||
values given the min_vals and/or max_vals.
|
||||
If with_nans is provided, then some values are randomly set to nan.
|
||||
"""
|
||||
X = torch.rand(100, device=device).mul(50).add(-25) # uniform in [-25, 25]
|
||||
X = X.to(dtype)
|
||||
if with_nans:
|
||||
mask = torch.randint(0, 2, X.shape, dtype=torch.bool, device=device)
|
||||
X[mask] = nan
|
||||
|
||||
if isinstance(min_vals, torch.Tensor):
|
||||
min_vals = min_vals.cpu().numpy()
|
||||
|
||||
if isinstance(max_vals, torch.Tensor):
|
||||
max_vals = max_vals.cpu().numpy()
|
||||
|
||||
# Use NumPy implementation as reference
|
||||
X_clamped = torch.tensor(np.clip(X.cpu().numpy(), a_min=min_vals, a_max=max_vals), device=device)
|
||||
return X, X_clamped
|
||||
|
||||
# Tests clamp and its alias, clip
|
||||
@dtypes(torch.int64, torch.float32)
|
||||
def test_clamp(self, device, dtype):
|
||||
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
|
||||
torch.clip, torch.Tensor.clip, torch.Tensor.clip_)
|
||||
|
||||
# min/max argument product
|
||||
args = product((-10, None), (10, None))
|
||||
|
||||
for op in op_list:
|
||||
for min_val, max_val in args:
|
||||
if min_val is None and max_val is None:
|
||||
continue
|
||||
|
||||
X, Y_expected = self.generate_clamp_baseline(device, dtype,
|
||||
min_vals=min_val,
|
||||
max_vals=max_val,
|
||||
with_nans=False)
|
||||
|
||||
# Test op
|
||||
X1 = X.clone() # So that the in-place ops do not change X
|
||||
Y_actual = op(X1, min_val, max_val)
|
||||
self.assertEqual(Y_expected, Y_actual)
|
||||
|
||||
# Test op-out behavior (out does not exist for method versions)
|
||||
if op in (torch.clamp, torch.clip):
|
||||
Y_out = torch.empty_like(X)
|
||||
op(X, min=min_val, max=max_val, out=Y_out)
|
||||
self.assertEqual(Y_expected, Y_out)
|
||||
|
||||
def test_clamp_propagates_nans(self, device):
|
||||
op_list = (torch.clamp, torch.Tensor.clamp, torch.Tensor.clamp_,
|
||||
torch.clip, torch.Tensor.clip, torch.Tensor.clip_)
|
||||
|
||||
# min/max argument product
|
||||
args = product((-10, None), (10, None))
|
||||
|
||||
for op in op_list:
|
||||
for min_val, max_val in args:
|
||||
if min_val is None and max_val is None:
|
||||
continue
|
||||
|
||||
X, Y_expected = self.generate_clamp_baseline(device, torch.float,
|
||||
min_vals=min_val,
|
||||
max_vals=max_val,
|
||||
with_nans=True)
|
||||
Y_expected = torch.isnan(Y_expected)
|
||||
|
||||
# Test op
|
||||
X1 = X.clone() # So that the in-place ops do not change X
|
||||
Y_actual = op(X1, min_val, max_val)
|
||||
self.assertEqual(Y_expected, torch.isnan(Y_actual))
|
||||
|
||||
# Test op-out behavior (out does not exist for method versions)
|
||||
if op in (torch.clamp, torch.clip):
|
||||
Y_out = torch.empty_like(X)
|
||||
op(X, min_val, max_val, out=Y_out)
|
||||
self.assertEqual(Y_expected, torch.isnan(Y_out))
|
||||
|
||||
def test_clamp_raises_arg_errors(self, device):
|
||||
X = torch.randn(100, dtype=torch.float, device=device)
|
||||
error_msg = 'At least one of \'min\' or \'max\' must not be None'
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
X.clamp()
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
X.clamp_()
|
||||
with self.assertRaisesRegex(RuntimeError, error_msg):
|
||||
torch.clamp(X)
|
||||
|
||||
def test_flip(self, device):
|
||||
data = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8], device=device).view(2, 2, 2)
|
||||
|
||||
self.assertEqual(torch.tensor([5, 6, 7, 8, 1, 2, 3, 4]).view(2, 2, 2), data.flip(0))
|
||||
self.assertEqual(torch.tensor([3, 4, 1, 2, 7, 8, 5, 6]).view(2, 2, 2), data.flip(1))
|
||||
self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(2))
|
||||
self.assertEqual(torch.tensor([7, 8, 5, 6, 3, 4, 1, 2]).view(2, 2, 2), data.flip(0, 1))
|
||||
self.assertEqual(torch.tensor([8, 7, 6, 5, 4, 3, 2, 1]).view(2, 2, 2), data.flip(0, 1, 2))
|
||||
|
||||
# check for wrap dim
|
||||
self.assertEqual(torch.tensor([2, 1, 4, 3, 6, 5, 8, 7]).view(2, 2, 2), data.flip(-1))
|
||||
# check for permute
|
||||
self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(0, 2))
|
||||
self.assertEqual(torch.tensor([6, 5, 8, 7, 2, 1, 4, 3]).view(2, 2, 2), data.flip(2, 0))
|
||||
|
||||
# not allow flip on the same dim more than once
|
||||
self.assertRaises(RuntimeError, lambda: data.flip(0, 1, 1))
|
||||
# not allow empty list as input
|
||||
self.assertRaises(TypeError, lambda: data.flip())
|
||||
|
||||
# not allow size of flip dim > total dims
|
||||
self.assertRaises(IndexError, lambda: data.flip(0, 1, 2, 3))
|
||||
# not allow dim > max dim
|
||||
self.assertRaises(IndexError, lambda: data.flip(3))
|
||||
|
||||
# test for non-contiguous case
|
||||
expanded_data = torch.arange(1, 4, device=device).view(3, 1).expand(3, 2)
|
||||
transposed_data = torch.arange(1, 9, device=device).view(2, 2, 2).transpose(0, 1)
|
||||
self.assertEqual(torch.tensor([3, 3, 2, 2, 1, 1]).view(3, 2), expanded_data.flip(0))
|
||||
self.assertEqual(torch.tensor([8, 7, 4, 3, 6, 5, 2, 1]).view(2, 2, 2), transposed_data.flip(0, 1, 2))
|
||||
|
||||
# test for shape
|
||||
data = torch.randn(2, 3, 4, device=device)
|
||||
size = [2, 3, 4]
|
||||
test_dims = []
|
||||
for i in range(1, 3):
|
||||
test_dims += combinations(range(len(size)), i)
|
||||
|
||||
for ds in test_dims:
|
||||
self.assertEqual(size, list(data.flip(ds).size()))
|
||||
|
||||
# test rectangular case
|
||||
data = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3).to(device)
|
||||
flip0_result = torch.tensor([[4, 5, 6], [1, 2, 3]]).to(device)
|
||||
flip1_result = torch.tensor([[3, 2, 1], [6, 5, 4]]).to(device)
|
||||
|
||||
self.assertEqual(flip0_result, data.flip(0))
|
||||
self.assertEqual(flip1_result, data.flip(1))
|
||||
|
||||
# test empty tensor, should just return an empty tensor of the same shape
|
||||
data = torch.tensor([])
|
||||
self.assertEqual(data, data.flip(0))
|
||||
|
||||
# test bool tensor
|
||||
a = torch.tensor([False, True])
|
||||
self.assertEqual(a.flip(0), torch.tensor([True, False]))
|
||||
|
||||
def _rand_shape(self, dim, min_size, max_size):
|
||||
shape = []
|
||||
for i in range(dim):
|
||||
shape.append(random.randint(min_size, max_size))
|
||||
return tuple(shape)
|
||||
|
||||
@dtypes(torch.cfloat, torch.cdouble)
|
||||
def test_complex_flip(self, device, dtype):
|
||||
rand_dim = random.randint(3, 4)
|
||||
shape = self._rand_shape(rand_dim, 5, 10)
|
||||
|
||||
# Axis to sample for given shape.
|
||||
for i in range(1, rand_dim):
|
||||
# Check all combinations of `i` axis.
|
||||
for flip_dim in combinations(range(rand_dim), i):
|
||||
data = torch.randn(*shape, device=device, dtype=dtype)
|
||||
torch_fn = partial(torch.flip, dims=flip_dim)
|
||||
np_fn = partial(np.flip, axis=flip_dim)
|
||||
self.compare_with_numpy(torch_fn, np_fn, data)
|
||||
|
||||
def _test_fliplr_flipud(self, torch_fn, np_fn, min_dim, max_dim, device, dtype):
|
||||
for dim in range(min_dim, max_dim + 1):
|
||||
shape = self._rand_shape(dim, 5, 10)
|
||||
# Randomly scale the input
|
||||
if dtype.is_floating_point or dtype.is_complex:
|
||||
data = torch.randn(*shape, device=device, dtype=dtype)
|
||||
else:
|
||||
data = torch.randint(0, 10, shape, device=device, dtype=dtype)
|
||||
self.compare_with_numpy(torch_fn, np_fn, data)
|
||||
|
||||
@dtypes(torch.int64, torch.double, torch.cdouble)
|
||||
def test_fliplr(self, device, dtype):
|
||||
self._test_fliplr_flipud(torch.fliplr, np.fliplr, 2, 4, device, dtype)
|
||||
|
||||
@dtypes(torch.int64, torch.double, torch.cdouble)
|
||||
def test_fliplr_invalid(self, device, dtype):
|
||||
x = torch.randn(42).to(dtype)
|
||||
with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
|
||||
torch.fliplr(x)
|
||||
with self.assertRaisesRegex(RuntimeError, "Input must be >= 2-d."):
|
||||
torch.fliplr(torch.tensor(42, device=device, dtype=dtype))
|
||||
|
||||
@dtypes(torch.int64, torch.double, torch.cdouble)
|
||||
def test_flipud(self, device, dtype):
|
||||
self._test_fliplr_flipud(torch.flipud, np.flipud, 1, 4, device, dtype)
|
||||
|
||||
@dtypes(torch.int64, torch.double, torch.cdouble)
|
||||
def test_flipud_invalid(self, device, dtype):
|
||||
with self.assertRaisesRegex(RuntimeError, "Input must be >= 1-d."):
|
||||
torch.flipud(torch.tensor(42, device=device, dtype=dtype))
|
||||
|
||||
def test_rot90(self, device):
|
||||
data = torch.arange(1, 5, device=device).view(2, 2)
|
||||
self.assertEqual(torch.tensor([1, 2, 3, 4]).view(2, 2), data.rot90(0, [0, 1]))
|
||||
self.assertEqual(torch.tensor([2, 4, 1, 3]).view(2, 2), data.rot90(1, [0, 1]))
|
||||
self.assertEqual(torch.tensor([4, 3, 2, 1]).view(2, 2), data.rot90(2, [0, 1]))
|
||||
self.assertEqual(torch.tensor([3, 1, 4, 2]).view(2, 2), data.rot90(3, [0, 1]))
|
||||
|
||||
# test for default args k=1, dims=[0, 1]
|
||||
self.assertEqual(data.rot90(), data.rot90(1, [0, 1]))
|
||||
|
||||
# test for reversed order of dims
|
||||
self.assertEqual(data.rot90(3, [0, 1]), data.rot90(1, [1, 0]))
|
||||
|
||||
# test for modulo of k
|
||||
self.assertEqual(data.rot90(5, [0, 1]), data.rot90(1, [0, 1]))
|
||||
self.assertEqual(data.rot90(3, [0, 1]), data.rot90(-1, [0, 1]))
|
||||
self.assertEqual(data.rot90(-5, [0, 1]), data.rot90(-1, [0, 1]))
|
||||
|
||||
# test for dims out-of-range error
|
||||
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, -3]))
|
||||
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 2]))
|
||||
|
||||
# test tensor with more than 2D
|
||||
data = torch.arange(1, 9, device=device).view(2, 2, 2)
|
||||
self.assertEqual(torch.tensor([2, 4, 1, 3, 6, 8, 5, 7]).view(2, 2, 2), data.rot90(1, [1, 2]))
|
||||
self.assertEqual(data.rot90(1, [1, -1]), data.rot90(1, [1, 2]))
|
||||
|
||||
# test for errors
|
||||
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 3]))
|
||||
self.assertRaises(RuntimeError, lambda: data.rot90(1, [1, 1]))
|
||||
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0, 1, 2]))
|
||||
self.assertRaises(RuntimeError, lambda: data.rot90(1, [0]))
|
||||
|
||||
@dtypes(torch.cfloat, torch.cdouble)
|
||||
def test_complex_rot90(self, device, dtype):
|
||||
shape = self._rand_shape(random.randint(2, 4), 5, 10)
|
||||
for rot_times in range(4):
|
||||
data = torch.randn(*shape, device=device, dtype=dtype)
|
||||
torch_fn = partial(torch.rot90, k=rot_times, dims=[0, 1])
|
||||
np_fn = partial(np.rot90, k=rot_times, axes=[0, 1])
|
||||
self.compare_with_numpy(torch_fn, np_fn, data)
|
||||
|
||||
@dtypes(*torch.testing.get_all_dtypes(include_complex=False))
|
||||
def test_nonzero(self, device, dtype):
|
||||
|
||||
shapes = [
|
||||
torch.Size((12,)),
|
||||
torch.Size((12, 1)),
|
||||
torch.Size((1, 12)),
|
||||
torch.Size((6, 2)),
|
||||
torch.Size((3, 2, 2)),
|
||||
torch.Size((5, 5, 5)),
|
||||
]
|
||||
|
||||
def gen_nontrivial_input(shape, dtype, device):
|
||||
if dtype != torch.bfloat16:
|
||||
return torch.randint(2, shape, device=device, dtype=dtype)
|
||||
else:
|
||||
# windows does not work for bfloat16 randing
|
||||
return torch.randint(2, shape, device=device, dtype=torch.float).to(dtype)
|
||||
|
||||
for shape in shapes:
|
||||
tensor = gen_nontrivial_input(shape, dtype, device)
|
||||
dst1 = torch.nonzero(tensor, as_tuple=False)
|
||||
dst2 = tensor.nonzero(as_tuple=False)
|
||||
dst3 = torch.empty([], dtype=torch.long, device=device)
|
||||
torch.nonzero(tensor, out=dst3)
|
||||
if self.device_type != 'xla':
|
||||
# xla does not raise runtime error
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"scalar type Long",
|
||||
lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.float))
|
||||
)
|
||||
if self.device_type == 'cuda':
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"on the same device",
|
||||
lambda: torch.nonzero(tensor, out=torch.empty([], dtype=torch.long))
|
||||
)
|
||||
np_array = tensor.cpu().numpy() if dtype != torch.bfloat16 else tensor.float().cpu().numpy()
|
||||
np_result = torch.from_numpy(np.stack(np_array.nonzero())).t()
|
||||
self.assertEqual(dst1.cpu(), np_result, atol=0, rtol=0)
|
||||
self.assertEqual(dst2.cpu(), np_result, atol=0, rtol=0)
|
||||
self.assertEqual(dst3.cpu(), np_result, atol=0, rtol=0)
|
||||
tup1 = torch.nonzero(tensor, as_tuple=True)
|
||||
tup2 = tensor.nonzero(as_tuple=True)
|
||||
tup1 = torch.stack(tup1).t().cpu()
|
||||
tup2 = torch.stack(tup2).t().cpu()
|
||||
self.assertEqual(tup1, np_result, atol=0, rtol=0)
|
||||
self.assertEqual(tup2, np_result, atol=0, rtol=0)
|
||||
|
||||
def test_nonzero_astuple_out(self, device):
|
||||
t = torch.randn((3, 3, 3), device=device)
|
||||
out = torch.empty_like(t, dtype=torch.long)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.nonzero(t, as_tuple=True, out=out)
|
||||
|
||||
self.assertEqual(torch.nonzero(t, as_tuple=False, out=out), torch.nonzero(t, out=out))
|
||||
|
||||
# Verifies that JIT script cannot handle the as_tuple kwarg
|
||||
# See Issue https://github.com/pytorch/pytorch/issues/45499.
|
||||
def _foo(t):
|
||||
tuple_result = torch.nonzero(t, as_tuple=True)
|
||||
nontuple_result = torch.nonzero(t, as_tuple=False)
|
||||
out = torch.empty_like(nontuple_result)
|
||||
torch.nonzero(t, as_tuple=False, out=out)
|
||||
return tuple_result, nontuple_result, out
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
scripted_foo = torch.jit.script(_foo)
|
||||
|
||||
# Verifies that JIT tracing works fine
|
||||
traced_foo = torch.jit.trace(_foo, t)
|
||||
traced_tuple, traced_nontuple, traced_out = traced_foo(t)
|
||||
expected_tuple = torch.nonzero(t, as_tuple=True)
|
||||
expected_nontuple = torch.nonzero(t)
|
||||
|
||||
self.assertEqual(traced_tuple, expected_tuple)
|
||||
self.assertEqual(traced_nontuple, expected_nontuple)
|
||||
self.assertEqual(traced_out, expected_nontuple)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
def test_nonzero_discontiguous(self, device):
|
||||
shape = (4, 4)
|
||||
tensor = torch.randint(2, shape, device=device)
|
||||
tensor_nc = torch.empty(shape[0], shape[1] * 2, device=device)[:, ::2].copy_(tensor)
|
||||
dst1 = tensor.nonzero(as_tuple=False)
|
||||
dst2 = tensor_nc.nonzero(as_tuple=False)
|
||||
self.assertEqual(dst1, dst2, atol=0, rtol=0)
|
||||
dst3 = torch.empty_like(dst1)
|
||||
data_ptr = dst3.data_ptr()
|
||||
# expect dst3 storage to be reused
|
||||
torch.nonzero(tensor, out=dst3)
|
||||
self.assertEqual(data_ptr, dst3.data_ptr())
|
||||
self.assertEqual(dst1, dst3, atol=0, rtol=0)
|
||||
# discontiguous out
|
||||
dst4 = torch.empty(dst1.size(0), dst1.size(1) * 2, dtype=torch.long, device=device)[:, ::2]
|
||||
data_ptr = dst4.data_ptr()
|
||||
strides = dst4.stride()
|
||||
torch.nonzero(tensor, out=dst4)
|
||||
self.assertEqual(data_ptr, dst4.data_ptr())
|
||||
self.assertEqual(dst1, dst4, atol=0, rtol=0)
|
||||
self.assertEqual(strides, dst4.stride())
|
||||
|
||||
def test_nonzero_non_diff(self, device):
|
||||
x = torch.randn(10, requires_grad=True)
|
||||
nz = x.nonzero()
|
||||
self.assertFalse(nz.requires_grad)
|
||||
|
||||
instantiate_device_type_tests(TestShapeOps, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
640
test/test_sort_and_select.py
Normal file
640
test/test_sort_and_select.py
Normal file
@ -0,0 +1,640 @@
|
||||
import torch
|
||||
|
||||
import random
|
||||
from torch._six import nan
|
||||
from itertools import product
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
|
||||
skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA)
|
||||
|
||||
# TODO: remove this
|
||||
SIZE = 100
|
||||
|
||||
class TestSortAndSelect(TestCase):
|
||||
|
||||
def assertIsOrdered(self, order, x, mxx, ixx, task):
|
||||
SIZE = 4
|
||||
if order == 'descending':
|
||||
def check_order(a, b):
|
||||
# `a != a` because we put NaNs
|
||||
# at the end of ascending sorted lists,
|
||||
# and the beginning of descending ones.
|
||||
return a != a or a >= b
|
||||
elif order == 'ascending':
|
||||
def check_order(a, b):
|
||||
# see above
|
||||
return b != b or a <= b
|
||||
else:
|
||||
error('unknown order "{}", must be "ascending" or "descending"'.format(order))
|
||||
|
||||
are_ordered = True
|
||||
for j, k in product(range(SIZE), range(1, SIZE)):
|
||||
self.assertTrue(check_order(mxx[j][k - 1], mxx[j][k]),
|
||||
'torch.sort ({}) values unordered for {}'.format(order, task))
|
||||
|
||||
seen = set()
|
||||
indicesCorrect = True
|
||||
size = x.size(x.dim() - 1)
|
||||
for k in range(size):
|
||||
seen.clear()
|
||||
for j in range(size):
|
||||
self.assertEqual(x[k][ixx[k][j]], mxx[k][j],
|
||||
msg='torch.sort ({}) indices wrong for {}'.format(order, task))
|
||||
seen.add(ixx[k][j])
|
||||
self.assertEqual(len(seen), size)
|
||||
|
||||
def test_sort(self, device):
|
||||
SIZE = 4
|
||||
x = torch.rand(SIZE, SIZE, device=device)
|
||||
res1val, res1ind = torch.sort(x)
|
||||
|
||||
# Test use of result tensor
|
||||
res2val = torch.tensor((), device=device)
|
||||
res2ind = torch.tensor((), device=device, dtype=torch.long)
|
||||
torch.sort(x, out=(res2val, res2ind))
|
||||
self.assertEqual(res1val, res2val, atol=0, rtol=0)
|
||||
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
|
||||
self.assertEqual(torch.argsort(x), res1ind)
|
||||
self.assertEqual(x.argsort(), res1ind)
|
||||
|
||||
# Test sorting of random numbers
|
||||
self.assertIsOrdered('ascending', x, res2val, res2ind, 'random')
|
||||
|
||||
# Test simple sort
|
||||
self.assertEqual(
|
||||
torch.sort(torch.tensor((50, 40, 30, 20, 10), device=device))[0],
|
||||
torch.tensor((10, 20, 30, 40, 50), device=device),
|
||||
atol=0, rtol=0
|
||||
)
|
||||
|
||||
# Test that we still have proper sorting with duplicate keys
|
||||
x = torch.floor(torch.rand(SIZE, SIZE, device=device) * 10)
|
||||
torch.sort(x, out=(res2val, res2ind))
|
||||
self.assertIsOrdered('ascending', x, res2val, res2ind, 'random with duplicate keys')
|
||||
|
||||
# DESCENDING SORT
|
||||
x = torch.rand(SIZE, SIZE, device=device)
|
||||
res1val, res1ind = torch.sort(x, x.dim() - 1, True)
|
||||
|
||||
# Test use of result tensor
|
||||
res2val = torch.tensor((), device=device)
|
||||
res2ind = torch.tensor((), device=device, dtype=torch.long)
|
||||
torch.sort(x, x.dim() - 1, True, out=(res2val, res2ind))
|
||||
self.assertEqual(res1val, res2val, atol=0, rtol=0)
|
||||
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
|
||||
self.assertEqual(torch.argsort(x, x.dim() - 1, True), res1ind)
|
||||
self.assertEqual(x.argsort(x.dim() - 1, True), res1ind)
|
||||
|
||||
# Test sorting of random numbers
|
||||
self.assertIsOrdered('descending', x, res2val, res2ind, 'random')
|
||||
|
||||
# Test simple sort task
|
||||
self.assertEqual(
|
||||
torch.sort(torch.tensor((10, 20, 30, 40, 50), device=device), 0, True)[0],
|
||||
torch.tensor((50, 40, 30, 20, 10), device=device),
|
||||
atol=0, rtol=0
|
||||
)
|
||||
|
||||
# Test that we still have proper sorting with duplicate keys
|
||||
self.assertIsOrdered('descending', x, res2val, res2ind, 'random with duplicate keys')
|
||||
|
||||
# Test sorting with NaNs
|
||||
x = torch.rand(SIZE, SIZE, device=device)
|
||||
x[1][2] = float('NaN')
|
||||
x[3][0] = float('NaN')
|
||||
torch.sort(x, out=(res2val, res2ind))
|
||||
self.assertIsOrdered('ascending', x, res2val, res2ind,
|
||||
'random with NaNs')
|
||||
torch.sort(x, out=(res2val, res2ind), descending=True)
|
||||
self.assertIsOrdered('descending', x, res2val, res2ind,
|
||||
'random with NaNs')
|
||||
|
||||
def test_topk(self, device):
|
||||
def topKViaSort(t, k, dim, dir):
|
||||
sorted, indices = t.sort(dim, dir)
|
||||
return sorted.narrow(dim, 0, k), indices.narrow(dim, 0, k)
|
||||
|
||||
def compareTensors(t, res1, ind1, res2, ind2, dim):
|
||||
# Values should be exactly equivalent
|
||||
self.assertEqual(res1, res2, atol=0, rtol=0)
|
||||
|
||||
# Indices might differ based on the implementation, since there is
|
||||
# no guarantee of the relative order of selection
|
||||
if not ind1.eq(ind2).all():
|
||||
# To verify that the indices represent equivalent elements,
|
||||
# gather from the input using the topk indices and compare against
|
||||
# the sort indices
|
||||
vals = t.gather(dim, ind2)
|
||||
self.assertEqual(res1, vals, atol=0, rtol=0)
|
||||
|
||||
def compare(t, k, dim, dir):
|
||||
topKVal, topKInd = t.topk(k, dim, dir, True)
|
||||
sortKVal, sortKInd = topKViaSort(t, k, dim, dir)
|
||||
compareTensors(t, sortKVal, sortKInd, topKVal, topKInd, dim)
|
||||
|
||||
t = torch.rand(random.randint(1, SIZE),
|
||||
random.randint(1, SIZE),
|
||||
random.randint(1, SIZE), device=device)
|
||||
|
||||
for _kTries in range(3):
|
||||
for _dimTries in range(3):
|
||||
for transpose in (True, False):
|
||||
for dir in (True, False):
|
||||
testTensor = t
|
||||
if transpose:
|
||||
dim1 = random.randrange(t.ndimension())
|
||||
dim2 = dim1
|
||||
while dim1 == dim2:
|
||||
dim2 = random.randrange(t.ndimension())
|
||||
|
||||
testTensor = t.transpose(dim1, dim2)
|
||||
|
||||
dim = random.randrange(testTensor.ndimension())
|
||||
k = random.randint(1, testTensor.size(dim))
|
||||
compare(testTensor, k, dim, dir)
|
||||
|
||||
def test_topk_arguments(self, device):
|
||||
q = torch.randn(10, 2, 10, device=device)
|
||||
# Make sure True isn't mistakenly taken as the 2nd dimension (interpreted as 1)
|
||||
self.assertRaises(TypeError, lambda: q.topk(4, True))
|
||||
|
||||
@skipCUDAIfRocm
|
||||
def test_unique_dim(self, device):
|
||||
self.assertFalse(hasattr(torch, 'unique_dim'))
|
||||
|
||||
def run_test(device, dtype):
|
||||
x = torch.tensor([[[1., 1.],
|
||||
[0., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]],
|
||||
[[1., 1.],
|
||||
[0., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]]],
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
x_empty = torch.empty(5, 0, dtype=dtype, device=device)
|
||||
x_ill_formed_empty = torch.empty(5, 0, 0, dtype=dtype, device=device)
|
||||
x_ill_formed_empty_another = torch.empty(5, 0, 5, dtype=dtype, device=device)
|
||||
expected_unique_dim0 = torch.tensor([[[1., 1.],
|
||||
[0., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]]],
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
expected_inverse_dim0 = torch.tensor([0, 0])
|
||||
expected_counts_dim0 = torch.tensor([2])
|
||||
expected_unique_dim1 = torch.tensor([[[0., 1.],
|
||||
[1., 1.],
|
||||
[2., 1.]],
|
||||
[[0., 1.],
|
||||
[1., 1.],
|
||||
[2., 1.]]],
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
expected_unique_dim1_bool = torch.tensor([[[False, True], [True, True]],
|
||||
[[False, True], [True, True]]],
|
||||
dtype=torch.bool,
|
||||
device=device)
|
||||
expected_inverse_dim1 = torch.tensor([1, 0, 2, 0])
|
||||
expected_inverse_dim1_bool = torch.tensor([1, 0, 1, 0])
|
||||
expected_counts_dim1 = torch.tensor([2, 1, 1])
|
||||
expected_counts_dim1_bool = torch.tensor([2, 2])
|
||||
expected_unique_dim2 = torch.tensor([[[1., 1.],
|
||||
[0., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]],
|
||||
[[1., 1.],
|
||||
[0., 1.],
|
||||
[2., 1.],
|
||||
[0., 1.]]],
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
expected_inverse_dim2 = torch.tensor([0, 1])
|
||||
expected_counts_dim2 = torch.tensor([1, 1])
|
||||
expected_unique_empty = torch.tensor([], dtype=dtype, device=device)
|
||||
expected_inverse_empty = torch.tensor([], dtype=torch.long, device=device)
|
||||
expected_counts_empty = torch.tensor([], dtype=torch.long, device=device)
|
||||
# dim0
|
||||
x_unique = torch.unique(x, dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
self.assertEqual(expected_inverse_dim0, x_inverse)
|
||||
|
||||
x_unique, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=False,
|
||||
return_counts=True,
|
||||
dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
self.assertEqual(expected_counts_dim0, x_counts)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=0)
|
||||
self.assertEqual(expected_unique_dim0, x_unique)
|
||||
self.assertEqual(expected_inverse_dim0, x_inverse)
|
||||
self.assertEqual(expected_counts_dim0, x_counts)
|
||||
|
||||
# dim1
|
||||
x_unique = torch.unique(x, dim=1)
|
||||
if x.dtype == torch.bool:
|
||||
self.assertEqual(expected_unique_dim1_bool, x_unique)
|
||||
else:
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
dim=1)
|
||||
if x.dtype == torch.bool:
|
||||
self.assertEqual(expected_unique_dim1_bool, x_unique)
|
||||
self.assertEqual(expected_inverse_dim1_bool, x_inverse)
|
||||
else:
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
self.assertEqual(expected_inverse_dim1, x_inverse)
|
||||
|
||||
x_unique, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=False,
|
||||
return_counts=True,
|
||||
dim=1)
|
||||
if x.dtype == torch.bool:
|
||||
self.assertEqual(expected_unique_dim1_bool, x_unique)
|
||||
self.assertEqual(expected_counts_dim1_bool, x_counts)
|
||||
else:
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
self.assertEqual(expected_counts_dim1, x_counts)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=1)
|
||||
if x.dtype == torch.bool:
|
||||
self.assertEqual(expected_unique_dim1_bool, x_unique)
|
||||
self.assertEqual(expected_inverse_dim1_bool, x_inverse)
|
||||
self.assertEqual(expected_counts_dim1_bool, x_counts)
|
||||
else:
|
||||
self.assertEqual(expected_unique_dim1, x_unique)
|
||||
self.assertEqual(expected_inverse_dim1, x_inverse)
|
||||
self.assertEqual(expected_counts_dim1, x_counts)
|
||||
|
||||
# dim2
|
||||
x_unique = torch.unique(x, dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
|
||||
x_unique, x_inverse = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
self.assertEqual(expected_inverse_dim2, x_inverse)
|
||||
|
||||
x_unique, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=False,
|
||||
return_counts=True,
|
||||
dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
self.assertEqual(expected_counts_dim2, x_counts)
|
||||
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=2)
|
||||
self.assertEqual(expected_unique_dim2, x_unique)
|
||||
self.assertEqual(expected_inverse_dim2, x_inverse)
|
||||
self.assertEqual(expected_counts_dim2, x_counts)
|
||||
|
||||
# test empty tensor
|
||||
x_unique, x_inverse, x_counts = torch.unique(
|
||||
x_empty,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=1)
|
||||
self.assertEqual(expected_unique_empty, x_unique)
|
||||
self.assertEqual(expected_inverse_empty, x_inverse)
|
||||
self.assertEqual(expected_counts_empty, x_counts)
|
||||
|
||||
# test not a well formed tensor
|
||||
# Checking for runtime error, as this is the expected behaviour
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.unique(
|
||||
x_ill_formed_empty,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=1)
|
||||
|
||||
# test along dim2
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.unique(
|
||||
x_ill_formed_empty_another,
|
||||
return_inverse=True,
|
||||
return_counts=True,
|
||||
dim=2)
|
||||
|
||||
# test consecutive version
|
||||
y = torch.tensor(
|
||||
[[0, 1],
|
||||
[0, 1],
|
||||
[0, 1],
|
||||
[1, 2],
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
[0, 1],
|
||||
[0, 1],
|
||||
[3, 4],
|
||||
[1, 2]],
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
expected_y_unique = torch.tensor(
|
||||
[[0, 1],
|
||||
[1, 2],
|
||||
[3, 4],
|
||||
[0, 1],
|
||||
[3, 4],
|
||||
[1, 2]],
|
||||
dtype=dtype,
|
||||
device=device
|
||||
)
|
||||
expected_y_inverse = torch.tensor([0, 0, 0, 1, 1, 2, 3, 3, 4, 5], dtype=torch.int64, device=device)
|
||||
expected_y_counts = torch.tensor([3, 2, 1, 2, 1, 1], dtype=torch.int64, device=device)
|
||||
expected_y_inverse_bool = torch.tensor([0, 0, 0, 1, 1, 1, 2, 2, 3, 3], dtype=torch.int64, device=device)
|
||||
expected_y_counts_bool = torch.tensor([3, 3, 2, 2], dtype=torch.int64, device=device)
|
||||
y_unique, y_inverse, y_counts = torch.unique_consecutive(y, return_inverse=True, return_counts=True, dim=0)
|
||||
if x.dtype == torch.bool:
|
||||
self.assertEqual(expected_y_inverse_bool, y_inverse)
|
||||
self.assertEqual(expected_y_counts_bool, y_counts)
|
||||
else:
|
||||
self.assertEqual(expected_y_inverse, y_inverse)
|
||||
self.assertEqual(expected_y_counts, y_counts)
|
||||
|
||||
run_test(device, torch.float)
|
||||
run_test(device, torch.double)
|
||||
run_test(device, torch.long)
|
||||
run_test(device, torch.uint8)
|
||||
run_test(device, torch.bool)
|
||||
|
||||
@onlyCUDA
|
||||
def test_topk_noncontiguous_gpu(self, device):
|
||||
t = torch.randn(20, device=device)[::2]
|
||||
top1, idx1 = t.topk(5)
|
||||
top2, idx2 = t.contiguous().topk(5)
|
||||
self.assertEqual(top1, top2)
|
||||
self.assertEqual(idx1, idx2)
|
||||
|
||||
@dtypes(torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64)
|
||||
def test_topk_integral(self, device, dtype):
|
||||
a = torch.randint(torch.iinfo(dtype).min, torch.iinfo(dtype).max, size=(10,),
|
||||
dtype=dtype, device=device)
|
||||
sort_topk = a.sort()[0][-5:].flip(0)
|
||||
topk = a.topk(5)
|
||||
self.assertEqual(sort_topk, topk[0]) # check values
|
||||
self.assertEqual(sort_topk, a[topk[1]]) # check indices
|
||||
|
||||
@dtypesIfCUDA(*torch.testing.get_all_fp_dtypes())
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_topk_nonfinite(self, device, dtype):
|
||||
x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype)
|
||||
val, idx = x.topk(4)
|
||||
expect = torch.tensor([float('nan'), float('inf'), 1e4, 0], device=device, dtype=dtype)
|
||||
self.assertEqual(val, expect)
|
||||
self.assertEqual(idx, [0, 1, 2, 3])
|
||||
|
||||
val, idx = x.topk(4, largest=False)
|
||||
expect = torch.tensor([-float('inf'), -1e4, 0, 1e4], device=device, dtype=dtype)
|
||||
self.assertEqual(val, expect)
|
||||
self.assertEqual(idx, [5, 4, 3, 2])
|
||||
|
||||
def test_topk_4d(self, device):
|
||||
x = torch.ones(2, 3072, 2, 2, device=device)
|
||||
x[:, 1, :, :] *= 2.
|
||||
x[:, 10, :, :] *= 1.5
|
||||
val, ind = torch.topk(x, k=2, dim=1)
|
||||
expected_ind = torch.ones(2, 2, 2, 2, dtype=torch.long, device=device)
|
||||
expected_ind[:, 1, :, :] = 10
|
||||
expected_val = torch.ones(2, 2, 2, 2, device=device)
|
||||
expected_val[:, 0, :, :] *= 2.
|
||||
expected_val[:, 1, :, :] *= 1.5
|
||||
self.assertEqual(val, expected_val, atol=0, rtol=0)
|
||||
self.assertEqual(ind, expected_ind, atol=0, rtol=0)
|
||||
|
||||
def _test_unique_scalar_empty(self, dtype, device, f):
|
||||
# test scalar
|
||||
x = torch.tensor(0, dtype=dtype, device=device)
|
||||
unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
|
||||
expected_unique = torch.tensor([0], dtype=dtype, device=device)
|
||||
expected_inverse = torch.tensor(0, device=device)
|
||||
expected_counts = torch.tensor([1], device=device)
|
||||
self.assertEqual(unique, expected_unique)
|
||||
self.assertEqual(inverse, expected_inverse)
|
||||
self.assertEqual(counts, expected_counts)
|
||||
|
||||
# test zero sized tensor
|
||||
x = torch.zeros((0, 0, 3), dtype=dtype, device=device)
|
||||
unique, inverse, counts = f(x, return_inverse=True, return_counts=True)
|
||||
expected_unique = torch.tensor([], dtype=dtype, device=device)
|
||||
expected_inverse = torch.empty((0, 0, 3), dtype=torch.long, device=device)
|
||||
expected_counts = torch.tensor([], dtype=torch.long, device=device)
|
||||
self.assertEqual(unique, expected_unique)
|
||||
self.assertEqual(inverse, expected_inverse)
|
||||
self.assertEqual(counts, expected_counts)
|
||||
|
||||
def _test_unique_with_expects(self, device, dtype, f, x, expected_unique, expected_inverse, expected_counts, additional_shape):
|
||||
def ensure_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
for return_inverse in [True, False]:
|
||||
for return_counts in [True, False]:
|
||||
# test with expected
|
||||
ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
|
||||
self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
|
||||
self.assertEqual(expected_unique, ret[0])
|
||||
if return_inverse:
|
||||
self.assertEqual(expected_inverse, ret[1])
|
||||
if return_counts:
|
||||
count_index = 1 + int(return_inverse)
|
||||
self.assertEqual(expected_counts, ret[count_index])
|
||||
|
||||
# tests per-element unique on a higher rank tensor.
|
||||
y = x.view(additional_shape)
|
||||
y_unique, y_inverse, y_counts = f(y, return_inverse=True, return_counts=True)
|
||||
self.assertEqual(expected_unique, y_unique)
|
||||
self.assertEqual(expected_inverse.view(additional_shape), y_inverse)
|
||||
self.assertEqual(expected_counts, y_counts)
|
||||
|
||||
@dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
|
||||
def test_unique(self, device, dtype):
|
||||
if dtype is torch.half and self.device_type == 'cpu':
|
||||
return # CPU does not have half support
|
||||
|
||||
def ensure_tuple(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return (x,)
|
||||
return x
|
||||
|
||||
if dtype is torch.bool:
|
||||
x = torch.tensor([True, False, False, False, True, False, True, False], dtype=torch.bool, device=device)
|
||||
expected_unique = torch.tensor([False, True], dtype=torch.bool, device=device)
|
||||
expected_inverse = torch.tensor([1, 0, 0, 0, 1, 0, 1, 0], dtype=torch.long, device=device)
|
||||
expected_counts = torch.tensor([5, 3], dtype=torch.long, device=device)
|
||||
else:
|
||||
x = torch.tensor([1, 2, 3, 2, 8, 5, 2, 3], dtype=dtype, device=device)
|
||||
expected_unique = torch.tensor([1, 2, 3, 5, 8], dtype=dtype, device=device)
|
||||
expected_inverse = torch.tensor([0, 1, 2, 1, 4, 3, 1, 2], device=device)
|
||||
expected_counts = torch.tensor([1, 3, 2, 1, 1], device=device)
|
||||
|
||||
# test sorted unique
|
||||
fs = [
|
||||
lambda x, **kwargs: torch.unique(x, sorted=True, **kwargs),
|
||||
lambda x, **kwargs: x.unique(sorted=True, **kwargs),
|
||||
]
|
||||
for f in fs:
|
||||
self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (2, 2, 2))
|
||||
self._test_unique_scalar_empty(dtype, device, f)
|
||||
|
||||
# test unsorted unique
|
||||
fs = [
|
||||
lambda x, **kwargs: torch.unique(x, sorted=False, **kwargs),
|
||||
lambda x, **kwargs: x.unique(sorted=False, **kwargs)
|
||||
]
|
||||
for f in fs:
|
||||
self._test_unique_scalar_empty(dtype, device, f)
|
||||
for return_inverse in [True, False]:
|
||||
for return_counts in [True, False]:
|
||||
ret = ensure_tuple(f(x, return_inverse=return_inverse, return_counts=return_counts))
|
||||
self.assertEqual(len(ret), 1 + int(return_inverse) + int(return_counts))
|
||||
x_list = x.tolist()
|
||||
x_unique_list = ret[0].tolist()
|
||||
self.assertEqual(expected_unique.tolist(), sorted(x_unique_list))
|
||||
if return_inverse:
|
||||
x_inverse_list = ret[1].tolist()
|
||||
for i, j in enumerate(x_inverse_list):
|
||||
self.assertEqual(x_list[i], x_unique_list[j])
|
||||
if return_counts:
|
||||
count_index = 1 + int(return_inverse)
|
||||
x_counts_list = ret[count_index].tolist()
|
||||
for i, j in zip(x_unique_list, x_counts_list):
|
||||
count = 0
|
||||
for k in x_list:
|
||||
if k == i:
|
||||
count += 1
|
||||
self.assertEqual(j, count)
|
||||
|
||||
@dtypes(*set(torch.testing.get_all_dtypes()) - {torch.bfloat16, torch.complex64, torch.complex128})
|
||||
def test_unique_consecutive(self, device, dtype):
|
||||
if dtype is torch.half and self.device_type == 'cpu':
|
||||
return # CPU does not have half support
|
||||
|
||||
if dtype is torch.bool:
|
||||
x = torch.tensor([True, False, False, False, True, True, False, False, False], dtype=torch.bool, device=device)
|
||||
expected_unique = torch.tensor([True, False, True, False], dtype=torch.bool, device=device)
|
||||
expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 3], dtype=torch.long, device=device)
|
||||
expected_counts = torch.tensor([1, 3, 2, 3], dtype=torch.long, device=device)
|
||||
else:
|
||||
x = torch.tensor([1, 2, 2, 2, 5, 5, 2, 2, 3], dtype=dtype, device=device)
|
||||
expected_unique = torch.tensor([1, 2, 5, 2, 3], dtype=dtype, device=device)
|
||||
expected_inverse = torch.tensor([0, 1, 1, 1, 2, 2, 3, 3, 4], device=device)
|
||||
expected_counts = torch.tensor([1, 3, 2, 2, 1], device=device)
|
||||
|
||||
for f in [torch.unique_consecutive, lambda x, **kwargs: x.unique_consecutive(**kwargs)]:
|
||||
self._test_unique_with_expects(device, dtype, f, x, expected_unique, expected_inverse, expected_counts, (3, 3))
|
||||
self._test_unique_scalar_empty(dtype, device, f)
|
||||
|
||||
@dtypes(torch.double)
|
||||
def test_kthvalue(self, device, dtype):
|
||||
SIZE = 50
|
||||
x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device)
|
||||
x0 = x.clone()
|
||||
|
||||
k = random.randint(1, SIZE)
|
||||
res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
|
||||
res2val, res2ind = torch.sort(x)
|
||||
|
||||
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
|
||||
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)
|
||||
# test use of result tensors
|
||||
k = random.randint(1, SIZE)
|
||||
res1val = torch.tensor([], dtype=dtype, device=device)
|
||||
res1ind = torch.tensor([], dtype=torch.long, device=device)
|
||||
torch.kthvalue(x, k, keepdim=False, out=(res1val, res1ind))
|
||||
res2val, res2ind = torch.sort(x)
|
||||
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
|
||||
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)
|
||||
|
||||
# test non-default dim
|
||||
k = random.randint(1, SIZE)
|
||||
res1val, res1ind = torch.kthvalue(x, k, 0, keepdim=False)
|
||||
res2val, res2ind = torch.sort(x, 0)
|
||||
self.assertEqual(res1val, res2val[k - 1], atol=0, rtol=0)
|
||||
self.assertEqual(res1ind, res2ind[k - 1], atol=0, rtol=0)
|
||||
|
||||
# non-contiguous
|
||||
y = x.narrow(1, 0, 1)
|
||||
y0 = y.contiguous()
|
||||
k = random.randint(1, SIZE)
|
||||
res1val, res1ind = torch.kthvalue(y, k)
|
||||
res2val, res2ind = torch.kthvalue(y0, k)
|
||||
self.assertEqual(res1val, res2val, atol=0, rtol=0)
|
||||
self.assertEqual(res1ind, res2ind, atol=0, rtol=0)
|
||||
|
||||
# non-contiguous [Reference: https://github.com/pytorch/pytorch/issues/45721]
|
||||
non_contig_t = torch.tensor([0, -1, 1, -2, 2], dtype=dtype, device=device)[::2]
|
||||
expected_val, expected_ind = non_contig_t.contiguous().kthvalue(2)
|
||||
non_contig_cpu_t = non_contig_t.cpu()
|
||||
expected_val_cpu, expected_ind_cpu = non_contig_cpu_t.kthvalue(2)
|
||||
|
||||
out_val, out_ind = non_contig_t.kthvalue(2)
|
||||
self.assertEqual(expected_val, out_val, atol=0, rtol=0)
|
||||
self.assertEqual(expected_ind, out_ind, atol=0, rtol=0)
|
||||
self.assertEqual(expected_val_cpu, out_val, atol=0, rtol=0)
|
||||
self.assertEqual(expected_ind_cpu, out_ind, atol=0, rtol=0)
|
||||
|
||||
# check that the input wasn't modified
|
||||
self.assertEqual(x, x0, atol=0, rtol=0)
|
||||
|
||||
# simple test case (with repetitions)
|
||||
y = torch.tensor((3., 5, 4, 1, 1, 5), dtype=dtype, device=device)
|
||||
self.assertEqual(torch.kthvalue(y, 3)[0], 3, atol=0, rtol=0)
|
||||
self.assertEqual(torch.kthvalue(y, 2)[0], 1, atol=0, rtol=0)
|
||||
|
||||
# simple test case (with NaN)
|
||||
SIZE = 50
|
||||
x = torch.rand(SIZE, SIZE, SIZE, dtype=dtype, device=device)
|
||||
x[torch.arange(SIZE), :, torch.randint(50, (50,))] = nan
|
||||
ks = [random.randint(1, SIZE), 1, SIZE, SIZE - 1]
|
||||
res2val, res2ind = torch.sort(x)
|
||||
for k in ks:
|
||||
res1val, res1ind = torch.kthvalue(x, k, keepdim=False)
|
||||
self.assertEqual(res1val[:, :], res2val[:, :, k - 1], atol=0, rtol=0)
|
||||
self.assertEqual(res1ind[:, :], res2ind[:, :, k - 1], atol=0, rtol=0)
|
||||
|
||||
@dtypes(torch.float)
|
||||
@onlyOnCPUAndCUDA # Fails on XLA
|
||||
def test_kthvalue_scalar(self, device, dtype):
|
||||
# Test scalar input (test case from https://github.com/pytorch/pytorch/issues/30818)
|
||||
# Tests that passing a scalar tensor or 1D tensor with 1 element work either way
|
||||
res = torch.tensor(2, device=device, dtype=dtype).kthvalue(1)
|
||||
ref = torch.tensor([2], device=device, dtype=dtype).kthvalue(1)
|
||||
self.assertEqual(res[0], ref[0].squeeze())
|
||||
self.assertEqual(res[1], ref[1].squeeze())
|
||||
|
||||
instantiate_device_type_tests(TestSortAndSelect, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
File diff suppressed because it is too large
Load Diff
438
test/test_testing.py
Normal file
438
test/test_testing.py
Normal file
@ -0,0 +1,438 @@
|
||||
import torch
|
||||
|
||||
import math
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests, make_tensor)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyOnCPUAndCUDA, dtypes)
|
||||
|
||||
# For testing TestCase methods and torch.testing functions
|
||||
class TestTesting(TestCase):
|
||||
# Ensure that assertEqual handles numpy arrays properly
|
||||
@dtypes(*(torch.testing.get_all_dtypes(include_half=True, include_bfloat16=False,
|
||||
include_bool=True, include_complex=True)))
|
||||
def test_assertEqual_numpy(self, device, dtype):
|
||||
S = 10
|
||||
test_sizes = [
|
||||
(),
|
||||
(0,),
|
||||
(S,),
|
||||
(S, S),
|
||||
(0, S),
|
||||
(S, 0)]
|
||||
for test_size in test_sizes:
|
||||
a = make_tensor(test_size, device, dtype, low=-5, high=5)
|
||||
a_n = a.cpu().numpy()
|
||||
msg = f'size: {test_size}'
|
||||
self.assertEqual(a_n, a, rtol=0, atol=0, msg=msg)
|
||||
self.assertEqual(a, a_n, rtol=0, atol=0, msg=msg)
|
||||
self.assertEqual(a_n, a_n, rtol=0, atol=0, msg=msg)
|
||||
|
||||
# Tests that when rtol or atol (including self.precision) is set, then
|
||||
# the other is zeroed.
|
||||
# TODO: this is legacy behavior and should be updated after test
|
||||
# precisions are reviewed to be consistent with torch.isclose.
|
||||
@onlyOnCPUAndCUDA
|
||||
def test__comparetensors_legacy(self, device):
|
||||
a = torch.tensor((10000000.,))
|
||||
b = torch.tensor((10000002.,))
|
||||
|
||||
x = torch.tensor((1.,))
|
||||
y = torch.tensor((1. + 1e-5,))
|
||||
|
||||
# Helper for reusing the tensor values as scalars
|
||||
def _scalar_helper(a, b, rtol=None, atol=None):
|
||||
return self._compareScalars(a.item(), b.item(), rtol=rtol, atol=atol)
|
||||
|
||||
for op in (self._compareTensors, _scalar_helper):
|
||||
# Tests default
|
||||
result, debug_msg = op(a, b)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Tests setting atol
|
||||
result, debug_msg = op(a, b, atol=2, rtol=0)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Tests setting atol too small
|
||||
result, debug_msg = op(a, b, atol=1, rtol=0)
|
||||
self.assertFalse(result)
|
||||
|
||||
# Tests setting rtol too small
|
||||
result, debug_msg = op(x, y, atol=0, rtol=1.05e-5)
|
||||
self.assertTrue(result)
|
||||
|
||||
# Tests setting rtol too small
|
||||
result, debug_msg = op(x, y, atol=0, rtol=1e-5)
|
||||
self.assertFalse(result)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
def test__comparescalars_debug_msg(self, device):
|
||||
# float x float
|
||||
result, debug_msg = self._compareScalars(4., 7.)
|
||||
expected_msg = ("Comparing 4.0 and 7.0 gives a difference of 3.0, "
|
||||
"but the allowed difference with rtol=1.3e-06 and "
|
||||
"atol=1e-05 is only 1.9100000000000003e-05!")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# complex x complex, real difference
|
||||
result, debug_msg = self._compareScalars(complex(1, 3), complex(3, 1))
|
||||
expected_msg = ("Comparing the real part 1.0 and 3.0 gives a difference "
|
||||
"of 2.0, but the allowed difference with rtol=1.3e-06 "
|
||||
"and atol=1e-05 is only 1.39e-05!")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# complex x complex, imaginary difference
|
||||
result, debug_msg = self._compareScalars(complex(1, 3), complex(1, 5.5))
|
||||
expected_msg = ("Comparing the imaginary part 3.0 and 5.5 gives a "
|
||||
"difference of 2.5, but the allowed difference with "
|
||||
"rtol=1.3e-06 and atol=1e-05 is only 1.715e-05!")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# complex x int
|
||||
result, debug_msg = self._compareScalars(complex(1, -2), 1)
|
||||
expected_msg = ("Comparing the imaginary part -2.0 and 0.0 gives a "
|
||||
"difference of 2.0, but the allowed difference with "
|
||||
"rtol=1.3e-06 and atol=1e-05 is only 1e-05!")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# NaN x NaN, equal_nan=False
|
||||
result, debug_msg = self._compareScalars(float('nan'), float('nan'), equal_nan=False)
|
||||
expected_msg = ("Found nan and nan while comparing and either one is "
|
||||
"nan and the other isn't, or both are nan and equal_nan "
|
||||
"is False")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks that compareTensors provides the correct debug info
|
||||
@onlyOnCPUAndCUDA
|
||||
def test__comparetensors_debug_msg(self, device):
|
||||
# Acquires atol that will be used
|
||||
atol = max(1e-05, self.precision)
|
||||
|
||||
# Checks float tensor comparisons (2D tensor)
|
||||
a = torch.tensor(((0, 6), (7, 9)), device=device, dtype=torch.float32)
|
||||
b = torch.tensor(((0, 7), (7, 22)), device=device, dtype=torch.float32)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 4) "
|
||||
"whose difference(s) exceeded the margin of error (including 0 nan comparisons). "
|
||||
"The greatest difference was 13.0 (9.0 vs. 22.0), "
|
||||
"which occurred at index (1, 1).").format(atol)
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks float tensor comparisons (with extremal values)
|
||||
a = torch.tensor((float('inf'), 5, float('inf')), device=device, dtype=torch.float32)
|
||||
b = torch.tensor((float('inf'), float('nan'), float('-inf')), device=device, dtype=torch.float32)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 3) "
|
||||
"whose difference(s) exceeded the margin of error (including 1 nan comparisons). "
|
||||
"The greatest difference was nan (5.0 vs. nan), "
|
||||
"which occurred at index 1.").format(atol)
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks float tensor comparisons (with finite vs nan differences)
|
||||
a = torch.tensor((20, -6), device=device, dtype=torch.float32)
|
||||
b = torch.tensor((-1, float('nan')), device=device, dtype=torch.float32)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("With rtol=1.3e-06 and atol={0}, found 2 element(s) (out of 2) "
|
||||
"whose difference(s) exceeded the margin of error (including 1 nan comparisons). "
|
||||
"The greatest difference was nan (-6.0 vs. nan), "
|
||||
"which occurred at index 1.").format(atol)
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks int tensor comparisons (1D tensor)
|
||||
a = torch.tensor((1, 2, 3, 4), device=device)
|
||||
b = torch.tensor((2, 5, 3, 4), device=device)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("Found 2 different element(s) (out of 4), "
|
||||
"with the greatest difference of 3 (2 vs. 5) "
|
||||
"occuring at index 1.")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks bool tensor comparisons (0D tensor)
|
||||
a = torch.tensor((True), device=device)
|
||||
b = torch.tensor((False), device=device)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("Found 1 different element(s) (out of 1), "
|
||||
"with the greatest difference of 1 (1 vs. 0) "
|
||||
"occuring at index 0.")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks complex tensor comparisons (real part)
|
||||
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
|
||||
b = torch.tensor((1 - 1j, 1 + 3j), device=device)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("Real parts failed to compare as equal! "
|
||||
"With rtol=1.3e-06 and atol={0}, "
|
||||
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
|
||||
"margin of error (including 0 nan comparisons). The greatest difference was "
|
||||
"3.0 (4.0 vs. 1.0), which occurred at index 1.").format(atol)
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks complex tensor comparisons (imaginary part)
|
||||
a = torch.tensor((1 - 1j, 4 + 3j), device=device)
|
||||
b = torch.tensor((1 - 1j, 4 - 21j), device=device)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("Imaginary parts failed to compare as equal! "
|
||||
"With rtol=1.3e-06 and atol={0}, "
|
||||
"found 1 element(s) (out of 2) whose difference(s) exceeded the "
|
||||
"margin of error (including 0 nan comparisons). The greatest difference was "
|
||||
"24.0 (3.0 vs. -21.0), which occurred at index 1.").format(atol)
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks size mismatch
|
||||
a = torch.tensor((1, 2), device=device)
|
||||
b = torch.tensor((3), device=device)
|
||||
result, debug_msg = self._compareTensors(a, b)
|
||||
expected_msg = ("Attempted to compare equality of tensors "
|
||||
"with different sizes. Got sizes torch.Size([2]) and torch.Size([]).")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks dtype mismatch
|
||||
a = torch.tensor((1, 2), device=device, dtype=torch.long)
|
||||
b = torch.tensor((1, 2), device=device, dtype=torch.float32)
|
||||
result, debug_msg = self._compareTensors(a, b, exact_dtype=True)
|
||||
expected_msg = ("Attempted to compare equality of tensors "
|
||||
"with different dtypes. Got dtypes torch.int64 and torch.float32.")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Checks device mismatch
|
||||
if self.device_type == 'cuda':
|
||||
a = torch.tensor((5), device='cpu')
|
||||
b = torch.tensor((5), device=device)
|
||||
result, debug_msg = self._compareTensors(a, b, exact_device=True)
|
||||
expected_msg = ("Attempted to compare equality of tensors "
|
||||
"on different devices! Got devices cpu and cuda:0.")
|
||||
self.assertEqual(debug_msg, expected_msg)
|
||||
|
||||
# Helper for testing _compareTensors and _compareScalars
|
||||
# Works on single element tensors
|
||||
def _comparetensors_helper(self, tests, device, dtype, equal_nan, exact_dtype=True, atol=1e-08, rtol=1e-05):
|
||||
for test in tests:
|
||||
a = torch.tensor((test[0],), device=device, dtype=dtype)
|
||||
b = torch.tensor((test[1],), device=device, dtype=dtype)
|
||||
|
||||
# Tensor x Tensor comparison
|
||||
compare_result, debug_msg = self._compareTensors(a, b, rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan,
|
||||
exact_dtype=exact_dtype)
|
||||
self.assertEqual(compare_result, test[2])
|
||||
|
||||
# Scalar x Scalar comparison
|
||||
compare_result, debug_msg = self._compareScalars(a.item(), b.item(),
|
||||
rtol=rtol, atol=atol,
|
||||
equal_nan=equal_nan)
|
||||
self.assertEqual(compare_result, test[2])
|
||||
|
||||
def _isclose_helper(self, tests, device, dtype, equal_nan, atol=1e-08, rtol=1e-05):
|
||||
for test in tests:
|
||||
a = torch.tensor((test[0],), device=device, dtype=dtype)
|
||||
b = torch.tensor((test[1],), device=device, dtype=dtype)
|
||||
|
||||
actual = torch.isclose(a, b, equal_nan=equal_nan, atol=atol, rtol=rtol)
|
||||
expected = test[2]
|
||||
self.assertEqual(actual.item(), expected)
|
||||
|
||||
# torch.close is not implemented for bool tensors
|
||||
# see https://github.com/pytorch/pytorch/issues/33048
|
||||
def test_isclose_comparetensors_bool(self, device):
|
||||
tests = (
|
||||
(True, True, True),
|
||||
(False, False, True),
|
||||
(True, False, False),
|
||||
(False, True, False),
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self._isclose_helper(tests, device, torch.bool, False)
|
||||
|
||||
self._comparetensors_helper(tests, device, torch.bool, False)
|
||||
|
||||
@dtypes(torch.uint8,
|
||||
torch.int8, torch.int16, torch.int32, torch.int64)
|
||||
def test_isclose_comparetensors_integer(self, device, dtype):
|
||||
tests = (
|
||||
(0, 0, True),
|
||||
(0, 1, False),
|
||||
(1, 0, False),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False)
|
||||
|
||||
# atol and rtol tests
|
||||
tests = [
|
||||
(0, 1, True),
|
||||
(1, 0, False),
|
||||
(1, 3, True),
|
||||
]
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
|
||||
if dtype is torch.uint8:
|
||||
tests = [
|
||||
(-1, 1, False),
|
||||
(1, -1, False)
|
||||
]
|
||||
else:
|
||||
tests = [
|
||||
(-1, 1, True),
|
||||
(1, -1, True)
|
||||
]
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
|
||||
self._comparetensors_helper(tests, device, dtype, False, atol=1.5, rtol=.5)
|
||||
|
||||
@onlyOnCPUAndCUDA
|
||||
@dtypes(torch.float16, torch.float32, torch.float64)
|
||||
def test_isclose_comparetensors_float(self, device, dtype):
|
||||
tests = (
|
||||
(0, 0, True),
|
||||
(0, -1, False),
|
||||
(float('inf'), float('inf'), True),
|
||||
(-float('inf'), float('inf'), False),
|
||||
(float('inf'), float('nan'), False),
|
||||
(float('nan'), float('nan'), False),
|
||||
(0, float('nan'), False),
|
||||
(1, 1, True),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False)
|
||||
self._comparetensors_helper(tests, device, dtype, False)
|
||||
|
||||
# atol and rtol tests
|
||||
eps = 1e-2 if dtype is torch.half else 1e-6
|
||||
tests = (
|
||||
(0, 1, True),
|
||||
(0, 1 + eps, False),
|
||||
(1, 0, False),
|
||||
(1, 3, True),
|
||||
(1 - eps, 3, False),
|
||||
(-.25, .5, True),
|
||||
(-.25 - eps, .5, False),
|
||||
(.25, -.5, True),
|
||||
(.25 + eps, -.5, False),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
|
||||
# equal_nan = True tests
|
||||
tests = (
|
||||
(0, float('nan'), False),
|
||||
(float('inf'), float('nan'), False),
|
||||
(float('nan'), float('nan'), True),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, True)
|
||||
|
||||
self._comparetensors_helper(tests, device, dtype, True)
|
||||
|
||||
# torch.close with equal_nan=True is not implemented for complex inputs
|
||||
# see https://github.com/numpy/numpy/issues/15959
|
||||
# Note: compareTensor will compare the real and imaginary parts of a
|
||||
# complex tensors separately, unlike isclose.
|
||||
@dtypes(torch.complex64, torch.complex128)
|
||||
def test_isclose_comparetensors_complex(self, device, dtype):
|
||||
tests = (
|
||||
(complex(1, 1), complex(1, 1 + 1e-8), True),
|
||||
(complex(0, 1), complex(1, 1), False),
|
||||
(complex(1, 1), complex(1, 0), False),
|
||||
(complex(1, 1), complex(1, float('nan')), False),
|
||||
(complex(1, float('nan')), complex(1, float('nan')), False),
|
||||
(complex(1, 1), complex(1, float('inf')), False),
|
||||
(complex(float('inf'), 1), complex(1, float('inf')), False),
|
||||
(complex(-float('inf'), 1), complex(1, float('inf')), False),
|
||||
(complex(-float('inf'), 1), complex(float('inf'), 1), False),
|
||||
(complex(float('inf'), 1), complex(float('inf'), 1), True),
|
||||
(complex(float('inf'), 1), complex(float('inf'), 1 + 1e-4), False),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False)
|
||||
self._comparetensors_helper(tests, device, dtype, False)
|
||||
|
||||
# atol and rtol tests
|
||||
|
||||
# atol and rtol tests
|
||||
eps = 1e-6
|
||||
tests = (
|
||||
# Complex versions of float tests (real part)
|
||||
(complex(0, 0), complex(1, 0), True),
|
||||
(complex(0, 0), complex(1 + eps, 0), False),
|
||||
(complex(1, 0), complex(0, 0), False),
|
||||
(complex(1, 0), complex(3, 0), True),
|
||||
(complex(1 - eps, 0), complex(3, 0), False),
|
||||
(complex(-.25, 0), complex(.5, 0), True),
|
||||
(complex(-.25 - eps, 0), complex(.5, 0), False),
|
||||
(complex(.25, 0), complex(-.5, 0), True),
|
||||
(complex(.25 + eps, 0), complex(-.5, 0), False),
|
||||
# Complex versions of float tests (imaginary part)
|
||||
(complex(0, 0), complex(0, 1), True),
|
||||
(complex(0, 0), complex(0, 1 + eps), False),
|
||||
(complex(0, 1), complex(0, 0), False),
|
||||
(complex(0, 1), complex(0, 3), True),
|
||||
(complex(0, 1 - eps), complex(0, 3), False),
|
||||
(complex(0, -.25), complex(0, .5), True),
|
||||
(complex(0, -.25 - eps), complex(0, .5), False),
|
||||
(complex(0, .25), complex(0, -.5), True),
|
||||
(complex(0, .25 + eps), complex(0, -.5), False),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
|
||||
# atol and rtol tests for isclose
|
||||
tests = (
|
||||
# Complex-specific tests
|
||||
(complex(1, -1), complex(-1, 1), False),
|
||||
(complex(1, -1), complex(2, -2), True),
|
||||
(complex(-math.sqrt(2), math.sqrt(2)),
|
||||
complex(-math.sqrt(.5), math.sqrt(.5)), True),
|
||||
(complex(-math.sqrt(2), math.sqrt(2)),
|
||||
complex(-math.sqrt(.501), math.sqrt(.499)), False),
|
||||
(complex(2, 4), complex(1., 8.8523607), True),
|
||||
(complex(2, 4), complex(1., 8.8523607 + eps), False),
|
||||
(complex(1, 99), complex(4, 100), True),
|
||||
)
|
||||
|
||||
self._isclose_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
|
||||
# atol and rtol tests for compareTensors
|
||||
tests = (
|
||||
(complex(1, -1), complex(-1, 1), False),
|
||||
(complex(1, -1), complex(2, -2), True),
|
||||
(complex(1, 99), complex(4, 100), False),
|
||||
)
|
||||
|
||||
self._comparetensors_helper(tests, device, dtype, False, atol=.5, rtol=.5)
|
||||
|
||||
# equal_nan = True tests
|
||||
tests = (
|
||||
(complex(1, 1), complex(1, float('nan')), False),
|
||||
(complex(float('nan'), 1), complex(1, float('nan')), False),
|
||||
(complex(float('nan'), 1), complex(float('nan'), 1), True),
|
||||
)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
self._isclose_helper(tests, device, dtype, True)
|
||||
|
||||
self._comparetensors_helper(tests, device, dtype, True)
|
||||
|
||||
# Tests that isclose with rtol or atol values less than zero throws a
|
||||
# RuntimeError
|
||||
@dtypes(torch.bool, torch.uint8,
|
||||
torch.int8, torch.int16, torch.int32, torch.int64,
|
||||
torch.float16, torch.float32, torch.float64)
|
||||
def test_isclose_atol_rtol_greater_than_zero(self, device, dtype):
|
||||
t = torch.tensor((1,), device=device, dtype=dtype)
|
||||
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.isclose(t, t, atol=-1, rtol=1)
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.isclose(t, t, atol=1, rtol=-1)
|
||||
with self.assertRaises(RuntimeError):
|
||||
torch.isclose(t, t, atol=-1, rtol=-1)
|
||||
|
||||
instantiate_device_type_tests(TestTesting, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
14450
test/test_torch.py
14450
test/test_torch.py
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
1218
test/test_view_ops.py
Normal file
1218
test/test_view_ops.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -16,6 +16,8 @@ import math # noqa: F401
|
||||
|
||||
# Testing utils
|
||||
from torch._six import inf
|
||||
|
||||
# TODO: include files like this should not set the default dtype
|
||||
torch.set_default_dtype(torch.double)
|
||||
|
||||
L = 20
|
||||
|
Reference in New Issue
Block a user