Adds python ref consistency test, elementwise unary reference inputs, and formats test files

Per title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626
Approved by: https://github.com/ngimel
This commit is contained in:
Mike Ruberry
2022-05-01 22:42:46 +00:00
committed by PyTorch MergeBot
parent 33be4c94c0
commit f6bbecf8b5
12 changed files with 3333 additions and 1682 deletions

File diff suppressed because it is too large Load Diff

View File

@ -8,17 +8,44 @@ import itertools
import torch import torch
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_dtype import floating_and_complex_types_and, all_types_and_complex_and from torch.testing._internal.common_dtype import (
from torch.testing._internal.common_utils import \ floating_and_complex_types_and,
(TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper, all_types_and_complex_and,
IS_IN_CI, suppress_warnings, noncontiguous_like, )
TEST_WITH_ASAN, IS_WINDOWS, IS_FBCODE, first_sample) from torch.testing._internal.common_utils import (
from torch.testing._internal.common_methods_invocations import \ TestCase,
(op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, ops_and_refs) is_iterable_of_tensors,
from torch.testing._internal.common_device_type import \ run_tests,
(deviceCountAtLeast, instantiate_device_type_tests, ops, IS_SANDCASTLE,
onlyCUDA, onlyNativeDeviceTypes, OpDTypes, skipMeta) clone_input_helper,
# import torch._prims as prims IS_IN_CI,
suppress_warnings,
noncontiguous_like,
TEST_WITH_ASAN,
IS_WINDOWS,
IS_FBCODE,
first_sample,
)
from torch.testing._internal.common_methods_invocations import (
op_db,
_NOTHING,
UnaryUfuncInfo,
ReductionOpInfo,
SpectralFuncInfo,
ops_and_refs,
python_ref_db,
BinaryUfuncInfo,
)
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
instantiate_device_type_tests,
ops,
onlyCUDA,
onlyNativeDeviceTypes,
OpDTypes,
skipMeta,
)
import torch._prims as prims
import torch.testing._internal.opinfo_helper as opinfo_helper import torch.testing._internal.opinfo_helper as opinfo_helper
from torch.testing._internal import composite_compliance from torch.testing._internal import composite_compliance
@ -28,15 +55,25 @@ torch.set_default_dtype(torch.float32)
# variant testing is only done with torch.float and torch.cfloat to avoid # variant testing is only done with torch.float and torch.cfloat to avoid
# excessive test times and maximize signal to noise ratio # excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported, _variant_ops = partial(
allowed_dtypes=(torch.float, torch.cfloat)) ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
)
# Get names of all the operators which have ref in their entry in OpInfo (testing infra) # Get names of all the operators which have ref in their entry in OpInfo (testing infra)
# except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py) # except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
# elementwise binary operators (separately implemented in test_binary_ufuncs.py),
# reduction operations (separately impelemented in test_reductions.py),
# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py) # and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo, _ref_test_ops = tuple(
SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db)) filter(
lambda op: not isinstance(
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
)
and op.ref is not None
and op.ref is not _NOTHING,
op_db,
)
)
# Tests that apply to all operators and aren't related to any particular # Tests that apply to all operators and aren't related to any particular
# system # system
@ -49,8 +86,10 @@ class TestCommon(TestCase):
super().tearDownClass() super().tearDownClass()
if IS_IN_CI: if IS_IN_CI:
err_msg = ("The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries." err_msg = (
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!") "The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
)
# Assure no opinfo entry has dynamic_dtypes # Assure no opinfo entry has dynamic_dtypes
filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db)) filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db))
for op in filtered_ops: for op in filtered_ops:
@ -68,11 +107,16 @@ class TestCommon(TestCase):
# Check complex32 support only if the op claims. # Check complex32 support only if the op claims.
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally. # TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
device_type = torch.device(device).type device_type = torch.device(device).type
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device_type) else ()) include_complex32 = (
(torch.complex32,)
if op.supports_dtype(torch.complex32, device_type)
else ()
)
# dtypes to try to backward in # dtypes to try to backward in
allowed_backward_dtypes = floating_and_complex_types_and( allowed_backward_dtypes = floating_and_complex_types_and(
*((torch.half, torch.bfloat16) + include_complex32)) *((torch.half, torch.bfloat16) + include_complex32)
)
# lists for (un)supported dtypes # lists for (un)supported dtypes
supported_dtypes = set() supported_dtypes = set()
@ -86,11 +130,14 @@ class TestCommon(TestCase):
unsupported_backward_dtypes.add(dtype) unsupported_backward_dtypes.add(dtype)
for dtype in all_types_and_complex_and( for dtype in all_types_and_complex_and(
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)): *((torch.half, torch.bfloat16, torch.bool) + include_complex32)
):
# tries to acquire samples - failure indicates lack of support # tries to acquire samples - failure indicates lack of support
requires_grad = (dtype in allowed_backward_dtypes) requires_grad = dtype in allowed_backward_dtypes
try: try:
samples = tuple(op.sample_inputs(device, dtype, requires_grad=requires_grad)) samples = tuple(
op.sample_inputs(device, dtype, requires_grad=requires_grad)
)
except Exception as e: except Exception as e:
unsupported(dtype) unsupported(dtype)
continue continue
@ -113,7 +160,9 @@ class TestCommon(TestCase):
result = sample.output_process_fn_grad(result) result = sample.output_process_fn_grad(result)
if isinstance(result, torch.Tensor): if isinstance(result, torch.Tensor):
backward_tensor = result backward_tensor = result
elif isinstance(result, Sequence) and isinstance(result[0], torch.Tensor): elif isinstance(result, Sequence) and isinstance(
result[0], torch.Tensor
):
backward_tensor = result[0] backward_tensor = result[0]
else: else:
continue continue
@ -130,14 +179,15 @@ class TestCommon(TestCase):
except Exception as e: except Exception as e:
unsupported_backward_dtypes.add(dtype) unsupported_backward_dtypes.add(dtype)
# Checks that dtypes are listed correctly and generates an informative # Checks that dtypes are listed correctly and generates an informative
# error message # error message
supported_forward = supported_dtypes - unsupported_dtypes supported_forward = supported_dtypes - unsupported_dtypes
partially_supported_forward = supported_dtypes & unsupported_dtypes partially_supported_forward = supported_dtypes & unsupported_dtypes
unsupported_forward = unsupported_dtypes - supported_dtypes unsupported_forward = unsupported_dtypes - supported_dtypes
supported_backward = supported_backward_dtypes - unsupported_backward_dtypes supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes partially_supported_backward = (
supported_backward_dtypes & unsupported_backward_dtypes
)
unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
device_type = torch.device(device).type device_type = torch.device(device).type
@ -156,17 +206,27 @@ class TestCommon(TestCase):
op.name, device_type op.name, device_type
) )
if len(partially_supported_forward) > 0: if len(partially_supported_forward) > 0:
msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format( msg = (
partially_supported_forward msg
+ "The following dtypes only worked on some samples during forward: {0}.\n".format(
partially_supported_forward
)
) )
if len(partially_supported_backward) > 0: if len(partially_supported_backward) > 0:
msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format( msg = (
partially_supported_backward msg
+ "The following dtypes only worked on some samples during backward: {0}.\n".format(
partially_supported_backward
)
) )
print(msg) print(msg)
if (len(supported_but_unclaimed_forward) + len(claimed_but_unsupported_forward) + if (
len(supported_but_unclaimed_backward) + len(claimed_but_unsupported_backward)) == 0: len(supported_but_unclaimed_forward)
+ len(claimed_but_unsupported_forward)
+ len(supported_but_unclaimed_backward)
+ len(claimed_but_unsupported_backward)
) == 0:
return return
# Generates error msg # Generates error msg
@ -174,20 +234,32 @@ class TestCommon(TestCase):
op.name, device_type op.name, device_type
) )
if len(supported_but_unclaimed_forward) > 0: if len(supported_but_unclaimed_forward) > 0:
msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format( msg = (
supported_but_unclaimed_forward msg
+ "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
supported_but_unclaimed_forward
)
) )
if len(supported_but_unclaimed_backward) > 0: if len(supported_but_unclaimed_backward) > 0:
msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format( msg = (
supported_but_unclaimed_backward msg
+ "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
supported_but_unclaimed_backward
)
) )
if len(claimed_but_unsupported_forward) > 0: if len(claimed_but_unsupported_forward) > 0:
msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format( msg = (
claimed_but_unsupported_forward msg
+ "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
claimed_but_unsupported_forward
)
) )
if len(claimed_but_unsupported_backward) > 0: if len(claimed_but_unsupported_backward) > 0:
msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format( msg = (
claimed_but_unsupported_backward msg
+ "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
claimed_but_unsupported_backward
)
) )
self.fail(msg) self.fail(msg)
@ -209,7 +281,9 @@ class TestCommon(TestCase):
elif is_iterable_of_tensors(result): elif is_iterable_of_tensors(result):
self.assertTrue(all(map(lambda t: t.device == cuda_device, result))) self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
else: else:
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
# Tests that the function and its (ndarray-accepting) reference produce the same # Tests that the function and its (ndarray-accepting) reference produce the same
# values on the tensors from sample_inputs func for the corresponding op. # values on the tensors from sample_inputs func for the corresponding op.
@ -226,28 +300,48 @@ class TestCommon(TestCase):
cur_default = torch.get_default_dtype() cur_default = torch.get_default_dtype()
torch.set_default_dtype(torch.double) torch.set_default_dtype(torch.double)
for sample_input in op.reference_inputs(device, dtype): for sample_input in op.reference_inputs(device, dtype):
self.compare_with_reference(op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)) self.compare_with_reference(
op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
)
finally: finally:
torch.set_default_dtype(cur_default) torch.set_default_dtype(cur_default)
# Tests that experimental Python References' can propagate shape, dtype, # Tests that experimental Python References can propagate shape, dtype,
# and device metadata properly. # and device metadata properly.
# TODO: include stride propagation. # TODO: include stride propagation.
# @onlyNativeDeviceTypes @onlyNativeDeviceTypes
# @ops(python_ref_db) @ops(python_ref_db)
# def test_python_reference_meta_functions(self, device, dtype, op): def test_python_reference_meta_functions(self, device, dtype, op):
# def _to_tensormeta(x): def _to_tensormeta(x):
# if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
# return prims.utils.TensorMeta(x) return prims.utils.TensorMeta(x)
# # TODO: iterate over requires_grad true/false # TODO: iterate over requires_grad true/false
# for sample in op.reference_inputs(device, dtype, requires_grad=False): for sample in op.reference_inputs(device, dtype, requires_grad=False):
# result = op(sample.input, *sample.args, **sample.kwargs) result = op(sample.input, *sample.args, **sample.kwargs)
# meta_sample = sample.transform(_to_tensormeta) meta_sample = sample.transform(_to_tensormeta)
# meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs) meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
# prims.utils.compare_tensor_meta(result, meta_result) prims.utils.compare_tensor_meta(result, meta_result)
# Tests that experimental Python References perform the same computation
# as the operators they reference.
@onlyNativeDeviceTypes
@ops(python_ref_db)
def test_python_reference_consistency(self, device, dtype, op):
for sample in op.reference_inputs(device, dtype, requires_grad=False):
actual = op(sample.input, *sample.args, **sample.kwargs)
expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
self.assertEqual(
actual,
expected,
exact_stride=True,
exact_device=True,
exact_layout=True,
exact_is_coalesced=True,
)
@skipMeta @skipMeta
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@ -272,9 +366,17 @@ class TestCommon(TestCase):
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad) sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
for sample_input in sample_inputs: for sample_input in sample_inputs:
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs t_inp, t_args, t_kwargs = (
sample_input.input,
sample_input.args,
sample_input.kwargs,
)
noncontig_sample = sample_input.noncontiguous() noncontig_sample = sample_input.noncontiguous()
n_inp, n_args, n_kwargs = noncontig_sample.input, noncontig_sample.args, noncontig_sample.kwargs n_inp, n_args, n_kwargs = (
noncontig_sample.input,
noncontig_sample.args,
noncontig_sample.kwargs,
)
# Verifies sample input tensors should have no grad or history # Verifies sample input tensors should have no grad or history
sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0] sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0]
@ -300,10 +402,14 @@ class TestCommon(TestCase):
grad_for_actual = noncontiguous_like(grad_for_expected) grad_for_actual = noncontiguous_like(grad_for_expected)
elif isinstance(expected, Sequence): elif isinstance(expected, Sequence):
# Filter output elements that do not require grad # Filter output elements that do not require grad
expected = [t for t in expected expected = [
if isinstance(t, torch.Tensor) and t.requires_grad] t
actual = [n for n in actual for t in expected
if isinstance(n, torch.Tensor) and n.requires_grad] if isinstance(t, torch.Tensor) and t.requires_grad
]
actual = [
n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
]
grad_for_expected = [torch.randn_like(t) for t in expected] grad_for_expected = [torch.randn_like(t) for t in expected]
grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected] grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
else: else:
@ -311,19 +417,35 @@ class TestCommon(TestCase):
continue continue
# Concatenate inputs into a tuple # Concatenate inputs into a tuple
t_inputs = (t_inp,) + t_args if isinstance(t_inp, torch.Tensor) else tuple(t_inp) + t_args t_inputs = (
n_inputs = (n_inp,) + n_args if isinstance(n_inp, torch.Tensor) else tuple(n_inp) + n_args (t_inp,) + t_args
if isinstance(t_inp, torch.Tensor)
else tuple(t_inp) + t_args
)
n_inputs = (
(n_inp,) + n_args
if isinstance(n_inp, torch.Tensor)
else tuple(n_inp) + n_args
)
# Filter the elemnts that are tensors that require grad # Filter the elemnts that are tensors that require grad
t_input_tensors = [t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad] t_input_tensors = [
n_input_tensors = [n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad] t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
]
n_input_tensors = [
n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
]
self.assertEqual(len(t_input_tensors), len(n_input_tensors)) self.assertEqual(len(t_input_tensors), len(n_input_tensors))
# Some functions may not use all the inputs to generate gradients. One of the # Some functions may not use all the inputs to generate gradients. One of the
# few examples of this "odd" behaviour is F.hinge_embedding_loss # few examples of this "odd" behaviour is F.hinge_embedding_loss
t_grads = torch.autograd.grad(expected, t_input_tensors, grad_for_expected, allow_unused=True) t_grads = torch.autograd.grad(
n_grads = torch.autograd.grad(actual, n_input_tensors, grad_for_actual, allow_unused=True) expected, t_input_tensors, grad_for_expected, allow_unused=True
)
n_grads = torch.autograd.grad(
actual, n_input_tensors, grad_for_actual, allow_unused=True
)
msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}." msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
for i, (t, n) in enumerate(zip(t_grads, n_grads)): for i, (t, n) in enumerate(zip(t_grads, n_grads)):
@ -339,7 +461,11 @@ class TestCommon(TestCase):
supported_dtypes = op.supported_dtypes(self.device_type) supported_dtypes = op.supported_dtypes(self.device_type)
if len(supported_dtypes) == 0: if len(supported_dtypes) == 0:
self.skipTest("Skipped! Op has not supported dtypes on this device.") self.skipTest("Skipped! Op has not supported dtypes on this device.")
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0] dtype = (
torch.float32
if torch.float32 in supported_dtypes
else list(supported_dtypes)[0]
)
samples = op.sample_inputs(device, dtype) samples = op.sample_inputs(device, dtype)
for sample in samples: for sample in samples:
@ -349,8 +475,12 @@ class TestCommon(TestCase):
# Short-circuits if output is not a single tensor or an # Short-circuits if output is not a single tensor or an
# iterable of tensors # iterable of tensors
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True): if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") expected, include_empty=True
):
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
# Validates the op doesn't support out if it claims not to # Validates the op doesn't support out if it claims not to
if not op.supports_out: if not op.supports_out:
@ -380,7 +510,7 @@ class TestCommon(TestCase):
# NOTE: only extracts on the CPU and CUDA device types since some # NOTE: only extracts on the CPU and CUDA device types since some
# device types don't have storage # device types don't have storage
def _extract_data_ptrs(out): def _extract_data_ptrs(out):
if self.device_type != 'cpu' and self.device_type != 'cuda': if self.device_type != "cpu" and self.device_type != "cuda":
return () return ()
if isinstance(out, torch.Tensor): if isinstance(out, torch.Tensor):
@ -403,7 +533,8 @@ class TestCommon(TestCase):
if compare_strides_and_data_ptrs: if compare_strides_and_data_ptrs:
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format( stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
original_strides, final_strides) original_strides, final_strides
)
self.assertEqual(original_strides, final_strides, msg=stride_msg) self.assertEqual(original_strides, final_strides, msg=stride_msg)
self.assertEqual(original_ptrs, final_ptrs) self.assertEqual(original_ptrs, final_ptrs)
@ -433,7 +564,9 @@ class TestCommon(TestCase):
out = _apply_out_transform(_case_zero_transform, expected) out = _apply_out_transform(_case_zero_transform, expected)
msg_fail = "Resized a non-empty tensor but did not warn about it." msg_fail = "Resized a non-empty tensor but did not warn about it."
if _any_nonempty(out): if _any_nonempty(out):
with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail): with self.assertWarnsRegex(
UserWarning, "An output with one or more elements", msg=msg_fail
):
op_out(out=out) op_out(out=out)
# Validates ops implement the correct out= behavior # Validates ops implement the correct out= behavior
@ -452,7 +585,11 @@ class TestCommon(TestCase):
supported_dtypes = op.supported_dtypes(self.device_type) supported_dtypes = op.supported_dtypes(self.device_type)
if len(supported_dtypes) == 0: if len(supported_dtypes) == 0:
self.skipTest("Skipped! Op has not supported dtypes on this device.") self.skipTest("Skipped! Op has not supported dtypes on this device.")
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0] dtype = (
torch.float32
if torch.float32 in supported_dtypes
else list(supported_dtypes)[0]
)
samples = op.sample_inputs(device, dtype) samples = op.sample_inputs(device, dtype)
for sample in samples: for sample in samples:
@ -462,8 +599,12 @@ class TestCommon(TestCase):
# Short-circuits if output is not a single tensor or an # Short-circuits if output is not a single tensor or an
# iterable of tensors # iterable of tensors
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True): if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.") expected, include_empty=True
):
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
# Validates the op doesn't support out if it claims not to # Validates the op doesn't support out if it claims not to
if not op.supports_out: if not op.supports_out:
@ -493,7 +634,7 @@ class TestCommon(TestCase):
# NOTE: only extracts on the CPU and CUDA device types since some # NOTE: only extracts on the CPU and CUDA device types since some
# device types don't have storage # device types don't have storage
def _extract_data_ptrs(out): def _extract_data_ptrs(out):
if self.device_type != 'cpu' and self.device_type != 'cuda': if self.device_type != "cpu" and self.device_type != "cuda":
return () return ()
if isinstance(out, torch.Tensor): if isinstance(out, torch.Tensor):
@ -515,7 +656,8 @@ class TestCommon(TestCase):
if compare_strides_and_data_ptrs: if compare_strides_and_data_ptrs:
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format( stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
original_strides, final_strides) original_strides, final_strides
)
self.assertEqual(original_strides, final_strides, msg=stride_msg) self.assertEqual(original_strides, final_strides, msg=stride_msg)
self.assertEqual(original_ptrs, final_ptrs) self.assertEqual(original_ptrs, final_ptrs)
@ -529,7 +671,7 @@ class TestCommon(TestCase):
return torch.full_like(t, info.max) return torch.full_like(t, info.max)
except TypeError as te: except TypeError as te:
# for non-integer types fills with NaN # for non-integer types fills with NaN
return torch.full_like(t, float('nan')) return torch.full_like(t, float("nan"))
_compare_out(_case_zero_transform) _compare_out(_case_zero_transform)
@ -537,10 +679,9 @@ class TestCommon(TestCase):
# but noncontiguous. # but noncontiguous.
# Expected behavior: strides are respected and `out` storage is not changed. # Expected behavior: strides are respected and `out` storage is not changed.
def _case_one_transform(t): def _case_one_transform(t):
return make_tensor(t.shape, return make_tensor(
dtype=t.dtype, t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
device=t.device, )
noncontiguous=True)
_compare_out(_case_one_transform) _compare_out(_case_one_transform)
@ -560,16 +701,19 @@ class TestCommon(TestCase):
# Verifies no warning is a resize warning # Verifies no warning is a resize warning
for w in caught: for w in caught:
if "An output with one or more elements" in str(w.message): if "An output with one or more elements" in str(w.message):
self.fail("Resizing an out= argument with no elements threw a resize warning!") self.fail(
"Resizing an out= argument with no elements threw a resize warning!"
)
# Case 3: out= with correct shape and dtype, but wrong device. # Case 3: out= with correct shape and dtype, but wrong device.
wrong_device = None wrong_device = None
if torch.device(device).type != 'cpu': if torch.device(device).type != "cpu":
wrong_device = 'cpu' wrong_device = "cpu"
elif torch.cuda.is_available(): elif torch.cuda.is_available():
wrong_device = 'cuda' wrong_device = "cuda"
if wrong_device is not None: if wrong_device is not None:
def _case_three_transform(t): def _case_three_transform(t):
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device) return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
@ -587,16 +731,28 @@ class TestCommon(TestCase):
# dtypes, or if an op returns multiple tensors when at least one such # dtypes, or if an op returns multiple tensors when at least one such
# tensor is a floating point or complex dtype. # tensor is a floating point or complex dtype.
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16) _dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or if (
(not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))): isinstance(expected, torch.Tensor)
and expected.dtype in _dtypes
or (
not isinstance(expected, torch.Tensor)
and any(t.dtype in _dtypes for t in expected)
)
):
def _case_four_transform(t): def _case_four_transform(t):
return make_tensor(t.shape, dtype=torch.long, device=t.device) return make_tensor(t.shape, dtype=torch.long, device=t.device)
out = _apply_out_transform(_case_four_transform, expected) out = _apply_out_transform(_case_four_transform, expected)
msg_fail = "Expected RuntimeError when doing an unsafe cast!" msg_fail = "Expected RuntimeError when doing an unsafe cast!"
msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \ msg_fail = (
("Expected RuntimeError when doing an unsafe cast from a result of dtype " msg_fail
f"{expected.dtype} into an out= with dtype torch.long") if not isinstance(expected, torch.Tensor)
else (
"Expected RuntimeError when doing an unsafe cast from a result of dtype "
f"{expected.dtype} into an out= with dtype torch.long"
)
)
with self.assertRaises(RuntimeError, msg=msg_fail): with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out) op_out(out=out)
@ -611,7 +767,9 @@ class TestCommon(TestCase):
inplace = op.inplace_variant inplace = op.inplace_variant
# list of all inplace ops: inplace variant + alias inplace variants if exist # list of all inplace ops: inplace variant + alias inplace variants if exist
inplace_ops = [inplace, ] inplace_ops = [
inplace,
]
variants = [method, inplace] variants = [method, inplace]
for a_op in op.aliases: for a_op in op.aliases:
@ -623,30 +781,46 @@ class TestCommon(TestCase):
inplace_variants = tuple(filter(None, inplace_ops)) inplace_variants = tuple(filter(None, inplace_ops))
variants = tuple(filter(None, variants)) variants = tuple(filter(None, variants))
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type)) _requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs) samples = op.sample_inputs(
device,
dtype,
requires_grad=_requires_grad,
include_conjugated_inputs=include_conjugated_inputs,
)
samples = list(samples) samples = list(samples)
def _test_consistency_helper(samples, variants): def _test_consistency_helper(samples, variants):
for sample in samples: for sample in samples:
# TODO: Check grad for all Tensors requiring grad if sample.input is TensorList # TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
# Computes function forward and backward values # Computes function forward and backward values
tensor.grad = None tensor.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None expected_grad = None
output_process_fn_grad = sample.output_process_fn_grad if sample.output_process_fn_grad \ output_process_fn_grad = (
sample.output_process_fn_grad
if sample.output_process_fn_grad
else lambda x: x else lambda x: x
)
# Skips inplace variants if the output dtype is not the same as # Skips inplace variants if the output dtype is not the same as
# the input dtype # the input dtype
skip_inplace = False skip_inplace = False
if (isinstance(expected_forward, torch.Tensor) and if (
expected_forward.dtype is not tensor.dtype): isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is not tensor.dtype
):
skip_inplace = True skip_inplace = True
# TODO: backward consistency only supported for single tensor outputs # TODO: backward consistency only supported for single tensor outputs
@ -654,8 +828,9 @@ class TestCommon(TestCase):
# tensor inputs # tensor inputs
# TODO: update to handle checking grads of all tensor inputs as # TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output # derived from each tensor output
if (isinstance(expected_forward, torch.Tensor) if isinstance(
and dtype in op.supported_backward_dtypes(torch.device(device).type)): expected_forward, torch.Tensor
) and dtype in op.supported_backward_dtypes(torch.device(device).type):
output_process_fn_grad(expected_forward).sum().backward() output_process_fn_grad(expected_forward).sum().backward()
expected_grad = tensor.grad expected_grad = tensor.grad
@ -668,26 +843,35 @@ class TestCommon(TestCase):
# Compares variant's forward # Compares variant's forward
# Note: copies the to-be-modified input when testing the inplace variant # Note: copies the to-be-modified input when testing the inplace variant
tensor.grad = None tensor.grad = None
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input cloned = (
clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
if variant in inplace_ops and sample.broadcasts_input: if variant in inplace_ops and sample.broadcasts_input:
with self.assertRaises(RuntimeError, with self.assertRaises(
msg=('inplace variant either incorrectly allowed ' RuntimeError,
'resizing or you have marked the sample {}' msg=(
' incorrectly with `broadcasts_self=True'.format(sample.summary()))): "inplace variant either incorrectly allowed "
variant_forward = variant(cloned, "resizing or you have marked the sample {}"
*sample.args, " incorrectly with `broadcasts_self=True".format(
**sample.kwargs) sample.summary()
)
),
):
variant_forward = variant(
cloned, *sample.args, **sample.kwargs
)
continue continue
variant_forward = variant(cloned, variant_forward = variant(cloned, *sample.args, **sample.kwargs)
*sample.args,
**sample.kwargs)
self.assertEqual(expected_forward, variant_forward) self.assertEqual(expected_forward, variant_forward)
# Compares variant's backward # Compares variant's backward
if expected_grad is not None and \ if expected_grad is not None and (
(variant not in inplace_ops or op.supports_inplace_autograd): variant not in inplace_ops or op.supports_inplace_autograd
):
output_process_fn_grad(variant_forward).sum().backward() output_process_fn_grad(variant_forward).sum().backward()
self.assertEqual(expected_grad, tensor.grad) self.assertEqual(expected_grad, tensor.grad)
@ -698,28 +882,45 @@ class TestCommon(TestCase):
# Skips inplace variants if the output dtype is not the same as # Skips inplace variants if the output dtype is not the same as
# the input dtype # the input dtype
expected_forward = op(sample.input, *sample.args, **sample.kwargs) expected_forward = op(sample.input, *sample.args, **sample.kwargs)
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
skip_inplace = False skip_inplace = False
if (isinstance(expected_forward, torch.Tensor) and if (
expected_forward.dtype is not tensor.dtype): isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is not tensor.dtype
):
skip_inplace = True skip_inplace = True
if skip_inplace: if skip_inplace:
return return
for variant in variants: for variant in variants:
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input cloned = (
inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0] clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
inp_tensor = (
cloned if isinstance(cloned, torch.Tensor) else cloned[0]
)
data_ptr = inp_tensor.data_ptr() data_ptr = inp_tensor.data_ptr()
variant_forward = variant(cloned, variant_forward = variant(cloned, *sample.args, **sample.kwargs)
*sample.args,
**sample.kwargs)
# TODO Support non-tensor outputs if they exist for inplace ops # TODO Support non-tensor outputs if they exist for inplace ops
if (isinstance(variant_forward, torch.Tensor)): if isinstance(variant_forward, torch.Tensor):
self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0) self.assertEqual(
data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
)
else: else:
self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported") self.assertTrue(
False,
"Non-tensor outputs for inplace ops are not supported",
)
if len(inplace_ops) > 0: if len(inplace_ops) > 0:
inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples)) inplace_samples = list(
filter(lambda sample: not sample.broadcasts_input, samples)
)
_test_inplace_preserve_storage(inplace_samples, inplace_variants) _test_inplace_preserve_storage(inplace_samples, inplace_variants)
# Reference testing for operations in complex32 against complex64. # Reference testing for operations in complex32 against complex64.
@ -732,14 +933,21 @@ class TestCommon(TestCase):
for sample in op.sample_inputs(device, dtype): for sample in op.sample_inputs(device, dtype):
actual = op(sample.input, *sample.args, **sample.kwargs) actual = op(sample.input, *sample.args, **sample.kwargs)
transformed_sample = sample.transform(lambda x: x.to(torch.complex64)) transformed_sample = sample.transform(lambda x: x.to(torch.complex64))
expected = op(transformed_sample.input, *transformed_sample.args, **transformed_sample.kwargs) expected = op(
transformed_sample.input,
*transformed_sample.args,
**transformed_sample.kwargs,
)
self.assertEqual(actual, expected, exact_dtype=False) self.assertEqual(actual, expected, exact_dtype=False)
class TestCompositeCompliance(TestCase): class TestCompositeCompliance(TestCase):
# Checks if the operator (if it is composite) is written to support most # Checks if the operator (if it is composite) is written to support most
# backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance" # backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
# in aten/src/ATen/native/README.md for more details # in aten/src/ATen/native/README.md for more details
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode') @unittest.skipIf(
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
)
@ops(op_db, allowed_dtypes=(torch.float,)) @ops(op_db, allowed_dtypes=(torch.float,))
def test_operator(self, device, dtype, op): def test_operator(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=False) samples = op.sample_inputs(device, dtype, requires_grad=False)
@ -750,7 +958,9 @@ class TestCompositeCompliance(TestCase):
composite_compliance.check_with_mode(op, args, kwargs) composite_compliance.check_with_mode(op, args, kwargs)
composite_compliance.check_all_permutations(op, args, kwargs) composite_compliance.check_all_permutations(op, args, kwargs)
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode') @unittest.skipIf(
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
)
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,)) @ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
def test_backward(self, device, dtype, op): def test_backward(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True) samples = op.sample_inputs(device, dtype, requires_grad=True)
@ -760,7 +970,9 @@ class TestCompositeCompliance(TestCase):
kwargs = sample.kwargs kwargs = sample.kwargs
composite_compliance.check_backward_formula(op, args, kwargs) composite_compliance.check_backward_formula(op, args, kwargs)
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode') @unittest.skipIf(
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
)
@ops(op_db, allowed_dtypes=(torch.float,)) @ops(op_db, allowed_dtypes=(torch.float,))
def test_forward_ad(self, device, dtype, op): def test_forward_ad(self, device, dtype, op):
if torch.float not in op.supported_backward_dtypes(device): if torch.float not in op.supported_backward_dtypes(device):
@ -776,6 +988,7 @@ class TestCompositeCompliance(TestCase):
kwargs = sample.kwargs kwargs = sample.kwargs
composite_compliance.check_forward_ad_formula(op, args, kwargs) composite_compliance.check_forward_ad_formula(op, args, kwargs)
class TestMathBits(TestCase): class TestMathBits(TestCase):
# Tests that # Tests that
# 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors # 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
@ -787,7 +1000,17 @@ class TestMathBits(TestCase):
# This test only runs for C -> R and C -> C functions # This test only runs for C -> R and C -> C functions
# TODO: add tests for `R->C` functions # TODO: add tests for `R->C` functions
# Note: This test runs for functions that take both tensors and tensorlists as input. # Note: This test runs for functions that take both tensors and tensorlists as input.
def _test_math_view(self, device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, out_type): def _test_math_view(
self,
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
out_type,
):
inplace_variant = op.inplace_variant inplace_variant = op.inplace_variant
# helper function to clone and conjugate/negate the input if its a tensor # helper function to clone and conjugate/negate the input if its a tensor
@ -796,7 +1019,7 @@ class TestMathBits(TestCase):
# have its requires_grad set to that value. # have its requires_grad set to that value.
def clone_and_perform_view(input, **kwargs): def clone_and_perform_view(input, **kwargs):
if isinstance(input, torch.Tensor): if isinstance(input, torch.Tensor):
requires_grad = kwargs.get('requires_grad', input.requires_grad) requires_grad = kwargs.get("requires_grad", input.requires_grad)
with torch.no_grad(): with torch.no_grad():
# Ensure view represents the original sample input # Ensure view represents the original sample input
input = math_op_physical(input) input = math_op_physical(input)
@ -813,7 +1036,11 @@ class TestMathBits(TestCase):
return tuple(out) return tuple(out)
for sample in samples: for sample in samples:
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
cloned1 = clone_and_perform_view(sample.input) cloned1 = clone_and_perform_view(sample.input)
# Computes function forward value with a physically conjugated/negated tensor and # Computes function forward value with a physically conjugated/negated tensor and
@ -827,9 +1054,13 @@ class TestMathBits(TestCase):
# input produces correct output, and the output tensor has the conj/neg bit set to True # input produces correct output, and the output tensor has the conj/neg bit set to True
if inplace_variant is not None and not sample.broadcasts_input: if inplace_variant is not None and not sample.broadcasts_input:
cloned2 = clone_and_perform_view(tensor, requires_grad=False) cloned2 = clone_and_perform_view(tensor, requires_grad=False)
if (isinstance(expected_forward, torch.Tensor) and if (
expected_forward.dtype is tensor.dtype): isinstance(expected_forward, torch.Tensor)
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs) and expected_forward.dtype is tensor.dtype
):
inplace_forward = inplace_variant(
cloned2, *sample.args, **sample.kwargs
)
self.assertTrue(is_bit_set(inplace_forward)) self.assertTrue(is_bit_set(inplace_forward))
self.assertEqual(inplace_forward, expected_forward) self.assertEqual(inplace_forward, expected_forward)
@ -838,25 +1069,36 @@ class TestMathBits(TestCase):
# tensor inputs # tensor inputs
# TODO: update to handle checking grads of all tensor inputs as # TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output # derived from each tensor output
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad: if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.requires_grad
):
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x) output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
expected_forward = output_process_fn_grad(expected_forward) expected_forward = output_process_fn_grad(expected_forward)
forward_with_mathview = output_process_fn_grad(forward_with_mathview) forward_with_mathview = output_process_fn_grad(forward_with_mathview)
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0] tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
expected_forward.sum().backward(retain_graph=True) expected_forward.sum().backward(retain_graph=True)
forward_with_mathview.sum().backward(retain_graph=True) forward_with_mathview.sum().backward(retain_graph=True)
if tensor.grad is not None: if tensor.grad is not None:
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0] cloned1_tensor = (
cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
)
self.assertEqual(tensor.grad, cloned1_tensor.grad) self.assertEqual(tensor.grad, cloned1_tensor.grad)
tensor.grad, cloned1_tensor.grad = None, None tensor.grad, cloned1_tensor.grad = None, None
# a repeat of the above test if output is not complex valued # a repeat of the above test if output is not complex valued
if (out_type(expected_forward)): if out_type(expected_forward):
grad = torch.randn_like(expected_forward) grad = torch.randn_like(expected_forward)
expected_forward.backward(grad) expected_forward.backward(grad)
forward_with_mathview.backward(math_op_view(math_op_physical(grad))) forward_with_mathview.backward(
math_op_view(math_op_physical(grad))
)
self.assertEqual(tensor.grad, cloned1_tensor.grad) self.assertEqual(tensor.grad, cloned1_tensor.grad)
@ -866,10 +1108,21 @@ class TestMathBits(TestCase):
self.skipTest("Operation doesn't support conjugated inputs.") self.skipTest("Operation doesn't support conjugated inputs.")
math_op_physical = torch.conj_physical math_op_physical = torch.conj_physical
math_op_view = torch.conj math_op_view = torch.conj
_requires_grad = torch.cfloat in op.supported_backward_dtypes(torch.device(device).type) _requires_grad = torch.cfloat in op.supported_backward_dtypes(
torch.device(device).type
)
is_bit_set = torch.is_conj is_bit_set = torch.is_conj
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, torch.is_complex) self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
torch.is_complex,
)
@ops(op_db, allowed_dtypes=(torch.double,)) @ops(op_db, allowed_dtypes=(torch.double,))
def test_neg_view(self, device, dtype, op): def test_neg_view(self, device, dtype, op):
@ -879,8 +1132,16 @@ class TestMathBits(TestCase):
math_op_view = torch._neg_view math_op_view = torch._neg_view
is_bit_set = torch.is_neg is_bit_set = torch.is_neg
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd) samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, self._test_math_view(
lambda x: True) device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
lambda x: True,
)
@ops(op_db, allowed_dtypes=(torch.cdouble,)) @ops(op_db, allowed_dtypes=(torch.cdouble,))
def test_neg_conj_view(self, device, dtype, op): def test_neg_conj_view(self, device, dtype, op):
@ -898,17 +1159,27 @@ class TestMathBits(TestCase):
def is_bit_set(x): def is_bit_set(x):
return torch.is_neg(x) and torch.is_conj(x) return torch.is_neg(x) and torch.is_conj(x)
_requires_grad = dtype in op.supported_backward_dtypes(torch.device(device).type) _requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad) samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
# Only test one sample # Only test one sample
samples = itertools.islice(samples, 1) samples = itertools.islice(samples, 1)
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, self._test_math_view(
torch.is_complex) device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
torch.is_complex,
)
instantiate_device_type_tests(TestCommon, globals()) instantiate_device_type_tests(TestCommon, globals())
instantiate_device_type_tests(TestCompositeCompliance, globals()) instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals()) instantiate_device_type_tests(TestMathBits, globals())
if __name__ == '__main__': if __name__ == "__main__":
run_tests() run_tests()

