mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Adds python ref consistency test, elementwise unary reference inputs, and formats test files
Per title. Pull Request resolved: https://github.com/pytorch/pytorch/pull/76626 Approved by: https://github.com/ngimel
This commit is contained in:
committed by
PyTorch MergeBot
parent
33be4c94c0
commit
f6bbecf8b5
File diff suppressed because it is too large
Load Diff
571
test/test_ops.py
571
test/test_ops.py
@ -8,17 +8,44 @@ import itertools
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_dtype import floating_and_complex_types_and, all_types_and_complex_and
|
from torch.testing._internal.common_dtype import (
|
||||||
from torch.testing._internal.common_utils import \
|
floating_and_complex_types_and,
|
||||||
(TestCase, is_iterable_of_tensors, run_tests, IS_SANDCASTLE, clone_input_helper,
|
all_types_and_complex_and,
|
||||||
IS_IN_CI, suppress_warnings, noncontiguous_like,
|
)
|
||||||
TEST_WITH_ASAN, IS_WINDOWS, IS_FBCODE, first_sample)
|
from torch.testing._internal.common_utils import (
|
||||||
from torch.testing._internal.common_methods_invocations import \
|
TestCase,
|
||||||
(op_db, _NOTHING, UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, ops_and_refs)
|
is_iterable_of_tensors,
|
||||||
from torch.testing._internal.common_device_type import \
|
run_tests,
|
||||||
(deviceCountAtLeast, instantiate_device_type_tests, ops,
|
IS_SANDCASTLE,
|
||||||
onlyCUDA, onlyNativeDeviceTypes, OpDTypes, skipMeta)
|
clone_input_helper,
|
||||||
# import torch._prims as prims
|
IS_IN_CI,
|
||||||
|
suppress_warnings,
|
||||||
|
noncontiguous_like,
|
||||||
|
TEST_WITH_ASAN,
|
||||||
|
IS_WINDOWS,
|
||||||
|
IS_FBCODE,
|
||||||
|
first_sample,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_methods_invocations import (
|
||||||
|
op_db,
|
||||||
|
_NOTHING,
|
||||||
|
UnaryUfuncInfo,
|
||||||
|
ReductionOpInfo,
|
||||||
|
SpectralFuncInfo,
|
||||||
|
ops_and_refs,
|
||||||
|
python_ref_db,
|
||||||
|
BinaryUfuncInfo,
|
||||||
|
)
|
||||||
|
from torch.testing._internal.common_device_type import (
|
||||||
|
deviceCountAtLeast,
|
||||||
|
instantiate_device_type_tests,
|
||||||
|
ops,
|
||||||
|
onlyCUDA,
|
||||||
|
onlyNativeDeviceTypes,
|
||||||
|
OpDTypes,
|
||||||
|
skipMeta,
|
||||||
|
)
|
||||||
|
import torch._prims as prims
|
||||||
|
|
||||||
import torch.testing._internal.opinfo_helper as opinfo_helper
|
import torch.testing._internal.opinfo_helper as opinfo_helper
|
||||||
from torch.testing._internal import composite_compliance
|
from torch.testing._internal import composite_compliance
|
||||||
@ -28,15 +55,25 @@ torch.set_default_dtype(torch.float32)
|
|||||||
|
|
||||||
# variant testing is only done with torch.float and torch.cfloat to avoid
|
# variant testing is only done with torch.float and torch.cfloat to avoid
|
||||||
# excessive test times and maximize signal to noise ratio
|
# excessive test times and maximize signal to noise ratio
|
||||||
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
|
_variant_ops = partial(
|
||||||
allowed_dtypes=(torch.float, torch.cfloat))
|
ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
|
||||||
|
)
|
||||||
|
|
||||||
# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
|
# Get names of all the operators which have ref in their entry in OpInfo (testing infra)
|
||||||
# except for Unary Ufuncs (separately implemented in test/test_unary_ufuncs.py)
|
# except for elementwise unary operators (separately implemented in test/test_unary_ufuncs.py),
|
||||||
|
# elementwise binary operators (separately implemented in test_binary_ufuncs.py),
|
||||||
|
# reduction operations (separately impelemented in test_reductions.py),
|
||||||
# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
|
# and Spectral Functions (separately implemented for only 1D as of now, in test/test_spectral_ops.py)
|
||||||
_ref_test_ops = list(filter(lambda op: not isinstance(op, (UnaryUfuncInfo, ReductionOpInfo,
|
_ref_test_ops = tuple(
|
||||||
SpectralFuncInfo)) and op.ref is not None and op.ref is not _NOTHING, op_db))
|
filter(
|
||||||
|
lambda op: not isinstance(
|
||||||
|
op, (UnaryUfuncInfo, ReductionOpInfo, SpectralFuncInfo, BinaryUfuncInfo)
|
||||||
|
)
|
||||||
|
and op.ref is not None
|
||||||
|
and op.ref is not _NOTHING,
|
||||||
|
op_db,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
# Tests that apply to all operators and aren't related to any particular
|
# Tests that apply to all operators and aren't related to any particular
|
||||||
# system
|
# system
|
||||||
@ -49,8 +86,10 @@ class TestCommon(TestCase):
|
|||||||
super().tearDownClass()
|
super().tearDownClass()
|
||||||
|
|
||||||
if IS_IN_CI:
|
if IS_IN_CI:
|
||||||
err_msg = ("The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
|
err_msg = (
|
||||||
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!")
|
"The operator(s) below is(are) using dynamic_dtypes in the OpInfo entries."
|
||||||
|
"This is OK for testing, but be sure to set the dtypes manually before landing your PR!"
|
||||||
|
)
|
||||||
# Assure no opinfo entry has dynamic_dtypes
|
# Assure no opinfo entry has dynamic_dtypes
|
||||||
filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db))
|
filtered_ops = list(filter(opinfo_helper.is_dynamic_dtype_set, op_db))
|
||||||
for op in filtered_ops:
|
for op in filtered_ops:
|
||||||
@ -68,11 +107,16 @@ class TestCommon(TestCase):
|
|||||||
# Check complex32 support only if the op claims.
|
# Check complex32 support only if the op claims.
|
||||||
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
|
# TODO: Once the complex32 support is better, we should add check for complex32 unconditionally.
|
||||||
device_type = torch.device(device).type
|
device_type = torch.device(device).type
|
||||||
include_complex32 = ((torch.complex32,) if op.supports_dtype(torch.complex32, device_type) else ())
|
include_complex32 = (
|
||||||
|
(torch.complex32,)
|
||||||
|
if op.supports_dtype(torch.complex32, device_type)
|
||||||
|
else ()
|
||||||
|
)
|
||||||
|
|
||||||
# dtypes to try to backward in
|
# dtypes to try to backward in
|
||||||
allowed_backward_dtypes = floating_and_complex_types_and(
|
allowed_backward_dtypes = floating_and_complex_types_and(
|
||||||
*((torch.half, torch.bfloat16) + include_complex32))
|
*((torch.half, torch.bfloat16) + include_complex32)
|
||||||
|
)
|
||||||
|
|
||||||
# lists for (un)supported dtypes
|
# lists for (un)supported dtypes
|
||||||
supported_dtypes = set()
|
supported_dtypes = set()
|
||||||
@ -86,11 +130,14 @@ class TestCommon(TestCase):
|
|||||||
unsupported_backward_dtypes.add(dtype)
|
unsupported_backward_dtypes.add(dtype)
|
||||||
|
|
||||||
for dtype in all_types_and_complex_and(
|
for dtype in all_types_and_complex_and(
|
||||||
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)):
|
*((torch.half, torch.bfloat16, torch.bool) + include_complex32)
|
||||||
|
):
|
||||||
# tries to acquire samples - failure indicates lack of support
|
# tries to acquire samples - failure indicates lack of support
|
||||||
requires_grad = (dtype in allowed_backward_dtypes)
|
requires_grad = dtype in allowed_backward_dtypes
|
||||||
try:
|
try:
|
||||||
samples = tuple(op.sample_inputs(device, dtype, requires_grad=requires_grad))
|
samples = tuple(
|
||||||
|
op.sample_inputs(device, dtype, requires_grad=requires_grad)
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
unsupported(dtype)
|
unsupported(dtype)
|
||||||
continue
|
continue
|
||||||
@ -113,7 +160,9 @@ class TestCommon(TestCase):
|
|||||||
result = sample.output_process_fn_grad(result)
|
result = sample.output_process_fn_grad(result)
|
||||||
if isinstance(result, torch.Tensor):
|
if isinstance(result, torch.Tensor):
|
||||||
backward_tensor = result
|
backward_tensor = result
|
||||||
elif isinstance(result, Sequence) and isinstance(result[0], torch.Tensor):
|
elif isinstance(result, Sequence) and isinstance(
|
||||||
|
result[0], torch.Tensor
|
||||||
|
):
|
||||||
backward_tensor = result[0]
|
backward_tensor = result[0]
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
@ -130,14 +179,15 @@ class TestCommon(TestCase):
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
unsupported_backward_dtypes.add(dtype)
|
unsupported_backward_dtypes.add(dtype)
|
||||||
|
|
||||||
|
|
||||||
# Checks that dtypes are listed correctly and generates an informative
|
# Checks that dtypes are listed correctly and generates an informative
|
||||||
# error message
|
# error message
|
||||||
supported_forward = supported_dtypes - unsupported_dtypes
|
supported_forward = supported_dtypes - unsupported_dtypes
|
||||||
partially_supported_forward = supported_dtypes & unsupported_dtypes
|
partially_supported_forward = supported_dtypes & unsupported_dtypes
|
||||||
unsupported_forward = unsupported_dtypes - supported_dtypes
|
unsupported_forward = unsupported_dtypes - supported_dtypes
|
||||||
supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
|
supported_backward = supported_backward_dtypes - unsupported_backward_dtypes
|
||||||
partially_supported_backward = supported_backward_dtypes & unsupported_backward_dtypes
|
partially_supported_backward = (
|
||||||
|
supported_backward_dtypes & unsupported_backward_dtypes
|
||||||
|
)
|
||||||
unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
|
unsupported_backward = unsupported_backward_dtypes - supported_backward_dtypes
|
||||||
|
|
||||||
device_type = torch.device(device).type
|
device_type = torch.device(device).type
|
||||||
@ -156,17 +206,27 @@ class TestCommon(TestCase):
|
|||||||
op.name, device_type
|
op.name, device_type
|
||||||
)
|
)
|
||||||
if len(partially_supported_forward) > 0:
|
if len(partially_supported_forward) > 0:
|
||||||
msg = msg + "The following dtypes only worked on some samples during forward: {0}.\n".format(
|
msg = (
|
||||||
partially_supported_forward
|
msg
|
||||||
|
+ "The following dtypes only worked on some samples during forward: {0}.\n".format(
|
||||||
|
partially_supported_forward
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if len(partially_supported_backward) > 0:
|
if len(partially_supported_backward) > 0:
|
||||||
msg = msg + "The following dtypes only worked on some samples during backward: {0}.\n".format(
|
msg = (
|
||||||
partially_supported_backward
|
msg
|
||||||
|
+ "The following dtypes only worked on some samples during backward: {0}.\n".format(
|
||||||
|
partially_supported_backward
|
||||||
|
)
|
||||||
)
|
)
|
||||||
print(msg)
|
print(msg)
|
||||||
|
|
||||||
if (len(supported_but_unclaimed_forward) + len(claimed_but_unsupported_forward) +
|
if (
|
||||||
len(supported_but_unclaimed_backward) + len(claimed_but_unsupported_backward)) == 0:
|
len(supported_but_unclaimed_forward)
|
||||||
|
+ len(claimed_but_unsupported_forward)
|
||||||
|
+ len(supported_but_unclaimed_backward)
|
||||||
|
+ len(claimed_but_unsupported_backward)
|
||||||
|
) == 0:
|
||||||
return
|
return
|
||||||
|
|
||||||
# Generates error msg
|
# Generates error msg
|
||||||
@ -174,20 +234,32 @@ class TestCommon(TestCase):
|
|||||||
op.name, device_type
|
op.name, device_type
|
||||||
)
|
)
|
||||||
if len(supported_but_unclaimed_forward) > 0:
|
if len(supported_but_unclaimed_forward) > 0:
|
||||||
msg = msg + "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
|
msg = (
|
||||||
supported_but_unclaimed_forward
|
msg
|
||||||
|
+ "The following dtypes worked in forward but are not listed by the OpInfo: {0}.\n".format(
|
||||||
|
supported_but_unclaimed_forward
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if len(supported_but_unclaimed_backward) > 0:
|
if len(supported_but_unclaimed_backward) > 0:
|
||||||
msg = msg + "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
|
msg = (
|
||||||
supported_but_unclaimed_backward
|
msg
|
||||||
|
+ "The following dtypes worked in backward but are not listed by the OpInfo: {0}.\n".format(
|
||||||
|
supported_but_unclaimed_backward
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if len(claimed_but_unsupported_forward) > 0:
|
if len(claimed_but_unsupported_forward) > 0:
|
||||||
msg = msg + "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
|
msg = (
|
||||||
claimed_but_unsupported_forward
|
msg
|
||||||
|
+ "The following dtypes did not work in forward but are listed by the OpInfo: {0}.\n".format(
|
||||||
|
claimed_but_unsupported_forward
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if len(claimed_but_unsupported_backward) > 0:
|
if len(claimed_but_unsupported_backward) > 0:
|
||||||
msg = msg + "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
|
msg = (
|
||||||
claimed_but_unsupported_backward
|
msg
|
||||||
|
+ "The following dtypes did not work in backward but are listed by the OpInfo: {0}.\n".format(
|
||||||
|
claimed_but_unsupported_backward
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.fail(msg)
|
self.fail(msg)
|
||||||
@ -209,7 +281,9 @@ class TestCommon(TestCase):
|
|||||||
elif is_iterable_of_tensors(result):
|
elif is_iterable_of_tensors(result):
|
||||||
self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
|
self.assertTrue(all(map(lambda t: t.device == cuda_device, result)))
|
||||||
else:
|
else:
|
||||||
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
self.skipTest(
|
||||||
|
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||||
|
)
|
||||||
|
|
||||||
# Tests that the function and its (ndarray-accepting) reference produce the same
|
# Tests that the function and its (ndarray-accepting) reference produce the same
|
||||||
# values on the tensors from sample_inputs func for the corresponding op.
|
# values on the tensors from sample_inputs func for the corresponding op.
|
||||||
@ -226,28 +300,48 @@ class TestCommon(TestCase):
|
|||||||
cur_default = torch.get_default_dtype()
|
cur_default = torch.get_default_dtype()
|
||||||
torch.set_default_dtype(torch.double)
|
torch.set_default_dtype(torch.double)
|
||||||
for sample_input in op.reference_inputs(device, dtype):
|
for sample_input in op.reference_inputs(device, dtype):
|
||||||
self.compare_with_reference(op, op.ref, sample_input, exact_dtype=(dtype is not torch.long))
|
self.compare_with_reference(
|
||||||
|
op, op.ref, sample_input, exact_dtype=(dtype is not torch.long)
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
torch.set_default_dtype(cur_default)
|
torch.set_default_dtype(cur_default)
|
||||||
|
|
||||||
# Tests that experimental Python References' can propagate shape, dtype,
|
# Tests that experimental Python References can propagate shape, dtype,
|
||||||
# and device metadata properly.
|
# and device metadata properly.
|
||||||
# TODO: include stride propagation.
|
# TODO: include stride propagation.
|
||||||
# @onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
# @ops(python_ref_db)
|
@ops(python_ref_db)
|
||||||
# def test_python_reference_meta_functions(self, device, dtype, op):
|
def test_python_reference_meta_functions(self, device, dtype, op):
|
||||||
# def _to_tensormeta(x):
|
def _to_tensormeta(x):
|
||||||
# if isinstance(x, torch.Tensor):
|
if isinstance(x, torch.Tensor):
|
||||||
# return prims.utils.TensorMeta(x)
|
return prims.utils.TensorMeta(x)
|
||||||
|
|
||||||
# # TODO: iterate over requires_grad true/false
|
# TODO: iterate over requires_grad true/false
|
||||||
# for sample in op.reference_inputs(device, dtype, requires_grad=False):
|
for sample in op.reference_inputs(device, dtype, requires_grad=False):
|
||||||
# result = op(sample.input, *sample.args, **sample.kwargs)
|
result = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
|
|
||||||
# meta_sample = sample.transform(_to_tensormeta)
|
meta_sample = sample.transform(_to_tensormeta)
|
||||||
# meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
|
meta_result = op(meta_sample.input, *meta_sample.args, **meta_sample.kwargs)
|
||||||
|
|
||||||
# prims.utils.compare_tensor_meta(result, meta_result)
|
prims.utils.compare_tensor_meta(result, meta_result)
|
||||||
|
|
||||||
|
# Tests that experimental Python References perform the same computation
|
||||||
|
# as the operators they reference.
|
||||||
|
@onlyNativeDeviceTypes
|
||||||
|
@ops(python_ref_db)
|
||||||
|
def test_python_reference_consistency(self, device, dtype, op):
|
||||||
|
for sample in op.reference_inputs(device, dtype, requires_grad=False):
|
||||||
|
actual = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
|
expected = op.torch_opinfo(sample.input, *sample.args, **sample.kwargs)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
actual,
|
||||||
|
expected,
|
||||||
|
exact_stride=True,
|
||||||
|
exact_device=True,
|
||||||
|
exact_layout=True,
|
||||||
|
exact_is_coalesced=True,
|
||||||
|
)
|
||||||
|
|
||||||
@skipMeta
|
@skipMeta
|
||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@ -272,9 +366,17 @@ class TestCommon(TestCase):
|
|||||||
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
|
test_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
|
||||||
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
|
sample_inputs = op.sample_inputs(device, dtype, requires_grad=test_grad)
|
||||||
for sample_input in sample_inputs:
|
for sample_input in sample_inputs:
|
||||||
t_inp, t_args, t_kwargs = sample_input.input, sample_input.args, sample_input.kwargs
|
t_inp, t_args, t_kwargs = (
|
||||||
|
sample_input.input,
|
||||||
|
sample_input.args,
|
||||||
|
sample_input.kwargs,
|
||||||
|
)
|
||||||
noncontig_sample = sample_input.noncontiguous()
|
noncontig_sample = sample_input.noncontiguous()
|
||||||
n_inp, n_args, n_kwargs = noncontig_sample.input, noncontig_sample.args, noncontig_sample.kwargs
|
n_inp, n_args, n_kwargs = (
|
||||||
|
noncontig_sample.input,
|
||||||
|
noncontig_sample.args,
|
||||||
|
noncontig_sample.kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
# Verifies sample input tensors should have no grad or history
|
# Verifies sample input tensors should have no grad or history
|
||||||
sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0]
|
sample_tensor = t_inp if isinstance(t_inp, torch.Tensor) else t_inp[0]
|
||||||
@ -300,10 +402,14 @@ class TestCommon(TestCase):
|
|||||||
grad_for_actual = noncontiguous_like(grad_for_expected)
|
grad_for_actual = noncontiguous_like(grad_for_expected)
|
||||||
elif isinstance(expected, Sequence):
|
elif isinstance(expected, Sequence):
|
||||||
# Filter output elements that do not require grad
|
# Filter output elements that do not require grad
|
||||||
expected = [t for t in expected
|
expected = [
|
||||||
if isinstance(t, torch.Tensor) and t.requires_grad]
|
t
|
||||||
actual = [n for n in actual
|
for t in expected
|
||||||
if isinstance(n, torch.Tensor) and n.requires_grad]
|
if isinstance(t, torch.Tensor) and t.requires_grad
|
||||||
|
]
|
||||||
|
actual = [
|
||||||
|
n for n in actual if isinstance(n, torch.Tensor) and n.requires_grad
|
||||||
|
]
|
||||||
grad_for_expected = [torch.randn_like(t) for t in expected]
|
grad_for_expected = [torch.randn_like(t) for t in expected]
|
||||||
grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
|
grad_for_actual = [noncontiguous_like(n) for n in grad_for_expected]
|
||||||
else:
|
else:
|
||||||
@ -311,19 +417,35 @@ class TestCommon(TestCase):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Concatenate inputs into a tuple
|
# Concatenate inputs into a tuple
|
||||||
t_inputs = (t_inp,) + t_args if isinstance(t_inp, torch.Tensor) else tuple(t_inp) + t_args
|
t_inputs = (
|
||||||
n_inputs = (n_inp,) + n_args if isinstance(n_inp, torch.Tensor) else tuple(n_inp) + n_args
|
(t_inp,) + t_args
|
||||||
|
if isinstance(t_inp, torch.Tensor)
|
||||||
|
else tuple(t_inp) + t_args
|
||||||
|
)
|
||||||
|
n_inputs = (
|
||||||
|
(n_inp,) + n_args
|
||||||
|
if isinstance(n_inp, torch.Tensor)
|
||||||
|
else tuple(n_inp) + n_args
|
||||||
|
)
|
||||||
|
|
||||||
# Filter the elemnts that are tensors that require grad
|
# Filter the elemnts that are tensors that require grad
|
||||||
t_input_tensors = [t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad]
|
t_input_tensors = [
|
||||||
n_input_tensors = [n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad]
|
t for t in t_inputs if isinstance(t, torch.Tensor) and t.requires_grad
|
||||||
|
]
|
||||||
|
n_input_tensors = [
|
||||||
|
n for n in n_inputs if isinstance(n, torch.Tensor) and n.requires_grad
|
||||||
|
]
|
||||||
|
|
||||||
self.assertEqual(len(t_input_tensors), len(n_input_tensors))
|
self.assertEqual(len(t_input_tensors), len(n_input_tensors))
|
||||||
|
|
||||||
# Some functions may not use all the inputs to generate gradients. One of the
|
# Some functions may not use all the inputs to generate gradients. One of the
|
||||||
# few examples of this "odd" behaviour is F.hinge_embedding_loss
|
# few examples of this "odd" behaviour is F.hinge_embedding_loss
|
||||||
t_grads = torch.autograd.grad(expected, t_input_tensors, grad_for_expected, allow_unused=True)
|
t_grads = torch.autograd.grad(
|
||||||
n_grads = torch.autograd.grad(actual, n_input_tensors, grad_for_actual, allow_unused=True)
|
expected, t_input_tensors, grad_for_expected, allow_unused=True
|
||||||
|
)
|
||||||
|
n_grads = torch.autograd.grad(
|
||||||
|
actual, n_input_tensors, grad_for_actual, allow_unused=True
|
||||||
|
)
|
||||||
|
|
||||||
msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
|
msg = "Got different gradients for contiguous / non-contiguous inputs wrt input {}."
|
||||||
for i, (t, n) in enumerate(zip(t_grads, n_grads)):
|
for i, (t, n) in enumerate(zip(t_grads, n_grads)):
|
||||||
@ -339,7 +461,11 @@ class TestCommon(TestCase):
|
|||||||
supported_dtypes = op.supported_dtypes(self.device_type)
|
supported_dtypes = op.supported_dtypes(self.device_type)
|
||||||
if len(supported_dtypes) == 0:
|
if len(supported_dtypes) == 0:
|
||||||
self.skipTest("Skipped! Op has not supported dtypes on this device.")
|
self.skipTest("Skipped! Op has not supported dtypes on this device.")
|
||||||
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
|
dtype = (
|
||||||
|
torch.float32
|
||||||
|
if torch.float32 in supported_dtypes
|
||||||
|
else list(supported_dtypes)[0]
|
||||||
|
)
|
||||||
|
|
||||||
samples = op.sample_inputs(device, dtype)
|
samples = op.sample_inputs(device, dtype)
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
@ -349,8 +475,12 @@ class TestCommon(TestCase):
|
|||||||
|
|
||||||
# Short-circuits if output is not a single tensor or an
|
# Short-circuits if output is not a single tensor or an
|
||||||
# iterable of tensors
|
# iterable of tensors
|
||||||
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
|
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
|
||||||
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
expected, include_empty=True
|
||||||
|
):
|
||||||
|
self.skipTest(
|
||||||
|
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||||
|
)
|
||||||
|
|
||||||
# Validates the op doesn't support out if it claims not to
|
# Validates the op doesn't support out if it claims not to
|
||||||
if not op.supports_out:
|
if not op.supports_out:
|
||||||
@ -380,7 +510,7 @@ class TestCommon(TestCase):
|
|||||||
# NOTE: only extracts on the CPU and CUDA device types since some
|
# NOTE: only extracts on the CPU and CUDA device types since some
|
||||||
# device types don't have storage
|
# device types don't have storage
|
||||||
def _extract_data_ptrs(out):
|
def _extract_data_ptrs(out):
|
||||||
if self.device_type != 'cpu' and self.device_type != 'cuda':
|
if self.device_type != "cpu" and self.device_type != "cuda":
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
if isinstance(out, torch.Tensor):
|
if isinstance(out, torch.Tensor):
|
||||||
@ -403,7 +533,8 @@ class TestCommon(TestCase):
|
|||||||
|
|
||||||
if compare_strides_and_data_ptrs:
|
if compare_strides_and_data_ptrs:
|
||||||
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
|
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
|
||||||
original_strides, final_strides)
|
original_strides, final_strides
|
||||||
|
)
|
||||||
self.assertEqual(original_strides, final_strides, msg=stride_msg)
|
self.assertEqual(original_strides, final_strides, msg=stride_msg)
|
||||||
self.assertEqual(original_ptrs, final_ptrs)
|
self.assertEqual(original_ptrs, final_ptrs)
|
||||||
|
|
||||||
@ -433,7 +564,9 @@ class TestCommon(TestCase):
|
|||||||
out = _apply_out_transform(_case_zero_transform, expected)
|
out = _apply_out_transform(_case_zero_transform, expected)
|
||||||
msg_fail = "Resized a non-empty tensor but did not warn about it."
|
msg_fail = "Resized a non-empty tensor but did not warn about it."
|
||||||
if _any_nonempty(out):
|
if _any_nonempty(out):
|
||||||
with self.assertWarnsRegex(UserWarning, "An output with one or more elements", msg=msg_fail):
|
with self.assertWarnsRegex(
|
||||||
|
UserWarning, "An output with one or more elements", msg=msg_fail
|
||||||
|
):
|
||||||
op_out(out=out)
|
op_out(out=out)
|
||||||
|
|
||||||
# Validates ops implement the correct out= behavior
|
# Validates ops implement the correct out= behavior
|
||||||
@ -452,7 +585,11 @@ class TestCommon(TestCase):
|
|||||||
supported_dtypes = op.supported_dtypes(self.device_type)
|
supported_dtypes = op.supported_dtypes(self.device_type)
|
||||||
if len(supported_dtypes) == 0:
|
if len(supported_dtypes) == 0:
|
||||||
self.skipTest("Skipped! Op has not supported dtypes on this device.")
|
self.skipTest("Skipped! Op has not supported dtypes on this device.")
|
||||||
dtype = torch.float32 if torch.float32 in supported_dtypes else list(supported_dtypes)[0]
|
dtype = (
|
||||||
|
torch.float32
|
||||||
|
if torch.float32 in supported_dtypes
|
||||||
|
else list(supported_dtypes)[0]
|
||||||
|
)
|
||||||
|
|
||||||
samples = op.sample_inputs(device, dtype)
|
samples = op.sample_inputs(device, dtype)
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
@ -462,8 +599,12 @@ class TestCommon(TestCase):
|
|||||||
|
|
||||||
# Short-circuits if output is not a single tensor or an
|
# Short-circuits if output is not a single tensor or an
|
||||||
# iterable of tensors
|
# iterable of tensors
|
||||||
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(expected, include_empty=True):
|
if not isinstance(expected, torch.Tensor) and not is_iterable_of_tensors(
|
||||||
self.skipTest("Skipped! Only supports single tensor or iterable of tensor outputs.")
|
expected, include_empty=True
|
||||||
|
):
|
||||||
|
self.skipTest(
|
||||||
|
"Skipped! Only supports single tensor or iterable of tensor outputs."
|
||||||
|
)
|
||||||
|
|
||||||
# Validates the op doesn't support out if it claims not to
|
# Validates the op doesn't support out if it claims not to
|
||||||
if not op.supports_out:
|
if not op.supports_out:
|
||||||
@ -493,7 +634,7 @@ class TestCommon(TestCase):
|
|||||||
# NOTE: only extracts on the CPU and CUDA device types since some
|
# NOTE: only extracts on the CPU and CUDA device types since some
|
||||||
# device types don't have storage
|
# device types don't have storage
|
||||||
def _extract_data_ptrs(out):
|
def _extract_data_ptrs(out):
|
||||||
if self.device_type != 'cpu' and self.device_type != 'cuda':
|
if self.device_type != "cpu" and self.device_type != "cuda":
|
||||||
return ()
|
return ()
|
||||||
|
|
||||||
if isinstance(out, torch.Tensor):
|
if isinstance(out, torch.Tensor):
|
||||||
@ -515,7 +656,8 @@ class TestCommon(TestCase):
|
|||||||
|
|
||||||
if compare_strides_and_data_ptrs:
|
if compare_strides_and_data_ptrs:
|
||||||
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
|
stride_msg = "Strides are not the same! Original strides were {0} and strides are now {1}".format(
|
||||||
original_strides, final_strides)
|
original_strides, final_strides
|
||||||
|
)
|
||||||
self.assertEqual(original_strides, final_strides, msg=stride_msg)
|
self.assertEqual(original_strides, final_strides, msg=stride_msg)
|
||||||
self.assertEqual(original_ptrs, final_ptrs)
|
self.assertEqual(original_ptrs, final_ptrs)
|
||||||
|
|
||||||
@ -529,7 +671,7 @@ class TestCommon(TestCase):
|
|||||||
return torch.full_like(t, info.max)
|
return torch.full_like(t, info.max)
|
||||||
except TypeError as te:
|
except TypeError as te:
|
||||||
# for non-integer types fills with NaN
|
# for non-integer types fills with NaN
|
||||||
return torch.full_like(t, float('nan'))
|
return torch.full_like(t, float("nan"))
|
||||||
|
|
||||||
_compare_out(_case_zero_transform)
|
_compare_out(_case_zero_transform)
|
||||||
|
|
||||||
@ -537,10 +679,9 @@ class TestCommon(TestCase):
|
|||||||
# but noncontiguous.
|
# but noncontiguous.
|
||||||
# Expected behavior: strides are respected and `out` storage is not changed.
|
# Expected behavior: strides are respected and `out` storage is not changed.
|
||||||
def _case_one_transform(t):
|
def _case_one_transform(t):
|
||||||
return make_tensor(t.shape,
|
return make_tensor(
|
||||||
dtype=t.dtype,
|
t.shape, dtype=t.dtype, device=t.device, noncontiguous=True
|
||||||
device=t.device,
|
)
|
||||||
noncontiguous=True)
|
|
||||||
|
|
||||||
_compare_out(_case_one_transform)
|
_compare_out(_case_one_transform)
|
||||||
|
|
||||||
@ -560,16 +701,19 @@ class TestCommon(TestCase):
|
|||||||
# Verifies no warning is a resize warning
|
# Verifies no warning is a resize warning
|
||||||
for w in caught:
|
for w in caught:
|
||||||
if "An output with one or more elements" in str(w.message):
|
if "An output with one or more elements" in str(w.message):
|
||||||
self.fail("Resizing an out= argument with no elements threw a resize warning!")
|
self.fail(
|
||||||
|
"Resizing an out= argument with no elements threw a resize warning!"
|
||||||
|
)
|
||||||
|
|
||||||
# Case 3: out= with correct shape and dtype, but wrong device.
|
# Case 3: out= with correct shape and dtype, but wrong device.
|
||||||
wrong_device = None
|
wrong_device = None
|
||||||
if torch.device(device).type != 'cpu':
|
if torch.device(device).type != "cpu":
|
||||||
wrong_device = 'cpu'
|
wrong_device = "cpu"
|
||||||
elif torch.cuda.is_available():
|
elif torch.cuda.is_available():
|
||||||
wrong_device = 'cuda'
|
wrong_device = "cuda"
|
||||||
|
|
||||||
if wrong_device is not None:
|
if wrong_device is not None:
|
||||||
|
|
||||||
def _case_three_transform(t):
|
def _case_three_transform(t):
|
||||||
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
|
return make_tensor(t.shape, dtype=t.dtype, device=wrong_device)
|
||||||
|
|
||||||
@ -587,16 +731,28 @@ class TestCommon(TestCase):
|
|||||||
# dtypes, or if an op returns multiple tensors when at least one such
|
# dtypes, or if an op returns multiple tensors when at least one such
|
||||||
# tensor is a floating point or complex dtype.
|
# tensor is a floating point or complex dtype.
|
||||||
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
|
_dtypes = floating_and_complex_types_and(torch.float16, torch.bfloat16)
|
||||||
if (isinstance(expected, torch.Tensor) and expected.dtype in _dtypes or
|
if (
|
||||||
(not isinstance(expected, torch.Tensor) and any(t.dtype in _dtypes for t in expected))):
|
isinstance(expected, torch.Tensor)
|
||||||
|
and expected.dtype in _dtypes
|
||||||
|
or (
|
||||||
|
not isinstance(expected, torch.Tensor)
|
||||||
|
and any(t.dtype in _dtypes for t in expected)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
|
||||||
def _case_four_transform(t):
|
def _case_four_transform(t):
|
||||||
return make_tensor(t.shape, dtype=torch.long, device=t.device)
|
return make_tensor(t.shape, dtype=torch.long, device=t.device)
|
||||||
|
|
||||||
out = _apply_out_transform(_case_four_transform, expected)
|
out = _apply_out_transform(_case_four_transform, expected)
|
||||||
msg_fail = "Expected RuntimeError when doing an unsafe cast!"
|
msg_fail = "Expected RuntimeError when doing an unsafe cast!"
|
||||||
msg_fail = msg_fail if not isinstance(expected, torch.Tensor) else \
|
msg_fail = (
|
||||||
("Expected RuntimeError when doing an unsafe cast from a result of dtype "
|
msg_fail
|
||||||
f"{expected.dtype} into an out= with dtype torch.long")
|
if not isinstance(expected, torch.Tensor)
|
||||||
|
else (
|
||||||
|
"Expected RuntimeError when doing an unsafe cast from a result of dtype "
|
||||||
|
f"{expected.dtype} into an out= with dtype torch.long"
|
||||||
|
)
|
||||||
|
)
|
||||||
with self.assertRaises(RuntimeError, msg=msg_fail):
|
with self.assertRaises(RuntimeError, msg=msg_fail):
|
||||||
op_out(out=out)
|
op_out(out=out)
|
||||||
|
|
||||||
@ -611,7 +767,9 @@ class TestCommon(TestCase):
|
|||||||
inplace = op.inplace_variant
|
inplace = op.inplace_variant
|
||||||
|
|
||||||
# list of all inplace ops: inplace variant + alias inplace variants if exist
|
# list of all inplace ops: inplace variant + alias inplace variants if exist
|
||||||
inplace_ops = [inplace, ]
|
inplace_ops = [
|
||||||
|
inplace,
|
||||||
|
]
|
||||||
variants = [method, inplace]
|
variants = [method, inplace]
|
||||||
|
|
||||||
for a_op in op.aliases:
|
for a_op in op.aliases:
|
||||||
@ -623,30 +781,46 @@ class TestCommon(TestCase):
|
|||||||
inplace_variants = tuple(filter(None, inplace_ops))
|
inplace_variants = tuple(filter(None, inplace_ops))
|
||||||
variants = tuple(filter(None, variants))
|
variants = tuple(filter(None, variants))
|
||||||
|
|
||||||
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
|
_requires_grad = dtype in op.supported_backward_dtypes(
|
||||||
|
torch.device(device).type
|
||||||
|
)
|
||||||
|
|
||||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
samples = op.sample_inputs(
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
requires_grad=_requires_grad,
|
||||||
|
include_conjugated_inputs=include_conjugated_inputs,
|
||||||
|
)
|
||||||
samples = list(samples)
|
samples = list(samples)
|
||||||
|
|
||||||
def _test_consistency_helper(samples, variants):
|
def _test_consistency_helper(samples, variants):
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
# TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
|
# TODO: Check grad for all Tensors requiring grad if sample.input is TensorList
|
||||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
tensor = (
|
||||||
|
sample.input
|
||||||
|
if isinstance(sample.input, torch.Tensor)
|
||||||
|
else sample.input[0]
|
||||||
|
)
|
||||||
|
|
||||||
# Computes function forward and backward values
|
# Computes function forward and backward values
|
||||||
tensor.grad = None
|
tensor.grad = None
|
||||||
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
expected_grad = None
|
expected_grad = None
|
||||||
|
|
||||||
output_process_fn_grad = sample.output_process_fn_grad if sample.output_process_fn_grad \
|
output_process_fn_grad = (
|
||||||
|
sample.output_process_fn_grad
|
||||||
|
if sample.output_process_fn_grad
|
||||||
else lambda x: x
|
else lambda x: x
|
||||||
|
)
|
||||||
|
|
||||||
# Skips inplace variants if the output dtype is not the same as
|
# Skips inplace variants if the output dtype is not the same as
|
||||||
# the input dtype
|
# the input dtype
|
||||||
skip_inplace = False
|
skip_inplace = False
|
||||||
if (isinstance(expected_forward, torch.Tensor) and
|
if (
|
||||||
expected_forward.dtype is not tensor.dtype):
|
isinstance(expected_forward, torch.Tensor)
|
||||||
|
and expected_forward.dtype is not tensor.dtype
|
||||||
|
):
|
||||||
skip_inplace = True
|
skip_inplace = True
|
||||||
|
|
||||||
# TODO: backward consistency only supported for single tensor outputs
|
# TODO: backward consistency only supported for single tensor outputs
|
||||||
@ -654,8 +828,9 @@ class TestCommon(TestCase):
|
|||||||
# tensor inputs
|
# tensor inputs
|
||||||
# TODO: update to handle checking grads of all tensor inputs as
|
# TODO: update to handle checking grads of all tensor inputs as
|
||||||
# derived from each tensor output
|
# derived from each tensor output
|
||||||
if (isinstance(expected_forward, torch.Tensor)
|
if isinstance(
|
||||||
and dtype in op.supported_backward_dtypes(torch.device(device).type)):
|
expected_forward, torch.Tensor
|
||||||
|
) and dtype in op.supported_backward_dtypes(torch.device(device).type):
|
||||||
output_process_fn_grad(expected_forward).sum().backward()
|
output_process_fn_grad(expected_forward).sum().backward()
|
||||||
expected_grad = tensor.grad
|
expected_grad = tensor.grad
|
||||||
|
|
||||||
@ -668,26 +843,35 @@ class TestCommon(TestCase):
|
|||||||
# Compares variant's forward
|
# Compares variant's forward
|
||||||
# Note: copies the to-be-modified input when testing the inplace variant
|
# Note: copies the to-be-modified input when testing the inplace variant
|
||||||
tensor.grad = None
|
tensor.grad = None
|
||||||
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
|
cloned = (
|
||||||
|
clone_input_helper(sample.input)
|
||||||
|
if variant in inplace_ops
|
||||||
|
else sample.input
|
||||||
|
)
|
||||||
|
|
||||||
if variant in inplace_ops and sample.broadcasts_input:
|
if variant in inplace_ops and sample.broadcasts_input:
|
||||||
with self.assertRaises(RuntimeError,
|
with self.assertRaises(
|
||||||
msg=('inplace variant either incorrectly allowed '
|
RuntimeError,
|
||||||
'resizing or you have marked the sample {}'
|
msg=(
|
||||||
' incorrectly with `broadcasts_self=True'.format(sample.summary()))):
|
"inplace variant either incorrectly allowed "
|
||||||
variant_forward = variant(cloned,
|
"resizing or you have marked the sample {}"
|
||||||
*sample.args,
|
" incorrectly with `broadcasts_self=True".format(
|
||||||
**sample.kwargs)
|
sample.summary()
|
||||||
|
)
|
||||||
|
),
|
||||||
|
):
|
||||||
|
variant_forward = variant(
|
||||||
|
cloned, *sample.args, **sample.kwargs
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
variant_forward = variant(cloned,
|
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
|
||||||
*sample.args,
|
|
||||||
**sample.kwargs)
|
|
||||||
self.assertEqual(expected_forward, variant_forward)
|
self.assertEqual(expected_forward, variant_forward)
|
||||||
|
|
||||||
# Compares variant's backward
|
# Compares variant's backward
|
||||||
if expected_grad is not None and \
|
if expected_grad is not None and (
|
||||||
(variant not in inplace_ops or op.supports_inplace_autograd):
|
variant not in inplace_ops or op.supports_inplace_autograd
|
||||||
|
):
|
||||||
output_process_fn_grad(variant_forward).sum().backward()
|
output_process_fn_grad(variant_forward).sum().backward()
|
||||||
self.assertEqual(expected_grad, tensor.grad)
|
self.assertEqual(expected_grad, tensor.grad)
|
||||||
|
|
||||||
@ -698,28 +882,45 @@ class TestCommon(TestCase):
|
|||||||
# Skips inplace variants if the output dtype is not the same as
|
# Skips inplace variants if the output dtype is not the same as
|
||||||
# the input dtype
|
# the input dtype
|
||||||
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
expected_forward = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
tensor = (
|
||||||
|
sample.input
|
||||||
|
if isinstance(sample.input, torch.Tensor)
|
||||||
|
else sample.input[0]
|
||||||
|
)
|
||||||
skip_inplace = False
|
skip_inplace = False
|
||||||
if (isinstance(expected_forward, torch.Tensor) and
|
if (
|
||||||
expected_forward.dtype is not tensor.dtype):
|
isinstance(expected_forward, torch.Tensor)
|
||||||
|
and expected_forward.dtype is not tensor.dtype
|
||||||
|
):
|
||||||
skip_inplace = True
|
skip_inplace = True
|
||||||
if skip_inplace:
|
if skip_inplace:
|
||||||
return
|
return
|
||||||
for variant in variants:
|
for variant in variants:
|
||||||
cloned = clone_input_helper(sample.input) if variant in inplace_ops else sample.input
|
cloned = (
|
||||||
inp_tensor = cloned if isinstance(cloned, torch.Tensor) else cloned[0]
|
clone_input_helper(sample.input)
|
||||||
|
if variant in inplace_ops
|
||||||
|
else sample.input
|
||||||
|
)
|
||||||
|
inp_tensor = (
|
||||||
|
cloned if isinstance(cloned, torch.Tensor) else cloned[0]
|
||||||
|
)
|
||||||
data_ptr = inp_tensor.data_ptr()
|
data_ptr = inp_tensor.data_ptr()
|
||||||
variant_forward = variant(cloned,
|
variant_forward = variant(cloned, *sample.args, **sample.kwargs)
|
||||||
*sample.args,
|
|
||||||
**sample.kwargs)
|
|
||||||
# TODO Support non-tensor outputs if they exist for inplace ops
|
# TODO Support non-tensor outputs if they exist for inplace ops
|
||||||
if (isinstance(variant_forward, torch.Tensor)):
|
if isinstance(variant_forward, torch.Tensor):
|
||||||
self.assertEqual(data_ptr, variant_forward.data_ptr(), atol=0, rtol=0)
|
self.assertEqual(
|
||||||
|
data_ptr, variant_forward.data_ptr(), atol=0, rtol=0
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.assertTrue(False, "Non-tensor outputs for inplace ops are not supported")
|
self.assertTrue(
|
||||||
|
False,
|
||||||
|
"Non-tensor outputs for inplace ops are not supported",
|
||||||
|
)
|
||||||
|
|
||||||
if len(inplace_ops) > 0:
|
if len(inplace_ops) > 0:
|
||||||
inplace_samples = list(filter(lambda sample: not sample.broadcasts_input, samples))
|
inplace_samples = list(
|
||||||
|
filter(lambda sample: not sample.broadcasts_input, samples)
|
||||||
|
)
|
||||||
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
|
_test_inplace_preserve_storage(inplace_samples, inplace_variants)
|
||||||
|
|
||||||
# Reference testing for operations in complex32 against complex64.
|
# Reference testing for operations in complex32 against complex64.
|
||||||
@ -732,14 +933,21 @@ class TestCommon(TestCase):
|
|||||||
for sample in op.sample_inputs(device, dtype):
|
for sample in op.sample_inputs(device, dtype):
|
||||||
actual = op(sample.input, *sample.args, **sample.kwargs)
|
actual = op(sample.input, *sample.args, **sample.kwargs)
|
||||||
transformed_sample = sample.transform(lambda x: x.to(torch.complex64))
|
transformed_sample = sample.transform(lambda x: x.to(torch.complex64))
|
||||||
expected = op(transformed_sample.input, *transformed_sample.args, **transformed_sample.kwargs)
|
expected = op(
|
||||||
|
transformed_sample.input,
|
||||||
|
*transformed_sample.args,
|
||||||
|
**transformed_sample.kwargs,
|
||||||
|
)
|
||||||
self.assertEqual(actual, expected, exact_dtype=False)
|
self.assertEqual(actual, expected, exact_dtype=False)
|
||||||
|
|
||||||
|
|
||||||
class TestCompositeCompliance(TestCase):
|
class TestCompositeCompliance(TestCase):
|
||||||
# Checks if the operator (if it is composite) is written to support most
|
# Checks if the operator (if it is composite) is written to support most
|
||||||
# backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
|
# backends and Tensor subclasses. See "CompositeImplicitAutograd Compliance"
|
||||||
# in aten/src/ATen/native/README.md for more details
|
# in aten/src/ATen/native/README.md for more details
|
||||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
|
@unittest.skipIf(
|
||||||
|
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
|
||||||
|
)
|
||||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||||
def test_operator(self, device, dtype, op):
|
def test_operator(self, device, dtype, op):
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||||
@ -750,7 +958,9 @@ class TestCompositeCompliance(TestCase):
|
|||||||
composite_compliance.check_with_mode(op, args, kwargs)
|
composite_compliance.check_with_mode(op, args, kwargs)
|
||||||
composite_compliance.check_all_permutations(op, args, kwargs)
|
composite_compliance.check_all_permutations(op, args, kwargs)
|
||||||
|
|
||||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
|
@unittest.skipIf(
|
||||||
|
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
|
||||||
|
)
|
||||||
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
|
@ops([op for op in op_db if op.supports_autograd], allowed_dtypes=(torch.float,))
|
||||||
def test_backward(self, device, dtype, op):
|
def test_backward(self, device, dtype, op):
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
||||||
@ -760,7 +970,9 @@ class TestCompositeCompliance(TestCase):
|
|||||||
kwargs = sample.kwargs
|
kwargs = sample.kwargs
|
||||||
composite_compliance.check_backward_formula(op, args, kwargs)
|
composite_compliance.check_backward_formula(op, args, kwargs)
|
||||||
|
|
||||||
@unittest.skipIf(IS_FBCODE or IS_SANDCASTLE, '__torch_dispatch__ does not work in fbcode')
|
@unittest.skipIf(
|
||||||
|
IS_FBCODE or IS_SANDCASTLE, "__torch_dispatch__ does not work in fbcode"
|
||||||
|
)
|
||||||
@ops(op_db, allowed_dtypes=(torch.float,))
|
@ops(op_db, allowed_dtypes=(torch.float,))
|
||||||
def test_forward_ad(self, device, dtype, op):
|
def test_forward_ad(self, device, dtype, op):
|
||||||
if torch.float not in op.supported_backward_dtypes(device):
|
if torch.float not in op.supported_backward_dtypes(device):
|
||||||
@ -776,6 +988,7 @@ class TestCompositeCompliance(TestCase):
|
|||||||
kwargs = sample.kwargs
|
kwargs = sample.kwargs
|
||||||
composite_compliance.check_forward_ad_formula(op, args, kwargs)
|
composite_compliance.check_forward_ad_formula(op, args, kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestMathBits(TestCase):
|
class TestMathBits(TestCase):
|
||||||
# Tests that
|
# Tests that
|
||||||
# 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
|
# 1. The operator's output for physically conjugated/negated tensors and conjugate/negative view tensors
|
||||||
@ -787,7 +1000,17 @@ class TestMathBits(TestCase):
|
|||||||
# This test only runs for C -> R and C -> C functions
|
# This test only runs for C -> R and C -> C functions
|
||||||
# TODO: add tests for `R->C` functions
|
# TODO: add tests for `R->C` functions
|
||||||
# Note: This test runs for functions that take both tensors and tensorlists as input.
|
# Note: This test runs for functions that take both tensors and tensorlists as input.
|
||||||
def _test_math_view(self, device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, out_type):
|
def _test_math_view(
|
||||||
|
self,
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
op,
|
||||||
|
samples,
|
||||||
|
math_op_physical,
|
||||||
|
math_op_view,
|
||||||
|
is_bit_set,
|
||||||
|
out_type,
|
||||||
|
):
|
||||||
inplace_variant = op.inplace_variant
|
inplace_variant = op.inplace_variant
|
||||||
|
|
||||||
# helper function to clone and conjugate/negate the input if its a tensor
|
# helper function to clone and conjugate/negate the input if its a tensor
|
||||||
@ -796,7 +1019,7 @@ class TestMathBits(TestCase):
|
|||||||
# have its requires_grad set to that value.
|
# have its requires_grad set to that value.
|
||||||
def clone_and_perform_view(input, **kwargs):
|
def clone_and_perform_view(input, **kwargs):
|
||||||
if isinstance(input, torch.Tensor):
|
if isinstance(input, torch.Tensor):
|
||||||
requires_grad = kwargs.get('requires_grad', input.requires_grad)
|
requires_grad = kwargs.get("requires_grad", input.requires_grad)
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
# Ensure view represents the original sample input
|
# Ensure view represents the original sample input
|
||||||
input = math_op_physical(input)
|
input = math_op_physical(input)
|
||||||
@ -813,7 +1036,11 @@ class TestMathBits(TestCase):
|
|||||||
return tuple(out)
|
return tuple(out)
|
||||||
|
|
||||||
for sample in samples:
|
for sample in samples:
|
||||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
tensor = (
|
||||||
|
sample.input
|
||||||
|
if isinstance(sample.input, torch.Tensor)
|
||||||
|
else sample.input[0]
|
||||||
|
)
|
||||||
cloned1 = clone_and_perform_view(sample.input)
|
cloned1 = clone_and_perform_view(sample.input)
|
||||||
|
|
||||||
# Computes function forward value with a physically conjugated/negated tensor and
|
# Computes function forward value with a physically conjugated/negated tensor and
|
||||||
@ -827,9 +1054,13 @@ class TestMathBits(TestCase):
|
|||||||
# input produces correct output, and the output tensor has the conj/neg bit set to True
|
# input produces correct output, and the output tensor has the conj/neg bit set to True
|
||||||
if inplace_variant is not None and not sample.broadcasts_input:
|
if inplace_variant is not None and not sample.broadcasts_input:
|
||||||
cloned2 = clone_and_perform_view(tensor, requires_grad=False)
|
cloned2 = clone_and_perform_view(tensor, requires_grad=False)
|
||||||
if (isinstance(expected_forward, torch.Tensor) and
|
if (
|
||||||
expected_forward.dtype is tensor.dtype):
|
isinstance(expected_forward, torch.Tensor)
|
||||||
inplace_forward = inplace_variant(cloned2, *sample.args, **sample.kwargs)
|
and expected_forward.dtype is tensor.dtype
|
||||||
|
):
|
||||||
|
inplace_forward = inplace_variant(
|
||||||
|
cloned2, *sample.args, **sample.kwargs
|
||||||
|
)
|
||||||
self.assertTrue(is_bit_set(inplace_forward))
|
self.assertTrue(is_bit_set(inplace_forward))
|
||||||
self.assertEqual(inplace_forward, expected_forward)
|
self.assertEqual(inplace_forward, expected_forward)
|
||||||
|
|
||||||
@ -838,25 +1069,36 @@ class TestMathBits(TestCase):
|
|||||||
# tensor inputs
|
# tensor inputs
|
||||||
# TODO: update to handle checking grads of all tensor inputs as
|
# TODO: update to handle checking grads of all tensor inputs as
|
||||||
# derived from each tensor output
|
# derived from each tensor output
|
||||||
if isinstance(expected_forward, torch.Tensor) and expected_forward.requires_grad:
|
if (
|
||||||
|
isinstance(expected_forward, torch.Tensor)
|
||||||
|
and expected_forward.requires_grad
|
||||||
|
):
|
||||||
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
|
output_process_fn_grad = sample.output_process_fn_grad or (lambda x: x)
|
||||||
expected_forward = output_process_fn_grad(expected_forward)
|
expected_forward = output_process_fn_grad(expected_forward)
|
||||||
forward_with_mathview = output_process_fn_grad(forward_with_mathview)
|
forward_with_mathview = output_process_fn_grad(forward_with_mathview)
|
||||||
|
|
||||||
tensor = sample.input if isinstance(sample.input, torch.Tensor) else sample.input[0]
|
tensor = (
|
||||||
|
sample.input
|
||||||
|
if isinstance(sample.input, torch.Tensor)
|
||||||
|
else sample.input[0]
|
||||||
|
)
|
||||||
expected_forward.sum().backward(retain_graph=True)
|
expected_forward.sum().backward(retain_graph=True)
|
||||||
forward_with_mathview.sum().backward(retain_graph=True)
|
forward_with_mathview.sum().backward(retain_graph=True)
|
||||||
if tensor.grad is not None:
|
if tensor.grad is not None:
|
||||||
cloned1_tensor = cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
|
cloned1_tensor = (
|
||||||
|
cloned1 if isinstance(cloned1, torch.Tensor) else cloned1[0]
|
||||||
|
)
|
||||||
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||||
|
|
||||||
tensor.grad, cloned1_tensor.grad = None, None
|
tensor.grad, cloned1_tensor.grad = None, None
|
||||||
|
|
||||||
# a repeat of the above test if output is not complex valued
|
# a repeat of the above test if output is not complex valued
|
||||||
if (out_type(expected_forward)):
|
if out_type(expected_forward):
|
||||||
grad = torch.randn_like(expected_forward)
|
grad = torch.randn_like(expected_forward)
|
||||||
expected_forward.backward(grad)
|
expected_forward.backward(grad)
|
||||||
forward_with_mathview.backward(math_op_view(math_op_physical(grad)))
|
forward_with_mathview.backward(
|
||||||
|
math_op_view(math_op_physical(grad))
|
||||||
|
)
|
||||||
|
|
||||||
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
self.assertEqual(tensor.grad, cloned1_tensor.grad)
|
||||||
|
|
||||||
@ -866,10 +1108,21 @@ class TestMathBits(TestCase):
|
|||||||
self.skipTest("Operation doesn't support conjugated inputs.")
|
self.skipTest("Operation doesn't support conjugated inputs.")
|
||||||
math_op_physical = torch.conj_physical
|
math_op_physical = torch.conj_physical
|
||||||
math_op_view = torch.conj
|
math_op_view = torch.conj
|
||||||
_requires_grad = torch.cfloat in op.supported_backward_dtypes(torch.device(device).type)
|
_requires_grad = torch.cfloat in op.supported_backward_dtypes(
|
||||||
|
torch.device(device).type
|
||||||
|
)
|
||||||
is_bit_set = torch.is_conj
|
is_bit_set = torch.is_conj
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||||
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set, torch.is_complex)
|
self._test_math_view(
|
||||||
|
device,
|
||||||
|
dtype,
|
||||||
|
op,
|
||||||
|
samples,
|
||||||
|
math_op_physical,
|
||||||
|
math_op_view,
|
||||||
|
is_bit_set,
|
||||||
|
torch.is_complex,
|
||||||
|
)
|
||||||
|
|
||||||
@ops(op_db, allowed_dtypes=(torch.double,))
|
@ops(op_db, allowed_dtypes=(torch.double,))
|
||||||
def test_neg_view(self, device, dtype, op):
|
def test_neg_view(self, device, dtype, op):
|
||||||
@ -879,8 +1132,16 @@ class TestMathBits(TestCase):
|
|||||||
math_op_view = torch._neg_view
|
math_op_view = torch._neg_view
|
||||||
is_bit_set = torch.is_neg
|
is_bit_set = torch.is_neg
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
|
samples = op.sample_inputs(device, dtype, requires_grad=op.supports_autograd)
|
||||||
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
|
self._test_math_view(
|
||||||
lambda x: True)
|
device,
|
||||||
|
dtype,
|
||||||
|
op,
|
||||||
|
samples,
|
||||||
|
math_op_physical,
|
||||||
|
math_op_view,
|
||||||
|
is_bit_set,
|
||||||
|
lambda x: True,
|
||||||
|
)
|
||||||
|
|
||||||
@ops(op_db, allowed_dtypes=(torch.cdouble,))
|
@ops(op_db, allowed_dtypes=(torch.cdouble,))
|
||||||
def test_neg_conj_view(self, device, dtype, op):
|
def test_neg_conj_view(self, device, dtype, op):
|
||||||
@ -898,17 +1159,27 @@ class TestMathBits(TestCase):
|
|||||||
def is_bit_set(x):
|
def is_bit_set(x):
|
||||||
return torch.is_neg(x) and torch.is_conj(x)
|
return torch.is_neg(x) and torch.is_conj(x)
|
||||||
|
|
||||||
_requires_grad = dtype in op.supported_backward_dtypes(torch.device(device).type)
|
_requires_grad = dtype in op.supported_backward_dtypes(
|
||||||
|
torch.device(device).type
|
||||||
|
)
|
||||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad)
|
||||||
# Only test one sample
|
# Only test one sample
|
||||||
samples = itertools.islice(samples, 1)
|
samples = itertools.islice(samples, 1)
|
||||||
self._test_math_view(device, dtype, op, samples, math_op_physical, math_op_view, is_bit_set,
|
self._test_math_view(
|
||||||
torch.is_complex)
|
device,
|
||||||
|
dtype,
|
||||||
|
op,
|
||||||
|
samples,
|
||||||
|
math_op_physical,
|
||||||
|
math_op_view,
|
||||||
|
is_bit_set,
|
||||||
|
torch.is_complex,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
instantiate_device_type_tests(TestCommon, globals())
|
instantiate_device_type_tests(TestCommon, globals())
|
||||||
instantiate_device_type_tests(TestCompositeCompliance, globals())
|
instantiate_device_type_tests(TestCompositeCompliance, globals())
|
||||||
instantiate_device_type_tests(TestMathBits, globals())
|
instantiate_device_type_tests(TestMathBits, globals())
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
run_tests()
|
run_tests()
|
||||||
|
@ -1,85 +1,83 @@
|
|||||||
# Owner(s): ["module: primTorch"]
|
# Owner(s): ["module: primTorch"]
|
||||||
|
|
||||||
# TODO: uncomment this file once CI issues with import nvfuser are resolved
|
from functools import partial
|
||||||
|
|
||||||
# from functools import partial
|
import torch
|
||||||
|
from torch.testing import make_tensor
|
||||||
# import torch
|
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||||
# from torch.testing import make_tensor
|
from torch.testing._internal.common_device_type import (
|
||||||
# from torch.testing._internal.common_utils import run_tests, TestCase
|
instantiate_device_type_tests,
|
||||||
# from torch.testing._internal.common_device_type import (
|
onlyCUDA,
|
||||||
# instantiate_device_type_tests,
|
dtypes,
|
||||||
# onlyCUDA,
|
)
|
||||||
# dtypes,
|
import torch._prims as prims
|
||||||
# )
|
from torch._prims.executor import make_traced
|
||||||
# import torch._prims as prims
|
|
||||||
# from torch._prims.executor import make_traced
|
|
||||||
|
|
||||||
|
|
||||||
# class TestPrims(TestCase):
|
class TestPrims(TestCase):
|
||||||
# @onlyCUDA
|
@onlyCUDA
|
||||||
# @dtypes(torch.float32)
|
@dtypes(torch.float32)
|
||||||
# def test_broadcast_in_dim(self, device, dtype):
|
def test_broadcast_in_dim(self, device, dtype):
|
||||||
# def _wrapper(a, shape, broadcast_dimensions):
|
def _wrapper(a, shape, broadcast_dimensions):
|
||||||
# return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
|
return prims.broadcast_in_dim(a, shape, broadcast_dimensions)
|
||||||
|
|
||||||
# traced = make_traced(_wrapper)
|
traced = make_traced(_wrapper)
|
||||||
# make_arg = partial(make_tensor, device=device, dtype=dtype)
|
make_arg = partial(make_tensor, device=device, dtype=dtype)
|
||||||
|
|
||||||
# # TODO: FIXME:
|
# TODO: FIXME:
|
||||||
# # for executor in ('aten', 'nvfuser'):
|
# for executor in ('aten', 'nvfuser'):
|
||||||
# for executor in ("aten",):
|
for executor in ("aten",):
|
||||||
# fn = partial(traced, executor=executor)
|
fn = partial(traced, executor=executor)
|
||||||
# # Same shape
|
# Same shape
|
||||||
# shape = (5, 5)
|
shape = (5, 5)
|
||||||
# a = make_arg(shape)
|
a = make_arg(shape)
|
||||||
# result = fn(a, shape, (0, 1))
|
result = fn(a, shape, (0, 1))
|
||||||
|
|
||||||
# self.assertEqual(result.shape, a.shape)
|
self.assertEqual(result.shape, a.shape)
|
||||||
# self.assertTrue(result.is_contiguous)
|
self.assertTrue(result.is_contiguous)
|
||||||
# self.assertEqual(a, result)
|
self.assertEqual(a, result)
|
||||||
|
|
||||||
# # Error input: reordering dims
|
# Error input: reordering dims
|
||||||
# with self.assertRaises(Exception):
|
with self.assertRaises(Exception):
|
||||||
# result = fn(a, shape, (1, 0))
|
result = fn(a, shape, (1, 0))
|
||||||
|
|
||||||
# # Adding outermost dimensions
|
# Adding outermost dimensions
|
||||||
# a = make_arg((5, 5))
|
a = make_arg((5, 5))
|
||||||
# target_shape = (3, 3, 5, 5)
|
target_shape = (3, 3, 5, 5)
|
||||||
# result = fn(a, target_shape, (2, 3))
|
result = fn(a, target_shape, (2, 3))
|
||||||
|
|
||||||
# self.assertEqual(result.shape, target_shape)
|
self.assertEqual(result.shape, target_shape)
|
||||||
# self.assertEqual(a.broadcast_to(target_shape), result)
|
self.assertEqual(a.broadcast_to(target_shape), result)
|
||||||
|
|
||||||
# # Expands
|
# Expands
|
||||||
# a = make_arg((1, 5, 1))
|
a = make_arg((1, 5, 1))
|
||||||
# target_shape = (3, 5, 7)
|
target_shape = (3, 5, 7)
|
||||||
# result = fn(a, target_shape, (0, 1, 2))
|
result = fn(a, target_shape, (0, 1, 2))
|
||||||
|
|
||||||
# self.assertEqual(result.shape, target_shape)
|
self.assertEqual(result.shape, target_shape)
|
||||||
# self.assertEqual(a.expand_as(result), result)
|
self.assertEqual(a.expand_as(result), result)
|
||||||
|
|
||||||
# # Unsqueezes
|
# Unsqueezes
|
||||||
# a = make_arg((1, 2, 3))
|
a = make_arg((1, 2, 3))
|
||||||
# target_shape = (1, 2, 1, 3)
|
target_shape = (1, 2, 1, 3)
|
||||||
# result = fn(a, target_shape, (0, 1, 3))
|
result = fn(a, target_shape, (0, 1, 3))
|
||||||
|
|
||||||
# self.assertEqual(result.shape, target_shape)
|
self.assertEqual(result.shape, target_shape)
|
||||||
# self.assertEqual(a.unsqueeze(2), result)
|
self.assertEqual(a.unsqueeze(2), result)
|
||||||
|
|
||||||
# # Adds outermost, expands, and unsqueezes
|
# Adds outermost, expands, and unsqueezes
|
||||||
# a = make_arg((1, 2, 3))
|
a = make_arg((1, 2, 3))
|
||||||
# target_shape = (4, 1, 7, 2, 3, 3)
|
target_shape = (4, 1, 7, 2, 3, 3)
|
||||||
# result = fn(a, target_shape, (1, 3, 4))
|
result = fn(a, target_shape, (1, 3, 4))
|
||||||
|
|
||||||
# self.assertEqual(result.shape, target_shape)
|
self.assertEqual(result.shape, target_shape)
|
||||||
# a.unsqueeze_(3)
|
a.unsqueeze_(3)
|
||||||
# a.unsqueeze_(1)
|
a.unsqueeze_(1)
|
||||||
# a.unsqueeze_(0)
|
a.unsqueeze_(0)
|
||||||
# self.assertEqual(a.expand_as(result), result)
|
self.assertEqual(a.expand_as(result), result)
|
||||||
|
|
||||||
|
|
||||||
# instantiate_device_type_tests(TestPrims, globals())
|
instantiate_device_type_tests(TestPrims, globals())
|
||||||
|
|
||||||
# if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# run_tests()
|
run_tests()
|
||||||
|
@ -13,7 +13,7 @@ import random
|
|||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
TestCase, run_tests, do_test_empty_full, TEST_WITH_ROCM, suppress_warnings,
|
||||||
torch_to_numpy_dtype_dict, slowTest,
|
torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict, slowTest,
|
||||||
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize)
|
TEST_SCIPY, IS_MACOS, IS_PPC, IS_WINDOWS, parametrize)
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
expectedFailureMeta, instantiate_device_type_tests, deviceCountAtLeast, onlyNativeDeviceTypes,
|
||||||
@ -3683,13 +3683,13 @@ class TestBufferProtocol(TestCase):
|
|||||||
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
|
self.assertEqual(numpy_frombuffer.__array_interface__["data"][0], torch_frombuffer.data_ptr())
|
||||||
return (numpy_original, torch_frombuffer)
|
return (numpy_original, torch_frombuffer)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_same_type(self, device, dtype):
|
def test_same_type(self, device, dtype):
|
||||||
self._run_test((), dtype)
|
self._run_test((), dtype)
|
||||||
self._run_test((4,), dtype)
|
self._run_test((4,), dtype)
|
||||||
self._run_test((10, 10), dtype)
|
self._run_test((10, 10), dtype)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_requires_grad(self, device, dtype):
|
def test_requires_grad(self, device, dtype):
|
||||||
def _run_test_and_check_grad(requires_grad, *args, **kwargs):
|
def _run_test_and_check_grad(requires_grad, *args, **kwargs):
|
||||||
kwargs["requires_grad"] = requires_grad
|
kwargs["requires_grad"] = requires_grad
|
||||||
@ -3704,14 +3704,14 @@ class TestBufferProtocol(TestCase):
|
|||||||
_run_test_and_check_grad(False, (4,), dtype)
|
_run_test_and_check_grad(False, (4,), dtype)
|
||||||
_run_test_and_check_grad(False, (10, 10), dtype)
|
_run_test_and_check_grad(False, (10, 10), dtype)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_with_offset(self, device, dtype):
|
def test_with_offset(self, device, dtype):
|
||||||
# Offset should be valid whenever there is, at least,
|
# Offset should be valid whenever there is, at least,
|
||||||
# one remaining element
|
# one remaining element
|
||||||
for i in range(SIZE):
|
for i in range(SIZE):
|
||||||
self._run_test(SHAPE, dtype, first=i)
|
self._run_test(SHAPE, dtype, first=i)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_with_count(self, device, dtype):
|
def test_with_count(self, device, dtype):
|
||||||
# Count should be valid for any valid in the interval
|
# Count should be valid for any valid in the interval
|
||||||
# [-1, len(input)], except for 0
|
# [-1, len(input)], except for 0
|
||||||
@ -3719,7 +3719,7 @@ class TestBufferProtocol(TestCase):
|
|||||||
if i != 0:
|
if i != 0:
|
||||||
self._run_test(SHAPE, dtype, count=i)
|
self._run_test(SHAPE, dtype, count=i)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_with_count_and_offset(self, device, dtype):
|
def test_with_count_and_offset(self, device, dtype):
|
||||||
# Explicit default count [-1, 1, 2, ..., len]
|
# Explicit default count [-1, 1, 2, ..., len]
|
||||||
for i in range(-1, SIZE + 1):
|
for i in range(-1, SIZE + 1):
|
||||||
@ -3735,7 +3735,7 @@ class TestBufferProtocol(TestCase):
|
|||||||
for j in range(SIZE - i + 1):
|
for j in range(SIZE - i + 1):
|
||||||
self._run_test(SHAPE, dtype, count=i, first=j)
|
self._run_test(SHAPE, dtype, count=i, first=j)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_invalid_positional_args(self, device, dtype):
|
def test_invalid_positional_args(self, device, dtype):
|
||||||
bytes = get_dtype_size(dtype)
|
bytes = get_dtype_size(dtype)
|
||||||
in_bytes = SIZE * bytes
|
in_bytes = SIZE * bytes
|
||||||
@ -3772,7 +3772,7 @@ class TestBufferProtocol(TestCase):
|
|||||||
rf"buffer length \({in_bytes} bytes\)"):
|
rf"buffer length \({in_bytes} bytes\)"):
|
||||||
self._run_test(SHAPE, dtype, count=count, first=first)
|
self._run_test(SHAPE, dtype, count=count, first=first)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_shared_buffer(self, device, dtype):
|
def test_shared_buffer(self, device, dtype):
|
||||||
x = make_tensor((1,), dtype=dtype, device=device)
|
x = make_tensor((1,), dtype=dtype, device=device)
|
||||||
# Modify the whole tensor
|
# Modify the whole tensor
|
||||||
@ -3799,13 +3799,13 @@ class TestBufferProtocol(TestCase):
|
|||||||
arr[first] = x.item() - 1
|
arr[first] = x.item() - 1
|
||||||
self.assertEqual(arr[first:last], tensor)
|
self.assertEqual(arr[first:last], tensor)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_not_a_buffer(self, device, dtype):
|
def test_not_a_buffer(self, device, dtype):
|
||||||
with self.assertRaisesRegex(ValueError,
|
with self.assertRaisesRegex(ValueError,
|
||||||
r"object does not implement Python buffer protocol."):
|
r"object does not implement Python buffer protocol."):
|
||||||
torch.frombuffer([1, 2, 3, 4], dtype=dtype)
|
torch.frombuffer([1, 2, 3, 4], dtype=dtype)
|
||||||
|
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_non_writable_buffer(self, device, dtype):
|
def test_non_writable_buffer(self, device, dtype):
|
||||||
numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy()
|
numpy_arr = make_tensor((1,), dtype=dtype, device=device).numpy()
|
||||||
byte_arr = numpy_arr.tobytes()
|
byte_arr = numpy_arr.tobytes()
|
||||||
@ -3910,7 +3910,7 @@ class TestAsArray(TestCase):
|
|||||||
self._test_alias_with_cvt(identity, device, dtype)
|
self._test_alias_with_cvt(identity, device, dtype)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_alias_from_numpy(self, device, dtype):
|
def test_alias_from_numpy(self, device, dtype):
|
||||||
self._test_alias_with_cvt(to_numpy, device, dtype)
|
self._test_alias_with_cvt(to_numpy, device, dtype)
|
||||||
|
|
||||||
@ -3921,7 +3921,7 @@ class TestAsArray(TestCase):
|
|||||||
self._test_alias_with_cvt(to_dlpack, device, dtype)
|
self._test_alias_with_cvt(to_dlpack, device, dtype)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_alias_from_buffer(self, device, dtype):
|
def test_alias_from_buffer(self, device, dtype):
|
||||||
self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
self._test_alias_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
||||||
|
|
||||||
@ -3959,7 +3959,7 @@ class TestAsArray(TestCase):
|
|||||||
self._test_copy_with_cvt(identity, device, dtype)
|
self._test_copy_with_cvt(identity, device, dtype)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_copy_from_numpy(self, device, dtype):
|
def test_copy_from_numpy(self, device, dtype):
|
||||||
self._test_copy_with_cvt(to_numpy, device, dtype)
|
self._test_copy_with_cvt(to_numpy, device, dtype)
|
||||||
|
|
||||||
@ -3969,7 +3969,7 @@ class TestAsArray(TestCase):
|
|||||||
self._test_copy_with_cvt(to_dlpack, device, dtype)
|
self._test_copy_with_cvt(to_dlpack, device, dtype)
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*torch_to_numpy_dtype_dict.keys())
|
@dtypes(*set(numpy_to_torch_dtype_dict.values()))
|
||||||
def test_copy_from_buffer(self, device, dtype):
|
def test_copy_from_buffer(self, device, dtype):
|
||||||
self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
self._test_copy_with_cvt(to_memview, device, dtype, shape=(5,), only_with_dtype=True)
|
||||||
|
|
||||||
|
@ -7,15 +7,14 @@ import unittest
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
|
from torch.testing._internal.common_utils import (TestCase, run_tests, load_tests,
|
||||||
TEST_NUMPY, torch_to_numpy_dtype_dict)
|
TEST_NUMPY, torch_to_numpy_dtype_dict, numpy_to_torch_dtype_dict)
|
||||||
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
|
from torch.testing._internal.common_device_type import (instantiate_device_type_tests, onlyNativeDeviceTypes,
|
||||||
dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta)
|
dtypes, dtypesIfCUDA, onlyCPU, expectedFailureMeta, skipMeta)
|
||||||
from torch.testing._internal.common_dtype import (
|
from torch.testing._internal.common_dtype import (
|
||||||
all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and
|
all_types_and_complex_and, all_types_and, get_all_math_dtypes, integral_types_and, floating_types_and
|
||||||
)
|
)
|
||||||
|
|
||||||
if TEST_NUMPY:
|
import numpy as np
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
|
||||||
# sharding on sandcastle. This line silences flake warnings
|
# sharding on sandcastle. This line silences flake warnings
|
||||||
@ -812,8 +811,8 @@ class TestTypePromotion(TestCase):
|
|||||||
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
|
||||||
@float_double_default_dtype
|
@float_double_default_dtype
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*list(itertools.product(torch_to_numpy_dtype_dict.keys(),
|
@dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()),
|
||||||
torch_to_numpy_dtype_dict.keys())))
|
set(numpy_to_torch_dtype_dict.values()))))
|
||||||
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
|
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
|
||||||
import operator
|
import operator
|
||||||
np_type = torch_to_numpy_dtype_dict[dtypes[0]]
|
np_type = torch_to_numpy_dtype_dict[dtypes[0]]
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -11,7 +11,7 @@ import random
|
|||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck,
|
TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck,
|
||||||
torch_to_numpy_dtype_dict,
|
numpy_to_torch_dtype_dict,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_device_type import \
|
from torch.testing._internal.common_device_type import \
|
||||||
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
|
(instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta)
|
||||||
@ -130,7 +130,7 @@ class TestViewOps(TestCase):
|
|||||||
@onlyNativeDeviceTypes
|
@onlyNativeDeviceTypes
|
||||||
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
@dtypes(*all_types_and_complex_and(torch.half, torch.bool))
|
||||||
def test_view_dtype_new(self, device, dtype):
|
def test_view_dtype_new(self, device, dtype):
|
||||||
dtypes = torch_to_numpy_dtype_dict.copy()
|
dtypes = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||||
del dtypes[torch.bool]
|
del dtypes[torch.bool]
|
||||||
|
|
||||||
def generate_inputs():
|
def generate_inputs():
|
||||||
|
@ -2,14 +2,15 @@ import torch
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
import torch._prims.utils as utils
|
import torch._prims.utils as utils
|
||||||
from torch._prims.utils import TensorLike, TensorLikeType, TensorMeta, ShapeType
|
from torch._prims.utils import (
|
||||||
|
TensorLike,
|
||||||
|
TensorLikeType,
|
||||||
|
TensorMeta,
|
||||||
|
ShapeType,
|
||||||
|
getnvFuserDtype,
|
||||||
|
)
|
||||||
from torch.overrides import has_torch_function, handle_torch_function
|
from torch.overrides import has_torch_function, handle_torch_function
|
||||||
|
|
||||||
import torch._C._nvfuser as nvfuser # type: ignore[import]
|
|
||||||
|
|
||||||
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
|
|
||||||
DataType = nvfuser.DataType # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any
|
from typing import Sequence, Optional, Union, Callable, List, Tuple, Any
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
@ -1033,21 +1034,8 @@ def _convert_element_type_aten(a: Tensor, dtype: torch.dtype) -> Tensor:
|
|||||||
return a.to(dtype)
|
return a.to(dtype)
|
||||||
|
|
||||||
|
|
||||||
_torch_dtype_to_nvfuser_dtype_map = {
|
|
||||||
torch.cdouble: DataType.ComplexDouble,
|
|
||||||
torch.cfloat: DataType.ComplexFloat,
|
|
||||||
torch.double: DataType.Double,
|
|
||||||
torch.float: DataType.Float,
|
|
||||||
torch.half: DataType.Half,
|
|
||||||
torch.bfloat16: DataType.BFloat16,
|
|
||||||
torch.long: DataType.Int,
|
|
||||||
torch.int: DataType.Int32,
|
|
||||||
torch.bool: DataType.Bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor:
|
def _convert_element_type_nvfuser(fd: Any, a: Tensor, dtype: torch.dtype) -> Tensor:
|
||||||
nvfuser_dtype = _torch_dtype_to_nvfuser_dtype_map[dtype]
|
nvfuser_dtype = getnvFuserDtype(dtype)
|
||||||
return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined]
|
return fd.Ops.cast(nvfuser_dtype, a) # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
@ -3,27 +3,11 @@ from typing import Callable
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from torch.fx import GraphModule
|
from torch.fx import GraphModule
|
||||||
from torch._prims.utils import TensorMeta
|
from torch._prims.utils import TensorMeta, getnvFuserDtype
|
||||||
from torch._prims.context import PrimContext
|
from torch._prims.context import PrimContext
|
||||||
|
|
||||||
import torch._C._nvfuser as nvfuser # type: ignore[import]
|
if torch.cuda.is_available():
|
||||||
|
from torch._C._nvfuser import Fusion, FusionDefinition # type: ignore[import]
|
||||||
DataType = nvfuser.DataType # type: ignore[attr-defined]
|
|
||||||
Fusion = nvfuser.Fusion # type: ignore[attr-defined]
|
|
||||||
FusionDefinition = nvfuser.FusionDefinition # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
# TODO: refactor me into a common place
|
|
||||||
_torch_dtype_to_nvfuser_dtype_map = {
|
|
||||||
torch.cdouble: DataType.ComplexDouble,
|
|
||||||
torch.cfloat: DataType.ComplexFloat,
|
|
||||||
torch.double: DataType.Double,
|
|
||||||
torch.float: DataType.Float,
|
|
||||||
torch.half: DataType.Half,
|
|
||||||
torch.bfloat16: DataType.BFloat16,
|
|
||||||
torch.long: DataType.Int,
|
|
||||||
torch.int: DataType.Int32,
|
|
||||||
torch.bool: DataType.Bool,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
||||||
@ -37,6 +21,11 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
|||||||
gm = GraphModule({}, ctx.graph)
|
gm = GraphModule({}, ctx.graph)
|
||||||
return gm.forward(*args, **kwargs)
|
return gm.forward(*args, **kwargs)
|
||||||
elif executor == "nvfuser":
|
elif executor == "nvfuser":
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
raise RuntimeError(
|
||||||
|
"Attempting to use nvFuser trace executor but CUDA is not available!"
|
||||||
|
)
|
||||||
|
|
||||||
# PROTOTYPE nvfuser executor
|
# PROTOTYPE nvfuser executor
|
||||||
# Only accepts tensor inputs and single tensor outputs
|
# Only accepts tensor inputs and single tensor outputs
|
||||||
# Does not handle kwargs
|
# Does not handle kwargs
|
||||||
@ -53,9 +42,7 @@ def execute(ctx: PrimContext, *args, executor: str = "aten", **kwargs):
|
|||||||
nv_args = [fd]
|
nv_args = [fd]
|
||||||
for arg in args:
|
for arg in args:
|
||||||
if isinstance(arg, torch.Tensor):
|
if isinstance(arg, torch.Tensor):
|
||||||
x = fd.define_tensor(
|
x = fd.define_tensor(arg.ndim, getnvFuserDtype(arg.dtype))
|
||||||
arg.ndim, _torch_dtype_to_nvfuser_dtype_map[arg.dtype]
|
|
||||||
)
|
|
||||||
fd.add_input(x)
|
fd.add_input(x)
|
||||||
nv_args.append(x)
|
nv_args.append(x)
|
||||||
else:
|
else:
|
||||||
|
@ -8,6 +8,32 @@ import threading
|
|||||||
import torch
|
import torch
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
|
|
||||||
|
# nvFuser imports are conditional on CUDA being available
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
from torch._C._nvfuser import DataType # type: ignore[import]
|
||||||
|
|
||||||
|
_torch_dtype_to_nvfuser_dtype_map = {
|
||||||
|
torch.cdouble: DataType.ComplexDouble,
|
||||||
|
torch.cfloat: DataType.ComplexFloat,
|
||||||
|
torch.double: DataType.Double,
|
||||||
|
torch.float: DataType.Float,
|
||||||
|
torch.half: DataType.Half,
|
||||||
|
torch.bfloat16: DataType.BFloat16,
|
||||||
|
torch.long: DataType.Int,
|
||||||
|
torch.int: DataType.Int32,
|
||||||
|
torch.bool: DataType.Bool,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
_torch_dtype_to_nvfuser_dtype_map = {}
|
||||||
|
|
||||||
|
|
||||||
|
def getnvFuserDtype(dtype: torch.dtype):
|
||||||
|
"""
|
||||||
|
Translates from torch.dtype to nvFuser's DataType enum
|
||||||
|
"""
|
||||||
|
return _torch_dtype_to_nvfuser_dtype_map[dtype]
|
||||||
|
|
||||||
|
|
||||||
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
|
ShapeType = Union[torch.Size, List[int], Tuple[int, ...]]
|
||||||
StrideType = Union[List[int], Tuple[int, ...]]
|
StrideType = Union[List[int], Tuple[int, ...]]
|
||||||
DimsType = Union[int, List[int], Tuple[int, ...]]
|
DimsType = Union[int, List[int], Tuple[int, ...]]
|
||||||
|
File diff suppressed because it is too large
Load Diff
@ -871,6 +871,10 @@ if IS_WINDOWS:
|
|||||||
|
|
||||||
# Dict of torch dtype -> NumPy dtype
|
# Dict of torch dtype -> NumPy dtype
|
||||||
torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
|
||||||
|
torch_to_numpy_dtype_dict.update({
|
||||||
|
torch.bfloat16: np.float32,
|
||||||
|
torch.complex32: np.complex64
|
||||||
|
})
|
||||||
|
|
||||||
def skipIfRocm(fn):
|
def skipIfRocm(fn):
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
|
Reference in New Issue
Block a user