[torch][segment_reduce] Update default values when initial value is not set (#61266)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/61266

Same as title.
Mainly this concludes the initially planned features from the op. Only missing functionality is to do reduction on any axis (currently axis 0 only is supported).

Test Plan: Updated unit test.

Reviewed By: ngimel

Differential Revision: D29552037

fbshipit-source-id: 023c7cbf750a0671f76082708f14c05739dda07a
This commit is contained in:
Serhat Yilmaz
2021-07-07 13:07:12 -07:00
committed by Facebook GitHub Bot
parent a78ad5dc4c
commit f84a441718
3 changed files with 120 additions and 107 deletions

View File

@ -92,7 +92,8 @@ Tensor _segment_reduce_cpu_kernel(
// ===== step3: finalize reduction
TORCH_CHECK(lengths_data[i] >= 0);
if (lengths_data[i] == 0 && !initial.has_value()) {
if (lengths_data[i] == 0 && !initial.has_value() &&
reduction == SegmentReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN &&
@ -229,7 +230,6 @@ Tensor segment_reduce_kernel(
if (!unsafe) {
auto min_length = lengths_value.min().item<int64_t>();
TORCH_CHECK((min_length >= 0), "lengths contains negative value!");
TORCH_CHECK(min_length != 0 || initial.has_value());
TORCH_CHECK(lengths_value.sum().item<int64_t>() == data.size(axis));
}

View File

@ -126,7 +126,8 @@ __global__ void segment_reduce_forward_kernel(
// ===== step3: finalize reduction
CUDA_KERNEL_ASSERT(lengths_data[row_id] >= 0);
if (lengths_data[row_id] == 0 && !is_initial_set) {
if (lengths_data[row_id] == 0 && !is_initial_set &&
reduction == SegmentReductionType::MEAN) {
initial_value = static_cast<scalar_t>(NAN);
} else if (
reduction == SegmentReductionType::MEAN && lengths_data[row_id] > 0 &&

View File

@ -16,6 +16,19 @@ from torch.testing._internal.common_utils import (
reductions = ["max", "mean", "min", "sum"]
def get_default_value(initial_value, reduction):
if initial_value is not None:
return initial_value
if reduction == "max":
return -float("Inf")
elif reduction == "mean":
return float("nan")
elif reduction == "min":
return float("Inf")
elif reduction == "sum":
return 0.0
class TestSegmentReductions(TestCase):
def _test_common(
self,
@ -98,28 +111,29 @@ class TestSegmentReductions(TestCase):
val_dtype, length_type = dtypes
lengths = [1, 2, 3, 0]
data = [1, float("nan"), 3, 4, 5, 5]
check_backward = True
for reduction in reductions:
for initial in [0, None]:
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
initial_value = 0
expected_result = [1, float("nan"), 5, initial_value]
expected_result = [1, float("nan"), 5, default_value]
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
elif reduction == "mean":
initial_value = 0
expected_result = [1, float("nan"), 4.666, initial_value]
expected_result = [1, float("nan"), 4.666, default_value]
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
expected_result = [1, float("nan"), 4, initial_value]
default_value = get_default_value(initial_value, reduction)
expected_result = [1, float("nan"), 4, default_value]
expected_grad = [1.0, 1.0, 0, 1, 0, 0]
elif reduction == "sum":
initial_value = 0
expected_result = [1, float("nan"), 14, initial_value]
expected_result = [1, float("nan"), 14, default_value]
expected_grad = [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]
for axis in [0, -1]:
for unsafe in [True, False]:
for initial in [initial_value, None]:
self._test_common(
reduction,
device,
@ -143,19 +157,21 @@ class TestSegmentReductions(TestCase):
)
def test_multi_d_simple(self, device, dtypes):
val_dtype, length_type = dtypes
check_backward = True
axis = 0
lengths = [1, 2, 3, 0]
data = [[1, 1], [float("nan"), 1], [3, float("nan")], [4, 1], [3, 2], [2, 3]]
for reduction in reductions:
for initial in [0, None]:
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "max":
initial_value = 0
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[4, 3],
[initial_value, initial_value],
[default_value, default_value],
]
expected_grad = [
[1, 1],
@ -166,12 +182,11 @@ class TestSegmentReductions(TestCase):
[0, 1],
]
elif reduction == "mean":
initial_value = 0
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[3, 2],
[initial_value, initial_value],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
@ -182,12 +197,14 @@ class TestSegmentReductions(TestCase):
[0.333, 0.333],
]
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[2, 1],
[initial_value, initial_value],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
@ -198,12 +215,11 @@ class TestSegmentReductions(TestCase):
[1, 0],
]
elif reduction == "sum":
initial_value = 0
expected_result = [
[1, 1],
[float("nan"), float("nan")],
[9, 6],
[initial_value, initial_value],
[default_value, default_value],
]
expected_grad = [
[1.0, 1.0],
@ -214,7 +230,6 @@ class TestSegmentReductions(TestCase):
[1.0, 1.0],
]
for unsafe in [True, False]:
for initial in [initial_value, None]:
self._test_common(
reduction,
device,
@ -246,14 +261,13 @@ class TestSegmentReductions(TestCase):
check_backward = False
for reduction in reductions:
if reduction == "max":
initial_value = 0
if reduction == "max":
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.max(data, axis=0).tolist(),
]
elif reduction == "mean":
initial_value = 0
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.mean(data, axis=0).tolist(),
@ -265,13 +279,11 @@ class TestSegmentReductions(TestCase):
np.min(data, axis=0).tolist(),
]
elif reduction == "sum":
initial_value = 0
expected_result = [
np.full((2, 5), initial_value).tolist(),
np.sum(data, axis=0).tolist(),
]
for unsafe in [True, False]:
for initial in [initial_value, None]:
self._test_common(
reduction,
device,