Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77061

Approved by: https://github.com/cpuhrsch
This commit is contained in:
Mikayla Gawarecki
2022-06-10 22:33:06 +00:00
committed by PyTorch MergeBot
parent 38350acf8f
commit e727539c29
4 changed files with 170 additions and 80 deletions

View File

@ -7,13 +7,12 @@ import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
onlyCPU
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
gradcheck,
parametrize
parametrize,
)
@ -300,7 +299,6 @@ class TestSegmentReductions(TestCase):
)
)
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
@onlyCPU # will be removed in next PR where CUDA implementation of segment_reduce is adjusted
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
val_dtype, length_dtype = dtypes
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
@ -384,7 +382,6 @@ class TestSegmentReductions(TestCase):
axis=dim,
unsafe=True,
)
self.assertEqual(actual_result, expected)
if val_dtype == torch.float64:
@ -469,20 +466,19 @@ class TestSegmentReductions(TestCase):
check_backward,
)
@onlyCPU
@dtypes(torch.int, torch.int64)
def test_unsafe_flag(self, device, dtype):
length_type = dtype
lengths = torch.tensor([0, 2, 3, 0], dtype=length_type)
data = torch.arange(6).float()
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
data = torch.arange(6, dtype=torch.float, device=device)
# test for error on 1-D lenghts
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
# test for error on multi-D lengths
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type)
nd_data = torch.arange(12).reshape(2, 6).float()
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)