View File

@ -1,85 +1,83 @@
# Owner(s): ["module: primTorch"] # Owner(s): ["module: primTorch"]
# TODO: uncomment this file once CI issues with import nvfuser are resolved from functools import partial
# from functools import partial import torch
from torch.testing import make_tensor
# import torch from torch.testing._internal.common_utils import run_tests, TestCase
# from torch.testing import make_tensor from torch.testing._internal.common_device_type import (
# from torch.testing._internal.common_utils import run_tests, TestCase instantiate_device_type_tests,
# from torch.testing._internal.common_device_type import ( onlyCUDA,
# instantiate_device_type_tests, dtypes,
# onlyCUDA, )
# dtypes, import torch._prims as prims
# ) from torch._prims.executor import make_traced
# import torch._prims as prims
# from torch._prims.executor import make_traced
# class TestPrims(TestCase): class TestPrims(TestCase):
# @onlyCUDA @onlyCUDA
# @dtypes(torch.float32) @dtypes(torch.float32)
# def test_broadcast_in_dim(self, device, dtype): def test_broadcast_in_dim(self, device, dtype):
# def _wrapper(a, shape, broadcast_dimensions): def _wrapper(a, shape, broadcast_dimensions):
# return prims.broadcast_in_dim(a, shape, broadcast_dimensions) return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
# traced = make_traced(_wrapper) traced = make_traced(_wrapper)
# make_arg = partial(make_tensor, device=device, dtype=dtype) make_arg = partial(make_tensor, device=device, dtype=dtype)
# # TODO: FIXME: # TODO: FIXME:
# # for executor in ('aten', 'nvfuser'): # for executor in ('aten', 'nvfuser'):
# for executor in ("aten",): for executor in ("aten",):
# fn = partial(traced, executor=executor) fn = partial(traced, executor=executor)
# # Same shape # Same shape
# shape = (5, 5) shape = (5, 5)
# a = make_arg(shape) a = make_arg(shape)
# result = fn(a, shape, (0, 1)) result = fn(a, shape, (0, 1))
# self.assertEqual(result.shape, a.shape) self.assertEqual(result.shape, a.shape)
# self.assertTrue(result.is_contiguous) self.assertTrue(result.is_contiguous)
# self.assertEqual(a, result) self.assertEqual(a, result)
# # Error input: reordering dims # Error input: reordering dims
# with self.assertRaises(Exception): with self.assertRaises(Exception):
# result = fn(a, shape, (1, 0)) result = fn(a, shape, (1, 0))
# # Adding outermost dimensions # Adding outermost dimensions
# a = make_arg((5, 5)) a = make_arg((5, 5))
# target_shape = (3, 3, 5, 5) target_shape = (3, 3, 5, 5)
# result = fn(a, target_shape, (2, 3)) result = fn(a, target_shape, (2, 3))
# self.assertEqual(result.shape, target_shape) self.assertEqual(result.shape, target_shape)
# self.assertEqual(a.broadcast_to(target_shape), result) self.assertEqual(a.broadcast_to(target_shape), result)
# # Expands # Expands
# a = make_arg((1, 5, 1)) a = make_arg((1, 5, 1))
# target_shape = (3, 5, 7) target_shape = (3, 5, 7)
# result = fn(a, target_shape, (0, 1, 2)) result = fn(a, target_shape, (0, 1, 2))
# self.assertEqual(result.shape, target_shape) self.assertEqual(result.shape, target_shape)
# self.assertEqual(a.expand_as(result), result) self.assertEqual(a.expand_as(result), result)
# # Unsqueezes # Unsqueezes
# a = make_arg((1, 2, 3)) a = make_arg((1, 2, 3))
# target_shape = (1, 2, 1, 3) target_shape = (1, 2, 1, 3)
# result = fn(a, target_shape, (0, 1, 3)) result = fn(a, target_shape, (0, 1, 3))
# self.assertEqual(result.shape, target_shape) self.assertEqual(result.shape, target_shape)
# self.assertEqual(a.unsqueeze(2), result) self.assertEqual(a.unsqueeze(2), result)
# # Adds outermost, expands, and unsqueezes # Adds outermost, expands, and unsqueezes
# a = make_arg((1, 2, 3)) a = make_arg((1, 2, 3))
# target_shape = (4, 1, 7, 2, 3, 3) target_shape = (4, 1, 7, 2, 3, 3)
# result = fn(a, target_shape, (1, 3, 4)) result = fn(a, target_shape, (1, 3, 4))
# self.assertEqual(result.shape, target_shape) self.assertEqual(result.shape, target_shape)
# a.unsqueeze_(3) a.unsqueeze_(3)
# a.unsqueeze_(1) a.unsqueeze_(1)
# a.unsqueeze_(0) a.unsqueeze_(0)
# self.assertEqual(a.expand_as(result), result) self.assertEqual(a.expand_as(result), result)
# instantiate_device_type_tests(TestPrims, globals()) instantiate_device_type_tests(TestPrims, globals())
# if __name__ == "__main__": if __name__ == "__main__":
# run_tests() run_tests()

