[functorch] More OpInfos

This commit is contained in:
Richard Zou
2021-07-21 10:55:30 -07:00
committed by Jon Janzen
parent 6b59f1ad78
commit f28e199609

View File

@ -35,7 +35,10 @@ from torch.testing._internal.common_utils import \
torch_to_numpy_dtype_dict, TEST_WITH_ASAN,
GRADCHECK_NONDET_TOL,)
import torch.testing._internal.opinfo_helper as opinfo_helper
from torch.testing._internal.common_methods_invocations import OpInfo, SkipInfo, SampleInput
from torch.testing._internal.common_methods_invocations import (
OpInfo, SkipInfo, SampleInput, sample_inputs_hardshrink_hardtanh,
sample_inputs_log_softmax,
)
# List of OpInfos that aren't in PyTorch Core yet.
# They are here because we wanted a fast way of writing OpInfos and may not be
@ -207,3 +210,55 @@ additional_op_db.extend([
),
supports_out=False),
])
# https://github.com/pytorch/pytorch/pull/61067
additional_op_db.extend([
OpInfo('nn.functional.relu',
aten_name="relu",
dtypes=all_types(),
dtypesIfCPU=all_types_and(torch.bfloat16),
backward_dtypesIfCPU=floating_types(),
dtypesIfCUDA=all_types_and(torch.float16, torch.bfloat16),
backward_dtypesIfCUDA=floating_types_and(torch.float16),
supports_autograd=True,
assert_autodiffed=True,
sample_inputs_func=sample_inputs_hardshrink_hardtanh,
supports_gradgrad=True,
supports_out=False,
autodiff_nonfusible_nodes=["aten::relu"]),
OpInfo(
'softmax',
supports_out=False,
dtypes=floating_types_and(torch.bfloat16),
dtypesIfCUDA=floating_types_and(torch.float16, torch.bfloat16),
sample_inputs_func=sample_inputs_log_softmax,
assert_autodiffed=True),
])
# https://github.com/pytorch/pytorch/pull/61068
def sample_inputs_dropout(self, device, dtype, requires_grad):
samples = []
dropout_args = [
(0.6, False, False),
(1.0, True, False),
(0.0, True, False)
]
shapes = [(), (2,), (2, 3, 4), (2, 3, 4, 5, 6)]
for rank in [1, 3, 5]:
for shape in shapes:
for args in dropout_args:
samples.append(SampleInput(make_tensor(shape, device=device, dtype=dtype, requires_grad=requires_grad, low=-5, high=5), args=args))
return samples
additional_op_db.extend([
OpInfo('nn.functional.dropout',
aten_name="dropout",
supports_autograd=True,
assert_autodiffed=True,
sample_inputs_func=sample_inputs_dropout,
dtypesIfCUDA=floating_types_and(torch.half, torch.bfloat16),
supports_gradgrad=False,
supports_forward_ad=True,
supports_out=False,
autodiff_nonfusible_nodes=["aten::dropout"]),
])