mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch][segment_reduce] Add support for initial value (#56923)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56923 Next Steps in order: - Add backward support for CUDA - Add support for more aggregation types - Benchmarking (for cuda mainly)/more testing/documentation - Support for multi dimension Test Plan: Updated unit test to include 0 length segment as well. Reviewed By: ngimel Differential Revision: D27992228 fbshipit-source-id: 28851811f8a784a63162721c511d69e617a93727
This commit is contained in:
committed by
Facebook GitHub Bot
parent
bd347012ec
commit
20eac093a7
@ -13,16 +13,24 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
class TestSegmentReductions(TestCase):
|
||||
def _test_max_simple_1d(self, device, dtype, unsafe, axis):
|
||||
lengths = torch.tensor([1, 2, 3], device=device)
|
||||
lengths = torch.tensor([1, 2, 3, 0], device=device)
|
||||
data = torch.tensor(
|
||||
[1, float("nan"), 3, 4, 5, 5],
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
)
|
||||
expected_result = torch.tensor([1, float("nan"), 5], device=device, dtype=dtype)
|
||||
initial_value = 0
|
||||
expected_result = torch.tensor(
|
||||
[1, float("nan"), 5, initial_value], device=device, dtype=dtype
|
||||
)
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe
|
||||
data=data,
|
||||
reduce="max",
|
||||
lengths=lengths,
|
||||
axis=axis,
|
||||
unsafe=unsafe,
|
||||
initial=initial_value,
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
||||
@ -52,7 +60,12 @@ class TestSegmentReductions(TestCase):
|
||||
self.assertTrue(
|
||||
gradcheck(
|
||||
lambda x: torch.segment_reduce(
|
||||
data=x, reduce="max", lengths=lengths, axis=axis, unsafe=unsafe
|
||||
data=x,
|
||||
reduce="max",
|
||||
lengths=lengths,
|
||||
axis=axis,
|
||||
unsafe=unsafe,
|
||||
initial=initial_value,
|
||||
),
|
||||
(data,),
|
||||
)
|
||||
|
Reference in New Issue
Block a user