View File

@ -13,7 +13,7 @@ import random
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings, TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
torch_to_numpy_dtype_dict, slowTest, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize) TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize)
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes, expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
@ -3683,13 +3683,13 @@ class TestBufferProtocol(TestCase):
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr()) self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
return (numpy_original, torch_frombuffer) return (numpy_original, torch_frombuffer)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_same_type(self, device, dtype): def test_same_type(self, device, dtype):
self._run_test((), dtype) self._run_test((), dtype)
self._run_test((4,), dtype) self._run_test((4,), dtype)
self._run_test((10, 10), dtype) self._run_test((10, 10), dtype)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_requires_grad(self, device, dtype): def test_requires_grad(self, device, dtype):
def _run_test_and_check_grad(requires_grad, *args, **kwargs): def _run_test_and_check_grad(requires_grad, *args, **kwargs):
kwargs["requires_grad"] = requires_grad kwargs["requires_grad"] = requires_grad
@ -3704,14 +3704,14 @@ class TestBufferProtocol(TestCase):
_run_test_and_check_grad(False, (4,), dtype) _run_test_and_check_grad(False, (4,), dtype)
_run_test_and_check_grad(False, (10, 10), dtype) _run_test_and_check_grad(False, (10, 10), dtype)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_with_offset(self, device, dtype): def test_with_offset(self, device, dtype):
# Offset should be valid whenever there is, at least, # Offset should be valid whenever there is, at least,
# one remaining element # one remaining element
for i in range(SIZE): for i in range(SIZE):
self._run_test(SHAPE, dtype, first=i) self._run_test(SHAPE, dtype, first=i)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_with_count(self, device, dtype): def test_with_count(self, device, dtype):
# Count should be valid for any valid in the interval # Count should be valid for any valid in the interval
# [-1, len(input)], except for 0 # [-1, len(input)], except for 0
@ -3719,7 +3719,7 @@ class TestBufferProtocol(TestCase):
if i != 0: if i != 0:
self._run_test(SHAPE, dtype, count=i) self._run_test(SHAPE, dtype, count=i)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_with_count_and_offset(self, device, dtype): def test_with_count_and_offset(self, device, dtype):
# Explicit default count [-1, 1, 2, ..., len] # Explicit default count [-1, 1, 2, ..., len]
for i in range(-1, SIZE + 1): for i in range(-1, SIZE + 1):
@ -3735,7 +3735,7 @@ class TestBufferProtocol(TestCase):
for j in range(SIZE - i + 1): for j in range(SIZE - i + 1):
self._run_test(SHAPE, dtype, count=i, first=j) self._run_test(SHAPE, dtype, count=i, first=j)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_invalid_positional_args(self, device, dtype): def test_invalid_positional_args(self, device, dtype):
bytes = get_dtype_size(dtype) bytes = get_dtype_size(dtype)
in_bytes = SIZE * bytes in_bytes = SIZE * bytes
@ -3772,7 +3772,7 @@ class TestBufferProtocol(TestCase):
rf"buffer length \({in_bytes} bytes\)"): rf"buffer length \({in_bytes} bytes\)"):
self._run_test(SHAPE, dtype, count=count, first=first) self._run_test(SHAPE, dtype, count=count, first=first)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_shared_buffer(self, device, dtype): def test_shared_buffer(self, device, dtype):
x = make_tensor((1,), dtype=dtype, device=device) x = make_tensor((1,), dtype=dtype, device=device)
# Modify the whole tensor # Modify the whole tensor
@ -3799,13 +3799,13 @@ class TestBufferProtocol(TestCase):
arr[first] = x.item() - 1 arr[first] = x.item() - 1
self.assertEqual(arr[first:last], tensor) self.assertEqual(arr[first:last], tensor)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_not_a_buffer(self, device, dtype): def test_not_a_buffer(self, device, dtype):
with self.assertRaisesRegex(ValueError, with self.assertRaisesRegex(ValueError,
r"object does not implement Python buffer protocol."): r"object does not implement Python buffer protocol."):
torch.frombuffer([1, 2, 3, 4], dtype=dtype) torch.frombuffer([1, 2, 3, 4], dtype=dtype)
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_non_writable_buffer(self, device, dtype): def test_non_writable_buffer(self, device, dtype):
numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy() numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy()
byte_arr = numpy_arr.tobytes() byte_arr = numpy_arr.tobytes()
@ -3910,7 +3910,7 @@ class TestAsArray(TestCase):
self._test_alias_with_cvt(identity, device, dtype) self._test_alias_with_cvt(identity, device, dtype)
@onlyCPU @onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_alias_from_numpy(self, device, dtype): def test_alias_from_numpy(self, device, dtype):
self._test_alias_with_cvt(to_numpy, device, dtype) self._test_alias_with_cvt(to_numpy, device, dtype)
@ -3921,7 +3921,7 @@ class TestAsArray(TestCase):
self._test_alias_with_cvt(to_dlpack, device, dtype) self._test_alias_with_cvt(to_dlpack, device, dtype)
@onlyCPU @onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_alias_from_buffer(self, device, dtype): def test_alias_from_buffer(self, device, dtype):
self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True) self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
@ -3959,7 +3959,7 @@ class TestAsArray(TestCase):
self._test_copy_with_cvt(identity, device, dtype) self._test_copy_with_cvt(identity, device, dtype)
@onlyCPU @onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_copy_from_numpy(self, device, dtype): def test_copy_from_numpy(self, device, dtype):
self._test_copy_with_cvt(to_numpy, device, dtype) self._test_copy_with_cvt(to_numpy, device, dtype)
@ -3969,7 +3969,7 @@ class TestAsArray(TestCase):
self._test_copy_with_cvt(to_dlpack, device, dtype) self._test_copy_with_cvt(to_dlpack, device, dtype)
@onlyCPU @onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys()) @dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_copy_from_buffer(self, device, dtype): def test_copy_from_buffer(self, device, dtype):
self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True) self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)

