[torch][segment_reduce] Add support for initial value (#56923)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56923

Next Steps in order:
- Add backward support for CUDA
- Add support for more aggregation types
- Benchmarking (for cuda mainly)/more testing/documentation
- Support for multi dimension

Test Plan: Updated unit test to include 0 length segment as well.

Reviewed By: ngimel

Differential Revision: D27992228

fbshipit-source-id: 28851811f8a784a63162721c511d69e617a93727
This commit is contained in:
Serhat Yilmaz
2021-04-30 18:00:20 -07:00
committed by Facebook GitHub Bot
parent bd347012ec
commit 20eac093a7
6 changed files with 63 additions and 53 deletions

View File

@ -13,16 +13,24 @@ from torch.testing._internal.common_utils import (
class TestSegmentReductions(TestCase):
def _test_max_simple_1d(self, device, dtype, unsafe, axis):
lengths = torch.tensor([1, 2, 3], device=device)
lengths = torch.tensor([1, 2, 3, 0], device=device)
data = torch.tensor(
[1, float("nan"), 3, 4, 5, 5],
device=device,
dtype=dtype,
requires_grad=True,
)
expected_result = torch.tensor([1, float("nan"), 5], device=device, dtype=dtype)
initial_value = 0
expected_result = torch.tensor(
[1, float("nan"), 5, initial_value], device=device, dtype=dtype
)
actual_result = torch.segment_reduce(
data=data, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe
data=data,
reduce="max",
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
)
self.assertEqual(
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
@ -52,7 +60,12 @@ class TestSegmentReductions(TestCase):
self.assertTrue(
gradcheck(
lambda x: torch.segment_reduce(
data=x, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe
data=x,
reduce="max",
lengths=lengths,
axis=axis,
unsafe=unsafe,
initial=initial_value,
),
(data,),
)