Enable UFMT on test_decomp.py, test_expanded_weights.py and some files (#125117)

Part of: #123062

Ran lintrunner on:

- test/test_decomp.py
- test/test_deploy.py
- test/test_determination.py
- test/test_dlpack.py
- test/test_dynamic_shapes.py
- test/test_expanded_weights.py

Detail:

```bash
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125117
Approved by: https://github.com/jansel
This commit is contained in:
Yuanhao Ji
2024-05-07 02:36:36 +00:00
committed by PyTorch MergeBot
parent 48b6c8dbc3
commit c165a8e71d
7 changed files with 1211 additions and 511 deletions

View File

@ -1052,12 +1052,6 @@ exclude_patterns = [
'test/quantization/fx/test_quantize_fx.py',
'test/quantization/fx/test_subgraph_rewriter.py',
'test/test_datapipe.py',
'test/test_decomp.py',
'test/test_deploy.py',
'test/test_determination.py',
'test/test_dlpack.py',
'test/test_dynamic_shapes.py',
'test/test_expanded_weights.py',
'test/test_fake_tensor.py',
'test/test_flop_counter.py',
'test/test_function_schema.py',

View File

@ -1,45 +1,51 @@
# Owner(s): ["module: decompositions"]
from collections import defaultdict
from torch import Tensor
import torch.autograd
from torch._decomp import core_aten_decompositions, decomposition_table
from torch.utils._python_dispatch import TorchDispatchMode
import functools
from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
from torch.utils import _pytree as pytree
import itertools
import re
import unittest
from collections import defaultdict
from functools import partial
import torch.autograd
from torch import Tensor
from torch._decomp import core_aten_decompositions, decomposition_table
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import DispatchKey
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_off
from torch.testing._internal.common_utils import unMarkDynamoStrictTest
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
IS_WINDOWS,
IS_MACOS,
TestCase,
skipIfCrossRef,
suppress_warnings,
TEST_WITH_ASAN,
TEST_WITH_SLOW,
run_tests,
skipIfTorchDynamo,
)
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_device_type import (
onlyNativeDeviceTypes,
ops,
instantiate_device_type_tests,
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
ops,
)
from torch.testing._internal.common_methods_invocations import op_db, skip, skipOps, xfail
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import DispatchKey
from torch.testing._internal.common_methods_invocations import (
op_db,
skip,
skipOps,
xfail,
)
from torch.testing._internal.common_modules import module_db, modules
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
IS_MACOS,
IS_WINDOWS,
run_tests,
skipIfCrossRef,
skipIfTorchDynamo,
suppress_warnings,
TEST_WITH_ASAN,
TEST_WITH_SLOW,
TestCase,
unMarkDynamoStrictTest,
)
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
import itertools
import functools
from functools import partial
import re
import unittest
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
aten = torch.ops.aten
@ -51,11 +57,13 @@ def overload_to_aten_name(op):
# All operators that can have decomp tests
decomposition_names = {
overload_to_aten_name(k) for k in decomposition_table
overload_to_aten_name(k)
for k in decomposition_table
if isinstance(k, torch._ops.OpOverload)
}
core_decomposition_names = {
overload_to_aten_name(k) for k in core_aten_decompositions()
overload_to_aten_name(k)
for k in core_aten_decompositions()
if isinstance(k, torch._ops.OpOverload)
}
_decomp_test_ops = [
@ -67,12 +75,9 @@ _decomp_test_ops = [
_decomp_test_ops_core_autograd = [
op
for op in op_db
if op.aten_name in core_decomposition_names
and op.supports_autograd
]
_sdpa_op_info = [
op for op in op_db if "scaled_dot_product_attention" in op.aten_name
if op.aten_name in core_decomposition_names and op.supports_autograd
]
_sdpa_op_info = [op for op in op_db if "scaled_dot_product_attention" in op.aten_name]
def diff_arg(arg, requires_grad=True):
@ -144,7 +149,10 @@ def ref_vjp_no_create(f, *primals):
def wrapped(cotangents):
return _autograd_grad(
_as_tuple(result), primals, _as_tuple(cotangents), create_graph=False,
_as_tuple(result),
primals,
_as_tuple(cotangents),
create_graph=False,
retain_graph=True,
)
@ -230,7 +238,10 @@ def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs)
def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
test_case.assertEqual(
orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}")
orig.dtype,
decomp.dtype,
f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}",
)
# Before adding an entry to this table, make sure your decomposition is right :)
tol_table = {
# Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161
@ -241,42 +252,48 @@ def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs):
),
(torch.float64, torch.ops.aten.native_layer_norm.default): (1e-6, 1e-6),
# This exceeds default tolerances only on CPU, on CUDA it's fine
(torch.float32, torch.ops.aten.grid_sampler_2d.default) : (7e-6, 3e-5),
(torch.float32, torch.ops.aten.grid_sampler_2d.default): (7e-6, 3e-5),
# Exceeds tolerances on CUDA, likely due to fma
(torch.float32, torch.ops.aten.mv.default) : (1e-5, 3e-5),
(torch.float32, torch.ops.aten.mv.default): (1e-5, 3e-5),
(torch.complex64, torch.ops.aten.mv.default): (5e-5, 5e-5),
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec) : (1e-5, 5e-4),
(torch.float64, torch.ops.aten.upsample_bicubic2d.default) : (1e-5, 5e-4),
(torch.float64, torch.ops.aten.upsample_bicubic2d.vec): (1e-5, 5e-4),
(torch.float64, torch.ops.aten.upsample_bicubic2d.default): (1e-5, 5e-4),
# The decomposition is TOO correct. It computes everything in int64, so sometimes
# there's an off-by-one error. See
# https://github.com/pytorch/pytorch/issues/81996
# https://github.com/pytorch/pytorch/issues/82230
(torch.int8, torch.ops.aten.linspace.default) : (0, 1),
(torch.uint8, torch.ops.aten.linspace.default) : (0, 1),
(torch.int16, torch.ops.aten.linspace.default) : (0, 1),
(torch.int32, torch.ops.aten.linspace.default) : (0, 1),
(torch.int64, torch.ops.aten.linspace.default) : (0, 1),
(torch.int8, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
(torch.uint8, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
(torch.int16, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
(torch.int32, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
(torch.int64, torch.ops.aten.linspace.Tensor_Tensor) : (0, 1),
(torch.int8, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
(torch.uint8, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
(torch.int16, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
(torch.int32, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
(torch.int64, torch.ops.aten.linspace.Tensor_Scalar) : (0, 1),
(torch.int8, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
(torch.uint8, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
(torch.int16, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
(torch.int32, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
(torch.int64, torch.ops.aten.linspace.Scalar_Tensor) : (0, 1),
(torch.int8, torch.ops.aten.linspace.default): (0, 1),
(torch.uint8, torch.ops.aten.linspace.default): (0, 1),
(torch.int16, torch.ops.aten.linspace.default): (0, 1),
(torch.int32, torch.ops.aten.linspace.default): (0, 1),
(torch.int64, torch.ops.aten.linspace.default): (0, 1),
(torch.int8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
(torch.uint8, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
(torch.int16, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
(torch.int32, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
(torch.int64, torch.ops.aten.linspace.Tensor_Tensor): (0, 1),
(torch.int8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
(torch.uint8, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
(torch.int16, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
(torch.int32, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
(torch.int64, torch.ops.aten.linspace.Tensor_Scalar): (0, 1),
(torch.int8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
(torch.uint8, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
(torch.int16, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
(torch.int32, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
(torch.int64, torch.ops.aten.linspace.Scalar_Tensor): (0, 1),
}
if (decomp.dtype, op) in tol_table:
rtol, atol = tol_table[(decomp.dtype, op)]
else:
rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype)
test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}")
test_case.assertEqual(
orig,
decomp,
rtol=rtol,
atol=atol,
msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}",
)
# Given f, returns an f' such that:
@ -322,8 +339,11 @@ def normalize_op_input_output2(
def upcast_tensor(x, dtype=torch.float32):
if isinstance(x, Tensor) and x.dtype.is_floating_point:
return x.to(dtype=dtype)
elif (isinstance(x, torch.dtype)
and x in [torch.float16, torch.bfloat16, torch.float]):
elif isinstance(x, torch.dtype) and x in [
torch.float16,
torch.bfloat16,
torch.float,
]:
return dtype
else:
return x
@ -352,20 +372,16 @@ CROSS_REF_EXCLUDE_SET = {
(None, None, "new_empty"),
(None, None, "empty_like"),
(None, None, "empty"),
# AssertionError: False is not true : aten.item was not decomposed, saw calls for: aten._local_scalar_dense.default.
(None, None, "item"),
# It's the only in-place op without an out-of-place equivalent in the Python API
# Its OpInfo wrongly registers it as `torch.zero_(x.clone())`.
(None, None, "zero_"),
# No idea what's going on here
# In the recursive test logsumexp.default fails with args = (torch.tensor(-math.inf), [])
# in the test, but it seems to pass when tested locally and in the logsumexp test
(None, torch.float32, "masked.logsumexp"),
(None, torch.float64, "masked.logsumexp"),
# exp_vml_cpu not implemented for Half
(torch.cpu, torch.float16, "signal.windows.exponential"),
(torch.cpu, torch.float16, "signal.windows.gaussian"),
@ -387,9 +403,7 @@ CROSS_REF_EXCLUDE_SET = {
(None, None, "norm"),
# native_batch_norm is only implicit when python dispatcher is on (and noncomposite otherwise)
(None, None, "native_batch_norm"),
(None, None, "_upsample_bilinear2d_aa"),
(None, None, "empty_strided"), # aten.empty_strided was not decomposed
}
@ -432,10 +446,16 @@ def any_unsupported(args, kwargs):
if type(t) is torch.Tensor or type(t) is torch.nn.Parameter:
# These are all things that we haven't coded decompositions
# to handle correctly. Maybe they should.
return any([
t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized,
t.is_nested, torch._is_functional_tensor(t),
])
return any(
[
t.is_sparse_csr,
t.is_sparse,
t.is_mkldnn,
t.is_quantized,
t.is_nested,
torch._is_functional_tensor(t),
]
)
elif torch.overrides.is_tensor_like(t):
# Decompositions will generally change the behavior of Tensor-like
# subclasses, so bypass tests in this case too
@ -448,59 +468,68 @@ def any_unsupported(args, kwargs):
core_backward_failures = {
skip('_softmax_backward_data'), # slow: fails with --timeout=360 secs
xfail('addcdiv'),
skip('addcmul'), # slow: fails with --timeout=360 secs
skip('deg2rad'), # slow: fails with --timeout=360 secs
skip('diag_embed'), # slow: fails with --timeout=360 secs
skip('frac'), # slow: fails with --timeout=360 secs
skip('grid_sampler_2d'), # slow: fails with --timeout=360 secs
xfail('lerp'),
skip('logaddexp'), # slow: fails with --timeout=360 secs
skip('native_dropout_backward'), # slow: fails with --timeout=360 secs
xfail('nn.functional.binary_cross_entropy_with_logits'),
skip('nn.functional.glu'), # slow: fails with --timeout=360 secs
xfail('nn.functional.hardshrink'),
xfail('nn.functional.softshrink'),
skip('nn.functional.unfold'), # slow: fails with --timeout=360 secs
xfail('norm'),
xfail('norm', 'fro'),
xfail('norm', 'inf'),
xfail('norm', 'nuc'),
skip('rad2deg'), # slow: fails with --timeout=360 secs
skip('renorm'), # slow: fails with --timeout=360 secs
skip('rot90'), # slow: fails with --timeout=360 secs
skip('rsub'), # slow: fails with --timeout=360 secs
skip('sgn'), # slow: fails with --timeout=360 secs
skip('special.xlog1py'), # slow: fails with --timeout=360 secs
xfail('stack'),
skip('tril'), # slow: fails with --timeout=360 secs
skip('triu'), # slow: fails with --timeout=360 secs
skip('unfold_copy'), # slow: fails with --timeout=360 secs
skip('xlogy'), # slow: fails with --timeout=360 secs
xfail('zero_'),
skip("_softmax_backward_data"), # slow: fails with --timeout=360 secs
xfail("addcdiv"),
skip("addcmul"), # slow: fails with --timeout=360 secs
skip("deg2rad"), # slow: fails with --timeout=360 secs
skip("diag_embed"), # slow: fails with --timeout=360 secs
skip("frac"), # slow: fails with --timeout=360 secs
skip("grid_sampler_2d"), # slow: fails with --timeout=360 secs
xfail("lerp"),
skip("logaddexp"), # slow: fails with --timeout=360 secs
skip("native_dropout_backward"), # slow: fails with --timeout=360 secs
xfail("nn.functional.binary_cross_entropy_with_logits"),
skip("nn.functional.glu"), # slow: fails with --timeout=360 secs
xfail("nn.functional.hardshrink"),
xfail("nn.functional.softshrink"),
skip("nn.functional.unfold"), # slow: fails with --timeout=360 secs
xfail("norm"),
xfail("norm", "fro"),
xfail("norm", "inf"),
xfail("norm", "nuc"),
skip("rad2deg"), # slow: fails with --timeout=360 secs
skip("renorm"), # slow: fails with --timeout=360 secs
skip("rot90"), # slow: fails with --timeout=360 secs
skip("rsub"), # slow: fails with --timeout=360 secs
skip("sgn"), # slow: fails with --timeout=360 secs
skip("special.xlog1py"), # slow: fails with --timeout=360 secs
xfail("stack"),
skip("tril"), # slow: fails with --timeout=360 secs
skip("triu"), # slow: fails with --timeout=360 secs
skip("unfold_copy"), # slow: fails with --timeout=360 secs
skip("xlogy"), # slow: fails with --timeout=360 secs
xfail("zero_"),
}
if not TEST_WITH_SLOW:
core_backward_failures.update({
skip('addr'), # slow: takes 46 sec on A100
skip('baddbmm'), # slow: takes 800+ sec on A100
skip('clamp_min'), # slow: takes 800 sec on A100
skip('clamp_max'), # slow: takes 800 sec on A100
skip('logit'), # slow: takes 44 sec on A100
skip('nn.functional.hardswish'), # slow: takes 60 sec on A100
skip('std_mean'), # slow: takes 170 sec on A100
skip('split', variant_name='list_args'), # slow: takes 118 sec on A100
skip('transpose'), # slow: takes 50 sec on A100
skip('unbind'), # slow: takes 70 sec on A100
skip('unsafe_split'), # slow: takes 49 sec on A100
})
core_backward_failures.update(
{
skip("addr"), # slow: takes 46 sec on A100
skip("baddbmm"), # slow: takes 800+ sec on A100
skip("clamp_min"), # slow: takes 800 sec on A100
skip("clamp_max"), # slow: takes 800 sec on A100
skip("logit"), # slow: takes 44 sec on A100
skip("nn.functional.hardswish"), # slow: takes 60 sec on A100
skip("std_mean"), # slow: takes 170 sec on A100
skip("split", variant_name="list_args"), # slow: takes 118 sec on A100
skip("transpose"), # slow: takes 50 sec on A100
skip("unbind"), # slow: takes 70 sec on A100
skip("unsafe_split"), # slow: takes 49 sec on A100
}
)
comprehensive_failures = {
xfail("nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)), # off by one error
xfail("nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)), # off by one error
xfail("nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)), # off by one error
xfail(
"nn.functional.interpolate", "bilinear", dtypes=(torch.uint8,)
), # off by one error
xfail(
"nn.functional.interpolate", "bicubic", dtypes=(torch.uint8,)
), # off by one error
xfail(
"nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,)
), # off by one error
}
@unMarkDynamoStrictTest
class TestDecomp(TestCase):
longMessage = True
@ -517,7 +546,7 @@ class TestDecomp(TestCase):
self.do_cross_ref(device, dtype, op, run_all=False)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@skipOps('TestDecomp', 'test_quick_core_backward', core_backward_failures)
@skipOps("TestDecomp", "test_quick_core_backward", core_backward_failures)
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ -528,15 +557,16 @@ class TestDecomp(TestCase):
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
func = partial(op.get_op(), **kwargs)
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all=False)\
as mode, enable_python_dispatcher():
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=False
) as mode, enable_python_dispatcher():
torch.autograd.gradcheck(func, args)
self.check_decomposed(aten_name, mode)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@skipOps('TestDecomp', 'test_comprehensive', comprehensive_failures)
@skipOps("TestDecomp", "test_comprehensive", comprehensive_failures)
@suppress_warnings
@ops(op_db)
def test_comprehensive(self, device, dtype, op):
@ -560,7 +590,9 @@ class TestDecomp(TestCase):
xs = torch.ones([2, 10], device=device)
def index_copy(xs, x):
torch._decomp.decompositions.index_copy_(xs, 0, torch.tensor(0).to(device), x)
torch._decomp.decompositions.index_copy_(
xs, 0, torch.tensor(0).to(device), x
)
index_copy(xs, x)
@ -574,7 +606,9 @@ class TestDecomp(TestCase):
# are <= 0, and b) whether we're in training mode. Cover all cases:
dtype = torch.float64
x = torch.tensor(
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0], dtype=dtype, device=device,
[-3.0, -2.0, -1.0, 0.0, 1.0, 2.0],
dtype=dtype,
device=device,
)
lower = 1.0
upper = 4.0
@ -587,7 +621,11 @@ class TestDecomp(TestCase):
torch.manual_seed(123)
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
res = torch._decomp.decompositions.rrelu_with_noise(
x, noise_res, lower, upper, training,
x,
noise_res,
lower,
upper,
training,
)
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
@ -602,30 +640,51 @@ class TestDecomp(TestCase):
torch.manual_seed(123)
noise_res = torch.zeros(x.shape, dtype=dtype, device=device)
res = torch._decomp.decompositions.rrelu_with_noise(
x, noise_res, lower, upper, training,
x,
noise_res,
lower,
upper,
training,
)
self.assertEqual(ref, res)
self.assertEqual(noise_ref, noise_res)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@suppress_warnings
@tf32_off()
# only tests RNNs since we have py dispsatcher decomps for them
@modules(filter(lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db))
@modules(
filter(
lambda m: m.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU),
module_db,
)
)
def test_rnn_decomp_module(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=True, training=training)
module_inputs = module_info.module_inputs_func(
module_info,
device=device,
dtype=dtype,
requires_grad=True,
training=training,
)
for module_input in module_inputs:
if module_input.forward_input is None:
continue
args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
args, kwargs = (
module_input.constructor_input.args,
module_input.constructor_input.kwargs,
)
m = module_cls(*args, **kwargs)
m.to(device).to(dtype)
args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all=True), enable_python_dispatcher():
args, kwargs = (
module_input.forward_input.args,
module_input.forward_input.kwargs,
)
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all=True
), enable_python_dispatcher():
decomp_out = m(*args, **kwargs)
non_decomp_out = m(*args, **kwargs)
@ -641,7 +700,9 @@ class TestDecomp(TestCase):
bias = torch.randn(3, device=device)
mean = torch.randn(3, device=device)
var = torch.randn(3, device=device)
res = torch._decomp.decompositions.native_batch_norm(input, weight, bias, mean, var, False, 1, 1e-05)
res = torch._decomp.decompositions.native_batch_norm(
input, weight, bias, mean, var, False, 1, 1e-05
)
self.assertEqual(shape, res[0].shape)
def test_arange_graph(self, device):
@ -662,29 +723,40 @@ class TestDecomp(TestCase):
fx_g_code = fx_g.code.strip()
# Remove device and requires_grad
fx_g_code = re.sub(pattern, "", fx_g_code)
self.assertExpectedInline(fx_g_code, """\
self.assertExpectedInline(
fx_g_code,
"""\
def forward(self, x_1, start_1):
iota = torch.ops.prims.iota.default(10, start = 0, step = 1, dtype = torch.int64)
mul = torch.ops.prims.mul.default(iota, 1); iota = None
add = torch.ops.prims.add.default(mul, 0); mul = None
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
return convert_element_type""")
return convert_element_type""",
)
fx_g = cfunc(torch.rand(10, device=device), 1)
fx_g_code = fx_g.code.strip()
# Remove device and requires_grad
fx_g_code = re.sub(pattern, "", fx_g_code)
self.assertExpectedInline(fx_g_code, """\
self.assertExpectedInline(
fx_g_code,
"""\
def forward(self, x_1, start_1):
iota = torch.ops.prims.iota.default(9, start = 0, step = 1, dtype = torch.int64)
mul = torch.ops.prims.mul.default(iota, 1); iota = None
add = torch.ops.prims.add.default(mul, 1); mul = None
convert_element_type = torch.ops.prims.convert_element_type.default(add, torch.float32); add = None
return convert_element_type""")
return convert_element_type""",
)
def test_masked_fill(self, device):
from torch.fx.experimental.proxy_tensor import make_fx
if torch.device(device).type not in ["xpu", "cuda", torch._C._get_privateuse1_backend_name()]:
if torch.device(device).type not in [
"xpu",
"cuda",
torch._C._get_privateuse1_backend_name(),
]:
self.skipTest("only runs on XPU and CUDA and PrivateUse1.")
def func(scores, mask, value):
@ -695,10 +767,13 @@ def forward(self, x_1, start_1):
value_t = torch.tensor(0, dtype=scores_t.dtype)
cfunc = make_fx(func, decomposition_table=decomposition_table)
fx_g = cfunc(scores_t, mask_t, value_t)
self.assertExpectedInline(fx_g.code.strip(), """\
self.assertExpectedInline(
fx_g.code.strip(),
"""\
def forward(self, scores_1, mask_1, value_1):
where = torch.ops.prims.where.default(mask_1, value_1, scores_1); mask_1 = value_1 = scores_1 = None
return where""")
return where""",
)
class DecompCrossRefMode(TorchDispatchMode):
def __init__(self, test_case, saved_precision, saved_rel_tol, dtype, run_all):
@ -724,7 +799,7 @@ def forward(self, scores_1, mask_1, value_1):
# Stuff we shouldn't bother testing
# (TODO: remove detach from the decomp table?)
# N.b. Testing in-place ops would need dedicated logic
in_place = func.name()[-1] == '_'
in_place = func.name()[-1] == "_"
ignored_ops = [
torch.ops.aten.detach.default,
# non-deterministic ops
@ -737,11 +812,11 @@ def forward(self, scores_1, mask_1, value_1):
torch.ops.aten.native_dropout.default,
]
if (
func not in decomposition_table or
func in ignored_ops or
torch.Tag.nondeterministic_seeded in func.tags or
any_unsupported(args, kwargs) or
in_place
func not in decomposition_table
or func in ignored_ops
or torch.Tag.nondeterministic_seeded in func.tags
or any_unsupported(args, kwargs)
or in_place
):
return func(*args, **kwargs)
@ -789,29 +864,51 @@ def forward(self, scores_1, mask_1, value_1):
real_out_double, _ = tree_flatten(
func(*tree_map(upcast, args), **tree_map(upcast, kwargs))
)
for i, (orig, decomp, ref) in enumerate(zip(real_out, decomp_out, real_out_double)):
for i, (orig, decomp, ref) in enumerate(
zip(real_out, decomp_out, real_out_double)
):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_ref(self.test_case, func, self.test_dtype, i, orig, decomp, ref, args, kwargs)
op_assert_ref(
self.test_case,
func,
self.test_dtype,
i,
orig,
decomp,
ref,
args,
kwargs,
)
else:
for orig, decomp in zip(real_out, decomp_out):
if not isinstance(orig, torch.Tensor):
assert type(orig) == type(decomp)
assert orig == decomp
continue
op_assert_equal(self.test_case, func, self.test_dtype, orig, decomp, args, kwargs)
op_assert_equal(
self.test_case,
func,
self.test_dtype,
orig,
decomp,
args,
kwargs,
)
return real_out_unflat
def check_decomposed(self, aten_name, mode):
self.assertTrue(
any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
msg=(f"aten.{aten_name} was not decomposed, saw calls for: "
msg=(
f"aten.{aten_name} was not decomposed, saw calls for: "
f"{', '.join(map(str, list(mode.called)))}. If your op is "
f"CompositeImplicitAutograd you should skip this test "
f"by updating CROSS_REF_EXCLUDE_SET.")
f"by updating CROSS_REF_EXCLUDE_SET."
),
)
@skipIfTorchDynamo("Test does not work with TorchDynamo")
@ -824,7 +921,9 @@ def forward(self, scores_1, mask_1, value_1):
if any(key in CROSS_REF_EXCLUDE_SET for key in test_keys):
self.skipTest(f"{op.name} in {dtype} not supported")
skip_decomp_vjp = any(key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys)
skip_decomp_vjp = any(
key in CROSS_REF_BACKWARD_EXCLUDE_SET for key in test_keys
)
requires_grad = (
op.supports_autograd
@ -842,9 +941,13 @@ def forward(self, scores_1, mask_1, value_1):
func = op.get_op()
def run_without_python_dispatcher(mode):
return any(isinstance(op, torch._ops.OpOverload) and
op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
for op in mode.decomposed.union([func]))
return any(
isinstance(op, torch._ops.OpOverload)
and op.has_kernel_for_dispatch_key(
DispatchKey.CompositeImplicitAutograd
)
for op in mode.decomposed.union([func])
)
for sample_input in samples:
if requires_grad:
@ -857,29 +960,35 @@ def forward(self, scores_1, mask_1, value_1):
# store the called list on the mode object instance and no
# explicit clearing is necessary as I will create a fresh mode
# for each region
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode, enable_python_dispatcher():
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode:
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode:
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
self.check_decomposed(aten_name, mode)
if not skip_decomp_vjp and (op.aten_backward_name in decomposition_names or run_all):
if not skip_decomp_vjp and (
op.aten_backward_name in decomposition_names or run_all
):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode, enable_python_dispatcher():
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
decomp_vjp_fn(cotangents)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode:
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode:
decomp_vjp_fn(cotangents)
if not run_all:
self.check_decomposed(op.aten_backward_name, mode)
@ -889,15 +998,17 @@ def forward(self, scores_1, mask_1, value_1):
kwargs = sample_input.kwargs
# A failure here might be because the decomposition for the op is wrong or because a
# decomposition used by the particular op is wrong.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode, enable_python_dispatcher():
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode, enable_python_dispatcher():
func(*args, **kwargs)
if run_without_python_dispatcher(mode):
# without this check, incorrect decomps at the python dispatcher level can still pass because
# they're checking aten decomps at the torch_dispatch level.
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all)\
as mode:
with self.DecompCrossRefMode(
self, self.precision, self.rel_tol, dtype, run_all
) as mode:
func(*args, **kwargs)
if not run_all:
@ -908,6 +1019,7 @@ def forward(self, scores_1, mask_1, value_1):
"only backwards is decomposed, but dtype doesn't support AD"
)
instantiate_device_type_tests(TestDecomp, globals())
@ -964,7 +1076,8 @@ class DecompOneOffTests(TestCase):
mean,
False,
1e-05,
[True, True, True])
[True, True, True],
)
res = torch._decomp.decompositions.native_batch_norm_backward(
grad_out,
x,
@ -975,12 +1088,12 @@ class DecompOneOffTests(TestCase):
mean,
False,
1e-05,
[True, True, True])
for (a, b) in zip(ref, res):
[True, True, True],
)
for a, b in zip(ref, res):
self.assertEqual(a.stride(), b.stride())
self.assertEqual(a.dtype, b.dtype)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
@ -1022,15 +1135,22 @@ class DecompOneOffTests(TestCase):
self.assertEqual(
torch.ops.aten._weight_norm_interface(inp, inp2),
torch._decomp.decompositions._weight_norm_interface(inp, inp2)
torch._decomp.decompositions._weight_norm_interface(inp, inp2),
)
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyCPU
@skipIfCrossRef
@skipOps('DecompOneOffTests', 'test_sdpa', [
xfail("nn.functional.scaled_dot_product_attention", dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else [])),
])
@skipOps(
"DecompOneOffTests",
"test_sdpa",
[
xfail(
"nn.functional.scaled_dot_product_attention",
dtypes=[torch.half] + ([torch.bfloat16] if IS_MACOS else []),
),
],
)
@ops(_sdpa_op_info)
def test_sdpa(self, device, dtype, op):
# SDPA doesn't support float16, this is aligned with aten/src/ATen/native/transformers/attention.cpp. If we
@ -1040,13 +1160,19 @@ class DecompOneOffTests(TestCase):
def __init__(self):
super().__init__()
def forward(self, query_layer, key_layer, value_layer, mask=None, is_causal=True):
def forward(
self, query_layer, key_layer, value_layer, mask=None, is_causal=True
):
attn_output = op(
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
query_layer,
key_layer,
value_layer,
attn_mask=mask,
dropout_p=0.0,
is_causal=is_causal,
)
return attn_output
query_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
key_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
value_layer = torch.randn(1, 128, 100, 64, device=device, dtype=dtype)
@ -1057,20 +1183,28 @@ class DecompOneOffTests(TestCase):
for mask in masks:
is_causal = mask is None
attention = ScaledDotProductAttention()
decomposed_res = torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
decomposed_res = (
torch._decomp.decompositions.scaled_dot_product_flash_attention_for_cpu(
query_layer, key_layer, value_layer, 0.0, is_causal, attn_mask=mask
)
)
eager_res = op(
query_layer, key_layer, value_layer, attn_mask=mask, dropout_p=0.0, is_causal=is_causal
query_layer,
key_layer,
value_layer,
attn_mask=mask,
dropout_p=0.0,
is_causal=is_causal,
)
self.assertTrue(torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol))
self.assertTrue(
torch.allclose(decomposed_res[0], eager_res, atol=atol, rtol=rtol)
)
instantiate_device_type_tests(DecompOneOffTests, globals())
class HasDecompTest(TestCase):
def setUp(self):
super().setUp()
@ -1080,22 +1214,24 @@ class HasDecompTest(TestCase):
def _can_appear_in_trace(op: torch._ops.OpOverload) -> bool:
has_tensor_arg = any(
"Tensor" in str(a.type)
for a in itertools.chain(op._schema.arguments, op._schema.returns))
for a in itertools.chain(op._schema.arguments, op._schema.returns)
)
if not has_tensor_arg:
return False
try:
# CompositeImplicitAutograd ops are transparent to the tracer, so don't need decompositions
return not op.has_kernel_for_dispatch_key(DispatchKey.CompositeImplicitAutograd)
return not op.has_kernel_for_dispatch_key(
DispatchKey.CompositeImplicitAutograd
)
except RuntimeError as e:
# has_key fails for some jit-registered ops, which shouldn't be
# relevant here anyway
if 'does not exist' in str(e):
if "does not exist" in str(e):
return False
raise
def test_has_decomposition(self):
def all_aten_overloads():
for name in torch._C._dispatch_get_all_op_names():
if not name.startswith("aten::"):
@ -1116,11 +1252,14 @@ class HasDecompTest(TestCase):
# configurations, so would cause the test to fail
allow_list = {aten.get_gradients.default}
overloads_wanting_decomp = {op for op in all_aten_overloads()
if self._can_appear_in_trace(op)}
overloads_wanting_decomp = {
op for op in all_aten_overloads() if self._can_appear_in_trace(op)
}
ops_missing_decomp = overloads_wanting_decomp - decomposition_table.keys()
ops_missing_decomp -= allow_list
self.assertExpected("".join(sorted(op.name() + "\n" for op in ops_missing_decomp)))
self.assertExpected(
"".join(sorted(op.name() + "\n" for op in ops_missing_decomp))
)
def test_aten_core_operators(self):
# If a decomposition isn't included in the core decompositions,
@ -1136,9 +1275,11 @@ class HasDecompTest(TestCase):
# Some decompositions are registered for CompositeImplicitAutograd
# operators, which never appear in AOTAutograd's graph so are never used.
useful_decomps = {op for op in decomposition_table.keys()
if isinstance(op, torch._ops.OpOverload) and
self._can_appear_in_trace(op)}
useful_decomps = {
op
for op in decomposition_table.keys()
if isinstance(op, torch._ops.OpOverload) and self._can_appear_in_trace(op)
}
core_decomps = torch._decomp.core_aten_decompositions().keys()
core_aten_ops = useful_decomps - core_decomps
self.assertExpected("".join(sorted(op.name() + "\n" for op in core_aten_ops)))

View File

@ -3,9 +3,10 @@
import textwrap
import types
from torch.utils._freeze import Freezer, PATH_MARKER
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._freeze import Freezer, PATH_MARKER
class TestFreezer(TestCase):
"""Tests the freeze.py script"""

View File

@ -3,7 +3,7 @@
import os
import run_test
from torch.testing._internal.common_utils import TestCase, run_tests
from torch.testing._internal.common_utils import run_tests, TestCase
class DummyOptions:
@ -30,7 +30,9 @@ class DeterminationTest(TestCase):
return [
test
for test in cls.TESTS
if run_test.should_run_test(run_test.TARGET_DET_LIST, test, changed_files, DummyOptions())
if run_test.should_run_test(
run_test.TARGET_DET_LIST, test, changed_files, DummyOptions()
)
]
def test_target_det_list_is_sorted(self):
@ -42,9 +44,7 @@ class DeterminationTest(TestCase):
def test_config_change_only(self):
"""CI configs trigger all tests"""
self.assertEqual(
self.determined_tests([".ci/pytorch/test.sh"]), self.TESTS
)
self.assertEqual(self.determined_tests([".ci/pytorch/test.sh"]), self.TESTS)
def test_run_test(self):
"""run_test.py is imported by determination tests"""
@ -68,14 +68,17 @@ class DeterminationTest(TestCase):
def test_test_file(self):
"""Test files trigger themselves and dependent tests"""
self.assertEqual(
self.determined_tests(["test/test_jit.py"]), ["test_jit_profiling", "test_jit"]
self.determined_tests(["test/test_jit.py"]),
["test_jit_profiling", "test_jit"],
)
self.assertEqual(
self.determined_tests(["test/jit/test_custom_operators.py"]),
["test_jit_profiling", "test_jit"],
)
self.assertEqual(
self.determined_tests(["test/quantization/eager/test_quantize_eager_ptq.py"]),
self.determined_tests(
["test/quantization/eager/test_quantize_eager_ptq.py"]
),
["test_quantization"],
)

View File

@ -2,11 +2,16 @@
import torch
from torch.testing import make_tensor
from torch.testing._internal.common_utils import TestCase, run_tests, IS_JETSON
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCUDA, dtypes, skipMeta, skipCUDAIfRocm,
onlyNativeDeviceTypes)
dtypes,
instantiate_device_type_tests,
onlyCUDA,
onlyNativeDeviceTypes,
skipCUDAIfRocm,
skipMeta,
)
from torch.testing._internal.common_dtype import all_types_and_complex_and
from torch.testing._internal.common_utils import IS_JETSON, run_tests, TestCase
from torch.utils.dlpack import from_dlpack, to_dlpack
@ -15,7 +20,16 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
@dtypes(
*all_types_and_complex_and(
torch.half,
torch.bfloat16,
torch.bool,
torch.uint16,
torch.uint32,
torch.uint64,
)
)
def test_dlpack_capsule_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
@ -23,7 +37,16 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
@dtypes(
*all_types_and_complex_and(
torch.half,
torch.bfloat16,
torch.bool,
torch.uint16,
torch.uint32,
torch.uint64,
)
)
def test_dlpack_protocol_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(x)
@ -62,7 +85,16 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
@dtypes(
*all_types_and_complex_and(
torch.half,
torch.bfloat16,
torch.bool,
torch.uint16,
torch.uint32,
torch.uint64,
)
)
def test_from_dlpack(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
@ -70,7 +102,16 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
@dtypes(
*all_types_and_complex_and(
torch.half,
torch.bfloat16,
torch.bool,
torch.uint16,
torch.uint32,
torch.uint64,
)
)
def test_from_dlpack_noncontinguous(self, device, dtype):
x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
@ -113,7 +154,16 @@ class TestTorchDlPack(TestCase):
@skipMeta
@onlyNativeDeviceTypes
@dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
@dtypes(
*all_types_and_complex_and(
torch.half,
torch.bfloat16,
torch.bool,
torch.uint16,
torch.uint32,
torch.uint64,
)
)
def test_from_dlpack_dtype(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
@ -204,5 +254,5 @@ class TestTorchDlPack(TestCase):
instantiate_device_type_tests(TestTorchDlPack, globals())
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

View File

@ -2,8 +2,8 @@
import contextlib
import copy
import itertools
import inspect
import itertools
import math
import operator
import re
@ -16,8 +16,9 @@ from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import to_node, SymNode, method_to_operator
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraints,
DimDynamic,
expect_true,
@ -25,11 +26,10 @@ from torch.fx.experimental.symbolic_shapes import (
guard_float,
guard_int,
GuardOnDataDependentSymNode,
ShapeEnv,
is_symbolic,
ShapeEnv,
StatelessSymbolicContext,
statically_known_true,
_constrain_range_for_size,
)
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
@ -38,8 +38,8 @@ from torch.testing._internal.common_utils import (
skipIfTorchDynamo,
TestCase,
)
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._sympy.functions import FloorDiv, Mod
aten = torch.ops.aten
@ -51,8 +51,10 @@ def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
pytree.tree_map_(add_func, op)
return f
return decorator
@ -101,13 +103,26 @@ def create_contiguous(shape):
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0):
def __new__(
cls,
sym_shape,
sym_strides,
dtype,
layout,
requires_grad,
device,
storage_offset=0,
):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls, sym_shape,
sym_stride, storage_offset,
dtype=dtype, layout=layout, requires_grad=requires_grad,
cls,
sym_shape,
sym_stride,
storage_offset,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device=device,
)
return r
@ -115,7 +130,9 @@ class FakeSymbolicTensor(torch.Tensor):
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(shape, None, self.dtype, self.layout, self.requires_grad, self.device)
return FakeSymbolicTensor(
shape, None, self.dtype, self.layout, self.requires_grad, self.device
)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
@ -125,7 +142,14 @@ class FakeSymbolicTensor(torch.Tensor):
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device)
return FakeSymbolicTensor(
shape,
self.stride(),
self.dtype,
self.layout,
self.requires_grad,
self.device,
)
raise RuntimeError(f"operator {func_overload} not supported")
@ -138,58 +162,86 @@ def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None)
constraint_dims = [None] * arg.dim()
if dynamic_dims is None:
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
sym_shapes, sym_strides, sym_storage_offset = \
shape_env.create_symbolic_sizes_strides_storage_offset(
(
sym_shapes,
sym_strides,
sym_storage_offset,
) = shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=source,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_dims,
constraint_sizes=constraint_dims
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
),
)
return FakeSymbolicTensor(sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset)
return FakeSymbolicTensor(
sym_shapes,
sym_strides,
arg.dtype,
arg.layout,
arg.requires_grad,
arg.device,
sym_storage_offset,
)
def create_symtype(cls, pytype, shape_env, val, duck=True):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
constraint_dim=None,
)
return cls(SymNode(
return cls(
SymNode(
symbol,
shape_env,
pytype,
hint=val,
))
)
)
# TODO: default duck to False
def create_symint(shape_env, i: int, duck=True):
return create_symtype(SymInt, int, shape_env, i, duck=duck)
def create_symbool(shape_env, b: bool):
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float):
return create_symtype(SymFloat, float, shape_env, f)
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
class TestPySymInt(TestCase):
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod]
ops = [
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.mod,
]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and ((op != operator.mod or op != operator.floordiv) and args[1][0] != 0):
self.assertTrue(op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]))
if not isinstance(args[0][1], int) and (
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
):
self.assertTrue(
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
)
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
@ -200,7 +252,6 @@ class TestPySymInt(TestCase):
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
@ -217,7 +268,9 @@ class TestPySymInt(TestCase):
# Should be simplifiable to an integer.
# Ref: https://github.com/pytorch/pytorch/pull/107492
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(isinstance(x.size()[1].node.maybe_as_int(), int)) # due to guard above
self.assertTrue(
isinstance(x.size()[1].node.maybe_as_int(), int)
) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
@ -344,7 +397,6 @@ class TestPySymInt(TestCase):
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
@ -368,7 +420,7 @@ class TestPySymInt(TestCase):
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device='meta')
r = torch.empty(a0, device="meta")
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
@ -389,7 +441,7 @@ class TestPySymInt(TestCase):
self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s0, 2)]"""
"""[Eq(s0, 2)]""",
)
def test_sym_int(self):
@ -410,7 +462,9 @@ class TestPySymInt(TestCase):
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""")
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)"""
)
def test_sym_sqrt(self):
shape_env = ShapeEnv()
@ -418,7 +472,9 @@ class TestPySymInt(TestCase):
r = torch._sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""")
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)"""
)
def test_sym_floor(self):
shape_env = ShapeEnv()
@ -442,7 +498,9 @@ class TestPySymInt(TestCase):
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""")
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)"""
)
def test_sym_ceil(self):
shape_env = ShapeEnv()
@ -450,7 +508,9 @@ class TestPySymInt(TestCase):
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""")
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)"""
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
@ -471,13 +531,19 @@ class TestPySymInt(TestCase):
self.assertEqual(len(shape_env.guards), 0)
self.assertEqual(r3, 5)
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""")
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
)
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
self.assertEqual(len(shape_env.guards), 1)
self.assertEqual(r4, 4)
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""")
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
)
def test_tracing_sym_ite(self):
def f(x):
@ -487,13 +553,16 @@ class TestPySymInt(TestCase):
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(gm.code.strip(), """\
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 5
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
return sym_ite""")
return sym_ite""",
)
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
self.assertEqual(r1, 5)
@ -536,7 +605,7 @@ def forward(self, x_1):
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[-s0 + u0 < 0]"""
"""[-s0 + u0 < 0]""",
)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
@ -552,7 +621,7 @@ def forward(self, x_1):
# Importantly, this is put in i1, not i0!
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
"""[Eq(u0 + u1, 10)]"""
"""[Eq(u0 + u1, 10)]""",
)
self.assertTrue(i0 + i1 == 10)
# NB: We currently don't support deriving that we can substitute
@ -595,7 +664,9 @@ def forward(self, x_1):
def test_expect_true_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate([lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]):
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
@ -606,7 +677,9 @@ def forward(self, x_1):
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate([lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]):
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
@ -619,7 +692,9 @@ def forward(self, x_1):
def test_guard_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate([lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]):
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertTrue(bool(rel(i0)))
@ -630,7 +705,9 @@ def forward(self, x_1):
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate([lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]):
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertFalse(bool(rel(i0)))
@ -641,7 +718,9 @@ def forward(self, x_1):
self.assertTrue(statically_known_true(i0 <= 4))
self.assertTrue(statically_known_true(i0 < 5))
for i, rel in enumerate([lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]):
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertTrue(bool(rel(i0)))
@ -652,7 +731,9 @@ def forward(self, x_1):
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 3))
for i, rel in enumerate([lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]):
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertFalse(bool(rel(i0)))
@ -666,7 +747,7 @@ def forward(self, x_1):
def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device='meta')
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_specialize_zero_one(self):
@ -741,7 +822,9 @@ def forward(self, x_1):
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(out.strip(), """\
self.assertExpectedInline(
out.strip(),
"""\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
# No stacktrace found for following nodes
@ -755,7 +838,8 @@ class f(torch.nn.Module):
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""") # noqa: B950
return (getitem, getitem_1)""", # noqa: B950
)
def test_statically_known_true(self):
shape_env = ShapeEnv()
@ -792,8 +876,8 @@ class f(torch.nn.Module):
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
# of construction order or check order.
for construct_ephemeral_first, x_first_in_check in (
itertools.product([False, True], [False, True])
for construct_ephemeral_first, x_first_in_check in itertools.product(
[False, True], [False, True]
):
shape_env = ShapeEnv()
shape = (5, 10)
@ -816,10 +900,13 @@ class f(torch.nn.Module):
def _get_ephemeral_source_symbols(t):
return [
s.node.expr for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
if isinstance(s, torch.SymInt) and s.node.expr in shape_env.var_to_sources
s.node.expr
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
if isinstance(s, torch.SymInt)
and s.node.expr in shape_env.var_to_sources
and any(
source.is_ephemeral() for source in shape_env.var_to_sources[s.node.expr]
source.is_ephemeral()
for source in shape_env.var_to_sources[s.node.expr]
)
]
@ -869,12 +956,14 @@ class f(torch.nn.Module):
self.assertEqual(x.storage_offset(), y.storage_offset())
@skipIfTorchDynamo("Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)")
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
seed_node = (create_symint(shape_env, 2) / 2.).node
seed_node = (create_symint(shape_env, 2) / 2.0).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
@ -896,16 +985,19 @@ class TestSymNumberMagicMethods(TestCase):
elif fn == "pow" and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif fn == "pow" and inp1 < 0 and inp2 in (2.5, -2.5) and (
type(inp1) in (SymFloat, SymInt) or
type(inp2) in (SymFloat, SymInt)
elif (
fn == "pow"
and inp1 < 0
and inp2 in (2.5, -2.5)
and (
type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt)
)
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and
isinstance(inp2, (SymInt, int))
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
@ -964,7 +1056,6 @@ class TestSymNumberMagicMethods(TestCase):
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(sym_node.magic_methods.keys()))
def test_bool_method(self, fn):
# sym_ite has its own tests
@ -975,7 +1066,6 @@ class TestSymNumberMagicMethods(TestCase):
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(sym_node.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
@ -984,7 +1074,9 @@ class TestSymNumberMagicMethods(TestCase):
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
if (first_type == "int" or second_type == "int") and fn in sym_node.only_float_magic_methods:
if (
first_type == "int" or second_type == "int"
) and fn in sym_node.only_float_magic_methods:
self.skipTest(f"{fn} is not an int method")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
@ -1000,7 +1092,7 @@ class TestSymNumberMagicMethods(TestCase):
values = (
0.0,
1.0,
0.5 if fn in ("sym_acos", "sym_asin") else 2.5 # avoid math domain error
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
)
neg_values = tuple(-x for x in values)
@ -1076,7 +1168,9 @@ class TestSymNumberMagicMethods(TestCase):
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(RuntimeError, "add not supported by NestedIntSymNode"):
with self.assertRaisesRegex(
RuntimeError, "add not supported by NestedIntSymNode"
):
j1 + 3
self.assertFalse(j1 == 3)
@ -1130,8 +1224,10 @@ class TestSymNumberMagicMethods(TestCase):
self.assertIs(sz1 == sz2, True)
self.assertIs(sz1 != sz2, False)
instantiate_parametrized_tests(TestSymNumberMagicMethods)
class TestFloorDiv(TestCase):
@staticmethod
def python_floordiv(x, y):
@ -1164,7 +1260,9 @@ class TestFloorDiv(TestCase):
)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_bool(self):
values = (
@ -1177,14 +1275,20 @@ class TestFloorDiv(TestCase):
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
# Compares to int since our FloorDiv has no bool support
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(int(x), int(y)))
self.assertEqual(
TestFloorDiv.python_floordiv(x, y),
TestFloorDiv.torch_floordiv(int(x), int(y)),
)
# Tests that our impl throws
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
(
rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))
rf", expected integer or real"
),
lambda: TestFloorDiv.torch_floordiv(x, y),
)
def test_floordiv_complex(self):
values = (
@ -1201,10 +1305,13 @@ class TestFloorDiv(TestCase):
self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaisesRegex(
TypeError,
(rf"unsupported operand type\(s\) for //: "
(
rf"unsupported operand type\(s\) for //: "
rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'"
rf", expected integer or real"),
lambda: TestFloorDiv.torch_floordiv(x, y))
rf", expected integer or real"
),
lambda: TestFloorDiv.torch_floordiv(x, y),
)
def test_floordiv_div_by_zero(self):
values = (
@ -1217,11 +1324,14 @@ class TestFloorDiv(TestCase):
# We don't test error messages to avoid depending on Python
# interpreter version
if type(y) is not sympy.Symbol:
self.assertRaises(ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y))
self.assertRaises(
ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y)
)
self.assertRaisesRegex(
ZeroDivisionError,
"division by zero",
lambda: TestFloorDiv.torch_floordiv(x, y))
lambda: TestFloorDiv.torch_floordiv(x, y),
)
def test_floordiv_zero_base(self):
values = (
@ -1232,7 +1342,10 @@ class TestFloorDiv(TestCase):
for x, y in TestFloorDiv.yield_test_cases(values, negate=False):
if type(x) is not sympy.Symbol:
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
self.assertEqual(
TestFloorDiv.python_floordiv(x, y),
TestFloorDiv.torch_floordiv(x, y),
)
else:
self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y))
@ -1245,7 +1358,9 @@ class TestFloorDiv(TestCase):
)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y))
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
@ -1296,6 +1411,7 @@ class TestFloorDiv(TestCase):
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
@ -1305,9 +1421,12 @@ class TestFloorDiv(TestCase):
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"),
op)
(
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"
),
op,
)
continue
op = op()
@ -1371,7 +1490,11 @@ class TestDimConstraints(TestCase):
from sympy import Symbol
from sympy.solvers.inequalities import reduce_inequalities
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
s0 = Symbol("s0", positive=True, integer=True)
@ -1398,7 +1521,11 @@ class TestDimConstraints(TestCase):
def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Ne, Symbol
from torch._dynamo.source import LocalSource, TensorProperty, TensorPropertySource
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
@ -1456,7 +1583,9 @@ class TestDimConstraints(TestCase):
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(symbol_to_source, var_to_val, marked_dynamic, {})
dim_constraints = DimConstraints(
symbol_to_source, var_to_val, marked_dynamic, {}
)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
@ -2181,7 +2310,9 @@ class TestDimConstraints(TestCase):
FloorDiv(s0, 2),
)
)
dim_constraints.add(Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0))
dim_constraints.add(
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
)
dim_constraints.add(
Eq(
64
@ -2275,7 +2406,9 @@ class TestDimConstraints(TestCase):
dim_constraints.solve()
dim_constraints.remove_redundant_dynamic_results()
self.assertEqual(dim_constraints._static_results, {
self.assertEqual(
dim_constraints._static_results,
{
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
@ -2286,18 +2419,22 @@ class TestDimConstraints(TestCase):
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
})
self.assertEqual(dim_constraints._dynamic_results, {
},
)
self.assertEqual(
dim_constraints._dynamic_results,
{
"dynamic_dim(L['e'], 1) == dynamic_dim(L['c'], 1)",
"dynamic_dim(L['d'], 1) == dynamic_dim(L['c'], 1)",
})
},
)
def dummy_fn(a, b, c, d, e, f):
pass
action_code = dim_constraints.prettify_results(inspect.signature(dummy_fn))
static_code, dynamic_code = re.findall(r"```(.*?)```", action_code, re.DOTALL)
expected_static = '''
expected_static = """
def specializations(a, b, c, d, e, f):
# a:
assert a.size()[0] == 8
@ -2318,8 +2455,8 @@ def specializations(a, b, c, d, e, f):
# f:
assert f.size()[1] == 1
'''
expected_dynamic = '''
"""
expected_dynamic = """
def specify_constraints(a, b, c, d, e, f):
return [
# d:
@ -2328,12 +2465,11 @@ def specify_constraints(a, b, c, d, e, f):
# e:
dynamic_dim(e, 1) == dynamic_dim(c, 1),
]
'''
"""
self.assertEqual(static_code, expected_static)
self.assertEqual(dynamic_code, expected_dynamic)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests()

File diff suppressed because it is too large Load Diff