mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
Enable UFMT on test/test_ops* (#123935)
Part of https://github.com/pytorch/pytorch/issues/123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/123935 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
71b8363f40
commit
2216068559
@ -1476,10 +1476,6 @@ exclude_patterns = [
|
||||
'test/test_nvfuser_dynamo.py',
|
||||
'test/test_nvfuser_frontend.py',
|
||||
'test/test_openmp.py',
|
||||
'test/test_ops.py',
|
||||
'test/test_ops_fwd_gradients.py',
|
||||
'test/test_ops_gradients.py',
|
||||
'test/test_ops_jit.py',
|
||||
'test/test_optim.py',
|
||||
'test/test_out_dtype_op.py',
|
||||
'test/test_overrides.py',
|
||||
|
800
test/test_ops.py
800
test/test_ops.py
File diff suppressed because it is too large
Load Diff
@ -1,16 +1,25 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from functools import partial
|
||||
import platform
|
||||
from functools import partial
|
||||
from unittest import skipIf as skipif
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
TestGradients, run_tests, skipIfTorchInductor, IS_MACOS, TestCase)
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
OpDTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
IS_MACOS,
|
||||
run_tests,
|
||||
skipIfTorchInductor,
|
||||
TestCase,
|
||||
TestGradients,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
|
||||
# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033
|
||||
# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The
|
||||
@ -19,8 +28,10 @@ if IS_MACOS:
|
||||
torch.set_num_threads(1)
|
||||
|
||||
# gradcheck requires double precision
|
||||
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=[torch.double, torch.cdouble])
|
||||
_gradcheck_ops = partial(
|
||||
ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]
|
||||
)
|
||||
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestFwdGradients(TestGradients):
|
||||
@ -33,31 +44,46 @@ class TestFwdGradients(TestGradients):
|
||||
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
||||
else:
|
||||
err_msg = r"Trying to use forward AD with .* that does not support it"
|
||||
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
|
||||
hint_msg = (
|
||||
"Running forward-over-backward gradgrad for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True."
|
||||
)
|
||||
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
||||
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
|
||||
|
||||
|
||||
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
|
||||
# TODO: clean up how attributes are passed to gradcheck from OpInfos
|
||||
def call_grad_test_helper():
|
||||
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
|
||||
(op.check_inplace_batched_forward_grad and is_inplace))
|
||||
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
|
||||
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
|
||||
check_batched_forward_grad = (
|
||||
op.check_batched_forward_grad and not is_inplace
|
||||
) or (op.check_inplace_batched_forward_grad and is_inplace)
|
||||
self._grad_test_helper(
|
||||
device,
|
||||
dtype,
|
||||
op,
|
||||
variant,
|
||||
check_forward_ad=True,
|
||||
check_backward_ad=False,
|
||||
check_batched_grad=False,
|
||||
check_batched_forward_grad=check_batched_forward_grad,
|
||||
)
|
||||
|
||||
if op.supports_forward_ad:
|
||||
call_grad_test_helper()
|
||||
else:
|
||||
err_msg = r"Trying to use forward AD with .* that does not support it"
|
||||
hint_msg = ("Running forward AD for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
|
||||
hint_msg = (
|
||||
"Running forward AD for an OP that has does not support it did not "
|
||||
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True"
|
||||
)
|
||||
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
|
||||
call_grad_test_helper()
|
||||
|
||||
@_gradcheck_ops(op_db)
|
||||
@skipif(platform.machine() == "s390x",
|
||||
reason="Different precision of openblas functions: https://github.com/OpenMathLib/OpenBLAS/issues/4194")
|
||||
@skipif(
|
||||
platform.machine() == "s390x",
|
||||
reason="Different precision of openblas functions: https://github.com/OpenMathLib/OpenBLAS/issues/4194",
|
||||
)
|
||||
def test_forward_mode_AD(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
|
||||
@ -71,10 +97,13 @@ class TestFwdGradients(TestGradients):
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
|
||||
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
|
||||
self._forward_grad_helper(
|
||||
device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestFwdGradients, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
|
@ -1,19 +1,29 @@
|
||||
# Owner(s): ["module: unknown"]
|
||||
|
||||
from functools import partial
|
||||
import torch
|
||||
|
||||
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
|
||||
import torch
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
OpDTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
|
||||
from torch.testing._internal.common_utils import (
|
||||
run_tests,
|
||||
TestCase,
|
||||
TestGradients,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
from torch.testing._internal.custom_op_db import custom_op_db
|
||||
from torch.testing._internal.hop_db import hop_db
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, ops, OpDTypes)
|
||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||
|
||||
# gradcheck requires double precision
|
||||
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=[torch.double, torch.cdouble])
|
||||
_gradcheck_ops = partial(
|
||||
ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]
|
||||
)
|
||||
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestBwdGradients(TestGradients):
|
||||
@ -49,16 +59,20 @@ class TestBwdGradients(TestGradients):
|
||||
result = inplace(sample)
|
||||
result.sum().backward()
|
||||
else:
|
||||
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
|
||||
self._grad_test_helper(
|
||||
device, dtype, op, self._get_safe_inplace(op.get_inplace())
|
||||
)
|
||||
|
||||
# Test that gradients of gradients are computed correctly
|
||||
@_gradcheck_ops(op_db + hop_db + custom_op_db)
|
||||
def test_fn_gradgrad(self, device, dtype, op):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.supports_gradgrad:
|
||||
self.skipTest("Op claims it doesn't support gradgrad. This is not verified.")
|
||||
self.skipTest(
|
||||
"Op claims it doesn't support gradgrad. This is not verified."
|
||||
)
|
||||
else:
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
|
||||
|
||||
# Test that gradients of gradients are properly raising
|
||||
@_gradcheck_ops(op_db + custom_op_db)
|
||||
@ -69,7 +83,7 @@ class TestBwdGradients(TestGradients):
|
||||
|
||||
err_msg = r"derivative for .* is not implemented"
|
||||
with self.assertRaisesRegex(RuntimeError, err_msg):
|
||||
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
|
||||
self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
|
||||
|
||||
# Method gradgrad (and grad, see above) tests are disabled since they're
|
||||
# costly and redundant with function gradgrad (and grad) tests
|
||||
@ -83,11 +97,13 @@ class TestBwdGradients(TestGradients):
|
||||
self._skip_helper(op, device, dtype)
|
||||
if not op.inplace_variant or not op.supports_inplace_autograd:
|
||||
self.skipTest("Skipped! Operation does not support inplace autograd.")
|
||||
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
|
||||
self._check_helper(
|
||||
device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad"
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestBwdGradients, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
|
@ -6,20 +6,39 @@ from textwrap import dedent
|
||||
import torch
|
||||
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import \
|
||||
(run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, TestCase)
|
||||
from torch.testing._internal.common_device_type import (
|
||||
instantiate_device_type_tests,
|
||||
OpDTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_jit import (
|
||||
check_against_reference,
|
||||
JitCommonTestCase,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
|
||||
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
|
||||
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
|
||||
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
|
||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
clone_input_helper,
|
||||
first_sample,
|
||||
IS_SANDCASTLE,
|
||||
run_tests,
|
||||
TestCase,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
from torch.testing._internal.jit_metaprogramming_utils import (
|
||||
check_alias_annotation,
|
||||
create_script_fn,
|
||||
create_traced_fn,
|
||||
)
|
||||
from torch.testing._internal.jit_utils import (
|
||||
disable_autodiff_subgraph_inlining,
|
||||
is_lambda,
|
||||
)
|
||||
|
||||
# variant testing is only done with torch.float and torch.cfloat to avoid
|
||||
# excessive test times and maximize signal to noise ratio
|
||||
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float, torch.cfloat))
|
||||
|
||||
_variant_ops = partial(
|
||||
ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
|
||||
)
|
||||
|
||||
|
||||
# Tests operators for consistency between JIT and eager, also checks
|
||||
@ -37,17 +56,25 @@ class TestJit(JitCommonTestCase):
|
||||
# TODO WARNING: inplace x {traced, scripted} not currently tested
|
||||
@_variant_ops(op_db)
|
||||
def test_variant_consistency_jit(self, device, dtype, op):
|
||||
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
|
||||
_requires_grad = dtype in op.supported_backward_dtypes(
|
||||
torch.device(device).type
|
||||
)
|
||||
|
||||
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
|
||||
samples = op.sample_inputs(
|
||||
device,
|
||||
dtype,
|
||||
requires_grad=_requires_grad,
|
||||
include_conjugated_inputs=include_conjugated_inputs,
|
||||
)
|
||||
|
||||
# Acquires variants to test
|
||||
func = op.get_op()
|
||||
method = op.get_method()
|
||||
variants = {
|
||||
# TODO: inplace tests currently fail, fix and add inplace variant
|
||||
'function': func, 'method': method,
|
||||
"function": func,
|
||||
"method": method,
|
||||
}
|
||||
|
||||
# scripting strips the torch.ops prefix from these operators
|
||||
@ -57,13 +84,12 @@ class TestJit(JitCommonTestCase):
|
||||
self.skipTest("variant consistency doesn't work on torch.ops")
|
||||
|
||||
# TODO: find better way to standardize on op registration itself..
|
||||
has_fake_function = op.name in ["resize_", 'resize_as_']
|
||||
has_fake_function = op.name in ["resize_", "resize_as_"]
|
||||
|
||||
if has_fake_function:
|
||||
variants = {'method': getattr(torch.Tensor, op.name)}
|
||||
variants = {"method": getattr(torch.Tensor, op.name)}
|
||||
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
||||
|
||||
|
||||
tested = False
|
||||
for sample in samples:
|
||||
# Test traced and scripted consistency
|
||||
@ -80,22 +106,30 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
tested = True
|
||||
try:
|
||||
self.indiv_variant_test_jit(device, dtype, op, sample, func_type, variant, has_fake_function)
|
||||
self.indiv_variant_test_jit(
|
||||
device, dtype, op, sample, func_type, variant, has_fake_function
|
||||
)
|
||||
except Exception as e:
|
||||
variant_error_info = dedent(f"""
|
||||
variant_error_info = dedent(
|
||||
f"""
|
||||
Error testing {op.name} {func_type} variant
|
||||
with dtype: {dtype}
|
||||
with inputs {sample}:
|
||||
""")
|
||||
"""
|
||||
)
|
||||
raise Exception(variant_error_info) from e
|
||||
|
||||
assert tested, "JIT Test does not execute any logic"
|
||||
|
||||
def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function):
|
||||
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
|
||||
def indiv_variant_test_jit(
|
||||
self, device, dtype, op, sample, func_type, variant, has_fake_function
|
||||
):
|
||||
_requires_grad = dtype in op.supported_backward_dtypes(
|
||||
torch.device(device).type
|
||||
)
|
||||
support_script = op.supports_scripting
|
||||
# Create accessor for script function variant
|
||||
name = op.name + '_' if func_type == 'inplace' else op.name
|
||||
name = op.name + "_" if func_type == "inplace" else op.name
|
||||
|
||||
# run with disable_autodiff_subgraph_inlining(True) to test
|
||||
# autodiff support. Context manager forces the graph to contain
|
||||
@ -112,16 +146,23 @@ class TestJit(JitCommonTestCase):
|
||||
return output
|
||||
|
||||
def get_sample():
|
||||
return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
|
||||
return (
|
||||
clone_input_helper(sample.input)
|
||||
if op.name[-1] == "_"
|
||||
else sample.input
|
||||
)
|
||||
|
||||
if support_script:
|
||||
check_against_reference(self,
|
||||
script_fn,
|
||||
op.get_op(),
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
|
||||
check_against_reference(
|
||||
self,
|
||||
script_fn,
|
||||
op.get_op(),
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad,
|
||||
no_gradgrad=not op.supports_gradgrad,
|
||||
)
|
||||
|
||||
# Check traced forward, grad, and grad grad
|
||||
# TODO: fix tracing here
|
||||
@ -131,13 +172,16 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
if supports_tracing:
|
||||
traced_fn = create_traced_fn(self, variant)
|
||||
check_against_reference(self,
|
||||
traced_fn,
|
||||
op.get_op(),
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
|
||||
check_against_reference(
|
||||
self,
|
||||
traced_fn,
|
||||
op.get_op(),
|
||||
out_fn,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
no_grad=not _requires_grad,
|
||||
no_gradgrad=not op.supports_gradgrad,
|
||||
)
|
||||
|
||||
# Check alias annotation schema for correctness (make
|
||||
# sure inputs that aren't supposed to be modified aren't)
|
||||
@ -146,8 +190,13 @@ class TestJit(JitCommonTestCase):
|
||||
if dtype == torch.float32:
|
||||
# TODO: no reason why we cant run this with tracing graph
|
||||
if support_script and op.name != "rsub":
|
||||
check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
|
||||
func_type=func_type, aten_name=op.aten_name)
|
||||
check_alias_annotation(
|
||||
name,
|
||||
(get_sample(),) + sample.args,
|
||||
sample.kwargs,
|
||||
func_type=func_type,
|
||||
aten_name=op.aten_name,
|
||||
)
|
||||
|
||||
# TODO: use script graph as well
|
||||
checked_shape_analysis = False
|
||||
@ -156,14 +205,18 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
# right now, tuple of outputs and tensor output supported
|
||||
# TODO: list of tensor outputs
|
||||
tuple_of_tensors = isinstance(out, tuple) and all(isinstance(elem, torch.Tensor) for elem in out)
|
||||
tuple_of_tensors = isinstance(out, tuple) and all(
|
||||
isinstance(elem, torch.Tensor) for elem in out
|
||||
)
|
||||
|
||||
if isinstance(out, torch.Tensor) or tuple_of_tensors:
|
||||
if tuple_of_tensors:
|
||||
sizes = [elem.size() for elem in out]
|
||||
else:
|
||||
sizes = out.size()
|
||||
self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
|
||||
self.checkShapeAnalysis(
|
||||
sizes, traced_fn.graph, op.assert_jit_shape_analysis
|
||||
)
|
||||
checked_shape_analysis = True
|
||||
if op.assert_jit_shape_analysis:
|
||||
self.assertTrue(checked_shape_analysis)
|
||||
@ -173,20 +226,31 @@ class TestJit(JitCommonTestCase):
|
||||
# Sandcastle doesn't fuse nodes
|
||||
if IS_SANDCASTLE:
|
||||
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
|
||||
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
|
||||
nonfusible_nodes = (
|
||||
op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
|
||||
)
|
||||
fusible_nodes = []
|
||||
else:
|
||||
nonfusible_nodes = op.autodiff_nonfusible_nodes
|
||||
fusible_nodes = op.autodiff_fusible_nodes
|
||||
|
||||
if supports_tracing:
|
||||
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
||||
self.assertAutodiffNode(
|
||||
traced_fn.last_graph,
|
||||
op.assert_autodiffed,
|
||||
nonfusible_nodes,
|
||||
fusible_nodes,
|
||||
)
|
||||
if support_script:
|
||||
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
|
||||
self.assertAutodiffNode(
|
||||
script_fn.last_graph,
|
||||
op.assert_autodiffed,
|
||||
nonfusible_nodes,
|
||||
fusible_nodes,
|
||||
)
|
||||
|
||||
# alias testing is only done with torch.float for the same reason
|
||||
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
|
||||
allowed_dtypes=(torch.float,))
|
||||
_alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,))
|
||||
|
||||
@_alias_ops(op for op in op_db if op.aliases)
|
||||
def test_jit_alias_remapping(self, device, dtype, op):
|
||||
@ -209,16 +273,18 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
return str(v)
|
||||
|
||||
args_kw = args + \
|
||||
[f"{v}" for v in sample.args] + \
|
||||
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
|
||||
args_kw = (
|
||||
args
|
||||
+ [f"{v}" for v in sample.args]
|
||||
+ [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
|
||||
)
|
||||
|
||||
# Prepare data for test tracing
|
||||
sample_args_kwargs = ()
|
||||
if len(sample.args) > 0:
|
||||
sample_args_kwargs += (sample.args, )
|
||||
sample_args_kwargs += (sample.args,)
|
||||
if len(sample.kwargs) > 0:
|
||||
sample_args_kwargs += (sample.kwargs, )
|
||||
sample_args_kwargs += (sample.kwargs,)
|
||||
|
||||
original_name = op.aten_name
|
||||
original_name_inplace = original_name + "_"
|
||||
@ -227,7 +293,11 @@ class TestJit(JitCommonTestCase):
|
||||
for a_op in op.aliases:
|
||||
inplace = a_op.inplace_variant
|
||||
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
|
||||
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
|
||||
variants = (
|
||||
v
|
||||
for v in (a_op.op, a_op.method_variant, a_op.inplace_variant)
|
||||
if v is not None
|
||||
)
|
||||
|
||||
# Test scripting:
|
||||
for variant in variants:
|
||||
@ -235,10 +305,10 @@ class TestJit(JitCommonTestCase):
|
||||
op_name = original_name_inplace if variant is inplace else original_name
|
||||
|
||||
if variant in method_or_inplace:
|
||||
fn_template = '''
|
||||
fn_template = """
|
||||
def _fn(t0{c}):
|
||||
return t0.{alias_name}({args_kw})
|
||||
'''
|
||||
"""
|
||||
# remove the first input tensor
|
||||
script = fn_template.format(
|
||||
c=", " if len(args_kw[1:]) > 1 else "",
|
||||
@ -246,10 +316,10 @@ class TestJit(JitCommonTestCase):
|
||||
alias_name=variant_name,
|
||||
)
|
||||
else:
|
||||
fn_template = '''
|
||||
fn_template = """
|
||||
def _fn({args}):
|
||||
return variant({args_kw})
|
||||
'''
|
||||
"""
|
||||
script = fn_template.format(
|
||||
args=", ".join(args),
|
||||
args_kw=", ".join(args_kw),
|
||||
@ -261,13 +331,15 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
scripted = torch.jit.CompilationUnit(script)._fn
|
||||
|
||||
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
|
||||
if variant is inplace and not torch.can_cast(expected_dtype, dtype):
|
||||
try:
|
||||
inp = clone_input_helper(sample.input)
|
||||
scripted(inp)
|
||||
except Exception as e:
|
||||
continue
|
||||
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
|
||||
self.fail(
|
||||
"Inplace operation on integer tensor that should be promoted to float didn't fail!"
|
||||
)
|
||||
|
||||
inp = clone_input_helper(sample.input)
|
||||
scripted(inp)
|
||||
@ -294,6 +366,6 @@ class TestJit(JitCommonTestCase):
|
||||
|
||||
instantiate_device_type_tests(TestJit, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
TestCase._default_dtype_check_enabled = True
|
||||
run_tests()
|
||||
|
Reference in New Issue
Block a user