mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on test_decomp.py
, test_expanded_weights.py
and some files (#125117)
Part of: #123062 Ran lintrunner on: - test/test_decomp.py - test/test_deploy.py - test/test_determination.py - test/test_dlpack.py - test/test_dynamic_shapes.py - test/test_expanded_weights.py Detail: ```bash $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125117 Approved by: https://github.com/jansel
This commit is contained in:
committed by
PyTorch MergeBot
parent
48b6c8dbc3
commit
c165a8e71d
@ -1,45 +1,51 @@
|
||||
# Owner(s): ["module: decompositions"]
|
||||
|
||||
from collections import defaultdict
|
||||
from torch import Tensor
|
||||
import torch.autograd
|
||||
from torch._decomp import core_aten_decompositions, decomposition_table
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
import functools
|
||||
|
||||
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
|
||||
from torch.utils import _pytree as pytree
|
||||
import itertools
|
||||
import re
|
||||
import unittest
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
|
||||
import torch.autograd
|
||||
from torch import Tensor
|
||||
from torch._decomp import core_aten_decompositions, decomposition_table
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._ops import DispatchKey
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import tf32_off
|
||||
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
|
||||
from torch.testing._internal.common_utils import (
|
||||
is_iterable_of_tensors,
|
||||
IS_WINDOWS,
|
||||
IS_MACOS,
|
||||
TestCase,
|
||||
skipIfCrossRef,
|
||||
suppress_warnings,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_SLOW,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
)
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_device_type import (
|
||||
onlyNativeDeviceTypes,
|
||||
ops,
|
||||
instantiate_device_type_tests,
|
||||
onlyCPU,
|
||||
onlyCUDA,
|
||||
onlyNativeDeviceTypes,
|
||||
ops,
|
||||
)
|
||||
from torch.testing._internal.common_methods_invocations import op_db, skip, skipOps, xfail
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._ops import DispatchKey
|
||||
from torch.testing._internal.common_methods_invocations import (
|
||||
op_db,
|
||||
skip,
|
||||
skipOps,
|
||||
xfail,
|
||||
)
|
||||
from torch.testing._internal.common_modules import module_db, modules
|
||||
from torch.testing._internal.common_utils import (
|
||||
is_iterable_of_tensors,
|
||||
IS_MACOS,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfCrossRef,
|
||||
skipIfTorchDynamo,
|
||||
suppress_warnings,
|
||||
TEST_WITH_ASAN,
|
||||
TEST_WITH_SLOW,
|
||||
TestCase,
|
||||
unMarkDynamoStrictTest,
|
||||
)
|
||||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
import itertools
|
||||
import functools
|
||||
from functools import partial
|
||||
import re
|
||||
import unittest
|
||||
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
@ -51,11 +57,13 @@ def overload_to_aten_name(op):
|
||||
|
||||
# All operators that can have decomp tests
|
||||
decomposition_names = {
|
||||
overload_to_aten_name(k) for k in decomposition_table
|
||||
overload_to_aten_name(k)
|
||||
for k in decomposition_table
|
||||
if isinstance(k, torch._ops.OpOverload)
|
||||
}
|
||||
core_decomposition_names = {
|
||||
overload_to_aten_name(k) for k in core_aten_decompositions()
|
||||
overload_to_aten_name(k)
|
||||
for k in core_aten_decompositions()
|
||||
if isinstance(k, torch._ops.OpOverload)
|
||||
}
|
||||
_decomp_test_ops = [
|
||||
@ -67,12 +75,9 @@ _decomp_test_ops = [
|
||||
_decomp_test_ops_core_autograd = [
|
||||
op
|
||||
for op in op_db
|
||||
if op.aten_name in core_decomposition_names
|
||||
and op.supports_autograd
|
||||
]
|
||||
_sdpa_op_info = [
|
||||
op for op in op_db if "scaled_dot_product_attention" in op.aten_name
|
||||
if op.aten_name in core_decomposition_names and op.supports_autograd
|
||||
]
|
||||
_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]
|
||||
|
||||
|
||||
def diff_arg(arg, requires_grad=True):
|
||||
@ -144,7 +149,10 @@ def ref_vjp_no_create(f, *primals):
|
||||
|
||||
def wrapped(cotangents):
|
||||
return _autograd_grad(
|
||||
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False,
|
||||
_as_tuple(result),
|
||||
primals,
|
||||
_as_tuple(cotangents),
|
||||
create_graph=False,
|
||||
retain_graph=True,
|
||||
)
|
||||
|
||||
@ -230,7 +238,10 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
|
||||
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
test_case.assertEqual(
|
||||
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
|
||||
orig.dtype,
|
||||
decomp.dtype,
|
||||
f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
|
||||
)
|
||||
# Before adding an entry to this table, make sure your decomposition is right :)
|
||||
tol_table = {
|
||||
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
|
||||
@ -241,42 +252,48 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
|
||||
),
|
||||
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
|
||||
# This exceeds default tolerances only on CPU, on CUDA it's fine
|
||||
(torch.float32, torch.ops.aten.grid_sampler_2d.default) : (7e-6, 3e-5),
|
||||
(torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
|
||||
# Exceeds tolerances on CUDA, likely due to fma
|
||||
(torch.float32, torch.ops.aten.mv.default) : (1e-5, 3e-5),
|
||||
(torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
|
||||
(torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
|
||||
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 5e-4),
|
||||
(torch.float64, torch.ops.aten.upsample_bicubic2d.default) : (1e-5, 5e-4),
|
||||
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
|
||||
(torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
|
||||
# The decomposition is TOO correct. It computes everything in int64, so sometimes
|
||||
# there's an off-by-one error. See
|
||||
# https://github.com/pytorch/pytorch/issues/81996
|
||||
# https://github.com/pytorch/pytorch/issues/82230
|
||||
(torch.int8, torch.ops.aten.linspace.default) : (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.default) : (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.default) : (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.default) : (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.default) : (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.default): (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.default): (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.default): (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.default): (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.default): (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
|
||||
(torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
|
||||
(torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
|
||||
(torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
|
||||
(torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
|
||||
(torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
|
||||
}
|
||||
if (decomp.dtype, op) in tol_table:
|
||||
rtol, atol = tol_table[(decomp.dtype, op)]
|
||||
else:
|
||||
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
|
||||
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
|
||||
test_case.assertEqual(
|
||||
orig,
|
||||
decomp,
|
||||
rtol=rtol,
|
||||
atol=atol,
|
||||
msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
|
||||
)
|
||||
|
||||
|
||||
# Given f, returns an f' such that:
|
||||
@ -322,8 +339,11 @@ def normalize_op_input_output2(
|
||||
def upcast_tensor(x, dtype=torch.float32):
|
||||
if isinstance(x, Tensor) and x.dtype.is_floating_point:
|
||||
return x.to(dtype=dtype)
|
||||
elif (isinstance(x, torch.dtype)
|
||||
and x in [torch.float16, torch.bfloat16, torch.float]):
|
||||
elif isinstance(x, torch.dtype) and x in [
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
torch.float,
|
||||
]:
|
||||
return dtype
|
||||
else:
|
||||
return x
|
||||
@ -352,20 +372,16 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
(None, None, "new_empty"),
|
||||
(None, None, "empty_like"),
|
||||
(None, None, "empty"),
|
||||
|
||||
# AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
|
||||
(None, None, "item"),
|
||||
|
||||
# It's the only in-place op without an out-of-place equivalent in the Python API
|
||||
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
|
||||
(None, None, "zero_"),
|
||||
|
||||
# No idea what's going on here
|
||||
# In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
|
||||
# in the test, but it seems to pass when tested locally and in the logsumexp test
|
||||
(None, torch.float32, "masked.logsumexp"),
|
||||
(None, torch.float64, "masked.logsumexp"),
|
||||
|
||||
# exp_vml_cpu not implemented for Half
|
||||
(torch.cpu, torch.float16, "signal.windows.exponential"),
|
||||
(torch.cpu, torch.float16, "signal.windows.gaussian"),
|
||||
@ -387,9 +403,7 @@ CROSS_REF_EXCLUDE_SET = {
|
||||
(None, None, "norm"),
|
||||
# native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
|
||||
(None, None, "native_batch_norm"),
|
||||
|
||||
(None, None, "_upsample_bilinear2d_aa"),
|
||||
|
||||
(None, None, "empty_strided"), # aten.empty_strided was not decomposed
|
||||
}
|
||||
|
||||
@ -432,10 +446,16 @@ def any_unsupported(args, kwargs):
|
||||
if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
|
||||
# These are all things that we haven't coded decompositions
|
||||
# to handle correctly. Maybe they should.
|
||||
return any([
|
||||
t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized,
|
||||
t.is_nested, torch._is_functional_tensor(t),
|
||||
])
|
||||
return any(
|
||||
[
|
||||
t.is_sparse_csr,
|
||||
t.is_sparse,
|
||||
t.is_mkldnn,
|
||||
t.is_quantized,
|
||||
t.is_nested,
|
||||
torch._is_functional_tensor(t),
|
||||
]
|
||||
)
|
||||
elif torch.overrides.is_tensor_like(t):
|
||||
# Decompositions will generally change the behavior of Tensor-like
|
||||
# subclasses, so bypass tests in this case too
|
||||
@ -448,59 +468,68 @@ def any_unsupported(args, kwargs):
|
||||
|
||||
|
||||
core_backward_failures = {
|
||||
skip('_softmax_backward_data'), # slow: fails with --timeout=360 secs
|
||||
xfail('addcdiv'),
|
||||
skip('addcmul'), # slow: fails with --timeout=360 secs
|
||||
skip('deg2rad'), # slow: fails with --timeout=360 secs
|
||||
skip('diag_embed'), # slow: fails with --timeout=360 secs
|
||||
skip('frac'), # slow: fails with --timeout=360 secs
|
||||
skip('grid_sampler_2d'), # slow: fails with --timeout=360 secs
|
||||
xfail('lerp'),
|
||||
skip('logaddexp'), # slow: fails with --timeout=360 secs
|
||||
skip('native_dropout_backward'), # slow: fails with --timeout=360 secs
|
||||
xfail('nn.functional.binary_cross_entropy_with_logits'),
|
||||
skip('nn.functional.glu'), # slow: fails with --timeout=360 secs
|
||||
xfail('nn.functional.hardshrink'),
|
||||
xfail('nn.functional.softshrink'),
|
||||
skip('nn.functional.unfold'), # slow: fails with --timeout=360 secs
|
||||
xfail('norm'),
|
||||
xfail('norm', 'fro'),
|
||||
xfail('norm', 'inf'),
|
||||
xfail('norm', 'nuc'),
|
||||
skip('rad2deg'), # slow: fails with --timeout=360 secs
|
||||
skip('renorm'), # slow: fails with --timeout=360 secs
|
||||
skip('rot90'), # slow: fails with --timeout=360 secs
|
||||
skip('rsub'), # slow: fails with --timeout=360 secs
|
||||
skip('sgn'), # slow: fails with --timeout=360 secs
|
||||
skip('special.xlog1py'), # slow: fails with --timeout=360 secs
|
||||
xfail('stack'),
|
||||
skip('tril'), # slow: fails with --timeout=360 secs
|
||||
skip('triu'), # slow: fails with --timeout=360 secs
|
||||
skip('unfold_copy'), # slow: fails with --timeout=360 secs
|
||||
skip('xlogy'), # slow: fails with --timeout=360 secs
|
||||
xfail('zero_'),
|
||||
skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs
|
||||
xfail("addcdiv"),
|
||||
skip("addcmul"), # slow: fails with --timeout=360 secs
|
||||
skip("deg2rad"), # slow: fails with --timeout=360 secs
|
||||
skip("diag_embed"), # slow: fails with --timeout=360 secs
|
||||
skip("frac"), # slow: fails with --timeout=360 secs
|
||||
skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs
|
||||
xfail("lerp"),
|
||||
skip("logaddexp"), # slow: fails with --timeout=360 secs
|
||||
skip("native_dropout_backward"), # slow: fails with --timeout=360 secs
|
||||
xfail("nn.functional.binary_cross_entropy_with_logits"),
|
||||
skip("nn.functional.glu"), # slow: fails with --timeout=360 secs
|
||||
xfail("nn.functional.hardshrink"),
|
||||
xfail("nn.functional.softshrink"),
|
||||
skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs
|
||||
xfail("norm"),
|
||||
xfail("norm", "fro"),
|
||||
xfail("norm", "inf"),
|
||||
xfail("norm", "nuc"),
|
||||
skip("rad2deg"), # slow: fails with --timeout=360 secs
|
||||
skip("renorm"), # slow: fails with --timeout=360 secs
|
||||
skip("rot90"), # slow: fails with --timeout=360 secs
|
||||
skip("rsub"), # slow: fails with --timeout=360 secs
|
||||
skip("sgn"), # slow: fails with --timeout=360 secs
|
||||
skip("special.xlog1py"), # slow: fails with --timeout=360 secs
|
||||
xfail("stack"),
|
||||
skip("tril"), # slow: fails with --timeout=360 secs
|
||||
skip("triu"), # slow: fails with --timeout=360 secs
|
||||
skip("unfold_copy"), # slow: fails with --timeout=360 secs
|
||||
skip("xlogy"), # slow: fails with --timeout=360 secs
|
||||
xfail("zero_"),
|
||||
}
|
||||
if not TEST_WITH_SLOW:
|
||||
core_backward_failures.update({
|
||||
skip('addr'), # slow: takes 46 sec on A100
|
||||
skip('baddbmm'), # slow: takes 800+ sec on A100
|
||||
skip('clamp_min'), # slow: takes 800 sec on A100
|
||||
skip('clamp_max'), # slow: takes 800 sec on A100
|
||||
skip('logit'), # slow: takes 44 sec on A100
|
||||
skip('nn.functional.hardswish'), # slow: takes 60 sec on A100
|
||||
skip('std_mean'), # slow: takes 170 sec on A100
|
||||
skip('split', variant_name='list_args'), # slow: takes 118 sec on A100
|
||||
skip('transpose'), # slow: takes 50 sec on A100
|
||||
skip('unbind'), # slow: takes 70 sec on A100
|
||||
skip('unsafe_split'), # slow: takes 49 sec on A100
|
||||
})
|
||||
core_backward_failures.update(
|
||||
{
|
||||
skip("addr"), # slow: takes 46 sec on A100
|
||||
skip("baddbmm"), # slow: takes 800+ sec on A100
|
||||
skip("clamp_min"), # slow: takes 800 sec on A100
|
||||
skip("clamp_max"), # slow: takes 800 sec on A100
|
||||
skip("logit"), # slow: takes 44 sec on A100
|
||||
skip("nn.functional.hardswish"), # slow: takes 60 sec on A100
|
||||
skip("std_mean"), # slow: takes 170 sec on A100
|
||||
skip("split", variant_name="list_args"), # slow: takes 118 sec on A100
|
||||
skip("transpose"), # slow: takes 50 sec on A100
|
||||
skip("unbind"), # slow: takes 70 sec on A100
|
||||
skip("unsafe_split"), # slow: takes 49 sec on A100
|
||||
}
|
||||
)
|
||||
|
||||
comprehensive_failures = {
|
||||
xfail("nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)), # off by one error
|
||||
xfail("nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)), # off by one error
|
||||
xfail("nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)), # off by one error
|
||||
xfail(
|
||||
"nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
|
||||
), # off by one error
|
||||
xfail(
|
||||
"nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
|
||||
), # off by one error
|
||||
xfail(
|
||||
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
|
||||
), # off by one error
|
||||
}
|
||||
|
||||
|
||||
@unMarkDynamoStrictTest
|
||||
class TestDecomp(TestCase):
|
||||
longMessage = True
|
||||
@ -517,7 +546,7 @@ class TestDecomp(TestCase):
|
||||
self.do_cross_ref(device, dtype, op, run_all=False)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@skipOps('TestDecomp', 'test_quick_core_backward', core_backward_failures)
|
||||
@skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
|
||||
@onlyNativeDeviceTypes
|
||||
@skipIfCrossRef
|
||||
@suppress_warnings
|
||||
@ -528,15 +557,16 @@ class TestDecomp(TestCase):
|
||||
args = [sample_input.input] + list(sample_input.args)
|
||||
kwargs = sample_input.kwargs
|
||||
func = partial(op.get_op(), **kwargs)
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all=False)\
|
||||
as mode, enable_python_dispatcher():
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all=False
|
||||
) as mode, enable_python_dispatcher():
|
||||
torch.autograd.gradcheck(func, args)
|
||||
self.check_decomposed(aten_name, mode)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyNativeDeviceTypes
|
||||
@skipIfCrossRef
|
||||
@skipOps('TestDecomp', 'test_comprehensive', comprehensive_failures)
|
||||
@skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
|
||||
@suppress_warnings
|
||||
@ops(op_db)
|
||||
def test_comprehensive(self, device, dtype, op):
|
||||
@ -560,7 +590,9 @@ class TestDecomp(TestCase):
|
||||
xs = torch.ones([2, 10], device=device)
|
||||
|
||||
def index_copy(xs, x):
|
||||
torch._decomp.decompositions.index_copy_(xs, 0, torch.tensor(0).to(device), x)
|
||||
torch._decomp.decompositions.index_copy_(
|
||||
xs, 0, torch.tensor(0).to(device), x
|
||||
)
|
||||
|
||||
index_copy(xs, x)
|
||||
|
||||
@ -574,7 +606,9 @@ class TestDecomp(TestCase):
|
||||
# are <= 0, and b) whether we're in training mode. Cover all cases:
|
||||
dtype = torch.float64
|
||||
x = torch.tensor(
|
||||
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device,
|
||||
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0],
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
lower = 1.0
|
||||
upper = 4.0
|
||||
@ -587,7 +621,11 @@ class TestDecomp(TestCase):
|
||||
torch.manual_seed(123)
|
||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
||||
x, noise_res, lower, upper, training,
|
||||
x,
|
||||
noise_res,
|
||||
lower,
|
||||
upper,
|
||||
training,
|
||||
)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(noise_ref, noise_res)
|
||||
@ -602,30 +640,51 @@ class TestDecomp(TestCase):
|
||||
torch.manual_seed(123)
|
||||
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
|
||||
res = torch._decomp.decompositions.rrelu_with_noise(
|
||||
x, noise_res, lower, upper, training,
|
||||
x,
|
||||
noise_res,
|
||||
lower,
|
||||
upper,
|
||||
training,
|
||||
)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertEqual(noise_ref, noise_res)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@suppress_warnings
|
||||
@tf32_off()
|
||||
# only tests RNNs since we have py dispsatcher decomps for them
|
||||
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
|
||||
@modules(
|
||||
filter(
|
||||
lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
|
||||
module_db,
|
||||
)
|
||||
)
|
||||
def test_rnn_decomp_module(self, device, dtype, module_info, training):
|
||||
module_cls = module_info.module_cls
|
||||
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
|
||||
requires_grad=True, training=training)
|
||||
module_inputs = module_info.module_inputs_func(
|
||||
module_info,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
requires_grad=True,
|
||||
training=training,
|
||||
)
|
||||
for module_input in module_inputs:
|
||||
if module_input.forward_input is None:
|
||||
continue
|
||||
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
|
||||
args, kwargs = (
|
||||
module_input.constructor_input.args,
|
||||
module_input.constructor_input.kwargs,
|
||||
)
|
||||
m = module_cls(*args, **kwargs)
|
||||
m.to(device).to(dtype)
|
||||
|
||||
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all=True), enable_python_dispatcher():
|
||||
args, kwargs = (
|
||||
module_input.forward_input.args,
|
||||
module_input.forward_input.kwargs,
|
||||
)
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all=True
|
||||
), enable_python_dispatcher():
|
||||
decomp_out = m(*args, **kwargs)
|
||||
|
||||
non_decomp_out = m(*args, **kwargs)
|
||||
@ -641,7 +700,9 @@ class TestDecomp(TestCase):
|
||||
bias = torch.randn(3, device=device)
|
||||
mean = torch.randn(3, device=device)
|
||||
var = torch.randn(3, device=device)
|
||||
res = torch._decomp.decompositions.native_batch_norm(input, weight, bias, mean, var, False, 1, 1e-05)
|
||||
res = torch._decomp.decompositions.native_batch_norm(
|
||||
input, weight, bias, mean, var, False, 1, 1e-05
|
||||
)
|
||||
self.assertEqual(shape, res[0].shape)
|
||||
|
||||
def test_arange_graph(self, device):
|
||||
@ -662,29 +723,40 @@ class TestDecomp(TestCase):
|
||||
fx_g_code = fx_g.code.strip()
|
||||
# Remove device and requires_grad
|
||||
fx_g_code = re.sub(pattern, "", fx_g_code)
|
||||
self.assertExpectedInline(fx_g_code, """\
|
||||
self.assertExpectedInline(
|
||||
fx_g_code,
|
||||
"""\
|
||||
def forward(self, x_1, start_1):
|
||||
iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
|
||||
mul = torch.ops.prims.mul.default(iota, 1); iota = None
|
||||
add = torch.ops.prims.add.default(mul, 0); mul = None
|
||||
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
|
||||
return convert_element_type""")
|
||||
return convert_element_type""",
|
||||
)
|
||||
|
||||
fx_g = cfunc(torch.rand(10, device=device), 1)
|
||||
fx_g_code = fx_g.code.strip()
|
||||
# Remove device and requires_grad
|
||||
fx_g_code = re.sub(pattern, "", fx_g_code)
|
||||
self.assertExpectedInline(fx_g_code, """\
|
||||
self.assertExpectedInline(
|
||||
fx_g_code,
|
||||
"""\
|
||||
def forward(self, x_1, start_1):
|
||||
iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
|
||||
mul = torch.ops.prims.mul.default(iota, 1); iota = None
|
||||
add = torch.ops.prims.add.default(mul, 1); mul = None
|
||||
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
|
||||
return convert_element_type""")
|
||||
return convert_element_type""",
|
||||
)
|
||||
|
||||
def test_masked_fill(self, device):
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
if torch.device(device).type not in ["xpu", "cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
|
||||
if torch.device(device).type not in [
|
||||
"xpu",
|
||||
"cuda",
|
||||
torch._C._get_privateuse1_backend_name(),
|
||||
]:
|
||||
self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
|
||||
|
||||
def func(scores, mask, value):
|
||||
@ -695,10 +767,13 @@ def forward(self, x_1, start_1):
|
||||
value_t = torch.tensor(0, dtype=scores_t.dtype)
|
||||
cfunc = make_fx(func, decomposition_table=decomposition_table)
|
||||
fx_g = cfunc(scores_t, mask_t, value_t)
|
||||
self.assertExpectedInline(fx_g.code.strip(), """\
|
||||
self.assertExpectedInline(
|
||||
fx_g.code.strip(),
|
||||
"""\
|
||||
def forward(self, scores_1, mask_1, value_1):
|
||||
where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None
|
||||
return where""")
|
||||
return where""",
|
||||
)
|
||||
|
||||
class DecompCrossRefMode(TorchDispatchMode):
|
||||
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
|
||||
@ -724,7 +799,7 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
# Stuff we shouldn't bother testing
|
||||
# (TODO: remove detach from the decomp table?)
|
||||
# N.b. Testing in-place ops would need dedicated logic
|
||||
in_place = func.name()[-1] == '_'
|
||||
in_place = func.name()[-1] == "_"
|
||||
ignored_ops = [
|
||||
torch.ops.aten.detach.default,
|
||||
# non-deterministic ops
|
||||
@ -737,11 +812,11 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
torch.ops.aten.native_dropout.default,
|
||||
]
|
||||
if (
|
||||
func not in decomposition_table or
|
||||
func in ignored_ops or
|
||||
torch.Tag.nondeterministic_seeded in func.tags or
|
||||
any_unsupported(args, kwargs) or
|
||||
in_place
|
||||
func not in decomposition_table
|
||||
or func in ignored_ops
|
||||
or torch.Tag.nondeterministic_seeded in func.tags
|
||||
or any_unsupported(args, kwargs)
|
||||
or in_place
|
||||
):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
@ -789,29 +864,51 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
real_out_double, _ = tree_flatten(
|
||||
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
|
||||
)
|
||||
for i, (orig, decomp, ref) in enumerate(zip(real_out, decomp_out, real_out_double)):
|
||||
for i, (orig, decomp, ref) in enumerate(
|
||||
zip(real_out, decomp_out, real_out_double)
|
||||
):
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_ref(self.test_case, func, self.test_dtype, i, orig, decomp, ref, args, kwargs)
|
||||
op_assert_ref(
|
||||
self.test_case,
|
||||
func,
|
||||
self.test_dtype,
|
||||
i,
|
||||
orig,
|
||||
decomp,
|
||||
ref,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
else:
|
||||
for orig, decomp in zip(real_out, decomp_out):
|
||||
if not isinstance(orig, torch.Tensor):
|
||||
assert type(orig) == type(decomp)
|
||||
assert orig == decomp
|
||||
continue
|
||||
op_assert_equal(self.test_case, func, self.test_dtype, orig, decomp, args, kwargs)
|
||||
op_assert_equal(
|
||||
self.test_case,
|
||||
func,
|
||||
self.test_dtype,
|
||||
orig,
|
||||
decomp,
|
||||
args,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
return real_out_unflat
|
||||
|
||||
def check_decomposed(self, aten_name, mode):
|
||||
self.assertTrue(
|
||||
any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
|
||||
msg=(f"aten.{aten_name} was not decomposed, saw calls for: "
|
||||
f"{', '.join(map(str, list(mode.called)))}. If your op is "
|
||||
f"CompositeImplicitAutograd you should skip this test "
|
||||
f"by updating CROSS_REF_EXCLUDE_SET.")
|
||||
msg=(
|
||||
f"aten.{aten_name} was not decomposed, saw calls for: "
|
||||
f"{', '.join(map(str, list(mode.called)))}. If your op is "
|
||||
f"CompositeImplicitAutograd you should skip this test "
|
||||
f"by updating CROSS_REF_EXCLUDE_SET."
|
||||
),
|
||||
)
|
||||
|
||||
@skipIfTorchDynamo("Test does not work with TorchDynamo")
|
||||
@ -824,7 +921,9 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
|
||||
self.skipTest(f"{op.name} in {dtype} not supported")
|
||||
|
||||
skip_decomp_vjp = any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys)
|
||||
skip_decomp_vjp = any(
|
||||
key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
|
||||
)
|
||||
|
||||
requires_grad = (
|
||||
op.supports_autograd
|
||||
@ -842,9 +941,13 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
func = op.get_op()
|
||||
|
||||
def run_without_python_dispatcher(mode):
|
||||
return any(isinstance(op, torch._ops.OpOverload) and
|
||||
op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
|
||||
for op in mode.decomposed.union([func]))
|
||||
return any(
|
||||
isinstance(op, torch._ops.OpOverload)
|
||||
and op.has_kernel_for_dispatch_key(
|
||||
DispatchKey.CompositeImplicitAutograd
|
||||
)
|
||||
for op in mode.decomposed.union([func])
|
||||
)
|
||||
|
||||
for sample_input in samples:
|
||||
if requires_grad:
|
||||
@ -857,29 +960,35 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
# store the called list on the mode object instance and no
|
||||
# explicit clearing is necessary as I will create a fresh mode
|
||||
# for each region
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode, enable_python_dispatcher():
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all
|
||||
) as mode, enable_python_dispatcher():
|
||||
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
|
||||
if run_without_python_dispatcher(mode):
|
||||
# without this check, incorrect decomps at the python dispatcher level can still pass because
|
||||
# they're checking aten decomps at the torch_dispatch level.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode:
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all
|
||||
) as mode:
|
||||
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
|
||||
if aten_name in decomposition_names:
|
||||
self.check_decomposed(aten_name, mode)
|
||||
|
||||
if not skip_decomp_vjp and (op.aten_backward_name in decomposition_names or run_all):
|
||||
if not skip_decomp_vjp and (
|
||||
op.aten_backward_name in decomposition_names or run_all
|
||||
):
|
||||
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
|
||||
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode, enable_python_dispatcher():
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all
|
||||
) as mode, enable_python_dispatcher():
|
||||
decomp_vjp_fn(cotangents)
|
||||
if run_without_python_dispatcher(mode):
|
||||
# without this check, incorrect decomps at the python dispatcher level can still pass because
|
||||
# they're checking aten decomps at the torch_dispatch level.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode:
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all
|
||||
) as mode:
|
||||
decomp_vjp_fn(cotangents)
|
||||
if not run_all:
|
||||
self.check_decomposed(op.aten_backward_name, mode)
|
||||
@ -889,15 +998,17 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
kwargs = sample_input.kwargs
|
||||
# A failure here might be because the decomposition for the op is wrong or because a
|
||||
# decomposition used by the particular op is wrong.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode, enable_python_dispatcher():
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all
|
||||
) as mode, enable_python_dispatcher():
|
||||
func(*args, **kwargs)
|
||||
|
||||
if run_without_python_dispatcher(mode):
|
||||
# without this check, incorrect decomps at the python dispatcher level can still pass because
|
||||
# they're checking aten decomps at the torch_dispatch level.
|
||||
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
|
||||
as mode:
|
||||
with self.DecompCrossRefMode(
|
||||
self, self.precision, self.rel_tol, dtype, run_all
|
||||
) as mode:
|
||||
func(*args, **kwargs)
|
||||
|
||||
if not run_all:
|
||||
@ -908,6 +1019,7 @@ def forward(self, scores_1, mask_1, value_1):
|
||||
"only backwards is decomposed, but dtype doesn't support AD"
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestDecomp, globals())
|
||||
|
||||
|
||||
@ -964,7 +1076,8 @@ class DecompOneOffTests(TestCase):
|
||||
mean,
|
||||
False,
|
||||
1e-05,
|
||||
[True, True, True])
|
||||
[True, True, True],
|
||||
)
|
||||
res = torch._decomp.decompositions.native_batch_norm_backward(
|
||||
grad_out,
|
||||
x,
|
||||
@ -975,12 +1088,12 @@ class DecompOneOffTests(TestCase):
|
||||
mean,
|
||||
False,
|
||||
1e-05,
|
||||
[True, True, True])
|
||||
for (a, b) in zip(ref, res):
|
||||
[True, True, True],
|
||||
)
|
||||
for a, b in zip(ref, res):
|
||||
self.assertEqual(a.stride(), b.stride())
|
||||
self.assertEqual(a.dtype, b.dtype)
|
||||
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyNativeDeviceTypes
|
||||
@skipIfCrossRef
|
||||
@ -1022,15 +1135,22 @@ class DecompOneOffTests(TestCase):
|
||||
|
||||
self.assertEqual(
|
||||
torch.ops.aten._weight_norm_interface(inp, inp2),
|
||||
torch._decomp.decompositions._weight_norm_interface(inp, inp2)
|
||||
torch._decomp.decompositions._weight_norm_interface(inp, inp2),
|
||||
)
|
||||
|
||||
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
||||
@onlyCPU
|
||||
@skipIfCrossRef
|
||||
@skipOps('DecompOneOffTests', 'test_sdpa', [
|
||||
xfail("nn.functional.scaled_dot_product_attention", dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else [])),
|
||||
])
|
||||
@skipOps(
|
||||
"DecompOneOffTests",
|
||||
"test_sdpa",
|
||||
[
|
||||
xfail(
|
||||
"nn.functional.scaled_dot_product_attention",
|
||||
dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else []),
|
||||
),
|
||||
],
|
||||
)
|
||||
@ops(_sdpa_op_info)
|
||||
def test_sdpa(self, device, dtype, op):
|
||||
# SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
|
||||
@ -1040,13 +1160,19 @@ class DecompOneOffTests(TestCase):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, query_layer, key_layer, value_layer, mask=None, is_causal=True):
|
||||
def forward(
|
||||
self, query_layer, key_layer, value_layer, mask=None, is_causal=True
|
||||
):
|
||||
attn_output = op(
|
||||
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
return attn_output
|
||||
|
||||
|
||||
query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
|
||||
key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
|
||||
value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
|
||||
@ -1057,20 +1183,28 @@ class DecompOneOffTests(TestCase):
|
||||
for mask in masks:
|
||||
is_causal = mask is None
|
||||
attention = ScaledDotProductAttention()
|
||||
decomposed_res = torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
|
||||
query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
|
||||
decomposed_res = (
|
||||
torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
|
||||
query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
|
||||
)
|
||||
)
|
||||
eager_res = op(
|
||||
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
attn_mask=mask,
|
||||
dropout_p=0.0,
|
||||
is_causal=is_causal,
|
||||
)
|
||||
|
||||
self.assertTrue(torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol))
|
||||
self.assertTrue(
|
||||
torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
|
||||
)
|
||||
|
||||
|
||||
instantiate_device_type_tests(DecompOneOffTests, globals())
|
||||
|
||||
|
||||
|
||||
class HasDecompTest(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
@ -1080,22 +1214,24 @@ class HasDecompTest(TestCase):
|
||||
def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
|
||||
has_tensor_arg = any(
|
||||
"Tensor" in str(a.type)
|
||||
for a in itertools.chain(op._schema.arguments, op._schema.returns))
|
||||
for a in itertools.chain(op._schema.arguments, op._schema.returns)
|
||||
)
|
||||
if not has_tensor_arg:
|
||||
return False
|
||||
|
||||
try:
|
||||
# CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
|
||||
return not op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
|
||||
return not op.has_kernel_for_dispatch_key(
|
||||
DispatchKey.CompositeImplicitAutograd
|
||||
)
|
||||
except RuntimeError as e:
|
||||
# has_key fails for some jit-registered ops, which shouldn't be
|
||||
# relevant here anyway
|
||||
if 'does not exist' in str(e):
|
||||
if "does not exist" in str(e):
|
||||
return False
|
||||
raise
|
||||
|
||||
def test_has_decomposition(self):
|
||||
|
||||
def all_aten_overloads():
|
||||
for name in torch._C._dispatch_get_all_op_names():
|
||||
if not name.startswith("aten::"):
|
||||
@ -1116,11 +1252,14 @@ class HasDecompTest(TestCase):
|
||||
# configurations, so would cause the test to fail
|
||||
allow_list = {aten.get_gradients.default}
|
||||
|
||||
overloads_wanting_decomp = {op for op in all_aten_overloads()
|
||||
if self._can_appear_in_trace(op)}
|
||||
overloads_wanting_decomp = {
|
||||
op for op in all_aten_overloads() if self._can_appear_in_trace(op)
|
||||
}
|
||||
ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
|
||||
ops_missing_decomp -= allow_list
|
||||
self.assertExpected("".join(sorted(op.name() + "\n" for op in ops_missing_decomp)))
|
||||
self.assertExpected(
|
||||
"".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
|
||||
)
|
||||
|
||||
def test_aten_core_operators(self):
|
||||
# If a decomposition isn't included in the core decompositions,
|
||||
@ -1136,9 +1275,11 @@ class HasDecompTest(TestCase):
|
||||
|
||||
# Some decompositions are registered for CompositeImplicitAutograd
|
||||
# operators, which never appear in AOTAutograd's graph so are never used.
|
||||
useful_decomps = {op for op in decomposition_table.keys()
|
||||
if isinstance(op, torch._ops.OpOverload) and
|
||||
self._can_appear_in_trace(op)}
|
||||
useful_decomps = {
|
||||
op
|
||||
for op in decomposition_table.keys()
|
||||
if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
|
||||
}
|
||||
core_decomps = torch._decomp.core_aten_decompositions().keys()
|
||||
core_aten_ops = useful_decomps - core_decomps
|
||||
self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))
|
||||
|
Reference in New Issue
Block a user