Adds python ref consistency test, elementwise unary reference inputs, and formats test files

Per title.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626
Approved by: https://github.com/ngimel
This commit is contained in:
Mike Ruberry
2022-05-01 22:42:46 +00:00
committed by PyTorch MergeBot
parent 33be4c94c0
commit f6bbecf8b5
12 changed files with 3333 additions and 1682 deletions

File diff suppressed because it is too large Load Diff

View File

@ -8,17 +8,44 @@ import itertools
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_dtype import floating_and_complex_types_and, all_types_and_complex_and
from torch.testing._internal.common_utils import \
(TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper,
IS_IN_CI, suppress_warnings, noncontiguous_like,
TEST_WITH_ASAN, IS_WINDOWS, IS_FBCODE, first_sample)
from torch.testing._internal.common_methods_invocations import \
(op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, ops_and_refs)
from torch.testing._internal.common_device_type import \
(deviceCountAtLeast, instantiate_device_type_tests, ops,
onlyCUDA, onlyNativeDeviceTypes, OpDTypes, skipMeta)
# import torch._prims as prims
from torch.testing._internal.common_dtype import (
floating_and_complex_types_and,
all_types_and_complex_and,
)
from torch.testing._internal.common_utils import (
TestCase,
is_iterable_of_tensors,
run_tests,
IS_SANDCASTLE,
clone_input_helper,
IS_IN_CI,
suppress_warnings,
noncontiguous_like,
TEST_WITH_ASAN,
IS_WINDOWS,
IS_FBCODE,
first_sample,
)
from torch.testing._internal.common_methods_invocations import (
op_db,
_NOTHING,
UnaryUfuncInfo,
ReductionOpInfo,
SpectralFuncInfo,
ops_and_refs,
python_ref_db,
BinaryUfuncInfo,
)
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
instantiate_device_type_tests,
ops,
onlyCUDA,
onlyNativeDeviceTypes,
OpDTypes,
skipMeta,
)
import torch._prims as prims
import torch.testing._internal.opinfo_helper as opinfo_helper
from torch.testing._internal import composite_compliance
@ -28,15 +55,25 @@ torch.set_default_dtype(torch.float32)
# variant testing is only done with torch.float and torch.cfloat to avoid
# excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float, torch.cfloat))
_variant_ops = partial(
ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
)
# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
# except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py)
# except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
# elementwise binary operators (separately implemented in test_binary_ufuncs.py),
# reduction operations (separately impelemented in test_reductions.py),
# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo,
SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db))
_ref_test_ops = tuple(
filter(
lambda op: not isinstance(
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
)
and op.ref is not None
and op.ref is not _NOTHING,
op_db,
)
)
# Tests that apply to all operators and aren't related to any particular
# system
@ -49,8 +86,10 @@ class TestCommon(TestCase):
super().tearDownClass()
if IS_IN_CI:
err_msg = ("The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!")
err_msg = (
"The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
)
# Assure no opinfo entry has dynamic_dtypes
filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db))
for op in filtered_ops:
@ -68,11 +107,16 @@ class TestCommon(TestCase):
# Check complex32 support only if the op claims.
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
device_type = torch.device(device).type
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device_type) else ())
include_complex32 = (
(torch.complex32,)
if op.supports_dtype(torch.complex32, device_type)
else ()
)
# dtypes to try to backward in
allowed_backward_dtypes = floating_and_complex_types_and(
*((torch.half, torch.bfloat16) + include_complex32))
*((torch.half, torch.bfloat16) + include_complex32)
)
# lists for (un)supported dtypes
supported_dtypes = set()
@ -86,11 +130,14 @@ class TestCommon(TestCase):
unsupported_backward_dtypes.add(dtype)
for dtype in all_types_and_complex_and(
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)):
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)
):
# tries to acquire samples - failure indicates lack of support
requires_grad = (dtype in allowed_backward_dtypes)
requires_grad = dtype in allowed_backward_dtypes
try:
samples = tuple(op.sample_inputs(device, dtype, requires_grad=requires_grad))
samples = tuple(
op.sample_inputs(device, dtype, requires_grad=requires_grad)
)
except Exception as e:
unsupported(dtype)
continue
@ -113,7 +160,9 @@ class TestCommon(TestCase):
result = sample.output_process_fn_grad(result)
if isinstance(result, torch.Tensor):
backward_tensor = result
elif isinstance(result, Sequence) and isinstance(result[0], torch.Tensor):
elif isinstance(result, Sequence) and isinstance(
result[0], torch.Tensor
):
backward_tensor = result[0]
else:
continue
@ -130,14 +179,15 @@ class TestCommon(TestCase):
except Exception as e:
unsupported_backward_dtypes.add(dtype)
# Checks that dtypes are listed correctly and generates an informative
# error message
supported_forward = supported_dtypes - unsupported_dtypes
partially_supported_forward = supported_dtypes & unsupported_dtypes
unsupported_forward = unsupported_dtypes - supported_dtypes
supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes
partially_supported_backward = (
supported_backward_dtypes & unsupported_backward_dtypes
)
unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
device_type = torch.device(device).type
@ -156,17 +206,27 @@ class TestCommon(TestCase):
op.name, device_type
)
if len(partially_supported_forward) > 0:
msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format(
msg = (
msg
+ "The following dtypes only worked on some samples during forward: {0}.\n".format(
partially_supported_forward
)
)
if len(partially_supported_backward) > 0:
msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format(
msg = (
msg
+ "The following dtypes only worked on some samples during backward: {0}.\n".format(
partially_supported_backward
)
)
print(msg)
if (len(supported_but_unclaimed_forward) + len(claimed_but_unsupported_forward) +
len(supported_but_unclaimed_backward) + len(claimed_but_unsupported_backward)) == 0:
if (
len(supported_but_unclaimed_forward)
+ len(claimed_but_unsupported_forward)
+ len(supported_but_unclaimed_backward)
+ len(claimed_but_unsupported_backward)
) == 0:
return
# Generates error msg
@ -174,21 +234,33 @@ class TestCommon(TestCase):
op.name, device_type
)
if len(supported_but_unclaimed_forward) > 0:
msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
msg = (
msg
+ "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
supported_but_unclaimed_forward
)
)
if len(supported_but_unclaimed_backward) > 0:
msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
msg = (
msg
+ "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
supported_but_unclaimed_backward
)
)
if len(claimed_but_unsupported_forward) > 0:
msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
msg = (
msg
+ "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
claimed_but_unsupported_forward
)
)
if len(claimed_but_unsupported_backward) > 0:
msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
msg = (
msg
+ "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
claimed_but_unsupported_backward
)
)
self.fail(msg)
@ -209,7 +281,9 @@ class TestCommon(TestCase):
elif is_iterable_of_tensors(result):
self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
else:
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
# Tests that the function and its (ndarray-accepting) reference produce the same
# values on the tensors from sample_inputs func for the corresponding op.
@ -226,28 +300,48 @@ class TestCommon(TestCase):
cur_default = torch.get_default_dtype()
torch.set_default_dtype(torch.double)
for sample_input in op.reference_inputs(device, dtype):
self.compare_with_reference(op, op.ref, sample_input, exact_dtype=(dtype is not torch.long))
self.compare_with_reference(
op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
)
finally:
torch.set_default_dtype(cur_default)
# Tests that experimental Python References' can propagate shape, dtype,
# Tests that experimental Python References can propagate shape, dtype,
# and device metadata properly.
# TODO: include stride propagation.
# @onlyNativeDeviceTypes
# @ops(python_ref_db)
# def test_python_reference_meta_functions(self, device, dtype, op):
# def _to_tensormeta(x):
# if isinstance(x, torch.Tensor):
# return prims.utils.TensorMeta(x)
@onlyNativeDeviceTypes
@ops(python_ref_db)
def test_python_reference_meta_functions(self, device, dtype, op):
def _to_tensormeta(x):
if isinstance(x, torch.Tensor):
return prims.utils.TensorMeta(x)
# # TODO: iterate over requires_grad true/false
# for sample in op.reference_inputs(device, dtype, requires_grad=False):
# result = op(sample.input, *sample.args, **sample.kwargs)
# TODO: iterate over requires_grad true/false
for sample in op.reference_inputs(device, dtype, requires_grad=False):
result = op(sample.input, *sample.args, **sample.kwargs)
# meta_sample = sample.transform(_to_tensormeta)
# meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
meta_sample = sample.transform(_to_tensormeta)
meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
# prims.utils.compare_tensor_meta(result, meta_result)
prims.utils.compare_tensor_meta(result, meta_result)
# Tests that experimental Python References perform the same computation
# as the operators they reference.
@onlyNativeDeviceTypes
@ops(python_ref_db)
def test_python_reference_consistency(self, device, dtype, op):
for sample in op.reference_inputs(device, dtype, requires_grad=False):
actual = op(sample.input, *sample.args, **sample.kwargs)
expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
self.assertEqual(
actual,
expected,
exact_stride=True,
exact_device=True,
exact_layout=True,
exact_is_coalesced=True,
)
@skipMeta
@onlyNativeDeviceTypes
@ -272,9 +366,17 @@ class TestCommon(TestCase):
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
for sample_input in sample_inputs:
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
t_inp, t_args, t_kwargs = (
sample_input.input,
sample_input.args,
sample_input.kwargs,
)
noncontig_sample = sample_input.noncontiguous()
n_inp, n_args, n_kwargs = noncontig_sample.input, noncontig_sample.args, noncontig_sample.kwargs
n_inp, n_args, n_kwargs = (
noncontig_sample.input,
noncontig_sample.args,
noncontig_sample.kwargs,
)
# Verifies sample input tensors should have no grad or history
sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0]
@ -300,10 +402,14 @@ class TestCommon(TestCase):
grad_for_actual = noncontiguous_like(grad_for_expected)
elif isinstance(expected, Sequence):
# Filter output elements that do not require grad
expected = [t for t in expected
if isinstance(t, torch.Tensor) and t.requires_grad]
actual = [n for n in actual
if isinstance(n, torch.Tensor) and n.requires_grad]
expected = [
t
for t in expected
if isinstance(t, torch.Tensor) and t.requires_grad
]
actual = [
n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
]
grad_for_expected = [torch.randn_like(t) for t in expected]
grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
else:
@ -311,19 +417,35 @@ class TestCommon(TestCase):
continue
# Concatenate inputs into a tuple
t_inputs = (t_inp,) + t_args if isinstance(t_inp, torch.Tensor) else tuple(t_inp) + t_args
n_inputs = (n_inp,) + n_args if isinstance(n_inp, torch.Tensor) else tuple(n_inp) + n_args
t_inputs = (
(t_inp,) + t_args
if isinstance(t_inp, torch.Tensor)
else tuple(t_inp) + t_args
)
n_inputs = (
(n_inp,) + n_args
if isinstance(n_inp, torch.Tensor)
else tuple(n_inp) + n_args
)
# Filter the elemnts that are tensors that require grad
t_input_tensors = [t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad]
n_input_tensors = [n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad]
t_input_tensors = [
t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
]
n_input_tensors = [
n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
]
self.assertEqual(len(t_input_tensors), len(n_input_tensors))
# Some functions may not use all the inputs to generate gradients. One of the
# few examples of this "odd" behaviour is F.hinge_embedding_loss
t_grads = torch.autograd.grad(expected, t_input_tensors, grad_for_expected, allow_unused=True)
n_grads = torch.autograd.grad(actual, n_input_tensors, grad_for_actual, allow_unused=True)
t_grads = torch.autograd.grad(
expected, t_input_tensors, grad_for_expected, allow_unused=True
)
n_grads = torch.autograd.grad(
actual, n_input_tensors, grad_for_actual, allow_unused=True
)
msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
for i, (t, n) in enumerate(zip(t_grads, n_grads)):
@ -339,7 +461,11 @@ class TestCommon(TestCase):
supported_dtypes = op.supported_dtypes(self.device_type)
if len(supported_dtypes) == 0:
self.skipTest("Skipped! Op has not supported dtypes on this device.")
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
dtype = (
torch.float32
if torch.float32 in supported_dtypes
else list(supported_dtypes)[0]
)
samples = op.sample_inputs(device, dtype)
for sample in samples:
@ -349,8 +475,12 @@ class TestCommon(TestCase):
# Short-circuits if output is not a single tensor or an
# iterable of tensors
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
expected, include_empty=True
):
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
# Validates the op doesn't support out if it claims not to
if not op.supports_out:
@ -380,7 +510,7 @@ class TestCommon(TestCase):
# NOTE: only extracts on the CPU and CUDA device types since some
# device types don't have storage
def _extract_data_ptrs(out):
if self.device_type != 'cpu' and self.device_type != 'cuda':
if self.device_type != "cpu" and self.device_type != "cuda":
return ()
if isinstance(out, torch.Tensor):
@ -403,7 +533,8 @@ class TestCommon(TestCase):
if compare_strides_and_data_ptrs:
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
original_strides, final_strides)
original_strides, final_strides
)
self.assertEqual(original_strides, final_strides, msg=stride_msg)
self.assertEqual(original_ptrs, final_ptrs)
@ -433,7 +564,9 @@ class TestCommon(TestCase):
out = _apply_out_transform(_case_zero_transform, expected)
msg_fail = "Resized a non-empty tensor but did not warn about it."
if _any_nonempty(out):
with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail):
with self.assertWarnsRegex(
UserWarning, "An output with one or more elements", msg=msg_fail
):
op_out(out=out)
# Validates ops implement the correct out= behavior
@ -452,7 +585,11 @@ class TestCommon(TestCase):
supported_dtypes = op.supported_dtypes(self.device_type)
if len(supported_dtypes) == 0:
self.skipTest("Skipped! Op has not supported dtypes on this device.")
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
dtype = (
torch.float32
if torch.float32 in supported_dtypes
else list(supported_dtypes)[0]
)
samples = op.sample_inputs(device, dtype)
for sample in samples:
@ -462,8 +599,12 @@ class TestCommon(TestCase):
# Short-circuits if output is not a single tensor or an
# iterable of tensors
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
expected, include_empty=True
):
self.skipTest(
"Skipped! Only supports single tensor or iterable of tensor outputs."
)
# Validates the op doesn't support out if it claims not to
if not op.supports_out:
@ -493,7 +634,7 @@ class TestCommon(TestCase):
# NOTE: only extracts on the CPU and CUDA device types since some
# device types don't have storage
def _extract_data_ptrs(out):
if self.device_type != 'cpu' and self.device_type != 'cuda':
if self.device_type != "cpu" and self.device_type != "cuda":
return ()
if isinstance(out, torch.Tensor):
@ -515,7 +656,8 @@ class TestCommon(TestCase):
if compare_strides_and_data_ptrs:
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
original_strides, final_strides)
original_strides, final_strides
)
self.assertEqual(original_strides, final_strides, msg=stride_msg)
self.assertEqual(original_ptrs, final_ptrs)
@ -529,7 +671,7 @@ class TestCommon(TestCase):
return torch.full_like(t, info.max)
except TypeError as te:
# for non-integer types fills with NaN
return torch.full_like(t, float('nan'))
return torch.full_like(t, float("nan"))
_compare_out(_case_zero_transform)
@ -537,10 +679,9 @@ class TestCommon(TestCase):
# but noncontiguous.
# Expected behavior: strides are respected and `out` storage is not changed.
def _case_one_transform(t):
return make_tensor(t.shape,
dtype=t.dtype,
device=t.device,
noncontiguous=True)
return make_tensor(
t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
)
_compare_out(_case_one_transform)
@ -560,16 +701,19 @@ class TestCommon(TestCase):
# Verifies no warning is a resize warning
for w in caught:
if "An output with one or more elements" in str(w.message):
self.fail("Resizing an out= argument with no elements threw a resize warning!")
self.fail(
"Resizing an out= argument with no elements threw a resize warning!"
)
# Case 3: out= with correct shape and dtype, but wrong device.
wrong_device = None
if torch.device(device).type != 'cpu':
wrong_device = 'cpu'
if torch.device(device).type != "cpu":
wrong_device = "cpu"
elif torch.cuda.is_available():
wrong_device = 'cuda'
wrong_device = "cuda"
if wrong_device is not None:
def _case_three_transform(t):
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
@ -587,16 +731,28 @@ class TestCommon(TestCase):
# dtypes, or if an op returns multiple tensors when at least one such
# tensor is a floating point or complex dtype.
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
(not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))):
if (
isinstance(expected, torch.Tensor)
and expected.dtype in _dtypes
or (
not isinstance(expected, torch.Tensor)
and any(t.dtype in _dtypes for t in expected)
)
):
def _case_four_transform(t):
return make_tensor(t.shape, dtype=torch.long, device=t.device)
out = _apply_out_transform(_case_four_transform, expected)
msg_fail = "Expected RuntimeError when doing an unsafe cast!"
msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \
("Expected RuntimeError when doing an unsafe cast from a result of dtype "
f"{expected.dtype} into an out= with dtype torch.long")
msg_fail = (
msg_fail
if not isinstance(expected, torch.Tensor)
else (
"Expected RuntimeError when doing an unsafe cast from a result of dtype "
f"{expected.dtype} into an out= with dtype torch.long"
)
)
with self.assertRaises(RuntimeError, msg=msg_fail):
op_out(out=out)
@ -611,7 +767,9 @@ class TestCommon(TestCase):
inplace = op.inplace_variant
# list of all inplace ops: inplace variant + alias inplace variants if exist
inplace_ops = [inplace, ]
inplace_ops = [
inplace,
]
variants = [method, inplace]
for a_op in op.aliases:
@ -623,30 +781,46 @@ class TestCommon(TestCase):
inplace_variants = tuple(filter(None, inplace_ops))
variants = tuple(filter(None, variants))
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
_requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
samples = op.sample_inputs(
device,
dtype,
requires_grad=_requires_grad,
include_conjugated_inputs=include_conjugated_inputs,
)
samples = list(samples)
def _test_consistency_helper(samples, variants):
for sample in samples:
# TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
# Computes function forward and backward values
tensor.grad = None
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
expected_grad = None
output_process_fn_grad = sample.output_process_fn_grad if sample.output_process_fn_grad \
output_process_fn_grad = (
sample.output_process_fn_grad
if sample.output_process_fn_grad
else lambda x: x
)
# Skips inplace variants if the output dtype is not the same as
# the input dtype
skip_inplace = False
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is not tensor.dtype):
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is not tensor.dtype
):
skip_inplace = True
# TODO: backward consistency only supported for single tensor outputs
@ -654,8 +828,9 @@ class TestCommon(TestCase):
# tensor inputs
# TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output
if (isinstance(expected_forward, torch.Tensor)
and dtype in op.supported_backward_dtypes(torch.device(device).type)):
if isinstance(
expected_forward, torch.Tensor
) and dtype in op.supported_backward_dtypes(torch.device(device).type):
output_process_fn_grad(expected_forward).sum().backward()
expected_grad = tensor.grad
@ -668,26 +843,35 @@ class TestCommon(TestCase):
# Compares variant's forward
# Note: copies the to-be-modified input when testing the inplace variant
tensor.grad = None
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
cloned = (
clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
if variant in inplace_ops and sample.broadcasts_input:
with self.assertRaises(RuntimeError,
msg=('inplace variant either incorrectly allowed '
'resizing or you have marked the sample {}'
' incorrectly with `broadcasts_self=True'.format(sample.summary()))):
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
with self.assertRaises(
RuntimeError,
msg=(
"inplace variant either incorrectly allowed "
"resizing or you have marked the sample {}"
" incorrectly with `broadcasts_self=True".format(
sample.summary()
)
),
):
variant_forward = variant(
cloned, *sample.args, **sample.kwargs
)
continue
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
self.assertEqual(expected_forward, variant_forward)
# Compares variant's backward
if expected_grad is not None and \
(variant not in inplace_ops or op.supports_inplace_autograd):
if expected_grad is not None and (
variant not in inplace_ops or op.supports_inplace_autograd
):
output_process_fn_grad(variant_forward).sum().backward()
self.assertEqual(expected_grad, tensor.grad)
@ -698,28 +882,45 @@ class TestCommon(TestCase):
# Skips inplace variants if the output dtype is not the same as
# the input dtype
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
skip_inplace = False
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is not tensor.dtype):
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is not tensor.dtype
):
skip_inplace = True
if skip_inplace:
return
for variant in variants:
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0]
cloned = (
clone_input_helper(sample.input)
if variant in inplace_ops
else sample.input
)
inp_tensor = (
cloned if isinstance(cloned, torch.Tensor) else cloned[0]
)
data_ptr = inp_tensor.data_ptr()
variant_forward = variant(cloned,
*sample.args,
**sample.kwargs)
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
# TODO Support non-tensor outputs if they exist for inplace ops
if (isinstance(variant_forward, torch.Tensor)):
self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0)
if isinstance(variant_forward, torch.Tensor):
self.assertEqual(
data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
)
else:
self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported")
self.assertTrue(
False,
"Non-tensor outputs for inplace ops are not supported",
)
if len(inplace_ops) > 0:
inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples))
inplace_samples = list(
filter(lambda sample: not sample.broadcasts_input, samples)
)
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
# Reference testing for operations in complex32 against complex64.
@ -732,14 +933,21 @@ class TestCommon(TestCase):
for sample in op.sample_inputs(device, dtype):
actual = op(sample.input, *sample.args, **sample.kwargs)
transformed_sample = sample.transform(lambda x: x.to(torch.complex64))
expected = op(transformed_sample.input, *transformed_sample.args, **transformed_sample.kwargs)
expected = op(
transformed_sample.input,
*transformed_sample.args,
**transformed_sample.kwargs,
)
self.assertEqual(actual, expected, exact_dtype=False)
class TestCompositeCompliance(TestCase):
# Checks if the operator (if it is composite) is written to support most
# backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
# in aten/src/ATen/native/README.md for more details
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
@unittest.skipIf(
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
)
@ops(op_db, allowed_dtypes=(torch.float,))
def test_operator(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=False)
@ -750,7 +958,9 @@ class TestCompositeCompliance(TestCase):
composite_compliance.check_with_mode(op, args, kwargs)
composite_compliance.check_all_permutations(op, args, kwargs)
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
@unittest.skipIf(
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
)
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
def test_backward(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
@ -760,7 +970,9 @@ class TestCompositeCompliance(TestCase):
kwargs = sample.kwargs
composite_compliance.check_backward_formula(op, args, kwargs)
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
@unittest.skipIf(
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
)
@ops(op_db, allowed_dtypes=(torch.float,))
def test_forward_ad(self, device, dtype, op):
if torch.float not in op.supported_backward_dtypes(device):
@ -776,6 +988,7 @@ class TestCompositeCompliance(TestCase):
kwargs = sample.kwargs
composite_compliance.check_forward_ad_formula(op, args, kwargs)
class TestMathBits(TestCase):
# Tests that
# 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
@ -787,7 +1000,17 @@ class TestMathBits(TestCase):
# This test only runs for C -> R and C -> C functions
# TODO: add tests for `R->C` functions
# Note: This test runs for functions that take both tensors and tensorlists as input.
def _test_math_view(self, device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, out_type):
def _test_math_view(
self,
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
out_type,
):
inplace_variant = op.inplace_variant
# helper function to clone and conjugate/negate the input if its a tensor
@ -796,7 +1019,7 @@ class TestMathBits(TestCase):
# have its requires_grad set to that value.
def clone_and_perform_view(input, **kwargs):
if isinstance(input, torch.Tensor):
requires_grad = kwargs.get('requires_grad', input.requires_grad)
requires_grad = kwargs.get("requires_grad", input.requires_grad)
with torch.no_grad():
# Ensure view represents the original sample input
input = math_op_physical(input)
@ -813,7 +1036,11 @@ class TestMathBits(TestCase):
return tuple(out)
for sample in samples:
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
cloned1 = clone_and_perform_view(sample.input)
# Computes function forward value with a physically conjugated/negated tensor and
@ -827,9 +1054,13 @@ class TestMathBits(TestCase):
# input produces correct output, and the output tensor has the conj/neg bit set to True
if inplace_variant is not None and not sample.broadcasts_input:
cloned2 = clone_and_perform_view(tensor, requires_grad=False)
if (isinstance(expected_forward, torch.Tensor) and
expected_forward.dtype is tensor.dtype):
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.dtype is tensor.dtype
):
inplace_forward = inplace_variant(
cloned2, *sample.args, **sample.kwargs
)
self.assertTrue(is_bit_set(inplace_forward))
self.assertEqual(inplace_forward, expected_forward)
@ -838,25 +1069,36 @@ class TestMathBits(TestCase):
# tensor inputs
# TODO: update to handle checking grads of all tensor inputs as
# derived from each tensor output
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
if (
isinstance(expected_forward, torch.Tensor)
and expected_forward.requires_grad
):
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
expected_forward = output_process_fn_grad(expected_forward)
forward_with_mathview = output_process_fn_grad(forward_with_mathview)
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
tensor = (
sample.input
if isinstance(sample.input, torch.Tensor)
else sample.input[0]
)
expected_forward.sum().backward(retain_graph=True)
forward_with_mathview.sum().backward(retain_graph=True)
if tensor.grad is not None:
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
cloned1_tensor = (
cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
)
self.assertEqual(tensor.grad, cloned1_tensor.grad)
tensor.grad, cloned1_tensor.grad = None, None
# a repeat of the above test if output is not complex valued
if (out_type(expected_forward)):
if out_type(expected_forward):
grad = torch.randn_like(expected_forward)
expected_forward.backward(grad)
forward_with_mathview.backward(math_op_view(math_op_physical(grad)))
forward_with_mathview.backward(
math_op_view(math_op_physical(grad))
)
self.assertEqual(tensor.grad, cloned1_tensor.grad)
@ -866,10 +1108,21 @@ class TestMathBits(TestCase):
self.skipTest("Operation doesn't support conjugated inputs.")
math_op_physical = torch.conj_physical
math_op_view = torch.conj
_requires_grad = torch.cfloat in op.supported_backward_dtypes(torch.device(device).type)
_requires_grad = torch.cfloat in op.supported_backward_dtypes(
torch.device(device).type
)
is_bit_set = torch.is_conj
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, torch.is_complex)
self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
torch.is_complex,
)
@ops(op_db, allowed_dtypes=(torch.double,))
def test_neg_view(self, device, dtype, op):
@ -879,8 +1132,16 @@ class TestMathBits(TestCase):
math_op_view = torch._neg_view
is_bit_set = torch.is_neg
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
lambda x: True)
self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
lambda x: True,
)
@ops(op_db, allowed_dtypes=(torch.cdouble,))
def test_neg_conj_view(self, device, dtype, op):
@ -898,17 +1159,27 @@ class TestMathBits(TestCase):
def is_bit_set(x):
return torch.is_neg(x) and torch.is_conj(x)
_requires_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
_requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
# Only test one sample
samples = itertools.islice(samples, 1)
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
torch.is_complex)
self._test_math_view(
device,
dtype,
op,
samples,
math_op_physical,
math_op_view,
is_bit_set,
torch.is_complex,
)
instantiate_device_type_tests(TestCommon, globals())
instantiate_device_type_tests(TestCompositeCompliance, globals())
instantiate_device_type_tests(TestMathBits, globals())
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -1,85 +1,83 @@
# Owner(s): ["module: primTorch"]
# TODO: uncomment this file once CI issues with import nvfuser are resolved
from functools import partial
# from functools import partial
# import torch
# from torch.testing import make_tensor
# from torch.testing._internal.common_utils import run_tests, TestCase
# from torch.testing._internal.common_device_type import (
# instantiate_device_type_tests,
# onlyCUDA,
# dtypes,
# )
# import torch._prims as prims
# from torch._prims.executor import make_traced
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
onlyCUDA,
dtypes,
)
import torch._prims as prims
from torch._prims.executor import make_traced
# class TestPrims(TestCase):
# @onlyCUDA
# @dtypes(torch.float32)
# def test_broadcast_in_dim(self, device, dtype):
# def _wrapper(a, shape, broadcast_dimensions):
# return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
class TestPrims(TestCase):
@onlyCUDA
@dtypes(torch.float32)
def test_broadcast_in_dim(self, device, dtype):
def _wrapper(a, shape, broadcast_dimensions):
return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
# traced = make_traced(_wrapper)
# make_arg = partial(make_tensor, device=device, dtype=dtype)
traced = make_traced(_wrapper)
make_arg = partial(make_tensor, device=device, dtype=dtype)
# # TODO: FIXME:
# # for executor in ('aten', 'nvfuser'):
# for executor in ("aten",):
# fn = partial(traced, executor=executor)
# # Same shape
# shape = (5, 5)
# a = make_arg(shape)
# result = fn(a, shape, (0, 1))
# TODO: FIXME:
# for executor in ('aten', 'nvfuser'):
for executor in ("aten",):
fn = partial(traced, executor=executor)
# Same shape
shape = (5, 5)
a = make_arg(shape)
result = fn(a, shape, (0, 1))
# self.assertEqual(result.shape, a.shape)
# self.assertTrue(result.is_contiguous)
# self.assertEqual(a, result)
self.assertEqual(result.shape, a.shape)
self.assertTrue(result.is_contiguous)
self.assertEqual(a, result)
# # Error input: reordering dims
# with self.assertRaises(Exception):
# result = fn(a, shape, (1, 0))
# Error input: reordering dims
with self.assertRaises(Exception):
result = fn(a, shape, (1, 0))
# # Adding outermost dimensions
# a = make_arg((5, 5))
# target_shape = (3, 3, 5, 5)
# result = fn(a, target_shape, (2, 3))
# Adding outermost dimensions
a = make_arg((5, 5))
target_shape = (3, 3, 5, 5)
result = fn(a, target_shape, (2, 3))
# self.assertEqual(result.shape, target_shape)
# self.assertEqual(a.broadcast_to(target_shape), result)
self.assertEqual(result.shape, target_shape)
self.assertEqual(a.broadcast_to(target_shape), result)
# # Expands
# a = make_arg((1, 5, 1))
# target_shape = (3, 5, 7)
# result = fn(a, target_shape, (0, 1, 2))
# Expands
a = make_arg((1, 5, 1))
target_shape = (3, 5, 7)
result = fn(a, target_shape, (0, 1, 2))
# self.assertEqual(result.shape, target_shape)
# self.assertEqual(a.expand_as(result), result)
self.assertEqual(result.shape, target_shape)
self.assertEqual(a.expand_as(result), result)
# # Unsqueezes
# a = make_arg((1, 2, 3))
# target_shape = (1, 2, 1, 3)
# result = fn(a, target_shape, (0, 1, 3))
# Unsqueezes
a = make_arg((1, 2, 3))
target_shape = (1, 2, 1, 3)
result = fn(a, target_shape, (0, 1, 3))
# self.assertEqual(result.shape, target_shape)
# self.assertEqual(a.unsqueeze(2), result)
self.assertEqual(result.shape, target_shape)
self.assertEqual(a.unsqueeze(2), result)
# # Adds outermost, expands, and unsqueezes
# a = make_arg((1, 2, 3))
# target_shape = (4, 1, 7, 2, 3, 3)
# result = fn(a, target_shape, (1, 3, 4))
# Adds outermost, expands, and unsqueezes
a = make_arg((1, 2, 3))
target_shape = (4, 1, 7, 2, 3, 3)
result = fn(a, target_shape, (1, 3, 4))
# self.assertEqual(result.shape, target_shape)
# a.unsqueeze_(3)
# a.unsqueeze_(1)
# a.unsqueeze_(0)
# self.assertEqual(a.expand_as(result), result)
self.assertEqual(result.shape, target_shape)
a.unsqueeze_(3)
a.unsqueeze_(1)
a.unsqueeze_(0)
self.assertEqual(a.expand_as(result), result)
# instantiate_device_type_tests(TestPrims, globals())
instantiate_device_type_tests(TestPrims, globals())
# if __name__ == "__main__":
# run_tests()
if __name__ == "__main__":
run_tests()

