mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Summary: When we compute contiguity for a tensor with dynamic shapes we first: 1) Try to compute it without guarding. 2) If all shapes hinted, compute it with potentially adding guards. 3) if any input is not hinted, compute it symbolically. sym_is_contiguous return a SymBool that is then either evaluated or guard_or_false can be called on it to avoid data dependent errors. ex: bool is_contiguous = input.sym_is_contiguous().guard_or_false(__FILE__, __LINE__); is_contiguous_or_false is a helper function that does that. In this PR I only handle default contiguity, will follow up with changes for other formats like channel_last . We use this patter in this PR for several locations to avoid DDEs. Test Plan: contbuild & OSS CI, Rollback Plan: Reviewed By: malfet Differential Revision: D77639021 Pull Request resolved: https://github.com/pytorch/pytorch/pull/157472 Approved by: https://github.com/aorenste
3538 lines
129 KiB
Python
3538 lines
129 KiB
Python
# Owner(s): ["oncall: jit"]
|
|
# ruff: noqa: F841
|
|
import contextlib
|
|
import copy
|
|
import itertools
|
|
import math
|
|
import operator
|
|
import unittest
|
|
|
|
import numpy as np
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.fx
|
|
import torch.nn.functional as F
|
|
from torch import sym_int, SymBool, SymFloat, SymInt
|
|
from torch._C import _disabled_torch_function_impl
|
|
from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend
|
|
from torch._inductor.utils import fresh_cache
|
|
from torch.fx.experimental import sym_node
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
_constrain_range_for_size,
|
|
DimConstraints,
|
|
DimDynamic,
|
|
expect_true,
|
|
guard_bool,
|
|
guard_float,
|
|
guard_int,
|
|
GuardOnDataDependentSymNode,
|
|
has_free_symbols,
|
|
hint_int,
|
|
is_symbolic,
|
|
ShapeEnv,
|
|
StatelessSymbolicContext,
|
|
statically_known_false,
|
|
statically_known_true,
|
|
)
|
|
from torch.testing._internal.common_dtype import all_types_and
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfTorchDynamo,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.logging_utils import logs_to_string
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
from torch.utils._sympy.functions import (
|
|
CleanDiv,
|
|
FloorDiv,
|
|
IsNonOverlappingAndDenseIndicator,
|
|
Mod,
|
|
)
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
meta_funcs = {}
|
|
|
|
|
|
def register_meta(op):
|
|
def decorator(f):
|
|
def add_func(op):
|
|
meta_funcs[op] = f
|
|
|
|
pytree.tree_map_(add_func, op)
|
|
return f
|
|
|
|
return decorator
|
|
|
|
|
|
@register_meta([aten.add.Tensor, aten.sub.Tensor])
|
|
def binary_meta(a, b):
|
|
return a.new_empty(a.shape)
|
|
|
|
|
|
@register_meta(aten.cat.default)
|
|
def cat_meta(tensors, dim=0):
|
|
concat_length = 0
|
|
shape = tensors[0].shape
|
|
for tensor in tensors:
|
|
for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)):
|
|
if idx == dim:
|
|
concat_length = concat_length + length
|
|
else:
|
|
assert length == common_length
|
|
new_shape = list(shape)
|
|
new_shape[dim] = concat_length
|
|
return tensors[0].new_empty(new_shape)
|
|
|
|
|
|
@register_meta([aten.narrow_copy.default])
|
|
def narrow_copy_symint_meta(a, dim, start, length, **kwargs):
|
|
shape = []
|
|
for i, x in enumerate(a.shape):
|
|
if i == dim:
|
|
shape.append(length)
|
|
else:
|
|
shape.append(x)
|
|
return a.new_empty(tuple(shape))
|
|
|
|
|
|
@register_meta([aten.expand.default])
|
|
def expand_symint_meta(a, size, implicit=False):
|
|
return a.new_empty(size)
|
|
|
|
|
|
def create_contiguous(shape):
|
|
strides = [1]
|
|
for dim in reversed(shape[:-1]):
|
|
strides.append(dim * strides[-1])
|
|
return list(reversed(strides))
|
|
|
|
|
|
class FakeSymbolicTensor(torch.Tensor):
|
|
@staticmethod
|
|
def __new__(
|
|
cls,
|
|
sym_shape,
|
|
sym_strides,
|
|
dtype,
|
|
layout,
|
|
requires_grad,
|
|
device,
|
|
storage_offset=0,
|
|
):
|
|
# TODO: this is wrong in general
|
|
sym_stride = create_contiguous(sym_shape)
|
|
r = torch.Tensor._make_wrapper_subclass(
|
|
cls,
|
|
sym_shape,
|
|
sym_stride,
|
|
storage_offset,
|
|
dtype=dtype,
|
|
layout=layout,
|
|
requires_grad=requires_grad,
|
|
device=device,
|
|
)
|
|
return r
|
|
|
|
__torch_function__ = _disabled_torch_function_impl
|
|
|
|
def new_empty(self, shape):
|
|
return FakeSymbolicTensor(
|
|
shape, None, self.dtype, self.layout, self.requires_grad, self.device
|
|
)
|
|
|
|
@classmethod
|
|
def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None):
|
|
if func_overload in meta_funcs:
|
|
return meta_funcs[func_overload](*args, **kwargs)
|
|
|
|
if func_overload == torch.ops.aten.new_empty.default:
|
|
self = args[0]
|
|
shape = args[1]
|
|
return FakeSymbolicTensor(
|
|
shape,
|
|
self.stride(),
|
|
self.dtype,
|
|
self.layout,
|
|
self.requires_grad,
|
|
self.device,
|
|
)
|
|
|
|
raise RuntimeError(f"operator {func_overload} not supported")
|
|
|
|
|
|
def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
if source is None:
|
|
source = ConstantSource(name)
|
|
constraint_dims = [None] * arg.dim()
|
|
if dynamic_dims is None:
|
|
dynamic_dims = [DimDynamic.DUCK] * arg.dim()
|
|
(
|
|
sym_shapes,
|
|
sym_strides,
|
|
sym_storage_offset,
|
|
) = shape_env.create_symbolic_sizes_strides_storage_offset(
|
|
arg,
|
|
source=source,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims
|
|
),
|
|
)
|
|
return FakeSymbolicTensor(
|
|
sym_shapes,
|
|
sym_strides,
|
|
arg.dtype,
|
|
arg.layout,
|
|
arg.requires_grad,
|
|
arg.device,
|
|
sym_storage_offset,
|
|
)
|
|
|
|
|
|
def create_fake_tensor_with_dynamic_size(x, shape_env, dynamic_sizes, dynamic_strides):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
with FakeTensorMode(shape_env=shape_env) as fake_mode:
|
|
return fake_mode.from_tensor(
|
|
x,
|
|
symbolic_context=StatelessSymbolicContext(
|
|
dynamic_sizes=dynamic_sizes,
|
|
dynamic_strides=dynamic_strides,
|
|
),
|
|
)
|
|
|
|
|
|
def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs):
|
|
from torch._dynamo.source import ConstantSource
|
|
|
|
symbol = shape_env.create_symbol(
|
|
val,
|
|
source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"),
|
|
dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC,
|
|
constraint_dim=None,
|
|
**kwargs,
|
|
)
|
|
return cls(SymNode(symbol, shape_env, pytype, hint=val))
|
|
|
|
|
|
# TODO: default duck to False
|
|
def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt:
|
|
return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs)
|
|
|
|
|
|
def create_symbool(shape_env, b: bool) -> SymBool:
|
|
return create_symtype(SymBool, bool, shape_env, b)
|
|
|
|
|
|
def create_symfloat(shape_env, f: float) -> SymFloat:
|
|
return create_symtype(SymFloat, float, shape_env, f)
|
|
|
|
|
|
@skipIfTorchDynamo(
|
|
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
|
|
)
|
|
class TestPySymInt(TestCase):
|
|
def test_arith_ops(self):
|
|
shape_env = ShapeEnv()
|
|
symints = []
|
|
for i in range(2, 5):
|
|
symints.append((i, create_symint(shape_env, i)))
|
|
|
|
ops = [
|
|
operator.add,
|
|
operator.sub,
|
|
operator.floordiv,
|
|
operator.mul,
|
|
operator.mod,
|
|
]
|
|
|
|
for op in ops:
|
|
for args in itertools.permutations(symints, 2):
|
|
if not isinstance(args[0][1], int) and (
|
|
(op != operator.mod or op != operator.floordiv) and args[1][0] != 0
|
|
):
|
|
self.assertTrue(
|
|
op(args[0][1], args[1][1]) == op(args[0][0], args[1][0])
|
|
)
|
|
|
|
def test_reverse_arith_ops(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
a = create_symint(shape_env, 2)
|
|
self.assertTrue(5 // a == 5 // 2)
|
|
|
|
a = create_symint(shape_env, 2)
|
|
self.assertTrue(5 * a == 5 * 2)
|
|
|
|
def test_sympify_symint(self):
|
|
shape_env = ShapeEnv()
|
|
a = create_symint(shape_env, 2)
|
|
self.assertIs(sympy.sympify(a), a.node.expr)
|
|
b = create_symfloat(shape_env, 3.0)
|
|
self.assertIs(sympy.sympify(b), b.node.expr)
|
|
c = create_symbool(shape_env, True)
|
|
self.assertIs(sympy.sympify(c), c.node.expr)
|
|
|
|
def test_roundtrip(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
|
|
self.assertTrue(not isinstance(x.shape[0], SymNode))
|
|
self.assertTrue(isinstance(x.shape[0], SymInt))
|
|
|
|
self.assertTrue(x.shape[0] == 5)
|
|
self.assertTrue(x.shape[1] == 4)
|
|
self.assertTrue(x.shape[2], 3)
|
|
|
|
self.assertTrue(x.size()[0], 5)
|
|
self.assertTrue(x.size()[1], 4)
|
|
# Should be simplifiable to an integer.
|
|
# Ref: https://github.com/pytorch/pytorch/pull/107492
|
|
self.assertTrue(isinstance(x.size()[1], SymInt))
|
|
self.assertTrue(
|
|
isinstance(x.size()[1].node.maybe_as_int(), int)
|
|
) # due to guard above
|
|
self.assertTrue(x.size()[2] == 3)
|
|
|
|
self.assertTrue(x.size(0) == 5)
|
|
self.assertTrue(x.size(1) == 4)
|
|
self.assertTrue(x.size(2) == 3)
|
|
self.assertTrue(isinstance(x.size(2), SymInt))
|
|
self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int))
|
|
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env)
|
|
self.assertTrue(isinstance(y.storage_offset(), SymInt))
|
|
self.assertTrue(y.storage_offset() == 12)
|
|
|
|
def test_binary(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env)
|
|
|
|
z = x + y
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# broadcasting
|
|
y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env)
|
|
z = x + y
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
def test_symint_args(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env)
|
|
LAST_DIM = 2
|
|
z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM])
|
|
self.assertTrue(z.shape[2] == y.shape[2])
|
|
|
|
# arithmetic expr with two symints
|
|
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM])
|
|
self.assertTrue(z.shape[2] == 2)
|
|
|
|
# arithmetic expr with a symint and python int
|
|
z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1)
|
|
self.assertTrue(z.shape[2] == 2)
|
|
|
|
def test_symint_vargs(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env)
|
|
y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env)
|
|
|
|
# varargs
|
|
z = y.expand(x.shape[0], y.shape[1], x.shape[2])
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# shape list
|
|
z = y.expand((x.shape[0], y.shape[1], x.shape[2]))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints
|
|
z = y.expand(x.shape[0], y.shape[1], 3)
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints in a list
|
|
z = y.expand((x.shape[0], y.shape[1], 3))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python symints and ints
|
|
z = y.expand(5, y.shape[1], x.shape[2])
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
# mixed python ints and symints in a list
|
|
z = y.expand((5, y.shape[1], x.shape[2]))
|
|
self.assertTrue(z.shape[0] == 5)
|
|
self.assertTrue(z.shape[1] == 4)
|
|
self.assertTrue(z.shape[2] == 3)
|
|
|
|
z = y.expand((y.shape[1],))
|
|
z = y.expand(y.shape[1])
|
|
|
|
def test_symint_bitwise_and(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 0b1100)
|
|
b0 = create_symint(shape_env, 0b1010)
|
|
res_and = a0 & b0
|
|
self.assertEqual(res_and, 0b1000)
|
|
self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)"""
|
|
)
|
|
|
|
a1 = create_symint(shape_env, 3)
|
|
b1 = create_symbool(shape_env, True)
|
|
self.assertEqual(a1 & b1, 1)
|
|
|
|
a2 = create_symint(shape_env, 0b1100)
|
|
self.assertEqual(a2 & 0b1010, 0b1000)
|
|
|
|
a3 = create_symbool(shape_env, True)
|
|
b3 = create_symbool(shape_env, True)
|
|
self.assertEqual(a3 & b3, True)
|
|
|
|
def test_symint_bitwise_or(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 0b1100)
|
|
b0 = create_symint(shape_env, 0b1010)
|
|
res_or = a0 | b0
|
|
self.assertEqual(res_or, 0b1110)
|
|
self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)"""
|
|
)
|
|
|
|
def test_stride(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env)
|
|
self.assertIsInstance(x.stride()[0], SymInt)
|
|
|
|
def test_size_expressions(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
expand_x = x.expand(x.shape[0], x.shape[0])
|
|
if expand_x.shape[0] > 3:
|
|
result = expand_x + expand_x
|
|
else:
|
|
result = expand_x + expand_x
|
|
|
|
gt_op, _bt, is_size_obv = shape_env.guards[-1]
|
|
self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan))
|
|
self.assertTrue(str(x.shape[0]), str(gt_op.args[0]))
|
|
self.assertTrue(str(expand_x.shape[1]), str(x.shape[0]))
|
|
self.assertTrue(str(expand_x.shape[1]), str(result.shape[0]))
|
|
self.assertFalse(is_size_obv)
|
|
|
|
def test_floordiv_static(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 8)
|
|
# This was extracted from
|
|
# python test/inductor/test_cuda_cpp_wrapper.py -k
|
|
# DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper
|
|
bool(s0 % 2 == 0)
|
|
bool(s0 % (s0 // 2) == 0)
|
|
bool(2 * (s0 // 2) == s0)
|
|
self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2))
|
|
|
|
def test_numel(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
self.assertIsInstance(x.numel(), torch.SymInt)
|
|
self.assertIsInstance(torch.numel(x), torch.SymInt)
|
|
|
|
x = torch.rand(3, 3)
|
|
self.assertIsInstance(x.numel(), int)
|
|
self.assertIsInstance(torch.numel(x), int)
|
|
|
|
def test_int_to_float(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
r = torch.sym_float(x.shape[0])
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
|
|
def test_aten_ops(self):
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x", torch.randn(5), shape_env)
|
|
torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0])
|
|
|
|
shape_env = ShapeEnv()
|
|
x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env)
|
|
torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]])
|
|
|
|
def test_fx_trace_intlist(self):
|
|
class CustomModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
bs, c, h, w = x.shape
|
|
return F.pad(x, (0, w % 2, 0, h % 2, 0, 0))
|
|
|
|
m = CustomModule()
|
|
x = torch.rand(1, 3, 4, 4)
|
|
# should not TypeError: pad(): argument 'pad' (position 2) must be
|
|
# tuple of ints, not tuple
|
|
torch.fx.symbolic_trace(m)
|
|
|
|
def test_meta_symint(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
r = torch.empty(a0, device="meta")
|
|
self.assertIsInstance(r.shape[0], SymInt)
|
|
|
|
def test_guard_int(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
self.assertEqual(guard_int(a0), 2)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
|
|
|
def test_sym_sum(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 2)
|
|
s1 = create_symint(shape_env, 3)
|
|
s2 = create_symint(shape_env, 4)
|
|
self.assertEqual(
|
|
(s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr
|
|
)
|
|
|
|
def test_prefer_deferred_runtime_assertions_over_guards(self):
|
|
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
|
s0 = create_symint(shape_env, 2)
|
|
self.assertEqual(guard_int(s0), 2)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
|
|
|
shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True)
|
|
s0 = create_symint(shape_env, 2)
|
|
self.assertTrue(expect_true(s0 == 2))
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]),
|
|
"""[Eq(s97, 2)]""",
|
|
)
|
|
|
|
def test_sym_int(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = sym_int(a0)
|
|
self.assertEqual(r, 5)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""")
|
|
|
|
a1 = create_symint(shape_env, 7)
|
|
r = sym_int(a1 / 2)
|
|
self.assertEqual(guard_int(r), 3)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)"""
|
|
)
|
|
|
|
a3 = create_symint(shape_env, 3)
|
|
r = sym_int(2.0 * torch.sym_float(a3))
|
|
self.assertEqual(guard_int(r), 6)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)"""
|
|
)
|
|
|
|
def test_sym_log2(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 4)
|
|
r = torch._sym_log2(a0)
|
|
self.assertEqual(r, 2.0)
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)"""
|
|
)
|
|
|
|
def test_sym_sqrt(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 4)
|
|
r = torch._sym_sqrt(a0)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymFloat, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)"""
|
|
)
|
|
|
|
def test_sym_floor(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.floor(a0 / 2)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""",
|
|
)
|
|
r = math.floor(3.0 * a0)
|
|
self.assertEqual(r, 15)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
|
|
)
|
|
|
|
def test_sym_trunc(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.trunc(a0 / 2)
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)"""
|
|
)
|
|
r = torch.sym_int(torch.sym_sqrt(a0))
|
|
self.assertEqual(r, 2)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""",
|
|
)
|
|
|
|
def test_sym_ceil(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = math.ceil(a0 / 2)
|
|
self.assertEqual(r, 3)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""",
|
|
)
|
|
r1 = 3.0 * a0
|
|
r = math.floor(r1)
|
|
self.assertEqual(r, 15)
|
|
self.assertIsInstance(r, torch.SymInt, msg=type(r))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(FloorToInt(3.0*ToFloat(s97)), 15)""",
|
|
)
|
|
|
|
def test_sym_ite(self):
|
|
shape_env = ShapeEnv()
|
|
t = create_symint(shape_env, 5)
|
|
f = create_symint(shape_env, 4)
|
|
b1 = True
|
|
r1 = torch.sym_ite(b1, t, f)
|
|
self.assertTrue(r1 is t)
|
|
b2 = False
|
|
r2 = torch.sym_ite(b2, t, f)
|
|
self.assertTrue(r2 is f)
|
|
b3 = t == 5
|
|
r3 = torch.sym_ite(b3, t, f)
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
self.assertEqual(r3, 5)
|
|
self.assertEqual(type(t), type(r3))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[0][0]),
|
|
"""Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""",
|
|
)
|
|
b4 = f == 5
|
|
r4 = torch.sym_ite(b4, t, f)
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
self.assertEqual(r4, 4)
|
|
self.assertEqual(type(f), type(r4))
|
|
self.assertExpectedInline(
|
|
str(shape_env.guards[1][0]),
|
|
"""Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""",
|
|
)
|
|
|
|
def test_tracing_sym_ite(self):
|
|
def f(x):
|
|
b = x.shape[0] == 5
|
|
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
|
|
return ret
|
|
|
|
gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5))
|
|
self.assertEqual(len(gm.shape_env.guards), 0)
|
|
self.assertExpectedInline(
|
|
gm.code.strip(),
|
|
"""\
|
|
def forward(self, x_1):
|
|
sym_size_int = torch.ops.aten.sym_size.int(x_1, 0)
|
|
eq = sym_size_int == 5
|
|
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None
|
|
sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None
|
|
return sym_ite""",
|
|
)
|
|
r1 = gm(torch.ones(4, 5))
|
|
self.assertIsInstance(r1, int)
|
|
self.assertEqual(r1, 5)
|
|
r2 = gm(torch.ones(5, 4))
|
|
self.assertIsInstance(r2, int)
|
|
self.assertEqual(r2, 5)
|
|
|
|
def test_int_conversion(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
int(a0)
|
|
self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""")
|
|
|
|
def test_data_dependent_guard(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0))
|
|
|
|
def test_data_dependent_guard_propagate_real_tensors(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
shape_env.set_unbacked_var_to_val(s0.node.expr, 0)
|
|
self.assertEqual(bool(s0 == 0), True)
|
|
|
|
def test_expect_true_basic(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i0_sym = i0.node.expr
|
|
# This doesn't error
|
|
self.assertTrue(expect_true(i0 == 0))
|
|
# This generates a deferred runtime assert via replacement
|
|
self.assertEqual(shape_env.replacements[i0_sym], 0)
|
|
# After expecting true, guards now resolve given the runtime assert
|
|
bool(i0 == 0)
|
|
|
|
def test_expect_true_with_s0(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 5)
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(i0 < s0))
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]),
|
|
"""[u0 < s97]""",
|
|
)
|
|
self.assertTrue(i0 < s0)
|
|
self.assertTrue(i0 != s0)
|
|
self.assertFalse(i0 > s0)
|
|
self.assertFalse(i0 >= s0)
|
|
|
|
def test_expect_true_prefer_later(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i1 = shape_env.create_unbacked_symint()
|
|
i1_sym = i1.node.expr
|
|
self.assertTrue(expect_true(i0 + i1 == 10))
|
|
# Importantly, this is put in i1, not i0!
|
|
self.assertExpectedInline(
|
|
str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]),
|
|
"""[Eq(u0 + u1, 10)]""",
|
|
)
|
|
self.assertTrue(i0 + i1 == 10)
|
|
# NB: We currently don't support deriving that we can substitute
|
|
# i0 + i1 with 10; maybe we should, but this means our rewriting
|
|
# system is no longer confluent (it's probably OK though, because
|
|
# you're unlikely to get other equalities like this on the
|
|
# unbacked SymInts.)
|
|
|
|
def test_unbacked_substitution(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
i1 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i0)
|
|
_constrain_range_for_size(i1)
|
|
self.assertTrue(expect_true(i0 == i1 * 4))
|
|
self.assertExpectedInline(str(i0), """u0""")
|
|
|
|
i2 = shape_env.create_unbacked_symint()
|
|
i3 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i2)
|
|
_constrain_range_for_size(i3)
|
|
self.assertTrue(expect_true(i2 * 4 == i3))
|
|
self.assertExpectedInline(str(i3), """u3""")
|
|
|
|
def test_avoid_unbacked_substitution(self):
|
|
shape_env = ShapeEnv()
|
|
i0 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i0)
|
|
i1 = shape_env.create_unbacked_symint()
|
|
_constrain_range_for_size(i1)
|
|
self.assertTrue(expect_true(i0 == 10 - i1))
|
|
self.assertExpectedInline(str(i0), """u0""")
|
|
|
|
def test_expect_true_double_digits(self):
|
|
shape_env = ShapeEnv()
|
|
ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10
|
|
self.assertEqual(str(ia[-1]), "u10")
|
|
self.assertTrue(expect_true(sum(ia) == 20))
|
|
self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1)
|
|
|
|
def test_expect_true_refine_range(self):
|
|
shape_env = ShapeEnv()
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertFalse(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 > 4))
|
|
self.assertTrue(statically_known_true(i0 >= 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(expect_true(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 2))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 < 4))
|
|
self.assertTrue(statically_known_true(i0 <= 5))
|
|
|
|
def test_guard_refine_range(self):
|
|
shape_env = ShapeEnv()
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 10, duck=False)
|
|
self.assertTrue(bool(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertFalse(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 > 4))
|
|
self.assertTrue(statically_known_true(i0 >= 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 2, duck=False)
|
|
self.assertFalse(bool(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertFalse(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 != 6))
|
|
self.assertTrue(statically_known_true(i0 <= 4))
|
|
self.assertTrue(statically_known_true(i0 < 5))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 2, duck=False)
|
|
self.assertTrue(bool(rel(i0)))
|
|
self.assertFalse(statically_known_true(i0 != 2))
|
|
self.assertFalse(statically_known_true(i0 != 3))
|
|
self.assertTrue(statically_known_true(i0 != 4))
|
|
self.assertTrue(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 < 4))
|
|
self.assertTrue(statically_known_true(i0 <= 3))
|
|
|
|
for i, rel in enumerate(
|
|
[lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x]
|
|
):
|
|
with self.subTest(f"i = {i}"):
|
|
i0 = create_symint(shape_env, 10, duck=False)
|
|
self.assertFalse(bool(rel(i0)))
|
|
self.assertTrue(statically_known_true(i0 != 2))
|
|
self.assertTrue(statically_known_true(i0 != 3))
|
|
self.assertFalse(statically_known_true(i0 != 4))
|
|
self.assertFalse(statically_known_true(i0 != 5))
|
|
self.assertTrue(statically_known_true(i0 >= 4))
|
|
self.assertTrue(statically_known_true(i0 > 3))
|
|
|
|
def test_mul_int_oo_nan(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 5, duck=False)
|
|
s1 = create_symint(shape_env, 6, duck=False)
|
|
s2 = create_symint(shape_env, 5, duck=False)
|
|
bool(s0 * (s1 // s0) == s2)
|
|
|
|
def test_non_overlapping_and_dense(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 5)
|
|
r = torch.empty_strided((a0, 7), (1, a0), device="meta")
|
|
self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r))
|
|
|
|
def test_non_overlapping_and_dense_unbacked(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
cf = torch.ops.aten.is_non_overlapping_and_dense.default
|
|
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1)
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1)
|
|
self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")))
|
|
self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta")))
|
|
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1)
|
|
self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1)
|
|
self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta")))
|
|
self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta")))
|
|
|
|
Max = torch.sym_max
|
|
# NB: This only works because we're able to determine this tensor is
|
|
# contiguous. transpose(0, 1) makes it stop working
|
|
self.assertTrue(
|
|
cf(
|
|
torch.empty_strided(
|
|
(2, 3, 1, u0),
|
|
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
|
|
device="meta",
|
|
)
|
|
)
|
|
)
|
|
|
|
def test_sympy_optimized_add_binary_search(self):
|
|
import sympy
|
|
|
|
from torch.fx.experimental.sym_node import _binary_search_insert_arg
|
|
|
|
a = sympy.Symbol("a")
|
|
b = sympy.Symbol("b")
|
|
c = sympy.Symbol("c")
|
|
|
|
args = []
|
|
args = _binary_search_insert_arg([], b)
|
|
self.assertEqual(args, [b])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
|
|
args = _binary_search_insert_arg(args, a)
|
|
self.assertEqual(args, [a, b])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, a), None)
|
|
|
|
args = _binary_search_insert_arg(args, c)
|
|
self.assertEqual(args, [a, b, c])
|
|
|
|
self.assertEqual(_binary_search_insert_arg(args, a), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, b), None)
|
|
self.assertEqual(_binary_search_insert_arg(args, c), None)
|
|
|
|
a1 = sympy.Symbol("a1")
|
|
a2 = sympy.Symbol("a2")
|
|
|
|
args = _binary_search_insert_arg(args, a1)
|
|
self.assertEqual(args, [a, a1, b, c])
|
|
|
|
args = _binary_search_insert_arg(args, a2)
|
|
self.assertEqual(args, [a, a1, a2, b, c])
|
|
|
|
c1 = sympy.Symbol("c1")
|
|
args = _binary_search_insert_arg(args, c1)
|
|
self.assertEqual(args, [a, a1, a2, b, c, c1])
|
|
|
|
# insert to front
|
|
_a = sympy.Symbol("_a")
|
|
args = _binary_search_insert_arg(args, _a)
|
|
self.assertEqual(args, [_a, a, a1, a2, b, c, c1])
|
|
|
|
def test_floor_clean_div_axioms(self):
|
|
# Test that if we add an axiom that have FloorDiv, after which the
|
|
# shapeEnv changed such that it can be simplified it to CleanDiv, then
|
|
# We still correctly replace CleanDiv with the axiom value of FloorDiv.
|
|
shape_env = ShapeEnv()
|
|
a = shape_env.create_unbacked_symint()
|
|
|
|
shape_env.guard_or_defer_runtime_assert((a // 3 == 1).node.expr, " test")
|
|
|
|
from sympy import Eq
|
|
|
|
test1 = Eq(FloorDiv(a.node.expr, 3), 1)
|
|
test2 = Eq(CleanDiv(a.node.expr, 3), 1)
|
|
|
|
self.assertTrue(shape_env.evaluate_expr(test1))
|
|
self.assertEqual(shape_env._maybe_evaluate_static(test2), None)
|
|
|
|
# After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3)
|
|
shape_env.guard_or_defer_runtime_assert(Eq(Mod(a, 3), 0), " test")
|
|
self.assertEqual(test2, shape_env.simplify(test1))
|
|
|
|
self.assertTrue(shape_env.evaluate_expr(test1))
|
|
self.assertTrue(shape_env.evaluate_expr(test2))
|
|
|
|
def test_sympy_optimized_add(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 2)
|
|
s1 = create_symint(shape_env, 3)
|
|
s2 = create_symint(shape_env, 4)
|
|
sum = s0 + s1
|
|
|
|
self.assertTrue(sum.node._optimized_summation)
|
|
|
|
def assert_optimized(sym):
|
|
self.assertTrue(sym.node._optimized_summation)
|
|
|
|
def assert_not_optimized(sym):
|
|
self.assertFalse(getattr(sym.node, "_optimized_summation", False))
|
|
|
|
assert_optimized(sum)
|
|
|
|
# add duplicate symbol
|
|
assert_not_optimized(sum + s0)
|
|
|
|
# add constant.
|
|
assert_not_optimized(sum + 1)
|
|
|
|
# add new unique symbol, should maintain _optimized_summation property.
|
|
assert_optimized(sum + s2)
|
|
|
|
assert_optimized(s2 + sum)
|
|
|
|
# add x + (a+b) with no _optimized_summation on the rhs sum.
|
|
a = create_symint(shape_env, 10)
|
|
b = create_symint(shape_env, 11)
|
|
two_sum = torch.sym_sum([a, b])
|
|
assert_not_optimized(two_sum)
|
|
assert_optimized(sum + two_sum)
|
|
|
|
# adding two expressions of length >2 that are _optimized_summation.
|
|
a = s0 + s1 + s2
|
|
s3 = create_symint(shape_env, 10)
|
|
s4 = create_symint(shape_env, 20)
|
|
s5 = create_symint(shape_env, 30)
|
|
b = s3 + s4 + s5
|
|
assert_optimized(a)
|
|
assert_optimized(b)
|
|
assert_not_optimized(a + b)
|
|
assert_not_optimized(b + a)
|
|
assert_not_optimized(b + a + b)
|
|
|
|
def test_max_of_unique_summation_opt(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = shape_env.create_unbacked_symint()
|
|
s1 = shape_env.create_unbacked_symint()
|
|
s2 = shape_env.create_unbacked_symint()
|
|
s3 = shape_env.create_unbacked_symint()
|
|
s4 = shape_env.create_unbacked_symint()
|
|
s5 = shape_env.create_unbacked_symint()
|
|
s7 = shape_env.create_unbacked_symint()
|
|
|
|
def assert_optimized(sym):
|
|
self.assertTrue(sym.node.expr.unique_summations_symbols is not None)
|
|
|
|
def assert_not_optimized(sym):
|
|
getattr(sym.node.expr, "unique_summations_symbols", None)
|
|
|
|
mx1 = torch.sym_max(s0, s1)
|
|
assert_not_optimized(mx1)
|
|
|
|
mx2 = torch.sym_max(s0 + s1, s2 + s3)
|
|
assert_optimized(mx2)
|
|
|
|
mx3 = torch.sym_max(mx2, s4 + s5)
|
|
assert_optimized(mx3)
|
|
assert_optimized(torch.sym_max(s4 + s5, mx2))
|
|
|
|
assert_not_optimized(torch.sym_max(mx3, s7))
|
|
assert_not_optimized(torch.sym_max(mx3, 10))
|
|
assert_not_optimized(torch.sym_max(mx3, s3 + s7))
|
|
assert_not_optimized(torch.sym_max(mx3, s7 * 2))
|
|
|
|
def test_sym_max_multi_max_simplify(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
self.assertTrue(
|
|
statically_known_true(
|
|
torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0)
|
|
)
|
|
)
|
|
|
|
def test_numpy_sym_max(self):
|
|
self.assertEqual(torch.sym_max(np.int64(10), 12), 12)
|
|
self.assertEqual(torch.sym_max(np.int64(12), 10), 12)
|
|
self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5)
|
|
self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0)
|
|
self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0)
|
|
self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0)
|
|
|
|
def test_numpy_sym_min(self):
|
|
self.assertEqual(torch.sym_min(np.int64(10), 12), 10)
|
|
self.assertEqual(torch.sym_min(np.int64(12), 10), 10)
|
|
self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0)
|
|
self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5)
|
|
self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0)
|
|
self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0)
|
|
|
|
def test_debug_has_internal_overlap_unbacked(self):
|
|
shape_env = ShapeEnv()
|
|
u0 = shape_env.create_unbacked_symint()
|
|
torch._check_is_size(u0)
|
|
cf = torch._debug_has_internal_overlap
|
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0)
|
|
self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2)
|
|
Max = torch.sym_max
|
|
self.assertEqual(
|
|
cf(
|
|
torch.empty_strided(
|
|
(2, 3, 1, u0),
|
|
(3 * Max(1, u0), Max(1, u0), Max(1, u0), 1),
|
|
device="meta",
|
|
)
|
|
),
|
|
2,
|
|
)
|
|
|
|
# Wobbling these to zero is OK too
|
|
self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2)
|
|
self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2)
|
|
|
|
def test_specialize_zero_one(self):
|
|
shape_env = ShapeEnv(specialize_zero_one=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0 != 1
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
shape_env = ShapeEnv(specialize_zero_one=False)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0 != 1
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
|
|
def test_duck_shape(self):
|
|
shape_env = ShapeEnv(duck_shape=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
a1 = create_symint(shape_env, 5)
|
|
assert a0 == a1
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
shape_env = ShapeEnv(duck_shape=False)
|
|
a0 = create_symint(shape_env, 5)
|
|
a1 = create_symint(shape_env, 5)
|
|
assert a0 == a1
|
|
self.assertEqual(len(shape_env.guards), 1)
|
|
|
|
def test_int_bool(self):
|
|
# See https://github.com/pytorch/pytorch/issues/95981
|
|
shape_env = ShapeEnv(duck_shape=True)
|
|
a0 = create_symint(shape_env, 5)
|
|
assert a0
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_symint_as_scalar(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
|
|
sym_int_encountered = False
|
|
|
|
class TestSymInt(TorchDispatchMode):
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
assert func == torch.ops.aten.add.Tensor
|
|
|
|
nonlocal sym_int_encountered
|
|
# WARNING: do not do identity tests on the outer
|
|
# SymInt/SymFloat, they are NOT STABLE
|
|
sym_int_encountered = kwargs["alpha"].node is a0.node
|
|
kwargs["alpha"] = 0
|
|
return func(*args)
|
|
|
|
x = torch.rand([4, 4])
|
|
with TestSymInt():
|
|
y = torch.add(x, x, alpha=a0)
|
|
|
|
self.assertTrue(sym_int_encountered)
|
|
|
|
def test_deepcopy(self):
|
|
shape_env = ShapeEnv()
|
|
a0 = create_symint(shape_env, 2)
|
|
assert a0 < 4
|
|
new_shape_env = copy.deepcopy(shape_env)
|
|
self.assertEqual(len(new_shape_env.guards), 1)
|
|
|
|
def test_print_readable_with_symints(self):
|
|
def f(a, b):
|
|
dim0 = a.shape[0] + b.shape[0]
|
|
dim1 = a.shape[1] + b.shape[1]
|
|
d = a.new_empty(dim0, dim1)
|
|
d = torch.ops.aten.native_dropout(d, 0.5, train=True)
|
|
return d
|
|
|
|
fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3))
|
|
out = fx_g.print_readable(print_output=False)
|
|
|
|
self.assertExpectedInline(
|
|
out.strip(),
|
|
"""\
|
|
class f(torch.nn.Module):
|
|
def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"):
|
|
# No stacktrace found for following nodes
|
|
sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0)
|
|
sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0)
|
|
add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None
|
|
sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1)
|
|
sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None
|
|
add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None
|
|
new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None
|
|
native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None
|
|
getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0]
|
|
getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None
|
|
return (getitem, getitem_1)""", # noqa: B950
|
|
)
|
|
|
|
def test_statically_known_true(self):
|
|
shape_env = ShapeEnv()
|
|
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
|
|
|
|
# Statically known true
|
|
self.assertTrue(statically_known_true(True))
|
|
self.assertTrue(statically_known_true(s2 == s2))
|
|
self.assertTrue(statically_known_true(s2 * s3 > s3))
|
|
self.assertTrue(statically_known_true(s3 * s4 > s4))
|
|
self.assertTrue(statically_known_true((s3 + s3) % 2 == 0))
|
|
|
|
# Statically known false
|
|
self.assertFalse(statically_known_true(False))
|
|
self.assertFalse(statically_known_true(s3 * s4 <= s4))
|
|
self.assertFalse(statically_known_true((s3 + s3) % 2 == 1))
|
|
|
|
# True for hints, but not known statically
|
|
self.assertFalse(statically_known_true(s2 + s2 == s4))
|
|
self.assertFalse(statically_known_true(s4 % s2 == 0))
|
|
self.assertFalse(statically_known_true(s2 != s3))
|
|
self.assertFalse(statically_known_true(s3 * s4 > s2))
|
|
|
|
# False for hints, but not known statically
|
|
self.assertFalse(statically_known_true(s2 == s3))
|
|
self.assertFalse(statically_known_true(s2 > s3))
|
|
self.assertFalse(statically_known_true(s3 + s3 == s4))
|
|
|
|
# No guards should be generated
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_statically_known_false(self):
|
|
shape_env = ShapeEnv()
|
|
s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5))
|
|
|
|
# Statically known true
|
|
self.assertFalse(statically_known_false(True))
|
|
self.assertFalse(statically_known_false(s2 == s2))
|
|
self.assertFalse(statically_known_false(s2 * s3 > s3))
|
|
self.assertFalse(statically_known_false(s3 * s4 > s4))
|
|
self.assertFalse(statically_known_false((s3 + s3) % 2 == 0))
|
|
|
|
# Statically known false
|
|
self.assertTrue(statically_known_false(False))
|
|
self.assertTrue(statically_known_false(s3 * s4 <= s4))
|
|
self.assertTrue(statically_known_false((s3 + s3) % 2 == 1))
|
|
|
|
# True for hints, but not known statically
|
|
self.assertFalse(statically_known_false(s2 + s2 == s4))
|
|
self.assertFalse(statically_known_false(s4 % s2 == 0))
|
|
self.assertFalse(statically_known_false(s2 != s3))
|
|
self.assertFalse(statically_known_false(s3 * s4 > s2))
|
|
|
|
# False for hints, but not known statically
|
|
self.assertFalse(statically_known_false(s2 == s3))
|
|
self.assertFalse(statically_known_false(s2 > s3))
|
|
self.assertFalse(statically_known_false(s3 + s3 == s4))
|
|
|
|
# No guards should be generated
|
|
self.assertEqual(len(shape_env.guards), 0)
|
|
|
|
def test_ephemeral_source_simplification(self):
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
|
# For full robustness, ensure the ephemeral source symbols are simplified out regardless
|
|
# of construction order or check order.
|
|
for construct_ephemeral_first, x_first_in_check in itertools.product(
|
|
[False, True], [False, True]
|
|
):
|
|
shape_env = ShapeEnv()
|
|
shape = (5, 10)
|
|
dynamic_dims = [DimDynamic.DYNAMIC for _ in shape]
|
|
x = create_symbolic_tensor(
|
|
"x",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if construct_ephemeral_first else None),
|
|
dynamic_dims=dynamic_dims,
|
|
)
|
|
y = create_symbolic_tensor(
|
|
"y",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
|
dynamic_dims=dynamic_dims,
|
|
)
|
|
t_with_ephemeral = x if construct_ephemeral_first else y
|
|
|
|
def _get_ephemeral_source_symbols(t):
|
|
return [
|
|
s.node.expr
|
|
for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),))
|
|
if isinstance(s, torch.SymInt)
|
|
and s.node.expr in shape_env.var_to_sources
|
|
and any(
|
|
source.is_ephemeral()
|
|
for source in shape_env.var_to_sources[s.node.expr]
|
|
)
|
|
]
|
|
|
|
# these checks should simplify out the ephemeral symbols, regardless of the
|
|
# ordering x == y or y == x
|
|
self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0)
|
|
if x_first_in_check:
|
|
torch._check(x.size() == y.size())
|
|
torch._check(x.stride() == y.stride())
|
|
torch._check(x.storage_offset() == y.storage_offset())
|
|
else:
|
|
torch._check(y.size() == x.size())
|
|
torch._check(y.stride() == x.stride())
|
|
torch._check(y.storage_offset() == x.storage_offset())
|
|
self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0)
|
|
|
|
def test_ephemeral_source_unified_with_non_ephemeral_source(self):
|
|
from torch._dynamo.source import EphemeralSource
|
|
|
|
for construct_ephemeral_first in (False, True):
|
|
shape_env = ShapeEnv()
|
|
shape = (5, 10)
|
|
# use duck sizing here to ensure symbol reuse across x and y
|
|
duck_dims = [DimDynamic.DUCK for _ in shape]
|
|
x = create_symbolic_tensor(
|
|
"x",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if construct_ephemeral_first else None),
|
|
dynamic_dims=duck_dims,
|
|
)
|
|
y = create_symbolic_tensor(
|
|
"y",
|
|
torch.randn(*shape),
|
|
shape_env,
|
|
source=(EphemeralSource() if not construct_ephemeral_first else None),
|
|
dynamic_dims=duck_dims,
|
|
)
|
|
|
|
# regardless of construction order, non-ephemeral sources should be preferred
|
|
# first in the var_to_sources list for potential guarding later on
|
|
for source_list in shape_env.var_to_sources.values():
|
|
self.assertFalse(source_list[0].is_ephemeral())
|
|
|
|
self.assertEqual(x.size(), y.size())
|
|
self.assertEqual(x.stride(), y.stride())
|
|
self.assertEqual(x.storage_offset(), y.storage_offset())
|
|
|
|
def test_tensor_factory_with_symint(self):
|
|
args = list(range(0, 3))
|
|
expected = torch.tensor(args)
|
|
|
|
shape_env = ShapeEnv()
|
|
sym_args = [create_symint(shape_env, i) for i in args]
|
|
|
|
# test tensor factories
|
|
for dt in all_types_and(torch.half, torch.bfloat16):
|
|
res = torch.tensor(sym_args, dtype=dt)
|
|
self.assertEqual(res, expected, exact_dtype=False)
|
|
|
|
# test legacy tensor factories
|
|
legacy_ctors = [
|
|
torch.Tensor,
|
|
torch.LongTensor,
|
|
torch.DoubleTensor,
|
|
torch.FloatTensor,
|
|
torch.IntTensor,
|
|
torch.ShortTensor,
|
|
torch.HalfTensor,
|
|
torch.ByteTensor,
|
|
]
|
|
for Tensor in legacy_ctors:
|
|
res = Tensor(sym_args)
|
|
self.assertEqual(res, expected, exact_dtype=False)
|
|
|
|
def test_backed_size_oblivious_01_spec(self):
|
|
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
|
|
|
@torch.compile(dynamic=True, fullgraph=True)
|
|
def f(a, b):
|
|
if guard_size_oblivious(a.size(0) == 1):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
|
# always go to the >= 2 branch.
|
|
self.assertEqual(
|
|
f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
|
|
def test_baddbmm_symint(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
|
|
shape_env = ShapeEnv()
|
|
fake_mode = FakeTensorMode(shape_env=shape_env)
|
|
|
|
B, M, K, N = [shape_env.create_unbacked_symint() for _ in range(4)]
|
|
|
|
with fake_mode:
|
|
A = torch.empty((B, M, K), device="meta")
|
|
Bmat = torch.empty((B, K, N), device="meta")
|
|
bias3 = torch.empty((B, M, N), device="meta")
|
|
|
|
_ = torch.baddbmm(bias3, A, Bmat)
|
|
|
|
|
|
@skipIfTorchDynamo(
|
|
"Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)"
|
|
)
|
|
class TestSymNumberMagicMethods(TestCase):
|
|
def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
|
with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn):
|
|
return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn)
|
|
|
|
def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn):
|
|
# Helper function
|
|
# NB: don't use one as that will get specialized
|
|
# TODO: We don't have to circuitously create the float, can just
|
|
# create a symfloat directly
|
|
seed_node = (create_symint(shape_env, 2) / 2.0).node
|
|
bool_seed_node = (create_symint(shape_env, 2) == 2).node
|
|
|
|
def get_sym_inp(inp):
|
|
# NB: this must come before int
|
|
if isinstance(inp, bool):
|
|
return torch.SymBool(to_node(bool_seed_node, inp))
|
|
elif isinstance(inp, int):
|
|
return torch.SymInt(to_node(seed_node, inp))
|
|
else:
|
|
return torch.SymFloat(to_node(seed_node, inp))
|
|
|
|
if fn == "float_pow":
|
|
if inp1 < 0:
|
|
return
|
|
|
|
if fn == "pow_by_natural":
|
|
if isinstance(inp1, float) or isinstance(inp2, float):
|
|
return
|
|
if inp2 < 0:
|
|
return
|
|
|
|
def maybe_xfail(inp1, inp2):
|
|
if fn == "sym_sqrt" and inp1 < 0:
|
|
# ValueError: math domain error
|
|
return self.assertRaises((ValueError,))
|
|
elif (
|
|
fn in ("float_truediv", "int_truediv", "int_floordiv", "mod")
|
|
and inp2 == 0
|
|
):
|
|
# ZeroDivisionError: division by zero
|
|
return self.assertRaises((ZeroDivisionError,))
|
|
elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0:
|
|
# ZeroDivisionError: 0.0 cannot be raised to a negative power
|
|
return self.assertRaises((ZeroDivisionError,))
|
|
elif (
|
|
# TODO: dear catastrophe waitress,
|
|
# this doesn't work
|
|
fn in ["float_pow", "pow_by_natural"]
|
|
and inp1 < 0
|
|
and (
|
|
type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat)
|
|
)
|
|
and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float))
|
|
):
|
|
# Complex result, which we do not support:
|
|
# TypeError: Cannot convert complex to float
|
|
return self.assertRaises((RuntimeError,))
|
|
elif fn in ("lshift", "rshift") and not (
|
|
isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int))
|
|
):
|
|
# TypeError: unsupported operand type(s)
|
|
return self.assertRaises((TypeError,))
|
|
elif fn in ("lshift", "rshift") and inp2 < 0:
|
|
# ValueError: math domain error
|
|
return self.assertRaises((ValueError,))
|
|
else:
|
|
return contextlib.nullcontext()
|
|
|
|
lambda_apply = method_to_operator(fn)
|
|
|
|
def guard_fn(v):
|
|
if type(v) in (SymBool, bool):
|
|
return guard_bool(v)
|
|
elif type(v) in (SymFloat, float):
|
|
return guard_float(v)
|
|
else: # SymInt, int
|
|
return guard_int(v)
|
|
|
|
# Get reference result
|
|
with maybe_xfail(inp1, inp2):
|
|
if is_unary_fn:
|
|
ref_out = lambda_apply(inp1)
|
|
else:
|
|
ref_out = lambda_apply(inp1, inp2)
|
|
|
|
# Symified first arg
|
|
sym_inp1 = get_sym_inp(inp1)
|
|
with maybe_xfail(sym_inp1, inp2):
|
|
if is_unary_fn:
|
|
out = lambda_apply(sym_inp1)
|
|
else:
|
|
out = lambda_apply(sym_inp1, inp2)
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
if is_unary_fn:
|
|
return
|
|
|
|
# Symified second arg
|
|
sym_inp2 = get_sym_inp(inp2)
|
|
with maybe_xfail(inp1, sym_inp2):
|
|
out = lambda_apply(inp1, sym_inp2)
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
# Symified both args
|
|
with maybe_xfail(sym_inp1, sym_inp2):
|
|
out = lambda_apply(sym_inp1, sym_inp2)
|
|
self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool)))
|
|
out = guard_fn(out)
|
|
self.assertEqual(out, ref_out)
|
|
|
|
@parametrize("fn", list(sym_node.magic_methods.keys()))
|
|
def test_bool_method(self, fn):
|
|
# sym_ite has its own tests
|
|
if fn not in sym_node.bool_magic_methods or fn == "sym_ite":
|
|
self.skipTest(f"{fn} is non-bool")
|
|
|
|
is_unary_fn = fn in sym_node.unary_methods
|
|
shape_env = ShapeEnv()
|
|
self._do_test(fn, True, False, shape_env, is_unary_fn)
|
|
|
|
@parametrize("fn", list(sym_node.magic_methods.keys()))
|
|
@parametrize("first_type", ["int", "float"])
|
|
@parametrize("second_type", ["int", "float"])
|
|
def test_method(self, fn, first_type, second_type):
|
|
if first_type == "float":
|
|
# TODO: Hmm, this looks like we skip all floats
|
|
self.skipTest(f"{fn} is not a float magic method")
|
|
|
|
if (
|
|
first_type == "int" or second_type == "int"
|
|
) and fn in sym_node.only_float_magic_methods:
|
|
self.skipTest(f"{fn} is not an int method")
|
|
|
|
if second_type == "float" and fn in ["mod"]:
|
|
self.skipTest(f"{fn} only handles int")
|
|
|
|
if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"):
|
|
self.skipTest(f"{fn} is a bitwise op, only handles int")
|
|
|
|
is_unary_fn = fn in sym_node.unary_methods or fn == "round"
|
|
# Second argument is ignored for unary function. So only run for one type
|
|
if is_unary_fn and second_type == "float":
|
|
self.skipTest(f"{fn} is unary and already tested")
|
|
|
|
if fn in sym_node.bool_magic_methods:
|
|
self.skipTest(f"{fn} is bool")
|
|
|
|
# Only floats here since these will be converted to int if necessary.
|
|
# We also ignore complex and bool.
|
|
values = (
|
|
0.0,
|
|
1.0,
|
|
0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error
|
|
)
|
|
|
|
neg_values = tuple(-x for x in values)
|
|
|
|
for inp1, inp2 in itertools.chain(
|
|
itertools.product(values, values),
|
|
itertools.product(values, neg_values),
|
|
itertools.product(neg_values, values),
|
|
itertools.product(neg_values, neg_values),
|
|
):
|
|
if first_type == "int":
|
|
inp1 = int(inp1)
|
|
if second_type == "int":
|
|
inp2 = int(inp2)
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
self._do_test(fn, inp1, inp2, shape_env, is_unary_fn)
|
|
|
|
def get_constant_bool(self, val):
|
|
return SymBool(torch._C._get_constant_bool_symnode(val))
|
|
|
|
@unittest.expectedFailure
|
|
def test_symint_hashing(self):
|
|
shape_env = ShapeEnv()
|
|
hash(create_symint(shape_env, 3))
|
|
|
|
def test_symnode_hashing(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
# These all trigger specialization when hashed
|
|
hash(create_symbool(shape_env, True))
|
|
# We should be passing in float here, but create_symbol currently
|
|
# only supports int
|
|
hash(create_symfloat(shape_env, 3.0))
|
|
|
|
# NestedInt (SymInt), constant SymBool, SymNode are hashable
|
|
j1 = torch._C._get_nested_int(1, 1)
|
|
j1_copy = torch._C._get_nested_int(1, 1)
|
|
j2 = torch._C._get_nested_int(2, 1)
|
|
t = self.get_constant_bool(True)
|
|
t_copy = self.get_constant_bool(True)
|
|
f = self.get_constant_bool(False)
|
|
n = create_symint(shape_env, 3).node
|
|
m = self.get_constant_bool(True).node
|
|
|
|
self.assertIs(j1 == j1_copy, True)
|
|
self.assertEqual(hash(j1), hash(j1_copy))
|
|
self.assertIs(j1 == j2, False)
|
|
self.assertNotEqual(hash(j1), hash(j2))
|
|
self.assertIs(t == t_copy, True)
|
|
self.assertEqual(hash(t), hash(t_copy))
|
|
self.assertIs(t == f, False)
|
|
self.assertNotEqual(hash(t), hash(f))
|
|
|
|
hash(n)
|
|
hash(m)
|
|
|
|
def test_symint_deepcopy(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
symnodes = (torch._C._get_nested_int(1, 1),)
|
|
deepcopied_symnodes = copy.deepcopy(symnodes)
|
|
self.assertEqual(symnodes, deepcopied_symnodes)
|
|
|
|
def test_non_symbolic_symnode(self):
|
|
j1 = torch._C._get_nested_int(1, 1)
|
|
j2 = torch._C._get_nested_int(1, 1)
|
|
j3 = torch._C._get_nested_int(3, 1)
|
|
|
|
self.assertIsInstance(j1, torch.SymInt)
|
|
self.assertNotIsInstance(j1, int)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "add not supported by NestedIntSymNode"
|
|
):
|
|
j1 + 3
|
|
|
|
self.assertFalse(j1 == 3)
|
|
with self.assertRaisesRegex(RuntimeError, "indeterminate"):
|
|
self.assertFalse(3 >= j2)
|
|
|
|
self.assertIs(j1 == j1, True)
|
|
self.assertIs(j1 == j2, True)
|
|
self.assertIs(j1 == j3, False)
|
|
self.assertIs(j1 != j3, True)
|
|
self.assertIs(j1 != j2, False)
|
|
|
|
x = self.get_constant_bool(True)
|
|
#
|
|
# Unary
|
|
#
|
|
# op(constant SymBool)
|
|
self.assertIs(x.__sym_not__(), False)
|
|
|
|
#
|
|
# Binary
|
|
#
|
|
# op(constant SymBool, bool)
|
|
# op(constant SymBool, constant SymBool)
|
|
# op(bool, constant SymBool)
|
|
self.assertIs(operator.and_(x, True), True)
|
|
self.assertIs(operator.and_(x, x), True)
|
|
self.assertIs(operator.and_(True, x), True)
|
|
|
|
# op(symbolic SymBool, constant Symbool)
|
|
# op(constant SymBool, symbolic Symbool)
|
|
shape_env = ShapeEnv()
|
|
a = create_symint(shape_env, 2)
|
|
b = create_symint(shape_env, 2)
|
|
c = a == b # symbolic SymBool
|
|
d = self.get_constant_bool(True)
|
|
e = operator.and_(c, d)
|
|
f = operator.and_(d, c)
|
|
self.assertTrue(is_symbolic(e))
|
|
self.assertTrue(is_symbolic(f))
|
|
self.assertIs(e.node.guard_bool("", 0), True)
|
|
self.assertIs(f.node.guard_bool("", 0), True)
|
|
|
|
# Comparing sizes
|
|
sz1 = torch.Size([j1, j1, j1])
|
|
sz2 = torch.Size([j1, j1, j1])
|
|
self.assertIs(sz1 == sz2, True)
|
|
|
|
sz1 = torch.Size([3, j1, 4])
|
|
sz2 = torch.Size([3, j2, 4])
|
|
self.assertIs(sz1 == sz2, True)
|
|
self.assertIs(sz1 != sz2, False)
|
|
|
|
def test_stride_symnode(self):
|
|
shape_env = ShapeEnv()
|
|
|
|
# check everything static
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.STATIC,
|
|
DimDynamic.STATIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
self.assertTrue(all(isinstance(size, int) for size in t.size()))
|
|
self.assertTrue(all(isinstance(stride, int) for stride in t.stride()))
|
|
|
|
# check dynamic size but static dims
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
# Expect stride to be inferred
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, torch.SymInt))
|
|
self.assertTrue(isinstance(s1, torch.SymInt))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(s1 == s2)
|
|
self.assertEqual(s3, 1)
|
|
|
|
# Check dynamic stride but static dims
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.STATIC,
|
|
DimDynamic.STATIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, int))
|
|
self.assertTrue(isinstance(s1, int))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(isinstance(s3, int))
|
|
|
|
# Check dynamic sizes and dims, and ensure different symbol
|
|
t = create_fake_tensor_with_dynamic_size(
|
|
torch.ones(3, 6),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
s0, s1 = t.size()
|
|
s2, s3 = t.stride()
|
|
self.assertTrue(isinstance(s0, torch.SymInt))
|
|
self.assertTrue(isinstance(s1, torch.SymInt))
|
|
self.assertTrue(isinstance(s2, torch.SymInt))
|
|
self.assertTrue(isinstance(s3, int))
|
|
self.assertTrue(str(s1.node.expr) != str(s2.node.expr))
|
|
|
|
|
|
instantiate_parametrized_tests(TestSymNumberMagicMethods)
|
|
|
|
|
|
class TestFloorDiv(TestCase):
|
|
@staticmethod
|
|
def python_floordiv(x, y):
|
|
return x // y
|
|
|
|
@staticmethod
|
|
def torch_floordiv(x, y):
|
|
# Note: we fully evaluate here since FloorDiv might not always do
|
|
# that.
|
|
shape_env = ShapeEnv()
|
|
return shape_env.evaluate_expr(FloorDiv(x, y))
|
|
|
|
@staticmethod
|
|
def yield_test_cases(values, negate=True):
|
|
for x, y in values:
|
|
yield (x, y)
|
|
if negate:
|
|
yield (-x, y)
|
|
yield (x, -y)
|
|
yield (-x, -y)
|
|
|
|
def test_floordiv_float_int(self):
|
|
values = ((7, 2),)
|
|
|
|
for x, y in TestFloorDiv.yield_test_cases(values):
|
|
self.assertEqual(
|
|
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
|
)
|
|
|
|
def test_floordiv_div_by_one(self):
|
|
values = ((2, 1),)
|
|
|
|
for x, y in TestFloorDiv.yield_test_cases(values):
|
|
self.assertEqual(
|
|
TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y)
|
|
)
|
|
|
|
def test_floordiv_simplify(self):
|
|
# Tests how we simplify or evaluate FloorDiv without free variables
|
|
shape_env = ShapeEnv()
|
|
result = 21
|
|
exprs = (7 * FloorDiv(6, 2),)
|
|
|
|
for expr in exprs:
|
|
self.assertEqual(expr, result)
|
|
self.assertEqual(expr.doit(deep=False), result)
|
|
self.assertEqual(expr.doit(deep=True), result)
|
|
self.assertEqual(sympy.simplify(expr), result)
|
|
self.assertEqual(shape_env.simplify(expr), result)
|
|
self.assertEqual(shape_env.evaluate_expr(expr), result)
|
|
|
|
def test_floordiv_assumptions(self):
|
|
cases = (
|
|
sympy.Symbol("i1", integer=True),
|
|
sympy.Symbol("i2", integer=True),
|
|
)
|
|
|
|
for base, divisor in itertools.product(cases, repeat=2):
|
|
|
|
def op():
|
|
return FloorDiv(base, divisor)
|
|
|
|
def is_complex(x):
|
|
return x.is_integer is False and x.is_real is False and x.is_complex
|
|
|
|
if is_complex(base) or is_complex(divisor):
|
|
self.assertRaisesRegex(
|
|
TypeError,
|
|
(
|
|
r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol',"
|
|
r" expected integer or real"
|
|
),
|
|
op,
|
|
)
|
|
continue
|
|
|
|
op = op()
|
|
|
|
# In regular Python, x//x == 1.0 if x is a float, but FloorDiv
|
|
# always returns an integer 1 when both args are the same object.
|
|
# This even works for Symbols with no assumptions specified.
|
|
if base is divisor:
|
|
self.assertTrue(op.is_integer)
|
|
self.assertTrue(op.is_real)
|
|
elif base.is_integer and divisor.is_integer:
|
|
self.assertTrue(op.is_integer)
|
|
self.assertTrue(op.is_real)
|
|
else:
|
|
self.assertEqual(op.is_integer, None)
|
|
self.assertTrue(op.is_real)
|
|
|
|
|
|
class TestDimConstraints(TestCase):
|
|
@skipIfTorchDynamo("mark_dynamic not supported")
|
|
def test_simplify_max_1_0(self):
|
|
x = torch.rand(10)
|
|
torch._dynamo.mark_dynamic(x, 0, max=20, min=5)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def func(x, v):
|
|
# test that statically_known_true
|
|
if (v == 0 or v == 1) and not statically_known_true(
|
|
max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2)
|
|
):
|
|
raise AssertionError("error")
|
|
|
|
if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2):
|
|
return x * 400
|
|
else:
|
|
return (x * 10) * 100
|
|
|
|
# testing that this does not throw constraint violation error.
|
|
self.assertEqual(func(x, 1), x * 400)
|
|
self.assertEqual(func(x, 0), x * 400)
|
|
|
|
def test_dim_constraints_reduce_congruences_simple(self):
|
|
from sympy import Symbol
|
|
|
|
s = Symbol("s", positive=True, integer=True)
|
|
dim_constraints = DimConstraints({}, {}, set(), {})
|
|
dim_constraints._congruences[s] = {
|
|
(s / 2) % 2,
|
|
(s / 2) % 8,
|
|
(s / 2) % 4,
|
|
s % 2,
|
|
((s / 16) + 2) % 4,
|
|
}
|
|
congruences = dim_constraints._reduce_congruences()
|
|
self.assertEqual(congruences[s], {(s + 32) % 64})
|
|
|
|
def test_dim_constraints_reduce_inequalities_simple(self):
|
|
from sympy import Eq, Interval, Ne, Symbol
|
|
from sympy.solvers.inequalities import reduce_inequalities
|
|
|
|
s = Symbol("s", positive=True, integer=True)
|
|
exprs = {
|
|
s >= 2,
|
|
Ne(8 * s, 16),
|
|
Ne(s / 2, 1),
|
|
Ne(16 * s, 32),
|
|
s < 16,
|
|
Ne(s, 2),
|
|
s / 2 < 16,
|
|
s / 2 > 1,
|
|
s / 2 >= 2,
|
|
Ne(3 * s / 2, 3),
|
|
}
|
|
solution = reduce_inequalities(exprs, s).as_set()
|
|
self.assertEqual(solution, Interval.Ropen(4, 16))
|
|
|
|
exprs.add(Eq(s / 2, 4))
|
|
solution = reduce_inequalities(exprs, s).as_set()
|
|
self.assertEqual(solution, {8})
|
|
|
|
def test_dim_constraints_reduce_inequalities_error(self):
|
|
from collections import defaultdict
|
|
|
|
from sympy import Symbol
|
|
from sympy.solvers.inequalities import reduce_inequalities
|
|
|
|
from torch._dynamo.source import (
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter
|
|
|
|
s0 = Symbol("s0", positive=True, integer=True)
|
|
exprs = {
|
|
4 * s0**3 - 4 * s0**2 + s0 <= 2147483647,
|
|
s0 >= 2,
|
|
s0**3 <= 2147483647,
|
|
s0 <= 2147483647,
|
|
}
|
|
answer = reduce_inequalities(exprs, s0)
|
|
|
|
symbol_to_source = defaultdict(list)
|
|
symbol_to_source[s0].append(
|
|
TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
)
|
|
dcp = DynamicDimConstraintPrinter(symbol_to_source, {})
|
|
with self.assertRaisesRegex(
|
|
AssertionError,
|
|
"Unknown symbol.*created by constraints solver",
|
|
):
|
|
dcp.doprint(answer)
|
|
|
|
def test_dim_constraints_solve_full(self):
|
|
from sympy import Eq, Integer, Ne, Symbol
|
|
|
|
from torch._dynamo.source import (
|
|
LocalSource,
|
|
TensorProperty,
|
|
TensorPropertySource,
|
|
)
|
|
|
|
src0 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src2 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src3 = TensorPropertySource(
|
|
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
src4 = TensorPropertySource(
|
|
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0
|
|
)
|
|
|
|
src1 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2
|
|
)
|
|
src7 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3
|
|
)
|
|
|
|
src5 = TensorPropertySource(
|
|
base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src8 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
|
|
src6 = TensorPropertySource(
|
|
base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src9 = TensorPropertySource(
|
|
base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src10 = TensorPropertySource(
|
|
base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
|
|
src11 = TensorPropertySource(
|
|
base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1
|
|
)
|
|
src12 = TensorPropertySource(
|
|
base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2
|
|
)
|
|
|
|
s0 = Symbol("s0", positive=True, integer=True)
|
|
s1 = Symbol("s1", positive=True, integer=True)
|
|
s5 = Symbol("s5", positive=True, integer=True)
|
|
s6 = Symbol("s6", positive=True, integer=True)
|
|
symbol_to_source = {
|
|
s0: [src0, src2, src3, src4],
|
|
s1: [src1, src7],
|
|
s5: [src5, src8],
|
|
s6: [src6, src9, src10],
|
|
}
|
|
var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21}
|
|
marked_dynamic = {s0, s1, s5, s6}
|
|
dim_constraints = DimConstraints(
|
|
symbol_to_source, var_to_val, marked_dynamic, {}
|
|
)
|
|
dim_constraints.add_equality(src2, s0)
|
|
dim_constraints.add_equality(src3, s0)
|
|
dim_constraints.add_equality(src4, s0)
|
|
dim_constraints.add_equality(src7, s1)
|
|
dim_constraints.add_equality(src8, s5)
|
|
dim_constraints.add_equality(src9, s6)
|
|
dim_constraints.add_equality(src10, s6)
|
|
dim_constraints.add_equality(src11, Integer(1))
|
|
dim_constraints.add_equality(src12, Integer(3))
|
|
|
|
dim_constraints.add(s1**2 <= 2147483647)
|
|
dim_constraints.add(32 * s1**2 <= 2147483647)
|
|
dim_constraints.add(s0 < 16)
|
|
dim_constraints.add(Eq(Mod(s1, 2), 0))
|
|
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
|
|
dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1))
|
|
dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647)
|
|
dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1)
|
|
dim_constraints.add(Ne(FloorDiv(s1, 2), 1))
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 64
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 128
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 256
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
+ 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * s0,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
>= 0
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0))
|
|
dim_constraints.add(
|
|
1
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * s0,
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2))
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2)
|
|
+ 60 * (Mod(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
)
|
|
dim_constraints.add(Ne(16 * s0, 32))
|
|
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
|
|
dim_constraints.add(Ne(16 * s0, 32))
|
|
dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0))
|
|
dim_constraints.add(FloorDiv(s0, 2) >= 2)
|
|
dim_constraints.add(Ne(FloorDiv(s0, 2), 1))
|
|
dim_constraints.add(1 < FloorDiv(s0, 2))
|
|
dim_constraints.add(Ne(s0, 2))
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
>= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60
|
|
* (FloorDiv(s0, 2))
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120
|
|
* FloorDiv(s0, 2)
|
|
* FloorDiv(s0, (FloorDiv(s0, 2)))
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20,
|
|
20,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20
|
|
* (
|
|
Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
20
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
* (
|
|
Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
)
|
|
)
|
|
- 20
|
|
* Mod(
|
|
1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1))
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 0
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
>= 2
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 20
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60,
|
|
60,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8))
|
|
* (
|
|
Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
1,
|
|
)
|
|
)
|
|
- Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
- 2
|
|
* FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
/ (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1)
|
|
+ 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1),
|
|
1,
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(8 * s0, 16))
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
>= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv(s0, 2) < 16)
|
|
dim_constraints.add(FloorDiv(s0, 2) > 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60
|
|
* (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
* (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, (FloorDiv(s0, 2))))
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2)),
|
|
3 * (FloorDiv(s0, 2)),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60 * (FloorDiv(s0, 2))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
6,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 120
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 90
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 60
|
|
<= 20480
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 240,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Eq(6 * s5, 132))
|
|
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
|
|
dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4))
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 64
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 62
|
|
)
|
|
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
|
|
dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3))
|
|
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
|
|
dim_constraints.add(Eq(4, FloorDiv(s0, 2)))
|
|
dim_constraints.add(Eq(FloorDiv(s0, 2), 4))
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0)
|
|
dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1)
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9,
|
|
1,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9,
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
>= 0
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0))
|
|
dim_constraints.add(
|
|
1
|
|
< 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576
|
|
)
|
|
dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1))
|
|
dim_constraints.add(
|
|
Ne(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2)),
|
|
256,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64
|
|
* (
|
|
Mod(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2)),
|
|
4,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
FloorDiv(s0, 2),
|
|
FloorDiv(
|
|
(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2))
|
|
),
|
|
4,
|
|
),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
FloorDiv(
|
|
(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2))
|
|
),
|
|
4,
|
|
),
|
|
FloorDiv(s0, 2),
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0)
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64
|
|
* (
|
|
Mod(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 1,
|
|
4,
|
|
)
|
|
),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576 * (FloorDiv(s0, 2))
|
|
> 0
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
>= 1
|
|
)
|
|
dim_constraints.add(
|
|
Eq(
|
|
64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 576,
|
|
256,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
<= 2147483647
|
|
)
|
|
dim_constraints.add(
|
|
Ne(
|
|
(FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9 * (FloorDiv(s0, 2)),
|
|
0,
|
|
)
|
|
)
|
|
dim_constraints.add(
|
|
1
|
|
< (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
)
|
|
dim_constraints.add(
|
|
(FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 9
|
|
> 1
|
|
)
|
|
dim_constraints.add(
|
|
60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2
|
|
- 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8)
|
|
+ 540
|
|
> 1
|
|
)
|
|
dim_constraints.add(s0 >= 2)
|
|
dim_constraints.add(s1 >= 2)
|
|
dim_constraints.add(s6 >= 2)
|
|
dim_constraints.add(s5 >= 2)
|
|
|
|
dim_constraints.solve()
|
|
self.assertEqual(
|
|
dim_constraints._static_results,
|
|
{
|
|
"L['c'].size()[0] == 8",
|
|
"L['d'].size()[0] == 8",
|
|
"L['a'].size()[2] == 96",
|
|
"L['f'].size()[1] == 1",
|
|
"L['a'].size()[3] == 96",
|
|
"L['b'].size()[2] == 3",
|
|
"L['b'].size()[1] == 22",
|
|
"L['b'].size()[0] == 8",
|
|
"L['a'].size()[1] == 22",
|
|
"L['a'].size()[0] == 8",
|
|
},
|
|
)
|
|
self.assertEqual(
|
|
dim_constraints._dynamic_results,
|
|
{
|
|
"2 <= L['c'].size()[1]",
|
|
"L['d'].size()[1] == L['c'].size()[1]",
|
|
"L['e'].size()[1] == L['c'].size()[1]",
|
|
},
|
|
)
|
|
|
|
|
|
class TestGuardsExpressions(TestCase):
|
|
"""
|
|
Tests the guards-related methods used by the inductor FX graph cache.
|
|
"""
|
|
|
|
def test_guards_gt_lt(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 6)
|
|
s1 = create_symint(shape_env, 7)
|
|
s2 = create_symint(shape_env, 5)
|
|
|
|
guard_int(sym_int(s0 > 5))
|
|
guard_int(sym_int(s0 < 7))
|
|
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)]))
|
|
|
|
def test_guards_float_print(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 3)
|
|
guard_bool(2 / s0 == 2 / 3)
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
|
|
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_guard_or_true(self):
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
|
|
|
def func(a, b):
|
|
x = a.item()
|
|
if guard_or_true(x == 1):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
# eager.
|
|
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
|
|
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
|
|
|
|
# compile with unbacked.
|
|
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func(a, b)
|
|
|
|
# always return b*10
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
|
|
# Test that statically known true works.
|
|
def func2(a, b):
|
|
x = a.item()
|
|
if guard_or_true(x != x):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func2(a, b)
|
|
# always return b*20
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
|
|
# Test backed_size_oblivious
|
|
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
|
|
|
|
def func3(a, b):
|
|
if guard_or_true(a.size()[0] != 9):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
|
|
a = torch.rand(9, 2)
|
|
b = torch.rand(3, 4)
|
|
|
|
self.assertEqual(func3(a, b), b * 20)
|
|
self.assertEqual(compiled(a, b), b * 10)
|
|
|
|
@skipIfTorchDynamo("Not a TorchDynamo suitable test")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_guard_or_false(self):
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
|
|
def func(a, b):
|
|
x = a.item()
|
|
if guard_or_false(x == 1):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
# eager.
|
|
self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]))
|
|
self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]))
|
|
|
|
# compile with unbacked.
|
|
unbacked_func = torch.compile(func, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func(a, b)
|
|
|
|
# always return b*20
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])
|
|
)
|
|
|
|
# Test that statically known true works.
|
|
def func2(a, b):
|
|
x = a.item()
|
|
if guard_or_false(x == x):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True)
|
|
a = torch.tensor([1])
|
|
b = torch.tensor([1])
|
|
unbacked_func2(a, b)
|
|
# always return b*10
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
self.assertEqual(
|
|
unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10])
|
|
)
|
|
|
|
# Test backed_size_oblivious
|
|
with torch.fx.experimental._config.patch("backed_size_oblivious", True):
|
|
|
|
def func3(a, b):
|
|
if guard_or_false(a.size()[0] == 9):
|
|
return b * 10
|
|
else:
|
|
return b * 20
|
|
|
|
compiled = torch.compile(func3, dynamic=True, fullgraph=True)
|
|
a = torch.rand(9, 2)
|
|
b = torch.rand(3, 4)
|
|
|
|
self.assertEqual(func3(a, b), b * 10)
|
|
self.assertEqual(compiled(a, b), b * 20)
|
|
|
|
def test_guards_float_div(self):
|
|
shape_env = ShapeEnv()
|
|
s0 = create_symint(shape_env, 8)
|
|
s1 = create_symint(shape_env, 7)
|
|
|
|
guard_int(sym_int(s0 / 2.0))
|
|
guards = shape_env.produce_guards_expression([s0])
|
|
|
|
self.assertIn("math.trunc(", guards)
|
|
self.assertIn("float(", guards)
|
|
self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)]))
|
|
self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)]))
|
|
|
|
def test_remove_symbols_without_guarding(self):
|
|
from torch._functorch.partitioners import _remove_symbols_without_guarding
|
|
|
|
shape_env = ShapeEnv()
|
|
|
|
x = create_fake_tensor_with_dynamic_size(
|
|
torch.randn(5, 8),
|
|
shape_env,
|
|
dynamic_sizes=[
|
|
DimDynamic.DYNAMIC,
|
|
DimDynamic.DYNAMIC,
|
|
],
|
|
dynamic_strides=[
|
|
DimDynamic.INFER_STRIDE,
|
|
DimDynamic.INFER_STRIDE,
|
|
],
|
|
)
|
|
|
|
self.assertEqual(f"{x.stride()}", "(s49, 1)")
|
|
self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])")
|
|
|
|
x_clean = _remove_symbols_without_guarding(x, 4096)
|
|
|
|
self.assertEqual(f"{x_clean.stride()}", "(8, 1)")
|
|
self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])")
|
|
|
|
|
|
def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph:
|
|
for node in graph.nodes:
|
|
if node.name == "arg3_1":
|
|
assert node.meta["val"].size()[0] == 2
|
|
return graph
|
|
|
|
|
|
class TestUnbacked(TestCase):
|
|
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_deferred_neq_assert(self, backend):
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def func(a):
|
|
torch._check(a.item() != 5)
|
|
return a.item() * 10
|
|
|
|
func(torch.tensor([100]))
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.tensor([5]))
|
|
|
|
# Test a situation where we generate a runtime assert i.e: u1==s1, then we specialize s1
|
|
# later on to a constant.
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_post_specialize_runtime_assert1(self, backend):
|
|
@torch.compile(dynamic=True, backend=backend)
|
|
def func(x, y):
|
|
u0 = y.item()
|
|
s0 = x.size()[0]
|
|
s1 = x.size()[1]
|
|
torch._check(u0 + s0 + s1 == 102)
|
|
assert s0 == 2
|
|
return x * 10
|
|
|
|
func(torch.rand(2, 50), torch.tensor([50]))
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.rand(2, 50), torch.tensor([51]))
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_post_specialize_runtime_assert2(self, backend):
|
|
@torch.compile(dynamic=True, backend=backend)
|
|
def func(x, y):
|
|
u0 = y.item()
|
|
s0 = x.size()[0]
|
|
s1 = x.size()[1]
|
|
torch._check(u0 + s0 + s1 == 102)
|
|
return x * 10
|
|
|
|
func(torch.rand(2, 50), torch.tensor([50]))
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.rand(2, 50), torch.tensor([51]))
|
|
|
|
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_deferred_sym_or_assert(self, backend):
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def func(a, b):
|
|
torch._check(operator.or_(a.item() == 5, b.item() == 5))
|
|
return a.item() * 10
|
|
|
|
func(torch.tensor([5]), torch.tensor([100]))
|
|
func(torch.tensor([100]), torch.tensor([5]))
|
|
|
|
def test_has_free_symbols(self):
|
|
self.assertFalse(has_free_symbols(sympy.S.true))
|
|
self.assertFalse(has_free_symbols(sympy.Max(1, 10, evaluate=False)))
|
|
|
|
self.assertFalse(has_free_symbols(sympy.sympify("1")))
|
|
self.assertFalse(has_free_symbols(sympy.sympify("1.1")))
|
|
self.assertTrue(has_free_symbols(sympy.sympify("a")))
|
|
self.assertTrue(has_free_symbols(sympy.sympify("a*2")))
|
|
self.assertTrue(has_free_symbols(sympy.sympify("a+b")))
|
|
|
|
@skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
def test_deferred_sym_eq_assert(self, backend):
|
|
@torch.compile(fullgraph=True, backend=backend)
|
|
def func(a, b):
|
|
torch._check(b.item() == 5)
|
|
return a * 10
|
|
|
|
func(torch.tensor([5]), torch.tensor([5]))
|
|
with self.assertRaises(RuntimeError):
|
|
func(torch.tensor([100]), torch.tensor([1]))
|
|
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
@parametrize("backend", ["inductor", "eager"])
|
|
@skipIfTorchDynamo("mark_unbacked is not traceable")
|
|
def test_deferred_with_unbacked_input(self, backend):
|
|
@torch.compile(fullgraph=True, dynamic=True, backend=backend)
|
|
def func(a, b):
|
|
torch._check(a.size()[0] == b.size()[0])
|
|
return a * 10
|
|
|
|
a = torch.rand(1, 1)
|
|
b = torch.rand(1, 1)
|
|
torch._dynamo.decorators.mark_unbacked(a, 0)
|
|
torch._dynamo.decorators.mark_unbacked(b, 0)
|
|
func(a, b)
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
func(a, torch.rand(2, 1))
|
|
|
|
|
|
class TestUbackedOps(TestCase):
|
|
@fresh_cache()
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_unbacked_reshape1(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
# Reshape happens in place reshape (no-clone)
|
|
# reshape u1 -> (u0*u0)
|
|
def func(x, y):
|
|
f = y.item()
|
|
t1 = x.view((f, f))
|
|
t2 = x.reshape((f, f))
|
|
# TODO avoid _check_is_size here.
|
|
torch._check_is_size(f)
|
|
return t1 * 10, t2 * 10
|
|
|
|
compiled_func = torch.compile(
|
|
fullgraph=True,
|
|
backend=cnt,
|
|
dynamic=True,
|
|
)(func)
|
|
|
|
# create a non-contiguous with data being even numbers in [0:cnt-1]
|
|
# and reshape it into sqrt(cnt)*sqrt(cnt)
|
|
def make_non_contiguous_tensor_and_test(cnt):
|
|
# create a non-contiguous tensor x that is skipping odd indices.
|
|
x = torch.arange(cnt * 2)
|
|
x = x.as_strided((x.size()[0] // 2,), (2,))
|
|
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
sz = torch.tensor([int(math.sqrt(cnt))])
|
|
compiled_result = compiled_func(x, sz)
|
|
eager_result = func(x, sz)
|
|
self.assertEqual(compiled_result, eager_result)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
make_non_contiguous_tensor_and_test(4)
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"):
|
|
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
|
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
|
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
|
view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense])
|
|
view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None
|
|
mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
|
mul_12: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
|
return (mul_9, mul_12)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
make_non_contiguous_tensor_and_test(49)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization).
|
|
# marking strides unbacked would have avoided the recompilation here.
|
|
x = torch.arange(100)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
compiled_result = compiled_func(x, torch.tensor([10]))
|
|
eager_result = func(x, torch.tensor([10]))
|
|
self.assertEqual(compiled_result, eager_result)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"):
|
|
ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None
|
|
ge_3: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2
|
|
eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None
|
|
view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense])
|
|
view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None
|
|
mul_4: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
|
mul_7: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None
|
|
return (mul_4, mul_7)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
x = torch.arange(25)
|
|
compiled_result = compiled_func(x, torch.tensor([5]))
|
|
eager_result = func(x, torch.tensor([5]))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_unbacked_reshape2(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
# This reshape requires a clone when the input is not contiguous and we cant compute strides.
|
|
# reshape (u2, u3) -> (u0, u1)
|
|
def func(x, y):
|
|
u0, u1 = y.tolist()
|
|
torch._check_is_size(u0)
|
|
torch._check_is_size(u1)
|
|
|
|
result1 = torch.reshape(x, (u0, u1))
|
|
return result1 * 10
|
|
|
|
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
|
|
|
x = torch.randn(10, 10)
|
|
# make x not contiguous.
|
|
x = x.t_()
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
result_eager = func(x, torch.tensor([5, 20]))
|
|
result_compiled = compiled_func(x, torch.tensor([5, 20]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""\
|
|
def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"):
|
|
ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0)
|
|
_local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None
|
|
ge_5: "Sym(u0 >= 0)" = _local_scalar_dense >= 0
|
|
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_5 = _assert_scalar_2 = None
|
|
select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None
|
|
_local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None
|
|
ge_7: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0
|
|
_assert_scalar_3 = torch.ops.aten._assert_scalar.default(ge_7, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_7 = _assert_scalar_3 = None
|
|
mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None
|
|
mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1
|
|
eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None
|
|
_assert_scalar_4 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_4 = None
|
|
clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None
|
|
view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None
|
|
mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None
|
|
return (mul_21,)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
result_eager = func(x, torch.tensor([2, 50]))
|
|
result_compiled = compiled_func(x, torch.tensor([2, 50]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
x = torch.randn(4, 4).t_()
|
|
result_eager = func(x, torch.tensor([2, 8]))
|
|
result_compiled = compiled_func(x, torch.tensor([2, 8]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
# Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride.
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._functorch._aot_autograd.dispatch_and_compile_graph", "aot_graphs"
|
|
)
|
|
with ctx():
|
|
# This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3].
|
|
# but not anymore since we use contiguous_or_false .
|
|
# We need a way to mark strides unbacked to avoid the recompilation here.
|
|
x = torch.randn(10, 10)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
|
|
aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
aot_graphs,
|
|
"""""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
result_compiled = compiled_func(x, torch.tensor([2, 50]))
|
|
result_eager = func(x, torch.tensor([2, 50]))
|
|
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
x = torch.randn(4, 4)
|
|
|
|
result_eager = func(x, torch.tensor([2, 8]))
|
|
result_compiled = compiled_func(x, torch.tensor([2, 8]))
|
|
self.assertEqual(result_compiled, result_eager)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
@unittest.skip("this test fails due to inductor/autograd issue #153041")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_unbacked_non_contigious_reshape_failing(self):
|
|
# reshape u1 -> (u0*u0)
|
|
# this result in the tensor "i64[u0, u0][s7*u0, s7].
|
|
# reshape happens in place reshape (no-clone)
|
|
def func(x, y):
|
|
f = y.item()
|
|
t1 = x.view((f, f))
|
|
t2 = x.reshape((f, f))
|
|
return t1, t2
|
|
|
|
# create a non-contiguous with data being even numbers in [0:cnt-1]
|
|
def make_non_contiguous_tensor(cnt):
|
|
# create a non-contiguous tensor x that is skipping odd indices.
|
|
x = torch.arange(cnt * 2)
|
|
x = x.as_strided((x.size()[0] // 2,), (2,))
|
|
return x
|
|
|
|
x = make_non_contiguous_tensor(4)
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
compiled_func = torch.compile(
|
|
fullgraph=True,
|
|
backend="inductor",
|
|
)(func)
|
|
|
|
compiled_result = compiled_func(x, torch.tensor([2]))
|
|
eager_result = func(x, torch.tensor([2]))
|
|
self.assertEqual(compiled_result, eager_result)
|
|
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
|
def test_invalid_view_unbacked_view(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
# This view (u2, u3) -> (u0, u1) cant happen in general unless we know that input is contigous or we have
|
|
# hints to to compute strides.
|
|
def func(x, y):
|
|
u0, u1 = y.tolist()
|
|
torch._check_is_size(u0)
|
|
torch._check_is_size(u1)
|
|
|
|
result2 = x.view(u0, u1) * 10
|
|
return result2
|
|
|
|
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
|
|
|
x = torch.randn(10, 10)
|
|
# make x not contiguous.
|
|
x = x.t_()
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
with self.assertRaises(torch._dynamo.exc.UserError):
|
|
# throws a data dependent error.
|
|
compiled_func(x, torch.tensor([5, 20]))
|
|
|
|
@skipIfTorchDynamo()
|
|
def test_unbind_not_dynamic(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(fullgraph=True, dynamic=True, backend=cnt)
|
|
def func(y):
|
|
return y.unbind(dim=2), y * 10
|
|
|
|
func(torch.ones(5, 6, 7, 8))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
# it can be dynamic in all dimentions except dim=2
|
|
func(torch.ones(4, 9, 7, 10))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
func(torch.ones(5, 6, 8, 8))
|
|
func(torch.ones(5, 6, 9, 8))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
@skipIfTorchDynamo("not allowed to trace mark_unbacked")
|
|
@fresh_cache()
|
|
def test_unbacked_contiguous(self):
|
|
cnt = CompileCounterWithBackend("inductor")
|
|
|
|
def func(x):
|
|
contig = x.contiguous()
|
|
return (contig + 1) * 100
|
|
|
|
compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func)
|
|
|
|
x = torch.randn(10, 10)
|
|
# make x not contiguous.
|
|
x = x.t_()
|
|
torch._dynamo.decorators.mark_unbacked(x, 0)
|
|
torch._dynamo.decorators.mark_unbacked(x, 1)
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
compiled_func(x)
|
|
self.assertEqual(compiled_func(x), func(x))
|
|
y = torch.rand(20, 20).t()
|
|
self.assertEqual(compiled_func(y), func(y))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
self.assertExpectedInline(
|
|
output,
|
|
"""\
|
|
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None
|
|
add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
|
mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None
|
|
return (mul_6,)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
log_stream, ctx = logs_to_string(
|
|
"torch._inductor.compile_fx", "post_grad_graphs"
|
|
)
|
|
with ctx():
|
|
# recompilation will happen due to stride specialization.
|
|
y = torch.rand(20, 20)
|
|
torch._dynamo.decorators.mark_unbacked(y, 0)
|
|
torch._dynamo.decorators.mark_unbacked(y, 1)
|
|
self.assertEqual(compiled_func(y), func(y))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip()
|
|
|
|
# No clone this time since input is contiguous.
|
|
self.assertExpectedInline(
|
|
output,
|
|
"""\
|
|
ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None
|
|
_assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None
|
|
ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None
|
|
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None
|
|
add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None
|
|
mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None
|
|
return (mul_5,)""", # noqa: B950
|
|
ignore_comments=True,
|
|
ignore_empty_lines=True,
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestUnbacked)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|