mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[torch][segment_reduce] Update default values when initial value is not set (#61266)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/61266 Same as title. Mainly this concludes the initially planned features from the op. Only missing functionality is to do reduction on any axis (currently axis 0 only is supported). Test Plan: Updated unit test. Reviewed By: ngimel Differential Revision: D29552037 fbshipit-source-id: 023c7cbf750a0671f76082708f14c05739dda07a
This commit is contained in:
committed by
Facebook GitHub Bot
parent
a78ad5dc4c
commit
f84a441718
@ -92,7 +92,8 @@ Tensor _segment_reduce_cpu_kernel(
|
||||
// ===== step3: finalize reduction
|
||||
TORCH_CHECK(lengths_data[i] >= 0);
|
||||
|
||||
if (lengths_data[i] == 0 && !initial.has_value()) {
|
||||
if (lengths_data[i] == 0 && !initial.has_value() &&
|
||||
reduction == SegmentReductionType::MEAN) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN &&
|
||||
@ -229,7 +230,6 @@ Tensor segment_reduce_kernel(
|
||||
if (!unsafe) {
|
||||
auto min_length = lengths_value.min().item<int64_t>();
|
||||
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
|
||||
TORCH_CHECK(min_length != 0 || initial.has_value());
|
||||
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.size(axis));
|
||||
}
|
||||
|
||||
|
@ -126,7 +126,8 @@ __global__ void segment_reduce_forward_kernel(
|
||||
|
||||
// ===== step3: finalize reduction
|
||||
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0);
|
||||
if (lengths_data[row_id] == 0 && !is_initial_set) {
|
||||
if (lengths_data[row_id] == 0 && !is_initial_set &&
|
||||
reduction == SegmentReductionType::MEAN) {
|
||||
initial_value = static_cast<scalar_t>(NAN);
|
||||
} else if (
|
||||
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 &&
|
||||
|
@ -16,6 +16,19 @@ from torch.testing._internal.common_utils import (
|
||||
reductions = ["max", "mean", "min", "sum"]
|
||||
|
||||
|
||||
def get_default_value(initial_value, reduction):
|
||||
if initial_value is not None:
|
||||
return initial_value
|
||||
if reduction == "max":
|
||||
return -float("Inf")
|
||||
elif reduction == "mean":
|
||||
return float("nan")
|
||||
elif reduction == "min":
|
||||
return float("Inf")
|
||||
elif reduction == "sum":
|
||||
return 0.0
|
||||
|
||||
|
||||
class TestSegmentReductions(TestCase):
|
||||
def _test_common(
|
||||
self,
|
||||
@ -98,28 +111,29 @@ class TestSegmentReductions(TestCase):
|
||||
val_dtype, length_type = dtypes
|
||||
lengths = [1, 2, 3, 0]
|
||||
data = [1, float("nan"), 3, 4, 5, 5]
|
||||
check_backward = True
|
||||
|
||||
for reduction in reductions:
|
||||
for initial in [0, None]:
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
initial_value = 0
|
||||
expected_result = [1, float("nan"), 5, initial_value]
|
||||
expected_result = [1, float("nan"), 5, default_value]
|
||||
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
|
||||
elif reduction == "mean":
|
||||
initial_value = 0
|
||||
expected_result = [1, float("nan"), 4.666, initial_value]
|
||||
expected_result = [1, float("nan"), 4.666, default_value]
|
||||
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
|
||||
elif reduction == "min":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
expected_result = [1, float("nan"), 4, initial_value]
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
expected_result = [1, float("nan"), 4, default_value]
|
||||
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
|
||||
elif reduction == "sum":
|
||||
initial_value = 0
|
||||
expected_result = [1, float("nan"), 14, initial_value]
|
||||
expected_result = [1, float("nan"), 14, default_value]
|
||||
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
|
||||
for axis in [0, -1]:
|
||||
for unsafe in [True, False]:
|
||||
for initial in [initial_value, None]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
@ -143,19 +157,21 @@ class TestSegmentReductions(TestCase):
|
||||
)
|
||||
def test_multi_d_simple(self, device, dtypes):
|
||||
val_dtype, length_type = dtypes
|
||||
check_backward = True
|
||||
axis = 0
|
||||
lengths = [1, 2, 3, 0]
|
||||
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
|
||||
|
||||
for reduction in reductions:
|
||||
for initial in [0, None]:
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "max":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
[4, 3],
|
||||
[initial_value, initial_value],
|
||||
[default_value, default_value],
|
||||
]
|
||||
expected_grad = [
|
||||
[1, 1],
|
||||
@ -166,12 +182,11 @@ class TestSegmentReductions(TestCase):
|
||||
[0, 1],
|
||||
]
|
||||
elif reduction == "mean":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
[3, 2],
|
||||
[initial_value, initial_value],
|
||||
[default_value, default_value],
|
||||
]
|
||||
expected_grad = [
|
||||
[1.0, 1.0],
|
||||
@ -182,12 +197,14 @@ class TestSegmentReductions(TestCase):
|
||||
[0.333, 0.333],
|
||||
]
|
||||
elif reduction == "min":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
[2, 1],
|
||||
[initial_value, initial_value],
|
||||
[default_value, default_value],
|
||||
]
|
||||
expected_grad = [
|
||||
[1.0, 1.0],
|
||||
@ -198,12 +215,11 @@ class TestSegmentReductions(TestCase):
|
||||
[1, 0],
|
||||
]
|
||||
elif reduction == "sum":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
[1, 1],
|
||||
[float("nan"), float("nan")],
|
||||
[9, 6],
|
||||
[initial_value, initial_value],
|
||||
[default_value, default_value],
|
||||
]
|
||||
expected_grad = [
|
||||
[1.0, 1.0],
|
||||
@ -214,7 +230,6 @@ class TestSegmentReductions(TestCase):
|
||||
[1.0, 1.0],
|
||||
]
|
||||
for unsafe in [True, False]:
|
||||
for initial in [initial_value, None]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
@ -246,14 +261,13 @@ class TestSegmentReductions(TestCase):
|
||||
check_backward = False
|
||||
|
||||
for reduction in reductions:
|
||||
if reduction == "max":
|
||||
initial_value = 0
|
||||
if reduction == "max":
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.max(data, axis=0).tolist(),
|
||||
]
|
||||
elif reduction == "mean":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.mean(data, axis=0).tolist(),
|
||||
@ -265,13 +279,11 @@ class TestSegmentReductions(TestCase):
|
||||
np.min(data, axis=0).tolist(),
|
||||
]
|
||||
elif reduction == "sum":
|
||||
initial_value = 0
|
||||
expected_result = [
|
||||
np.full((2, 5), initial_value).tolist(),
|
||||
np.sum(data, axis=0).tolist(),
|
||||
]
|
||||
for unsafe in [True, False]:
|
||||
for initial in [initial_value, None]:
|
||||
self._test_common(
|
||||
reduction,
|
||||
device,
|
||||
|
Reference in New Issue
Block a user