Files
pytorch/test/test_scatter_gather_ops.py
Natalia Gimelshein 541f8cd34c faster gather implementation (#151490)
So far it's only for `gather`, but we'll move index_select and index to this implementation too. Torchtitan and fbgemm have noticed that gather/index_select perf is bad, this PR brings core implementation to be on par with those customized implementations. Added benefits: all dtypes are supported, a bit less strict on the tensor dimensions/contiguity because we pick the fast path after TensorIterator collapsed the dimensions.

Biggest part of this PR is not even the kernel (it's dumb, just vectorized loads are enough), but moving utilities for vectorized loads and stores from SymmetricMemory to be generally accessible in MemoryAccess.cuh.
Additional tests are coming to make sure this implementation doesn't break anything

`gather` is equivalent to x[indices] for 1d indices via
```
def fn_gather(x, indices):
    return torch.gather(x, dim=0, index=indices.unsqueeze(1).expand(-1, x.shape[1]))

def fn_index(x, indices):
    return x[indices]
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151490
Approved by: https://github.com/Skylion007, https://github.com/eqy
2025-04-18 07:48:31 +00:00

411 lines
20 KiB
Python

# Owner(s): ["module: scatter & gather ops"]
import random
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import \
(parametrize, run_tests, TestCase, DeterministicGuard)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA,
toleranceOverride, tol,)
from torch.testing._internal.common_dtype import \
(get_all_dtypes,)
# Protects against includes accidentally setting the default dtype
assert torch.get_default_dtype() is torch.float32
# Note: test_scatter_gather_ops.py
# This test file tests scatter and gather operations,
# like torch.scatter and torch.gather.
class TestScatterGather(TestCase):
# Fills an index tensor with valid indices
def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o, unique_indices=True):
for i in range(1 if dim == 0 else m):
for j in range(1 if dim == 1 else n):
for k in range(1 if dim == 2 else o):
ii = [i, j, k]
ii[dim] = slice(0, idx.size(dim) + 1)
if unique_indices:
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
else:
idx[tuple(ii)] = torch.randint(dim_size, (elems_per_row,))
@dtypes(torch.float32, torch.complex64)
def test_gather(self, device, dtype):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
src = make_tensor((m, n, o), device=device, dtype=dtype)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = make_tensor(idx_size, device=device, dtype=torch.long)
self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
actual = torch.gather(src, dim, idx)
expected = torch.zeros(idx_size, device=device, dtype=dtype)
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
expected[i, j, k] = src[tuple(ii)]
self.assertEqual(actual, expected, atol=0, rtol=0)
# Guarded because torch.max isn't defined for complex types
if not dtype.is_complex:
src = make_tensor((3, 4, 5), device=device, dtype=dtype)
expected, idx = src.max(2, True)
actual = torch.gather(src, 2, idx)
self.assertEqual(actual, expected, atol=0, rtol=0)
@dtypes(torch.int8, torch.bfloat16)
def test_gather_large(self, device, dtype):
# test larger shapes to check vectorized implementation
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)):
src = make_tensor((m, k), device=device, dtype=dtype)
alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype)
discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src)
alloc1 = torch.empty(src.nelement() + 1, device=device, dtype=dtype)
misaligned = alloc1[1:].view(m, k).copy_(src)
num_ind = n
for dim in (0, 1):
max_ind = src.shape[dim]
ind0 = torch.randint(max_ind, (num_ind,), device=device)
shape_ind = [1] * src.ndim
shape_ind[dim] = ind0.shape[0]
shape_out = list(src.shape)
shape_out[dim] = ind0.shape[0]
ind = ind0.view(shape_ind).expand(shape_out)
res = torch.gather(src, dim=dim, index=ind)
ref = src[ind0] if dim == 0 else src[:, ind0]
self.assertEqual(res, ref, atol=0, rtol=0)
res = torch.gather(discontig, dim=dim, index=ind)
self.assertEqual(res, ref, atol=0, rtol=0)
res = torch.gather(misaligned, dim=dim, index=ind)
self.assertEqual(res, ref, atol=0, rtol=0)
@dtypes(torch.bool)
def test_gather_bool(self, device, dtype):
src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
actual = torch.gather(src, 1, idx)
expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype)
self.assertEqual(actual, expected, atol=0, rtol=0)
@parametrize("sparse_grad", [False, True])
@dtypes(torch.float32, torch.float64)
def test_gather_backward_with_empty_index_tensor(self, device, dtype, sparse_grad):
dim = -1
input = torch.rand([10, 5], dtype=dtype, device=device, requires_grad=True)
index = torch.randint(0, 2, [3, 0], dtype=torch.int64, device=device)
res = torch.gather(input, dim, index, sparse_grad=sparse_grad)
res.sum().backward()
grad = input.grad.to_dense() if sparse_grad else input.grad
expected_grad = torch.zeros_like(input, requires_grad=False)
self.assertEqual(grad, expected_grad, atol=0, rtol=0)
def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction,
unique_indices=True, include_self=True):
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
elems_per_row = random.randint(1, 10)
dim = random.randrange(3)
idx_size = [m, n, o]
idx_size[dim] = elems_per_row
idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o, unique_indices)
if is_scalar:
src = random.random()
else:
src_size = [random.randint(1, 5) + s for s in idx_size]
src = make_tensor(tuple(src_size), device=device, dtype=dtype)
base = make_tensor((m, n, o), device=device, dtype=dtype)
if reduction is not None:
if fn is torch.Tensor.scatter_reduce_:
actual = fn(base.clone(), dim, idx, src, reduce=reduction, include_self=include_self)
else:
actual = fn(base.clone(), dim, idx, src, reduce=reduction)
else:
actual = fn(base.clone(), dim, idx, src)
expected = base.clone()
counts = torch.zeros(base.shape, dtype=torch.long, device=device) + include_self
for i in range(idx_size[0]):
for j in range(idx_size[1]):
for k in range(idx_size[2]):
ii = [i, j, k]
ii[dim] = idx[i, j, k]
if fn is torch.Tensor.scatter_add_:
expected[tuple(ii)] += src[i, j, k]
else:
# method may be 'scatter_', 'scatter', 'scatter_reduce'
# or 'scatter_reduce_', the former two might have a reduction argument
# while the latter two always do
value = src if is_scalar else src[i, j, k]
if ((not include_self) and counts[tuple(ii)] == 0):
expected[tuple(ii)] = value
else:
if reduction == "add" or reduction == "sum":
expected[tuple(ii)] += value
elif reduction == "multiply" or reduction == "prod":
expected[tuple(ii)] *= value
elif reduction == "amax":
expected[tuple(ii)] = max(expected[tuple(ii)], value)
elif reduction == "amin":
expected[tuple(ii)] = min(expected[tuple(ii)], value)
elif reduction == "mean":
expected[tuple(ii)] += value
else:
expected[tuple(ii)] = value
counts[tuple(ii)] += 1
if (reduction == "mean"):
counts.masked_fill_(counts == 0, 1)
if (dtype.is_floating_point or dtype.is_complex):
expected /= counts
else:
expected.div_(counts, rounding_mode="floor")
if dtype == torch.float16 or dtype == torch.bfloat16:
# Some CUDA kernels (e.g. indexing_backward_kernel_stride_1) that are called during
# the test use fp32 for internal accumulation for improved accuracy. When using 16 bit
# precision types can be small differences
self.assertEqual(actual, expected, atol=0.04, rtol=0.05)
else:
self.assertEqual(actual, expected, atol=0, rtol=0)
# Tests empty index
dst = make_tensor((2, 2), device=device, dtype=dtype)
idx = torch.tensor((), device=device, dtype=torch.long)
src = make_tensor((2, 2), device=device, dtype=dtype)
if reduction is not None:
actual = fn(dst, 0, idx, src, reduce=reduction)
else:
actual = fn(dst, 0, idx, src)
self.assertEqual(actual, dst, atol=0, rtol=0)
@dtypes(torch.float16, torch.float32, torch.complex64)
def test_scatter_(self, device, dtype):
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
is_scalar=False, reduction=None)
@dtypes(torch.float16, torch.float32, torch.complex64)
def test_scatter__scalar(self, device, dtype):
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
is_scalar=True, reduction=None)
# FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
@toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
@dtypesIfCUDA(torch.float16, torch.float32)
@dtypes(torch.float16, torch.float32, torch.complex64)
def test_scatter__reductions(self, device, dtype):
for reduction in ("add", "multiply"):
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
is_scalar=False, reduction=reduction)
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
is_scalar=True, reduction=reduction)
@dtypes(torch.float16, torch.float32, torch.complex64)
def test_scatter_add_(self, device, dtype):
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
is_scalar=False, reduction=None)
@dtypes(torch.float32)
def test_scatter_add_mult_index_base(self, device, dtype):
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
m, n = 30, 40
idx = torch.zeros(m, n, device=device, dtype=torch.long)
src = torch.ones(m, n, device=device, dtype=dtype)
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
# FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA)
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
def test_scatter_reduce_sum(self, device, dtype):
for include_self in (True, False):
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='sum', unique_indices=False,
include_self=include_self)
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_prod(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='prod', unique_indices=False,
include_self=include_self)
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_mean(self, device, dtype):
for include_self in (True, False):
for deterministic in [False, True]:
with DeterministicGuard(deterministic):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='mean', unique_indices=False,
include_self=include_self)
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_amax(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='amax', unique_indices=False,
include_self=include_self)
# simple test for nan/inf propagation
if (dtype.is_floating_point):
input = torch.zeros(3, device=device, dtype=dtype)
src = torch.tensor([1, float('nan'), -float('inf'), -float('inf'), 2, float('inf')], device=device, dtype=dtype)
idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
input.scatter_reduce_(0, idx, src, 'amax', include_self=include_self)
expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
if (include_self):
expected_result[1] = 0
self.assertEqual(input, expected_result)
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
def test_scatter_reduce_amin(self, device, dtype):
for include_self in (True, False):
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
is_scalar=False, reduction='amin', unique_indices=False,
include_self=include_self)
# simple test for nan/inf propagation
if (dtype.is_floating_point):
input = torch.zeros(3, device=device, dtype=dtype)
src = torch.tensor([1, float('nan'), -2, -float('inf'), float('inf'), float('inf')], device=device, dtype=dtype)
idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
input.scatter_reduce_(0, idx, src, 'amin', include_self=include_self)
expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
if (include_self):
expected_result[2] = 0
self.assertEqual(input, expected_result)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
def test_scatter_expanded_index(self, device, dtype):
def helper(input_size, idx_size):
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone()
shape = [1] * len(input_size)
shape[0] = idx_size
dim_size = input_size[0]
idx = torch.randint(0, dim_size, shape)
# The fast path on scatter when index is expanded
# will depend on sorted index where the collected src indice
# for each row in input will be mapped to rowptrs in a CSR format.
# Create some empty rows by masking:
mask = (idx > 1) * (idx < 4)
idx[mask] = 0
expanded_shape = input_size
expanded_shape[0] = idx_size
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
out = input.scatter_add(0, idx, src)
out2 = input2.scatter_add(0, idx2, src)
self.assertEqual(out, out2)
for reduce in ["sum", "prod", "mean", "amax", "amin"]:
for include_self in [True, False]:
out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
self.assertEqual(out, out2)
helper([50, 17], 100)
helper([50, 1], 100)
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
@onlyCPU
@dtypes(torch.float32, torch.float64, torch.bfloat16)
def test_gather_expanded_index(self, device, dtype):
# Test when index is [N, 1], which would have stride [1, 0]
# should be excluded from the fast path when index ix expanded
input = torch.arange(25).view(5, 5)
input2 = input.to(dtype=dtype)
idx = torch.arange(5).view(5, 1)
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx)
self.assertEqual(out.to(dtype=dtype), out2)
def helper(input_size, idx_size):
input = torch.randn(input_size, device=device).to(dtype=dtype)
input2 = input.clone()
shape = [1] * len(input_size)
shape[0] = idx_size
dim_size = input_size[0]
idx = torch.randint(0, dim_size, shape)
# Test the fast path on gather when index is expanded
expanded_shape = input_size
expanded_shape[0] = idx_size
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx2)
self.assertEqual(out, out2)
# test unsqueezed index
# expanded_index kernel can not handle the case:
# the size > 1 and stride == 1 at a dimension.
# for example: the index with size of [1, 8, 7], stride of [1, 1, 0].
# see https://github.com/pytorch/pytorch/issues/129093
def unsqueeze_helper(idx, dim):
if dim == 2:
return idx.unsqueeze(1).t()
else:
return unsqueeze_helper(idx, dim - 1).unsqueeze(dim - 1)
idx = torch.randint(0, dim_size, (input.shape[1],))
idx = unsqueeze_helper(idx, len(input_size))
expanded_shape[0] = 1
idx = idx.expand(expanded_shape)
idx2 = idx.contiguous()
out = torch.gather(input, 0, idx)
out2 = torch.gather(input2, 0, idx2)
self.assertEqual(out, out2)
helper([50, 17], 100)
helper([50, 1], 100)
helper([50, 8, 7], 100)
helper([50, 3, 4, 5], 100)
# Generic Device Test Framework instantation, see
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
# for details.
instantiate_device_type_tests(TestScatterGather, globals())
if __name__ == '__main__':
run_tests()