mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Lets explore firs a couple of problem related to replacements and runtime assertions. #### example problem 1 if we have a runtime assertions that u0==s0, u0 is an input coming from mark_unbacked. A replacement u0=s0 will be added, the function f(u0, s0) will become f(s0, s0), this leads to the assert not being inserted during insert_deferred_runtime_asserts. The reason is that insert_deferred_runtime_asserts logic insert each assertion once all its inputs are seen, but u0 will never be seen. Same thing can happen when we defer assertion on backed i.e: s0==s2 ..etc. #### example problem 2 Consider u0==s0, where u0 is coming from a call to .item() Imagine later on that a specialization happens to s0 to become 2. In that case s0 as input wont be seen during insert_deferred_runtime_asserts and the assertion won't be inserted in the graph. Worse, Inductor will generate some code that refers to s0 in the cpp wrapper while it does not exist, causing a failure. internal xref: https://fb.workplace.com/groups/1075192433118967/permalink/1669766396994898/ ## The solution : Runtime assertions insertion loops depend on detecting that the symbols that are used in the runtime assertions are seen, note that those symbols are either graph inputs or generated in the graph from data dependent ops like .item(). The issues above happen when symbols are graph inputs, in order to force the symbols to exist in the graph and to be seen by the runtime assertions we do not do replacements on placeholders expressions during codegen and during runtime assertions insertion. This should not have performance overhead, since we already optimized the graph with replacements, the only effect is not mistakenly dropping graph inputs that are used in runtime assertions. I added extended testing. A solo unrelated follow up that I noticed, is that we might want to rename unbacked symbols in runtime assertions when we do unbacked renaming, but that's a different issue. Other approaches that did not work : #### ban replacements on unbacked. 1. does not work when we defer runtime assertions on backed ex: s0==s1. we could also ban such replacements but problem 2 becomes more problematic. 2. Problem two, it affects the quality of reasoning ! in a bad way. #### Apply specialization on runtime assertions before codegen . 1. Can fix some issues, but may lead also to runtime assertions becoming NOPs. 2. Does not fix the issue if not inserting runtime assertions during insert_deferred_runtime_asserts due to input not being detected. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153661 Approved by: https://github.com/jansel
3426 lines
124 KiB
Python
3426 lines
124 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# ruff: noqa: F841
|
|
import contextlib
|
|
import copy
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.fx
|
|
import torch.nn.functional as F
|
|
from torch import sym_int, SymBool, SymFloat, SymInt
|
|
from torch._C import _disabled_torch_function_impl
|
|
from torch._dynamo.testing import CompileCounterWithBackend
|
|
from torch._inductor.utils import fresh_inductor_cache
|
|
from torch.fx.experimental import sym_node
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
DimConstraints,
|
|
DimDynamic,
|
|
expect_true,
|
|
guard_bool,
|
|
guard_float,
|
|
guard_int,
|
|
GuardOnDataDependentSymNode,
|
|
has_free_symbols,
|
|
hint_int,
|
|
is_symbolic,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
statically_known_false,
|
|
statically_known_true,
|
|
)
|
|
from torch.testing._internal.common_dtype import all_types_and
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.logging_utils import logs_to_string
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._sympy.functions import (
|
|
CleanDiv,
|
|
FloorDiv,
|
|
IsNonOverlappingAndDenseIndicator,
|
|
Mod,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
meta_funcs = {}
|
|
|
|
|
|
def register_meta(op):
|
|
def decorator(f):
|
|
def add_func(op):
|
|
meta_funcs[op] = f
|
|
|
|
pytree.tree_map_(add_func, op)
|
|
return f
|
|
|
|
return decorator
|
|
|
|
|
|
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
|
def binary_meta(a, b):
|
|
return a.new_empty(a.shape)
|
|
|
|
|
|
@register_meta(aten.cat.default)
|
|
def cat_meta(tensors, dim=0):
|
|
concat_length = 0
|
|
shape = tensors[0].shape
|
|
for tensor in tensors:
|
|
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
|
if idx == dim:
|
|
concat_length = concat_length + length
|
|
else:
|
|
assert length == common_length
|
|
new_shape = list(shape)
|
|
new_shape[dim] = concat_length
|
|
return tensors[0].new_empty(new_shape)
|
|
|
|
|
|
@register_meta([aten.narrow_copy.default])
|
|
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
|
shape = []
|
|
for i, x in enumerate(a.shape):
|
|
if i == dim:
|
|
shape.append(length)
|
|
else:
|
|
shape.append(x)
|
|
return a.new_empty(tuple(shape))
|
|
|
|
|
|
@register_meta([aten.expand.default])
|
|
def expand_symint_meta(a, size, implicit=False):
|
|
return a.new_empty(size)
|
|
|
|
|
|
def create_contiguous(shape):
|
|
strides = [1]
|
|
for dim in reversed(shape[:-1]):
|
|
strides.append(dim * strides[-1])
|
|
return list(reversed(strides))
|
|
|
|
|
|
class FakeSymbolicTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(
|
|
cls,
|
|
sym_shape,
|
|
sym_strides,
|
|
dtype,
|
|
layout,
|
|
requires_grad,
|
|
device,
|
|
storage_offset=0,
|
|
):
|
|
# TODO: this is wrong in general
|
|
sym_stride = create_contiguous(sym_shape)
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
sym_shape,
|
|
sym_stride,
|
|
storage_offset,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
requires_grad=requires_grad,
|
|
device=device,
|
|
)
|
|
return r
|
|
|
|
__torch_function__ = _disabled_torch_function_impl
|
|
|
|
def new_empty(self, shape):
|
|
return FakeSymbolicTensor(
|
|
shape, None, self.dtype, self.layout, self.requires_grad, self.device
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
|
if func_overload in meta_funcs:
|
|
return meta_funcs[func_overload](*args, **kwargs)
|
|
|
|
if func_overload == torch.ops.aten.new_empty.default:
|
|
self = args[0]
|
|
shape = args[1]
|
|
return FakeSymbolicTensor(
|
|
shape,
|
|
self.stride(),
|
|
self.dtype,
|
|
self.layout,
|
|
self.requires_grad,
|
|
self.device,
|
|
)
|
|
|
|
raise RuntimeError(f"operator {func_overload} not supported")
|
|
|
|
|
|
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
if source is None:
|
|
source = ConstantSource(name)
|
|
constraint_dims = [None] * arg.dim()
|
|
if dynamic_dims is None:
|
|
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
|
|
(
|
|
sym_shapes,
|
|
sym_strides,
|
|
sym_storage_offset,
|
|
) = shape_env.create_symbolic_sizes_strides_storage_offset(
|
|
arg,
|
|
source=source,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
|
|
),
|
|
)
|
|
return FakeSymbolicTensor(
|
|
sym_shapes,
|
|
sym_strides,
|
|
arg.dtype,
|
|
arg.layout,
|
|
arg.requires_grad,
|
|
arg.device,
|
|
sym_storage_offset,
|
|
)
|
|
|
|
|
|
def create_fake_tensor_with_dynamic_size(x, shape_env, dynamic_sizes, dynamic_strides):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
with FakeTensorMode(shape_env=shape_env) as fake_mode:
|
|
return fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_sizes,
|
|
dynamic_strides=dynamic_strides,
|
|
),
|
|
)
|
|
|
|
|
|
def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
symbol = shape_env.create_symbol(
|
|
val,
|
|
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
|
|
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
|
|
constraint_dim=None,
|
|
**kwargs,
|
|
)
|
|
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
|
|
|
|
|
# TODO: default duck to False
|
|
def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
|
|
return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
|
|
|
|
|
|
def create_symbool(shape_env, b: bool) -> SymBool:
|
|
return create_symtype(SymBool, bool, shape_env, b)
|
|
|
|
|
|
def create_symfloat(shape_env, f: float) -> SymFloat:
|
|
return create_symtype(SymFloat, float, shape_env, f)
|
|
|
|
|
|
@skipIfTorchDynamo(
|
|
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
|
|
)
|
|
class TestPySymInt(TestCase):
|
|
def test_arith_ops(self):
|
|
shape_env = ShapeEnv()
|
|
symints = []
|
|
for i in range(2, 5):
|
|
symints.append((i, create_symint(shape_env, i)))
|
|
|
|
ops = [
|
|
operator.add,
|
|
operator.sub,
|
|
operator.floordiv,
|
|
operator.mul,
|
|
operator.mod,
|
|
]
|
|
|
|
for op in ops:
|
|
for args in itertools.permutations(symints, 2):
|
|
if not isinstance(args[0][1], int) and (
|
|
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
|
|
):
|
|
self.assertTrue(
|
|
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
|
|
)
|
|
|
|
def test_reverse_arith_ops(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
a = create_symint(shape_env, 2)
|
|
self.assertTrue(5 // a == 5 // 2)
|
|
|
|
a = create_symint(shape_env, 2)
|
|
self.assertTrue(5 * a == 5 * 2)
|
|
|
|
def test_sympify_symint(self):
|
|
shape_env = ShapeEnv()
|
|
a = create_symint(shape_env, 2)
|
|
self.assertIs(sympy.sympify(a), a.node.expr)
|
|
b = create_symfloat(shape_env, 3.0)
|
|
self.assertIs(sympy.sympify(b), b.node.expr)
|
|
c = create_symbool(shape_env, True)
|
|
self.assertIs(sympy.sympify(c), c.node.expr)
|
|
|
|
def test_roundtrip(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
|
|
self.assertTrue(not isinstance(x.shape[0], SymNode))
|
|
self.assertTrue(isinstance(x.shape[0], SymInt))
|
|
|
|
self.assertTrue(x.shape[0] == 5)
|
|
self.assertTrue(x.shape[1] == 4)
|
|
self.assertTrue(x.shape[2], 3)
|
|
|
|
self.assertTrue(x.size()[0], 5)
|
|
self.assertTrue(x.size()[1], 4)
|
|
# Should be simplifiable to an integer.
|
|
# Ref: https://github.com/pytorch/pytorch/pull/107492
|
|
self.assertTrue(isinstance(x.size()[1], SymInt))
|
|
self.assertTrue(
|
|
isinstance(x.size()[1].node.maybe_as_int(), int)
|
|
) # due to guard above
|
|
self.assertTrue(x.size()[2] == 3)
|
|
|
|
self.assertTrue(x.size(0) == 5)
|
|
self.assertTrue(x.size(1) == 4)
|
|
self.assertTrue(x.size(2) == 3)
|
|
self.assertTrue(isinstance(x.size(2), SymInt))
|
|
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
|
|
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
|
|
self.assertTrue(isinstance(y.storage_offset(), SymInt))
|
|
self.assertTrue(y.storage_offset() == 12)
|
|
|
|
def test_binary(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
|
|
|
z = x + y
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# broadcasting
|
|
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
|
|
z = x + y
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
def test_symint_args(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
|
LAST_DIM = 2
|
|
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
|
self.assertTrue(z.shape[2] == y.shape[2])
|
|
|
|
# arithmetic expr with two symints
|
|
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
|
self.assertTrue(z.shape[2] == 2)
|
|
|
|
# arithmetic expr with a symint and python int
|
|
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
|
self.assertTrue(z.shape[2] == 2)
|
|
|
|
def test_symint_vargs(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
|
|
|
# varargs
|
|
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# shape list
|
|
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints
|
|
z = y.expand(x.shape[0], y.shape[1], 3)
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints in a list
|
|
z = y.expand((x.shape[0], y.shape[1], 3))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints
|
|
z = y.expand(5, y.shape[1], x.shape[2])
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python ints and symints in a list
|
|
z = y.expand((5, y.shape[1], x.shape[2]))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
z = y.expand((y.shape[1],))
|
|
z = y.expand(y.shape[1])
|
|
|
|
def test_symint_bitwise_and(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 0b1100)
|
|
b0 = create_symint(shape_env, 0b1010)
|
|
res_and = a0 & b0
|
|
self.assertEqual(res_and, 0b1000)
|
|
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
|
|
)
|
|
|
|
a1 = create_symint(shape_env, 3)
|
|
b1 = create_symbool(shape_env, True)
|
|
self.assertEqual(a1 & b1, 1)
|
|
|
|
a2 = create_symint(shape_env, 0b1100)
|
|
self.assertEqual(a2 & 0b1010, 0b1000)
|
|
|
|
a3 = create_symbool(shape_env, True)
|
|
b3 = create_symbool(shape_env, True)
|
|
self.assertEqual(a3 & b3, True)
|
|
|
|
def test_symint_bitwise_or(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 0b1100)
|
|
b0 = create_symint(shape_env, 0b1010)
|
|
res_or = a0 | b0
|
|
self.assertEqual(res_or, 0b1110)
|
|
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
|
|
)
|
|
|
|
def test_stride(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
|
self.assertIsInstance(x.stride()[0], SymInt)
|
|
|
|
def test_size_expressions(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
expand_x = x.expand(x.shape[0], x.shape[0])
|
|
if expand_x.shape[0] > 3:
|
|
result = expand_x + expand_x
|
|
else:
|
|
result = expand_x + expand_x
|
|
|
|
gt_op, _bt, is_size_obv = shape_env.guards[-1]
|
|
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
|
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
|
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
|
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
|
self.assertFalse(is_size_obv)
|
|
|
|
def test_floordiv_static(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 8)
|
|
# This was extracted from
|
|
# python test/inductor/test_cuda_cpp_wrapper.py -k
|
|
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
|
|
bool(s0 % 2 == 0)
|
|
bool(s0 % (s0 // 2) == 0)
|
|
bool(2 * (s0 // 2) == s0)
|
|
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
|
|
|
|
def test_numel(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
self.assertIsInstance(x.numel(), torch.SymInt)
|
|
self.assertIsInstance(torch.numel(x), torch.SymInt)
|
|
|
|
x = torch.rand(3, 3)
|
|
self.assertIsInstance(x.numel(), int)
|
|
self.assertIsInstance(torch.numel(x), int)
|
|
|
|
def test_int_to_float(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
r = torch.sym_float(x.shape[0])
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
|
|
def test_aten_ops(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
|
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
|
|
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
|
|
|
def test_fx_trace_intlist(self):
|
|
class CustomModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
bs, c, h, w = x.shape
|
|
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
|
|
|
m = CustomModule()
|
|
x = torch.rand(1, 3, 4, 4)
|
|
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
|
# tuple of ints, not tuple
|
|
torch.fx.symbolic_trace(m)
|
|
|
|
def test_meta_symint(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
r = torch.empty(a0, device="meta")
|
|
self.assertIsInstance(r.shape[0], SymInt)
|
|
|
|
def test_guard_int(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
self.assertEqual(guard_int(a0), 2)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
|
|
|
def test_sym_sum(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 2)
|
|
s1 = create_symint(shape_env, 3)
|
|
s2 = create_symint(shape_env, 4)
|
|
self.assertEqual(
|
|
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
|
|
)
|
|
|
|
def test_prefer_deferred_runtime_assertions_over_guards(self):
|
|
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
|
s0 = create_symint(shape_env, 2)
|
|
self.assertEqual(guard_int(s0), 2)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
|
|
|
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
|
s0 = create_symint(shape_env, 2)
|
|
self.assertTrue(expect_true(s0 == 2))
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
|
|
"""[Eq(s97, 2)]""",
|
|
)
|
|
|
|
def test_sym_int(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = sym_int(a0)
|
|
self.assertEqual(r, 5)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
|
|
|
|
a1 = create_symint(shape_env, 7)
|
|
r = sym_int(a1 / 2)
|
|
self.assertEqual(guard_int(r), 3)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
|
|
)
|
|
|
|
a3 = create_symint(shape_env, 3)
|
|
r = sym_int(2.0 * torch.sym_float(a3))
|
|
self.assertEqual(guard_int(r), 6)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
|
|
)
|
|
|
|
def test_sym_log2(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 4)
|
|
r = torch._sym_log2(a0)
|
|
self.assertEqual(r, 2.0)
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
|
|
)
|
|
|
|
def test_sym_sqrt(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 4)
|
|
r = torch._sym_sqrt(a0)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
|
|
)
|
|
|
|
def test_sym_floor(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.floor(a0 / 2)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
|
|
)
|
|
r = math.floor(3.0 * a0)
|
|
self.assertEqual(r, 15)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
|
|
)
|
|
|
|
def test_sym_trunc(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.trunc(a0 / 2)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
|
|
)
|
|
r = torch.sym_int(torch.sym_sqrt(a0))
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
|
|
)
|
|
|
|
def test_sym_ceil(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.ceil(a0 / 2)
|
|
self.assertEqual(r, 3)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
|
|
)
|
|
r1 = 3.0 * a0
|
|
r = math.floor(r1)
|
|
self.assertEqual(r, 15)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
|
|
)
|
|
|
|
def test_sym_ite(self):
|
|
shape_env = ShapeEnv()
|
|
t = create_symint(shape_env, 5)
|
|
f = create_symint(shape_env, 4)
|
|
b1 = True
|
|
r1 = torch.sym_ite(b1, t, f)
|
|
self.assertTrue(r1 is t)
|
|
b2 = False
|
|
r2 = torch.sym_ite(b2, t, f)
|
|
self.assertTrue(r2 is f)
|
|
b3 = t == 5
|
|
r3 = torch.sym_ite(b3, t, f)
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
self.assertEqual(r3, 5)
|
|
self.assertEqual(type(t), type(r3))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
|
|
)
|
|
b4 = f == 5
|
|
r4 = torch.sym_ite(b4, t, f)
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
self.assertEqual(r4, 4)
|
|
self.assertEqual(type(f), type(r4))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
|
|
)
|
|
|
|
def test_tracing_sym_ite(self):
|
|
def f(x):
|
|
b = x.shape[0] == 5
|
|
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
|
|
return ret
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
|
|
self.assertEqual(len(gm.shape_env.guards), 0)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 5
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
|
|
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
|
|
return sym_ite""",
|
|
)
|
|
r1 = gm(torch.ones(4, 5))
|
|
self.assertIsInstance(r1, int)
|
|
self.assertEqual(r1, 5)
|
|
r2 = gm(torch.ones(5, 4))
|
|
self.assertIsInstance(r2, int)
|
|
self.assertEqual(r2, 5)
|
|
|
|
def test_int_conversion(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
int(a0)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
|
|
|
def test_data_dependent_guard(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
|
|
|
|
def test_data_dependent_guard_propagate_real_tensors(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
|
|
self.assertEqual(bool(s0 == 0), True)
|
|
|
|
def test_expect_true_basic(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i0_sym = i0.node.expr
|
|
# This doesn't error
|
|
self.assertTrue(expect_true(i0 == 0))
|
|
# This generates a deferred runtime assert via replacement
|
|
self.assertEqual(shape_env.replacements[i0_sym], 0)
|
|
# After expecting true, guards now resolve given the runtime assert
|
|
bool(i0 == 0)
|
|
|
|
def test_expect_true_with_s0(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 5)
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(i0 < s0))
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
|
|
"""[u0 < s97]""",
|
|
)
|
|
self.assertTrue(i0 < s0)
|
|
self.assertTrue(i0 != s0)
|
|
self.assertFalse(i0 > s0)
|
|
self.assertFalse(i0 >= s0)
|
|
|
|
def test_expect_true_prefer_later(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i1 = shape_env.create_unbacked_symint()
|
|
i1_sym = i1.node.expr
|
|
self.assertTrue(expect_true(i0 + i1 == 10))
|
|
# Importantly, this is put in i1, not i0!
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
|
|
"""[Eq(u0 + u1, 10)]""",
|
|
)
|
|
self.assertTrue(i0 + i1 == 10)
|
|
# NB: We currently don't support deriving that we can substitute
|
|
# i0 + i1 with 10; maybe we should, but this means our rewriting
|
|
# system is no longer confluent (it's probably OK though, because
|
|
# you're unlikely to get other equalities like this on the
|
|
# unbacked SymInts.)
|
|
|
|
def test_unbacked_substitution(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i1 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i0)
|
|
_constrain_range_for_size(i1)
|
|
self.assertTrue(expect_true(i0 == i1 * 4))
|
|
self.assertExpectedInline(str(i0), """u0""")
|
|
|
|
i2 = shape_env.create_unbacked_symint()
|
|
i3 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i2)
|
|
_constrain_range_for_size(i3)
|
|
self.assertTrue(expect_true(i2 * 4 == i3))
|
|
self.assertExpectedInline(str(i3), """u3""")
|
|
|
|
def test_avoid_unbacked_substitution(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i0)
|
|
i1 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i1)
|
|
self.assertTrue(expect_true(i0 == 10 - i1))
|
|
self.assertExpectedInline(str(i0), """u0""")
|
|
|
|
def test_expect_true_double_digits(self):
|
|
shape_env = ShapeEnv()
|
|
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
|
|
self.assertEqual(str(ia[-1]), "u10")
|
|
self.assertTrue(expect_true(sum(ia) == 20))
|
|
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
|
|
|
|
def test_expect_true_refine_range(self):
|
|
shape_env = ShapeEnv()
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertFalse(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 > 4))
|
|
self.assertTrue(statically_known_true(i0 >= 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 2))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 < 4))
|
|
self.assertTrue(statically_known_true(i0 <= 5))
|
|
|
|
def test_guard_refine_range(self):
|
|
shape_env = ShapeEnv()
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 10, duck=False)
|
|
self.assertTrue(bool(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertFalse(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 > 4))
|
|
self.assertTrue(statically_known_true(i0 >= 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 2, duck=False)
|
|
self.assertFalse(bool(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertFalse(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 <= 4))
|
|
self.assertTrue(statically_known_true(i0 < 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 2, duck=False)
|
|
self.assertTrue(bool(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 2))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 < 4))
|
|
self.assertTrue(statically_known_true(i0 <= 3))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 10, duck=False)
|
|
self.assertFalse(bool(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 2))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertFalse(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 >= 4))
|
|
self.assertTrue(statically_known_true(i0 > 3))
|
|
|
|
def test_mul_int_oo_nan(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 5, duck=False)
|
|
s1 = create_symint(shape_env, 6, duck=False)
|
|
s2 = create_symint(shape_env, 5, duck=False)
|
|
bool(s0 * (s1 // s0) == s2)
|
|
|
|
def test_non_overlapping_and_dense(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
|
|
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
|
|
|
|
def test_non_overlapping_and_dense_unbacked(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
cf = torch.ops.aten.is_non_overlapping_and_dense.default
|
|
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
|
|
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
|
|
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
|
|
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
|
|
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
|
|
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
|
|
|
|
Max = torch.sym_max
|
|
# NB: This only works because we're able to determine this tensor is
|
|
# contiguous. transpose(0, 1) makes it stop working
|
|
self.assertTrue(
|
|
cf(
|
|
torch.empty_strided(
|
|
(2, 3, 1, u0),
|
|
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
|
|
device="meta",
|
|
)
|
|
)
|
|
)
|
|
|
|
def test_sympy_optimized_add_binary_search(self):
|
|
import sympy
|
|
|
|
from torch.fx.experimental.sym_node import _binary_search_insert_arg
|
|
|
|
a = sympy.Symbol("a")
|
|
b = sympy.Symbol("b")
|
|
c = sympy.Symbol("c")
|
|
|
|
args = []
|
|
args = _binary_search_insert_arg([], b)
|
|
self.assertEqual(args, [b])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
|
|
args = _binary_search_insert_arg(args, a)
|
|
self.assertEqual(args, [a, b])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, a), None)
|
|
|
|
args = _binary_search_insert_arg(args, c)
|
|
self.assertEqual(args, [a, b, c])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, a), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, c), None)
|
|
|
|
a1 = sympy.Symbol("a1")
|
|
a2 = sympy.Symbol("a2")
|
|
|
|
args = _binary_search_insert_arg(args, a1)
|
|
self.assertEqual(args, [a, a1, b, c])
|
|
|
|
args = _binary_search_insert_arg(args, a2)
|
|
self.assertEqual(args, [a, a1, a2, b, c])
|
|
|
|
c1 = sympy.Symbol("c1")
|
|
args = _binary_search_insert_arg(args, c1)
|
|
self.assertEqual(args, [a, a1, a2, b, c, c1])
|
|
|
|
# insert to front
|
|
_a = sympy.Symbol("_a")
|
|
args = _binary_search_insert_arg(args, _a)
|
|
self.assertEqual(args, [_a, a, a1, a2, b, c, c1])
|
|
|
|
def test_floor_clean_div_axioms(self):
|
|
# Test that if we add an axiom that have FloorDiv, after which the
|
|
# shapeEnv changed such that it can be simplified it to CleanDiv, then
|
|
# We still correctly replace CleanDiv with the axiom value of FloorDiv.
|
|
shape_env = ShapeEnv()
|
|
a = shape_env.create_unbacked_symint()
|
|
|
|
shape_env.defer_runtime_assert((a // 3 == 1).node.expr, " test")
|
|
|
|
from sympy import Eq
|
|
|
|
test1 = Eq(FloorDiv(a.node.expr, 3), 1)
|
|
test2 = Eq(CleanDiv(a.node.expr, 3), 1)
|
|
|
|
self.assertTrue(shape_env.evaluate_expr(test1))
|
|
self.assertEqual(shape_env._maybe_evaluate_static(test2), None)
|
|
|
|
# After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3)
|
|
shape_env.defer_runtime_assert(Eq(Mod(a, 3), 0), " test")
|
|
self.assertEqual(test2, shape_env.simplify(test1))
|
|
|
|
self.assertTrue(shape_env.evaluate_expr(test1))
|
|
self.assertTrue(shape_env.evaluate_expr(test2))
|
|
|
|
def test_sympy_optimized_add(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 2)
|
|
s1 = create_symint(shape_env, 3)
|
|
s2 = create_symint(shape_env, 4)
|
|
sum = s0 + s1
|
|
|
|
self.assertTrue(sum.node._optimized_summation)
|
|
|
|
def assert_optimized(sym):
|
|
self.assertTrue(sym.node._optimized_summation)
|
|
|
|
def assert_not_optimized(sym):
|
|
self.assertFalse(getattr(sym.node, "_optimized_summation", False))
|
|
|
|
assert_optimized(sum)
|
|
|
|
# add duplicate symbol
|
|
assert_not_optimized(sum + s0)
|
|
|
|
# add constant.
|
|
assert_not_optimized(sum + 1)
|
|
|
|
# add new unique symbol, should maintain _optimized_summation property.
|
|
assert_optimized(sum + s2)
|
|
|
|
assert_optimized(s2 + sum)
|
|
|
|
# add x + (a+b) with no _optimized_summation on the rhs sum.
|
|
a = create_symint(shape_env, 10)
|
|
b = create_symint(shape_env, 11)
|
|
two_sum = torch.sym_sum([a, b])
|
|
assert_not_optimized(two_sum)
|
|
assert_optimized(sum + two_sum)
|
|
|
|
# adding two expressions of length >2 that are _optimized_summation.
|
|
a = s0 + s1 + s2
|
|
s3 = create_symint(shape_env, 10)
|
|
s4 = create_symint(shape_env, 20)
|
|
s5 = create_symint(shape_env, 30)
|
|
b = s3 + s4 + s5
|
|
assert_optimized(a)
|
|
assert_optimized(b)
|
|
assert_not_optimized(a + b)
|
|
assert_not_optimized(b + a)
|
|
assert_not_optimized(b + a + b)
|
|
|
|
def test_max_of_unique_summation_opt(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
s1 = shape_env.create_unbacked_symint()
|
|
s2 = shape_env.create_unbacked_symint()
|
|
s3 = shape_env.create_unbacked_symint()
|
|
s4 = shape_env.create_unbacked_symint()
|
|
s5 = shape_env.create_unbacked_symint()
|
|
s7 = shape_env.create_unbacked_symint()
|
|
|
|
def assert_optimized(sym):
|
|
self.assertTrue(sym.node.expr.unique_summations_symbols is not None)
|
|
|
|
def assert_not_optimized(sym):
|
|
getattr(sym.node.expr, "unique_summations_symbols", None)
|
|
|
|
mx1 = torch.sym_max(s0, s1)
|
|
assert_not_optimized(mx1)
|
|
|
|
mx2 = torch.sym_max(s0 + s1, s2 + s3)
|
|
assert_optimized(mx2)
|
|
|
|
mx3 = torch.sym_max(mx2, s4 + s5)
|
|
assert_optimized(mx3)
|
|
assert_optimized(torch.sym_max(s4 + s5, mx2))
|
|
|
|
assert_not_optimized(torch.sym_max(mx3, s7))
|
|
assert_not_optimized(torch.sym_max(mx3, 10))
|
|
assert_not_optimized(torch.sym_max(mx3, s3 + s7))
|
|
assert_not_optimized(torch.sym_max(mx3, s7 * 2))
|
|
|
|
def test_sym_max_multi_max_simplify(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(
|
|
statically_known_true(
|
|
torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0)
|
|
)
|
|
)
|
|
|
|
def test_numpy_sym_max(self):
|
|
self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
|
|
self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
|
|
self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
|
|
self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
|
|
self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
|
|
self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
|
|
|
|
def test_numpy_sym_min(self):
|
|
self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
|
|
self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
|
|
self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
|
|
self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
|
|
self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
|
|
self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
|
|
|
|
def test_debug_has_internal_overlap_unbacked(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
cf = torch._debug_has_internal_overlap
|
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
|
|
Max = torch.sym_max
|
|
self.assertEqual(
|
|
cf(
|
|
torch.empty_strided(
|
|
(2, 3, 1, u0),
|
|
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
|
|
device="meta",
|
|
)
|
|
),
|
|
2,
|
|
)
|
|
|
|
# Wobbling these to zero is OK too
|
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
|
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
|
|
|
|
def test_specialize_zero_one(self):
|
|
shape_env = ShapeEnv(specialize_zero_one=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0 != 1
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
shape_env = ShapeEnv(specialize_zero_one=False)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0 != 1
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
|
|
def test_duck_shape(self):
|
|
shape_env = ShapeEnv(duck_shape=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
a1 = create_symint(shape_env, 5)
|
|
assert a0 == a1
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
shape_env = ShapeEnv(duck_shape=False)
|
|
a0 = create_symint(shape_env, 5)
|
|
a1 = create_symint(shape_env, 5)
|
|
assert a0 == a1
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
|
|
def test_int_bool(self):
|
|
# See https://github.com/pytorch/pytorch/issues/95981
|
|
shape_env = ShapeEnv(duck_shape=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_symint_as_scalar(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
|
|
sym_int_encountered = False
|
|
|
|
class TestSymInt(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
assert func == torch.ops.aten.add.Tensor
|
|
|
|
nonlocal sym_int_encountered
|
|
# WARNING: do not do identity tests on the outer
|
|
# SymInt/SymFloat, they are NOT STABLE
|
|
sym_int_encountered = kwargs["alpha"].node is a0.node
|
|
kwargs["alpha"] = 0
|
|
return func(*args)
|
|
|
|
x = torch.rand([4, 4])
|
|
with TestSymInt():
|
|
y = torch.add(x, x, alpha=a0)
|
|
|
|
self.assertTrue(sym_int_encountered)
|
|
|
|
def test_deepcopy(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
assert a0 < 4
|
|
new_shape_env = copy.deepcopy(shape_env)
|
|
self.assertEqual(len(new_shape_env.guards), 1)
|
|
|
|
def test_print_readable_with_symints(self):
|
|
def f(a, b):
|
|
dim0 = a.shape[0] + b.shape[0]
|
|
dim1 = a.shape[1] + b.shape[1]
|
|
d = a.new_empty(dim0, dim1)
|
|
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
|
|
return d
|
|
|
|
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
|
|
out = fx_g.print_readable(print_output=False)
|
|
|
|
self.assertExpectedInline(
|
|
out.strip(),
|
|
"""\
|
|
class f(torch.nn.Module):
|
|
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
|
|
# No stacktrace found for following nodes
|
|
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
|
|
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
|
|
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
|
|
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
|
|
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
|
|
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
|
|
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
|
|
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
|
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
|
|
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
|
|
def test_statically_known_true(self):
|
|
shape_env = ShapeEnv()
|
|
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
|
|
|
|
# Statically known true
|
|
self.assertTrue(statically_known_true(True))
|
|
self.assertTrue(statically_known_true(s2 == s2))
|
|
self.assertTrue(statically_known_true(s2 * s3 > s3))
|
|
self.assertTrue(statically_known_true(s3 * s4 > s4))
|
|
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
|
|
|
|
# Statically known false
|
|
self.assertFalse(statically_known_true(False))
|
|
self.assertFalse(statically_known_true(s3 * s4 <= s4))
|
|
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
|
|
|
|
# True for hints, but not known statically
|
|
self.assertFalse(statically_known_true(s2 + s2 == s4))
|
|
self.assertFalse(statically_known_true(s4 % s2 == 0))
|
|
self.assertFalse(statically_known_true(s2 != s3))
|
|
self.assertFalse(statically_known_true(s3 * s4 > s2))
|
|
|
|
# False for hints, but not known statically
|
|
self.assertFalse(statically_known_true(s2 == s3))
|
|
self.assertFalse(statically_known_true(s2 > s3))
|
|
self.assertFalse(statically_known_true(s3 + s3 == s4))
|
|
|
|
# No guards should be generated
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_statically_known_false(self):
|
|
shape_env = ShapeEnv()
|
|
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
|
|
|
|
# Statically known true
|
|
self.assertFalse(statically_known_false(True))
|
|
self.assertFalse(statically_known_false(s2 == s2))
|
|
self.assertFalse(statically_known_false(s2 * s3 > s3))
|
|
self.assertFalse(statically_known_false(s3 * s4 > s4))
|
|
self.assertFalse(statically_known_false((s3 + s3) % 2 == 0))
|
|
|
|
# Statically known false
|
|
self.assertTrue(statically_known_false(False))
|
|
self.assertTrue(statically_known_false(s3 * s4 <= s4))
|
|
self.assertTrue(statically_known_false((s3 + s3) % 2 == 1))
|
|
|
|
# True for hints, but not known statically
|
|
self.assertFalse(statically_known_false(s2 + s2 == s4))
|
|
self.assertFalse(statically_known_false(s4 % s2 == 0))
|
|
self.assertFalse(statically_known_false(s2 != s3))
|
|
self.assertFalse(statically_known_false(s3 * s4 > s2))
|
|
|
|
# False for hints, but not known statically
|
|
self.assertFalse(statically_known_false(s2 == s3))
|
|
self.assertFalse(statically_known_false(s2 > s3))
|
|
self.assertFalse(statically_known_false(s3 + s3 == s4))
|
|
|
|
# No guards should be generated
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_ephemeral_source_simplification(self):
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
|
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
|
|
# of construction order or check order.
|
|
for construct_ephemeral_first, x_first_in_check in itertools.product(
|
|
[False, True], [False, True]
|
|
):
|
|
shape_env = ShapeEnv()
|
|
shape = (5, 10)
|
|
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
|
|
x = create_symbolic_tensor(
|
|
"x",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if construct_ephemeral_first else None),
|
|
dynamic_dims=dynamic_dims,
|
|
)
|
|
y = create_symbolic_tensor(
|
|
"y",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
|
dynamic_dims=dynamic_dims,
|
|
)
|
|
t_with_ephemeral = x if construct_ephemeral_first else y
|
|
|
|
def _get_ephemeral_source_symbols(t):
|
|
return [
|
|
s.node.expr
|
|
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
|
|
if isinstance(s, torch.SymInt)
|
|
and s.node.expr in shape_env.var_to_sources
|
|
and any(
|
|
source.is_ephemeral()
|
|
for source in shape_env.var_to_sources[s.node.expr]
|
|
)
|
|
]
|
|
|
|
# these checks should simplify out the ephemeral symbols, regardless of the
|
|
# ordering x == y or y == x
|
|
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
|
|
if x_first_in_check:
|
|
torch._check(x.size() == y.size())
|
|
torch._check(x.stride() == y.stride())
|
|
torch._check(x.storage_offset() == y.storage_offset())
|
|
else:
|
|
torch._check(y.size() == x.size())
|
|
torch._check(y.stride() == x.stride())
|
|
torch._check(y.storage_offset() == x.storage_offset())
|
|
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
|
|
|
|
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
|
for construct_ephemeral_first in (False, True):
|
|
shape_env = ShapeEnv()
|
|
shape = (5, 10)
|
|
# use duck sizing here to ensure symbol reuse across x and y
|
|
duck_dims = [DimDynamic.DUCK for _ in shape]
|
|
x = create_symbolic_tensor(
|
|
"x",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if construct_ephemeral_first else None),
|
|
dynamic_dims=duck_dims,
|
|
)
|
|
y = create_symbolic_tensor(
|
|
"y",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
|
dynamic_dims=duck_dims,
|
|
)
|
|
|
|
# regardless of construction order, non-ephemeral sources should be preferred
|
|
# first in the var_to_sources list for potential guarding later on
|
|
for source_list in shape_env.var_to_sources.values():
|
|
self.assertFalse(source_list[0].is_ephemeral())
|
|
|
|
self.assertEqual(x.size(), y.size())
|
|
self.assertEqual(x.stride(), y.stride())
|
|
self.assertEqual(x.storage_offset(), y.storage_offset())
|
|
|
|
def test_tensor_factory_with_symint(self):
|
|
args = list(range(0, 3))
|
|
expected = torch.tensor(args)
|
|
|
|
shape_env = ShapeEnv()
|
|
sym_args = [create_symint(shape_env, i) for i in args]
|
|
|
|
# test tensor factories
|
|
for dt in all_types_and(torch.half, torch.bfloat16):
|
|
res = torch.tensor(sym_args, dtype=dt)
|
|
self.assertEqual(res, expected, exact_dtype=False)
|
|
|
|
# test legacy tensor factories
|
|
legacy_ctors = [
|
|
torch.Tensor,
|
|
torch.LongTensor,
|
|
torch.DoubleTensor,
|
|
torch.FloatTensor,
|
|
torch.IntTensor,
|
|
torch.ShortTensor,
|
|
torch.HalfTensor,
|
|
torch.ByteTensor,
|
|
]
|
|
for Tensor in legacy_ctors:
|
|
res = Tensor(sym_args)
|
|
self.assertEqual(res, expected, exact_dtype=False)
|
|
|
|
def test_backed_size_oblivious_01_spec(self):
|
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
|
|
|
@torch.compile(dynamic=True, fullgraph=True)
|
|
def f(a, b):
|
|
if guard_size_oblivious(a.size(0) == 1):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
|
# always go to the >= 2 branch.
|
|
self.assertEqual(
|
|
f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
|
|
def test_baddbmm_symint(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
shape_env = ShapeEnv()
|
|
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
|
|
B, M, K, N = [shape_env.create_unbacked_symint() for _ in range(4)]
|
|
|
|
with fake_mode:
|
|
A = torch.empty((B, M, K), device="meta")
|
|
Bmat = torch.empty((B, K, N), device="meta")
|
|
bias3 = torch.empty((B, M, N), device="meta")
|
|
|
|
_ = torch.baddbmm(bias3, A, Bmat)
|
|
|
|
|
|
@skipIfTorchDynamo(
|
|
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
|
|
)
|
|
class TestSymNumberMagicMethods(TestCase):
|
|
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
|
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
|
|
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
|
|
|
|
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
|
# Helper function
|
|
# NB: don't use one as that will get specialized
|
|
# TODO: We don't have to circuitously create the float, can just
|
|
# create a symfloat directly
|
|
seed_node = (create_symint(shape_env, 2) / 2.0).node
|
|
bool_seed_node = (create_symint(shape_env, 2) == 2).node
|
|
|
|
def get_sym_inp(inp):
|
|
# NB: this must come before int
|
|
if isinstance(inp, bool):
|
|
return torch.SymBool(to_node(bool_seed_node, inp))
|
|
elif isinstance(inp, int):
|
|
return torch.SymInt(to_node(seed_node, inp))
|
|
else:
|
|
return torch.SymFloat(to_node(seed_node, inp))
|
|
|
|
if fn == "float_pow":
|
|
if inp1 < 0:
|
|
return
|
|
|
|
if fn == "pow_by_natural":
|
|
if isinstance(inp1, float) or isinstance(inp2, float):
|
|
return
|
|
if inp2 < 0:
|
|
return
|
|
|
|
def maybe_xfail(inp1, inp2):
|
|
if fn == "sym_sqrt" and inp1 < 0:
|
|
# ValueError: math domain error
|
|
return self.assertRaises((ValueError,))
|
|
elif (
|
|
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
|
|
and inp2 == 0
|
|
):
|
|
# ZeroDivisionError: division by zero
|
|
return self.assertRaises((ZeroDivisionError,))
|
|
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
|
|
# ZeroDivisionError: 0.0 cannot be raised to a negative power
|
|
return self.assertRaises((ZeroDivisionError,))
|
|
elif (
|
|
# TODO: dear catastrophe waitress,
|
|
# this doesn't work
|
|
fn in ["float_pow", "pow_by_natural"]
|
|
and inp1 < 0
|
|
and (
|
|
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
|
|
)
|
|
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
|
|
):
|
|
# Complex result, which we do not support:
|
|
# TypeError: Cannot convert complex to float
|
|
return self.assertRaises((RuntimeError,))
|
|
elif fn in ("lshift", "rshift") and not (
|
|
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
|
|
):
|
|
# TypeError: unsupported operand type(s)
|
|
return self.assertRaises((TypeError,))
|
|
elif fn in ("lshift", "rshift") and inp2 < 0:
|
|
# ValueError: math domain error
|
|
return self.assertRaises((ValueError,))
|
|
else:
|
|
return contextlib.nullcontext()
|
|
|
|
lambda_apply = method_to_operator(fn)
|
|
|
|
def guard_fn(v):
|
|
if type(v) in (SymBool, bool):
|
|
return guard_bool(v)
|
|
elif type(v) in (SymFloat, float):
|
|
return guard_float(v)
|
|
else: # SymInt, int
|
|
return guard_int(v)
|
|
|
|
# Get reference result
|
|
with maybe_xfail(inp1, inp2):
|
|
if is_unary_fn:
|
|
ref_out = lambda_apply(inp1)
|
|
else:
|
|
ref_out = lambda_apply(inp1, inp2)
|
|
|
|
# Symified first arg
|
|
sym_inp1 = get_sym_inp(inp1)
|
|
with maybe_xfail(sym_inp1, inp2):
|
|
if is_unary_fn:
|
|
out = lambda_apply(sym_inp1)
|
|
else:
|
|
out = lambda_apply(sym_inp1, inp2)
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
if is_unary_fn:
|
|
return
|
|
|
|
# Symified second arg
|
|
sym_inp2 = get_sym_inp(inp2)
|
|
with maybe_xfail(inp1, sym_inp2):
|
|
out = lambda_apply(inp1, sym_inp2)
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
# Symified both args
|
|
with maybe_xfail(sym_inp1, sym_inp2):
|
|
out = lambda_apply(sym_inp1, sym_inp2)
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
@parametrize("fn", list(sym_node.magic_methods.keys()))
|
|
def test_bool_method(self, fn):
|
|
# sym_ite has its own tests
|
|
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
|
|
self.skipTest(f"{fn} is non-bool")
|
|
|
|
is_unary_fn = fn in sym_node.unary_methods
|
|
shape_env = ShapeEnv()
|
|
self._do_test(fn, True, False, shape_env, is_unary_fn)
|
|
|
|
@parametrize("fn", list(sym_node.magic_methods.keys()))
|
|
@parametrize("first_type", ["int", "float"])
|
|
@parametrize("second_type", ["int", "float"])
|
|
def test_method(self, fn, first_type, second_type):
|
|
if first_type == "float":
|
|
# TODO: Hmm, this looks like we skip all floats
|
|
self.skipTest(f"{fn} is not a float magic method")
|
|
|
|
if (
|
|
first_type == "int" or second_type == "int"
|
|
) and fn in sym_node.only_float_magic_methods:
|
|
self.skipTest(f"{fn} is not an int method")
|
|
|
|
if second_type == "float" and fn in ["mod"]:
|
|
self.skipTest(f"{fn} only handles int")
|
|
|
|
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
|
|
self.skipTest(f"{fn} is a bitwise op, only handles int")
|
|
|
|
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
|
|
# Second argument is ignored for unary function. So only run for one type
|
|
if is_unary_fn and second_type == "float":
|
|
self.skipTest(f"{fn} is unary and already tested")
|
|
|
|
if fn in sym_node.bool_magic_methods:
|
|
self.skipTest(f"{fn} is bool")
|
|
|
|
# Only floats here since these will be converted to int if necessary.
|
|
# We also ignore complex and bool.
|
|
values = (
|
|
0.0,
|
|
1.0,
|
|
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
|
|
)
|
|
|
|
neg_values = tuple(-x for x in values)
|
|
|
|
for inp1, inp2 in itertools.chain(
|
|
itertools.product(values, values),
|
|
itertools.product(values, neg_values),
|
|
itertools.product(neg_values, values),
|
|
itertools.product(neg_values, neg_values),
|
|
):
|
|
if first_type == "int":
|
|
inp1 = int(inp1)
|
|
if second_type == "int":
|
|
inp2 = int(inp2)
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
|
|
|
|
def get_constant_bool(self, val):
|
|
return SymBool(torch._C._get_constant_bool_symnode(val))
|
|
|
|
@unittest.expectedFailure
|
|
def test_symint_hashing(self):
|
|
shape_env = ShapeEnv()
|
|
hash(create_symint(shape_env, 3))
|
|
|
|
def test_symnode_hashing(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
# These all trigger specialization when hashed
|
|
hash(create_symbool(shape_env, True))
|
|
# We should be passing in float here, but create_symbol currently
|
|
# only supports int
|
|
hash(create_symfloat(shape_env, 3.0))
|
|
|
|
# NestedInt (SymInt), constant SymBool, SymNode are hashable
|
|
j1 = torch._C._get_nested_int(1, 1)
|
|
j1_copy = torch._C._get_nested_int(1, 1)
|
|
j2 = torch._C._get_nested_int(2, 1)
|
|
t = self.get_constant_bool(True)
|
|
t_copy = self.get_constant_bool(True)
|
|
f = self.get_constant_bool(False)
|
|
n = create_symint(shape_env, 3).node
|
|
m = self.get_constant_bool(True).node
|
|
|
|
self.assertIs(j1 == j1_copy, True)
|
|
self.assertEqual(hash(j1), hash(j1_copy))
|
|
self.assertIs(j1 == j2, False)
|
|
self.assertNotEqual(hash(j1), hash(j2))
|
|
self.assertIs(t == t_copy, True)
|
|
self.assertEqual(hash(t), hash(t_copy))
|
|
self.assertIs(t == f, False)
|
|
self.assertNotEqual(hash(t), hash(f))
|
|
|
|
hash(n)
|
|
hash(m)
|
|
|
|
def test_symint_deepcopy(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
symnodes = (torch._C._get_nested_int(1, 1),)
|
|
deepcopied_symnodes = copy.deepcopy(symnodes)
|
|
self.assertEqual(symnodes, deepcopied_symnodes)
|
|
|
|
def test_non_symbolic_symnode(self):
|
|
j1 = torch._C._get_nested_int(1, 1)
|
|
j2 = torch._C._get_nested_int(1, 1)
|
|
j3 = torch._C._get_nested_int(3, 1)
|
|
|
|
self.assertIsInstance(j1, torch.SymInt)
|
|
self.assertNotIsInstance(j1, int)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "add not supported by NestedIntSymNode"
|
|
):
|
|
j1 + 3
|
|
|
|
self.assertFalse(j1 == 3)
|
|
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
|
|
self.assertFalse(3 >= j2)
|
|
|
|
self.assertIs(j1 == j1, True)
|
|
self.assertIs(j1 == j2, True)
|
|
self.assertIs(j1 == j3, False)
|
|
self.assertIs(j1 != j3, True)
|
|
self.assertIs(j1 != j2, False)
|
|
|
|
x = self.get_constant_bool(True)
|
|
#
|
|
# Unary
|
|
#
|
|
# op(constant SymBool)
|
|
self.assertIs(x.__sym_not__(), False)
|
|
|
|
#
|
|
# Binary
|
|
#
|
|
# op(constant SymBool, bool)
|
|
# op(constant SymBool, constant SymBool)
|
|
# op(bool, constant SymBool)
|
|
self.assertIs(operator.and_(x, True), True)
|
|
self.assertIs(operator.and_(x, x), True)
|
|
self.assertIs(operator.and_(True, x), True)
|
|
|
|
# op(symbolic SymBool, constant Symbool)
|
|
# op(constant SymBool, symbolic Symbool)
|
|
shape_env = ShapeEnv()
|
|
a = create_symint(shape_env, 2)
|
|
b = create_symint(shape_env, 2)
|
|
c = a == b # symbolic SymBool
|
|
d = self.get_constant_bool(True)
|
|
e = operator.and_(c, d)
|
|
f = operator.and_(d, c)
|
|
self.assertTrue(is_symbolic(e))
|
|
self.assertTrue(is_symbolic(f))
|
|
self.assertIs(e.node.guard_bool("", 0), True)
|
|
self.assertIs(f.node.guard_bool("", 0), True)
|
|
|
|
# Comparing sizes
|
|
sz1 = torch.Size([j1, j1, j1])
|
|
sz2 = torch.Size([j1, j1, j1])
|
|
self.assertIs(sz1 == sz2, True)
|
|
|
|
sz1 = torch.Size([3, j1, 4])
|
|
sz2 = torch.Size([3, j2, 4])
|
|
self.assertIs(sz1 == sz2, True)
|
|
self.assertIs(sz1 != sz2, False)
|
|
|
|
def test_stride_symnode(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
# check everything static
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.STATIC,
|
|
DimDynamic.STATIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
self.assertTrue(all(isinstance(size, int) for size in t.size()))
|
|
self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
|
|
|
|
# check dynamic size but static dims
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
# Expect stride to be inferred
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, torch.SymInt))
|
|
self.assertTrue(isinstance(s1, torch.SymInt))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(s1 == s2)
|
|
self.assertEqual(s3, 1)
|
|
|
|
# Check dynamic stride but static dims
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.STATIC,
|
|
DimDynamic.STATIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, int))
|
|
self.assertTrue(isinstance(s1, int))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(isinstance(s3, int))
|
|
|
|
# Check dynamic sizes and dims, and ensure different symbol
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, torch.SymInt))
|
|
self.assertTrue(isinstance(s1, torch.SymInt))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(isinstance(s3, int))
|
|
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
|
|
|
|
|
|
instantiate_parametrized_tests(TestSymNumberMagicMethods)
|
|
|
|
|
|
class TestFloorDiv(TestCase):
|
|
@staticmethod
|
|
def python_floordiv(x, y):
|
|
return x // y
|
|
|
|
@staticmethod
|
|
def torch_floordiv(x, y):
|
|
# Note: we fully evaluate here since FloorDiv might not always do
|
|
# that.
|
|
shape_env = ShapeEnv()
|
|
return shape_env.evaluate_expr(FloorDiv(x, y))
|
|
|
|
@staticmethod
|
|
def yield_test_cases(values, negate=True):
|
|
for x, y in values:
|
|
yield (x, y)
|
|
if negate:
|
|
yield (-x, y)
|
|
yield (x, -y)
|
|
yield (-x, -y)
|
|
|
|
def test_floordiv_float_int(self):
|
|
values = ((7, 2),)
|
|
|
|
for x, y in TestFloorDiv.yield_test_cases(values):
|
|
self.assertEqual(
|
|
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
|
)
|
|
|
|
def test_floordiv_div_by_one(self):
|
|
values = ((2, 1),)
|
|
|
|
for x, y in TestFloorDiv.yield_test_cases(values):
|
|
self.assertEqual(
|
|
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
|
)
|
|
|
|
def test_floordiv_simplify(self):
|
|
# Tests how we simplify or evaluate FloorDiv without free variables
|
|
shape_env = ShapeEnv()
|
|
result = 21
|
|
exprs = (7 * FloorDiv(6, 2),)
|
|
|
|
for expr in exprs:
|
|
self.assertEqual(expr, result)
|
|
self.assertEqual(expr.doit(deep=False), result)
|
|
self.assertEqual(expr.doit(deep=True), result)
|
|
self.assertEqual(sympy.simplify(expr), result)
|
|
self.assertEqual(shape_env.simplify(expr), result)
|
|
self.assertEqual(shape_env.evaluate_expr(expr), result)
|
|
|
|
def test_floordiv_assumptions(self):
|
|
cases = (
|
|
sympy.Symbol("i1", integer=True),
|
|
sympy.Symbol("i2", integer=True),
|
|
)
|
|
|
|
for base, divisor in itertools.product(cases, repeat=2):
|
|
|
|
def op():
|
|
return FloorDiv(base, divisor)
|
|
|
|
def is_complex(x):
|
|
return x.is_integer is False and x.is_real is False and x.is_complex
|
|
|
|
if is_complex(base) or is_complex(divisor):
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(
|
|
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
|
|
r" expected integer or real"
|
|
),
|
|
op,
|
|
)
|
|
continue
|
|
|
|
op = op()
|
|
|
|
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
|
|
# always returns an integer 1 when both args are the same object.
|
|
# This even works for Symbols with no assumptions specified.
|
|
if base is divisor:
|
|
self.assertTrue(op.is_integer)
|
|
self.assertTrue(op.is_real)
|
|
elif base.is_integer and divisor.is_integer:
|
|
self.assertTrue(op.is_integer)
|
|
self.assertTrue(op.is_real)
|
|
else:
|
|
self.assertEqual(op.is_integer, None)
|
|
self.assertTrue(op.is_real)
|
|
|
|
|
|
class TestDimConstraints(TestCase):
|
|
def test_dim_constraints_reduce_congruences_simple(self):
|
|
from sympy import Symbol
|
|
|
|
s = Symbol("s", positive=True, integer=True)
|
|
dim_constraints = DimConstraints({}, {}, set(), {})
|
|
dim_constraints._congruences[s] = {
|
|
(s / 2) % 2,
|
|
(s / 2) % 8,
|
|
(s / 2) % 4,
|
|
s % 2,
|
|
((s / 16) + 2) % 4,
|
|
}
|
|
congruences = dim_constraints._reduce_congruences()
|
|
self.assertEqual(congruences[s], {(s + 32) % 64})
|
|
|
|
def test_dim_constraints_reduce_inequalities_simple(self):
|
|
from sympy import Eq, Interval, Ne, Symbol
|
|
from sympy.solvers.inequalities import reduce_inequalities
|
|
|
|
s = Symbol("s", positive=True, integer=True)
|
|
exprs = {
|
|
s >= 2,
|
|
Ne(8 * s, 16),
|
|
Ne(s / 2, 1),
|
|
Ne(16 * s, 32),
|
|
s < 16,
|
|
Ne(s, 2),
|
|
s / 2 < 16,
|
|
s / 2 > 1,
|
|
s / 2 >= 2,
|
|
Ne(3 * s / 2, 3),
|
|
}
|
|
solution = reduce_inequalities(exprs, s).as_set()
|
|
self.assertEqual(solution, Interval.Ropen(4, 16))
|
|
|
|
exprs.add(Eq(s / 2, 4))
|
|
solution = reduce_inequalities(exprs, s).as_set()
|
|
self.assertEqual(solution, {8})
|
|
|
|
def test_dim_constraints_reduce_inequalities_error(self):
|
|
from collections import defaultdict
|
|
|
|
from sympy import Symbol
|
|
from sympy.solvers.inequalities import reduce_inequalities
|
|
|
|
from torch._dynamo.source import (
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
|
|
|
|
s0 = Symbol("s0", positive=True, integer=True)
|
|
exprs = {
|
|
4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
|
|
s0 >= 2,
|
|
s0**3 <= 2147483647,
|
|
s0 <= 2147483647,
|
|
}
|
|
answer = reduce_inequalities(exprs, s0)
|
|
|
|
symbol_to_source = defaultdict(list)
|
|
symbol_to_source[s0].append(
|
|
TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
)
|
|
dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Unknown symbol.*created by constraints solver",
|
|
):
|
|
dcp.doprint(answer)
|
|
|
|
def test_dim_constraints_solve_full(self):
|
|
from sympy import Eq, Integer, Ne, Symbol
|
|
|
|
from torch._dynamo.source import (
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
|
|
src0 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src2 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src3 = TensorPropertySource(
|
|
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src4 = TensorPropertySource(
|
|
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
|
|
src1 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
|
|
)
|
|
src7 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
|
|
)
|
|
|
|
src5 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src8 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
|
|
src6 = TensorPropertySource(
|
|
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src9 = TensorPropertySource(
|
|
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src10 = TensorPropertySource(
|
|
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
|
|
src11 = TensorPropertySource(
|
|
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src12 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
|
|
)
|
|
|
|
s0 = Symbol("s0", positive=True, integer=True)
|
|
s1 = Symbol("s1", positive=True, integer=True)
|
|
s5 = Symbol("s5", positive=True, integer=True)
|
|
s6 = Symbol("s6", positive=True, integer=True)
|
|
symbol_to_source = {
|
|
s0: [src0, src2, src3, src4],
|
|
s1: [src1, src7],
|
|
s5: [src5, src8],
|
|
s6: [src6, src9, src10],
|
|
}
|
|
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
|
|
marked_dynamic = {s0, s1, s5, s6}
|
|
dim_constraints = DimConstraints(
|
|
symbol_to_source, var_to_val, marked_dynamic, {}
|
|
)
|
|
dim_constraints.add_equality(src2, s0)
|
|
dim_constraints.add_equality(src3, s0)
|
|
dim_constraints.add_equality(src4, s0)
|
|
dim_constraints.add_equality(src7, s1)
|
|
dim_constraints.add_equality(src8, s5)
|
|
dim_constraints.add_equality(src9, s6)
|
|
dim_constraints.add_equality(src10, s6)
|
|
dim_constraints.add_equality(src11, Integer(1))
|
|
dim_constraints.add_equality(src12, Integer(3))
|
|
|
|
dim_constraints.add(s1**2 <= 2147483647)
|
|
dim_constraints.add(32 * s1**2 <= 2147483647)
|
|
dim_constraints.add(s0 < 16)
|
|
dim_constraints.add(Eq(Mod(s1, 2), 0))
|
|
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
|
|
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
|
|
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
|
|
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
|
|
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 64
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 128
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 256
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * s0,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
>= 0
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
|
|
dim_constraints.add(
|
|
1
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * s0,
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
|
|
+ 60 * (Mod(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
)
|
|
dim_constraints.add(Ne(16 * s0, 32))
|
|
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
|
|
dim_constraints.add(Ne(16 * s0, 32))
|
|
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
|
|
dim_constraints.add(FloorDiv(s0, 2) >= 2)
|
|
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
|
|
dim_constraints.add(1 < FloorDiv(s0, 2))
|
|
dim_constraints.add(Ne(s0, 2))
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20,
|
|
20,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20
|
|
* (
|
|
Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
* (
|
|
Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
)
|
|
)
|
|
- 20
|
|
* Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 2
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60,
|
|
60,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
* (
|
|
Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
1,
|
|
)
|
|
)
|
|
- Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
1,
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(8 * s0, 16))
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv(s0, 2) < 16)
|
|
dim_constraints.add(FloorDiv(s0, 2) > 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)),
|
|
3 * (FloorDiv(s0, 2)),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
6,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 240,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Eq(6 * s5, 132))
|
|
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
|
|
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62
|
|
)
|
|
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
|
|
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
|
|
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
|
|
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
|
|
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
>= 0
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
|
|
dim_constraints.add(
|
|
1
|
|
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2)),
|
|
256,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64
|
|
* (
|
|
Mod(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2)),
|
|
4,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
FloorDiv(s0, 2),
|
|
FloorDiv(
|
|
(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2))
|
|
),
|
|
4,
|
|
),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
FloorDiv(
|
|
(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2))
|
|
),
|
|
4,
|
|
),
|
|
FloorDiv(s0, 2),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64
|
|
* (
|
|
Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
4,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576,
|
|
256,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
> 1
|
|
)
|
|
dim_constraints.add(s0 >= 2)
|
|
dim_constraints.add(s1 >= 2)
|
|
dim_constraints.add(s6 >= 2)
|
|
dim_constraints.add(s5 >= 2)
|
|
|
|
dim_constraints.solve()
|
|
self.assertEqual(
|
|
dim_constraints._static_results,
|
|
{
|
|
"L['c'].size()[0] == 8",
|
|
"L['d'].size()[0] == 8",
|
|
"L['a'].size()[2] == 96",
|
|
"L['f'].size()[1] == 1",
|
|
"L['a'].size()[3] == 96",
|
|
"L['b'].size()[2] == 3",
|
|
"L['b'].size()[1] == 22",
|
|
"L['b'].size()[0] == 8",
|
|
"L['a'].size()[1] == 22",
|
|
"L['a'].size()[0] == 8",
|
|
},
|
|
)
|
|
self.assertEqual(
|
|
dim_constraints._dynamic_results,
|
|
{
|
|
"2 <= L['c'].size()[1]",
|
|
"L['d'].size()[1] == L['c'].size()[1]",
|
|
"L['e'].size()[1] == L['c'].size()[1]",
|
|
},
|
|
)
|
|
|
|
|
|
class TestGuardsExpressions(TestCase):
|
|
"""
|
|
Tests the guards-related methods used by the inductor FX graph cache.
|
|
"""
|
|
|
|
def test_guards_gt_lt(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 6)
|
|
s1 = create_symint(shape_env, 7)
|
|
s2 = create_symint(shape_env, 5)
|
|
|
|
guard_int(sym_int(s0 > 5))
|
|
guard_int(sym_int(s0 < 7))
|
|
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
|
|
|
|
def test_guards_float_print(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 3)
|
|
guard_bool(2 / s0 == 2 / 3)
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
|
|
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_guard_or_true(self):
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
|
|
|
def func(a, b):
|
|
x = a.item()
|
|
if guard_or_true(x == 1):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
# eager.
|
|
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
|
|
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
|
|
|
|
# compile with unbacked.
|
|
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func(a, b)
|
|
|
|
# always return b*10
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
|
|
# Test that statically known true works.
|
|
def func2(a, b):
|
|
x = a.item()
|
|
if guard_or_true(x != x):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func2(a, b)
|
|
# always return b*20
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
|
|
# Test backed_size_oblivious
|
|
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
|
|
|
|
def func3(a, b):
|
|
if guard_or_true(a.size()[0] != 9):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
|
|
a = torch.rand(9, 2)
|
|
b = torch.rand(3, 4)
|
|
|
|
self.assertEqual(func3(a, b), b * 20)
|
|
self.assertEqual(compiled(a, b), b * 10)
|
|
|
|
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_guard_or_false(self):
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
|
|
def func(a, b):
|
|
x = a.item()
|
|
if guard_or_false(x == 1):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
# eager.
|
|
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
|
|
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
|
|
|
|
# compile with unbacked.
|
|
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func(a, b)
|
|
|
|
# always return b*20
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
|
|
# Test that statically known true works.
|
|
def func2(a, b):
|
|
x = a.item()
|
|
if guard_or_false(x == x):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func2(a, b)
|
|
# always return b*10
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
|
|
# Test backed_size_oblivious
|
|
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
|
|
|
|
def func3(a, b):
|
|
if guard_or_false(a.size()[0] == 9):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
|
|
a = torch.rand(9, 2)
|
|
b = torch.rand(3, 4)
|
|
|
|
self.assertEqual(func3(a, b), b * 10)
|
|
self.assertEqual(compiled(a, b), b * 20)
|
|
|
|
def test_guards_float_div(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 8)
|
|
s1 = create_symint(shape_env, 7)
|
|
|
|
guard_int(sym_int(s0 / 2.0))
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
|
|
self.assertIn("math.trunc(", guards)
|
|
self.assertIn("float(", guards)
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
|
|
|
def test_remove_symbols_without_guarding(self):
|
|
from torch._functorch.partitioners import _remove_symbols_without_guarding
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
x = create_fake_tensor_with_dynamic_size(
|
|
torch.randn(5, 8),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
|
|
self.assertEqual(f"{x.stride()}", "(s49, 1)")
|
|
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
|
|
|
|
x_clean = _remove_symbols_without_guarding(x, 4096)
|
|
|
|
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
|
|
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")
|
|
|
|
|
|
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
|
|
for node in graph.nodes:
|
|
if node.name == "arg3_1":
|
|
assert node.meta["val"].size()[0] == 2
|
|
return graph
|
|
|
|
|
|
class TestUnbacked(TestCase):
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_deferred_neq_assert(self, backend):
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def func(a):
|
|
torch._check(a.item() != 5)
|
|
return a.item() * 10
|
|
|
|
func(torch.tensor([100]))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.tensor([5]))
|
|
|
|
# Test a situation where we generate a runtime assert i.e: u1==s1, then we specialize s1
|
|
# later on to a constant.
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_post_specialize_runtime_assert1(self, backend):
|
|
@torch.compile(dynamic=True, backend=backend)
|
|
def func(x, y):
|
|
u0 = y.item()
|
|
s0 = x.size()[0]
|
|
s1 = x.size()[1]
|
|
torch._check(u0 + s0 + s1 == 102)
|
|
assert s0 == 2
|
|
return x * 10
|
|
|
|
func(torch.rand(2, 50), torch.tensor([50]))
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.rand(2, 50), torch.tensor([51]))
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_post_specialize_runtime_assert2(self, backend):
|
|
@torch.compile(dynamic=True, backend=backend)
|
|
def func(x, y):
|
|
u0 = y.item()
|
|
s0 = x.size()[0]
|
|
s1 = x.size()[1]
|
|
torch._check(u0 + s0 + s1 == 102)
|
|
return x * 10
|
|
|
|
func(torch.rand(2, 50), torch.tensor([50]))
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.rand(2, 50), torch.tensor([51]))
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_deferred_sym_or_assert(self, backend):
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def func(a, b):
|
|
torch._check(operator.or_(a.item() == 5, b.item() == 5))
|
|
return a.item() * 10
|
|
|
|
func(torch.tensor([5]), torch.tensor([100]))
|
|
func(torch.tensor([100]), torch.tensor([5]))
|
|
|
|
def test_has_free_symbols(self):
|
|
self.assertFalse(has_free_symbols(sympy.S.true))
|
|
self.assertFalse(has_free_symbols(sympy.Max(1, 10, evaluate=False)))
|
|
|
|
self.assertFalse(has_free_symbols(sympy.sympify("1")))
|
|
self.assertFalse(has_free_symbols(sympy.sympify("1.1")))
|
|
self.assertTrue(has_free_symbols(sympy.sympify("a")))
|
|
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
|
|
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_deferred_sym_eq_assert(self, backend):
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def func(a, b):
|
|
torch._check(b.item() == 5)
|
|
return a * 10
|
|
|
|
func(torch.tensor([5]), torch.tensor([5]))
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.tensor([100]), torch.tensor([1]))
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
|
def test_deferred_with_unbacked_input(self, backend):
|
|
@torch.compile(fullgraph=True, dynamic=True, backend=backend)
|
|
def func(a, b):
|
|
torch._check(a.size()[0] == b.size()[0])
|
|
return a * 10
|
|
|
|
a = torch.rand(1, 1)
|
|
b = torch.rand(1, 1)
|
|
torch._dynamo.decorators.mark_unbacked(a, 0)
|
|
torch._dynamo.decorators.mark_unbacked(b, 0)
|
|
func(a, b)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
func(a, torch.rand(2, 1))
|
|
|
|
|
|
class TestUbackedOps(TestCase):
|
|
@fresh_inductor_cache()
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_unbacked_reshape1(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
# Reshape happens in place reshape (no-clone)
|
|
# reshape u1 -> (u0*u0)
|
|
def func(x, y):
|
|
f = y.item()
|
|
t1 = x.view((f, f))
|
|
t2 = x.reshape((f, f))
|
|
# TODO avoid _check_is_size here.
|
|
torch._check_is_size(f)
|
|
return t1 * 10, t2 * 10
|
|
|
|
compiled_func = torch.compile(
|
|
fullgraph=True,
|
|
backend=cnt,
|
|
dynamic=True,
|
|
)(func)
|
|
|
|
# create a non-contiguous with data being even numbers in [0:cnt-1]
|
|
# and reshape it into sqrt(cnt)*sqrt(cnt)
|
|
def make_non_contiguous_tensor_and_test(cnt):
|
|
# create a non-contiguous tensor x that is skipping odd indices.
|
|
x = torch.arange(cnt * 2)
|
|
x = x.as_strided((x.size()[0] // 2,), (2,))
|
|
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
sz = torch.tensor([int(math.sqrt(cnt))])
|
|
compiled_result = compiled_func(x, sz)
|
|
eager_result = func(x, sz)
|
|
self.assertEqual(compiled_result, eager_result)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
make_non_contiguous_tensor_and_test(4)
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
|
|
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
|
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
|
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
|
view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
|
|
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
|
|
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
|
mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
|
return (mul_9, mul_12)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
make_non_contiguous_tensor_and_test(49)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization).
|
|
# marking strides unbacked would have avoided the recompilation here.
|
|
x = torch.arange(100)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
compiled_result = compiled_func(x, torch.tensor([10]))
|
|
eager_result = func(x, torch.tensor([10]))
|
|
self.assertEqual(compiled_result, eager_result)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
|
|
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
|
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
|
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
|
view: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
|
|
view_1: "i64[u0, u0][u0, 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
|
|
mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
|
mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
|
return (mul_4, mul_7)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
x = torch.arange(25)
|
|
compiled_result = compiled_func(x, torch.tensor([5]))
|
|
eager_result = func(x, torch.tensor([5]))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_unbacked_reshape2(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
# This reshape requires a clone when the input is not contiguous and we cant compute strides.
|
|
# reshape (u2, u3) -> (u0, u1)
|
|
def func(x, y):
|
|
u0, u1 = y.tolist()
|
|
torch._check_is_size(u0)
|
|
torch._check_is_size(u1)
|
|
|
|
result1 = torch.reshape(x, (u0, u1))
|
|
return result1 * 10
|
|
|
|
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
|
|
|
x = torch.randn(10, 10)
|
|
# make x not contiguous.
|
|
x = x.t_()
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
result_eager = func(x, torch.tensor([5, 20]))
|
|
result_compiled = compiled_func(x, torch.tensor([5, 20]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
|
|
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
|
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
|
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
|
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
|
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
|
|
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None
|
|
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
|
|
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
|
|
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
|
|
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
|
|
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
|
|
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
|
|
mul_18: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
|
return (mul_18,)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
result_eager = func(x, torch.tensor([2, 50]))
|
|
result_compiled = compiled_func(x, torch.tensor([2, 50]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
x = torch.randn(4, 4).t_()
|
|
result_eager = func(x, torch.tensor([2, 8]))
|
|
result_compiled = compiled_func(x, torch.tensor([2, 8]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
|
|
# but not anymore since we use definitely_contiguous .
|
|
# We need a way to mark strides unbacked to avoid the recompilation here.
|
|
x = torch.randn(10, 10)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
result_compiled = compiled_func(x, torch.tensor([2, 50]))
|
|
result_eager = func(x, torch.tensor([2, 50]))
|
|
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
x = torch.randn(4, 4)
|
|
|
|
result_eager = func(x, torch.tensor([2, 8]))
|
|
result_compiled = compiled_func(x, torch.tensor([2, 8]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_unbacked_non_contigious_reshape_failing(self):
|
|
# reshape u1 -> (u0*u0)
|
|
# this result in the tensor "i64[u0, u0][s7*u0, s7].
|
|
# reshape happens in place reshape (no-clone)
|
|
def func(x, y):
|
|
f = y.item()
|
|
t1 = x.view((f, f))
|
|
t2 = x.reshape((f, f))
|
|
return t1, t2
|
|
|
|
# create a non-contiguous with data being even numbers in [0:cnt-1]
|
|
def make_non_contiguous_tensor(cnt):
|
|
# create a non-contiguous tensor x that is skipping odd indices.
|
|
x = torch.arange(cnt * 2)
|
|
x = x.as_strided((x.size()[0] // 2,), (2,))
|
|
return x
|
|
|
|
x = make_non_contiguous_tensor(4)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
compiled_func = torch.compile(
|
|
fullgraph=True,
|
|
backend="inductor",
|
|
)(func)
|
|
|
|
compiled_result = compiled_func(x, torch.tensor([2]))
|
|
eager_result = func(x, torch.tensor([2]))
|
|
self.assertEqual(compiled_result, eager_result)
|
|
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_invalid_view_unbacked_view(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
# This view (u2, u3) -> (u0, u1) cant happen in general unless we know that input is contigous or we have
|
|
# hints to to compute strides.
|
|
def func(x, y):
|
|
u0, u1 = y.tolist()
|
|
torch._check_is_size(u0)
|
|
torch._check_is_size(u1)
|
|
|
|
result2 = x.view(u0, u1) * 10
|
|
return result2
|
|
|
|
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
|
|
|
x = torch.randn(10, 10)
|
|
# make x not contiguous.
|
|
x = x.t_()
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
with self.assertRaises(torch._dynamo.exc.UserError):
|
|
# throws a data dependent error.
|
|
compiled_func(x, torch.tensor([5, 20]))
|
|
|
|
|
|
instantiate_parametrized_tests(TestUnbacked)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|