[BE] Move indexing tests to test_indexing (#160994)

Which enables them on MPS device
- xfail all `test_index_reduce` on MPS, as op is not implemented
- xfail all `test_index_copy` on MPS due to the silent correctness problems, see https://github.com/pytorch/pytorch/issues/160993
- Fixed hard crash in `index_fill` and replaced `skipIfMPS` with `expectedFailueMPS`
- Created issue for the lack of deterministic algorithms for MPS backend
Pull Request resolved: https://github.com/pytorch/pytorch/pull/160994
Approved by: https://github.com/manuelcandales
ghstack dependencies: #160850, #160889, #160926
This commit is contained in:
Nikita Shulga
2025-08-20 13:54:29 -07:00
committed by PyTorch MergeBot
parent 667245dc60
commit 3e3e83418d
3 changed files with 341 additions and 262 deletions

View File

@ -3448,267 +3448,6 @@ else:
actual = torch.narrow_copy(inp, 1, 0, 10)
self.assertEqual(expected, actual)
# FIXME: move to indexing test suite
@parametrize("reduce", ['prod', 'amin', 'amax', 'mean'])
@dtypes(*all_types_and(torch.half, torch.bfloat16))
def test_index_reduce(self, device, dtype, reduce):
size = (3, 4, 5)
index_dtypes = [torch.int, torch.long]
include_selfs = [True, False]
amin_init = float('inf') if dtype.is_floating_point else torch.iinfo(dtype).max
amax_init = -float('inf') if dtype.is_floating_point else torch.iinfo(dtype).min
reduction_init = {'prod': 1, 'mean': 0, 'amin': amin_init, 'amax': amax_init}
for dest_noncontig, src_noncontig, index_noncontig in product([True, False], repeat=3):
for idx_dtype, include_self in product(index_dtypes, include_selfs):
for dim in range(len(size)):
num_src = np.random.randint(10)
num_dest = size[dim]
dest = make_tensor(size, device=device, dtype=dtype, noncontiguous=dest_noncontig)
src_size = size[:dim] + (num_src,) + size[dim + 1:]
src = make_tensor(src_size, device=device, dtype=dtype, noncontiguous=src_noncontig)
idx = torch.testing.make_tensor(
num_src, low=0, high=num_dest, dtype=idx_dtype, device=device, noncontiguous=index_noncontig
)
expected = dest.clone()
dest.index_reduce_(dim, idx, src, reduce, include_self=include_self)
# fill rows in idx with reduction inits if include_self=False
if (not include_self):
expected.index_fill_(dim, idx.long(), reduction_init[reduce])
expected = expected.transpose(0, dim)
src = src.transpose(0, dim)
for i in range(num_src):
if reduce == 'prod':
expected[idx[i]] *= src[i]
elif reduce == 'amin':
torch.minimum(expected[idx[i]], src[i], out=expected[idx[i]])
elif reduce == 'amax':
torch.maximum(expected[idx[i]], src[i], out=expected[idx[i]])
else:
expected[idx[i]] += src[i]
if reduce == 'mean':
counts = torch.ones_like(expected) if include_self else torch.zeros_like(expected)
counts.index_add_(0, idx, torch.ones_like(src))
counts.masked_fill_(counts == 0, 1)
if (dtype.is_floating_point):
expected.div_(counts)
else:
expected.div_(counts, rounding_mode="floor")
expected = expected.transpose(0, dim)
self.assertEqual(dest, expected)
# FIXME: move to test indexing
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_index_copy(self, device, dtype):
# We just test for num_copy <= num_dest, as otherwise there are repeated indices
# and the behavior is undefined
num_copy, num_dest = 3, 5
def make_arg(batch_sizes, n, dim, contig):
size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
def ref_index_copy(tgt, dim, idx, src):
for i in range(idx.size(0)):
idx_dest = dim * (slice(None),) + (idx[i],)
idx_src = dim * (slice(None),) + (i,)
tgt[idx_dest] = src[idx_src]
# More thorough testing as in index_add
for dest_contig, src_contig, index_contig in product([True, False], repeat=3):
for other_sizes in ((), (4, 5)):
for dim in range(len(other_sizes)):
dest = make_arg(other_sizes, num_dest, dim, dest_contig)
src = make_arg(other_sizes, num_copy, dim, src_contig)
idx = torch.randperm(num_dest, dtype=torch.int64, device=device)[:num_copy]
if not index_contig:
idx = torch.repeat_interleave(idx, 2, dim=-1)
idx = idx[..., ::2]
dest2 = dest.clone()
dest.index_copy_(dim, idx, src)
ref_index_copy(dest2, dim, idx, src)
self.assertEqual(dest, dest2)
# FIXME: move to test indexing
# onlyNativeDeviceTypes due to an XLA error:
# https://github.com/pytorch/pytorch/issues/53256
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
def test_index_copy_scalars(self, device, dtype):
# Create the 8 possible combinations of scalar sizes for target / index / source
scalars = ((make_tensor(size_t, dtype=dtype, device=device, low=None, high=None),
make_tensor(size_i, dtype=torch.int64, device=device, low=0, high=1),
make_tensor(size_s, dtype=dtype, device=device, low=None, high=None))
for size_t, size_i, size_s in product([(), (1,)], repeat=3))
for target, idx, source in scalars:
target.index_copy_(0, idx, source)
self.assertEqual(target.item(), source.item())
# FIXME: move to test indexing
@onlyCPU
def test_errors_index_copy(self, device):
# We do not test the GPU as the CUDA_ASSERT would break the CUDA context
idx_dim = 8
tgt_dim = 5
batch_dim = 3
# Too large of an index
a = torch.randn(batch_dim, tgt_dim, device=device)
idx = torch.full((idx_dim,), tgt_dim, device=device)
c = torch.zeros(batch_dim, idx_dim, device=device)
with self.assertRaises(IndexError):
a.index_copy_(1, idx, c)
# Too small (negative indices)
idx = torch.full((idx_dim,), -1, device=device)
with self.assertRaises(IndexError):
a.index_copy_(1, idx, c)
# Too small (very negative indices) - they should be unsupported even
# when support for negative indices is implemented for index_copy_
idx = torch.full((idx_dim,), -tgt_dim - 1, device=device)
with self.assertRaises(IndexError):
a.index_copy_(1, idx, c)
def _prepare_data_for_index_copy_and_add_deterministic(
self, dim: int, device: torch.device
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert (dim >= 0 and dim < 3)
a = [5, 4, 3]
a[dim] = 2000
x = torch.zeros(a, device=device)
b = a.copy()
elems = a[dim] * 20
b[dim] = elems
src = torch.rand(b, device=device)
index = torch.randint(a[dim], (elems,), device=device)
return (x, index, src)
# FIXME: move to test indexing
@onlyNativeDeviceTypes
def test_index_copy_deterministic(self, device: torch.device) -> None:
for dim in range(3):
x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
with DeterministicGuard(True):
y0 = torch.index_copy(x, dim, index, src)
x0 = x.detach().clone()
index_list = index.tolist()
for i in range(len(index_list)):
if dim == 0:
x0[index_list[i], :, :] = src[i, :, :]
elif dim == 1:
x0[:, index_list[i], :] = src[:, i, :]
elif dim == 2:
x0[:, :, index_list[i]] = src[:, :, i]
self.assertEqual(x0, y0, atol=0, rtol=0)
# FIXME: move to test indexing
@onlyNativeDeviceTypes
def test_index_add_deterministic(self, device: torch.device) -> None:
for dim in range(3):
x, index, src = self._prepare_data_for_index_copy_and_add_deterministic(dim, device)
alpha = random.random() + 1
# on CPU it should be deterministic regardless of the deterministic mode
with DeterministicGuard(True):
y0 = torch.index_add(x, dim, index, src, alpha=alpha)
for _ in range(3):
y = torch.index_add(x, dim, index, src, alpha=alpha)
self.assertEqual(y, y0, atol=0, rtol=0)
with DeterministicGuard(False):
for _ in range(3):
y_nd = torch.index_add(x, dim, index, src, alpha=alpha)
self.assertEqual(y_nd, y0, atol=1e-3, rtol=1e-5)
# FIXME: find a test suite for the put operator
@onlyNativeDeviceTypes
def test_index_put_non_accumulate_deterministic(self, device) -> None:
with DeterministicGuard(True):
for i in range(3):
m = random.randint(10, 20)
elems = random.randint(20000, 30000)
values = torch.rand(elems, device=device)
indices = torch.randint(m, (elems,), device=device)
input = torch.rand(m, device=device)
output = input.index_put((indices,), values, accumulate=False)
input_list = input.tolist()
indices_list = indices.tolist()
values_list = values.tolist()
for i, v in zip(indices_list, values_list):
input_list[i] = v
self.assertEqual(output, input_list)
# FIXME: move to test indexing
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@skipIfMPS
def test_index_fill(self, device, dtype):
x = torch.tensor([[1, 2], [4, 5]], dtype=dtype, device=device)
index = torch.tensor([0], device=device)
x.index_fill_(1, index, 0)
self.assertEqual(x, torch.tensor([[0, 2], [0, 5]], dtype=dtype, device=device))
if not x.is_complex() and not device == "meta":
with self.assertRaisesRegex(RuntimeError, r"Scalar"):
x.index_fill_(1, index, 1 + 1j)
# Make sure that the result stays 0-dim while applied to
# a 0-dim input
x = torch.tensor(1, dtype=dtype, device=device)
self.assertEqual(0, x.index_fill(0, index, -1).dim())
self.assertEqual(0, x.index_fill_(0, index, -1).dim())
# FIXME: move to test indexing
# The test fails for zero-dimensional tensors on XLA
@onlyNativeDeviceTypes
@dtypes(*all_types_complex_float8_and(torch.half, torch.bool, torch.bfloat16))
def test_index_select(self, device, dtype):
num_src, num_out = 3, 5
def make_arg(batch_sizes, n, dim, contig):
size_arg = batch_sizes[:dim] + (n,) + batch_sizes[dim:]
return make_tensor(size_arg, dtype=dtype, device=device, low=None, high=None, noncontiguous=not contig)
def ref_index_select(src, dim, idx):
# some types not supported on numpy
not_np_dtypes = (torch.bfloat16, torch.float8_e5m2, torch.float8_e5m2fnuz, torch.float8_e4m3fn, torch.float8_e4m3fnuz)
if dtype in not_np_dtypes:
src = src.float()
out = torch.from_numpy(np.take(src.cpu().numpy(), idx.cpu().numpy(), axis=dim))
if dtype in not_np_dtypes:
out = out.to(device=device, dtype=dtype)
return out
for src_contig, idx_contig in product([True, False], repeat=2):
for other_sizes in ((), (4, 5)):
for dim in range(len(other_sizes)):
src = make_arg(other_sizes, num_src, dim, src_contig)
idx = make_tensor(
(num_out,), dtype=torch.int64, device=device, low=0, high=num_src, noncontiguous=not idx_contig
)
out = torch.index_select(src, dim, idx)
out2 = ref_index_select(src, dim, idx)
self.assertEqual(out, out2)
for idx_type in (torch.int32, torch.int64):
other_sizes = (3, 2)
dim = 1
src = make_arg(other_sizes, num_src, dim, True)
idx = make_tensor((num_out,), dtype=idx_type, device=device, low=0, high=num_src, noncontiguous=False)
out = torch.index_select(src, dim, idx)
out2 = ref_index_select(src, dim, idx)
self.assertEqual(out, out2)
# Create the 4 possible combinations of scalar sizes for index / source
scalars = ((make_tensor(size_s, dtype=dtype, device=device),
torch.zeros(size_i, dtype=torch.int64, device=device))
for size_s, size_i in product([(), (1,)], repeat=2))
for source, idx in scalars:
out = source.index_select(0, idx)
self.assertEqual(out.item(), source.item())
# FIXME: find a test suite for the take operator
@dtypes(*all_types_and_complex_and(torch.half, torch.bool, torch.bfloat16))
@slowTestIf(IS_WINDOWS)