mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[functorch] More OpInfos
This commit is contained in:
@ -35,7 +35,10 @@ from torch.testing._internal.common_utils import \
|
||||
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
|
||||
GRADCHECK_NONDET_TOL,)
|
||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
from torch.testing._internal.common_methods_invocations import OpInfo, SkipInfo, SampleInput
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
OpInfo, SkipInfo, SampleInput, sample_inputs_hardshrink_hardtanh,
|
||||
sample_inputs_log_softmax,
|
||||
)
|
||||
|
||||
# List of OpInfos that aren't in PyTorch Core yet.
|
||||
# They are here because we wanted a fast way of writing OpInfos and may not be
|
||||
@ -207,3 +210,55 @@ additional_op_db.extend([
|
||||
),
|
||||
supports_out=False),
|
||||
])
|
||||
|
||||
# https://github.com/pytorch/pytorch/pull/61067
|
||||
additional_op_db.extend([
|
||||
OpInfo('nn.functional.relu',
|
||||
aten_name="relu",
|
||||
dtypes=all_types(),
|
||||
dtypesIfCPU=all_types_and(torch.bfloat16),
|
||||
backward_dtypesIfCPU=floating_types(),
|
||||
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
|
||||
backward_dtypesIfCUDA=floating_types_and(torch.float16),
|
||||
supports_autograd=True,
|
||||
assert_autodiffed=True,
|
||||
sample_inputs_func=sample_inputs_hardshrink_hardtanh,
|
||||
supports_gradgrad=True,
|
||||
supports_out=False,
|
||||
autodiff_nonfusible_nodes=["aten::relu"]),
|
||||
OpInfo(
|
||||
'softmax',
|
||||
supports_out=False,
|
||||
dtypes=floating_types_and(torch.bfloat16),
|
||||
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_log_softmax,
|
||||
assert_autodiffed=True),
|
||||
])
|
||||
|
||||
# https://github.com/pytorch/pytorch/pull/61068
|
||||
def sample_inputs_dropout(self, device, dtype, requires_grad):
|
||||
samples = []
|
||||
dropout_args = [
|
||||
(0.6, False, False),
|
||||
(1.0, True, False),
|
||||
(0.0, True, False)
|
||||
]
|
||||
shapes = [(), (2,), (2, 3, 4), (2, 3, 4, 5, 6)]
|
||||
for rank in [1, 3, 5]:
|
||||
for shape in shapes:
|
||||
for args in dropout_args:
|
||||
samples.append(SampleInput(make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad, low=-5, high=5), args=args))
|
||||
return samples
|
||||
|
||||
additional_op_db.extend([
|
||||
OpInfo('nn.functional.dropout',
|
||||
aten_name="dropout",
|
||||
supports_autograd=True,
|
||||
assert_autodiffed=True,
|
||||
sample_inputs_func=sample_inputs_dropout,
|
||||
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
|
||||
supports_gradgrad=False,
|
||||
supports_forward_ad=True,
|
||||
supports_out=False,
|
||||
autodiff_nonfusible_nodes=["aten::dropout"]),
|
||||
])
|
||||
|
Reference in New Issue
Block a user