Add torch.nn.functional.threshold ref (#79808)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79808
Approved by: https://github.com/mruberry
This commit is contained in:
Will Constable
2022-06-27 20:30:41 +00:00
committed by PyTorch MergeBot
parent ab8797d69b
commit f72d867b70
2 changed files with 36 additions and 5 deletions

View File

@ -20,7 +20,7 @@ from torch._refs import (
_make_elementwise_binary_reference,
)
from typing import Optional
from typing import Optional, Union
__all__ = [
"celu",
@ -36,6 +36,7 @@ __all__ = [
"softplus",
"softshrink",
"tanhshrink",
"threshold",
]
Tensor = torch.Tensor
@ -357,6 +358,27 @@ def tanhshrink(a: TensorLikeType) -> TensorLikeType:
return refs.sub(a, refs.tanh(a))
@register_decomposition(torch.ops.aten.threshold)
@elementwise_type_promotion_wrapper(
type_promoting_args=("a",),
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)
def threshold(
a: TensorLikeType,
threshold: NumberType,
value: Union[bool, int, float],
inplace: bool = False,
) -> TensorLikeType:
"""
Reference implementation of torch.nn.functional.threshold
"""
if inplace:
raise NotImplementedError
return torch.where(a <= threshold, value, a)
@register_decomposition(torch.ops.aten.hardtanh)
@elementwise_unary_scalar_wrapper
@elementwise_type_promotion_wrapper(

View File

@ -14732,18 +14732,22 @@ op_db: List[OpInfo] = [
active_if=(IS_MACOS or IS_WINDOWS)),
),
),
OpInfo(
UnaryUfuncInfo(
'nn.functional.threshold',
aten_backward_name='threshold_backward',
ref=lambda x, threshold, value: np.where(x > threshold, x, value).astype(x.dtype),
ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
dtypes=all_types_and(torch.bfloat16),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
supports_autograd=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
assert_autodiffed=False,
supports_gradgrad=True,
supports_out=False,
sample_kwargs=lambda device, dtype, input: ({'threshold': 0.123,
'value': -9},
{'threshold': 0.123,
'value': -9}),
# TODO(whc) should not need sample_inputs_func, but without it
# kwargs aren't being hooked up properly
sample_inputs_func=sample_inputs_threshold,
),
OpInfo(
@ -20333,6 +20337,11 @@ python_ref_db = [
"_refs.nn.functional.celu",
torch_opinfo_name="nn.functional.celu",
),
ElementwiseUnaryPythonRefInfo(
"_refs.nn.functional.threshold",
torch_opinfo_name="nn.functional.threshold",
supports_nvfuser=False,
),
PythonRefInfo(
"_refs.nn.functional.dropout",
torch_opinfo_name="nn.functional.dropout",