mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "Support multi-dimensional lengths in segment_reduce to support pytorch_scatter.segment_* functionalities (CUDA)"
This reverts commit 40f7ef1f3db9717d8149a0bd1e8b8c80c8600753.
Reverted https://github.com/pytorch/pytorch/pull/77061 on behalf of https://github.com/janeyx99 due to Broke segment_reduce tests on trunk, e.g., 40f7ef1f3d
This commit is contained in:
@ -7,12 +7,13 @@ import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
dtypes,
|
||||
onlyCPU
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
run_tests,
|
||||
gradcheck,
|
||||
parametrize,
|
||||
parametrize
|
||||
|
||||
)
|
||||
|
||||
@ -299,6 +300,7 @@ class TestSegmentReductions(TestCase):
|
||||
)
|
||||
)
|
||||
@parametrize("reduce", ['sum', 'prod', 'min', 'max', 'mean'])
|
||||
@onlyCPU # will be removed in next PR where CUDA implementation of segment_reduce is adjusted
|
||||
def test_pytorch_scatter_test_cases(self, device, dtypes, reduce):
|
||||
val_dtype, length_dtype = dtypes
|
||||
# zero-length segments are filled with reduction inits contrary to pytorch_scatter.
|
||||
@ -382,6 +384,7 @@ class TestSegmentReductions(TestCase):
|
||||
axis=dim,
|
||||
unsafe=True,
|
||||
)
|
||||
|
||||
self.assertEqual(actual_result, expected)
|
||||
|
||||
if val_dtype == torch.float64:
|
||||
@ -466,19 +469,20 @@ class TestSegmentReductions(TestCase):
|
||||
check_backward,
|
||||
)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(torch.int, torch.int64)
|
||||
def test_unsafe_flag(self, device, dtype):
|
||||
length_type = dtype
|
||||
lengths = torch.tensor([0, 2, 3, 0], device=device, dtype=length_type)
|
||||
data = torch.arange(6, dtype=torch.float, device=device)
|
||||
lengths = torch.tensor([0, 2, 3, 0], dtype=length_type)
|
||||
data = torch.arange(6).float()
|
||||
|
||||
# test for error on 1-D lenghts
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
||||
torch.segment_reduce(data, 'sum', lengths=lengths, axis=0, unsafe=False)
|
||||
|
||||
# test for error on multi-D lengths
|
||||
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type, device=device)
|
||||
nd_data = torch.arange(12, dtype=torch.float, device=device).reshape(2, 6)
|
||||
nd_lengths = torch.tensor([[0, 3, 3, 0], [2, 3, 0, 0]], dtype=length_type)
|
||||
nd_data = torch.arange(12).reshape(2, 6).float()
|
||||
with self.assertRaisesRegex(RuntimeError, "Expected all rows of lengths along axis"):
|
||||
torch.segment_reduce(nd_data, 'sum', lengths=nd_lengths, axis=1, unsafe=False)
|
||||
|
||||
|
Reference in New Issue
Block a user