mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56704 This is re submit of PR: https://github.com/pytorch/pytorch/pull/54175 Main changes compared to original PR: - Switch to importing "<ATen/cuda/cub.cuh>" - Use CUB_WRAPPER to reduce boiler plate code. Test Plan: Will check CI status to make sure a Added unit test Reviewed By: ngimel Differential Revision: D27941257 fbshipit-source-id: 24a0e0c7f6c46126d2606fe42ed03dca15684415
42 lines
1.4 KiB
Python
42 lines
1.4 KiB
Python
import torch
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
dtypes,
|
|
dtypesIfCUDA,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
run_tests,
|
|
)
|
|
|
|
|
|
class TestSegmentReductions(TestCase):
|
|
def _test_max_simple_1d(self, device, dtype, unsafe):
|
|
lengths = torch.tensor([1, 2, 3], device=device)
|
|
data = torch.tensor([1, float("nan"), 3, 4, 5, 6], device=device, dtype=dtype)
|
|
expected_result = torch.tensor([1, float("nan"), 6], device=device, dtype=dtype)
|
|
actual_result = torch.segment_reduce(
|
|
data=data, reduce="max", lengths=lengths, axis=0, unsafe=unsafe
|
|
)
|
|
self.assertEqual(
|
|
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
|
)
|
|
actual_result = torch.segment_reduce(
|
|
data=data, reduce="max", lengths=lengths, axis=-1, unsafe=unsafe
|
|
)
|
|
self.assertEqual(
|
|
expected_result, actual_result, rtol=1e-03, atol=1e-05, equal_nan=True
|
|
)
|
|
|
|
@dtypesIfCUDA(torch.half, torch.bfloat16, torch.float, torch.double)
|
|
@dtypes(torch.half, torch.bfloat16, torch.float, torch.double)
|
|
def test_max_simple_1d(self, device, dtype):
|
|
self._test_max_simple_1d(device, dtype, False)
|
|
self._test_max_simple_1d(device, dtype, True)
|
|
|
|
|
|
instantiate_device_type_tests(TestSegmentReductions, globals())
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|