mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Allow data size equal to 0 for SegmentReduce (#99733)
Summary: Support special case that data size can be 0 for SegmentReduce. Example code below: ``` x = torch.ones((0, 6)).cuda() lengths = torch.tensor([0, 0]).cuda() torch.segment_reduce(x, "sum", lengths=lengths, unsafe=False, initial=0) ``` Previously, error message: Expected data.numel() > 0 to be true, but got false. Now expect to return 0. Test Plan: contbuild & OSS CI Differential Revision: D45133827 Pull Request resolved: https://github.com/pytorch/pytorch/pull/99733 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7a8d0ccddf
commit
efed5a4969
@ -174,6 +174,63 @@ class TestSegmentReductions(TestCase):
|
||||
length_type,
|
||||
)
|
||||
|
||||
@dtypes(
|
||||
*product(
|
||||
(torch.half, torch.bfloat16, torch.float, torch.double),
|
||||
(torch.int, torch.int64),
|
||||
)
|
||||
)
|
||||
def test_simple_zero_length(self, device, dtypes):
|
||||
val_dtype, length_type = dtypes
|
||||
lengths = [0, 0]
|
||||
data = torch.ones((0))
|
||||
|
||||
for reduction in reductions:
|
||||
for initial in [0, None]:
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
expected_result = [default_value, default_value]
|
||||
expected_grad = []
|
||||
elif reduction == "mean":
|
||||
expected_result = [default_value, default_value]
|
||||
expected_grad = []
|
||||
elif reduction == "min":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
expected_result = [default_value, default_value]
|
||||
expected_grad = []
|
||||
elif reduction == "sum":
|
||||
expected_result = [default_value, default_value]
|
||||
expected_grad = []
|
||||
elif reduction == "prod":
|
||||
if initial is not None:
|
||||
initial_value = 2 # 0 initial_value will zero out everything for prod
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
expected_result = [default_value, default_value]
|
||||
expected_grad = []
|
||||
else:
|
||||
expected_result = [default_value, default_value]
|
||||
expected_grad = []
|
||||
for axis in [0]:
|
||||
for unsafe in [True, False]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
val_dtype,
|
||||
unsafe,
|
||||
axis,
|
||||
initial_value,
|
||||
data,
|
||||
lengths,
|
||||
expected_result,
|
||||
expected_grad,
|
||||
check_backward,
|
||||
length_type,
|
||||
)
|
||||
|
||||
@dtypes(
|
||||
*product(
|
||||
(torch.half, torch.bfloat16, torch.float, torch.double),
|
||||
|
Reference in New Issue
Block a user