Enable UFMT on test/test_ops* (#123935)

Part of https://github.com/pytorch/pytorch/issues/123062

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123935
Approved by: https://github.com/ezyang
This commit is contained in:
statelesshz
2024-04-13 03:31:56 +00:00
committed by PyTorch MergeBot
parent 71b8363f40
commit 2216068559
5 changed files with 692 additions and 417 deletions

View File

@ -1476,10 +1476,6 @@ exclude_patterns = [
'test/test_nvfuser_dynamo.py',
'test/test_nvfuser_frontend.py',
'test/test_openmp.py',
'test/test_ops.py',
'test/test_ops_fwd_gradients.py',
'test/test_ops_gradients.py',
'test/test_ops_jit.py',
'test/test_optim.py',
'test/test_out_dtype_op.py',
'test/test_overrides.py',

File diff suppressed because it is too large Load Diff

View File

@ -1,16 +1,25 @@
# Owner(s): ["module: unknown"]
from functools import partial
import platform
from functools import partial
from unittest import skipIf as skipif
import torch
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
from torch.testing._internal.common_utils import (
TestGradients, run_tests, skipIfTorchInductor, IS_MACOS, TestCase)
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
OpDTypes,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
from torch.testing._internal.common_utils import (
IS_MACOS,
run_tests,
skipIfTorchInductor,
TestCase,
TestGradients,
unMarkDynamoStrictTest,
)
# TODO: mitigate flaky issue on macOS https://github.com/pytorch/pytorch/issues/66033
# AFAIK, c10::ThreadPool looks correct in the way it uses condition_variable wait. The
@ -19,8 +28,10 @@ if IS_MACOS:
torch.set_num_threads(1)
# gradcheck requires double precision
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=[torch.double, torch.cdouble])
_gradcheck_ops = partial(
ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]
)
@unMarkDynamoStrictTest
class TestFwdGradients(TestGradients):
@ -33,31 +44,46 @@ class TestFwdGradients(TestGradients):
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
else:
err_msg = r"Trying to use forward AD with .* that does not support it"
hint_msg = ("Running forward-over-backward gradgrad for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True.")
hint_msg = (
"Running forward-over-backward gradgrad for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_fwgrad_bwgrad=True."
)
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
self._check_helper(device, dtype, op, op.get_op(), "fwgrad_bwgrad")
def _forward_grad_helper(self, device, dtype, op, variant, is_inplace):
# TODO: clean up how attributes are passed to gradcheck from OpInfos
def call_grad_test_helper():
check_batched_forward_grad = ((op.check_batched_forward_grad and not is_inplace) or
(op.check_inplace_batched_forward_grad and is_inplace))
self._grad_test_helper(device, dtype, op, variant, check_forward_ad=True, check_backward_ad=False,
check_batched_grad=False, check_batched_forward_grad=check_batched_forward_grad)
check_batched_forward_grad = (
op.check_batched_forward_grad and not is_inplace
) or (op.check_inplace_batched_forward_grad and is_inplace)
self._grad_test_helper(
device,
dtype,
op,
variant,
check_forward_ad=True,
check_backward_ad=False,
check_batched_grad=False,
check_batched_forward_grad=check_batched_forward_grad,
)
if op.supports_forward_ad:
call_grad_test_helper()
else:
err_msg = r"Trying to use forward AD with .* that does not support it"
hint_msg = ("Running forward AD for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True")
hint_msg = (
"Running forward AD for an OP that has does not support it did not "
"raise any error. If your op supports forward AD, you should set supports_forward_ad=True"
)
with self.assertRaisesRegex(NotImplementedError, err_msg, msg=hint_msg):
call_grad_test_helper()
@_gradcheck_ops(op_db)
@skipif(platform.machine() == "s390x",
reason="Different precision of openblas functions: https://github.com/OpenMathLib/OpenBLAS/issues/4194")
@skipif(
platform.machine() == "s390x",
reason="Different precision of openblas functions: https://github.com/OpenMathLib/OpenBLAS/issues/4194",
)
def test_forward_mode_AD(self, device, dtype, op):
self._skip_helper(op, device, dtype)
@ -71,10 +97,13 @@ class TestFwdGradients(TestGradients):
if not op.inplace_variant or not op.supports_inplace_autograd:
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._forward_grad_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True)
self._forward_grad_helper(
device, dtype, op, self._get_safe_inplace(op.get_inplace()), is_inplace=True
)
instantiate_device_type_tests(TestFwdGradients, globals())
if __name__ == '__main__':
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()

View File

