mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Adds python ref consistency test, elementwise unary reference inputs, and formats test files
Per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
33be4c94c0
commit
f6bbecf8b5
File diff suppressed because it is too large
Load Diff
559
test/test_ops.py
559
test/test_ops.py
@ -8,17 +8,44 @@ import itertools
|
||||
import torch
|
||||
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_dtype import floating_and_complex_types_and, all_types_and_complex_and
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper,
|
||||
IS_IN_CI, suppress_warnings, noncontiguous_like,
|
||||
TEST_WITH_ASAN, IS_WINDOWS, IS_FBCODE, first_sample)
|
||||
from torch.testing._internal.common_methods_invocations import \
|
||||
(op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, ops_and_refs)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(deviceCountAtLeast, instantiate_device_type_tests, ops,
|
||||
onlyCUDA, onlyNativeDeviceTypes, OpDTypes, skipMeta)
|
||||
# import torch._prims as prims
|
||||
from torch.testing._internal.common_dtype import (
|
||||
floating_and_complex_types_and,
|
||||
all_types_and_complex_and,
|
||||
)
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase,
|
||||
is_iterable_of_tensors,
|
||||
run_tests,
|
||||
IS_SANDCASTLE,
|
||||
clone_input_helper,
|
||||
IS_IN_CI,
|
||||
suppress_warnings,
|
||||
noncontiguous_like,
|
||||
TEST_WITH_ASAN,
|
||||
IS_WINDOWS,
|
||||
IS_FBCODE,
|
||||
first_sample,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
op_db,
|
||||
_NOTHING,
|
||||
UnaryUfuncInfo,
|
||||
ReductionOpInfo,
|
||||
SpectralFuncInfo,
|
||||
ops_and_refs,
|
||||
python_ref_db,
|
||||
BinaryUfuncInfo,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
deviceCountAtLeast,
|
||||
instantiate_device_type_tests,
|
||||
ops,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
OpDTypes,
|
||||
skipMeta,
|
||||
)
|
||||
import torch._prims as prims
|
||||
|
||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||
from torch.testing._internal import composite_compliance
|
||||
@ -28,15 +55,25 @@ torch.set_default_dtype(torch.float32)
|
||||
|
||||
# variant testing is only done with torch.float and torch.cfloat to avoid
|
||||
# excessive test times and maximize signal to noise ratio
|
||||
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float, torch.cfloat))
|
||||
_variant_ops = partial(
|
||||
ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
|
||||
)
|
||||
|
||||
# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
|
||||
# except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py)
|
||||
# except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
|
||||
# elementwise binary operators (separately implemented in test_binary_ufuncs.py),
|
||||
# reduction operations (separately impelemented in test_reductions.py),
|
||||
# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
|
||||
_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo,
|
||||
SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db))
|
||||
|
||||
_ref_test_ops = tuple(
|
||||
filter(
|
||||
lambda op: not isinstance(
|
||||
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
|
||||
)
|
||||
and op.ref is not None
|
||||
and op.ref is not _NOTHING,
|
||||
op_db,
|
||||
)
|
||||
)
|
||||
|
||||
# Tests that apply to all operators and aren't related to any particular
|
||||
# system
|
||||
@ -49,8 +86,10 @@ class TestCommon(TestCase):
|
||||
super().tearDownClass()
|
||||
|
||||
if IS_IN_CI:
|
||||
err_msg = ("The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
|
||||
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!")
|
||||
err_msg = (
|
||||
"The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
|
||||
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
|
||||
)
|
||||
# Assure no opinfo entry has dynamic_dtypes
|
||||
filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db))
|
||||
for op in filtered_ops:
|
||||
@ -68,11 +107,16 @@ class TestCommon(TestCase):
|
||||
# Check complex32 support only if the op claims.
|
||||
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
|
||||
device_type = torch.device(device).type
|
||||
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device_type) else ())
|
||||
include_complex32 = (
|
||||
(torch.complex32,)
|
||||
if op.supports_dtype(torch.complex32, device_type)
|
||||
else ()
|
||||
)
|
||||
|
||||
# dtypes to try to backward in
|
||||
allowed_backward_dtypes = floating_and_complex_types_and(
|
||||
*((torch.half, torch.bfloat16) + include_complex32))
|
||||
*((torch.half, torch.bfloat16) + include_complex32)
|
||||
)
|
||||
|
||||
# lists for (un)supported dtypes
|
||||
supported_dtypes = set()
|
||||
@ -86,11 +130,14 @@ class TestCommon(TestCase):
|
||||
unsupported_backward_dtypes.add(dtype)
|
||||
|
||||
for dtype in all_types_and_complex_and(
|
||||
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)):
|
||||
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)
|
||||
):
|
||||
# tries to acquire samples - failure indicates lack of support
|
||||
requires_grad = (dtype in allowed_backward_dtypes)
|
||||
requires_grad = dtype in allowed_backward_dtypes
|
||||
try:
|
||||
samples = tuple(op.sample_inputs(device, dtype, requires_grad=requires_grad))
|
||||
samples = tuple(
|
||||
op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
||||
)
|
||||
except Exception as e:
|
||||
unsupported(dtype)
|
||||
continue
|
||||
@ -113,7 +160,9 @@ class TestCommon(TestCase):
|
||||
result = sample.output_process_fn_grad(result)
|
||||
if isinstance(result, torch.Tensor):
|
||||
backward_tensor = result
|
||||
elif isinstance(result, Sequence) and isinstance(result[0], torch.Tensor):
|
||||
elif isinstance(result, Sequence) and isinstance(
|
||||
result[0], torch.Tensor
|
||||
):
|
||||
backward_tensor = result[0]
|
||||
else:
|
||||
continue
|
||||
@ -130,14 +179,15 @@ class TestCommon(TestCase):
|
||||
except Exception as e:
|
||||
unsupported_backward_dtypes.add(dtype)
|
||||
|
||||
|
||||
# Checks that dtypes are listed correctly and generates an informative
|
||||
# error message
|
||||
supported_forward = supported_dtypes - unsupported_dtypes
|
||||
partially_supported_forward = supported_dtypes & unsupported_dtypes
|
||||
unsupported_forward = unsupported_dtypes - supported_dtypes
|
||||
supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
|
||||
partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes
|
||||
partially_supported_backward = (
|
||||
supported_backward_dtypes & unsupported_backward_dtypes
|
||||
)
|
||||
unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
|
||||
|
||||
device_type = torch.device(device).type
|
||||
@ -156,17 +206,27 @@ class TestCommon(TestCase):
|
||||
op.name, device_type
|
||||
)
|
||||
if len(partially_supported_forward) > 0:
|
||||
msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format(
|
||||
msg = (
|
||||
msg
|
||||
+ "The following dtypes only worked on some samples during forward: {0}.\n".format(
|
||||
partially_supported_forward
|
||||
)
|
||||
)
|
||||
if len(partially_supported_backward) > 0:
|
||||
msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format(
|
||||
msg = (
|
||||
msg
|
||||
+ "The following dtypes only worked on some samples during backward: {0}.\n".format(
|
||||
partially_supported_backward
|
||||
)
|
||||
)
|
||||
print(msg)
|
||||
|
||||
if (len(supported_but_unclaimed_forward) + len(claimed_but_unsupported_forward) +
|
||||
len(supported_but_unclaimed_backward) + len(claimed_but_unsupported_backward)) == 0:
|
||||
if (
|
||||
len(supported_but_unclaimed_forward)
|
||||
+ len(claimed_but_unsupported_forward)
|
||||
+ len(supported_but_unclaimed_backward)
|
||||
+ len(claimed_but_unsupported_backward)
|
||||
) == 0:
|
||||
return
|
||||
|
||||
# Generates error msg
|
||||
@ -174,21 +234,33 @@ class TestCommon(TestCase):
|
||||
op.name, device_type
|
||||
)
|
||||
if len(supported_but_unclaimed_forward) > 0:
|
||||
msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
|
||||
msg = (
|
||||
msg
|
||||
+ "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
|
||||
supported_but_unclaimed_forward
|
||||
)
|
||||
)
|
||||
if len(supported_but_unclaimed_backward) > 0:
|
||||
msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
|
||||
msg = (
|
||||
msg
|
||||
+ "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
|
||||
supported_but_unclaimed_backward
|
||||
)
|
||||
)
|
||||
if len(claimed_but_unsupported_forward) > 0:
|
||||
msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
|
||||
msg = (
|
||||
msg
|
||||
+ "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
|
||||
claimed_but_unsupported_forward
|
||||
)
|
||||
)
|
||||
if len(claimed_but_unsupported_backward) > 0:
|
||||
msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
|
||||
msg = (
|
||||
msg
|
||||
+ "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
|
||||
claimed_but_unsupported_backward
|
||||
)
|
||||
)
|
||||
|
||||
self.fail(msg)
|
||||
|
||||
@ -209,7 +281,9 @@ class TestCommon(TestCase):
|
||||
elif is_iterable_of_tensors(result):
|
||||
self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
|
||||
else:
|
||||
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
||||
self.skipTest(
|
||||
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||
)
|
||||
|
||||
# Tests that the function and its (ndarray-accepting) reference produce the same
|
||||
# values on the tensors from sample_inputs func for the corresponding op.
|
||||
@ -226,28 +300,48 @@ class TestCommon(TestCase):
|
||||
cur_default = torch.get_default_dtype()
|
||||
torch.set_default_dtype(torch.double)
|
||||
for sample_input in op.reference_inputs(device, dtype):
|
||||
self.compare_with_reference(op, op.ref, sample_input, exact_dtype=(dtype is not torch.long))
|
||||
self.compare_with_reference(
|
||||
op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
|
||||
)
|
||||
finally:
|
||||
torch.set_default_dtype(cur_default)
|
||||
|
||||
# Tests that experimental Python References' can propagate shape, dtype,
|
||||
# Tests that experimental Python References can propagate shape, dtype,
|
||||
# and device metadata properly.
|
||||
# TODO: include stride propagation.
|
||||
# @onlyNativeDeviceTypes
|
||||
# @ops(python_ref_db)
|
||||
# def test_python_reference_meta_functions(self, device, dtype, op):
|
||||
# def _to_tensormeta(x):
|
||||
# if isinstance(x, torch.Tensor):
|
||||
# return prims.utils.TensorMeta(x)
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(python_ref_db)
|
||||
def test_python_reference_meta_functions(self, device, dtype, op):
|
||||
def _to_tensormeta(x):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return prims.utils.TensorMeta(x)
|
||||
|
||||
# # TODO: iterate over requires_grad true/false
|
||||
# for sample in op.reference_inputs(device, dtype, requires_grad=False):
|
||||
# result = op(sample.input, *sample.args, **sample.kwargs)
|
||||
# TODO: iterate over requires_grad true/false
|
||||
for sample in op.reference_inputs(device, dtype, requires_grad=False):
|
||||
result = op(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
# meta_sample = sample.transform(_to_tensormeta)
|
||||
# meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
|
||||
meta_sample = sample.transform(_to_tensormeta)
|
||||
meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
|
||||
|
||||
# prims.utils.compare_tensor_meta(result, meta_result)
|
||||
prims.utils.compare_tensor_meta(result, meta_result)
|
||||
|
||||
# Tests that experimental Python References perform the same computation
|
||||
# as the operators they reference.
|
||||
@onlyNativeDeviceTypes
|
||||
@ops(python_ref_db)
|
||||
def test_python_reference_consistency(self, device, dtype, op):
|
||||
for sample in op.reference_inputs(device, dtype, requires_grad=False):
|
||||
actual = op(sample.input, *sample.args, **sample.kwargs)
|
||||
expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
|
||||
|
||||
self.assertEqual(
|
||||
actual,
|
||||
expected,
|
||||
exact_stride=True,
|
||||
exact_device=True,
|
||||
exact_layout=True,
|
||||
exact_is_coalesced=True,
|
||||
)
|
||||
|
||||
@skipMeta
|
||||
@onlyNativeDeviceTypes
|
||||
@ -272,9 +366,17 @@ class TestCommon(TestCase):
|
||||
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
|
||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
|
||||
for sample_input in sample_inputs:
|
||||
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
|
||||
t_inp, t_args, t_kwargs = (
|
||||
sample_input.input,
|
||||
sample_input.args,
|
||||
sample_input.kwargs,
|
||||
)
|
||||
noncontig_sample = sample_input.noncontiguous()
|
||||
n_inp, n_args, n_kwargs = noncontig_sample.input, noncontig_sample.args, noncontig_sample.kwargs
|
||||
n_inp, n_args, n_kwargs = (
|
||||
noncontig_sample.input,
|
||||
noncontig_sample.args,
|
||||
noncontig_sample.kwargs,
|
||||
)
|
||||
|
||||
# Verifies sample input tensors should have no grad or history
|
||||
sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0]
|
||||
@ -300,10 +402,14 @@ class TestCommon(TestCase):
|
||||
grad_for_actual = noncontiguous_like(grad_for_expected)
|
||||
elif isinstance(expected, Sequence):
|
||||
# Filter output elements that do not require grad
|
||||
expected = [t for t in expected
|
||||
if isinstance(t, torch.Tensor) and t.requires_grad]
|
||||
actual = [n for n in actual
|
||||
if isinstance(n, torch.Tensor) and n.requires_grad]
|
||||
expected = [
|
||||
t
|
||||
for t in expected
|
||||
if isinstance(t, torch.Tensor) and t.requires_grad
|
||||
]
|
||||
actual = [
|
||||
n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
|
||||
]
|
||||
grad_for_expected = [torch.randn_like(t) for t in expected]
|
||||
grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
|
||||
else:
|
||||
@ -311,19 +417,35 @@ class TestCommon(TestCase):
|
||||
continue
|
||||
|
||||
# Concatenate inputs into a tuple
|
||||
t_inputs = (t_inp,) + t_args if isinstance(t_inp, torch.Tensor) else tuple(t_inp) + t_args
|
||||
n_inputs = (n_inp,) + n_args if isinstance(n_inp, torch.Tensor) else tuple(n_inp) + n_args
|
||||
t_inputs = (
|
||||
(t_inp,) + t_args
|
||||
if isinstance(t_inp, torch.Tensor)
|
||||
else tuple(t_inp) + t_args
|
||||
)
|
||||
n_inputs = (
|
||||
(n_inp,) + n_args
|
||||
if isinstance(n_inp, torch.Tensor)
|
||||
else tuple(n_inp) + n_args
|
||||
)
|
||||
|
||||
# Filter the elemnts that are tensors that require grad
|
||||
t_input_tensors = [t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad]
|
||||
n_input_tensors = [n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad]
|
||||
t_input_tensors = [
|
||||
t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
|
||||
]
|
||||
n_input_tensors = [
|
||||
n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
|
||||
]
|
||||
|
||||
self.assertEqual(len(t_input_tensors), len(n_input_tensors))
|
||||
|
||||
# Some functions may not use all the inputs to generate gradients. One of the
|
||||
# few examples of this "odd" behaviour is F.hinge_embedding_loss
|
||||
t_grads = torch.autograd.grad(expected, t_input_tensors, grad_for_expected, allow_unused=True)
|
||||
n_grads = torch.autograd.grad(actual, n_input_tensors, grad_for_actual, allow_unused=True)
|
||||
t_grads = torch.autograd.grad(
|
||||
expected, t_input_tensors, grad_for_expected, allow_unused=True
|
||||
)
|
||||
n_grads = torch.autograd.grad(
|
||||
actual, n_input_tensors, grad_for_actual, allow_unused=True
|
||||
)
|
||||
|
||||
msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
|
||||
for i, (t, n) in enumerate(zip(t_grads, n_grads)):
|
||||
@ -339,7 +461,11 @@ class TestCommon(TestCase):
|
||||
supported_dtypes = op.supported_dtypes(self.device_type)
|
||||
if len(supported_dtypes) == 0:
|
||||
self.skipTest("Skipped! Op has not supported dtypes on this device.")
|
||||
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
|
||||
dtype = (
|
||||
torch.float32
|
||||
if torch.float32 in supported_dtypes
|
||||
else list(supported_dtypes)[0]
|
||||
)
|
||||
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
for sample in samples:
|
||||
@ -349,8 +475,12 @@ class TestCommon(TestCase):
|
||||
|
||||
# Short-circuits if output is not a single tensor or an
|
||||
# iterable of tensors
|
||||
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
|
||||
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
||||
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
|
||||
expected, include_empty=True
|
||||
):
|
||||
self.skipTest(
|
||||
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||
)
|
||||
|
||||
# Validates the op doesn't support out if it claims not to
|
||||
if not op.supports_out:
|
||||
@ -380,7 +510,7 @@ class TestCommon(TestCase):
|
||||
# NOTE: only extracts on the CPU and CUDA device types since some
|
||||
# device types don't have storage
|
||||
def _extract_data_ptrs(out):
|
||||
if self.device_type != 'cpu' and self.device_type != 'cuda':
|
||||
if self.device_type != "cpu" and self.device_type != "cuda":
|
||||
return ()
|
||||
|
||||
if isinstance(out, torch.Tensor):
|
||||
@ -403,7 +533,8 @@ class TestCommon(TestCase):
|
||||
|
||||
if compare_strides_and_data_ptrs:
|
||||
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
|
||||
original_strides, final_strides)
|
||||
original_strides, final_strides
|
||||
)
|
||||
self.assertEqual(original_strides, final_strides, msg=stride_msg)
|
||||
self.assertEqual(original_ptrs, final_ptrs)
|
||||
|
||||
@ -433,7 +564,9 @@ class TestCommon(TestCase):
|
||||
out = _apply_out_transform(_case_zero_transform, expected)
|
||||
msg_fail = "Resized a non-empty tensor but did not warn about it."
|
||||
if _any_nonempty(out):
|
||||
with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail):
|
||||
with self.assertWarnsRegex(
|
||||
UserWarning, "An output with one or more elements", msg=msg_fail
|
||||
):
|
||||
op_out(out=out)
|
||||
|
||||
# Validates ops implement the correct out= behavior
|
||||
@ -452,7 +585,11 @@ class TestCommon(TestCase):
|
||||
supported_dtypes = op.supported_dtypes(self.device_type)
|
||||
if len(supported_dtypes) == 0:
|
||||
self.skipTest("Skipped! Op has not supported dtypes on this device.")
|
||||
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
|
||||
dtype = (
|
||||
torch.float32
|
||||
if torch.float32 in supported_dtypes
|
||||
else list(supported_dtypes)[0]
|
||||
)
|
||||
|
||||
samples = op.sample_inputs(device, dtype)
|
||||
for sample in samples:
|
||||
@ -462,8 +599,12 @@ class TestCommon(TestCase):
|
||||
|
||||
# Short-circuits if output is not a single tensor or an
|
||||
# iterable of tensors
|
||||
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
|
||||
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
||||
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
|
||||
expected, include_empty=True
|
||||
):
|
||||
self.skipTest(
|
||||
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||
)
|
||||
|
||||
# Validates the op doesn't support out if it claims not to
|
||||
if not op.supports_out:
|
||||
@ -493,7 +634,7 @@ class TestCommon(TestCase):
|
||||
# NOTE: only extracts on the CPU and CUDA device types since some
|
||||
# device types don't have storage
|
||||
def _extract_data_ptrs(out):
|
||||
if self.device_type != 'cpu' and self.device_type != 'cuda':
|
||||
if self.device_type != "cpu" and self.device_type != "cuda":
|
||||
return ()
|
||||
|
||||
if isinstance(out, torch.Tensor):
|
||||
@ -515,7 +656,8 @@ class TestCommon(TestCase):
|
||||
|
||||
if compare_strides_and_data_ptrs:
|
||||
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
|
||||
original_strides, final_strides)
|
||||
original_strides, final_strides
|
||||
)
|
||||
self.assertEqual(original_strides, final_strides, msg=stride_msg)
|
||||
self.assertEqual(original_ptrs, final_ptrs)
|
||||
|
||||
@ -529,7 +671,7 @@ class TestCommon(TestCase):
|
||||
return torch.full_like(t, info.max)
|
||||
except TypeError as te:
|
||||
# for non-integer types fills with NaN
|
||||
return torch.full_like(t, float('nan'))
|
||||
return torch.full_like(t, float("nan"))
|
||||
|
||||
_compare_out(_case_zero_transform)
|
||||
|
||||
@ -537,10 +679,9 @@ class TestCommon(TestCase):
|
||||
# but noncontiguous.
|
||||
# Expected behavior: strides are respected and `out` storage is not changed.
|
||||
def _case_one_transform(t):
|
||||
return make_tensor(t.shape,
|
||||
dtype=t.dtype,
|
||||
device=t.device,
|
||||
noncontiguous=True)
|
||||
return make_tensor(
|
||||
t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
|
||||
)
|
||||
|
||||
_compare_out(_case_one_transform)
|
||||
|
||||
@ -560,16 +701,19 @@ class TestCommon(TestCase):
|
||||
# Verifies no warning is a resize warning
|
||||
for w in caught:
|
||||
if "An output with one or more elements" in str(w.message):
|
||||
self.fail("Resizing an out= argument with no elements threw a resize warning!")
|
||||
self.fail(
|
||||
"Resizing an out= argument with no elements threw a resize warning!"
|
||||
)
|
||||
|
||||
# Case 3: out= with correct shape and dtype, but wrong device.
|
||||
wrong_device = None
|
||||
if torch.device(device).type != 'cpu':
|
||||
wrong_device = 'cpu'
|
||||
if torch.device(device).type != "cpu":
|
||||
wrong_device = "cpu"
|
||||
elif torch.cuda.is_available():
|
||||
wrong_device = 'cuda'
|
||||
wrong_device = "cuda"
|
||||
|
||||
if wrong_device is not None:
|
||||
|
||||
def _case_three_transform(t):
|
||||
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
|
||||
|
||||
@ -587,16 +731,28 @@ class TestCommon(TestCase):
|
||||
# dtypes, or if an op returns multiple tensors when at least one such
|
||||
# tensor is a floating point or complex dtype.
|
||||
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
|
||||
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
|
||||
(not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))):
|
||||
if (
|
||||
isinstance(expected, torch.Tensor)
|
||||
and expected.dtype in _dtypes
|
||||
or (
|
||||
not isinstance(expected, torch.Tensor)
|
||||
and any(t.dtype in _dtypes for t in expected)
|
||||
)
|
||||
):
|
||||
|
||||
def _case_four_transform(t):
|
||||
return make_tensor(t.shape, dtype=torch.long, device=t.device)
|
||||
|
||||
out = _apply_out_transform(_case_four_transform, expected)
|
||||
msg_fail = "Expected RuntimeError when doing an unsafe cast!"
|
||||
msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \
|
||||
("Expected RuntimeError when doing an unsafe cast from a result of dtype "
|
||||
f"{expected.dtype} into an out= with dtype torch.long")
|
||||
msg_fail = (
|
||||
msg_fail
|
||||
if not isinstance(expected, torch.Tensor)
|
||||
else (
|
||||
"Expected RuntimeError when doing an unsafe cast from a result of dtype "
|
||||
f"{expected.dtype} into an out= with dtype torch.long"
|
||||
)
|
||||
)
|
||||
with self.assertRaises(RuntimeError, msg=msg_fail):
|
||||
op_out(out=out)
|
||||
|
||||
@ -611,7 +767,9 @@ class TestCommon(TestCase):
|
||||
inplace = op.inplace_variant
|
||||
|
||||
# list of all inplace ops: inplace variant + alias inplace variants if exist
|
||||
inplace_ops = [inplace, ]
|
||||
inplace_ops = [
|
||||
inplace,
|
||||
]
|
||||
variants = [method, inplace]
|
||||
|
||||
for a_op in op.aliases:
|
||||
@ -623,30 +781,46 @@ class TestCommon(TestCase):
|
||||
inplace_variants = tuple(filter(None, inplace_ops))
|
||||
variants = tuple(filter(None, variants))
|
||||
|
||||
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
|
||||
_requires_grad = dtype in op.supported_backward_dtypes(
|
||||
torch.device(device).type
|
||||
)
|
||||
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||
samples = op.sample_inputs(
|
||||
device,
|
||||
dtype,
|
||||
requires_grad=_requires_grad,
|
||||
include_conjugated_inputs=include_conjugated_inputs,
|
||||
)
|
||||
samples = list(samples)
|
||||
|
||||
def _test_consistency_helper(samples, variants):
|
||||
for sample in samples:
|
||||
# TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
|
||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||
tensor = (
|
||||
sample.input
|
||||
if isinstance(sample.input, torch.Tensor)
|
||||
else sample.input[0]
|
||||
)
|
||||
|
||||
# Computes function forward and backward values
|
||||
tensor.grad = None
|
||||
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
||||
expected_grad = None
|
||||
|
||||
output_process_fn_grad = sample.output_process_fn_grad if sample.output_process_fn_grad \
|
||||
output_process_fn_grad = (
|
||||
sample.output_process_fn_grad
|
||||
if sample.output_process_fn_grad
|
||||
else lambda x: x
|
||||
)
|
||||
|
||||
# Skips inplace variants if the output dtype is not the same as
|
||||
# the input dtype
|
||||
skip_inplace = False
|
||||
if (isinstance(expected_forward, torch.Tensor) and
|
||||
expected_forward.dtype is not tensor.dtype):
|
||||
if (
|
||||
isinstance(expected_forward, torch.Tensor)
|
||||
and expected_forward.dtype is not tensor.dtype
|
||||
):
|
||||
skip_inplace = True
|
||||
|
||||
# TODO: backward consistency only supported for single tensor outputs
|
||||
@ -654,8 +828,9 @@ class TestCommon(TestCase):
|
||||
# tensor inputs
|
||||
# TODO: update to handle checking grads of all tensor inputs as
|
||||
# derived from each tensor output
|
||||
if (isinstance(expected_forward, torch.Tensor)
|
||||
and dtype in op.supported_backward_dtypes(torch.device(device).type)):
|
||||
if isinstance(
|
||||
expected_forward, torch.Tensor
|
||||
) and dtype in op.supported_backward_dtypes(torch.device(device).type):
|
||||
output_process_fn_grad(expected_forward).sum().backward()
|
||||
expected_grad = tensor.grad
|
||||
|
||||
@ -668,26 +843,35 @@ class TestCommon(TestCase):
|
||||
# Compares variant's forward
|
||||
# Note: copies the to-be-modified input when testing the inplace variant
|
||||
tensor.grad = None
|
||||
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
|
||||
cloned = (
|
||||
clone_input_helper(sample.input)
|
||||
if variant in inplace_ops
|
||||
else sample.input
|
||||
)
|
||||
|
||||
if variant in inplace_ops and sample.broadcasts_input:
|
||||
with self.assertRaises(RuntimeError,
|
||||
msg=('inplace variant either incorrectly allowed '
|
||||
'resizing or you have marked the sample {}'
|
||||
' incorrectly with `broadcasts_self=True'.format(sample.summary()))):
|
||||
variant_forward = variant(cloned,
|
||||
*sample.args,
|
||||
**sample.kwargs)
|
||||
with self.assertRaises(
|
||||
RuntimeError,
|
||||
msg=(
|
||||
"inplace variant either incorrectly allowed "
|
||||
"resizing or you have marked the sample {}"
|
||||
" incorrectly with `broadcasts_self=True".format(
|
||||
sample.summary()
|
||||
)
|
||||
),
|
||||
):
|
||||
variant_forward = variant(
|
||||
cloned, *sample.args, **sample.kwargs
|
||||
)
|
||||
continue
|
||||
|
||||
variant_forward = variant(cloned,
|
||||
*sample.args,
|
||||
**sample.kwargs)
|
||||
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
|
||||
self.assertEqual(expected_forward, variant_forward)
|
||||
|
||||
# Compares variant's backward
|
||||
if expected_grad is not None and \
|
||||
(variant not in inplace_ops or op.supports_inplace_autograd):
|
||||
if expected_grad is not None and (
|
||||
variant not in inplace_ops or op.supports_inplace_autograd
|
||||
):
|
||||
output_process_fn_grad(variant_forward).sum().backward()
|
||||
self.assertEqual(expected_grad, tensor.grad)
|
||||
|
||||
@ -698,28 +882,45 @@ class TestCommon(TestCase):
|
||||
# Skips inplace variants if the output dtype is not the same as
|
||||
# the input dtype
|
||||
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||
tensor = (
|
||||
sample.input
|
||||
if isinstance(sample.input, torch.Tensor)
|
||||
else sample.input[0]
|
||||
)
|
||||
skip_inplace = False
|
||||
if (isinstance(expected_forward, torch.Tensor) and
|
||||
expected_forward.dtype is not tensor.dtype):
|
||||
if (
|
||||
isinstance(expected_forward, torch.Tensor)
|
||||
and expected_forward.dtype is not tensor.dtype
|
||||
):
|
||||
skip_inplace = True
|
||||
if skip_inplace:
|
||||
return
|
||||
for variant in variants:
|
||||
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
|
||||
inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0]
|
||||
cloned = (
|
||||
clone_input_helper(sample.input)
|
||||
if variant in inplace_ops
|
||||
else sample.input
|
||||
)
|
||||
inp_tensor = (
|
||||
cloned if isinstance(cloned, torch.Tensor) else cloned[0]
|
||||
)
|
||||
data_ptr = inp_tensor.data_ptr()
|
||||
variant_forward = variant(cloned,
|
||||
*sample.args,
|
||||
**sample.kwargs)
|
||||
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
|
||||
# TODO Support non-tensor outputs if they exist for inplace ops
|
||||
if (isinstance(variant_forward, torch.Tensor)):
|
||||
self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0)
|
||||
if isinstance(variant_forward, torch.Tensor):
|
||||
self.assertEqual(
|
||||
data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
|
||||
)
|
||||
else:
|
||||
self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported")
|
||||
self.assertTrue(
|
||||
False,
|
||||
"Non-tensor outputs for inplace ops are not supported",
|
||||
)
|
||||
|
||||
if len(inplace_ops) > 0:
|
||||
inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples))
|
||||
inplace_samples = list(
|
||||
filter(lambda sample: not sample.broadcasts_input, samples)
|
||||
)
|
||||
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
|
||||
|
||||
# Reference testing for operations in complex32 against complex64.
|
||||
@ -732,14 +933,21 @@ class TestCommon(TestCase):
|
||||
for sample in op.sample_inputs(device, dtype):
|
||||
actual = op(sample.input, *sample.args, **sample.kwargs)
|
||||
transformed_sample = sample.transform(lambda x: x.to(torch.complex64))
|
||||
expected = op(transformed_sample.input, *transformed_sample.args, **transformed_sample.kwargs)
|
||||
expected = op(
|
||||
transformed_sample.input,
|
||||
*transformed_sample.args,
|
||||
**transformed_sample.kwargs,
|
||||
)
|
||||
self.assertEqual(actual, expected, exact_dtype=False)
|
||||
|
||||
|
||||
class TestCompositeCompliance(TestCase):
|
||||
# Checks if the operator (if it is composite) is written to support most
|
||||
# backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
|
||||
# in aten/src/ATen/native/README.md for more details
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
|
||||
)
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
def test_operator(self, device, dtype, op):
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
@ -750,7 +958,9 @@ class TestCompositeCompliance(TestCase):
|
||||
composite_compliance.check_with_mode(op, args, kwargs)
|
||||
composite_compliance.check_all_permutations(op, args, kwargs)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
|
||||
)
|
||||
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
|
||||
def test_backward(self, device, dtype, op):
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||
@ -760,7 +970,9 @@ class TestCompositeCompliance(TestCase):
|
||||
kwargs = sample.kwargs
|
||||
composite_compliance.check_backward_formula(op, args, kwargs)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
|
||||
@unittest.skipIf(
|
||||
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
|
||||
)
|
||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||
def test_forward_ad(self, device, dtype, op):
|
||||
if torch.float not in op.supported_backward_dtypes(device):
|
||||
@ -776,6 +988,7 @@ class TestCompositeCompliance(TestCase):
|
||||
kwargs = sample.kwargs
|
||||
composite_compliance.check_forward_ad_formula(op, args, kwargs)
|
||||
|
||||
|
||||
class TestMathBits(TestCase):
|
||||
# Tests that
|
||||
# 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
|
||||
@ -787,7 +1000,17 @@ class TestMathBits(TestCase):
|
||||
# This test only runs for C -> R and C -> C functions
|
||||
# TODO: add tests for `R->C` functions
|
||||
# Note: This test runs for functions that take both tensors and tensorlists as input.
|
||||
def _test_math_view(self, device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, out_type):
|
||||
def _test_math_view(
|
||||
self,
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
samples,
|
||||
math_op_physical,
|
||||
math_op_view,
|
||||
is_bit_set,
|
||||
out_type,
|
||||
):
|
||||
inplace_variant = op.inplace_variant
|
||||
|
||||
# helper function to clone and conjugate/negate the input if its a tensor
|
||||
@ -796,7 +1019,7 @@ class TestMathBits(TestCase):
|
||||
# have its requires_grad set to that value.
|
||||
def clone_and_perform_view(input, **kwargs):
|
||||
if isinstance(input, torch.Tensor):
|
||||
requires_grad = kwargs.get('requires_grad', input.requires_grad)
|
||||
requires_grad = kwargs.get("requires_grad", input.requires_grad)
|
||||
with torch.no_grad():
|
||||
# Ensure view represents the original sample input
|
||||
input = math_op_physical(input)
|
||||
@ -813,7 +1036,11 @@ class TestMathBits(TestCase):
|
||||
return tuple(out)
|
||||
|
||||
for sample in samples:
|
||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||
tensor = (
|
||||
sample.input
|
||||
if isinstance(sample.input, torch.Tensor)
|
||||
else sample.input[0]
|
||||
)
|
||||
cloned1 = clone_and_perform_view(sample.input)
|
||||
|
||||
# Computes function forward value with a physically conjugated/negated tensor and
|
||||
@ -827,9 +1054,13 @@ class TestMathBits(TestCase):
|
||||
# input produces correct output, and the output tensor has the conj/neg bit set to True
|
||||
if inplace_variant is not None and not sample.broadcasts_input:
|
||||
cloned2 = clone_and_perform_view(tensor, requires_grad=False)
|
||||
if (isinstance(expected_forward, torch.Tensor) and
|
||||
expected_forward.dtype is tensor.dtype):
|
||||
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
|
||||
if (
|
||||
isinstance(expected_forward, torch.Tensor)
|
||||
and expected_forward.dtype is tensor.dtype
|
||||
):
|
||||
inplace_forward = inplace_variant(
|
||||
cloned2, *sample.args, **sample.kwargs
|
||||
)
|
||||
self.assertTrue(is_bit_set(inplace_forward))
|
||||
self.assertEqual(inplace_forward, expected_forward)
|
||||
|
||||
@ -838,25 +1069,36 @@ class TestMathBits(TestCase):
|
||||
# tensor inputs
|
||||
# TODO: update to handle checking grads of all tensor inputs as
|
||||
# derived from each tensor output
|
||||
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
|
||||
if (
|
||||
isinstance(expected_forward, torch.Tensor)
|
||||
and expected_forward.requires_grad
|
||||
):
|
||||
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
|
||||
expected_forward = output_process_fn_grad(expected_forward)
|
||||
forward_with_mathview = output_process_fn_grad(forward_with_mathview)
|
||||
|
||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
||||
tensor = (
|
||||
sample.input
|
||||
if isinstance(sample.input, torch.Tensor)
|
||||
else sample.input[0]
|
||||
)
|
||||
expected_forward.sum().backward(retain_graph=True)
|
||||
forward_with_mathview.sum().backward(retain_graph=True)
|
||||
if tensor.grad is not None:
|
||||
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
|
||||
cloned1_tensor = (
|
||||
cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
|
||||
)
|
||||
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||
|
||||
tensor.grad, cloned1_tensor.grad = None, None
|
||||
|
||||
# a repeat of the above test if output is not complex valued
|
||||
if (out_type(expected_forward)):
|
||||
if out_type(expected_forward):
|
||||
grad = torch.randn_like(expected_forward)
|
||||
expected_forward.backward(grad)
|
||||
forward_with_mathview.backward(math_op_view(math_op_physical(grad)))
|
||||
forward_with_mathview.backward(
|
||||
math_op_view(math_op_physical(grad))
|
||||
)
|
||||
|
||||
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||
|
||||
@ -866,10 +1108,21 @@ class TestMathBits(TestCase):
|
||||
self.skipTest("Operation doesn't support conjugated inputs.")
|
||||
math_op_physical = torch.conj_physical
|
||||
math_op_view = torch.conj
|
||||
_requires_grad = torch.cfloat in op.supported_backward_dtypes(torch.device(device).type)
|
||||
_requires_grad = torch.cfloat in op.supported_backward_dtypes(
|
||||
torch.device(device).type
|
||||
)
|
||||
is_bit_set = torch.is_conj
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, torch.is_complex)
|
||||
self._test_math_view(
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
samples,
|
||||
math_op_physical,
|
||||
math_op_view,
|
||||
is_bit_set,
|
||||
torch.is_complex,
|
||||
)
|
||||
|
||||
@ops(op_db, allowed_dtypes=(torch.double,))
|
||||
def test_neg_view(self, device, dtype, op):
|
||||
@ -879,8 +1132,16 @@ class TestMathBits(TestCase):
|
||||
math_op_view = torch._neg_view
|
||||
is_bit_set = torch.is_neg
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
|
||||
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
|
||||
lambda x: True)
|
||||
self._test_math_view(
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
samples,
|
||||
math_op_physical,
|
||||
math_op_view,
|
||||
is_bit_set,
|
||||
lambda x: True,
|
||||
)
|
||||
|
||||
@ops(op_db, allowed_dtypes=(torch.cdouble,))
|
||||
def test_neg_conj_view(self, device, dtype, op):
|
||||
@ -898,17 +1159,27 @@ class TestMathBits(TestCase):
|
||||
def is_bit_set(x):
|
||||
return torch.is_neg(x) and torch.is_conj(x)
|
||||
|
||||
_requires_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
|
||||
_requires_grad = dtype in op.supported_backward_dtypes(
|
||||
torch.device(device).type
|
||||
)
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||
# Only test one sample
|
||||
samples = itertools.islice(samples, 1)
|
||||
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
|
||||
torch.is_complex)
|
||||
self._test_math_view(
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
samples,
|
||||
math_op_physical,
|
||||
math_op_view,
|
||||
is_bit_set,
|
||||
torch.is_complex,
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestCommon, globals())
|
||||
instantiate_device_type_tests(TestCompositeCompliance, globals())
|
||||
instantiate_device_type_tests(TestMathBits, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -1,85 +1,83 @@
|
||||
# Owner(s): ["module: primTorch"]
|
||||
|
||||
# TODO: uncomment this file once CI issues with import nvfuser are resolved
|
||||
from functools import partial
|
||||
|
||||
# from functools import partial
|
||||
|
||||
# import torch
|
||||
# from torch.testing import make_tensor
|
||||
# from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
# from torch.testing._internal.common_device_type import (
|
||||
# instantiate_device_type_tests,
|
||||
# onlyCUDA,
|
||||
# dtypes,
|
||||
# )
|
||||
# import torch._prims as prims
|
||||
# from torch._prims.executor import make_traced
|
||||
import torch
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
onlyCUDA,
|
||||
dtypes,
|
||||
)
|
||||
import torch._prims as prims
|
||||
from torch._prims.executor import make_traced
|
||||
|
||||
|
||||
# class TestPrims(TestCase):
|
||||
# @onlyCUDA
|
||||
# @dtypes(torch.float32)
|
||||
# def test_broadcast_in_dim(self, device, dtype):
|
||||
# def _wrapper(a, shape, broadcast_dimensions):
|
||||
# return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
|
||||
class TestPrims(TestCase):
|
||||
@onlyCUDA
|
||||
@dtypes(torch.float32)
|
||||
def test_broadcast_in_dim(self, device, dtype):
|
||||
def _wrapper(a, shape, broadcast_dimensions):
|
||||
return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
|
||||
|
||||
# traced = make_traced(_wrapper)
|
||||
# make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
traced = make_traced(_wrapper)
|
||||
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||
|
||||
# # TODO: FIXME:
|
||||
# # for executor in ('aten', 'nvfuser'):
|
||||
# for executor in ("aten",):
|
||||
# fn = partial(traced, executor=executor)
|
||||
# # Same shape
|
||||
# shape = (5, 5)
|
||||
# a = make_arg(shape)
|
||||
# result = fn(a, shape, (0, 1))
|
||||
# TODO: FIXME:
|
||||
# for executor in ('aten', 'nvfuser'):
|
||||
for executor in ("aten",):
|
||||
fn = partial(traced, executor=executor)
|
||||
# Same shape
|
||||
shape = (5, 5)
|
||||
a = make_arg(shape)
|
||||
result = fn(a, shape, (0, 1))
|
||||
|
||||
# self.assertEqual(result.shape, a.shape)
|
||||
# self.assertTrue(result.is_contiguous)
|
||||
# self.assertEqual(a, result)
|
||||
self.assertEqual(result.shape, a.shape)
|
||||
self.assertTrue(result.is_contiguous)
|
||||
self.assertEqual(a, result)
|
||||
|
||||
# # Error input: reordering dims
|
||||
# with self.assertRaises(Exception):
|
||||
# result = fn(a, shape, (1, 0))
|
||||
# Error input: reordering dims
|
||||
with self.assertRaises(Exception):
|
||||
result = fn(a, shape, (1, 0))
|
||||
|
||||
# # Adding outermost dimensions
|
||||
# a = make_arg((5, 5))
|
||||
# target_shape = (3, 3, 5, 5)
|
||||
# result = fn(a, target_shape, (2, 3))
|
||||
# Adding outermost dimensions
|
||||
a = make_arg((5, 5))
|
||||
target_shape = (3, 3, 5, 5)
|
||||
result = fn(a, target_shape, (2, 3))
|
||||
|
||||
# self.assertEqual(result.shape, target_shape)
|
||||
# self.assertEqual(a.broadcast_to(target_shape), result)
|
||||
self.assertEqual(result.shape, target_shape)
|
||||
self.assertEqual(a.broadcast_to(target_shape), result)
|
||||
|
||||
# # Expands
|
||||
# a = make_arg((1, 5, 1))
|
||||
# target_shape = (3, 5, 7)
|
||||
# result = fn(a, target_shape, (0, 1, 2))
|
||||
# Expands
|
||||
a = make_arg((1, 5, 1))
|
||||
target_shape = (3, 5, 7)
|
||||
result = fn(a, target_shape, (0, 1, 2))
|
||||
|
||||
# self.assertEqual(result.shape, target_shape)
|
||||
# self.assertEqual(a.expand_as(result), result)
|
||||
self.assertEqual(result.shape, target_shape)
|
||||
self.assertEqual(a.expand_as(result), result)
|
||||
|
||||
# # Unsqueezes
|
||||
# a = make_arg((1, 2, 3))
|
||||
# target_shape = (1, 2, 1, 3)
|
||||
# result = fn(a, target_shape, (0, 1, 3))
|
||||
# Unsqueezes
|
||||
a = make_arg((1, 2, 3))
|
||||
target_shape = (1, 2, 1, 3)
|
||||
result = fn(a, target_shape, (0, 1, 3))
|
||||
|
||||
# self.assertEqual(result.shape, target_shape)
|
||||
# self.assertEqual(a.unsqueeze(2), result)
|
||||
self.assertEqual(result.shape, target_shape)
|
||||
self.assertEqual(a.unsqueeze(2), result)
|
||||
|
||||
# # Adds outermost, expands, and unsqueezes
|
||||
# a = make_arg((1, 2, 3))
|
||||
# target_shape = (4, 1, 7, 2, 3, 3)
|
||||
# result = fn(a, target_shape, (1, 3, 4))
|
||||
# Adds outermost, expands, and unsqueezes
|
||||
a = make_arg((1, 2, 3))
|
||||
target_shape = (4, 1, 7, 2, 3, 3)
|
||||
result = fn(a, target_shape, (1, 3, 4))
|
||||
|
||||
# self.assertEqual(result.shape, target_shape)
|
||||
# a.unsqueeze_(3)
|
||||
# a.unsqueeze_(1)
|
||||
# a.unsqueeze_(0)
|
||||
# self.assertEqual(a.expand_as(result), result)
|
||||
self.assertEqual(result.shape, target_shape)
|
||||
a.unsqueeze_(3)
|
||||
a.unsqueeze_(1)
|
||||
a.unsqueeze_(0)
|
||||
self.assertEqual(a.expand_as(result), result)
|
||||
|
||||
|
||||
# instantiate_device_type_tests(TestPrims, globals())
|
||||
instantiate_device_type_tests(TestPrims, globals())
|
||||
|
||||
# if __name__ == "__main__":
|
||||
# run_tests()
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
@ -13,7 +13,7 @@ import random
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||
torch_to_numpy_dtype_dict, slowTest,
|
||||
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
|
||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
||||
@ -3683,13 +3683,13 @@ class TestBufferProtocol(TestCase):
|
||||
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
|
||||
return (numpy_original, torch_frombuffer)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_same_type(self, device, dtype):
|
||||
self._run_test((), dtype)
|
||||
self._run_test((4,), dtype)
|
||||
self._run_test((10, 10), dtype)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_requires_grad(self, device, dtype):
|
||||
def _run_test_and_check_grad(requires_grad, *args, **kwargs):
|
||||
kwargs["requires_grad"] = requires_grad
|
||||
@ -3704,14 +3704,14 @@ class TestBufferProtocol(TestCase):
|
||||
_run_test_and_check_grad(False, (4,), dtype)
|
||||
_run_test_and_check_grad(False, (10, 10), dtype)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_with_offset(self, device, dtype):
|
||||
# Offset should be valid whenever there is, at least,
|
||||
# one remaining element
|
||||
for i in range(SIZE):
|
||||
self._run_test(SHAPE, dtype, first=i)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_with_count(self, device, dtype):
|
||||
# Count should be valid for any valid in the interval
|
||||
# [-1, len(input)], except for 0
|
||||
@ -3719,7 +3719,7 @@ class TestBufferProtocol(TestCase):
|
||||
if i != 0:
|
||||
self._run_test(SHAPE, dtype, count=i)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_with_count_and_offset(self, device, dtype):
|
||||
# Explicit default count [-1, 1, 2, ..., len]
|
||||
for i in range(-1, SIZE + 1):
|
||||
@ -3735,7 +3735,7 @@ class TestBufferProtocol(TestCase):
|
||||
for j in range(SIZE - i + 1):
|
||||
self._run_test(SHAPE, dtype, count=i, first=j)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_invalid_positional_args(self, device, dtype):
|
||||
bytes = get_dtype_size(dtype)
|
||||
in_bytes = SIZE * bytes
|
||||
@ -3772,7 +3772,7 @@ class TestBufferProtocol(TestCase):
|
||||
rf"buffer length \({in_bytes} bytes\)"):
|
||||
self._run_test(SHAPE, dtype, count=count, first=first)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_shared_buffer(self, device, dtype):
|
||||
x = make_tensor((1,), dtype=dtype, device=device)
|
||||
# Modify the whole tensor
|
||||
@ -3799,13 +3799,13 @@ class TestBufferProtocol(TestCase):
|
||||
arr[first] = x.item() - 1
|
||||
self.assertEqual(arr[first:last], tensor)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_not_a_buffer(self, device, dtype):
|
||||
with self.assertRaisesRegex(ValueError,
|
||||
r"object does not implement Python buffer protocol."):
|
||||
torch.frombuffer([1, 2, 3, 4], dtype=dtype)
|
||||
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_non_writable_buffer(self, device, dtype):
|
||||
numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy()
|
||||
byte_arr = numpy_arr.tobytes()
|
||||
@ -3910,7 +3910,7 @@ class TestAsArray(TestCase):
|
||||
self._test_alias_with_cvt(identity, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_alias_from_numpy(self, device, dtype):
|
||||
self._test_alias_with_cvt(to_numpy, device, dtype)
|
||||
|
||||
@ -3921,7 +3921,7 @@ class TestAsArray(TestCase):
|
||||
self._test_alias_with_cvt(to_dlpack, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_alias_from_buffer(self, device, dtype):
|
||||
self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
||||
|
||||
@ -3959,7 +3959,7 @@ class TestAsArray(TestCase):
|
||||
self._test_copy_with_cvt(identity, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_copy_from_numpy(self, device, dtype):
|
||||
self._test_copy_with_cvt(to_numpy, device, dtype)
|
||||
|
||||
@ -3969,7 +3969,7 @@ class TestAsArray(TestCase):
|
||||
self._test_copy_with_cvt(to_dlpack, device, dtype)
|
||||
|
||||
@onlyCPU
|
||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
||||
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||
def test_copy_from_buffer(self, device, dtype):
|
||||
self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
||||
|
||||
|
@ -7,15 +7,14 @@ import unittest
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
|
||||
TEST_NUMPY, torch_to_numpy_dtype_dict)
|
||||
TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict)
|
||||
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
|
||||
dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta)
|
||||
from torch.testing._internal.common_dtype import (
|
||||
all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and
|
||||
)
|
||||
|
||||
if TEST_NUMPY:
|
||||
import numpy as np
|
||||
import numpy as np
|
||||
|
||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||
# sharding on sandcastle. This line silences flake warnings
|
||||
@ -812,8 +811,8 @@ class TestTypePromotion(TestCase):
|
||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||
@float_double_default_dtype
|
||||
@onlyCPU
|
||||
@dtypes(*list(itertools.product(torch_to_numpy_dtype_dict.keys(),
|
||||
torch_to_numpy_dtype_dict.keys())))
|
||||
@dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()),
|
||||
set(numpy_to_torch_dtype_dict.values()))))
|
||||
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
|
||||
import operator
|
||||
np_type = torch_to_numpy_dtype_dict[dtypes[0]]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -11,7 +11,7 @@ import random
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck,
|
||||
torch_to_numpy_dtype_dict,
|
||||
numpy_to_torch_dtype_dict,
|
||||
)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
|
||||
@ -130,7 +130,7 @@ class TestViewOps(TestCase):
|
||||
@onlyNativeDeviceTypes
|
||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||
def test_view_dtype_new(self, device, dtype):
|
||||
dtypes = torch_to_numpy_dtype_dict.copy()
|
||||
dtypes = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
del dtypes[torch.bool]
|
||||
|
||||
def generate_inputs():
|
||||
|
@ -2,14 +2,15 @@ import torch
|
||||
from torch import Tensor
|
||||
|
||||
import torch._prims.utils as utils
|
||||
from torch._prims.utils import TensorLike, TensorLikeType, TensorMeta, ShapeType
|
||||
from torch._prims.utils import (
|
||||
TensorLike,
|
||||
TensorLikeType,
|
||||
TensorMeta,
|
||||
ShapeType,
|
||||
getnvFuserDtype,
|
||||
)
|
||||
from torch.overrides import has_torch_function, handle_torch_function
|
||||
|
||||
import torch._C._nvfuser as nvfuser # type: ignore[import]
|
||||
|
||||
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
|
||||
DataType = nvfuser.DataType # type: ignore[attr-defined]
|
||||
|
||||
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any
|
||||
from numbers import Number
|
||||
from functools import reduce
|
||||
@ -1033,21 +1034,8 @@ def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
return a.to(dtype)
|
||||
|
||||
|
||||
_torch_dtype_to_nvfuser_dtype_map = {
|
||||
torch.cdouble: DataType.ComplexDouble,
|
||||
torch.cfloat: DataType.ComplexFloat,
|
||||
torch.double: DataType.Double,
|
||||
torch.float: DataType.Float,
|
||||
torch.half: DataType.Half,
|
||||
torch.bfloat16: DataType.BFloat16,
|
||||
torch.long: DataType.Int,
|
||||
torch.int: DataType.Int32,
|
||||
torch.bool: DataType.Bool,
|
||||
}
|
||||
|
||||
|
||||
def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
nvfuser_dtype = _torch_dtype_to_nvfuser_dtype_map[dtype]
|
||||
nvfuser_dtype = getnvFuserDtype(dtype)
|
||||
return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
|
@ -3,27 +3,11 @@ from typing import Callable
|
||||
import torch
|
||||
|
||||
from torch.fx import GraphModule
|
||||
from torch._prims.utils import TensorMeta
|
||||
from torch._prims.utils import TensorMeta, getnvFuserDtype
|
||||
from torch._prims.context import PrimContext
|
||||
|
||||
import torch._C._nvfuser as nvfuser # type: ignore[import]
|
||||
|
||||
DataType = nvfuser.DataType # type: ignore[attr-defined]
|
||||
Fusion = nvfuser.Fusion # type: ignore[attr-defined]
|
||||
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
|
||||
|
||||
# TODO: refactor me into a common place
|
||||
_torch_dtype_to_nvfuser_dtype_map = {
|
||||
torch.cdouble: DataType.ComplexDouble,
|
||||
torch.cfloat: DataType.ComplexFloat,
|
||||
torch.double: DataType.Double,
|
||||
torch.float: DataType.Float,
|
||||
torch.half: DataType.Half,
|
||||
torch.bfloat16: DataType.BFloat16,
|
||||
torch.long: DataType.Int,
|
||||
torch.int: DataType.Int32,
|
||||
torch.bool: DataType.Bool,
|
||||
}
|
||||
if torch.cuda.is_available():
|
||||
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
|
||||
|
||||
|
||||
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
||||
@ -37,6 +21,11 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
||||
gm = GraphModule({}, ctx.graph)
|
||||
return gm.forward(*args, **kwargs)
|
||||
elif executor == "nvfuser":
|
||||
if not torch.cuda.is_available():
|
||||
raise RuntimeError(
|
||||
"Attempting to use nvFuser trace executor but CUDA is not available!"
|
||||
)
|
||||
|
||||
# PROTOTYPE nvfuser executor
|
||||
# Only accepts tensor inputs and single tensor outputs
|
||||
# Does not handle kwargs
|
||||
@ -53,9 +42,7 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
||||
nv_args = [fd]
|
||||
for arg in args:
|
||||
if isinstance(arg, torch.Tensor):
|
||||
x = fd.define_tensor(
|
||||
arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype]
|
||||
)
|
||||
x = fd.define_tensor(arg.ndim, getnvFuserDtype(arg.dtype))
|
||||
fd.add_input(x)
|
||||
nv_args.append(x)
|
||||
else:
|
||||
|
@ -8,6 +8,32 @@ import threading
|
||||
import torch
|
||||
from torch.fx import Node
|
||||
|
||||
# nvFuser imports are conditional on CUDA being available
|
||||
if torch.cuda.is_available():
|
||||
from torch._C._nvfuser import DataType # type: ignore[import]
|
||||
|
||||
_torch_dtype_to_nvfuser_dtype_map = {
|
||||
torch.cdouble: DataType.ComplexDouble,
|
||||
torch.cfloat: DataType.ComplexFloat,
|
||||
torch.double: DataType.Double,
|
||||
torch.float: DataType.Float,
|
||||
torch.half: DataType.Half,
|
||||
torch.bfloat16: DataType.BFloat16,
|
||||
torch.long: DataType.Int,
|
||||
torch.int: DataType.Int32,
|
||||
torch.bool: DataType.Bool,
|
||||
}
|
||||
else:
|
||||
_torch_dtype_to_nvfuser_dtype_map = {}
|
||||
|
||||
|
||||
def getnvFuserDtype(dtype: torch.dtype):
|
||||
"""
|
||||
Translates from torch.dtype to nvFuser's DataType enum
|
||||
"""
|
||||
return _torch_dtype_to_nvfuser_dtype_map[dtype]
|
||||
|
||||
|
||||
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
|
||||
StrideType = Union[List[int], Tuple[int, ...]]
|
||||
DimsType = Union[int, List[int], Tuple[int, ...]]
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -871,6 +871,10 @@ if IS_WINDOWS:
|
||||
|
||||
# Dict of torch dtype -> NumPy dtype
|
||||
torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||
torch_to_numpy_dtype_dict.update({
|
||||
torch.bfloat16: np.float32,
|
||||
torch.complex32: np.complex64
|
||||
})
|
||||
|
||||
def skipIfRocm(fn):
|
||||
@wraps(fn)
|
||||
|
Reference in New Issue
Block a user