View File

@ -13,7 +13,7 @@ import random
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
torch_to_numpy_dtype_dict, slowTest,
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize)
from torch.testing._internal.common_device_type import (
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
@ -3683,13 +3683,13 @@ class TestBufferProtocol(TestCase):
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
return (numpy_original, torch_frombuffer)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_same_type(self, device, dtype):
self._run_test((), dtype)
self._run_test((4,), dtype)
self._run_test((10, 10), dtype)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_requires_grad(self, device, dtype):
def _run_test_and_check_grad(requires_grad, *args, **kwargs):
kwargs["requires_grad"] = requires_grad
@ -3704,14 +3704,14 @@ class TestBufferProtocol(TestCase):
_run_test_and_check_grad(False, (4,), dtype)
_run_test_and_check_grad(False, (10, 10), dtype)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_with_offset(self, device, dtype):
# Offset should be valid whenever there is, at least,
# one remaining element
for i in range(SIZE):
self._run_test(SHAPE, dtype, first=i)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_with_count(self, device, dtype):
# Count should be valid for any valid in the interval
# [-1, len(input)], except for 0
@ -3719,7 +3719,7 @@ class TestBufferProtocol(TestCase):
if i != 0:
self._run_test(SHAPE, dtype, count=i)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_with_count_and_offset(self, device, dtype):
# Explicit default count [-1, 1, 2, ..., len]
for i in range(-1, SIZE + 1):
@ -3735,7 +3735,7 @@ class TestBufferProtocol(TestCase):
for j in range(SIZE - i + 1):
self._run_test(SHAPE, dtype, count=i, first=j)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_invalid_positional_args(self, device, dtype):
bytes = get_dtype_size(dtype)
in_bytes = SIZE * bytes
@ -3772,7 +3772,7 @@ class TestBufferProtocol(TestCase):
rf"buffer length \({in_bytes} bytes\)"):
self._run_test(SHAPE, dtype, count=count, first=first)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_shared_buffer(self, device, dtype):
x = make_tensor((1,), dtype=dtype, device=device)
# Modify the whole tensor
@ -3799,13 +3799,13 @@ class TestBufferProtocol(TestCase):
arr[first] = x.item() - 1
self.assertEqual(arr[first:last], tensor)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_not_a_buffer(self, device, dtype):
with self.assertRaisesRegex(ValueError,
r"object does not implement Python buffer protocol."):
torch.frombuffer([1, 2, 3, 4], dtype=dtype)
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_non_writable_buffer(self, device, dtype):
numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy()
byte_arr = numpy_arr.tobytes()
@ -3910,7 +3910,7 @@ class TestAsArray(TestCase):
self._test_alias_with_cvt(identity, device, dtype)
@onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_alias_from_numpy(self, device, dtype):
self._test_alias_with_cvt(to_numpy, device, dtype)
@ -3921,7 +3921,7 @@ class TestAsArray(TestCase):
self._test_alias_with_cvt(to_dlpack, device, dtype)
@onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_alias_from_buffer(self, device, dtype):
self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
@ -3959,7 +3959,7 @@ class TestAsArray(TestCase):
self._test_copy_with_cvt(identity, device, dtype)
@onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_copy_from_numpy(self, device, dtype):
self._test_copy_with_cvt(to_numpy, device, dtype)
@ -3969,7 +3969,7 @@ class TestAsArray(TestCase):
self._test_copy_with_cvt(to_dlpack, device, dtype)
@onlyCPU
@dtypes(*torch_to_numpy_dtype_dict.keys())
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
def test_copy_from_buffer(self, device, dtype):
self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)

