Fix BC-breaking change introduced by #91499 (#93091)

This fixes BC-breaking changes introduced by https://github.com/pytorch/pytorch/pull/91499
Make enum accept both `min` and `amin` values
Reinstante testing

To reiterate
454361435c/torch/masked/_ops.py (L786)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93091
Approved by: https://github.com/ngimel
This commit is contained in:
Nikita Shulga
2023-01-27 03:58:32 +00:00
committed by PyTorch MergeBot
parent 7fade4f771
commit 661800a2cf
3 changed files with 8 additions and 8 deletions

View File

@ -18,17 +18,17 @@ from torch.testing._internal.common_utils import (
)
reductions = ["amax", "mean", "amin", "sum", "prod"]
reductions = ["max", "mean", "min", "sum", "prod"]
def get_default_value(initial_value, reduction):
if initial_value is not None:
return initial_value
if reduction == "amax":
if reduction == "max":
return -float("Inf")
elif reduction == "mean":
return float("nan")
elif reduction == "amin":
elif reduction == "min":
return float("Inf")
elif reduction == "sum":
return 0.0
@ -133,13 +133,13 @@ class TestSegmentReductions(TestCase):
check_backward = True if initial is not None else False
initial_value = initial
default_value = get_default_value(initial_value, reduction)
if reduction == "amax":
if reduction == "max":
expected_result = [1, float("nan"), 5, default_value]
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
elif reduction == "mean":
expected_result = [1, float("nan"), 4.666, default_value]
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
elif reduction == "amin":
elif reduction == "min":
if initial is not None:
initial_value = 1000 # some high number
default_value = get_default_value(initial_value, reduction)