OpInfo for *_like functions (#65941)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65941

OpInfos for: empty_like, zeros_like, ones_like, full_like, randn_like

Test Plan: - run tests

Reviewed By: dagitses

Differential Revision: D31452625

Pulled By: zou3519

fbshipit-source-id: 5e6c45918694853f9252488d62bb7f4ccfa1f1e4
This commit is contained in:
Richard Zou
2021-10-14 09:11:42 -07:00
committed by Facebook GitHub Bot
parent 5d4452937d
commit d810e738b9
4 changed files with 104 additions and 0 deletions

View File

@ -66,6 +66,10 @@ ignore_missing_imports = True
[mypy-test_torch]
check_untyped_defs = False
# Excluded from mypy due to OpInfos being annoying to type
[mypy-torch.testing._internal.common_methods_invocations.*]
ignore_errors = True
[mypy-torch.testing._internal.hypothesis_utils.*]
ignore_errors = True

View File

@ -3220,6 +3220,11 @@ class TestOperatorSignatures(JitTestCase):
'int',
'long',
'short',
'empty_like',
'ones_like',
'randn_like',
'zeros_like',
'full_like',
'__getitem__',
'__radd__',
'__rsub__',

View File

@ -1500,6 +1500,11 @@ class TestNormalizeOperators(JitTestCase):
'int',
'long',
'short',
'empty_like',
'ones_like',
'randn_like',
'zeros_like',
'full_like',
"__getitem__",
"__radd__",
"__rsub__",

View File

@ -2024,6 +2024,57 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad):
return tuple(samples)
def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
inputs = [
((), {}),
((S, S), {}),
((0, S, 0), {}),
((S,), {'dtype': dtype, 'device': device}),
# Hard-code some dtypes/devices. We want to test cases where the
# (dtype, device) is different from the input's (dtype, device)
((S,), {'dtype': torch.double}),
((S,), {'device': 'cpu'}),
((S,), {'dtype': torch.double, 'device': 'cpu'}),
]
if torch.cuda.is_available():
inputs.append(((S,), {'device': 'cuda'}))
samples = []
for shape, kwargs in inputs:
t = make_tensor(shape, device, dtype,
low=None, high=None,
requires_grad=requires_grad)
samples.append(SampleInput(t, kwargs=kwargs))
return tuple(samples)
def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
def get_val(dtype):
return make_tensor([], 'cpu', dtype).item()
inputs = [
((), get_val(dtype), {}),
((S, S), get_val(dtype), {}),
((0, S, 0), get_val(dtype), {}),
((S,), get_val(dtype), {'dtype': dtype, 'device': device}),
# Hard-code some dtypes/devices. We want to test cases where the
# (dtype, device) is different from the input's (dtype, device)
((S,), get_val(torch.double), {'dtype': torch.double}),
((S,), get_val(dtype), {'device': 'cpu'}),
((S,), get_val(torch.double), {'dtype': torch.double, 'device': 'cpu'}),
]
if torch.cuda.is_available():
inputs.append(((S,), get_val(dtype), {'device': 'cuda'}))
samples = []
for shape, fill_value, kwargs in inputs:
t = make_tensor(shape, device, dtype,
low=None, high=None,
requires_grad=requires_grad)
samples.append(SampleInput(t, args=(fill_value,), kwargs=kwargs))
return tuple(samples)
def sample_inputs_logcumsumexp(self, device, dtype, requires_grad):
inputs = (
((S, S, S), 0),
@ -9656,6 +9707,45 @@ op_db: List[OpInfo] = [
# RuntimeError: attribute lookup is not defined on builtin
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('empty_like',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_like_fns,
supports_autograd=False,
skips=(
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
# Empty tensor data is garbage so it's hard to make comparisons with it.
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
)),
OpInfo('zeros_like',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_like_fns,
supports_autograd=False),
OpInfo('ones_like',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_like_fns,
supports_autograd=False),
OpInfo('randn_like',
dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex64, torch.complex128),
op=lambda inp, *args, **kwargs:
wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
supports_out=False,
sample_inputs_func=sample_inputs_like_fns,
supports_autograd=False,
skips=(
# AssertionError: JIT Test does not execute any logic
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
)),
OpInfo('full_like',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
supports_out=False,
sample_inputs_func=sample_inputs_full_like,
supports_autograd=False),
OpInfo('scatter_add',
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
sample_inputs_func=sample_inputs_scatter_add,