View File

@ -7,15 +7,14 @@ import unittest
import torch
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
TEST_NUMPY, torch_to_numpy_dtype_dict)
TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict)
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta)
from torch.testing._internal.common_dtype import (
all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and
)
if TEST_NUMPY:
import numpy as np
import numpy as np
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -812,8 +811,8 @@ class TestTypePromotion(TestCase):
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@float_double_default_dtype
@onlyCPU
@dtypes(*list(itertools.product(torch_to_numpy_dtype_dict.keys(),
torch_to_numpy_dtype_dict.keys())))
@dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()),
set(numpy_to_torch_dtype_dict.values()))))
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
import operator
np_type = torch_to_numpy_dtype_dict[dtypes[0]]

File diff suppressed because it is too large Load Diff

View File

@ -11,7 +11,7 @@ import random
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck,
torch_to_numpy_dtype_dict,
numpy_to_torch_dtype_dict,
)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
@ -130,7 +130,7 @@ class TestViewOps(TestCase):
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
def test_view_dtype_new(self, device, dtype):
dtypes = torch_to_numpy_dtype_dict.copy()
dtypes = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
del dtypes[torch.bool]
def generate_inputs():

View File

@ -2,14 +2,15 @@ import torch
from torch import Tensor
import torch._prims.utils as utils
from torch._prims.utils import TensorLike, TensorLikeType, TensorMeta, ShapeType
from torch._prims.utils import (
TensorLike,
TensorLikeType,
TensorMeta,
ShapeType,
getnvFuserDtype,
)
from torch.overrides import has_torch_function, handle_torch_function
import torch._C._nvfuser as nvfuser # type: ignore[import]
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
DataType = nvfuser.DataType # type: ignore[attr-defined]
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any
from numbers import Number
from functools import reduce
@ -1033,21 +1034,8 @@ def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
return a.to(dtype)
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor:
nvfuser_dtype = _torch_dtype_to_nvfuser_dtype_map[dtype]
nvfuser_dtype = getnvFuserDtype(dtype)
return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined]

