Revert "Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)"

This reverts commit 40f7ef1f3db9717d8149a0bd1e8b8c80c8600753.

Reverted https://github.com/pytorch/pytorch/pull/77061 on behalf of https://github.com/janeyx99 due to Broke segment_reduce tests on trunk, e.g., 40f7ef1f3d
This commit is contained in:
PyTorch MergeBot
2022-06-10 01:57:34 +00:00
parent 46234df5f1
commit 87a5ecced2
3 changed files with 78 additions and 169 deletions

View File

@ -7,12 +7,13 @@ import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
dtypes,
onlyCPU
)
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
gradcheck,
parametrize,
parametrize
)
@ -299,6 +300,7 @@ class TestSegmentReductions(TestCase):
)
)
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
@onlyCPU # will be removed in next PR where CUDA implementation of segment_reduce is adjusted
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
val_dtype, length_dtype = dtypes
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
@ -382,6 +384,7 @@ class TestSegmentReductions(TestCase):
axis=dim,
unsafe=True,
)
self.assertEqual(actual_result, expected)
if val_dtype == torch.float64:
@ -466,19 +469,20 @@ class TestSegmentReductions(TestCase):
check_backward,
)
@onlyCPU
@dtypes(torch.int, torch.int64)
def test_unsafe_flag(self, device, dtype):
length_type = dtype
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
data = torch.arange(6, dtype=torch.float, device=device)
lengths = torch.tensor([0, 2, 3, 0], dtype=length_type)
data = torch.arange(6).float()
# test for error on 1-D lenghts
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
# test for error on multi-D lengths
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type)
nd_data = torch.arange(12).reshape(2, 6).float()
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)