@ -1,19 +1,29 @@
# Owner(s): ["module: unknown"]
from functools import partial
import torch
from torch.testing._internal.common_utils import TestGradients, run_tests, TestCase
import torch
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
OpDTypes,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
run_tests,
TestCase,
TestGradients,
unMarkDynamoStrictTest,
)
from torch.testing._internal.custom_op_db import custom_op_db
from torch.testing._internal.hop_db import hop_db
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, ops, OpDTypes)
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
# gradcheck requires double precision
_gradcheck_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=[torch.double, torch.cdouble])
_gradcheck_ops = partial(
ops, dtypes=OpDTypes.supported, allowed_dtypes=[torch.double, torch.cdouble]
)
@unMarkDynamoStrictTest
class TestBwdGradients(TestGradients):
@ -49,16 +59,20 @@ class TestBwdGradients(TestGradients):
result = inplace(sample)
result.sum().backward()
else:
self._grad_test_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()))
self._grad_test_helper(
device, dtype, op, self._get_safe_inplace(op.get_inplace())
)
# Test that gradients of gradients are computed correctly
@_gradcheck_ops(op_db + hop_db + custom_op_db)
def test_fn_gradgrad(self, device, dtype, op):
self._skip_helper(op, device, dtype)
if not op.supports_gradgrad:
self.skipTest("Op claims it doesn't support gradgrad. This is not verified.")
self.skipTest(
"Op claims it doesn't support gradgrad. This is not verified."
)
else:
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
# Test that gradients of gradients are properly raising
@_gradcheck_ops(op_db + custom_op_db)
@ -69,7 +83,7 @@ class TestBwdGradients(TestGradients):
err_msg = r"derivative for .* is not implemented"
with self.assertRaisesRegex(RuntimeError, err_msg):
self._check_helper(device, dtype, op, op.get_op(), 'bwgrad_bwgrad')
self._check_helper(device, dtype, op, op.get_op(), "bwgrad_bwgrad")
# Method gradgrad (and grad, see above) tests are disabled since they're
# costly and redundant with function gradgrad (and grad) tests
@ -83,11 +97,13 @@ class TestBwdGradients(TestGradients):
self._skip_helper(op, device, dtype)
if not op.inplace_variant or not op.supports_inplace_autograd:
self.skipTest("Skipped! Operation does not support inplace autograd.")
self._check_helper(device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad")
self._check_helper(
device, dtype, op, self._get_safe_inplace(op.get_inplace()), "bwgrad_bwgrad"
)
instantiate_device_type_tests(TestBwdGradients, globals())
if __name__ == '__main__':
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()

View File

@ -6,20 +6,39 @@ from textwrap import dedent
import torch
from torch.testing import FileCheck
from torch.testing._internal.common_utils import \
(run_tests, IS_SANDCASTLE, clone_input_helper, first_sample, TestCase)
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
OpDTypes,
ops,
)
from torch.testing._internal.common_jit import (
check_against_reference,
JitCommonTestCase,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_device_type import instantiate_device_type_tests, ops, OpDTypes
from torch.testing._internal.common_jit import JitCommonTestCase, check_against_reference
from torch.testing._internal.jit_metaprogramming_utils import create_script_fn, create_traced_fn, check_alias_annotation
from torch.testing._internal.jit_utils import disable_autodiff_subgraph_inlining, is_lambda
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
from torch.testing._internal.common_utils import (
clone_input_helper,
first_sample,
IS_SANDCASTLE,
run_tests,
TestCase,
unMarkDynamoStrictTest,
)
from torch.testing._internal.jit_metaprogramming_utils import (
check_alias_annotation,
create_script_fn,
create_traced_fn,
)
from torch.testing._internal.jit_utils import (
disable_autodiff_subgraph_inlining,
is_lambda,
)
# variant testing is only done with torch.float and torch.cfloat to avoid
# excessive test times and maximize signal to noise ratio
_variant_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float, torch.cfloat))
_variant_ops = partial(
ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat)
)
# Tests operators for consistency between JIT and eager, also checks
@ -37,17 +56,25 @@ class TestJit(JitCommonTestCase):
# TODO WARNING: inplace x {traced, scripted} not currently tested
@_variant_ops(op_db)
def test_variant_consistency_jit(self, device, dtype, op):
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
_requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex
samples = op.sample_inputs(device, dtype, requires_grad=_requires_grad, include_conjugated_inputs=include_conjugated_inputs)
samples = op.sample_inputs(
device,
dtype,
requires_grad=_requires_grad,
include_conjugated_inputs=include_conjugated_inputs,
)
# Acquires variants to test
func = op.get_op()
method = op.get_method()
variants = {
# TODO: inplace tests currently fail, fix and add inplace variant
'function': func, 'method': method,
"function": func,
"method": method,
}
# scripting strips the torch.ops prefix from these operators
@ -57,13 +84,12 @@ class TestJit(JitCommonTestCase):
self.skipTest("variant consistency doesn't work on torch.ops")
# TODO: find better way to standardize on op registration itself..
has_fake_function = op.name in ["resize_", 'resize_as_']
has_fake_function = op.name in ["resize_", "resize_as_"]
if has_fake_function:
variants = {'method': getattr(torch.Tensor, op.name)}
variants = {"method": getattr(torch.Tensor, op.name)}
samples = op.sample_inputs(device, dtype, requires_grad=False)
tested = False
for sample in samples:
# Test traced and scripted consistency
@ -80,22 +106,30 @@ class TestJit(JitCommonTestCase):
tested = True
try:
self.indiv_variant_test_jit(device, dtype, op, sample, func_type, variant, has_fake_function)
self.indiv_variant_test_jit(
device, dtype, op, sample, func_type, variant, has_fake_function
)
except Exception as e:
variant_error_info = dedent(f"""
variant_error_info = dedent(
f"""
Error testing {op.name} {func_type} variant
with dtype: {dtype}
with inputs {sample}:
""")
"""
)
raise Exception(variant_error_info) from e
assert tested, "JIT Test does not execute any logic"
def indiv_variant_test_jit(self, device, dtype, op, sample, func_type, variant, has_fake_function):
_requires_grad = (dtype in op.supported_backward_dtypes(torch.device(device).type))
def indiv_variant_test_jit(
self, device, dtype, op, sample, func_type, variant, has_fake_function
):
_requires_grad = dtype in op.supported_backward_dtypes(
torch.device(device).type
)
support_script = op.supports_scripting
# Create accessor for script function variant
name = op.name + '_' if func_type == 'inplace' else op.name
name = op.name + "_" if func_type == "inplace" else op.name
# run with disable_autodiff_subgraph_inlining(True) to test
# autodiff support. Context manager forces the graph to contain
@ -112,16 +146,23 @@ class TestJit(JitCommonTestCase):
return output
def get_sample():
return clone_input_helper(sample.input) if op.name[-1] == '_' else sample.input
return (
clone_input_helper(sample.input)
if op.name[-1] == "_"
else sample.input
)
if support_script:
check_against_reference(self,
script_fn,
op.get_op(),
out_fn,
(get_sample(),) + sample.args,
sample.kwargs,
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
check_against_reference(
self,
script_fn,
op.get_op(),
out_fn,
(get_sample(),) + sample.args,
sample.kwargs,
no_grad=not _requires_grad,
no_gradgrad=not op.supports_gradgrad,
)
# Check traced forward, grad, and grad grad
# TODO: fix tracing here
@ -131,13 +172,16 @@ class TestJit(JitCommonTestCase):
if supports_tracing:
traced_fn = create_traced_fn(self, variant)
check_against_reference(self,
traced_fn,
op.get_op(),
out_fn,
(get_sample(),) + sample.args,
sample.kwargs,
no_grad=not _requires_grad, no_gradgrad=not op.supports_gradgrad)
check_against_reference(
self,
traced_fn,
op.get_op(),
out_fn,
(get_sample(),) + sample.args,
sample.kwargs,
no_grad=not _requires_grad,
no_gradgrad=not op.supports_gradgrad,
)
# Check alias annotation schema for correctness (make
# sure inputs that aren't supposed to be modified aren't)
@ -146,8 +190,13 @@ class TestJit(JitCommonTestCase):
if dtype == torch.float32:
# TODO: no reason why we cant run this with tracing graph
if support_script and op.name != "rsub":
check_alias_annotation(name, (get_sample(),) + sample.args, sample.kwargs,
func_type=func_type, aten_name=op.aten_name)
check_alias_annotation(
name,
(get_sample(),) + sample.args,
sample.kwargs,
func_type=func_type,
aten_name=op.aten_name,
)
# TODO: use script graph as well
checked_shape_analysis = False
@ -156,14 +205,18 @@ class TestJit(JitCommonTestCase):
# right now, tuple of outputs and tensor output supported
# TODO: list of tensor outputs
tuple_of_tensors = isinstance(out, tuple) and all(isinstance(elem, torch.Tensor) for elem in out)
tuple_of_tensors = isinstance(out, tuple) and all(
isinstance(elem, torch.Tensor) for elem in out
)
if isinstance(out, torch.Tensor) or tuple_of_tensors:
if tuple_of_tensors:
sizes = [elem.size() for elem in out]
else:
sizes = out.size()
self.checkShapeAnalysis(sizes, traced_fn.graph, op.assert_jit_shape_analysis)
self.checkShapeAnalysis(
sizes, traced_fn.graph, op.assert_jit_shape_analysis
)
checked_shape_analysis = True
if op.assert_jit_shape_analysis:
self.assertTrue(checked_shape_analysis)
@ -173,20 +226,31 @@ class TestJit(JitCommonTestCase):
# Sandcastle doesn't fuse nodes
if IS_SANDCASTLE:
# fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs
nonfusible_nodes = op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
nonfusible_nodes = (
op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes
)
fusible_nodes = []
else:
nonfusible_nodes = op.autodiff_nonfusible_nodes
fusible_nodes = op.autodiff_fusible_nodes
if supports_tracing:
self.assertAutodiffNode(traced_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
self.assertAutodiffNode(
traced_fn.last_graph,
op.assert_autodiffed,
nonfusible_nodes,
fusible_nodes,
)
if support_script:
self.assertAutodiffNode(script_fn.last_graph, op.assert_autodiffed, nonfusible_nodes, fusible_nodes)
self.assertAutodiffNode(
script_fn.last_graph,
op.assert_autodiffed,
nonfusible_nodes,
fusible_nodes,
)
# alias testing is only done with torch.float for the same reason
_alias_ops = partial(ops, dtypes=OpDTypes.supported,
allowed_dtypes=(torch.float,))
_alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,))
@_alias_ops(op for op in op_db if op.aliases)
def test_jit_alias_remapping(self, device, dtype, op):
@ -209,16 +273,18 @@ class TestJit(JitCommonTestCase):
return str(v)
args_kw = args + \
[f"{v}" for v in sample.args] + \
[f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
args_kw = (
args
+ [f"{v}" for v in sample.args]
+ [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()]
)
# Prepare data for test tracing
sample_args_kwargs = ()
if len(sample.args) > 0:
sample_args_kwargs += (sample.args, )
sample_args_kwargs += (sample.args,)
if len(sample.kwargs) > 0:
sample_args_kwargs += (sample.kwargs, )
sample_args_kwargs += (sample.kwargs,)
original_name = op.aten_name
original_name_inplace = original_name + "_"
@ -227,7 +293,11 @@ class TestJit(JitCommonTestCase):
for a_op in op.aliases:
inplace = a_op.inplace_variant
method_or_inplace = [a_op.inplace_variant, a_op.method_variant]
variants = (v for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) if v is not None)
variants = (
v
for v in (a_op.op, a_op.method_variant, a_op.inplace_variant)
if v is not None
)
# Test scripting:
for variant in variants:
@ -235,10 +305,10 @@ class TestJit(JitCommonTestCase):
op_name = original_name_inplace if variant is inplace else original_name
if variant in method_or_inplace:
fn_template = '''
fn_template = """
def _fn(t0{c}):
return t0.{alias_name}({args_kw})
'''
"""
# remove the first input tensor
script = fn_template.format(
c=", " if len(args_kw[1:]) > 1 else "",
@ -246,10 +316,10 @@ class TestJit(JitCommonTestCase):
alias_name=variant_name,
)
else:
fn_template = '''
fn_template = """
def _fn({args}):
return variant({args_kw})
'''
"""
script = fn_template.format(
args=", ".join(args),
args_kw=", ".join(args_kw),
@ -261,13 +331,15 @@ class TestJit(JitCommonTestCase):
scripted = torch.jit.CompilationUnit(script)._fn
if (variant is inplace and not torch.can_cast(expected_dtype, dtype)):
if variant is inplace and not torch.can_cast(expected_dtype, dtype):
try:
inp = clone_input_helper(sample.input)
scripted(inp)
except Exception as e:
continue
self.fail("Inplace operation on integer tensor that should be promoted to float didn't fail!")
self.fail(
"Inplace operation on integer tensor that should be promoted to float didn't fail!"
)
inp = clone_input_helper(sample.input)
scripted(inp)
@ -294,6 +366,6 @@ class TestJit(JitCommonTestCase):
instantiate_device_type_tests(TestJit, globals())
if __name__ == '__main__':
if __name__ == "__main__":
TestCase._default_dtype_check_enabled = True
run_tests()