mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
OpInfo for *_like
functions (#65941)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65941 OpInfos for: empty_like, zeros_like, ones_like, full_like, randn_like Test Plan: - run tests Reviewed By: dagitses Differential Revision: D31452625 Pulled By: zou3519 fbshipit-source-id: 5e6c45918694853f9252488d62bb7f4ccfa1f1e4
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5d4452937d
commit
d810e738b9
4
mypy.ini
4
mypy.ini
@ -66,6 +66,10 @@ ignore_missing_imports = True
|
||||
[mypy-test_torch]
|
||||
check_untyped_defs = False
|
||||
|
||||
# Excluded from mypy due to OpInfos being annoying to type
|
||||
[mypy-torch.testing._internal.common_methods_invocations.*]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.testing._internal.hypothesis_utils.*]
|
||||
ignore_errors = True
|
||||
|
||||
|
@ -3220,6 +3220,11 @@ class TestOperatorSignatures(JitTestCase):
|
||||
'int',
|
||||
'long',
|
||||
'short',
|
||||
'empty_like',
|
||||
'ones_like',
|
||||
'randn_like',
|
||||
'zeros_like',
|
||||
'full_like',
|
||||
'__getitem__',
|
||||
'__radd__',
|
||||
'__rsub__',
|
||||
|
@ -1500,6 +1500,11 @@ class TestNormalizeOperators(JitTestCase):
|
||||
'int',
|
||||
'long',
|
||||
'short',
|
||||
'empty_like',
|
||||
'ones_like',
|
||||
'randn_like',
|
||||
'zeros_like',
|
||||
'full_like',
|
||||
"__getitem__",
|
||||
"__radd__",
|
||||
"__rsub__",
|
||||
|
@ -2024,6 +2024,57 @@ def sample_inputs_logsumexp(self, device, dtype, requires_grad):
|
||||
|
||||
return tuple(samples)
|
||||
|
||||
def sample_inputs_like_fns(self, device, dtype, requires_grad, **kwargs):
|
||||
inputs = [
|
||||
((), {}),
|
||||
((S, S), {}),
|
||||
((0, S, 0), {}),
|
||||
((S,), {'dtype': dtype, 'device': device}),
|
||||
# Hard-code some dtypes/devices. We want to test cases where the
|
||||
# (dtype, device) is different from the input's (dtype, device)
|
||||
((S,), {'dtype': torch.double}),
|
||||
((S,), {'device': 'cpu'}),
|
||||
((S,), {'dtype': torch.double, 'device': 'cpu'}),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
inputs.append(((S,), {'device': 'cuda'}))
|
||||
|
||||
samples = []
|
||||
for shape, kwargs in inputs:
|
||||
t = make_tensor(shape, device, dtype,
|
||||
low=None, high=None,
|
||||
requires_grad=requires_grad)
|
||||
samples.append(SampleInput(t, kwargs=kwargs))
|
||||
|
||||
return tuple(samples)
|
||||
|
||||
def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs):
|
||||
def get_val(dtype):
|
||||
return make_tensor([], 'cpu', dtype).item()
|
||||
|
||||
inputs = [
|
||||
((), get_val(dtype), {}),
|
||||
((S, S), get_val(dtype), {}),
|
||||
((0, S, 0), get_val(dtype), {}),
|
||||
((S,), get_val(dtype), {'dtype': dtype, 'device': device}),
|
||||
# Hard-code some dtypes/devices. We want to test cases where the
|
||||
# (dtype, device) is different from the input's (dtype, device)
|
||||
((S,), get_val(torch.double), {'dtype': torch.double}),
|
||||
((S,), get_val(dtype), {'device': 'cpu'}),
|
||||
((S,), get_val(torch.double), {'dtype': torch.double, 'device': 'cpu'}),
|
||||
]
|
||||
if torch.cuda.is_available():
|
||||
inputs.append(((S,), get_val(dtype), {'device': 'cuda'}))
|
||||
|
||||
samples = []
|
||||
for shape, fill_value, kwargs in inputs:
|
||||
t = make_tensor(shape, device, dtype,
|
||||
low=None, high=None,
|
||||
requires_grad=requires_grad)
|
||||
samples.append(SampleInput(t, args=(fill_value,), kwargs=kwargs))
|
||||
|
||||
return tuple(samples)
|
||||
|
||||
def sample_inputs_logcumsumexp(self, device, dtype, requires_grad):
|
||||
inputs = (
|
||||
((S, S, S), 0),
|
||||
@ -9656,6 +9707,45 @@ op_db: List[OpInfo] = [
|
||||
# RuntimeError: attribute lookup is not defined on builtin
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
)),
|
||||
OpInfo('empty_like',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_like_fns,
|
||||
supports_autograd=False,
|
||||
skips=(
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_conj_view'),
|
||||
# Empty tensor data is garbage so it's hard to make comparisons with it.
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestMathBits', 'test_neg_view'),
|
||||
)),
|
||||
OpInfo('zeros_like',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_like_fns,
|
||||
supports_autograd=False),
|
||||
OpInfo('ones_like',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_like_fns,
|
||||
supports_autograd=False),
|
||||
OpInfo('randn_like',
|
||||
dtypes=floating_types_and(torch.half, torch.bfloat16, torch.complex64, torch.complex128),
|
||||
op=lambda inp, *args, **kwargs:
|
||||
wrapper_set_seed(torch.randn_like, inp, *args, **kwargs),
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_like_fns,
|
||||
supports_autograd=False,
|
||||
skips=(
|
||||
# AssertionError: JIT Test does not execute any logic
|
||||
DecorateInfo(unittest.skip("Skipped!"), 'TestJit', 'test_variant_consistency_jit'),
|
||||
)),
|
||||
OpInfo('full_like',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
supports_out=False,
|
||||
sample_inputs_func=sample_inputs_full_like,
|
||||
supports_autograd=False),
|
||||
OpInfo('scatter_add',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.half, torch.bfloat16),
|
||||
sample_inputs_func=sample_inputs_scatter_add,
|
||||
|
Reference in New Issue
Block a user