Files
pytorch/test/test_dynamic_shapes.py
Laith Sakka 7cfd054075 [attempt 2] Compute contiguity symbolically to avoid dde, and introduce c++ sym_is_contiguous (#157472)
Summary:
When we compute contiguity for a tensor with dynamic shapes we first:
1) Try to compute it without guarding.
2) If all shapes hinted, compute it with potentially adding guards.
3) if any input is not hinted, compute it symbolically.

sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called
on it to avoid data dependent errors.

ex:
 bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__);
is_contiguous_or_false is a helper function that does that.

In this PR I only handle default contiguity, will follow up with changes for other formats like  channel_last .
We use this patter in this PR for several locations to avoid DDEs.

Test Plan:
contbuild & OSS CI,

Rollback Plan:

Reviewed By: malfet

Differential Revision: D77639021

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157472
Approved by: https://github.com/aorenste
2025-07-02 23:12:29 +00:00

3538 lines
129 KiB
Python

# Owner(s): ["oncall: jit"]
# ruff: noqa: F841
import contextlib
import copy
import itertools
import math
import operator
import unittest
import numpy as np
import sympy
import torch
import torch.fx
import torch.nn.functional as F
from torch import sym_int, SymBool, SymFloat, SymInt
from torch._C import _disabled_torch_function_impl
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend
from torch._inductor.utils import fresh_cache
from torch.fx.experimental import sym_node
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
DimConstraints,
DimDynamic,
expect_true,
guard_bool,
guard_float,
guard_int,
GuardOnDataDependentSymNode,
has_free_symbols,
hint_int,
is_symbolic,
ShapeEnv,
StatelessSymbolicContext,
statically_known_false,
statically_known_true,
)
from torch.testing._internal.common_dtype import all_types_and
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
run_tests,
skipIfTorchDynamo,
TestCase,
)
from torch.testing._internal.logging_utils import logs_to_string
from torch.utils import _pytree as pytree
from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._sympy.functions import (
CleanDiv,
FloorDiv,
IsNonOverlappingAndDenseIndicator,
Mod,
)
aten = torch.ops.aten
meta_funcs = {}
def register_meta(op):
def decorator(f):
def add_func(op):
meta_funcs[op] = f
pytree.tree_map_(add_func, op)
return f
return decorator
@register_meta([aten.add.Tensor, aten.sub.Tensor])
def binary_meta(a, b):
return a.new_empty(a.shape)
@register_meta(aten.cat.default)
def cat_meta(tensors, dim=0):
concat_length = 0
shape = tensors[0].shape
for tensor in tensors:
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
if idx == dim:
concat_length = concat_length + length
else:
assert length == common_length
new_shape = list(shape)
new_shape[dim] = concat_length
return tensors[0].new_empty(new_shape)
@register_meta([aten.narrow_copy.default])
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
shape = []
for i, x in enumerate(a.shape):
if i == dim:
shape.append(length)
else:
shape.append(x)
return a.new_empty(tuple(shape))
@register_meta([aten.expand.default])
def expand_symint_meta(a, size, implicit=False):
return a.new_empty(size)
def create_contiguous(shape):
strides = [1]
for dim in reversed(shape[:-1]):
strides.append(dim * strides[-1])
return list(reversed(strides))
class FakeSymbolicTensor(torch.Tensor):
@staticmethod
def __new__(
cls,
sym_shape,
sym_strides,
dtype,
layout,
requires_grad,
device,
storage_offset=0,
):
# TODO: this is wrong in general
sym_stride = create_contiguous(sym_shape)
r = torch.Tensor._make_wrapper_subclass(
cls,
sym_shape,
sym_stride,
storage_offset,
dtype=dtype,
layout=layout,
requires_grad=requires_grad,
device=device,
)
return r
__torch_function__ = _disabled_torch_function_impl
def new_empty(self, shape):
return FakeSymbolicTensor(
shape, None, self.dtype, self.layout, self.requires_grad, self.device
)
@classmethod
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
if func_overload in meta_funcs:
return meta_funcs[func_overload](*args, **kwargs)
if func_overload == torch.ops.aten.new_empty.default:
self = args[0]
shape = args[1]
return FakeSymbolicTensor(
shape,
self.stride(),
self.dtype,
self.layout,
self.requires_grad,
self.device,
)
raise RuntimeError(f"operator {func_overload} not supported")
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
from torch._dynamo.source import ConstantSource
if source is None:
source = ConstantSource(name)
constraint_dims = [None] * arg.dim()
if dynamic_dims is None:
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
(
sym_shapes,
sym_strides,
sym_storage_offset,
) = shape_env.create_symbolic_sizes_strides_storage_offset(
arg,
source=source,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
),
)
return FakeSymbolicTensor(
sym_shapes,
sym_strides,
arg.dtype,
arg.layout,
arg.requires_grad,
arg.device,
sym_storage_offset,
)
def create_fake_tensor_with_dynamic_size(x, shape_env, dynamic_sizes, dynamic_strides):
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode(shape_env=shape_env) as fake_mode:
return fake_mode.from_tensor(
x,
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=dynamic_sizes,
dynamic_strides=dynamic_strides,
),
)
def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
from torch._dynamo.source import ConstantSource
symbol = shape_env.create_symbol(
val,
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
constraint_dim=None,
**kwargs,
)
return cls(SymNode(symbol, shape_env, pytype, hint=val))
# TODO: default duck to False
def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
def create_symbool(shape_env, b: bool) -> SymBool:
return create_symtype(SymBool, bool, shape_env, b)
def create_symfloat(shape_env, f: float) -> SymFloat:
return create_symtype(SymFloat, float, shape_env, f)
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestPySymInt(TestCase):
def test_arith_ops(self):
shape_env = ShapeEnv()
symints = []
for i in range(2, 5):
symints.append((i, create_symint(shape_env, i)))
ops = [
operator.add,
operator.sub,
operator.floordiv,
operator.mul,
operator.mod,
]
for op in ops:
for args in itertools.permutations(symints, 2):
if not isinstance(args[0][1], int) and (
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
):
self.assertTrue(
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
)
def test_reverse_arith_ops(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertTrue(5 // a == 5 // 2)
a = create_symint(shape_env, 2)
self.assertTrue(5 * a == 5 * 2)
def test_sympify_symint(self):
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
self.assertIs(sympy.sympify(a), a.node.expr)
b = create_symfloat(shape_env, 3.0)
self.assertIs(sympy.sympify(b), b.node.expr)
c = create_symbool(shape_env, True)
self.assertIs(sympy.sympify(c), c.node.expr)
def test_roundtrip(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
self.assertTrue(not isinstance(x.shape[0], SymNode))
self.assertTrue(isinstance(x.shape[0], SymInt))
self.assertTrue(x.shape[0] == 5)
self.assertTrue(x.shape[1] == 4)
self.assertTrue(x.shape[2], 3)
self.assertTrue(x.size()[0], 5)
self.assertTrue(x.size()[1], 4)
# Should be simplifiable to an integer.
# Ref: https://github.com/pytorch/pytorch/pull/107492
self.assertTrue(isinstance(x.size()[1], SymInt))
self.assertTrue(
isinstance(x.size()[1].node.maybe_as_int(), int)
) # due to guard above
self.assertTrue(x.size()[2] == 3)
self.assertTrue(x.size(0) == 5)
self.assertTrue(x.size(1) == 4)
self.assertTrue(x.size(2) == 3)
self.assertTrue(isinstance(x.size(2), SymInt))
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
self.assertTrue(isinstance(y.storage_offset(), SymInt))
self.assertTrue(y.storage_offset() == 12)
def test_binary(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# broadcasting
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
z = x + y
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
def test_symint_args(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
LAST_DIM = 2
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == y.shape[2])
# arithmetic expr with two symints
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
self.assertTrue(z.shape[2] == 2)
# arithmetic expr with a symint and python int
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
self.assertTrue(z.shape[2] == 2)
def test_symint_vargs(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
# varargs
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# shape list
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(x.shape[0], y.shape[1], 3)
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints in a list
z = y.expand((x.shape[0], y.shape[1], 3))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python symints and ints
z = y.expand(5, y.shape[1], x.shape[2])
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
# mixed python ints and symints in a list
z = y.expand((5, y.shape[1], x.shape[2]))
self.assertTrue(z.shape[0] == 5)
self.assertTrue(z.shape[1] == 4)
self.assertTrue(z.shape[2] == 3)
z = y.expand((y.shape[1],))
z = y.expand(y.shape[1])
def test_symint_bitwise_and(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 0b1100)
b0 = create_symint(shape_env, 0b1010)
res_and = a0 & b0
self.assertEqual(res_and, 0b1000)
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
)
a1 = create_symint(shape_env, 3)
b1 = create_symbool(shape_env, True)
self.assertEqual(a1 & b1, 1)
a2 = create_symint(shape_env, 0b1100)
self.assertEqual(a2 & 0b1010, 0b1000)
a3 = create_symbool(shape_env, True)
b3 = create_symbool(shape_env, True)
self.assertEqual(a3 & b3, True)
def test_symint_bitwise_or(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 0b1100)
b0 = create_symint(shape_env, 0b1010)
res_or = a0 | b0
self.assertEqual(res_or, 0b1110)
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
)
def test_stride(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
self.assertIsInstance(x.stride()[0], SymInt)
def test_size_expressions(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
expand_x = x.expand(x.shape[0], x.shape[0])
if expand_x.shape[0] > 3:
result = expand_x + expand_x
else:
result = expand_x + expand_x
gt_op, _bt, is_size_obv = shape_env.guards[-1]
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
self.assertFalse(is_size_obv)
def test_floordiv_static(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
# This was extracted from
# python test/inductor/test_cuda_cpp_wrapper.py -k
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
bool(s0 % 2 == 0)
bool(s0 % (s0 // 2) == 0)
bool(2 * (s0 // 2) == s0)
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
def test_numel(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
self.assertIsInstance(x.numel(), torch.SymInt)
self.assertIsInstance(torch.numel(x), torch.SymInt)
x = torch.rand(3, 3)
self.assertIsInstance(x.numel(), int)
self.assertIsInstance(torch.numel(x), int)
def test_int_to_float(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
r = torch.sym_float(x.shape[0])
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
def test_aten_ops(self):
shape_env = ShapeEnv()
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
shape_env = ShapeEnv()
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
def test_fx_trace_intlist(self):
class CustomModule(torch.nn.Module):
def forward(self, x):
bs, c, h, w = x.shape
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
m = CustomModule()
x = torch.rand(1, 3, 4, 4)
# should not TypeError: pad(): argument 'pad' (position 2) must be
# tuple of ints, not tuple
torch.fx.symbolic_trace(m)
def test_meta_symint(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
r = torch.empty(a0, device="meta")
self.assertIsInstance(r.shape[0], SymInt)
def test_guard_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(a0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_sym_sum(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 2)
s1 = create_symint(shape_env, 3)
s2 = create_symint(shape_env, 4)
self.assertEqual(
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
)
def test_prefer_deferred_runtime_assertions_over_guards(self):
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertEqual(guard_int(s0), 2)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
s0 = create_symint(shape_env, 2)
self.assertTrue(expect_true(s0 == 2))
self.assertEqual(len(shape_env.guards), 0)
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
"""[Eq(s97, 2)]""",
)
def test_sym_int(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = sym_int(a0)
self.assertEqual(r, 5)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
a1 = create_symint(shape_env, 7)
r = sym_int(a1 / 2)
self.assertEqual(guard_int(r), 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
)
a3 = create_symint(shape_env, 3)
r = sym_int(2.0 * torch.sym_float(a3))
self.assertEqual(guard_int(r), 6)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
)
def test_sym_log2(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = torch._sym_log2(a0)
self.assertEqual(r, 2.0)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
)
def test_sym_sqrt(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 4)
r = torch._sym_sqrt(a0)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
)
def test_sym_floor(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.floor(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
)
r = math.floor(3.0 * a0)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
)
def test_sym_trunc(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.trunc(a0 / 2)
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
)
r = torch.sym_int(torch.sym_sqrt(a0))
self.assertEqual(r, 2)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
)
def test_sym_ceil(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = math.ceil(a0 / 2)
self.assertEqual(r, 3)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
)
r1 = 3.0 * a0
r = math.floor(r1)
self.assertEqual(r, 15)
self.assertIsInstance(r, torch.SymInt, msg=type(r))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
)
def test_sym_ite(self):
shape_env = ShapeEnv()
t = create_symint(shape_env, 5)
f = create_symint(shape_env, 4)
b1 = True
r1 = torch.sym_ite(b1, t, f)
self.assertTrue(r1 is t)
b2 = False
r2 = torch.sym_ite(b2, t, f)
self.assertTrue(r2 is f)
b3 = t == 5
r3 = torch.sym_ite(b3, t, f)
self.assertEqual(len(shape_env.guards), 0)
self.assertEqual(r3, 5)
self.assertEqual(type(t), type(r3))
self.assertExpectedInline(
str(shape_env.guards[0][0]),
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
)
b4 = f == 5
r4 = torch.sym_ite(b4, t, f)
self.assertEqual(len(shape_env.guards), 1)
self.assertEqual(r4, 4)
self.assertEqual(type(f), type(r4))
self.assertExpectedInline(
str(shape_env.guards[1][0]),
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
)
def test_tracing_sym_ite(self):
def f(x):
b = x.shape[0] == 5
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
return ret
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
self.assertEqual(len(gm.shape_env.guards), 0)
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x_1):
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
eq = sym_size_int == 5
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
return sym_ite""",
)
r1 = gm(torch.ones(4, 5))
self.assertIsInstance(r1, int)
self.assertEqual(r1, 5)
r2 = gm(torch.ones(5, 4))
self.assertIsInstance(r2, int)
self.assertEqual(r2, 5)
def test_int_conversion(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
int(a0)
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
def test_data_dependent_guard(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
def test_data_dependent_guard_propagate_real_tensors(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
self.assertEqual(bool(s0 == 0), True)
def test_expect_true_basic(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i0_sym = i0.node.expr
# This doesn't error
self.assertTrue(expect_true(i0 == 0))
# This generates a deferred runtime assert via replacement
self.assertEqual(shape_env.replacements[i0_sym], 0)
# After expecting true, guards now resolve given the runtime assert
bool(i0 == 0)
def test_expect_true_with_s0(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5)
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(i0 < s0))
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
"""[u0 < s97]""",
)
self.assertTrue(i0 < s0)
self.assertTrue(i0 != s0)
self.assertFalse(i0 > s0)
self.assertFalse(i0 >= s0)
def test_expect_true_prefer_later(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
i1_sym = i1.node.expr
self.assertTrue(expect_true(i0 + i1 == 10))
# Importantly, this is put in i1, not i0!
self.assertExpectedInline(
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
"""[Eq(u0 + u1, 10)]""",
)
self.assertTrue(i0 + i1 == 10)
# NB: We currently don't support deriving that we can substitute
# i0 + i1 with 10; maybe we should, but this means our rewriting
# system is no longer confluent (it's probably OK though, because
# you're unlikely to get other equalities like this on the
# unbacked SymInts.)
def test_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == i1 * 4))
self.assertExpectedInline(str(i0), """u0""")
i2 = shape_env.create_unbacked_symint()
i3 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i2)
_constrain_range_for_size(i3)
self.assertTrue(expect_true(i2 * 4 == i3))
self.assertExpectedInline(str(i3), """u3""")
def test_avoid_unbacked_substitution(self):
shape_env = ShapeEnv()
i0 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i0)
i1 = shape_env.create_unbacked_symint()
_constrain_range_for_size(i1)
self.assertTrue(expect_true(i0 == 10 - i1))
self.assertExpectedInline(str(i0), """u0""")
def test_expect_true_double_digits(self):
shape_env = ShapeEnv()
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
self.assertEqual(str(ia[-1]), "u10")
self.assertTrue(expect_true(sum(ia) == 20))
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
def test_expect_true_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = shape_env.create_unbacked_symint()
self.assertTrue(expect_true(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 5))
def test_guard_refine_range(self):
shape_env = ShapeEnv()
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertFalse(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 > 4))
self.assertTrue(statically_known_true(i0 >= 5))
for i, rel in enumerate(
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 != 6))
self.assertTrue(statically_known_true(i0 <= 4))
self.assertTrue(statically_known_true(i0 < 5))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 2, duck=False)
self.assertTrue(bool(rel(i0)))
self.assertFalse(statically_known_true(i0 != 2))
self.assertFalse(statically_known_true(i0 != 3))
self.assertTrue(statically_known_true(i0 != 4))
self.assertTrue(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 < 4))
self.assertTrue(statically_known_true(i0 <= 3))
for i, rel in enumerate(
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
):
with self.subTest(f"i = {i}"):
i0 = create_symint(shape_env, 10, duck=False)
self.assertFalse(bool(rel(i0)))
self.assertTrue(statically_known_true(i0 != 2))
self.assertTrue(statically_known_true(i0 != 3))
self.assertFalse(statically_known_true(i0 != 4))
self.assertFalse(statically_known_true(i0 != 5))
self.assertTrue(statically_known_true(i0 >= 4))
self.assertTrue(statically_known_true(i0 > 3))
def test_mul_int_oo_nan(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 5, duck=False)
s1 = create_symint(shape_env, 6, duck=False)
s2 = create_symint(shape_env, 5, duck=False)
bool(s0 * (s1 // s0) == s2)
def test_non_overlapping_and_dense(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 5)
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
def test_non_overlapping_and_dense_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
cf = torch.ops.aten.is_non_overlapping_and_dense.default
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
Max = torch.sym_max
# NB: This only works because we're able to determine this tensor is
# contiguous. transpose(0, 1) makes it stop working
self.assertTrue(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
)
)
def test_sympy_optimized_add_binary_search(self):
import sympy
from torch.fx.experimental.sym_node import _binary_search_insert_arg
a = sympy.Symbol("a")
b = sympy.Symbol("b")
c = sympy.Symbol("c")
args = []
args = _binary_search_insert_arg([], b)
self.assertEqual(args, [b])
self.assertEqual(_binary_search_insert_arg(args, b), None)
args = _binary_search_insert_arg(args, a)
self.assertEqual(args, [a, b])
self.assertEqual(_binary_search_insert_arg(args, b), None)
self.assertEqual(_binary_search_insert_arg(args, a), None)
args = _binary_search_insert_arg(args, c)
self.assertEqual(args, [a, b, c])
self.assertEqual(_binary_search_insert_arg(args, a), None)
self.assertEqual(_binary_search_insert_arg(args, b), None)
self.assertEqual(_binary_search_insert_arg(args, c), None)
a1 = sympy.Symbol("a1")
a2 = sympy.Symbol("a2")
args = _binary_search_insert_arg(args, a1)
self.assertEqual(args, [a, a1, b, c])
args = _binary_search_insert_arg(args, a2)
self.assertEqual(args, [a, a1, a2, b, c])
c1 = sympy.Symbol("c1")
args = _binary_search_insert_arg(args, c1)
self.assertEqual(args, [a, a1, a2, b, c, c1])
# insert to front
_a = sympy.Symbol("_a")
args = _binary_search_insert_arg(args, _a)
self.assertEqual(args, [_a, a, a1, a2, b, c, c1])
def test_floor_clean_div_axioms(self):
# Test that if we add an axiom that have FloorDiv, after which the
# shapeEnv changed such that it can be simplified it to CleanDiv, then
# We still correctly replace CleanDiv with the axiom value of FloorDiv.
shape_env = ShapeEnv()
a = shape_env.create_unbacked_symint()
shape_env.guard_or_defer_runtime_assert((a // 3 == 1).node.expr, " test")
from sympy import Eq
test1 = Eq(FloorDiv(a.node.expr, 3), 1)
test2 = Eq(CleanDiv(a.node.expr, 3), 1)
self.assertTrue(shape_env.evaluate_expr(test1))
self.assertEqual(shape_env._maybe_evaluate_static(test2), None)
# After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3)
shape_env.guard_or_defer_runtime_assert(Eq(Mod(a, 3), 0), " test")
self.assertEqual(test2, shape_env.simplify(test1))
self.assertTrue(shape_env.evaluate_expr(test1))
self.assertTrue(shape_env.evaluate_expr(test2))
def test_sympy_optimized_add(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 2)
s1 = create_symint(shape_env, 3)
s2 = create_symint(shape_env, 4)
sum = s0 + s1
self.assertTrue(sum.node._optimized_summation)
def assert_optimized(sym):
self.assertTrue(sym.node._optimized_summation)
def assert_not_optimized(sym):
self.assertFalse(getattr(sym.node, "_optimized_summation", False))
assert_optimized(sum)
# add duplicate symbol
assert_not_optimized(sum + s0)
# add constant.
assert_not_optimized(sum + 1)
# add new unique symbol, should maintain _optimized_summation property.
assert_optimized(sum + s2)
assert_optimized(s2 + sum)
# add x + (a+b) with no _optimized_summation on the rhs sum.
a = create_symint(shape_env, 10)
b = create_symint(shape_env, 11)
two_sum = torch.sym_sum([a, b])
assert_not_optimized(two_sum)
assert_optimized(sum + two_sum)
# adding two expressions of length >2 that are _optimized_summation.
a = s0 + s1 + s2
s3 = create_symint(shape_env, 10)
s4 = create_symint(shape_env, 20)
s5 = create_symint(shape_env, 30)
b = s3 + s4 + s5
assert_optimized(a)
assert_optimized(b)
assert_not_optimized(a + b)
assert_not_optimized(b + a)
assert_not_optimized(b + a + b)
def test_max_of_unique_summation_opt(self):
shape_env = ShapeEnv()
s0 = shape_env.create_unbacked_symint()
s1 = shape_env.create_unbacked_symint()
s2 = shape_env.create_unbacked_symint()
s3 = shape_env.create_unbacked_symint()
s4 = shape_env.create_unbacked_symint()
s5 = shape_env.create_unbacked_symint()
s7 = shape_env.create_unbacked_symint()
def assert_optimized(sym):
self.assertTrue(sym.node.expr.unique_summations_symbols is not None)
def assert_not_optimized(sym):
getattr(sym.node.expr, "unique_summations_symbols", None)
mx1 = torch.sym_max(s0, s1)
assert_not_optimized(mx1)
mx2 = torch.sym_max(s0 + s1, s2 + s3)
assert_optimized(mx2)
mx3 = torch.sym_max(mx2, s4 + s5)
assert_optimized(mx3)
assert_optimized(torch.sym_max(s4 + s5, mx2))
assert_not_optimized(torch.sym_max(mx3, s7))
assert_not_optimized(torch.sym_max(mx3, 10))
assert_not_optimized(torch.sym_max(mx3, s3 + s7))
assert_not_optimized(torch.sym_max(mx3, s7 * 2))
def test_sym_max_multi_max_simplify(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
self.assertTrue(
statically_known_true(
torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0)
)
)
def test_numpy_sym_max(self):
self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
def test_numpy_sym_min(self):
self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
def test_debug_has_internal_overlap_unbacked(self):
shape_env = ShapeEnv()
u0 = shape_env.create_unbacked_symint()
torch._check_is_size(u0)
cf = torch._debug_has_internal_overlap
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
Max = torch.sym_max
self.assertEqual(
cf(
torch.empty_strided(
(2, 3, 1, u0),
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
device="meta",
)
),
2,
)
# Wobbling these to zero is OK too
self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
def test_specialize_zero_one(self):
shape_env = ShapeEnv(specialize_zero_one=True)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(specialize_zero_one=False)
a0 = create_symint(shape_env, 5)
assert a0 != 1
self.assertEqual(len(shape_env.guards), 1)
def test_duck_shape(self):
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 0)
shape_env = ShapeEnv(duck_shape=False)
a0 = create_symint(shape_env, 5)
a1 = create_symint(shape_env, 5)
assert a0 == a1
self.assertEqual(len(shape_env.guards), 1)
def test_int_bool(self):
# See https://github.com/pytorch/pytorch/issues/95981
shape_env = ShapeEnv(duck_shape=True)
a0 = create_symint(shape_env, 5)
assert a0
self.assertEqual(len(shape_env.guards), 0)
def test_symint_as_scalar(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
sym_int_encountered = False
class TestSymInt(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
assert func == torch.ops.aten.add.Tensor
nonlocal sym_int_encountered
# WARNING: do not do identity tests on the outer
# SymInt/SymFloat, they are NOT STABLE
sym_int_encountered = kwargs["alpha"].node is a0.node
kwargs["alpha"] = 0
return func(*args)
x = torch.rand([4, 4])
with TestSymInt():
y = torch.add(x, x, alpha=a0)
self.assertTrue(sym_int_encountered)
def test_deepcopy(self):
shape_env = ShapeEnv()
a0 = create_symint(shape_env, 2)
assert a0 < 4
new_shape_env = copy.deepcopy(shape_env)
self.assertEqual(len(new_shape_env.guards), 1)
def test_print_readable_with_symints(self):
def f(a, b):
dim0 = a.shape[0] + b.shape[0]
dim1 = a.shape[1] + b.shape[1]
d = a.new_empty(dim0, dim1)
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
return d
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
out = fx_g.print_readable(print_output=False)
self.assertExpectedInline(
out.strip(),
"""\
class f(torch.nn.Module):
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
# No stacktrace found for following nodes
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
return (getitem, getitem_1)""", # noqa: B950
)
def test_statically_known_true(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertTrue(statically_known_true(True))
self.assertTrue(statically_known_true(s2 == s2))
self.assertTrue(statically_known_true(s2 * s3 > s3))
self.assertTrue(statically_known_true(s3 * s4 > s4))
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
# Statically known false
self.assertFalse(statically_known_true(False))
self.assertFalse(statically_known_true(s3 * s4 <= s4))
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_true(s2 + s2 == s4))
self.assertFalse(statically_known_true(s4 % s2 == 0))
self.assertFalse(statically_known_true(s2 != s3))
self.assertFalse(statically_known_true(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_true(s2 == s3))
self.assertFalse(statically_known_true(s2 > s3))
self.assertFalse(statically_known_true(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
def test_statically_known_false(self):
shape_env = ShapeEnv()
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
# Statically known true
self.assertFalse(statically_known_false(True))
self.assertFalse(statically_known_false(s2 == s2))
self.assertFalse(statically_known_false(s2 * s3 > s3))
self.assertFalse(statically_known_false(s3 * s4 > s4))
self.assertFalse(statically_known_false((s3 + s3) % 2 == 0))
# Statically known false
self.assertTrue(statically_known_false(False))
self.assertTrue(statically_known_false(s3 * s4 <= s4))
self.assertTrue(statically_known_false((s3 + s3) % 2 == 1))
# True for hints, but not known statically
self.assertFalse(statically_known_false(s2 + s2 == s4))
self.assertFalse(statically_known_false(s4 % s2 == 0))
self.assertFalse(statically_known_false(s2 != s3))
self.assertFalse(statically_known_false(s3 * s4 > s2))
# False for hints, but not known statically
self.assertFalse(statically_known_false(s2 == s3))
self.assertFalse(statically_known_false(s2 > s3))
self.assertFalse(statically_known_false(s3 + s3 == s4))
# No guards should be generated
self.assertEqual(len(shape_env.guards), 0)
def test_ephemeral_source_simplification(self):
from torch._dynamo.source import EphemeralSource
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
# of construction order or check order.
for construct_ephemeral_first, x_first_in_check in itertools.product(
[False, True], [False, True]
):
shape_env = ShapeEnv()
shape = (5, 10)
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=dynamic_dims,
)
t_with_ephemeral = x if construct_ephemeral_first else y
def _get_ephemeral_source_symbols(t):
return [
s.node.expr
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
if isinstance(s, torch.SymInt)
and s.node.expr in shape_env.var_to_sources
and any(
source.is_ephemeral()
for source in shape_env.var_to_sources[s.node.expr]
)
]
# these checks should simplify out the ephemeral symbols, regardless of the
# ordering x == y or y == x
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
if x_first_in_check:
torch._check(x.size() == y.size())
torch._check(x.stride() == y.stride())
torch._check(x.storage_offset() == y.storage_offset())
else:
torch._check(y.size() == x.size())
torch._check(y.stride() == x.stride())
torch._check(y.storage_offset() == x.storage_offset())
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
from torch._dynamo.source import EphemeralSource
for construct_ephemeral_first in (False, True):
shape_env = ShapeEnv()
shape = (5, 10)
# use duck sizing here to ensure symbol reuse across x and y
duck_dims = [DimDynamic.DUCK for _ in shape]
x = create_symbolic_tensor(
"x",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
y = create_symbolic_tensor(
"y",
torch.randn(*shape),
shape_env,
source=(EphemeralSource() if not construct_ephemeral_first else None),
dynamic_dims=duck_dims,
)
# regardless of construction order, non-ephemeral sources should be preferred
# first in the var_to_sources list for potential guarding later on
for source_list in shape_env.var_to_sources.values():
self.assertFalse(source_list[0].is_ephemeral())
self.assertEqual(x.size(), y.size())
self.assertEqual(x.stride(), y.stride())
self.assertEqual(x.storage_offset(), y.storage_offset())
def test_tensor_factory_with_symint(self):
args = list(range(0, 3))
expected = torch.tensor(args)
shape_env = ShapeEnv()
sym_args = [create_symint(shape_env, i) for i in args]
# test tensor factories
for dt in all_types_and(torch.half, torch.bfloat16):
res = torch.tensor(sym_args, dtype=dt)
self.assertEqual(res, expected, exact_dtype=False)
# test legacy tensor factories
legacy_ctors = [
torch.Tensor,
torch.LongTensor,
torch.DoubleTensor,
torch.FloatTensor,
torch.IntTensor,
torch.ShortTensor,
torch.HalfTensor,
torch.ByteTensor,
]
for Tensor in legacy_ctors:
res = Tensor(sym_args)
self.assertEqual(res, expected, exact_dtype=False)
def test_backed_size_oblivious_01_spec(self):
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
@torch.compile(dynamic=True, fullgraph=True)
def f(a, b):
if guard_size_oblivious(a.size(0) == 1):
return b * 10
else:
return b * 20
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
# always go to the >= 2 branch.
self.assertEqual(
f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
def test_baddbmm_symint(self):
from torch._subclasses.fake_tensor import FakeTensorMode
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
B, M, K, N = [shape_env.create_unbacked_symint() for _ in range(4)]
with fake_mode:
A = torch.empty((B, M, K), device="meta")
Bmat = torch.empty((B, K, N), device="meta")
bias3 = torch.empty((B, M, N), device="meta")
_ = torch.baddbmm(bias3, A, Bmat)
@skipIfTorchDynamo(
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
)
class TestSymNumberMagicMethods(TestCase):
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
# Helper function
# NB: don't use one as that will get specialized
# TODO: We don't have to circuitously create the float, can just
# create a symfloat directly
seed_node = (create_symint(shape_env, 2) / 2.0).node
bool_seed_node = (create_symint(shape_env, 2) == 2).node
def get_sym_inp(inp):
# NB: this must come before int
if isinstance(inp, bool):
return torch.SymBool(to_node(bool_seed_node, inp))
elif isinstance(inp, int):
return torch.SymInt(to_node(seed_node, inp))
else:
return torch.SymFloat(to_node(seed_node, inp))
if fn == "float_pow":
if inp1 < 0:
return
if fn == "pow_by_natural":
if isinstance(inp1, float) or isinstance(inp2, float):
return
if inp2 < 0:
return
def maybe_xfail(inp1, inp2):
if fn == "sym_sqrt" and inp1 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
elif (
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
and inp2 == 0
):
# ZeroDivisionError: division by zero
return self.assertRaises((ZeroDivisionError,))
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
# ZeroDivisionError: 0.0 cannot be raised to a negative power
return self.assertRaises((ZeroDivisionError,))
elif (
# TODO: dear catastrophe waitress,
# this doesn't work
fn in ["float_pow", "pow_by_natural"]
and inp1 < 0
and (
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
)
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
):
# Complex result, which we do not support:
# TypeError: Cannot convert complex to float
return self.assertRaises((RuntimeError,))
elif fn in ("lshift", "rshift") and not (
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
):
# TypeError: unsupported operand type(s)
return self.assertRaises((TypeError,))
elif fn in ("lshift", "rshift") and inp2 < 0:
# ValueError: math domain error
return self.assertRaises((ValueError,))
else:
return contextlib.nullcontext()
lambda_apply = method_to_operator(fn)
def guard_fn(v):
if type(v) in (SymBool, bool):
return guard_bool(v)
elif type(v) in (SymFloat, float):
return guard_float(v)
else: # SymInt, int
return guard_int(v)
# Get reference result
with maybe_xfail(inp1, inp2):
if is_unary_fn:
ref_out = lambda_apply(inp1)
else:
ref_out = lambda_apply(inp1, inp2)
# Symified first arg
sym_inp1 = get_sym_inp(inp1)
with maybe_xfail(sym_inp1, inp2):
if is_unary_fn:
out = lambda_apply(sym_inp1)
else:
out = lambda_apply(sym_inp1, inp2)
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
if is_unary_fn:
return
# Symified second arg
sym_inp2 = get_sym_inp(inp2)
with maybe_xfail(inp1, sym_inp2):
out = lambda_apply(inp1, sym_inp2)
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
# Symified both args
with maybe_xfail(sym_inp1, sym_inp2):
out = lambda_apply(sym_inp1, sym_inp2)
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
out = guard_fn(out)
self.assertEqual(out, ref_out)
@parametrize("fn", list(sym_node.magic_methods.keys()))
def test_bool_method(self, fn):
# sym_ite has its own tests
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
self.skipTest(f"{fn} is non-bool")
is_unary_fn = fn in sym_node.unary_methods
shape_env = ShapeEnv()
self._do_test(fn, True, False, shape_env, is_unary_fn)
@parametrize("fn", list(sym_node.magic_methods.keys()))
@parametrize("first_type", ["int", "float"])
@parametrize("second_type", ["int", "float"])
def test_method(self, fn, first_type, second_type):
if first_type == "float":
# TODO: Hmm, this looks like we skip all floats
self.skipTest(f"{fn} is not a float magic method")
if (
first_type == "int" or second_type == "int"
) and fn in sym_node.only_float_magic_methods:
self.skipTest(f"{fn} is not an int method")
if second_type == "float" and fn in ["mod"]:
self.skipTest(f"{fn} only handles int")
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
self.skipTest(f"{fn} is a bitwise op, only handles int")
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
# Second argument is ignored for unary function. So only run for one type
if is_unary_fn and second_type == "float":
self.skipTest(f"{fn} is unary and already tested")
if fn in sym_node.bool_magic_methods:
self.skipTest(f"{fn} is bool")
# Only floats here since these will be converted to int if necessary.
# We also ignore complex and bool.
values = (
0.0,
1.0,
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
)
neg_values = tuple(-x for x in values)
for inp1, inp2 in itertools.chain(
itertools.product(values, values),
itertools.product(values, neg_values),
itertools.product(neg_values, values),
itertools.product(neg_values, neg_values),
):
if first_type == "int":
inp1 = int(inp1)
if second_type == "int":
inp2 = int(inp2)
shape_env = ShapeEnv()
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
def get_constant_bool(self, val):
return SymBool(torch._C._get_constant_bool_symnode(val))
@unittest.expectedFailure
def test_symint_hashing(self):
shape_env = ShapeEnv()
hash(create_symint(shape_env, 3))
def test_symnode_hashing(self):
shape_env = ShapeEnv()
# These all trigger specialization when hashed
hash(create_symbool(shape_env, True))
# We should be passing in float here, but create_symbol currently
# only supports int
hash(create_symfloat(shape_env, 3.0))
# NestedInt (SymInt), constant SymBool, SymNode are hashable
j1 = torch._C._get_nested_int(1, 1)
j1_copy = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(2, 1)
t = self.get_constant_bool(True)
t_copy = self.get_constant_bool(True)
f = self.get_constant_bool(False)
n = create_symint(shape_env, 3).node
m = self.get_constant_bool(True).node
self.assertIs(j1 == j1_copy, True)
self.assertEqual(hash(j1), hash(j1_copy))
self.assertIs(j1 == j2, False)
self.assertNotEqual(hash(j1), hash(j2))
self.assertIs(t == t_copy, True)
self.assertEqual(hash(t), hash(t_copy))
self.assertIs(t == f, False)
self.assertNotEqual(hash(t), hash(f))
hash(n)
hash(m)
def test_symint_deepcopy(self):
shape_env = ShapeEnv()
symnodes = (torch._C._get_nested_int(1, 1),)
deepcopied_symnodes = copy.deepcopy(symnodes)
self.assertEqual(symnodes, deepcopied_symnodes)
def test_non_symbolic_symnode(self):
j1 = torch._C._get_nested_int(1, 1)
j2 = torch._C._get_nested_int(1, 1)
j3 = torch._C._get_nested_int(3, 1)
self.assertIsInstance(j1, torch.SymInt)
self.assertNotIsInstance(j1, int)
with self.assertRaisesRegex(
RuntimeError, "add not supported by NestedIntSymNode"
):
j1 + 3
self.assertFalse(j1 == 3)
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
self.assertFalse(3 >= j2)
self.assertIs(j1 == j1, True)
self.assertIs(j1 == j2, True)
self.assertIs(j1 == j3, False)
self.assertIs(j1 != j3, True)
self.assertIs(j1 != j2, False)
x = self.get_constant_bool(True)
#
# Unary
#
# op(constant SymBool)
self.assertIs(x.__sym_not__(), False)
#
# Binary
#
# op(constant SymBool, bool)
# op(constant SymBool, constant SymBool)
# op(bool, constant SymBool)
self.assertIs(operator.and_(x, True), True)
self.assertIs(operator.and_(x, x), True)
self.assertIs(operator.and_(True, x), True)
# op(symbolic SymBool, constant Symbool)
# op(constant SymBool, symbolic Symbool)
shape_env = ShapeEnv()
a = create_symint(shape_env, 2)
b = create_symint(shape_env, 2)
c = a == b # symbolic SymBool
d = self.get_constant_bool(True)
e = operator.and_(c, d)
f = operator.and_(d, c)
self.assertTrue(is_symbolic(e))
self.assertTrue(is_symbolic(f))
self.assertIs(e.node.guard_bool("", 0), True)
self.assertIs(f.node.guard_bool("", 0), True)
# Comparing sizes
sz1 = torch.Size([j1, j1, j1])
sz2 = torch.Size([j1, j1, j1])
self.assertIs(sz1 == sz2, True)
sz1 = torch.Size([3, j1, 4])
sz2 = torch.Size([3, j2, 4])
self.assertIs(sz1 == sz2, True)
self.assertIs(sz1 != sz2, False)
def test_stride_symnode(self):
shape_env = ShapeEnv()
# check everything static
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
self.assertTrue(all(isinstance(size, int) for size in t.size()))
self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
# check dynamic size but static dims
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
# Expect stride to be inferred
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, torch.SymInt))
self.assertTrue(isinstance(s1, torch.SymInt))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(s1 == s2)
self.assertEqual(s3, 1)
# Check dynamic stride but static dims
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.STATIC,
DimDynamic.STATIC,
],
dynamic_strides=[
DimDynamic.DYNAMIC,
DimDynamic.INFER_STRIDE,
],
)
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, int))
self.assertTrue(isinstance(s1, int))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(isinstance(s3, int))
# Check dynamic sizes and dims, and ensure different symbol
t = create_fake_tensor_with_dynamic_size(
torch.ones(3, 6),
shape_env,
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.DYNAMIC,
DimDynamic.INFER_STRIDE,
],
)
s0, s1 = t.size()
s2, s3 = t.stride()
self.assertTrue(isinstance(s0, torch.SymInt))
self.assertTrue(isinstance(s1, torch.SymInt))
self.assertTrue(isinstance(s2, torch.SymInt))
self.assertTrue(isinstance(s3, int))
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
instantiate_parametrized_tests(TestSymNumberMagicMethods)
class TestFloorDiv(TestCase):
@staticmethod
def python_floordiv(x, y):
return x // y
@staticmethod
def torch_floordiv(x, y):
# Note: we fully evaluate here since FloorDiv might not always do
# that.
shape_env = ShapeEnv()
return shape_env.evaluate_expr(FloorDiv(x, y))
@staticmethod
def yield_test_cases(values, negate=True):
for x, y in values:
yield (x, y)
if negate:
yield (-x, y)
yield (x, -y)
yield (-x, -y)
def test_floordiv_float_int(self):
values = ((7, 2),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_div_by_one(self):
values = ((2, 1),)
for x, y in TestFloorDiv.yield_test_cases(values):
self.assertEqual(
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
)
def test_floordiv_simplify(self):
# Tests how we simplify or evaluate FloorDiv without free variables
shape_env = ShapeEnv()
result = 21
exprs = (7 * FloorDiv(6, 2),)
for expr in exprs:
self.assertEqual(expr, result)
self.assertEqual(expr.doit(deep=False), result)
self.assertEqual(expr.doit(deep=True), result)
self.assertEqual(sympy.simplify(expr), result)
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_assumptions(self):
cases = (
sympy.Symbol("i1", integer=True),
sympy.Symbol("i2", integer=True),
)
for base, divisor in itertools.product(cases, repeat=2):
def op():
return FloorDiv(base, divisor)
def is_complex(x):
return x.is_integer is False and x.is_real is False and x.is_complex
if is_complex(base) or is_complex(divisor):
self.assertRaisesRegex(
TypeError,
(
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
r" expected integer or real"
),
op,
)
continue
op = op()
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
# always returns an integer 1 when both args are the same object.
# This even works for Symbols with no assumptions specified.
if base is divisor:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
elif base.is_integer and divisor.is_integer:
self.assertTrue(op.is_integer)
self.assertTrue(op.is_real)
else:
self.assertEqual(op.is_integer, None)
self.assertTrue(op.is_real)
class TestDimConstraints(TestCase):
@skipIfTorchDynamo("mark_dynamic not supported")
def test_simplify_max_1_0(self):
x = torch.rand(10)
torch._dynamo.mark_dynamic(x, 0, max=20, min=5)
@torch.compile(fullgraph=True)
def func(x, v):
# test that statically_known_true
if (v == 0 or v == 1) and not statically_known_true(
max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2)
):
raise AssertionError("error")
if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2):
return x * 400
else:
return (x * 10) * 100
# testing that this does not throw constraint violation error.
self.assertEqual(func(x, 1), x * 400)
self.assertEqual(func(x, 0), x * 400)
def test_dim_constraints_reduce_congruences_simple(self):
from sympy import Symbol
s = Symbol("s", positive=True, integer=True)
dim_constraints = DimConstraints({}, {}, set(), {})
dim_constraints._congruences[s] = {
(s / 2) % 2,
(s / 2) % 8,
(s / 2) % 4,
s % 2,
((s / 16) + 2) % 4,
}
congruences = dim_constraints._reduce_congruences()
self.assertEqual(congruences[s], {(s + 32) % 64})
def test_dim_constraints_reduce_inequalities_simple(self):
from sympy import Eq, Interval, Ne, Symbol
from sympy.solvers.inequalities import reduce_inequalities
s = Symbol("s", positive=True, integer=True)
exprs = {
s >= 2,
Ne(8 * s, 16),
Ne(s / 2, 1),
Ne(16 * s, 32),
s < 16,
Ne(s, 2),
s / 2 < 16,
s / 2 > 1,
s / 2 >= 2,
Ne(3 * s / 2, 3),
}
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, Interval.Ropen(4, 16))
exprs.add(Eq(s / 2, 4))
solution = reduce_inequalities(exprs, s).as_set()
self.assertEqual(solution, {8})
def test_dim_constraints_reduce_inequalities_error(self):
from collections import defaultdict
from sympy import Symbol
from sympy.solvers.inequalities import reduce_inequalities
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
s0 = Symbol("s0", positive=True, integer=True)
exprs = {
4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
s0 >= 2,
s0**3 <= 2147483647,
s0 <= 2147483647,
}
answer = reduce_inequalities(exprs, s0)
symbol_to_source = defaultdict(list)
symbol_to_source[s0].append(
TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
)
dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
with self.assertRaisesRegex(
AssertionError,
"Unknown symbol.*created by constraints solver",
):
dcp.doprint(answer)
def test_dim_constraints_solve_full(self):
from sympy import Eq, Integer, Ne, Symbol
from torch._dynamo.source import (
LocalSource,
TensorProperty,
TensorPropertySource,
)
src0 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
)
src2 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
)
src3 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
)
src4 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
)
src1 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
)
src7 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
)
src5 = TensorPropertySource(
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
)
src8 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
)
src6 = TensorPropertySource(
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
)
src9 = TensorPropertySource(
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
)
src10 = TensorPropertySource(
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
)
src11 = TensorPropertySource(
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
)
src12 = TensorPropertySource(
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
)
s0 = Symbol("s0", positive=True, integer=True)
s1 = Symbol("s1", positive=True, integer=True)
s5 = Symbol("s5", positive=True, integer=True)
s6 = Symbol("s6", positive=True, integer=True)
symbol_to_source = {
s0: [src0, src2, src3, src4],
s1: [src1, src7],
s5: [src5, src8],
s6: [src6, src9, src10],
}
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
marked_dynamic = {s0, s1, s5, s6}
dim_constraints = DimConstraints(
symbol_to_source, var_to_val, marked_dynamic, {}
)
dim_constraints.add_equality(src2, s0)
dim_constraints.add_equality(src3, s0)
dim_constraints.add_equality(src4, s0)
dim_constraints.add_equality(src7, s1)
dim_constraints.add_equality(src8, s5)
dim_constraints.add_equality(src9, s6)
dim_constraints.add_equality(src10, s6)
dim_constraints.add_equality(src11, Integer(1))
dim_constraints.add_equality(src12, Integer(3))
dim_constraints.add(s1**2 <= 2147483647)
dim_constraints.add(32 * s1**2 <= 2147483647)
dim_constraints.add(s0 < 16)
dim_constraints.add(Eq(Mod(s1, 2), 0))
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 64
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
+ 1
> 1
)
dim_constraints.add(
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 128
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
+ 1
> 1
)
dim_constraints.add(
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 256
<= 2147483647
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1,
1,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
+ 1
> 1
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
dim_constraints.add(
Ne(
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * s0,
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Eq(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
+ 60 * (Mod(s0, 2)),
0,
)
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
0,
)
)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
0,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= 0
)
dim_constraints.add(
1
< 60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
)
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(Ne(16 * s0, 32))
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
dim_constraints.add(FloorDiv(s0, 2) >= 2)
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
dim_constraints.add(1 < FloorDiv(s0, 2))
dim_constraints.add(Ne(s0, 2))
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 0
)
dim_constraints.add(
Ne(
60
* (FloorDiv(s0, 2))
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120
* FloorDiv(s0, 2)
* FloorDiv(s0, (FloorDiv(s0, 2)))
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
)
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
0,
)
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
Ne(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20,
20,
)
)
dim_constraints.add(
Ne(
20
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
),
0,
)
)
dim_constraints.add(
Ne(
20
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
)
)
- 20
* Mod(
1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 0
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
>= 2
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
> 1
)
dim_constraints.add(
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 20
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
)
dim_constraints.add(
Ne(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60,
60,
)
)
dim_constraints.add(
Ne(
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
)
)
dim_constraints.add(
Eq(
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
)
)
- Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
- 2
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
1,
),
0,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
)
)
dim_constraints.add(Ne(8 * s0, 16))
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
<= 2147483647
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 2147483647
)
dim_constraints.add(FloorDiv(s0, 2) < 16)
dim_constraints.add(FloorDiv(s0, 2) > 1)
dim_constraints.add(
Ne(
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1
> 1
)
dim_constraints.add(
60
* (FloorDiv(s0, (FloorDiv(s0, 2))))
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 1
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2)),
3 * (FloorDiv(s0, 2)),
)
)
dim_constraints.add(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
1
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
6,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 0
)
dim_constraints.add(
Ne(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 2147483647
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
<= 20480
)
dim_constraints.add(
Ne(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90,
0,
)
)
dim_constraints.add(
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 120
> 1
)
dim_constraints.add(
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 90
<= 20480
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 60
<= 20480
)
dim_constraints.add(
Ne(
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 240,
0,
)
)
dim_constraints.add(Eq(6 * s5, 132))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
<= 2147483647
)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 64
> 1
)
dim_constraints.add(
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
<= 2147483647
)
dim_constraints.add(
Ne(
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 62
)
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
dim_constraints.add(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
<= 2147483647
)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
1,
)
)
dim_constraints.add(
Ne(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9,
0,
)
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 0
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
dim_constraints.add(
1
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576
)
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
dim_constraints.add(
Ne(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2)),
256,
)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
4,
)
),
0,
)
)
dim_constraints.add(
Eq(
FloorDiv(s0, 2),
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
)
)
dim_constraints.add(
Eq(
FloorDiv(
(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2))
),
4,
),
FloorDiv(s0, 2),
)
)
dim_constraints.add(
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
)
dim_constraints.add(
Eq(
64
* (
Mod(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 1,
4,
)
),
0,
)
)
dim_constraints.add(
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576 * (FloorDiv(s0, 2))
> 0
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
>= 1
)
dim_constraints.add(
Eq(
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 576,
256,
)
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
<= 2147483647
)
dim_constraints.add(
Ne(
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
<= 2147483647
)
dim_constraints.add(
Ne(
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9 * (FloorDiv(s0, 2)),
0,
)
)
dim_constraints.add(
1
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
)
dim_constraints.add(
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 9
> 1
)
dim_constraints.add(
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
+ 540
> 1
)
dim_constraints.add(s0 >= 2)
dim_constraints.add(s1 >= 2)
dim_constraints.add(s6 >= 2)
dim_constraints.add(s5 >= 2)
dim_constraints.solve()
self.assertEqual(
dim_constraints._static_results,
{
"L['c'].size()[0] == 8",
"L['d'].size()[0] == 8",
"L['a'].size()[2] == 96",
"L['f'].size()[1] == 1",
"L['a'].size()[3] == 96",
"L['b'].size()[2] == 3",
"L['b'].size()[1] == 22",
"L['b'].size()[0] == 8",
"L['a'].size()[1] == 22",
"L['a'].size()[0] == 8",
},
)
self.assertEqual(
dim_constraints._dynamic_results,
{
"2 <= L['c'].size()[1]",
"L['d'].size()[1] == L['c'].size()[1]",
"L['e'].size()[1] == L['c'].size()[1]",
},
)
class TestGuardsExpressions(TestCase):
"""
Tests the guards-related methods used by the inductor FX graph cache.
"""
def test_guards_gt_lt(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 6)
s1 = create_symint(shape_env, 7)
s2 = create_symint(shape_env, 5)
guard_int(sym_int(s0 > 5))
guard_int(sym_int(s0 < 7))
guards = shape_env.produce_guards_expression([s0])
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
def test_guards_float_print(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 3)
guard_bool(2 / s0 == 2 / 3)
guards = shape_env.produce_guards_expression([s0])
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_guard_or_true(self):
from torch.fx.experimental.symbolic_shapes import guard_or_true
def func(a, b):
x = a.item()
if guard_or_true(x == 1):
return b * 10
else:
return b * 20
# eager.
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
# compile with unbacked.
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func(a, b)
# always return b*10
self.assertEqual(
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
)
self.assertEqual(
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
)
# Test that statically known true works.
def func2(a, b):
x = a.item()
if guard_or_true(x != x):
return b * 10
else:
return b * 20
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func2(a, b)
# always return b*20
self.assertEqual(
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
self.assertEqual(
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
)
# Test backed_size_oblivious
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
def func3(a, b):
if guard_or_true(a.size()[0] != 9):
return b * 10
else:
return b * 20
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
a = torch.rand(9, 2)
b = torch.rand(3, 4)
self.assertEqual(func3(a, b), b * 20)
self.assertEqual(compiled(a, b), b * 10)
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_guard_or_false(self):
from torch.fx.experimental.symbolic_shapes import guard_or_false
def func(a, b):
x = a.item()
if guard_or_false(x == 1):
return b * 10
else:
return b * 20
# eager.
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
# compile with unbacked.
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func(a, b)
# always return b*20
self.assertEqual(
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
)
self.assertEqual(
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
)
# Test that statically known true works.
def func2(a, b):
x = a.item()
if guard_or_false(x == x):
return b * 10
else:
return b * 20
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
a = torch.tensor([1])
b = torch.tensor([1])
unbacked_func2(a, b)
# always return b*10
self.assertEqual(
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
)
self.assertEqual(
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
)
# Test backed_size_oblivious
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
def func3(a, b):
if guard_or_false(a.size()[0] == 9):
return b * 10
else:
return b * 20
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
a = torch.rand(9, 2)
b = torch.rand(3, 4)
self.assertEqual(func3(a, b), b * 10)
self.assertEqual(compiled(a, b), b * 20)
def test_guards_float_div(self):
shape_env = ShapeEnv()
s0 = create_symint(shape_env, 8)
s1 = create_symint(shape_env, 7)
guard_int(sym_int(s0 / 2.0))
guards = shape_env.produce_guards_expression([s0])
self.assertIn("math.trunc(", guards)
self.assertIn("float(", guards)
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
def test_remove_symbols_without_guarding(self):
from torch._functorch.partitioners import _remove_symbols_without_guarding
shape_env = ShapeEnv()
x = create_fake_tensor_with_dynamic_size(
torch.randn(5, 8),
shape_env,
dynamic_sizes=[
DimDynamic.DYNAMIC,
DimDynamic.DYNAMIC,
],
dynamic_strides=[
DimDynamic.INFER_STRIDE,
DimDynamic.INFER_STRIDE,
],
)
self.assertEqual(f"{x.stride()}", "(s49, 1)")
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
x_clean = _remove_symbols_without_guarding(x, 4096)
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
for node in graph.nodes:
if node.name == "arg3_1":
assert node.meta["val"].size()[0] == 2
return graph
class TestUnbacked(TestCase):
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_neq_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a):
torch._check(a.item() != 5)
return a.item() * 10
func(torch.tensor([100]))
with self.assertRaises(RuntimeError):
func(torch.tensor([5]))
# Test a situation where we generate a runtime assert i.e: u1==s1, then we specialize s1
# later on to a constant.
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_post_specialize_runtime_assert1(self, backend):
@torch.compile(dynamic=True, backend=backend)
def func(x, y):
u0 = y.item()
s0 = x.size()[0]
s1 = x.size()[1]
torch._check(u0 + s0 + s1 == 102)
assert s0 == 2
return x * 10
func(torch.rand(2, 50), torch.tensor([50]))
with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51]))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
@parametrize("backend", ["inductor", "eager"])
def test_post_specialize_runtime_assert2(self, backend):
@torch.compile(dynamic=True, backend=backend)
def func(x, y):
u0 = y.item()
s0 = x.size()[0]
s1 = x.size()[1]
torch._check(u0 + s0 + s1 == 102)
return x * 10
func(torch.rand(2, 50), torch.tensor([50]))
with self.assertRaises(RuntimeError):
func(torch.rand(2, 50), torch.tensor([51]))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_or_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a, b):
torch._check(operator.or_(a.item() == 5, b.item() == 5))
return a.item() * 10
func(torch.tensor([5]), torch.tensor([100]))
func(torch.tensor([100]), torch.tensor([5]))
def test_has_free_symbols(self):
self.assertFalse(has_free_symbols(sympy.S.true))
self.assertFalse(has_free_symbols(sympy.Max(1, 10, evaluate=False)))
self.assertFalse(has_free_symbols(sympy.sympify("1")))
self.assertFalse(has_free_symbols(sympy.sympify("1.1")))
self.assertTrue(has_free_symbols(sympy.sympify("a")))
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
def test_deferred_sym_eq_assert(self, backend):
@torch.compile(fullgraph=True, backend=backend)
def func(a, b):
torch._check(b.item() == 5)
return a * 10
func(torch.tensor([5]), torch.tensor([5]))
with self.assertRaises(RuntimeError):
func(torch.tensor([100]), torch.tensor([1]))
@torch._dynamo.config.patch("capture_scalar_outputs", True)
@parametrize("backend", ["inductor", "eager"])
@skipIfTorchDynamo("mark_unbacked is not traceable")
def test_deferred_with_unbacked_input(self, backend):
@torch.compile(fullgraph=True, dynamic=True, backend=backend)
def func(a, b):
torch._check(a.size()[0] == b.size()[0])
return a * 10
a = torch.rand(1, 1)
b = torch.rand(1, 1)
torch._dynamo.decorators.mark_unbacked(a, 0)
torch._dynamo.decorators.mark_unbacked(b, 0)
func(a, b)
with self.assertRaises(RuntimeError):
func(a, torch.rand(2, 1))
class TestUbackedOps(TestCase):
@fresh_cache()
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape1(self):
cnt = CompileCounterWithBackend("inductor")
# Reshape happens in place reshape (no-clone)
# reshape u1 -> (u0*u0)
def func(x, y):
f = y.item()
t1 = x.view((f, f))
t2 = x.reshape((f, f))
# TODO avoid _check_is_size here.
torch._check_is_size(f)
return t1 * 10, t2 * 10
compiled_func = torch.compile(
fullgraph=True,
backend=cnt,
dynamic=True,
)(func)
# create a non-contiguous with data being even numbers in [0:cnt-1]
# and reshape it into sqrt(cnt)*sqrt(cnt)
def make_non_contiguous_tensor_and_test(cnt):
# create a non-contiguous tensor x that is skipping odd indices.
x = torch.arange(cnt * 2)
x = x.as_strided((x.size()[0] // 2,), (2,))
torch._dynamo.decorators.mark_unbacked(x, 0)
sz = torch.tensor([int(math.sqrt(cnt))])
compiled_result = compiled_func(x, sz)
eager_result = func(x, sz)
self.assertEqual(compiled_result, eager_result)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
with ctx():
make_non_contiguous_tensor_and_test(4)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
return (mul_9, mul_12)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
make_non_contiguous_tensor_and_test(49)
self.assertEqual(cnt.frame_count, 1)
# Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization).
# marking strides unbacked would have avoided the recompilation here.
x = torch.arange(100)
torch._dynamo.decorators.mark_unbacked(x, 0)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
with ctx():
compiled_result = compiled_func(x, torch.tensor([10]))
eager_result = func(x, torch.tensor([10]))
self.assertEqual(compiled_result, eager_result)
self.assertEqual(cnt.frame_count, 2)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
return (mul_4, mul_7)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
x = torch.arange(25)
compiled_result = compiled_func(x, torch.tensor([5]))
eager_result = func(x, torch.tensor([5]))
self.assertEqual(cnt.frame_count, 2)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_reshape2(self):
cnt = CompileCounterWithBackend("inductor")
# This reshape requires a clone when the input is not contiguous and we cant compute strides.
# reshape (u2, u3) -> (u0, u1)
def func(x, y):
u0, u1 = y.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
result1 = torch.reshape(x, (u0, u1))
return result1 * 10
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
with ctx():
result_eager = func(x, torch.tensor([5, 20]))
result_compiled = compiled_func(x, torch.tensor([5, 20]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""\
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
return (mul_21,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
result_eager = func(x, torch.tensor([2, 50]))
result_compiled = compiled_func(x, torch.tensor([2, 50]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
x = torch.randn(4, 4).t_()
result_eager = func(x, torch.tensor([2, 8]))
result_compiled = compiled_func(x, torch.tensor([2, 8]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 1)
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
log_stream, ctx = logs_to_string(
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
)
with ctx():
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
# but not anymore since we use contiguous_or_false .
# We need a way to mark strides unbacked to avoid the recompilation here.
x = torch.randn(10, 10)
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
aot_graphs,
"""""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
result_compiled = compiled_func(x, torch.tensor([2, 50]))
result_eager = func(x, torch.tensor([2, 50]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
x = torch.randn(4, 4)
result_eager = func(x, torch.tensor([2, 8]))
result_compiled = compiled_func(x, torch.tensor([2, 8]))
self.assertEqual(result_compiled, result_eager)
self.assertEqual(cnt.frame_count, 2)
@unittest.skip("this test fails due to inductor/autograd issue #153041")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_unbacked_non_contigious_reshape_failing(self):
# reshape u1 -> (u0*u0)
# this result in the tensor "i64[u0, u0][s7*u0, s7].
# reshape happens in place reshape (no-clone)
def func(x, y):
f = y.item()
t1 = x.view((f, f))
t2 = x.reshape((f, f))
return t1, t2
# create a non-contiguous with data being even numbers in [0:cnt-1]
def make_non_contiguous_tensor(cnt):
# create a non-contiguous tensor x that is skipping odd indices.
x = torch.arange(cnt * 2)
x = x.as_strided((x.size()[0] // 2,), (2,))
return x
x = make_non_contiguous_tensor(4)
torch._dynamo.decorators.mark_unbacked(x, 0)
compiled_func = torch.compile(
fullgraph=True,
backend="inductor",
)(func)
compiled_result = compiled_func(x, torch.tensor([2]))
eager_result = func(x, torch.tensor([2]))
self.assertEqual(compiled_result, eager_result)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@torch._dynamo.config.patch("capture_scalar_outputs", True)
def test_invalid_view_unbacked_view(self):
cnt = CompileCounterWithBackend("inductor")
# This view (u2, u3) -> (u0, u1) cant happen in general unless we know that input is contigous or we have
# hints to to compute strides.
def func(x, y):
u0, u1 = y.tolist()
torch._check_is_size(u0)
torch._check_is_size(u1)
result2 = x.view(u0, u1) * 10
return result2
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
with self.assertRaises(torch._dynamo.exc.UserError):
# throws a data dependent error.
compiled_func(x, torch.tensor([5, 20]))
@skipIfTorchDynamo()
def test_unbind_not_dynamic(self):
cnt = CompileCounter()
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
def func(y):
return y.unbind(dim=2), y * 10
func(torch.ones(5, 6, 7, 8))
self.assertEqual(cnt.frame_count, 1)
# it can be dynamic in all dimentions except dim=2
func(torch.ones(4, 9, 7, 10))
self.assertEqual(cnt.frame_count, 1)
func(torch.ones(5, 6, 8, 8))
func(torch.ones(5, 6, 9, 8))
self.assertEqual(cnt.frame_count, 3)
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
@fresh_cache()
def test_unbacked_contiguous(self):
cnt = CompileCounterWithBackend("inductor")
def func(x):
contig = x.contiguous()
return (contig + 1) * 100
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
x = torch.randn(10, 10)
# make x not contiguous.
x = x.t_()
torch._dynamo.decorators.mark_unbacked(x, 0)
torch._dynamo.decorators.mark_unbacked(x, 1)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
compiled_func(x)
self.assertEqual(compiled_func(x), func(x))
y = torch.rand(20, 20).t()
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 1)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
return (mul_6,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
log_stream, ctx = logs_to_string(
"torch._inductor.compile_fx", "post_grad_graphs"
)
with ctx():
# recompilation will happen due to stride specialization.
y = torch.rand(20, 20)
torch._dynamo.decorators.mark_unbacked(y, 0)
torch._dynamo.decorators.mark_unbacked(y, 1)
self.assertEqual(compiled_func(y), func(y))
self.assertEqual(cnt.frame_count, 2)
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
# No clone this time since input is contiguous.
self.assertExpectedInline(
output,
"""\
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
return (mul_5,)""", # noqa: B950
ignore_comments=True,
ignore_empty_lines=True,
)
instantiate_parametrized_tests(TestUnbacked)
if __name__ == "__main__":
run_tests()