Allow data size equal to 0 for SegmentReduce (#99733)

Summary:
Support special case that data size can be 0 for SegmentReduce.

Example code below:
```
x = torch.ones((0, 6)).cuda()
lengths = torch.tensor([0, 0]).cuda()
torch.segment_reduce(x, "sum", lengths=lengths, unsafe=False, initial=0)
```
Previously, error message: Expected data.numel() > 0 to be true, but got false.
Now expect to return 0.

Test Plan: contbuild & OSS CI

Differential Revision: D45133827

Pull Request resolved: https://github.com/pytorch/pytorch/pull/99733
Approved by: https://github.com/ngimel
This commit is contained in:
Mengchi Zhang
2023-04-23 01:59:45 +00:00
committed by PyTorch MergeBot
parent 7a8d0ccddf
commit efed5a4969
2 changed files with 58 additions and 1 deletions

View File

@ -174,6 +174,63 @@ class TestSegmentReductions(TestCase):
length_type,
)
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),
(torch.int, torch.int64),
)
)
def test_simple_zero_length(self, device, dtypes):
val_dtype, length_type = dtypes
lengths = [0, 0]
data = torch.ones((0))
for reduction in reductions:
for initial in [0, None]:
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
expected_result = [default_value, default_value]
expected_grad = []
elif reduction == "mean":
expected_result = [default_value, default_value]
expected_grad = []
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)
expected_result = [default_value, default_value]
expected_grad = []
elif reduction == "sum":
expected_result = [default_value, default_value]
expected_grad = []
elif reduction == "prod":
if initial is not None:
initial_value = 2 # 0 initial_value will zero out everything for prod
default_value = get_default_value(initial_value, reduction)
expected_result = [default_value, default_value]
expected_grad = []
else:
expected_result = [default_value, default_value]
expected_grad = []
for axis in [0]:
for unsafe in [True, False]:
self._test_common(
reduction,
device,
val_dtype,
unsafe,
axis,
initial_value,
data,
lengths,
expected_result,
expected_grad,
check_backward,
length_type,
)
@dtypes(
*product(
(torch.half, torch.bfloat16, torch.float, torch.double),