mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-02 06:24:59 +08:00
Add torch.nn.functional.threshold ref (#79808)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/79808 Approved by: https://github.com/mruberry
This commit is contained in:
committed by
PyTorch MergeBot
parent
ab8797d69b
commit
f72d867b70
@ -20,7 +20,7 @@ from torch._refs import (
|
||||
_make_elementwise_binary_reference,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from typing import Optional, Union
|
||||
|
||||
__all__ = [
|
||||
"celu",
|
||||
@ -36,6 +36,7 @@ __all__ = [
|
||||
"softplus",
|
||||
"softshrink",
|
||||
"tanhshrink",
|
||||
"threshold",
|
||||
]
|
||||
|
||||
Tensor = torch.Tensor
|
||||
@ -357,6 +358,27 @@ def tanhshrink(a: TensorLikeType) -> TensorLikeType:
|
||||
return refs.sub(a, refs.tanh(a))
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.threshold)
|
||||
@elementwise_type_promotion_wrapper(
|
||||
type_promoting_args=("a",),
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)
|
||||
def threshold(
|
||||
a: TensorLikeType,
|
||||
threshold: NumberType,
|
||||
value: Union[bool, int, float],
|
||||
inplace: bool = False,
|
||||
) -> TensorLikeType:
|
||||
"""
|
||||
Reference implementation of torch.nn.functional.threshold
|
||||
"""
|
||||
|
||||
if inplace:
|
||||
raise NotImplementedError
|
||||
|
||||
return torch.where(a <= threshold, value, a)
|
||||
|
||||
|
||||
@register_decomposition(torch.ops.aten.hardtanh)
|
||||
@elementwise_unary_scalar_wrapper
|
||||
@elementwise_type_promotion_wrapper(
|
||||
|
||||
@ -14732,18 +14732,22 @@ op_db: List[OpInfo] = [
|
||||
active_if=(IS_MACOS or IS_WINDOWS)),
|
||||
),
|
||||
),
|
||||
OpInfo(
|
||||
UnaryUfuncInfo(
|
||||
'nn.functional.threshold',
|
||||
aten_backward_name='threshold_backward',
|
||||
ref=lambda x, threshold, value: np.where(x > threshold, x, value).astype(x.dtype),
|
||||
ref=lambda x, threshold, value: np.where(x <= threshold, value, x).astype(x.dtype),
|
||||
dtypes=all_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
|
||||
supports_autograd=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
assert_autodiffed=False,
|
||||
supports_gradgrad=True,
|
||||
supports_out=False,
|
||||
sample_kwargs=lambda device, dtype, input: ({'threshold': 0.123,
|
||||
'value': -9},
|
||||
{'threshold': 0.123,
|
||||
'value': -9}),
|
||||
# TODO(whc) should not need sample_inputs_func, but without it
|
||||
# kwargs aren't being hooked up properly
|
||||
sample_inputs_func=sample_inputs_threshold,
|
||||
),
|
||||
OpInfo(
|
||||
@ -20333,6 +20337,11 @@ python_ref_db = [
|
||||
"_refs.nn.functional.celu",
|
||||
torch_opinfo_name="nn.functional.celu",
|
||||
),
|
||||
ElementwiseUnaryPythonRefInfo(
|
||||
"_refs.nn.functional.threshold",
|
||||
torch_opinfo_name="nn.functional.threshold",
|
||||
supports_nvfuser=False,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.nn.functional.dropout",
|
||||
torch_opinfo_name="nn.functional.dropout",
|
||||
|
||||
Reference in New Issue
Block a user