mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: This very much the same solution proposed by bobrenjc93 except that it restrict it to expressions and axioms that have FloorDiv, since those are the only ones that could have became CleanDiv. and the only one that can changes as shape env changes. This also does not break torchrec benchmarks, it might be worth it to know why the generalization of this does break the torchrec benchmarks, but we could just be hitting another bug or NYI situation. ovearhead? None on ``` buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=1000 ``` Differential Revision: D66307433 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141267 Approved by: https://github.com/ezyang
2828 lines
99 KiB
Python
2828 lines
99 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
|
|
import contextlib
|
|
import copy
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.fx
|
|
import torch.nn.functional as F
|
|
from torch import sym_int, SymBool, SymFloat, SymInt
|
|
from torch._C import _disabled_torch_function_impl
|
|
from torch.fx.experimental import sym_node
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
DimConstraints,
|
|
DimDynamic,
|
|
expect_true,
|
|
guard_bool,
|
|
guard_float,
|
|
guard_int,
|
|
GuardOnDataDependentSymNode,
|
|
hint_int,
|
|
is_symbolic,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
statically_known_true,
|
|
)
|
|
from torch.testing._internal.common_dtype import all_types_and
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._sympy.functions import (
|
|
CleanDiv,
|
|
FloorDiv,
|
|
IsNonOverlappingAndDenseIndicator,
|
|
Mod,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
meta_funcs = {}
|
|
|
|
|
|
def register_meta(op):
|
|
def decorator(f):
|
|
def add_func(op):
|
|
meta_funcs[op] = f
|
|
|
|
pytree.tree_map_(add_func, op)
|
|
return f
|
|
|
|
return decorator
|
|
|
|
|
|
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
|
def binary_meta(a, b):
|
|
return a.new_empty(a.shape)
|
|
|
|
|
|
@register_meta(aten.cat.default)
|
|
def cat_meta(tensors, dim=0):
|
|
concat_length = 0
|
|
shape = tensors[0].shape
|
|
for tensor in tensors:
|
|
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
|
if idx == dim:
|
|
concat_length = concat_length + length
|
|
else:
|
|
assert length == common_length
|
|
new_shape = list(shape)
|
|
new_shape[dim] = concat_length
|
|
return tensors[0].new_empty(new_shape)
|
|
|
|
|
|
@register_meta([aten.narrow_copy.default])
|
|
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
|
shape = []
|
|
for i, x in enumerate(a.shape):
|
|
if i == dim:
|
|
shape.append(length)
|
|
else:
|
|
shape.append(x)
|
|
return a.new_empty(tuple(shape))
|
|
|
|
|
|
@register_meta([aten.expand.default])
|
|
def expand_symint_meta(a, size, implicit=False):
|
|
return a.new_empty(size)
|
|
|
|
|
|
def create_contiguous(shape):
|
|
strides = [1]
|
|
for dim in reversed(shape[:-1]):
|
|
strides.append(dim * strides[-1])
|
|
return list(reversed(strides))
|
|
|
|
|
|
class FakeSymbolicTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(
|
|
cls,
|
|
sym_shape,
|
|
sym_strides,
|
|
dtype,
|
|
layout,
|
|
requires_grad,
|
|
device,
|
|
storage_offset=0,
|
|
):
|
|
# TODO: this is wrong in general
|
|
sym_stride = create_contiguous(sym_shape)
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
sym_shape,
|
|
sym_stride,
|
|
storage_offset,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
requires_grad=requires_grad,
|
|
device=device,
|
|
)
|
|
return r
|
|
|
|
__torch_function__ = _disabled_torch_function_impl
|
|
|
|
def new_empty(self, shape):
|
|
return FakeSymbolicTensor(
|
|
shape, None, self.dtype, self.layout, self.requires_grad, self.device
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
|
if func_overload in meta_funcs:
|
|
return meta_funcs[func_overload](*args, **kwargs)
|
|
|
|
if func_overload == torch.ops.aten.new_empty.default:
|
|
self = args[0]
|
|
shape = args[1]
|
|
return FakeSymbolicTensor(
|
|
shape,
|
|
self.stride(),
|
|
self.dtype,
|
|
self.layout,
|
|
self.requires_grad,
|
|
self.device,
|
|
)
|
|
|
|
raise RuntimeError(f"operator {func_overload} not supported")
|
|
|
|
|
|
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
if source is None:
|
|
source = ConstantSource(name)
|
|
constraint_dims = [None] * arg.dim()
|
|
if dynamic_dims is None:
|
|
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
|
|
(
|
|
sym_shapes,
|
|
sym_strides,
|
|
sym_storage_offset,
|
|
) = shape_env.create_symbolic_sizes_strides_storage_offset(
|
|
arg,
|
|
source=source,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
|
|
),
|
|
)
|
|
return FakeSymbolicTensor(
|
|
sym_shapes,
|
|
sym_strides,
|
|
arg.dtype,
|
|
arg.layout,
|
|
arg.requires_grad,
|
|
arg.device,
|
|
sym_storage_offset,
|
|
)
|
|
|
|
|
|
def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
symbol = shape_env.create_symbol(
|
|
val,
|
|
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
|
|
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
|
|
constraint_dim=None,
|
|
**kwargs,
|
|
)
|
|
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
|
|
|
|
|
# TODO: default duck to False
|
|
def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
|
|
return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
|
|
|
|
|
|
def create_symbool(shape_env, b: bool) -> SymBool:
|
|
return create_symtype(SymBool, bool, shape_env, b)
|
|
|
|
|
|
def create_symfloat(shape_env, f: float) -> SymFloat:
|
|
return create_symtype(SymFloat, float, shape_env, f)
|
|
|
|
|
|
@skipIfTorchDynamo(
|
|
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
|
|
)
|
|
class TestPySymInt(TestCase):
|
|
def test_arith_ops(self):
|
|
shape_env = ShapeEnv()
|
|
symints = []
|
|
for i in range(2, 5):
|
|
symints.append((i, create_symint(shape_env, i)))
|
|
|
|
ops = [
|
|
operator.add,
|
|
operator.sub,
|
|
operator.floordiv,
|
|
operator.mul,
|
|
operator.mod,
|
|
]
|
|
|
|
for op in ops:
|
|
for args in itertools.permutations(symints, 2):
|
|
if not isinstance(args[0][1], int) and (
|
|
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
|
|
):
|
|
self.assertTrue(
|
|
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
|
|
)
|
|
|
|
def test_reverse_arith_ops(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
a = create_symint(shape_env, 2)
|
|
self.assertTrue(5 // a == 5 // 2)
|
|
|
|
a = create_symint(shape_env, 2)
|
|
self.assertTrue(5 * a == 5 * 2)
|
|
|
|
def test_sympify_symint(self):
|
|
shape_env = ShapeEnv()
|
|
a = create_symint(shape_env, 2)
|
|
self.assertIs(sympy.sympify(a), a.node.expr)
|
|
b = create_symfloat(shape_env, 3.0)
|
|
self.assertIs(sympy.sympify(b), b.node.expr)
|
|
c = create_symbool(shape_env, True)
|
|
self.assertIs(sympy.sympify(c), c.node.expr)
|
|
|
|
def test_roundtrip(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
|
|
self.assertTrue(not isinstance(x.shape[0], SymNode))
|
|
self.assertTrue(isinstance(x.shape[0], SymInt))
|
|
|
|
self.assertTrue(x.shape[0] == 5)
|
|
self.assertTrue(x.shape[1] == 4)
|
|
self.assertTrue(x.shape[2], 3)
|
|
|
|
self.assertTrue(x.size()[0], 5)
|
|
self.assertTrue(x.size()[1], 4)
|
|
# Should be simplifiable to an integer.
|
|
# Ref: https://github.com/pytorch/pytorch/pull/107492
|
|
self.assertTrue(isinstance(x.size()[1], SymInt))
|
|
self.assertTrue(
|
|
isinstance(x.size()[1].node.maybe_as_int(), int)
|
|
) # due to guard above
|
|
self.assertTrue(x.size()[2] == 3)
|
|
|
|
self.assertTrue(x.size(0) == 5)
|
|
self.assertTrue(x.size(1) == 4)
|
|
self.assertTrue(x.size(2) == 3)
|
|
self.assertTrue(isinstance(x.size(2), SymInt))
|
|
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
|
|
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
|
|
self.assertTrue(isinstance(y.storage_offset(), SymInt))
|
|
self.assertTrue(y.storage_offset() == 12)
|
|
|
|
def test_binary(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
|
|
|
z = x + y
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# broadcasting
|
|
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
|
|
z = x + y
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
def test_symint_args(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
|
LAST_DIM = 2
|
|
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
|
self.assertTrue(z.shape[2] == y.shape[2])
|
|
|
|
# arithmetic expr with two symints
|
|
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
|
self.assertTrue(z.shape[2] == 2)
|
|
|
|
# arithmetic expr with a symint and python int
|
|
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
|
self.assertTrue(z.shape[2] == 2)
|
|
|
|
def test_symint_vargs(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
|
|
|
# varargs
|
|
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# shape list
|
|
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints
|
|
z = y.expand(x.shape[0], y.shape[1], 3)
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints in a list
|
|
z = y.expand((x.shape[0], y.shape[1], 3))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints
|
|
z = y.expand(5, y.shape[1], x.shape[2])
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python ints and symints in a list
|
|
z = y.expand((5, y.shape[1], x.shape[2]))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
z = y.expand((y.shape[1],))
|
|
z = y.expand(y.shape[1])
|
|
|
|
def test_symint_bitwise_and(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 0b1100)
|
|
b0 = create_symint(shape_env, 0b1010)
|
|
res_and = a0 & b0
|
|
self.assertEqual(res_and, 0b1000)
|
|
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)"""
|
|
)
|
|
|
|
a1 = create_symint(shape_env, 3)
|
|
b1 = create_symbool(shape_env, True)
|
|
self.assertEqual(a1 & b1, 1)
|
|
|
|
a2 = create_symint(shape_env, 0b1100)
|
|
self.assertEqual(a2 & 0b1010, 0b1000)
|
|
|
|
a3 = create_symbool(shape_env, True)
|
|
b3 = create_symbool(shape_env, True)
|
|
self.assertEqual(a3 & b3, True)
|
|
|
|
def test_symint_bitwise_or(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 0b1100)
|
|
b0 = create_symint(shape_env, 0b1010)
|
|
res_or = a0 | b0
|
|
self.assertEqual(res_or, 0b1110)
|
|
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)"""
|
|
)
|
|
|
|
def test_stride(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
|
self.assertIsInstance(x.stride()[0], SymInt)
|
|
|
|
def test_size_expressions(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
expand_x = x.expand(x.shape[0], x.shape[0])
|
|
if expand_x.shape[0] > 3:
|
|
result = expand_x + expand_x
|
|
else:
|
|
result = expand_x + expand_x
|
|
|
|
gt_op, _bt = shape_env.guards[-1]
|
|
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
|
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
|
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
|
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
|
|
|
def test_floordiv_static(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 8)
|
|
# This was extracted from
|
|
# python test/inductor/test_cuda_cpp_wrapper.py -k
|
|
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
|
|
bool(s0 % 2 == 0)
|
|
bool(s0 % (s0 // 2) == 0)
|
|
bool(2 * (s0 // 2) == s0)
|
|
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
|
|
|
|
def test_numel(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
self.assertIsInstance(x.numel(), torch.SymInt)
|
|
self.assertIsInstance(torch.numel(x), torch.SymInt)
|
|
|
|
x = torch.rand(3, 3)
|
|
self.assertIsInstance(x.numel(), int)
|
|
self.assertIsInstance(torch.numel(x), int)
|
|
|
|
def test_int_to_float(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
r = torch.sym_float(x.shape[0])
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
|
|
def test_aten_ops(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
|
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
|
|
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
|
|
|
def test_fx_trace_intlist(self):
|
|
class CustomModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
bs, c, h, w = x.shape
|
|
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
|
|
|
m = CustomModule()
|
|
x = torch.rand(1, 3, 4, 4)
|
|
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
|
# tuple of ints, not tuple
|
|
torch.fx.symbolic_trace(m)
|
|
|
|
def test_meta_symint(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
r = torch.empty(a0, device="meta")
|
|
self.assertIsInstance(r.shape[0], SymInt)
|
|
|
|
def test_guard_int(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
self.assertEqual(guard_int(a0), 2)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
|
|
|
def test_sym_sum(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 2)
|
|
s1 = create_symint(shape_env, 3)
|
|
s2 = create_symint(shape_env, 4)
|
|
self.assertEqual(
|
|
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
|
|
)
|
|
|
|
def test_prefer_deferred_runtime_assertions_over_guards(self):
|
|
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
|
s0 = create_symint(shape_env, 2)
|
|
self.assertEqual(guard_int(s0), 2)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
|
|
|
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
|
s0 = create_symint(shape_env, 2)
|
|
self.assertTrue(expect_true(s0 == 2))
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
|
|
"""[Eq(s0, 2)]""",
|
|
)
|
|
|
|
def test_sym_int(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = sym_int(a0)
|
|
self.assertEqual(r, 5)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 5)""")
|
|
|
|
a1 = create_symint(shape_env, 7)
|
|
r = sym_int(a1 / 2)
|
|
self.assertEqual(guard_int(r), 3)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)"""
|
|
)
|
|
|
|
a3 = create_symint(shape_env, 3)
|
|
r = sym_int(2.0 * torch.sym_float(a3))
|
|
self.assertEqual(guard_int(r), 6)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)"""
|
|
)
|
|
|
|
def test_sym_log2(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 4)
|
|
r = torch._sym_log2(a0)
|
|
self.assertEqual(r, 2.0)
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s0)), 2.0)"""
|
|
)
|
|
|
|
def test_sym_sqrt(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 4)
|
|
r = torch._sym_sqrt(a0)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s0)), 2.0)"""
|
|
)
|
|
|
|
def test_sym_floor(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.floor(a0 / 2)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""",
|
|
)
|
|
r = math.floor(3.0 * a0)
|
|
self.assertEqual(r, 15)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
|
|
)
|
|
|
|
def test_sym_trunc(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.trunc(a0 / 2)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)"""
|
|
)
|
|
r = torch.sym_int(torch.sym_sqrt(a0))
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s0))), 2)""",
|
|
)
|
|
|
|
def test_sym_ceil(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.ceil(a0 / 2)
|
|
self.assertEqual(r, 3)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""",
|
|
)
|
|
r1 = 3.0 * a0
|
|
r = math.floor(r1)
|
|
self.assertEqual(r, 15)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(FloorToInt(3.0*ToFloat(s0)), 15)""",
|
|
)
|
|
|
|
def test_sym_ite(self):
|
|
shape_env = ShapeEnv()
|
|
t = create_symint(shape_env, 5)
|
|
f = create_symint(shape_env, 4)
|
|
b1 = True
|
|
r1 = torch.sym_ite(b1, t, f)
|
|
self.assertTrue(r1 is t)
|
|
b2 = False
|
|
r2 = torch.sym_ite(b2, t, f)
|
|
self.assertTrue(r2 is f)
|
|
b3 = t == 5
|
|
r3 = torch.sym_ite(b3, t, f)
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
self.assertEqual(r3, 5)
|
|
self.assertEqual(type(t), type(r3))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(Piecewise((s0, Eq(s0, 5)), (s1, True)), 5)""",
|
|
)
|
|
b4 = f == 5
|
|
r4 = torch.sym_ite(b4, t, f)
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
self.assertEqual(r4, 4)
|
|
self.assertEqual(type(f), type(r4))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(Piecewise((s0, Eq(s1, 5)), (s1, True)), 4)""",
|
|
)
|
|
|
|
def test_tracing_sym_ite(self):
|
|
def f(x):
|
|
b = x.shape[0] == 5
|
|
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
|
|
return ret
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
|
|
self.assertEqual(len(gm.shape_env.guards), 0)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 5
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
|
|
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
|
|
return sym_ite""",
|
|
)
|
|
r1 = gm(torch.ones(4, 5))
|
|
self.assertIsInstance(r1, int)
|
|
self.assertEqual(r1, 5)
|
|
r2 = gm(torch.ones(5, 4))
|
|
self.assertIsInstance(r2, int)
|
|
self.assertEqual(r2, 5)
|
|
|
|
def test_int_conversion(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
int(a0)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s0, 2)""")
|
|
|
|
def test_data_dependent_guard(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
|
|
|
|
def test_data_dependent_guard_propagate_real_tensors(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
|
|
self.assertEqual(bool(s0 == 0), True)
|
|
|
|
def test_expect_true_basic(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i0_sym = i0.node.expr
|
|
# This doesn't error
|
|
self.assertTrue(expect_true(i0 == 0))
|
|
# This generates a deferred runtime assert via replacement
|
|
self.assertEqual(shape_env.replacements[i0_sym], 0)
|
|
# After expecting true, guards now resolve given the runtime assert
|
|
bool(i0 == 0)
|
|
|
|
def test_expect_true_with_s0(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 5)
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(i0 < s0))
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
|
|
"""[u0 < s0]""",
|
|
)
|
|
self.assertTrue(i0 < s0)
|
|
self.assertTrue(i0 != s0)
|
|
self.assertFalse(i0 > s0)
|
|
self.assertFalse(i0 >= s0)
|
|
|
|
def test_expect_true_prefer_later(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i1 = shape_env.create_unbacked_symint()
|
|
i1_sym = i1.node.expr
|
|
self.assertTrue(expect_true(i0 + i1 == 10))
|
|
# Importantly, this is put in i1, not i0!
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
|
|
"""[Eq(u0 + u1, 10)]""",
|
|
)
|
|
self.assertTrue(i0 + i1 == 10)
|
|
# NB: We currently don't support deriving that we can substitute
|
|
# i0 + i1 with 10; maybe we should, but this means our rewriting
|
|
# system is no longer confluent (it's probably OK though, because
|
|
# you're unlikely to get other equalities like this on the
|
|
# unbacked SymInts.)
|
|
|
|
def test_unbacked_substitution(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i1 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i0)
|
|
_constrain_range_for_size(i1)
|
|
self.assertTrue(expect_true(i0 == i1 * 4))
|
|
self.assertExpectedInline(str(i0), """u0""")
|
|
|
|
i2 = shape_env.create_unbacked_symint()
|
|
i3 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i2)
|
|
_constrain_range_for_size(i3)
|
|
self.assertTrue(expect_true(i2 * 4 == i3))
|
|
self.assertExpectedInline(str(i3), """u3""")
|
|
|
|
def test_avoid_unbacked_substitution(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i0)
|
|
i1 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i1)
|
|
self.assertTrue(expect_true(i0 == 10 - i1))
|
|
self.assertExpectedInline(str(i0), """u0""")
|
|
|
|
def test_expect_true_double_digits(self):
|
|
shape_env = ShapeEnv()
|
|
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
|
|
self.assertEqual(str(ia[-1]), "u10")
|
|
self.assertTrue(expect_true(sum(ia) == 20))
|
|
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
|
|
|
|
def test_expect_true_refine_range(self):
|
|
shape_env = ShapeEnv()
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertFalse(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 > 4))
|
|
self.assertTrue(statically_known_true(i0 >= 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 2))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 < 4))
|
|
self.assertTrue(statically_known_true(i0 <= 5))
|
|
|
|
def test_guard_refine_range(self):
|
|
shape_env = ShapeEnv()
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 10, duck=False)
|
|
self.assertTrue(bool(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertFalse(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 > 4))
|
|
self.assertTrue(statically_known_true(i0 >= 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 2, duck=False)
|
|
self.assertFalse(bool(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertFalse(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 <= 4))
|
|
self.assertTrue(statically_known_true(i0 < 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 2, duck=False)
|
|
self.assertTrue(bool(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 2))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 < 4))
|
|
self.assertTrue(statically_known_true(i0 <= 3))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 10, duck=False)
|
|
self.assertFalse(bool(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 2))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertFalse(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 >= 4))
|
|
self.assertTrue(statically_known_true(i0 > 3))
|
|
|
|
def test_mul_int_oo_nan(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 5, duck=False)
|
|
s1 = create_symint(shape_env, 6, duck=False)
|
|
s2 = create_symint(shape_env, 5, duck=False)
|
|
bool(s0 * (s1 // s0) == s2)
|
|
|
|
def test_non_overlapping_and_dense(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
|
|
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
|
|
|
|
def test_non_overlapping_and_dense_unbacked(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
cf = torch.ops.aten.is_non_overlapping_and_dense.default
|
|
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
|
|
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
|
|
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
|
|
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
|
|
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
|
|
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
|
|
|
|
Max = torch.sym_max
|
|
# NB: This only works because we're able to determine this tensor is
|
|
# contiguous. transpose(0, 1) makes it stop working
|
|
self.assertTrue(
|
|
cf(
|
|
torch.empty_strided(
|
|
(2, 3, 1, u0),
|
|
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
|
|
device="meta",
|
|
)
|
|
)
|
|
)
|
|
|
|
def test_sympy_optimized_add_binary_search(self):
|
|
import sympy
|
|
|
|
from torch.fx.experimental.sym_node import _binary_search_insert_arg
|
|
|
|
a = sympy.Symbol("a")
|
|
b = sympy.Symbol("b")
|
|
c = sympy.Symbol("c")
|
|
|
|
args = []
|
|
args = _binary_search_insert_arg([], b)
|
|
self.assertEqual(args, [b])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
|
|
args = _binary_search_insert_arg(args, a)
|
|
self.assertEqual(args, [a, b])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, a), None)
|
|
|
|
args = _binary_search_insert_arg(args, c)
|
|
self.assertEqual(args, [a, b, c])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, a), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, c), None)
|
|
|
|
a1 = sympy.Symbol("a1")
|
|
a2 = sympy.Symbol("a2")
|
|
|
|
args = _binary_search_insert_arg(args, a1)
|
|
self.assertEqual(args, [a, a1, b, c])
|
|
|
|
args = _binary_search_insert_arg(args, a2)
|
|
self.assertEqual(args, [a, a1, a2, b, c])
|
|
|
|
c1 = sympy.Symbol("c1")
|
|
args = _binary_search_insert_arg(args, c1)
|
|
self.assertEqual(args, [a, a1, a2, b, c, c1])
|
|
|
|
# insert to front
|
|
_a = sympy.Symbol("_a")
|
|
args = _binary_search_insert_arg(args, _a)
|
|
self.assertEqual(args, [_a, a, a1, a2, b, c, c1])
|
|
|
|
def test_floor_clean_div_axioms(self):
|
|
# Test that if we add an axiom that have FloorDiv, after which the
|
|
# shapeEnv changed such that it can be simplified it to CleanDiv, then
|
|
# We still correctly replace CleanDiv with the axiom value of FloorDiv.
|
|
shape_env = ShapeEnv()
|
|
a = shape_env.create_unbacked_symint()
|
|
|
|
shape_env.defer_runtime_assert((a // 3 == 1).node.expr, " test")
|
|
|
|
from sympy import Eq
|
|
|
|
test1 = Eq(FloorDiv(a.node.expr, 3), 1)
|
|
test2 = Eq(CleanDiv(a.node.expr, 3), 1)
|
|
|
|
self.assertTrue(shape_env.evaluate_expr(test1))
|
|
self.assertEqual(shape_env._maybe_evaluate_static(test2), None)
|
|
|
|
# After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3)
|
|
shape_env.defer_runtime_assert(Eq(Mod(a, 3), 0), " test")
|
|
self.assertEqual(test2, shape_env.simplify(test1))
|
|
|
|
self.assertTrue(shape_env.evaluate_expr(test1))
|
|
self.assertTrue(shape_env.evaluate_expr(test2))
|
|
|
|
def test_sympy_optimized_add(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 2)
|
|
s1 = create_symint(shape_env, 3)
|
|
s2 = create_symint(shape_env, 4)
|
|
sum = s0 + s1
|
|
|
|
self.assertTrue(sum.node._optimized_summation)
|
|
|
|
def assert_optimized(sym):
|
|
self.assertTrue(sym.node._optimized_summation)
|
|
|
|
def assert_not_optimized(sym):
|
|
self.assertFalse(getattr(sym.node, "_optimized_summation", False))
|
|
|
|
assert_optimized(sum)
|
|
|
|
# add duplicate symbol
|
|
assert_not_optimized(sum + s0)
|
|
|
|
# add constant.
|
|
assert_not_optimized(sum + 1)
|
|
|
|
# add new unique symbol, should maintain _optimized_summation property.
|
|
assert_optimized(sum + s2)
|
|
|
|
assert_optimized(s2 + sum)
|
|
|
|
# add x + (a+b) with no _optimized_summation on the rhs sum.
|
|
a = create_symint(shape_env, 10)
|
|
b = create_symint(shape_env, 11)
|
|
two_sum = torch.sym_sum([a, b])
|
|
assert_not_optimized(two_sum)
|
|
assert_optimized(sum + two_sum)
|
|
|
|
# adding two expressions of length >2 that are _optimized_summation.
|
|
a = s0 + s1 + s2
|
|
s3 = create_symint(shape_env, 10)
|
|
s4 = create_symint(shape_env, 20)
|
|
s5 = create_symint(shape_env, 30)
|
|
b = s3 + s4 + s5
|
|
assert_optimized(a)
|
|
assert_optimized(b)
|
|
assert_optimized(a + b)
|
|
assert_optimized(b + a)
|
|
|
|
# same as above but b does not have ordered_summation_of_unique_symbols.
|
|
s6 = create_symint(shape_env, 11)
|
|
s7 = create_symint(shape_env, 21)
|
|
s8 = create_symint(shape_env, 31)
|
|
b = torch.sym_sum([s6, s7, s8])
|
|
assert_optimized(a)
|
|
assert_not_optimized(b)
|
|
assert_not_optimized(a + b)
|
|
|
|
def test_max_of_unique_summation_opt(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
s1 = shape_env.create_unbacked_symint()
|
|
s2 = shape_env.create_unbacked_symint()
|
|
s3 = shape_env.create_unbacked_symint()
|
|
s4 = shape_env.create_unbacked_symint()
|
|
s5 = shape_env.create_unbacked_symint()
|
|
s7 = shape_env.create_unbacked_symint()
|
|
|
|
def assert_optimized(sym):
|
|
self.assertTrue(sym.node.expr.unique_summations_symbols is not None)
|
|
|
|
def assert_not_optimized(sym):
|
|
getattr(sym.node.expr, "unique_summations_symbols", None)
|
|
|
|
mx1 = torch.sym_max(s0, s1)
|
|
assert_not_optimized(mx1)
|
|
|
|
mx2 = torch.sym_max(s0 + s1, s2 + s3)
|
|
assert_optimized(mx2)
|
|
|
|
mx3 = torch.sym_max(mx2, s4 + s5)
|
|
assert_optimized(mx3)
|
|
assert_optimized(torch.sym_max(s4 + s5, mx2))
|
|
|
|
assert_not_optimized(torch.sym_max(mx3, s7))
|
|
assert_not_optimized(torch.sym_max(mx3, 10))
|
|
assert_not_optimized(torch.sym_max(mx3, s3 + s7))
|
|
assert_not_optimized(torch.sym_max(mx3, s7 * 2))
|
|
|
|
def test_sym_max_multi_max_simplify(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(
|
|
statically_known_true(
|
|
torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0)
|
|
)
|
|
)
|
|
|
|
def test_numpy_sym_max(self):
|
|
self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
|
|
self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
|
|
self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
|
|
self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
|
|
self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
|
|
self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
|
|
|
|
def test_numpy_sym_min(self):
|
|
self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
|
|
self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
|
|
self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
|
|
self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
|
|
self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
|
|
self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
|
|
|
|
def test_debug_has_internal_overlap_unbacked(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
cf = torch._debug_has_internal_overlap
|
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 0)
|
|
Max = torch.sym_max
|
|
self.assertEqual(
|
|
cf(
|
|
torch.empty_strided(
|
|
(2, 3, 1, u0),
|
|
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
|
|
device="meta",
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
|
|
# Wobbling these to zero is OK too
|
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
|
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
|
|
|
|
def test_specialize_zero_one(self):
|
|
shape_env = ShapeEnv(specialize_zero_one=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0 != 1
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
shape_env = ShapeEnv(specialize_zero_one=False)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0 != 1
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
|
|
def test_duck_shape(self):
|
|
shape_env = ShapeEnv(duck_shape=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
a1 = create_symint(shape_env, 5)
|
|
assert a0 == a1
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
shape_env = ShapeEnv(duck_shape=False)
|
|
a0 = create_symint(shape_env, 5)
|
|
a1 = create_symint(shape_env, 5)
|
|
assert a0 == a1
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
|
|
def test_int_bool(self):
|
|
# See https://github.com/pytorch/pytorch/issues/95981
|
|
shape_env = ShapeEnv(duck_shape=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_symint_as_scalar(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
|
|
sym_int_encountered = False
|
|
|
|
class TestSymInt(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
assert func == torch.ops.aten.add.Tensor
|
|
|
|
nonlocal sym_int_encountered
|
|
# WARNING: do not do identity tests on the outer
|
|
# SymInt/SymFloat, they are NOT STABLE
|
|
sym_int_encountered = kwargs["alpha"].node is a0.node
|
|
kwargs["alpha"] = 0
|
|
return func(*args)
|
|
|
|
x = torch.rand([4, 4])
|
|
with TestSymInt():
|
|
y = torch.add(x, x, alpha=a0)
|
|
|
|
self.assertTrue(sym_int_encountered)
|
|
|
|
def test_deepcopy(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
assert a0 < 4
|
|
new_shape_env = copy.deepcopy(shape_env)
|
|
self.assertEqual(len(new_shape_env.guards), 1)
|
|
|
|
def test_print_readable_with_symints(self):
|
|
def f(a, b):
|
|
dim0 = a.shape[0] + b.shape[0]
|
|
dim1 = a.shape[1] + b.shape[1]
|
|
d = a.new_empty(dim0, dim1)
|
|
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
|
|
return d
|
|
|
|
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
|
|
out = fx_g.print_readable(print_output=False)
|
|
|
|
self.assertExpectedInline(
|
|
out.strip(),
|
|
"""\
|
|
class f(torch.nn.Module):
|
|
def forward(self, a_1: "f32[s0, s1]", b_1: "f32[s2, s1]"):
|
|
# No stacktrace found for following nodes
|
|
sym_size_int: "Sym(s0)" = torch.ops.aten.sym_size.int(a_1, 0)
|
|
sym_size_int_1: "Sym(s2)" = torch.ops.aten.sym_size.int(b_1, 0)
|
|
add: "Sym(s0 + s2)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
|
|
sym_size_int_2: "Sym(s1)" = torch.ops.aten.sym_size.int(a_1, 1)
|
|
sym_size_int_3: "Sym(s1)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
|
|
add_1: "Sym(2*s1)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
|
|
new_empty: "f32[s0 + s2, 2*s1]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
|
|
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
|
getitem: "f32[s0 + s2, 2*s1]" = native_dropout[0]
|
|
getitem_1: "b8[s0 + s2, 2*s1]" = native_dropout[1]; native_dropout = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
|
|
def test_statically_known_true(self):
|
|
shape_env = ShapeEnv()
|
|
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
|
|
|
|
# Statically known true
|
|
self.assertTrue(statically_known_true(True))
|
|
self.assertTrue(statically_known_true(s2 == s2))
|
|
self.assertTrue(statically_known_true(s2 * s3 > s3))
|
|
self.assertTrue(statically_known_true(s3 * s4 > s4))
|
|
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
|
|
|
|
# Statically known false
|
|
self.assertFalse(statically_known_true(False))
|
|
self.assertFalse(statically_known_true(s3 * s4 <= s4))
|
|
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
|
|
|
|
# True for hints, but not known statically
|
|
self.assertFalse(statically_known_true(s2 + s2 == s4))
|
|
self.assertFalse(statically_known_true(s4 % s2 == 0))
|
|
self.assertFalse(statically_known_true(s2 != s3))
|
|
self.assertFalse(statically_known_true(s3 * s4 > s2))
|
|
|
|
# False for hints, but not known statically
|
|
self.assertFalse(statically_known_true(s2 == s3))
|
|
self.assertFalse(statically_known_true(s2 > s3))
|
|
self.assertFalse(statically_known_true(s3 + s3 == s4))
|
|
|
|
# No guards should be generated
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_ephemeral_source_simplification(self):
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
|
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
|
|
# of construction order or check order.
|
|
for construct_ephemeral_first, x_first_in_check in itertools.product(
|
|
[False, True], [False, True]
|
|
):
|
|
shape_env = ShapeEnv()
|
|
shape = (5, 10)
|
|
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
|
|
x = create_symbolic_tensor(
|
|
"x",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if construct_ephemeral_first else None),
|
|
dynamic_dims=dynamic_dims,
|
|
)
|
|
y = create_symbolic_tensor(
|
|
"y",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
|
dynamic_dims=dynamic_dims,
|
|
)
|
|
t_with_ephemeral = x if construct_ephemeral_first else y
|
|
|
|
def _get_ephemeral_source_symbols(t):
|
|
return [
|
|
s.node.expr
|
|
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
|
|
if isinstance(s, torch.SymInt)
|
|
and s.node.expr in shape_env.var_to_sources
|
|
and any(
|
|
source.is_ephemeral()
|
|
for source in shape_env.var_to_sources[s.node.expr]
|
|
)
|
|
]
|
|
|
|
# these checks should simplify out the ephemeral symbols, regardless of the
|
|
# ordering x == y or y == x
|
|
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
|
|
if x_first_in_check:
|
|
torch._check(x.size() == y.size())
|
|
torch._check(x.stride() == y.stride())
|
|
torch._check(x.storage_offset() == y.storage_offset())
|
|
else:
|
|
torch._check(y.size() == x.size())
|
|
torch._check(y.stride() == x.stride())
|
|
torch._check(y.storage_offset() == x.storage_offset())
|
|
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
|
|
|
|
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
|
for construct_ephemeral_first in (False, True):
|
|
shape_env = ShapeEnv()
|
|
shape = (5, 10)
|
|
# use duck sizing here to ensure symbol reuse across x and y
|
|
duck_dims = [DimDynamic.DUCK for _ in shape]
|
|
x = create_symbolic_tensor(
|
|
"x",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if construct_ephemeral_first else None),
|
|
dynamic_dims=duck_dims,
|
|
)
|
|
y = create_symbolic_tensor(
|
|
"y",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
|
dynamic_dims=duck_dims,
|
|
)
|
|
|
|
# regardless of construction order, non-ephemeral sources should be preferred
|
|
# first in the var_to_sources list for potential guarding later on
|
|
for source_list in shape_env.var_to_sources.values():
|
|
self.assertFalse(source_list[0].is_ephemeral())
|
|
|
|
self.assertEqual(x.size(), y.size())
|
|
self.assertEqual(x.stride(), y.stride())
|
|
self.assertEqual(x.storage_offset(), y.storage_offset())
|
|
|
|
def test_tensor_factory_with_symint(self):
|
|
args = list(range(0, 3))
|
|
expected = torch.tensor(args)
|
|
|
|
shape_env = ShapeEnv()
|
|
sym_args = [create_symint(shape_env, i) for i in args]
|
|
|
|
# test tensor factories
|
|
for dt in all_types_and(torch.half, torch.bfloat16):
|
|
res = torch.tensor(sym_args, dtype=dt)
|
|
self.assertEqual(res, expected, exact_dtype=False)
|
|
|
|
# test legacy tensor factories
|
|
legacy_ctors = [
|
|
torch.Tensor,
|
|
torch.LongTensor,
|
|
torch.DoubleTensor,
|
|
torch.FloatTensor,
|
|
torch.IntTensor,
|
|
torch.ShortTensor,
|
|
torch.HalfTensor,
|
|
torch.ByteTensor,
|
|
]
|
|
for Tensor in legacy_ctors:
|
|
res = Tensor(sym_args)
|
|
self.assertEqual(res, expected, exact_dtype=False)
|
|
|
|
|
|
@skipIfTorchDynamo(
|
|
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
|
|
)
|
|
class TestSymNumberMagicMethods(TestCase):
|
|
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
|
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
|
|
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
|
|
|
|
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
|
# Helper function
|
|
# NB: don't use one as that will get specialized
|
|
# TODO: We don't have to circuitously create the float, can just
|
|
# create a symfloat directly
|
|
seed_node = (create_symint(shape_env, 2) / 2.0).node
|
|
bool_seed_node = (create_symint(shape_env, 2) == 2).node
|
|
|
|
def get_sym_inp(inp):
|
|
# NB: this must come before int
|
|
if isinstance(inp, bool):
|
|
return torch.SymBool(to_node(bool_seed_node, inp))
|
|
elif isinstance(inp, int):
|
|
return torch.SymInt(to_node(seed_node, inp))
|
|
else:
|
|
return torch.SymFloat(to_node(seed_node, inp))
|
|
|
|
if fn == "float_pow":
|
|
if inp1 < 0:
|
|
return
|
|
|
|
if fn == "pow_by_natural":
|
|
if isinstance(inp1, float) or isinstance(inp2, float):
|
|
return
|
|
if inp2 < 0:
|
|
return
|
|
|
|
def maybe_xfail(inp1, inp2):
|
|
if fn == "sym_sqrt" and inp1 < 0:
|
|
# ValueError: math domain error
|
|
return self.assertRaises((ValueError,))
|
|
elif (
|
|
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
|
|
and inp2 == 0
|
|
):
|
|
# ZeroDivisionError: division by zero
|
|
return self.assertRaises((ZeroDivisionError,))
|
|
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
|
|
# ZeroDivisionError: 0.0 cannot be raised to a negative power
|
|
return self.assertRaises((ZeroDivisionError,))
|
|
elif (
|
|
# TODO: dear catastrophe waitress,
|
|
# this doesn't work
|
|
fn in ["float_pow", "pow_by_natural"]
|
|
and inp1 < 0
|
|
and (
|
|
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
|
|
)
|
|
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
|
|
):
|
|
# Complex result, which we do not support:
|
|
# TypeError: Cannot convert complex to float
|
|
return self.assertRaises((RuntimeError,))
|
|
elif fn in ("lshift", "rshift") and not (
|
|
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
|
|
):
|
|
# TypeError: unsupported operand type(s)
|
|
return self.assertRaises((TypeError,))
|
|
elif fn in ("lshift", "rshift") and inp2 < 0:
|
|
# ValueError: math domain error
|
|
return self.assertRaises((ValueError,))
|
|
else:
|
|
return contextlib.nullcontext()
|
|
|
|
lambda_apply = method_to_operator(fn)
|
|
|
|
def guard_fn(v):
|
|
if type(v) in (SymBool, bool):
|
|
return guard_bool(v)
|
|
elif type(v) in (SymFloat, float):
|
|
return guard_float(v)
|
|
else: # SymInt, int
|
|
return guard_int(v)
|
|
|
|
# Get reference result
|
|
with maybe_xfail(inp1, inp2):
|
|
if is_unary_fn:
|
|
ref_out = lambda_apply(inp1)
|
|
else:
|
|
ref_out = lambda_apply(inp1, inp2)
|
|
|
|
# Symified first arg
|
|
sym_inp1 = get_sym_inp(inp1)
|
|
with maybe_xfail(sym_inp1, inp2):
|
|
if is_unary_fn:
|
|
out = lambda_apply(sym_inp1)
|
|
else:
|
|
out = lambda_apply(sym_inp1, inp2)
|
|
if fn not in sym_node.alternate_impl_if_hinted_methods:
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
if is_unary_fn:
|
|
return
|
|
|
|
# Symified second arg
|
|
sym_inp2 = get_sym_inp(inp2)
|
|
with maybe_xfail(inp1, sym_inp2):
|
|
out = lambda_apply(inp1, sym_inp2)
|
|
if fn not in sym_node.alternate_impl_if_hinted_methods:
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
# Symified both args
|
|
with maybe_xfail(sym_inp1, sym_inp2):
|
|
out = lambda_apply(sym_inp1, sym_inp2)
|
|
if fn not in sym_node.alternate_impl_if_hinted_methods:
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
@parametrize("fn", list(sym_node.magic_methods.keys()))
|
|
def test_bool_method(self, fn):
|
|
# sym_ite has its own tests
|
|
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
|
|
self.skipTest(f"{fn} is non-bool")
|
|
|
|
is_unary_fn = fn in sym_node.unary_methods
|
|
shape_env = ShapeEnv()
|
|
self._do_test(fn, True, False, shape_env, is_unary_fn)
|
|
|
|
@parametrize("fn", list(sym_node.magic_methods.keys()))
|
|
@parametrize("first_type", ["int", "float"])
|
|
@parametrize("second_type", ["int", "float"])
|
|
def test_method(self, fn, first_type, second_type):
|
|
if first_type == "float":
|
|
# TODO: Hmm, this looks like we skip all floats
|
|
self.skipTest(f"{fn} is not a float magic method")
|
|
|
|
if (
|
|
first_type == "int" or second_type == "int"
|
|
) and fn in sym_node.only_float_magic_methods:
|
|
self.skipTest(f"{fn} is not an int method")
|
|
|
|
if second_type == "float" and fn in ["mod"]:
|
|
self.skipTest(f"{fn} only handles int")
|
|
|
|
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
|
|
self.skipTest(f"{fn} is a bitwise op, only handles int")
|
|
|
|
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
|
|
# Second argument is ignored for unary function. So only run for one type
|
|
if is_unary_fn and second_type == "float":
|
|
self.skipTest(f"{fn} is unary and already tested")
|
|
|
|
if fn in sym_node.bool_magic_methods:
|
|
self.skipTest(f"{fn} is bool")
|
|
|
|
# Only floats here since these will be converted to int if necessary.
|
|
# We also ignore complex and bool.
|
|
values = (
|
|
0.0,
|
|
1.0,
|
|
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
|
|
)
|
|
|
|
neg_values = tuple(-x for x in values)
|
|
|
|
for inp1, inp2 in itertools.chain(
|
|
itertools.product(values, values),
|
|
itertools.product(values, neg_values),
|
|
itertools.product(neg_values, values),
|
|
itertools.product(neg_values, neg_values),
|
|
):
|
|
if first_type == "int":
|
|
inp1 = int(inp1)
|
|
if second_type == "int":
|
|
inp2 = int(inp2)
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
|
|
|
|
def get_constant_bool(self, val):
|
|
return SymBool(torch._C._get_constant_bool_symnode(val))
|
|
|
|
@unittest.expectedFailure
|
|
def test_symint_hashing(self):
|
|
shape_env = ShapeEnv()
|
|
hash(create_symint(shape_env, 3))
|
|
|
|
def test_symnode_hashing(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
# These all trigger specialization when hashed
|
|
hash(create_symbool(shape_env, True))
|
|
# We should be passing in float here, but create_symbol currently
|
|
# only supports int
|
|
hash(create_symfloat(shape_env, 3.0))
|
|
|
|
# NestedInt (SymInt), constant SymBool, SymNode are hashable
|
|
j1 = torch._C._get_nested_int(1, 1)
|
|
j1_copy = torch._C._get_nested_int(1, 1)
|
|
j2 = torch._C._get_nested_int(2, 1)
|
|
t = self.get_constant_bool(True)
|
|
t_copy = self.get_constant_bool(True)
|
|
f = self.get_constant_bool(False)
|
|
n = create_symint(shape_env, 3).node
|
|
m = self.get_constant_bool(True).node
|
|
|
|
self.assertIs(j1 == j1_copy, True)
|
|
self.assertEqual(hash(j1), hash(j1_copy))
|
|
self.assertIs(j1 == j2, False)
|
|
self.assertNotEqual(hash(j1), hash(j2))
|
|
self.assertIs(t == t_copy, True)
|
|
self.assertEqual(hash(t), hash(t_copy))
|
|
self.assertIs(t == f, False)
|
|
self.assertNotEqual(hash(t), hash(f))
|
|
|
|
hash(n)
|
|
hash(m)
|
|
|
|
def test_symint_deepcopy(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
symnodes = (torch._C._get_nested_int(1, 1),)
|
|
deepcopied_symnodes = copy.deepcopy(symnodes)
|
|
self.assertEqual(symnodes, deepcopied_symnodes)
|
|
|
|
def test_non_symbolic_symnode(self):
|
|
j1 = torch._C._get_nested_int(1, 1)
|
|
j2 = torch._C._get_nested_int(1, 1)
|
|
j3 = torch._C._get_nested_int(3, 1)
|
|
|
|
self.assertIsInstance(j1, torch.SymInt)
|
|
self.assertNotIsInstance(j1, int)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "add not supported by NestedIntSymNode"
|
|
):
|
|
j1 + 3
|
|
|
|
self.assertFalse(j1 == 3)
|
|
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
|
|
self.assertFalse(3 >= j2)
|
|
|
|
self.assertIs(j1 == j1, True)
|
|
self.assertIs(j1 == j2, True)
|
|
self.assertIs(j1 == j3, False)
|
|
self.assertIs(j1 != j3, True)
|
|
self.assertIs(j1 != j2, False)
|
|
|
|
x = self.get_constant_bool(True)
|
|
#
|
|
# Unary
|
|
#
|
|
# op(constant SymBool)
|
|
self.assertIs(x.__sym_not__(), False)
|
|
|
|
#
|
|
# Binary
|
|
#
|
|
# op(constant SymBool, bool)
|
|
# op(constant SymBool, constant SymBool)
|
|
# op(bool, constant SymBool)
|
|
self.assertIs(operator.and_(x, True), True)
|
|
self.assertIs(operator.and_(x, x), True)
|
|
self.assertIs(operator.and_(True, x), True)
|
|
|
|
# op(symbolic SymBool, constant Symbool)
|
|
# op(constant SymBool, symbolic Symbool)
|
|
shape_env = ShapeEnv()
|
|
a = create_symint(shape_env, 2)
|
|
b = create_symint(shape_env, 2)
|
|
c = a == b # symbolic SymBool
|
|
d = self.get_constant_bool(True)
|
|
e = operator.and_(c, d)
|
|
f = operator.and_(d, c)
|
|
self.assertTrue(is_symbolic(e))
|
|
self.assertTrue(is_symbolic(f))
|
|
self.assertIs(e.node.guard_bool("", 0), True)
|
|
self.assertIs(f.node.guard_bool("", 0), True)
|
|
|
|
# Comparing sizes
|
|
sz1 = torch.Size([j1, j1, j1])
|
|
sz2 = torch.Size([j1, j1, j1])
|
|
self.assertIs(sz1 == sz2, True)
|
|
|
|
sz1 = torch.Size([3, j1, 4])
|
|
sz2 = torch.Size([3, j2, 4])
|
|
self.assertIs(sz1 == sz2, True)
|
|
self.assertIs(sz1 != sz2, False)
|
|
|
|
def test_stride_symnode(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
def _create_symbolic_tensor(x, dynamic_sizes, dynamic_strides):
|
|
with FakeTensorMode(shape_env=shape_env) as fake_mode:
|
|
return fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_sizes,
|
|
dynamic_strides=dynamic_strides,
|
|
),
|
|
)
|
|
|
|
# check everything static
|
|
t = _create_symbolic_tensor(
|
|
x=torch.ones(3, 6),
|
|
dynamic_sizes=[
|
|
DimDynamic.STATIC,
|
|
DimDynamic.STATIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
self.assertTrue(all(isinstance(size, int) for size in t.size()))
|
|
self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
|
|
|
|
# check dynamic size but static dims
|
|
t = _create_symbolic_tensor(
|
|
x=torch.ones(3, 6),
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
# Expect stride to be inferred
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, torch.SymInt))
|
|
self.assertTrue(isinstance(s1, torch.SymInt))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(s1 == s2)
|
|
self.assertEqual(s3, 1)
|
|
|
|
# Check dynamic stride but static dims
|
|
t = _create_symbolic_tensor(
|
|
x=torch.ones(3, 6),
|
|
dynamic_sizes=[
|
|
DimDynamic.STATIC,
|
|
DimDynamic.STATIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, int))
|
|
self.assertTrue(isinstance(s1, int))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(isinstance(s3, int))
|
|
|
|
# Check dynamic sizes and dims, and ensure different symbol
|
|
t = _create_symbolic_tensor(
|
|
x=torch.ones(3, 6),
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, torch.SymInt))
|
|
self.assertTrue(isinstance(s1, torch.SymInt))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(isinstance(s3, int))
|
|
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
|
|
|
|
|
|
instantiate_parametrized_tests(TestSymNumberMagicMethods)
|
|
|
|
|
|
class TestFloorDiv(TestCase):
|
|
@staticmethod
|
|
def python_floordiv(x, y):
|
|
return x // y
|
|
|
|
@staticmethod
|
|
def torch_floordiv(x, y):
|
|
# Note: we fully evaluate here since FloorDiv might not always do
|
|
# that.
|
|
shape_env = ShapeEnv()
|
|
return shape_env.evaluate_expr(FloorDiv(x, y))
|
|
|
|
@staticmethod
|
|
def yield_test_cases(values, negate=True):
|
|
for x, y in values:
|
|
yield (x, y)
|
|
if negate:
|
|
yield (-x, y)
|
|
yield (x, -y)
|
|
yield (-x, -y)
|
|
|
|
def test_floordiv_float_int(self):
|
|
values = ((7, 2),)
|
|
|
|
for x, y in TestFloorDiv.yield_test_cases(values):
|
|
self.assertEqual(
|
|
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
|
)
|
|
|
|
def test_floordiv_div_by_one(self):
|
|
values = ((2, 1),)
|
|
|
|
for x, y in TestFloorDiv.yield_test_cases(values):
|
|
self.assertEqual(
|
|
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
|
)
|
|
|
|
def test_floordiv_simplify(self):
|
|
# Tests how we simplify or evaluate FloorDiv without free variables
|
|
shape_env = ShapeEnv()
|
|
result = 21
|
|
exprs = (7 * FloorDiv(6, 2),)
|
|
|
|
for expr in exprs:
|
|
self.assertEqual(expr, result)
|
|
self.assertEqual(expr.doit(deep=False), result)
|
|
self.assertEqual(expr.doit(deep=True), result)
|
|
self.assertEqual(sympy.simplify(expr), result)
|
|
self.assertEqual(shape_env.simplify(expr), result)
|
|
self.assertEqual(shape_env.evaluate_expr(expr), result)
|
|
|
|
def test_floordiv_assumptions(self):
|
|
cases = (
|
|
sympy.Symbol("i1", integer=True),
|
|
sympy.Symbol("i2", integer=True),
|
|
)
|
|
|
|
for base, divisor in itertools.product(cases, repeat=2):
|
|
|
|
def op():
|
|
return FloorDiv(base, divisor)
|
|
|
|
def is_complex(x):
|
|
return x.is_integer is False and x.is_real is False and x.is_complex
|
|
|
|
if is_complex(base) or is_complex(divisor):
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(
|
|
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
|
|
r" expected integer or real"
|
|
),
|
|
op,
|
|
)
|
|
continue
|
|
|
|
op = op()
|
|
|
|
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
|
|
# always returns an integer 1 when both args are the same object.
|
|
# This even works for Symbols with no assumptions specified.
|
|
if base is divisor:
|
|
self.assertTrue(op.is_integer)
|
|
self.assertTrue(op.is_real)
|
|
elif base.is_integer and divisor.is_integer:
|
|
self.assertTrue(op.is_integer)
|
|
self.assertTrue(op.is_real)
|
|
else:
|
|
self.assertEqual(op.is_integer, None)
|
|
self.assertTrue(op.is_real)
|
|
|
|
|
|
class TestDimConstraints(TestCase):
|
|
def test_dim_constraints_reduce_congruences_simple(self):
|
|
from sympy import Symbol
|
|
|
|
s = Symbol("s", positive=True, integer=True)
|
|
dim_constraints = DimConstraints({}, {}, set(), {})
|
|
dim_constraints._congruences[s] = {
|
|
(s / 2) % 2,
|
|
(s / 2) % 8,
|
|
(s / 2) % 4,
|
|
s % 2,
|
|
((s / 16) + 2) % 4,
|
|
}
|
|
congruences = dim_constraints._reduce_congruences()
|
|
self.assertEqual(congruences[s], {(s + 32) % 64})
|
|
|
|
def test_dim_constraints_reduce_inequalities_simple(self):
|
|
from sympy import Eq, Interval, Ne, Symbol
|
|
from sympy.solvers.inequalities import reduce_inequalities
|
|
|
|
s = Symbol("s", positive=True, integer=True)
|
|
exprs = {
|
|
s >= 2,
|
|
Ne(8 * s, 16),
|
|
Ne(s / 2, 1),
|
|
Ne(16 * s, 32),
|
|
s < 16,
|
|
Ne(s, 2),
|
|
s / 2 < 16,
|
|
s / 2 > 1,
|
|
s / 2 >= 2,
|
|
Ne(3 * s / 2, 3),
|
|
}
|
|
solution = reduce_inequalities(exprs, s).as_set()
|
|
self.assertEqual(solution, Interval.Ropen(4, 16))
|
|
|
|
exprs.add(Eq(s / 2, 4))
|
|
solution = reduce_inequalities(exprs, s).as_set()
|
|
self.assertEqual(solution, {8})
|
|
|
|
def test_dim_constraints_reduce_inequalities_error(self):
|
|
from collections import defaultdict
|
|
|
|
from sympy import Symbol
|
|
from sympy.solvers.inequalities import reduce_inequalities
|
|
|
|
from torch._dynamo.source import (
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
|
|
|
|
s0 = Symbol("s0", positive=True, integer=True)
|
|
exprs = {
|
|
4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
|
|
s0 >= 2,
|
|
s0**3 <= 2147483647,
|
|
s0 <= 2147483647,
|
|
}
|
|
answer = reduce_inequalities(exprs, s0)
|
|
|
|
symbol_to_source = defaultdict(list)
|
|
symbol_to_source[s0].append(
|
|
TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
)
|
|
dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Unknown symbol.*created by constraints solver",
|
|
):
|
|
dcp.doprint(answer)
|
|
|
|
def test_dim_constraints_solve_full(self):
|
|
from sympy import Eq, Integer, Ne, Symbol
|
|
|
|
from torch._dynamo.source import (
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
|
|
src0 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src2 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src3 = TensorPropertySource(
|
|
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src4 = TensorPropertySource(
|
|
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
|
|
src1 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
|
|
)
|
|
src7 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
|
|
)
|
|
|
|
src5 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src8 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
|
|
src6 = TensorPropertySource(
|
|
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src9 = TensorPropertySource(
|
|
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src10 = TensorPropertySource(
|
|
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
|
|
src11 = TensorPropertySource(
|
|
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src12 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
|
|
)
|
|
|
|
s0 = Symbol("s0", positive=True, integer=True)
|
|
s1 = Symbol("s1", positive=True, integer=True)
|
|
s5 = Symbol("s5", positive=True, integer=True)
|
|
s6 = Symbol("s6", positive=True, integer=True)
|
|
symbol_to_source = {
|
|
s0: [src0, src2, src3, src4],
|
|
s1: [src1, src7],
|
|
s5: [src5, src8],
|
|
s6: [src6, src9, src10],
|
|
}
|
|
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
|
|
marked_dynamic = {s0, s1, s5, s6}
|
|
dim_constraints = DimConstraints(
|
|
symbol_to_source, var_to_val, marked_dynamic, {}
|
|
)
|
|
dim_constraints.add_equality(src2, s0)
|
|
dim_constraints.add_equality(src3, s0)
|
|
dim_constraints.add_equality(src4, s0)
|
|
dim_constraints.add_equality(src7, s1)
|
|
dim_constraints.add_equality(src8, s5)
|
|
dim_constraints.add_equality(src9, s6)
|
|
dim_constraints.add_equality(src10, s6)
|
|
dim_constraints.add_equality(src11, Integer(1))
|
|
dim_constraints.add_equality(src12, Integer(3))
|
|
|
|
dim_constraints.add(s1**2 <= 2147483647)
|
|
dim_constraints.add(32 * s1**2 <= 2147483647)
|
|
dim_constraints.add(s0 < 16)
|
|
dim_constraints.add(Eq(Mod(s1, 2), 0))
|
|
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
|
|
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
|
|
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
|
|
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
|
|
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 64
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 128
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 256
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * s0,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
>= 0
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
|
|
dim_constraints.add(
|
|
1
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * s0,
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
|
|
+ 60 * (Mod(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
)
|
|
dim_constraints.add(Ne(16 * s0, 32))
|
|
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
|
|
dim_constraints.add(Ne(16 * s0, 32))
|
|
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
|
|
dim_constraints.add(FloorDiv(s0, 2) >= 2)
|
|
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
|
|
dim_constraints.add(1 < FloorDiv(s0, 2))
|
|
dim_constraints.add(Ne(s0, 2))
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20,
|
|
20,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20
|
|
* (
|
|
Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
* (
|
|
Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
)
|
|
)
|
|
- 20
|
|
* Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 2
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60,
|
|
60,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
* (
|
|
Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
1,
|
|
)
|
|
)
|
|
- Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
1,
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(8 * s0, 16))
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv(s0, 2) < 16)
|
|
dim_constraints.add(FloorDiv(s0, 2) > 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)),
|
|
3 * (FloorDiv(s0, 2)),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
6,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 240,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Eq(6 * s5, 132))
|
|
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
|
|
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62
|
|
)
|
|
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
|
|
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
|
|
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
|
|
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
|
|
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
>= 0
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
|
|
dim_constraints.add(
|
|
1
|
|
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2)),
|
|
256,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64
|
|
* (
|
|
Mod(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2)),
|
|
4,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
FloorDiv(s0, 2),
|
|
FloorDiv(
|
|
(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2))
|
|
),
|
|
4,
|
|
),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
FloorDiv(
|
|
(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2))
|
|
),
|
|
4,
|
|
),
|
|
FloorDiv(s0, 2),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64
|
|
* (
|
|
Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
4,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576,
|
|
256,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
> 1
|
|
)
|
|
dim_constraints.add(s0 >= 2)
|
|
dim_constraints.add(s1 >= 2)
|
|
dim_constraints.add(s6 >= 2)
|
|
dim_constraints.add(s5 >= 2)
|
|
|
|
dim_constraints.solve()
|
|
self.assertEqual(
|
|
dim_constraints._static_results,
|
|
{
|
|
"L['c'].size()[0] == 8",
|
|
"L['d'].size()[0] == 8",
|
|
"L['a'].size()[2] == 96",
|
|
"L['f'].size()[1] == 1",
|
|
"L['a'].size()[3] == 96",
|
|
"L['b'].size()[2] == 3",
|
|
"L['b'].size()[1] == 22",
|
|
"L['b'].size()[0] == 8",
|
|
"L['a'].size()[1] == 22",
|
|
"L['a'].size()[0] == 8",
|
|
},
|
|
)
|
|
self.assertEqual(
|
|
dim_constraints._dynamic_results,
|
|
{
|
|
"2 <= L['c'].size()[1]",
|
|
"L['d'].size()[1] == L['c'].size()[1]",
|
|
"L['e'].size()[1] == L['c'].size()[1]",
|
|
},
|
|
)
|
|
|
|
|
|
class TestGuardsExpressions(TestCase):
|
|
"""
|
|
Tests the guards-related methods used by the inductor FX graph cache.
|
|
"""
|
|
|
|
def test_guards_gt_lt(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 6)
|
|
s1 = create_symint(shape_env, 7)
|
|
s2 = create_symint(shape_env, 5)
|
|
|
|
guard_int(sym_int(s0 > 5))
|
|
guard_int(sym_int(s0 < 7))
|
|
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
|
|
|
|
def test_guards_float_print(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 3)
|
|
guard_bool(2 / s0 == 2 / 3)
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
|
|
def test_guards_float_div(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 8)
|
|
s1 = create_symint(shape_env, 7)
|
|
|
|
guard_int(sym_int(s0 / 2.0))
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
|
|
self.assertIn("math.trunc(", guards)
|
|
self.assertIn("float(", guards)
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|