mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 13:44:15 +08:00
Clean-up test_indexing.py after Tensor/Variable merge (#6433)
This commit is contained in:
@ -1,88 +1,87 @@
|
||||
from common import TestCase, run_tests
|
||||
import unittest
|
||||
import torch
|
||||
import warnings
|
||||
from torch.autograd import Variable
|
||||
from torch import tensor
|
||||
|
||||
|
||||
class TestIndexing(TestCase):
|
||||
def test_single_int(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
v = torch.randn(5, 7, 3)
|
||||
self.assertEqual(v[4].shape, (7, 3))
|
||||
|
||||
def test_multiple_int(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
v = torch.randn(5, 7, 3)
|
||||
self.assertEqual(v[4].shape, (7, 3))
|
||||
self.assertEqual(v[4, :, 1].shape, (7,))
|
||||
|
||||
def test_none(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
v = torch.randn(5, 7, 3)
|
||||
self.assertEqual(v[None].shape, (1, 5, 7, 3))
|
||||
self.assertEqual(v[:, None].shape, (5, 1, 7, 3))
|
||||
self.assertEqual(v[:, None, None].shape, (5, 1, 1, 7, 3))
|
||||
self.assertEqual(v[..., None].shape, (5, 7, 3, 1))
|
||||
|
||||
def test_step(self):
|
||||
v = Variable(torch.arange(10))
|
||||
v = torch.arange(10)
|
||||
self.assertEqual(v[::1], v)
|
||||
self.assertEqual(v[::2].data.tolist(), [0, 2, 4, 6, 8])
|
||||
self.assertEqual(v[::3].data.tolist(), [0, 3, 6, 9])
|
||||
self.assertEqual(v[::11].data.tolist(), [0])
|
||||
self.assertEqual(v[1:6:2].data.tolist(), [1, 3, 5])
|
||||
self.assertEqual(v[::2].tolist(), [0, 2, 4, 6, 8])
|
||||
self.assertEqual(v[::3].tolist(), [0, 3, 6, 9])
|
||||
self.assertEqual(v[::11].tolist(), [0])
|
||||
self.assertEqual(v[1:6:2].tolist(), [1, 3, 5])
|
||||
|
||||
def test_step_assignment(self):
|
||||
v = Variable(torch.zeros(4, 4))
|
||||
v[0, 1::2] = Variable(torch.Tensor([3, 4]))
|
||||
self.assertEqual(v[0].data.tolist(), [0, 3, 0, 4])
|
||||
self.assertEqual(v[1:].data.sum(), 0)
|
||||
v = torch.zeros(4, 4)
|
||||
v[0, 1::2] = torch.tensor([3., 4.])
|
||||
self.assertEqual(v[0].tolist(), [0, 3, 0, 4])
|
||||
self.assertEqual(v[1:].sum(), 0)
|
||||
|
||||
def test_byte_mask(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
mask = Variable(torch.ByteTensor([1, 0, 1, 1, 0]))
|
||||
v = torch.randn(5, 7, 3)
|
||||
mask = torch.ByteTensor([1, 0, 1, 1, 0])
|
||||
self.assertEqual(v[mask].shape, (3, 7, 3))
|
||||
self.assertEqual(v[mask], torch.stack([v[0], v[2], v[3]]))
|
||||
|
||||
v = Variable(torch.Tensor([1]))
|
||||
self.assertEqual(v[v == 0], Variable(torch.Tensor()))
|
||||
v = torch.tensor([1.])
|
||||
self.assertEqual(v[v == 0], torch.tensor([]))
|
||||
|
||||
def test_multiple_byte_mask(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
v = torch.randn(5, 7, 3)
|
||||
# note: these broadcast together and are transposed to the first dim
|
||||
mask1 = Variable(torch.ByteTensor([1, 0, 1, 1, 0]))
|
||||
mask2 = Variable(torch.ByteTensor([1, 1, 1]))
|
||||
mask1 = torch.ByteTensor([1, 0, 1, 1, 0])
|
||||
mask2 = torch.ByteTensor([1, 1, 1])
|
||||
self.assertEqual(v[mask1, :, mask2].shape, (3, 7))
|
||||
|
||||
def test_byte_mask2d(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
c = Variable(torch.randn(5, 7))
|
||||
num_ones = (c > 0).data.sum()
|
||||
v = torch.randn(5, 7, 3)
|
||||
c = torch.randn(5, 7)
|
||||
num_ones = (c > 0).sum()
|
||||
r = v[c > 0]
|
||||
self.assertEqual(r.shape, (num_ones, 3))
|
||||
|
||||
def test_int_indices(self):
|
||||
v = Variable(torch.randn(5, 7, 3))
|
||||
v = torch.randn(5, 7, 3)
|
||||
self.assertEqual(v[[0, 4, 2]].shape, (3, 7, 3))
|
||||
self.assertEqual(v[:, [0, 4, 2]].shape, (5, 3, 3))
|
||||
self.assertEqual(v[:, [[0, 1], [4, 3]]].shape, (5, 2, 2, 3))
|
||||
|
||||
def test_int_indices2d(self):
|
||||
# From the NumPy indexing example
|
||||
x = Variable(torch.arange(0, 12).view(4, 3))
|
||||
rows = Variable(torch.LongTensor([[0, 0], [3, 3]]))
|
||||
columns = Variable(torch.LongTensor([[0, 2], [0, 2]]))
|
||||
self.assertEqual(x[rows, columns].data.tolist(), [[0, 2], [9, 11]])
|
||||
x = torch.arange(0, 12).view(4, 3)
|
||||
rows = torch.tensor([[0, 0], [3, 3]])
|
||||
columns = torch.tensor([[0, 2], [0, 2]])
|
||||
self.assertEqual(x[rows, columns].tolist(), [[0, 2], [9, 11]])
|
||||
|
||||
def test_int_indices_broadcast(self):
|
||||
# From the NumPy indexing example
|
||||
x = Variable(torch.arange(0, 12).view(4, 3))
|
||||
rows = Variable(torch.LongTensor([0, 3]))
|
||||
columns = Variable(torch.LongTensor([0, 2]))
|
||||
x = torch.arange(0, 12).view(4, 3)
|
||||
rows = torch.tensor([0, 3])
|
||||
columns = torch.tensor([0, 2])
|
||||
result = x[rows[:, None], columns]
|
||||
self.assertEqual(result.data.tolist(), [[0, 2], [9, 11]])
|
||||
self.assertEqual(result.tolist(), [[0, 2], [9, 11]])
|
||||
|
||||
def test_empty_index(self):
|
||||
x = Variable(torch.arange(0, 12).view(4, 3))
|
||||
idx = Variable(torch.LongTensor())
|
||||
x = torch.arange(0, 12).view(4, 3)
|
||||
idx = torch.tensor([], dtype=torch.long)
|
||||
self.assertEqual(x[idx].numel(), 0)
|
||||
|
||||
# empty assignment should have no effect but not throw an exception
|
||||
@ -98,7 +97,7 @@ class TestIndexing(TestCase):
|
||||
true = torch.tensor(1, dtype=torch.uint8)
|
||||
false = torch.tensor(0, dtype=torch.uint8)
|
||||
|
||||
tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)]
|
||||
tensors = [torch.randn(2, 3), torch.tensor(3)]
|
||||
|
||||
for a in tensors:
|
||||
self.assertNotEqual(a.data_ptr(), a[True].data_ptr())
|
||||
@ -112,7 +111,7 @@ class TestIndexing(TestCase):
|
||||
true = torch.tensor(1, dtype=torch.uint8)
|
||||
false = torch.tensor(0, dtype=torch.uint8)
|
||||
|
||||
tensors = [Variable(torch.randn(2, 3)), torch.tensor(3)]
|
||||
tensors = [torch.randn(2, 3), torch.tensor(3)]
|
||||
|
||||
for a in tensors:
|
||||
# prefix with a 1,1, to ensure we are compatible with numpy which cuts off prefix 1s
|
||||
@ -136,21 +135,21 @@ class TestIndexing(TestCase):
|
||||
a[:] = neg_ones_expanded * 5
|
||||
|
||||
def test_setitem_expansion_error(self):
|
||||
true = torch.tensor(1, dtype=torch.uint8)
|
||||
a = Variable(torch.randn(2, 3))
|
||||
true = torch.tensor(True)
|
||||
a = torch.randn(2, 3)
|
||||
# check prefix with non-1s doesn't work
|
||||
a_expanded = a.expand(torch.Size([5, 1]) + a.size())
|
||||
with self.assertRaises(RuntimeError):
|
||||
a[True] = a_expanded
|
||||
with self.assertRaises(RuntimeError):
|
||||
a[true] = torch.autograd.Variable(a_expanded)
|
||||
a[true] = a_expanded
|
||||
|
||||
def test_getitem_scalars(self):
|
||||
zero = torch.tensor(0, dtype=torch.int64)
|
||||
one = torch.tensor(1, dtype=torch.int64)
|
||||
|
||||
# non-scalar indexed with scalars
|
||||
a = Variable(torch.randn(2, 3))
|
||||
a = torch.randn(2, 3)
|
||||
self.assertEqual(a[0], a[zero])
|
||||
self.assertEqual(a[0][1], a[zero][one])
|
||||
self.assertEqual(a[0, 1], a[zero, one])
|
||||
@ -173,10 +172,10 @@ class TestIndexing(TestCase):
|
||||
zero = torch.tensor(0, dtype=torch.int64)
|
||||
|
||||
# non-scalar indexed with scalars
|
||||
a = Variable(torch.randn(2, 3))
|
||||
a = torch.randn(2, 3)
|
||||
a_set_with_number = a.clone()
|
||||
a_set_with_scalar = a.clone()
|
||||
b = Variable(torch.randn(3))
|
||||
b = torch.randn(3)
|
||||
|
||||
a_set_with_number[0] = b
|
||||
a_set_with_scalar[zero] = b
|
||||
@ -195,9 +194,9 @@ class TestIndexing(TestCase):
|
||||
|
||||
def test_basic_advanced_combined(self):
|
||||
# From the NumPy indexing example
|
||||
x = Variable(torch.arange(0, 12).view(4, 3))
|
||||
x = torch.arange(0, 12).view(4, 3)
|
||||
self.assertEqual(x[1:2, 1:3], x[1:2, [1, 2]])
|
||||
self.assertEqual(x[1:2, 1:3].data.tolist(), [[4, 5]])
|
||||
self.assertEqual(x[1:2, 1:3].tolist(), [[4, 5]])
|
||||
|
||||
# Check that it is a copy
|
||||
unmodified = x.clone()
|
||||
@ -210,33 +209,33 @@ class TestIndexing(TestCase):
|
||||
self.assertNotEqual(x, unmodified)
|
||||
|
||||
def test_int_assignment(self):
|
||||
x = Variable(torch.arange(0, 4).view(2, 2))
|
||||
x = torch.arange(0, 4).view(2, 2)
|
||||
x[1] = 5
|
||||
self.assertEqual(x.data.tolist(), [[0, 1], [5, 5]])
|
||||
self.assertEqual(x.tolist(), [[0, 1], [5, 5]])
|
||||
|
||||
x = Variable(torch.arange(0, 4).view(2, 2))
|
||||
x[1] = Variable(torch.arange(5, 7))
|
||||
self.assertEqual(x.data.tolist(), [[0, 1], [5, 6]])
|
||||
x = torch.arange(0, 4).view(2, 2)
|
||||
x[1] = torch.arange(5, 7)
|
||||
self.assertEqual(x.tolist(), [[0, 1], [5, 6]])
|
||||
|
||||
def test_byte_tensor_assignment(self):
|
||||
x = Variable(torch.arange(0, 16).view(4, 4))
|
||||
b = Variable(torch.ByteTensor([True, False, True, False]))
|
||||
value = Variable(torch.Tensor([3, 4, 5, 6]))
|
||||
x = torch.arange(0, 16).view(4, 4)
|
||||
b = torch.ByteTensor([True, False, True, False])
|
||||
value = torch.tensor([3., 4., 5., 6.])
|
||||
x[b] = value
|
||||
self.assertEqual(x[0], value)
|
||||
self.assertEqual(x[1].data, torch.arange(4, 8))
|
||||
self.assertEqual(x[1], torch.arange(4, 8))
|
||||
self.assertEqual(x[2], value)
|
||||
self.assertEqual(x[3].data, torch.arange(12, 16))
|
||||
self.assertEqual(x[3], torch.arange(12, 16))
|
||||
|
||||
def test_variable_slicing(self):
|
||||
x = Variable(torch.arange(0, 16).view(4, 4))
|
||||
indices = Variable(torch.IntTensor([0, 1]))
|
||||
x = torch.arange(0, 16).view(4, 4)
|
||||
indices = torch.IntTensor([0, 1])
|
||||
i, j = indices
|
||||
self.assertEqual(x[i:j], x[0:1])
|
||||
|
||||
def test_ellipsis_tensor(self):
|
||||
x = Variable(torch.arange(0, 9).view(3, 3))
|
||||
idx = Variable(torch.LongTensor([0, 2]))
|
||||
x = torch.arange(0, 9).view(3, 3)
|
||||
idx = torch.tensor([0, 2])
|
||||
self.assertEqual(x[..., idx].tolist(), [[0, 2],
|
||||
[3, 5],
|
||||
[6, 8]])
|
||||
@ -244,7 +243,7 @@ class TestIndexing(TestCase):
|
||||
[6, 7, 8]])
|
||||
|
||||
def test_invalid_index(self):
|
||||
x = Variable(torch.arange(0, 16).view(4, 4))
|
||||
x = torch.arange(0, 16).view(4, 4)
|
||||
self.assertRaisesRegex(TypeError, 'slice indices', lambda: x["0":"1"])
|
||||
|
||||
def test_zero_dim_index(self):
|
||||
@ -256,22 +255,6 @@ class TestIndexing(TestCase):
|
||||
self.assertEqual(len(w), 1)
|
||||
|
||||
|
||||
def tensor(*args, **kwargs):
|
||||
return Variable(torch.Tensor(*args, **kwargs))
|
||||
|
||||
|
||||
def byteTensor(data):
|
||||
return Variable(torch.ByteTensor(data))
|
||||
|
||||
|
||||
def ones(*args):
|
||||
return Variable(torch.ones(*args))
|
||||
|
||||
|
||||
def zeros(*args):
|
||||
return Variable(torch.zeros(*args))
|
||||
|
||||
|
||||
# The tests below are from NumPy test_indexing.py with some modifications to
|
||||
# make them compatible with PyTorch. It's licensed under the BDS license below:
|
||||
#
|
||||
@ -309,7 +292,7 @@ def zeros(*args):
|
||||
|
||||
class NumpyTests(TestCase):
|
||||
def test_index_no_floats(self):
|
||||
a = Variable(torch.Tensor([[[5]]]))
|
||||
a = torch.tensor([[[5.]]])
|
||||
|
||||
self.assertRaises(IndexError, lambda: a[0.0])
|
||||
self.assertRaises(IndexError, lambda: a[0, 0.0])
|
||||
@ -348,10 +331,10 @@ class NumpyTests(TestCase):
|
||||
def test_empty_fancy_index(self):
|
||||
# Empty list index creates an empty array
|
||||
a = tensor([1, 2, 3])
|
||||
self.assertEqual(a[[]], Variable(torch.Tensor()))
|
||||
self.assertEqual(a[[]], torch.tensor([]))
|
||||
|
||||
b = tensor([]).long()
|
||||
self.assertEqual(a[[]], Variable(torch.LongTensor()))
|
||||
self.assertEqual(a[[]], torch.tensor([], dtype=torch.long))
|
||||
|
||||
b = tensor([]).float()
|
||||
self.assertRaises(RuntimeError, lambda: a[b])
|
||||
@ -386,8 +369,8 @@ class NumpyTests(TestCase):
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]])
|
||||
|
||||
self.assertEqual(a[0].data, [1, 2, 3])
|
||||
self.assertEqual(a[-1].data, [7, 8, 9])
|
||||
self.assertEqual(a[0], [1, 2, 3])
|
||||
self.assertEqual(a[-1], [7, 8, 9])
|
||||
|
||||
# Index out of bounds produces IndexError
|
||||
self.assertRaises(IndexError, a.__getitem__, 1 << 30)
|
||||
@ -404,16 +387,16 @@ class NumpyTests(TestCase):
|
||||
self.assertEqual(a[False], a[None][0:0])
|
||||
|
||||
def test_boolean_shape_mismatch(self):
|
||||
arr = ones((5, 4, 3))
|
||||
arr = torch.ones((5, 4, 3))
|
||||
|
||||
# TODO: prefer IndexError
|
||||
index = byteTensor([True])
|
||||
index = tensor([True])
|
||||
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
|
||||
|
||||
index = byteTensor([False] * 6)
|
||||
index = tensor([False] * 6)
|
||||
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
|
||||
|
||||
index = Variable(torch.ByteTensor(4, 4)).zero_()
|
||||
index = torch.ByteTensor(4, 4).zero_()
|
||||
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[index])
|
||||
|
||||
self.assertRaisesRegex(RuntimeError, 'mask', lambda: arr[(slice(None), index)])
|
||||
@ -422,7 +405,7 @@ class NumpyTests(TestCase):
|
||||
# Indexing a 2-dimensional array with
|
||||
# boolean array of length one
|
||||
a = tensor([[0., 0., 0.]])
|
||||
b = byteTensor([True])
|
||||
b = tensor([True])
|
||||
self.assertEqual(a[b], a)
|
||||
# boolean assignment
|
||||
a[b] = 1.
|
||||
@ -431,7 +414,7 @@ class NumpyTests(TestCase):
|
||||
def test_boolean_assignment_value_mismatch(self):
|
||||
# A boolean assignment should fail when the shape of the values
|
||||
# cannot be broadcast to the subscription. (see also gh-3458)
|
||||
a = Variable(torch.arange(0, 4))
|
||||
a = torch.arange(0, 4)
|
||||
|
||||
def f(a, v):
|
||||
a[a > -1] = tensor(v)
|
||||
@ -446,9 +429,9 @@ class NumpyTests(TestCase):
|
||||
a = tensor([[1, 2, 3],
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]])
|
||||
b = byteTensor([[True, False, True],
|
||||
[False, True, False],
|
||||
[True, False, True]])
|
||||
b = tensor([[True, False, True],
|
||||
[False, True, False],
|
||||
[True, False, True]])
|
||||
self.assertEqual(a[b], tensor([1, 3, 5, 7, 9]))
|
||||
self.assertEqual(a[b[1]], tensor([[4, 5, 6]]))
|
||||
self.assertEqual(a[b[0]], a[b[2]])
|
||||
@ -461,39 +444,39 @@ class NumpyTests(TestCase):
|
||||
|
||||
def test_everything_returns_views(self):
|
||||
# Before `...` would return a itself.
|
||||
a = tensor(5)
|
||||
a = tensor([5])
|
||||
|
||||
self.assertIsNot(a, a[()])
|
||||
self.assertIsNot(a, a[...])
|
||||
self.assertIsNot(a, a[:])
|
||||
|
||||
def test_broaderrors_indexing(self):
|
||||
a = zeros(5, 5)
|
||||
a = torch.zeros(5, 5)
|
||||
self.assertRaisesRegex(RuntimeError, 'match the size', a.__getitem__, ([0, 1], [0, 1, 2]))
|
||||
self.assertRaisesRegex(RuntimeError, 'match the size', a.__setitem__, ([0, 1], [0, 1, 2]), 0)
|
||||
|
||||
def test_trivial_fancy_out_of_bounds(self):
|
||||
a = zeros(5)
|
||||
ind = ones(20).long()
|
||||
a = torch.zeros(5)
|
||||
ind = torch.ones(20, dtype=torch.int64)
|
||||
ind[-1] = 10
|
||||
self.assertRaises(RuntimeError, a.__getitem__, ind)
|
||||
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
|
||||
ind = ones(20).long()
|
||||
ind = torch.ones(20, dtype=torch.int64)
|
||||
ind[0] = 11
|
||||
self.assertRaises(RuntimeError, a.__getitem__, ind)
|
||||
self.assertRaises(RuntimeError, a.__setitem__, ind, 0)
|
||||
|
||||
def test_index_is_larger(self):
|
||||
# Simple case of fancy index broadcasting of the index.
|
||||
a = zeros((5, 5))
|
||||
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2, 3, 4])
|
||||
a = torch.zeros((5, 5))
|
||||
a[[[0], [1], [2]], [0, 1, 2]] = tensor([2., 3., 4.])
|
||||
|
||||
self.assertTrue((a[:3, :3] == tensor([2, 3, 4])).all())
|
||||
self.assertTrue((a[:3, :3] == tensor([2., 3., 4.])).all())
|
||||
|
||||
def test_broadcast_subspace(self):
|
||||
a = zeros((100, 100))
|
||||
v = Variable(torch.arange(0, 100))[:, None]
|
||||
b = Variable(torch.arange(99, -1, -1).long())
|
||||
a = torch.zeros((100, 100))
|
||||
v = torch.arange(0, 100)[:, None]
|
||||
b = torch.arange(99, -1, -1).long()
|
||||
a[b] = v
|
||||
expected = b.double().unsqueeze(1).expand(100, 100)
|
||||
self.assertEqual(a, expected)
|
||||
|
Reference in New Issue
Block a user