mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch] Add cuda support for segment reduction 'max' (#54175)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54175 Building on top of previous PR. This PR adds cuda support for 1D max reduction. Next steps: - Add support for other major reduction types (e.g. min, sum) for 1D tensor - Documentation for the op - Perf optimizations and benchmark util - Backward support (not high priority) - Support for multi dimensional tensors (on data and lengths) (not high priority) - Support for 'indices' (not high priority) Test Plan: Added unit test Reviewed By: ngimel Differential Revision: D27121170 fbshipit-source-id: 1c2565f42e2903e6fc089d56983ce8857efbfa3c
This commit is contained in:
committed by
Facebook GitHub Bot
parent
778f9eab6c
commit
eb5e1fc713
@ -1,8 +1,8 @@
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
@ -11,25 +11,29 @@ from torch.testing._internal.common_utils import (
|
||||
|
||||
|
||||
class TestSegmentReductions(TestCase):
|
||||
@onlyCPU
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_max_simple_1d(self, device, dtype):
|
||||
def _test_max_simple_1d(self, device, dtype, unsafe):
|
||||
lengths = torch.tensor([1, 2, 3], device=device)
|
||||
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
|
||||
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data, reduce="max", lengths=lengths, axis=0, unsafe=False
|
||||
data=data, reduce="max", lengths=lengths, axis=0, unsafe=unsafe
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
||||
)
|
||||
actual_result = torch.segment_reduce(
|
||||
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=False
|
||||
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=unsafe
|
||||
)
|
||||
self.assertEqual(
|
||||
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
||||
)
|
||||
|
||||
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
||||
def test_max_simple_1d(self, device, dtype):
|
||||
self._test_max_simple_1d(device, dtype, False)
|
||||
self._test_max_simple_1d(device, dtype, True)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestSegmentReductions, globals())
|
||||
|
||||
|
||||
Reference in New Issue
Block a user