View File

@ -7,15 +7,14 @@ import unittest
import torch import torch
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests, from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
TEST_NUMPY, torch_to_numpy_dtype_dict) TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict)
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes, from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta) dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta)
from torch.testing._internal.common_dtype import ( from torch.testing._internal.common_dtype import (
all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and
) )
if TEST_NUMPY: import numpy as np
import numpy as np
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for # load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings # sharding on sandcastle. This line silences flake warnings
@ -812,8 +811,8 @@ class TestTypePromotion(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found") @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@float_double_default_dtype @float_double_default_dtype
@onlyCPU @onlyCPU
@dtypes(*list(itertools.product(torch_to_numpy_dtype_dict.keys(), @dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()),
torch_to_numpy_dtype_dict.keys()))) set(numpy_to_torch_dtype_dict.values()))))
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes): def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
import operator import operator
np_type = torch_to_numpy_dtype_dict[dtypes[0]] np_type = torch_to_numpy_dtype_dict[dtypes[0]]

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,7 @@ import random
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_utils import ( from torch.testing._internal.common_utils import (
TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict,
) )
from torch.testing._internal.common_device_type import \ from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
@ -130,7 +130,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes @onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool)) @dtypes(*all_types_and_complex_and(torch.half, torch.bool))
def test_view_dtype_new(self, device, dtype): def test_view_dtype_new(self, device, dtype):
dtypes = torch_to_numpy_dtype_dict.copy() dtypes = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
del dtypes[torch.bool] del dtypes[torch.bool]
def generate_inputs(): def generate_inputs():

