mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This fixes BC-breaking changes introduced by https://github.com/pytorch/pytorch/pull/91499
Make enum accept both `min` and `amin` values
Reinstante testing
To reiterate
454361435c/torch/masked/_ops.py (L786)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93091
Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
7fade4f771
commit
661800a2cf
@ -18,17 +18,17 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
reductions = ["amax", "mean", "amin", "sum", "prod"]
|
||||
reductions = ["max", "mean", "min", "sum", "prod"]
|
||||
|
||||
|
||||
def get_default_value(initial_value, reduction):
|
||||
if initial_value is not None:
|
||||
return initial_value
|
||||
if reduction == "amax":
|
||||
if reduction == "max":
|
||||
return -float("Inf")
|
||||
elif reduction == "mean":
|
||||
return float("nan")
|
||||
elif reduction == "amin":
|
||||
elif reduction == "min":
|
||||
return float("Inf")
|
||||
elif reduction == "sum":
|
||||
return 0.0
|
||||
@ -133,13 +133,13 @@ class TestSegmentReductions(TestCase):
|
||||
check_backward = True if initial is not None else False
|
||||
initial_value = initial
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
if reduction == "amax":
|
||||
if reduction == "max":
|
||||
expected_result = [1, float("nan"), 5, default_value]
|
||||
expected_grad = [1, 1, 0, 0, 0.5, 0.5]
|
||||
elif reduction == "mean":
|
||||
expected_result = [1, float("nan"), 4.666, default_value]
|
||||
expected_grad = [1.0, 0.5, 0.5, 0.333, 0.333, 0.333]
|
||||
elif reduction == "amin":
|
||||
elif reduction == "min":
|
||||
if initial is not None:
|
||||
initial_value = 1000 # some high number
|
||||
default_value = get_default_value(initial_value, reduction)
|
||||
|
Reference in New Issue
Block a user