mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Per title, improve x[index] cuda perf for the common case of indexing along the first dim, using vectorized gather kernel Pull Request resolved: https://github.com/pytorch/pytorch/pull/151753 Approved by: https://github.com/eqy
420 lines
20 KiB
Python
420 lines
20 KiB
Python
# Owner(s): ["module: scatter & gather ops"]
|
|
|
|
import random
|
|
|
|
import torch
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_utils import \
|
|
(parametrize, run_tests, TestCase, DeterministicGuard)
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, onlyCPU, dtypes, dtypesIfCUDA,
|
|
toleranceOverride, tol,)
|
|
from torch.testing._internal.common_dtype import \
|
|
(get_all_dtypes,)
|
|
|
|
# Protects against includes accidentally setting the default dtype
|
|
assert torch.get_default_dtype() is torch.float32
|
|
|
|
|
|
# Note: test_scatter_gather_ops.py
|
|
# This test file tests scatter and gather operations,
|
|
# like torch.scatter and torch.gather.
|
|
|
|
class TestScatterGather(TestCase):
|
|
# Fills an index tensor with valid indices
|
|
def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o, unique_indices=True):
|
|
for i in range(1 if dim == 0 else m):
|
|
for j in range(1 if dim == 1 else n):
|
|
for k in range(1 if dim == 2 else o):
|
|
ii = [i, j, k]
|
|
ii[dim] = slice(0, idx.size(dim) + 1)
|
|
if unique_indices:
|
|
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
|
|
else:
|
|
idx[tuple(ii)] = torch.randint(dim_size, (elems_per_row,))
|
|
|
|
@dtypes(torch.float32, torch.complex64)
|
|
def test_gather(self, device, dtype):
|
|
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
|
|
elems_per_row = random.randint(1, 10)
|
|
dim = random.randrange(3)
|
|
|
|
src = make_tensor((m, n, o), device=device, dtype=dtype)
|
|
idx_size = [m, n, o]
|
|
idx_size[dim] = elems_per_row
|
|
idx = make_tensor(idx_size, device=device, dtype=torch.long)
|
|
self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
|
|
|
|
actual = torch.gather(src, dim, idx)
|
|
expected = torch.zeros(idx_size, device=device, dtype=dtype)
|
|
for i in range(idx_size[0]):
|
|
for j in range(idx_size[1]):
|
|
for k in range(idx_size[2]):
|
|
ii = [i, j, k]
|
|
ii[dim] = idx[i, j, k]
|
|
expected[i, j, k] = src[tuple(ii)]
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
# Guarded because torch.max isn't defined for complex types
|
|
if not dtype.is_complex:
|
|
src = make_tensor((3, 4, 5), device=device, dtype=dtype)
|
|
expected, idx = src.max(2, True)
|
|
actual = torch.gather(src, 2, idx)
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
@dtypes(torch.int8, torch.bfloat16)
|
|
def test_gather_large(self, device, dtype):
|
|
# test larger shapes to check vectorized implementation
|
|
for (m, n, k) in ((4096, 3072, 4096), (4096, 3072, 4100)):
|
|
src = make_tensor((m, k), device=device, dtype=dtype)
|
|
alloc0 = torch.empty(src.nelement() * 2, device=device, dtype=dtype)
|
|
discontig = alloc0.view(m, 2 * k)[:, ::2].copy_(src)
|
|
alloc1 = torch.empty(src.nelement() + 1, device=device, dtype=dtype)
|
|
misaligned = alloc1[1:].view(m, k).copy_(src)
|
|
num_ind = n
|
|
for dim in (0, 1):
|
|
max_ind = src.shape[dim]
|
|
ind0 = torch.randint(max_ind, (num_ind,), device=device)
|
|
shape_ind = [1] * src.ndim
|
|
shape_ind[dim] = ind0.shape[0]
|
|
shape_out = list(src.shape)
|
|
shape_out[dim] = ind0.shape[0]
|
|
ind = ind0.view(shape_ind).expand(shape_out)
|
|
res = torch.gather(src, dim=dim, index=ind)
|
|
ref = src[ind0] if dim == 0 else src[:, ind0]
|
|
self.assertEqual(res, ref, atol=0, rtol=0)
|
|
if res.device.type == "cuda":
|
|
ref_cpu = src.cpu()[ind0.cpu()] if dim == 0 else src.cpu()[:, ind0.cpu()]
|
|
self.assertEqual(res.cpu(), ref_cpu, atol=0, rtol=0)
|
|
res_ind_neg = src[ind0 - src.shape[dim]] if dim == 0 else src[:, ind0 - src.shape[1]]
|
|
self.assertEqual(res_ind_neg, ref, atol=0, rtol=0)
|
|
res = torch.gather(discontig, dim=dim, index=ind)
|
|
self.assertEqual(res, ref, atol=0, rtol=0)
|
|
res_ind = discontig[ind0] if dim == 0 else discontig[:, ind0]
|
|
self.assertEqual(res_ind, ref, atol=0, rtol=0)
|
|
res = torch.gather(misaligned, dim=dim, index=ind)
|
|
self.assertEqual(res, ref, atol=0, rtol=0)
|
|
res_ind = misaligned[ind0] if dim == 0 else misaligned[:, ind0]
|
|
self.assertEqual(res_ind, ref, atol=0, rtol=0)
|
|
|
|
|
|
@dtypes(torch.bool)
|
|
def test_gather_bool(self, device, dtype):
|
|
src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)
|
|
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
|
|
actual = torch.gather(src, 1, idx)
|
|
expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype)
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
@parametrize("sparse_grad", [False, True])
|
|
@dtypes(torch.float32, torch.float64)
|
|
def test_gather_backward_with_empty_index_tensor(self, device, dtype, sparse_grad):
|
|
dim = -1
|
|
input = torch.rand([10, 5], dtype=dtype, device=device, requires_grad=True)
|
|
index = torch.randint(0, 2, [3, 0], dtype=torch.int64, device=device)
|
|
res = torch.gather(input, dim, index, sparse_grad=sparse_grad)
|
|
res.sum().backward()
|
|
grad = input.grad.to_dense() if sparse_grad else input.grad
|
|
expected_grad = torch.zeros_like(input, requires_grad=False)
|
|
self.assertEqual(grad, expected_grad, atol=0, rtol=0)
|
|
|
|
def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction,
|
|
unique_indices=True, include_self=True):
|
|
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
|
|
elems_per_row = random.randint(1, 10)
|
|
dim = random.randrange(3)
|
|
|
|
idx_size = [m, n, o]
|
|
idx_size[dim] = elems_per_row
|
|
idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
|
|
self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o, unique_indices)
|
|
|
|
if is_scalar:
|
|
src = random.random()
|
|
else:
|
|
src_size = [random.randint(1, 5) + s for s in idx_size]
|
|
src = make_tensor(tuple(src_size), device=device, dtype=dtype)
|
|
|
|
base = make_tensor((m, n, o), device=device, dtype=dtype)
|
|
if reduction is not None:
|
|
if fn is torch.Tensor.scatter_reduce_:
|
|
actual = fn(base.clone(), dim, idx, src, reduce=reduction, include_self=include_self)
|
|
else:
|
|
actual = fn(base.clone(), dim, idx, src, reduce=reduction)
|
|
else:
|
|
actual = fn(base.clone(), dim, idx, src)
|
|
|
|
expected = base.clone()
|
|
counts = torch.zeros(base.shape, dtype=torch.long, device=device) + include_self
|
|
for i in range(idx_size[0]):
|
|
for j in range(idx_size[1]):
|
|
for k in range(idx_size[2]):
|
|
ii = [i, j, k]
|
|
ii[dim] = idx[i, j, k]
|
|
if fn is torch.Tensor.scatter_add_:
|
|
expected[tuple(ii)] += src[i, j, k]
|
|
else:
|
|
# method may be 'scatter_', 'scatter', 'scatter_reduce'
|
|
# or 'scatter_reduce_', the former two might have a reduction argument
|
|
# while the latter two always do
|
|
value = src if is_scalar else src[i, j, k]
|
|
|
|
if ((not include_self) and counts[tuple(ii)] == 0):
|
|
expected[tuple(ii)] = value
|
|
else:
|
|
if reduction == "add" or reduction == "sum":
|
|
expected[tuple(ii)] += value
|
|
elif reduction == "multiply" or reduction == "prod":
|
|
expected[tuple(ii)] *= value
|
|
elif reduction == "amax":
|
|
expected[tuple(ii)] = max(expected[tuple(ii)], value)
|
|
elif reduction == "amin":
|
|
expected[tuple(ii)] = min(expected[tuple(ii)], value)
|
|
elif reduction == "mean":
|
|
expected[tuple(ii)] += value
|
|
else:
|
|
expected[tuple(ii)] = value
|
|
|
|
counts[tuple(ii)] += 1
|
|
|
|
if (reduction == "mean"):
|
|
counts.masked_fill_(counts == 0, 1)
|
|
if (dtype.is_floating_point or dtype.is_complex):
|
|
expected /= counts
|
|
else:
|
|
expected.div_(counts, rounding_mode="floor")
|
|
|
|
if dtype == torch.float16 or dtype == torch.bfloat16:
|
|
# Some CUDA kernels (e.g. indexing_backward_kernel_stride_1) that are called during
|
|
# the test use fp32 for internal accumulation for improved accuracy. When using 16 bit
|
|
# precision types can be small differences
|
|
self.assertEqual(actual, expected, atol=0.04, rtol=0.05)
|
|
else:
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
# Tests empty index
|
|
dst = make_tensor((2, 2), device=device, dtype=dtype)
|
|
idx = torch.tensor((), device=device, dtype=torch.long)
|
|
src = make_tensor((2, 2), device=device, dtype=dtype)
|
|
if reduction is not None:
|
|
actual = fn(dst, 0, idx, src, reduce=reduction)
|
|
else:
|
|
actual = fn(dst, 0, idx, src)
|
|
self.assertEqual(actual, dst, atol=0, rtol=0)
|
|
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter_(self, device, dtype):
|
|
for deterministic in [False, True]:
|
|
with DeterministicGuard(deterministic):
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction=None)
|
|
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter__scalar(self, device, dtype):
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=True, reduction=None)
|
|
|
|
# FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
|
|
@toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
|
|
@dtypesIfCUDA(torch.float16, torch.float32)
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter__reductions(self, device, dtype):
|
|
for reduction in ("add", "multiply"):
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction=reduction)
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=True, reduction=reduction)
|
|
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter_add_(self, device, dtype):
|
|
for deterministic in [False, True]:
|
|
with DeterministicGuard(deterministic):
|
|
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction=None)
|
|
|
|
@dtypes(torch.float32)
|
|
def test_scatter_add_mult_index_base(self, device, dtype):
|
|
for deterministic in [False, True]:
|
|
with DeterministicGuard(deterministic):
|
|
m, n = 30, 40
|
|
idx = torch.zeros(m, n, device=device, dtype=torch.long)
|
|
src = torch.ones(m, n, device=device, dtype=dtype)
|
|
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
|
|
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
|
|
|
|
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
|
|
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
|
|
|
|
# FIXME: discrepancy between bool ReduceAdd on CUDA and CPU (a + b on CPU and buggy a && b on CUDA)
|
|
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
|
|
def test_scatter_reduce_sum(self, device, dtype):
|
|
for include_self in (True, False):
|
|
for deterministic in [False, True]:
|
|
with DeterministicGuard(deterministic):
|
|
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction='sum', unique_indices=False,
|
|
include_self=include_self)
|
|
|
|
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True))
|
|
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
|
def test_scatter_reduce_prod(self, device, dtype):
|
|
for include_self in (True, False):
|
|
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction='prod', unique_indices=False,
|
|
include_self=include_self)
|
|
|
|
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_bool=False))
|
|
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
|
def test_scatter_reduce_mean(self, device, dtype):
|
|
for include_self in (True, False):
|
|
for deterministic in [False, True]:
|
|
with DeterministicGuard(deterministic):
|
|
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction='mean', unique_indices=False,
|
|
include_self=include_self)
|
|
|
|
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
|
|
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
|
def test_scatter_reduce_amax(self, device, dtype):
|
|
for include_self in (True, False):
|
|
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction='amax', unique_indices=False,
|
|
include_self=include_self)
|
|
# simple test for nan/inf propagation
|
|
if (dtype.is_floating_point):
|
|
input = torch.zeros(3, device=device, dtype=dtype)
|
|
src = torch.tensor([1, float('nan'), -float('inf'), -float('inf'), 2, float('inf')], device=device, dtype=dtype)
|
|
idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
|
|
input.scatter_reduce_(0, idx, src, 'amax', include_self=include_self)
|
|
expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
|
|
if (include_self):
|
|
expected_result[1] = 0
|
|
self.assertEqual(input, expected_result)
|
|
|
|
|
|
@dtypes(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False))
|
|
@dtypesIfCUDA(*get_all_dtypes(include_half=True, include_bfloat16=True, include_complex=False, include_bool=False))
|
|
def test_scatter_reduce_amin(self, device, dtype):
|
|
for include_self in (True, False):
|
|
self._test_scatter_base(torch.Tensor.scatter_reduce_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction='amin', unique_indices=False,
|
|
include_self=include_self)
|
|
# simple test for nan/inf propagation
|
|
if (dtype.is_floating_point):
|
|
input = torch.zeros(3, device=device, dtype=dtype)
|
|
src = torch.tensor([1, float('nan'), -2, -float('inf'), float('inf'), float('inf')], device=device, dtype=dtype)
|
|
idx = torch.tensor([0, 0, 1, 1, 2, 2], device=device)
|
|
input.scatter_reduce_(0, idx, src, 'amin', include_self=include_self)
|
|
expected_result = torch.tensor([float('nan'), -float('inf'), float('inf')], device=device, dtype=dtype)
|
|
if (include_self):
|
|
expected_result[2] = 0
|
|
self.assertEqual(input, expected_result)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float32, torch.float64, torch.bfloat16, torch.float16)
|
|
def test_scatter_expanded_index(self, device, dtype):
|
|
def helper(input_size, idx_size):
|
|
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
|
input2 = input.clone()
|
|
|
|
shape = [1] * len(input_size)
|
|
shape[0] = idx_size
|
|
dim_size = input_size[0]
|
|
idx = torch.randint(0, dim_size, shape)
|
|
|
|
# The fast path on scatter when index is expanded
|
|
# will depend on sorted index where the collected src indice
|
|
# for each row in input will be mapped to rowptrs in a CSR format.
|
|
# Create some empty rows by masking:
|
|
mask = (idx > 1) * (idx < 4)
|
|
idx[mask] = 0
|
|
|
|
expanded_shape = input_size
|
|
expanded_shape[0] = idx_size
|
|
idx = idx.expand(expanded_shape)
|
|
idx2 = idx.contiguous()
|
|
src = torch.randn(expanded_shape, device=device).to(dtype=dtype)
|
|
|
|
out = input.scatter_add(0, idx, src)
|
|
out2 = input2.scatter_add(0, idx2, src)
|
|
self.assertEqual(out, out2)
|
|
|
|
for reduce in ["sum", "prod", "mean", "amax", "amin"]:
|
|
for include_self in [True, False]:
|
|
out = input.scatter_reduce(0, idx, src, reduce=reduce, include_self=include_self)
|
|
out2 = input2.scatter_reduce(0, idx2, src, reduce=reduce, include_self=include_self)
|
|
self.assertEqual(out, out2)
|
|
|
|
helper([50, 17], 100)
|
|
helper([50, 1], 100)
|
|
helper([50, 8, 7], 100)
|
|
helper([50, 3, 4, 5], 100)
|
|
|
|
@onlyCPU
|
|
@dtypes(torch.float32, torch.float64, torch.bfloat16)
|
|
def test_gather_expanded_index(self, device, dtype):
|
|
# Test when index is [N, 1], which would have stride [1, 0]
|
|
# should be excluded from the fast path when index ix expanded
|
|
input = torch.arange(25).view(5, 5)
|
|
input2 = input.to(dtype=dtype)
|
|
|
|
idx = torch.arange(5).view(5, 1)
|
|
out = torch.gather(input, 0, idx)
|
|
out2 = torch.gather(input2, 0, idx)
|
|
|
|
self.assertEqual(out.to(dtype=dtype), out2)
|
|
|
|
def helper(input_size, idx_size):
|
|
input = torch.randn(input_size, device=device).to(dtype=dtype)
|
|
input2 = input.clone()
|
|
|
|
shape = [1] * len(input_size)
|
|
shape[0] = idx_size
|
|
dim_size = input_size[0]
|
|
idx = torch.randint(0, dim_size, shape)
|
|
|
|
# Test the fast path on gather when index is expanded
|
|
expanded_shape = input_size
|
|
expanded_shape[0] = idx_size
|
|
idx = idx.expand(expanded_shape)
|
|
idx2 = idx.contiguous()
|
|
|
|
out = torch.gather(input, 0, idx)
|
|
out2 = torch.gather(input2, 0, idx2)
|
|
|
|
self.assertEqual(out, out2)
|
|
|
|
# test unsqueezed index
|
|
# expanded_index kernel can not handle the case:
|
|
# the size > 1 and stride == 1 at a dimension.
|
|
# for example: the index with size of [1, 8, 7], stride of [1, 1, 0].
|
|
# see https://github.com/pytorch/pytorch/issues/129093
|
|
def unsqueeze_helper(idx, dim):
|
|
if dim == 2:
|
|
return idx.unsqueeze(1).t()
|
|
else:
|
|
return unsqueeze_helper(idx, dim - 1).unsqueeze(dim - 1)
|
|
|
|
idx = torch.randint(0, dim_size, (input.shape[1],))
|
|
idx = unsqueeze_helper(idx, len(input_size))
|
|
expanded_shape[0] = 1
|
|
idx = idx.expand(expanded_shape)
|
|
idx2 = idx.contiguous()
|
|
out = torch.gather(input, 0, idx)
|
|
out2 = torch.gather(input2, 0, idx2)
|
|
self.assertEqual(out, out2)
|
|
|
|
helper([50, 17], 100)
|
|
helper([50, 1], 100)
|
|
helper([50, 8, 7], 100)
|
|
helper([50, 3, 4, 5], 100)
|
|
|
|
# Generic Device Test Framework instantation, see
|
|
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
|
# for details.
|
|
instantiate_device_type_tests(TestScatterGather, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|