Clean-up test_indexing.py after Tensor/Variable merge (#6433)

This commit is contained in:
Sam Gross
2018-04-10 14:03:14 -04:00
committed by GitHub
parent aea31131e5
commit 64e94814da

View File

@ -1,88 +1,87 @@
from common import TestCase, run_tests
import unittest
import torch
import warnings
from torch.autograd import Variable
from torch import tensor
class TestIndexing(TestCase):
def test_single_int(self):
v = Variable(torch.randn(5, 7, 3))
v = torch.randn(5, 7, 3)
self.assertEqual(v[4].shape, (7, 3))
def test_multiple_int(self):
v = Variable(torch.randn(5, 7, 3))
v = torch.randn(5, 7, 3)
self.assertEqual(v[4].shape, (7, 3))
self.assertEqual(v[4, :, 1].shape, (7,))
def test_none(self):
v = Variable(torch.randn(5, 7, 3))
v = torch.randn(5, 7, 3)
self.assertEqual(v[None].shape, (1, 5, 7, 3))
self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
def test_step(self):
v = Variable(torch.arange(10))
v = torch.arange(10)
self.assertEqual(v[::1], v)
self.assertEqual(v[::2].data.tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].data.tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].data.tolist(), [0])
self.assertEqual(v[1:6:2].data.tolist(), [1, 3, 5])
self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
self.assertEqual(v[::11].tolist(), [0])
self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
def test_step_assignment(self):
v = Variable(torch.zeros(4, 4))
v[0, 1::2] = Variable(torch.Tensor([3, 4]))
self.assertEqual(v[0].data.tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].data.sum(), 0)
v = torch.zeros(4, 4)
v[0, 1::2] = torch.tensor([3., 4.])
self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
self.assertEqual(v[1:].sum(), 0)
def test_byte_mask(self):
v = Variable(torch.randn(5, 7, 3))
mask = Variable(torch.ByteTensor([1, 0, 1, 1, 0]))
v = torch.randn(5, 7, 3)
mask = torch.ByteTensor([1, 0, 1, 1, 0])
self.assertEqual(v[mask].shape, (3, 7, 3))
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
v = Variable(torch.Tensor([1]))
self.assertEqual(v[v == 0], Variable(torch.Tensor()))
v = torch.tensor([1.])
self.assertEqual(v[v == 0], torch.tensor([]))
def test_multiple_byte_mask(self):
v = Variable(torch.randn(5, 7, 3))
v = torch.randn(5, 7, 3)
# note: these broadcast together and are transposed to the first dim
mask1 = Variable(torch.ByteTensor([1, 0, 1, 1, 0]))
mask2 = Variable(torch.ByteTensor([1, 1, 1]))
mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
mask2 = torch.ByteTensor([1, 1, 1])
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
def test_byte_mask2d(self):
v = Variable(torch.randn(5, 7, 3))
c = Variable(torch.randn(5, 7))
num_ones = (c > 0).data.sum()
v = torch.randn(5, 7, 3)
c = torch.randn(5, 7)
num_ones = (c > 0).sum()
r = v[c > 0]
self.assertEqual(r.shape, (num_ones, 3))
def test_int_indices(self):
v = Variable(torch.randn(5, 7, 3))
v = torch.randn(5, 7, 3)
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
def test_int_indices2d(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
rows = Variable(torch.LongTensor([[0, 0], [3, 3]]))
columns = Variable(torch.LongTensor([[0, 2], [0, 2]]))
self.assertEqual(x[rows, columns].data.tolist(), [[0, 2], [9, 11]])
x = torch.arange(0, 12).view(4, 3)
rows = torch.tensor([[0, 0], [3, 3]])
columns = torch.tensor([[0, 2], [0, 2]])
self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
def test_int_indices_broadcast(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
rows = Variable(torch.LongTensor([0, 3]))
columns = Variable(torch.LongTensor([0, 2]))
x = torch.arange(0, 12).view(4, 3)
rows = torch.tensor([0, 3])
columns = torch.tensor([0, 2])
result = x[rows[:, None], columns]
self.assertEqual(result.data.tolist(), [[0, 2], [9, 11]])
self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
def test_empty_index(self):
x = Variable(torch.arange(0, 12).view(4, 3))
idx = Variable(torch.LongTensor())
x = torch.arange(0, 12).view(4, 3)
idx = torch.tensor([], dtype=torch.long)
self.assertEqual(x[idx].numel(), 0)
# empty assignment should have no effect but not throw an exception
@ -98,7 +97,7 @@ class TestIndexing(TestCase):
true = torch.tensor(1, dtype=torch.uint8)
false = torch.tensor(0, dtype=torch.uint8)
tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)]
tensors = [torch.randn(2, 3), torch.tensor(3)]
for a in tensors:
self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
@ -112,7 +111,7 @@ class TestIndexing(TestCase):
true = torch.tensor(1, dtype=torch.uint8)
false = torch.tensor(0, dtype=torch.uint8)
tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)]
tensors = [torch.randn(2, 3), torch.tensor(3)]
for a in tensors:
# prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
@ -136,21 +135,21 @@ class TestIndexing(TestCase):
a[:] = neg_ones_expanded * 5
def test_setitem_expansion_error(self):
true = torch.tensor(1, dtype=torch.uint8)
a = Variable(torch.randn(2, 3))
true = torch.tensor(True)
a = torch.randn(2, 3)
# check prefix with non-1s doesn't work
a_expanded = a.expand(torch.Size([5, 1]) + a.size())
with self.assertRaises(RuntimeError):
a[True] = a_expanded
with self.assertRaises(RuntimeError):
a[true] = torch.autograd.Variable(a_expanded)
a[true] = a_expanded
def test_getitem_scalars(self):
zero = torch.tensor(0, dtype=torch.int64)
one = torch.tensor(1, dtype=torch.int64)
# non-scalar indexed with scalars
a = Variable(torch.randn(2, 3))
a = torch.randn(2, 3)
self.assertEqual(a[0], a[zero])
self.assertEqual(a[0][1], a[zero][one])
self.assertEqual(a[0, 1], a[zero, one])
@ -173,10 +172,10 @@ class TestIndexing(TestCase):
zero = torch.tensor(0, dtype=torch.int64)
# non-scalar indexed with scalars
a = Variable(torch.randn(2, 3))
a = torch.randn(2, 3)
a_set_with_number = a.clone()
a_set_with_scalar = a.clone()
b = Variable(torch.randn(3))
b = torch.randn(3)
a_set_with_number[0] = b
a_set_with_scalar[zero] = b
@ -195,9 +194,9 @@ class TestIndexing(TestCase):
def test_basic_advanced_combined(self):
# From the NumPy indexing example
x = Variable(torch.arange(0, 12).view(4, 3))
x = torch.arange(0, 12).view(4, 3)
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
self.assertEqual(x[1:2, 1:3].data.tolist(), [[4, 5]])
self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
# Check that it is a copy
unmodified = x.clone()
@ -210,33 +209,33 @@ class TestIndexing(TestCase):
self.assertNotEqual(x, unmodified)
def test_int_assignment(self):
x = Variable(torch.arange(0, 4).view(2, 2))
x = torch.arange(0, 4).view(2, 2)
x[1] = 5
self.assertEqual(x.data.tolist(), [[0, 1], [5, 5]])
self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
x = Variable(torch.arange(0, 4).view(2, 2))
x[1] = Variable(torch.arange(5, 7))
self.assertEqual(x.data.tolist(), [[0, 1], [5, 6]])
x = torch.arange(0, 4).view(2, 2)
x[1] = torch.arange(5, 7)
self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
def test_byte_tensor_assignment(self):
x = Variable(torch.arange(0, 16).view(4, 4))
b = Variable(torch.ByteTensor([True, False, True, False]))
value = Variable(torch.Tensor([3, 4, 5, 6]))
x = torch.arange(0, 16).view(4, 4)
b = torch.ByteTensor([True, False, True, False])
value = torch.tensor([3., 4., 5., 6.])
x[b] = value
self.assertEqual(x[0], value)
self.assertEqual(x[1].data, torch.arange(4, 8))
self.assertEqual(x[1], torch.arange(4, 8))
self.assertEqual(x[2], value)
self.assertEqual(x[3].data, torch.arange(12, 16))
self.assertEqual(x[3], torch.arange(12, 16))
def test_variable_slicing(self):
x = Variable(torch.arange(0, 16).view(4, 4))
indices = Variable(torch.IntTensor([0, 1]))
x = torch.arange(0, 16).view(4, 4)
indices = torch.IntTensor([0, 1])
i, j = indices
self.assertEqual(x[i:j], x[0:1])
def test_ellipsis_tensor(self):
x = Variable(torch.arange(0, 9).view(3, 3))
idx = Variable(torch.LongTensor([0, 2]))
x = torch.arange(0, 9).view(3, 3)
idx = torch.tensor([0, 2])
self.assertEqual(x[..., idx].tolist(), [[0, 2],
[3, 5],
[6, 8]])
@ -244,7 +243,7 @@ class TestIndexing(TestCase):
[6, 7, 8]])
def test_invalid_index(self):
x = Variable(torch.arange(0, 16).view(4, 4))
x = torch.arange(0, 16).view(4, 4)
self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
def test_zero_dim_index(self):
@ -256,22 +255,6 @@ class TestIndexing(TestCase):
self.assertEqual(len(w), 1)
def tensor(*args, **kwargs):
return Variable(torch.Tensor(*args, **kwargs))
def byteTensor(data):
return Variable(torch.ByteTensor(data))
def ones(*args):
return Variable(torch.ones(*args))
def zeros(*args):
return Variable(torch.zeros(*args))
# The tests below are from NumPy test_indexing.py with some modifications to
# make them compatible with PyTorch. It's licensed under the BDS license below:
#
@ -309,7 +292,7 @@ def zeros(*args):
class NumpyTests(TestCase):
def test_index_no_floats(self):
a = Variable(torch.Tensor([[[5]]]))
a = torch.tensor([[[5.]]])
self.assertRaises(IndexError, lambda: a[0.0])
self.assertRaises(IndexError, lambda: a[0, 0.0])
@ -348,10 +331,10 @@ class NumpyTests(TestCase):
def test_empty_fancy_index(self):
# Empty list index creates an empty array
a = tensor([1, 2, 3])
self.assertEqual(a[[]], Variable(torch.Tensor()))
self.assertEqual(a[[]], torch.tensor([]))
b = tensor([]).long()
self.assertEqual(a[[]], Variable(torch.LongTensor()))
self.assertEqual(a[[]], torch.tensor([], dtype=torch.long))
b = tensor([]).float()
self.assertRaises(RuntimeError, lambda: a[b])
@ -386,8 +369,8 @@ class NumpyTests(TestCase):
[4, 5, 6],
[7, 8, 9]])
self.assertEqual(a[0].data, [1, 2, 3])
self.assertEqual(a[-1].data, [7, 8, 9])
self.assertEqual(a[0], [1, 2, 3])
self.assertEqual(a[-1], [7, 8, 9])
# Index out of bounds produces IndexError
self.assertRaises(IndexError, a.__getitem__, 1 << 30)
@ -404,16 +387,16 @@ class NumpyTests(TestCase):
self.assertEqual(a[False], a[None][0:0])
def test_boolean_shape_mismatch(self):
arr = ones((5, 4, 3))
arr = torch.ones((5, 4, 3))
# TODO: prefer IndexError
index = byteTensor([True])
index = tensor([True])
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
index = byteTensor([False] * 6)
index = tensor([False] * 6)
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
index = Variable(torch.ByteTensor(4, 4)).zero_()
index = torch.ByteTensor(4, 4).zero_()
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[(slice(None), index)])
@ -422,7 +405,7 @@ class NumpyTests(TestCase):
# Indexing a 2-dimensional array with
# boolean array of length one
a = tensor([[0., 0., 0.]])
b = byteTensor([True])
b = tensor([True])
self.assertEqual(a[b], a)
# boolean assignment
a[b] = 1.
@ -431,7 +414,7 @@ class NumpyTests(TestCase):
def test_boolean_assignment_value_mismatch(self):
# A boolean assignment should fail when the shape of the values
# cannot be broadcast to the subscription. (see also gh-3458)
a = Variable(torch.arange(0, 4))
a = torch.arange(0, 4)
def f(a, v):
a[a > -1] = tensor(v)
@ -446,9 +429,9 @@ class NumpyTests(TestCase):
a = tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
b = byteTensor([[True, False, True],
[False, True, False],
[True, False, True]])
b = tensor([[True, False, True],
[False, True, False],
[True, False, True]])
self.assertEqual(a[b], tensor([1, 3, 5, 7, 9]))
self.assertEqual(a[b[1]], tensor([[4, 5, 6]]))
self.assertEqual(a[b[0]], a[b[2]])
@ -461,39 +444,39 @@ class NumpyTests(TestCase):
def test_everything_returns_views(self):
# Before `...` would return a itself.
a = tensor(5)
a = tensor([5])
self.assertIsNot(a, a[()])
self.assertIsNot(a, a[...])
self.assertIsNot(a, a[:])
def test_broaderrors_indexing(self):
a = zeros(5, 5)
a = torch.zeros(5, 5)
self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
def test_trivial_fancy_out_of_bounds(self):
a = zeros(5)
ind = ones(20).long()
a = torch.zeros(5)
ind = torch.ones(20, dtype=torch.int64)
ind[-1] = 10
self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
ind = ones(20).long()
ind = torch.ones(20, dtype=torch.int64)
ind[0] = 11
self.assertRaises(RuntimeError, a.__getitem__, ind)
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
def test_index_is_larger(self):
# Simple case of fancy index broadcasting of the index.
a = zeros((5, 5))
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2, 3, 4])
a = torch.zeros((5, 5))
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2., 3., 4.])
self.assertTrue((a[:3, :3] == tensor([2, 3, 4])).all())
self.assertTrue((a[:3, :3] == tensor([2., 3., 4.])).all())
def test_broadcast_subspace(self):
a = zeros((100, 100))
v = Variable(torch.arange(0, 100))[:, None]
b = Variable(torch.arange(99, -1, -1).long())
a = torch.zeros((100, 100))
v = torch.arange(0, 100)[:, None]
b = torch.arange(99, -1, -1).long()
a[b] = v
expected = b.double().unsqueeze(1).expand(100, 100)
self.assertEqual(a, expected)