View File

@ -2,14 +2,15 @@ import torch
from torch import Tensor from torch import Tensor
import torch._prims.utils as utils import torch._prims.utils as utils
from torch._prims.utils import TensorLike, TensorLikeType, TensorMeta, ShapeType from torch._prims.utils import (
TensorLike,
TensorLikeType,
TensorMeta,
ShapeType,
getnvFuserDtype,
)
from torch.overrides import has_torch_function, handle_torch_function from torch.overrides import has_torch_function, handle_torch_function
import torch._C._nvfuser as nvfuser # type: ignore[import]
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
DataType = nvfuser.DataType # type: ignore[attr-defined]
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any from typing import Sequence, Optional, Union, Callable, List, Tuple, Any
from numbers import Number from numbers import Number
from functools import reduce from functools import reduce
@ -1033,21 +1034,8 @@ def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
return a.to(dtype) return a.to(dtype)
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor: def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor:
nvfuser_dtype = _torch_dtype_to_nvfuser_dtype_map[dtype] nvfuser_dtype = getnvFuserDtype(dtype)
return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined] return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined]

View File

@ -3,27 +3,11 @@ from typing import Callable
import torch import torch
from torch.fx import GraphModule from torch.fx import GraphModule
from torch._prims.utils import TensorMeta from torch._prims.utils import TensorMeta, getnvFuserDtype
from torch._prims.context import PrimContext from torch._prims.context import PrimContext
import torch._C._nvfuser as nvfuser # type: ignore[import] if torch.cuda.is_available():
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
DataType = nvfuser.DataType # type: ignore[attr-defined]
Fusion = nvfuser.Fusion # type: ignore[attr-defined]
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
# TODO: refactor me into a common place
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs): def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
@ -37,6 +21,11 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
gm = GraphModule({}, ctx.graph) gm = GraphModule({}, ctx.graph)
return gm.forward(*args, **kwargs) return gm.forward(*args, **kwargs)
elif executor == "nvfuser": elif executor == "nvfuser":
if not torch.cuda.is_available():
raise RuntimeError(
"Attempting to use nvFuser trace executor but CUDA is not available!"
)
# PROTOTYPE nvfuser executor # PROTOTYPE nvfuser executor
# Only accepts tensor inputs and single tensor outputs # Only accepts tensor inputs and single tensor outputs
# Does not handle kwargs # Does not handle kwargs
@ -53,9 +42,7 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
nv_args = [fd] nv_args = [fd]
for arg in args: for arg in args:
if isinstance(arg, torch.Tensor): if isinstance(arg, torch.Tensor):
x = fd.define_tensor( x = fd.define_tensor(arg.ndim, getnvFuserDtype(arg.dtype))
arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype]
)
fd.add_input(x) fd.add_input(x)
nv_args.append(x) nv_args.append(x)
else: else:

View File

@ -8,6 +8,32 @@ import threading
import torch import torch
from torch.fx import Node from torch.fx import Node
# nvFuser imports are conditional on CUDA being available
if torch.cuda.is_available():
from torch._C._nvfuser import DataType # type: ignore[import]
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
else:
_torch_dtype_to_nvfuser_dtype_map = {}
def getnvFuserDtype(dtype: torch.dtype):
"""
Translates from torch.dtype to nvFuser's DataType enum
"""
return _torch_dtype_to_nvfuser_dtype_map[dtype]
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]] ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
StrideType = Union[List[int], Tuple[int, ...]] StrideType = Union[List[int], Tuple[int, ...]]
DimsType = Union[int, List[int], Tuple[int, ...]] DimsType = Union[int, List[int], Tuple[int, ...]]

File diff suppressed because it is too large Load Diff

View File

@ -871,6 +871,10 @@ if IS_WINDOWS:
# Dict of torch dtype -> NumPy dtype # Dict of torch dtype -> NumPy dtype
torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
torch_to_numpy_dtype_dict.update({
torch.bfloat16: np.float32,
torch.complex32: np.complex64
})
def skipIfRocm(fn): def skipIfRocm(fn):
@wraps(fn) @wraps(fn)