View File

@ -3,27 +3,11 @@ from typing import Callable
import torch
from torch.fx import GraphModule
from torch._prims.utils import TensorMeta
from torch._prims.utils import TensorMeta, getnvFuserDtype
from torch._prims.context import PrimContext
import torch._C._nvfuser as nvfuser # type: ignore[import]
DataType = nvfuser.DataType # type: ignore[attr-defined]
Fusion = nvfuser.Fusion # type: ignore[attr-defined]
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
# TODO: refactor me into a common place
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
if torch.cuda.is_available():
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
@ -37,6 +21,11 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
gm = GraphModule({}, ctx.graph)
return gm.forward(*args, **kwargs)
elif executor == "nvfuser":
if not torch.cuda.is_available():
raise RuntimeError(
"Attempting to use nvFuser trace executor but CUDA is not available!"
)
# PROTOTYPE nvfuser executor
# Only accepts tensor inputs and single tensor outputs
# Does not handle kwargs
@ -53,9 +42,7 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
nv_args = [fd]
for arg in args:
if isinstance(arg, torch.Tensor):
x = fd.define_tensor(
arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype]
)
x = fd.define_tensor(arg.ndim, getnvFuserDtype(arg.dtype))
fd.add_input(x)
nv_args.append(x)
else:

View File

@ -8,6 +8,32 @@ import threading
import torch
from torch.fx import Node
# nvFuser imports are conditional on CUDA being available
if torch.cuda.is_available():
from torch._C._nvfuser import DataType # type: ignore[import]
_torch_dtype_to_nvfuser_dtype_map = {
torch.cdouble: DataType.ComplexDouble,
torch.cfloat: DataType.ComplexFloat,
torch.double: DataType.Double,
torch.float: DataType.Float,
torch.half: DataType.Half,
torch.bfloat16: DataType.BFloat16,
torch.long: DataType.Int,
torch.int: DataType.Int32,
torch.bool: DataType.Bool,
}
else:
_torch_dtype_to_nvfuser_dtype_map = {}
def getnvFuserDtype(dtype: torch.dtype):
"""
Translates from torch.dtype to nvFuser's DataType enum
"""
return _torch_dtype_to_nvfuser_dtype_map[dtype]
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
StrideType = Union[List[int], Tuple[int, ...]]
DimsType = Union[int, List[int], Tuple[int, ...]]

File diff suppressed because it is too large Load Diff

View File

@ -871,6 +871,10 @@ if IS_WINDOWS:
# Dict of torch dtype -> NumPy dtype
torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
torch_to_numpy_dtype_dict.update({
torch.bfloat16: np.float32,
torch.complex32: np.complex64
})
def skipIfRocm(fn):
@wraps(fn)