# Owner(s): ["oncall: jit"] # ruff: noqa: F841 import contextlib import copy import itertools import math import operator import unittest import numpy as np import pytest import sympy import torch import torch.fx import torch.nn.functional as F from torch import sym_int, SymBool, SymFloat, SymInt from torch._C import _disabled_torch_function_impl from torch._dynamo.testing import CompileCounter, CompileCounterWithBackend from torch._inductor.utils import fresh_cache from torch.fx.experimental import sym_node from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.sym_node import method_to_operator, SymNode, to_node from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, DimConstraints, DimDynamic, expect_true, guard_bool, guard_float, guard_int, GuardOnDataDependentSymNode, has_free_symbols, hint_int, is_symbolic, ShapeEnv, StatelessSymbolicContext, statically_known_false, statically_known_true, ) from torch.testing._internal.common_dtype import all_types_and from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, run_tests, skipIfTorchDynamo, TestCase, ) from torch.testing._internal.logging_utils import logs_to_string from torch.utils import _pytree as pytree from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._sympy.functions import ( CleanDiv, FloorDiv, IsNonOverlappingAndDenseIndicator, Mod, ) aten = torch.ops.aten meta_funcs = {} def register_meta(op): def decorator(f): def add_func(op): meta_funcs[op] = f pytree.tree_map_(add_func, op) return f return decorator @register_meta([aten.add.Tensor, aten.sub.Tensor]) def binary_meta(a, b): return a.new_empty(a.shape) @register_meta(aten.cat.default) def cat_meta(tensors, dim=0): concat_length = 0 shape = tensors[0].shape for tensor in tensors: for idx, (common_length, length) in enumerate(zip(shape, tensor.shape)): if idx == dim: concat_length = concat_length + length else: assert length == common_length new_shape = list(shape) new_shape[dim] = concat_length return tensors[0].new_empty(new_shape) @register_meta([aten.narrow_copy.default]) def narrow_copy_symint_meta(a, dim, start, length, **kwargs): shape = [] for i, x in enumerate(a.shape): if i == dim: shape.append(length) else: shape.append(x) return a.new_empty(tuple(shape)) @register_meta([aten.expand.default]) def expand_symint_meta(a, size, implicit=False): return a.new_empty(size) def create_contiguous(shape): strides = [1] for dim in reversed(shape[:-1]): strides.append(dim * strides[-1]) return list(reversed(strides)) class FakeSymbolicTensor(torch.Tensor): @staticmethod def __new__( cls, sym_shape, sym_strides, dtype, layout, requires_grad, device, storage_offset=0, ): # TODO: this is wrong in general sym_stride = create_contiguous(sym_shape) r = torch.Tensor._make_wrapper_subclass( cls, sym_shape, sym_stride, storage_offset, dtype=dtype, layout=layout, requires_grad=requires_grad, device=device, ) return r __torch_function__ = _disabled_torch_function_impl def new_empty(self, shape): return FakeSymbolicTensor( shape, None, self.dtype, self.layout, self.requires_grad, self.device ) @classmethod def __torch_dispatch__(cls, func_overload, types, args=(), kwargs=None): if func_overload in meta_funcs: return meta_funcs[func_overload](*args, **kwargs) if func_overload == torch.ops.aten.new_empty.default: self = args[0] shape = args[1] return FakeSymbolicTensor( shape, self.stride(), self.dtype, self.layout, self.requires_grad, self.device, ) raise RuntimeError(f"operator {func_overload} not supported") def create_symbolic_tensor(name, arg, shape_env, source=None, dynamic_dims=None): from torch._dynamo.source import ConstantSource if source is None: source = ConstantSource(name) constraint_dims = [None] * arg.dim() if dynamic_dims is None: dynamic_dims = [DimDynamic.DUCK] * arg.dim() ( sym_shapes, sym_strides, sym_storage_offset, ) = shape_env.create_symbolic_sizes_strides_storage_offset( arg, source=source, symbolic_context=StatelessSymbolicContext( dynamic_sizes=dynamic_dims, constraint_sizes=constraint_dims ), ) return FakeSymbolicTensor( sym_shapes, sym_strides, arg.dtype, arg.layout, arg.requires_grad, arg.device, sym_storage_offset, ) def create_fake_tensor_with_dynamic_size(x, shape_env, dynamic_sizes, dynamic_strides): from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(shape_env=shape_env) as fake_mode: return fake_mode.from_tensor( x, symbolic_context=StatelessSymbolicContext( dynamic_sizes=dynamic_sizes, dynamic_strides=dynamic_strides, ), ) def create_symtype(cls, pytype, shape_env, val, duck=True, **kwargs): from torch._dynamo.source import ConstantSource symbol = shape_env.create_symbol( val, source=ConstantSource(f"__testing_only{len(shape_env.var_to_val)}"), dynamic_dim=DimDynamic.DUCK if duck else DimDynamic.DYNAMIC, constraint_dim=None, **kwargs, ) return cls(SymNode(symbol, shape_env, pytype, hint=val)) # TODO: default duck to False def create_symint(shape_env, i: int, duck=True, **kwargs) -> SymInt: return create_symtype(SymInt, int, shape_env, i, duck=duck, **kwargs) def create_symbool(shape_env, b: bool) -> SymBool: return create_symtype(SymBool, bool, shape_env, b) def create_symfloat(shape_env, f: float) -> SymFloat: return create_symtype(SymFloat, float, shape_env, f) @skipIfTorchDynamo( "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" ) class TestPySymInt(TestCase): def test_arith_ops(self): shape_env = ShapeEnv() symints = [] for i in range(2, 5): symints.append((i, create_symint(shape_env, i))) ops = [ operator.add, operator.sub, operator.floordiv, operator.mul, operator.mod, ] for op in ops: for args in itertools.permutations(symints, 2): if not isinstance(args[0][1], int) and ( (op != operator.mod or op != operator.floordiv) and args[1][0] != 0 ): self.assertTrue( op(args[0][1], args[1][1]) == op(args[0][0], args[1][0]) ) def test_reverse_arith_ops(self): shape_env = ShapeEnv() a = create_symint(shape_env, 2) self.assertTrue(5 // a == 5 // 2) a = create_symint(shape_env, 2) self.assertTrue(5 * a == 5 * 2) def test_sympify_symint(self): shape_env = ShapeEnv() a = create_symint(shape_env, 2) self.assertIs(sympy.sympify(a), a.node.expr) b = create_symfloat(shape_env, 3.0) self.assertIs(sympy.sympify(b), b.node.expr) c = create_symbool(shape_env, True) self.assertIs(sympy.sympify(c), c.node.expr) def test_roundtrip(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) self.assertTrue(not isinstance(x.shape[0], SymNode)) self.assertTrue(isinstance(x.shape[0], SymInt)) self.assertTrue(x.shape[0] == 5) self.assertTrue(x.shape[1] == 4) self.assertTrue(x.shape[2], 3) self.assertTrue(x.size()[0], 5) self.assertTrue(x.size()[1], 4) # Should be simplifiable to an integer. # Ref: https://github.com/pytorch/pytorch/pull/107492 self.assertTrue(isinstance(x.size()[1], SymInt)) self.assertTrue( isinstance(x.size()[1].node.maybe_as_int(), int) ) # due to guard above self.assertTrue(x.size()[2] == 3) self.assertTrue(x.size(0) == 5) self.assertTrue(x.size(1) == 4) self.assertTrue(x.size(2) == 3) self.assertTrue(isinstance(x.size(2), SymInt)) self.assertTrue(isinstance(x.size(2).node.maybe_as_int(), int)) y = create_symbolic_tensor("y", torch.randn(5, 4, 3)[1:], shape_env) self.assertTrue(isinstance(y.storage_offset(), SymInt)) self.assertTrue(y.storage_offset() == 12) def test_binary(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(5, 4, 3), shape_env) z = x + y self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # broadcasting y = create_symbolic_tensor("y2", torch.randn(1, 4, 1), shape_env) z = x + y self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) def test_symint_args(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(5, 4, 1), shape_env) LAST_DIM = 2 z = x.narrow_copy(LAST_DIM, 0, y.shape[LAST_DIM]) self.assertTrue(z.shape[2] == y.shape[2]) # arithmetic expr with two symints z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - y.shape[LAST_DIM]) self.assertTrue(z.shape[2] == 2) # arithmetic expr with a symint and python int z = x.narrow_copy(LAST_DIM, 0, x.shape[LAST_DIM] - 1) self.assertTrue(z.shape[2] == 2) def test_symint_vargs(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 4, 3), shape_env) y = create_symbolic_tensor("y", torch.randn(1, 4, 1), shape_env) # varargs z = y.expand(x.shape[0], y.shape[1], x.shape[2]) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # shape list z = y.expand((x.shape[0], y.shape[1], x.shape[2])) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python symints and ints z = y.expand(x.shape[0], y.shape[1], 3) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python symints and ints in a list z = y.expand((x.shape[0], y.shape[1], 3)) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python symints and ints z = y.expand(5, y.shape[1], x.shape[2]) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) # mixed python ints and symints in a list z = y.expand((5, y.shape[1], x.shape[2])) self.assertTrue(z.shape[0] == 5) self.assertTrue(z.shape[1] == 4) self.assertTrue(z.shape[2] == 3) z = y.expand((y.shape[1],)) z = y.expand(y.shape[1]) def test_symint_bitwise_and(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 0b1100) b0 = create_symint(shape_env, 0b1010) res_and = a0 & b0 self.assertEqual(res_and, 0b1000) self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s97, s26), 8)""" ) a1 = create_symint(shape_env, 3) b1 = create_symbool(shape_env, True) self.assertEqual(a1 & b1, 1) a2 = create_symint(shape_env, 0b1100) self.assertEqual(a2 & 0b1010, 0b1000) a3 = create_symbool(shape_env, True) b3 = create_symbool(shape_env, True) self.assertEqual(a3 & b3, True) def test_symint_bitwise_or(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 0b1100) b0 = create_symint(shape_env, 0b1010) res_or = a0 | b0 self.assertEqual(res_or, 0b1110) self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s97, s26), 14)""" ) def test_stride(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) self.assertIsInstance(x.stride()[0], SymInt) def test_size_expressions(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) expand_x = x.expand(x.shape[0], x.shape[0]) if expand_x.shape[0] > 3: result = expand_x + expand_x else: result = expand_x + expand_x gt_op, _bt, is_size_obv = shape_env.guards[-1] self.assertTrue(isinstance(gt_op, sympy.core.relational.StrictGreaterThan)) self.assertTrue(str(x.shape[0]), str(gt_op.args[0])) self.assertTrue(str(expand_x.shape[1]), str(x.shape[0])) self.assertTrue(str(expand_x.shape[1]), str(result.shape[0])) self.assertFalse(is_size_obv) def test_floordiv_static(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 8) # This was extracted from # python test/inductor/test_cuda_cpp_wrapper.py -k # DynamicShapesCudaWrapperCudaTests.test_insignificant_strides_cuda_dynamic_shapes_cuda_wrapper bool(s0 % 2 == 0) bool(s0 % (s0 // 2) == 0) bool(2 * (s0 // 2) == s0) self.assertTrue(statically_known_true(s0 // (s0 // 2) == 2)) def test_numel(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) self.assertIsInstance(x.numel(), torch.SymInt) self.assertIsInstance(torch.numel(x), torch.SymInt) x = torch.rand(3, 3) self.assertIsInstance(x.numel(), int) self.assertIsInstance(torch.numel(x), int) def test_int_to_float(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) r = torch.sym_float(x.shape[0]) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) def test_aten_ops(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5), shape_env) torch.ops.aten.narrow_copy.default(x, 0, 0, x.shape[0]) shape_env = ShapeEnv() x = create_symbolic_tensor("x2", torch.randn(5, 4, 3), shape_env) torch.ops.aten.expand.default(x, [x.shape[0], x.shape[1], x.shape[2]]) def test_fx_trace_intlist(self): class CustomModule(torch.nn.Module): def forward(self, x): bs, c, h, w = x.shape return F.pad(x, (0, w % 2, 0, h % 2, 0, 0)) m = CustomModule() x = torch.rand(1, 3, 4, 4) # should not TypeError: pad(): argument 'pad' (position 2) must be # tuple of ints, not tuple torch.fx.symbolic_trace(m) def test_meta_symint(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) r = torch.empty(a0, device="meta") self.assertIsInstance(r.shape[0], SymInt) def test_guard_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) self.assertEqual(guard_int(a0), 2) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""") def test_sym_sum(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 2) s1 = create_symint(shape_env, 3) s2 = create_symint(shape_env, 4) self.assertEqual( (s0 + s1 + s2).node.expr, torch.sym_sum([s0, s1, s2]).node.expr ) def test_prefer_deferred_runtime_assertions_over_guards(self): shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) s0 = create_symint(shape_env, 2) self.assertEqual(guard_int(s0), 2) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""") shape_env = ShapeEnv(prefer_deferred_runtime_asserts_over_guards=True) s0 = create_symint(shape_env, 2) self.assertTrue(expect_true(s0 == 2)) self.assertEqual(len(shape_env.guards), 0) self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[None]]), """[Eq(s97, 2)]""", ) def test_sym_int(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = sym_int(a0) self.assertEqual(r, 5) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 5)""") a1 = create_symint(shape_env, 7) r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s26, 2)), 3)""" ) a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s57)), 6)""" ) def test_sym_log2(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 4) r = torch._sym_log2(a0) self.assertEqual(r, 2.0) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_log2(ToFloat(s97)), 2.0)""" ) def test_sym_sqrt(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 4) r = torch._sym_sqrt(a0) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(ToFloat(s97)), 2.0)""" ) def test_sym_floor(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(FloorToInt(IntTrueDiv(s97, 2)), 2)""", ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), """Eq(FloorToInt(3.0*ToFloat(s97)), 15)""", ) def test_sym_trunc(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s97, 2)), 2)""" ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(ToFloat(s97))), 2)""", ) def test_sym_ceil(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = math.ceil(a0 / 2) self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(CeilToInt(IntTrueDiv(s97, 2)), 3)""", ) r1 = 3.0 * a0 r = math.floor(r1) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( str(shape_env.guards[1][0]), """Eq(FloorToInt(3.0*ToFloat(s97)), 15)""", ) def test_sym_ite(self): shape_env = ShapeEnv() t = create_symint(shape_env, 5) f = create_symint(shape_env, 4) b1 = True r1 = torch.sym_ite(b1, t, f) self.assertTrue(r1 is t) b2 = False r2 = torch.sym_ite(b2, t, f) self.assertTrue(r2 is f) b3 = t == 5 r3 = torch.sym_ite(b3, t, f) self.assertEqual(len(shape_env.guards), 0) self.assertEqual(r3, 5) self.assertEqual(type(t), type(r3)) self.assertExpectedInline( str(shape_env.guards[0][0]), """Eq(Piecewise((s97, Eq(s97, 5)), (s26, True)), 5)""", ) b4 = f == 5 r4 = torch.sym_ite(b4, t, f) self.assertEqual(len(shape_env.guards), 1) self.assertEqual(r4, 4) self.assertEqual(type(f), type(r4)) self.assertExpectedInline( str(shape_env.guards[1][0]), """Eq(Piecewise((s97, Eq(s26, 5)), (s26, True)), 4)""", ) def test_tracing_sym_ite(self): def f(x): b = x.shape[0] == 5 ret = torch.sym_ite(b, x.shape[0], x.shape[1]) return ret gm = make_fx(f, tracing_mode="symbolic")(torch.ones(4, 5)) self.assertEqual(len(gm.shape_env.guards), 0) self.assertExpectedInline( gm.code.strip(), """\ def forward(self, x_1): sym_size_int = torch.ops.aten.sym_size.int(x_1, 0) eq = sym_size_int == 5 sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1); x_1 = None sym_ite = torch.sym_ite(eq, sym_size_int, sym_size_int_1); eq = sym_size_int = sym_size_int_1 = None return sym_ite""", ) r1 = gm(torch.ones(4, 5)) self.assertIsInstance(r1, int) self.assertEqual(r1, 5) r2 = gm(torch.ones(5, 4)) self.assertIsInstance(r2, int) self.assertEqual(r2, 5) def test_int_conversion(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) int(a0) self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(s97, 2)""") def test_data_dependent_guard(self): shape_env = ShapeEnv() s0 = shape_env.create_unbacked_symint() self.assertRaises(GuardOnDataDependentSymNode, lambda: bool(s0 == 0)) def test_data_dependent_guard_propagate_real_tensors(self): shape_env = ShapeEnv() s0 = shape_env.create_unbacked_symint() shape_env.set_unbacked_var_to_val(s0.node.expr, 0) self.assertEqual(bool(s0 == 0), True) def test_expect_true_basic(self): shape_env = ShapeEnv() i0 = shape_env.create_unbacked_symint() i0_sym = i0.node.expr # This doesn't error self.assertTrue(expect_true(i0 == 0)) # This generates a deferred runtime assert via replacement self.assertEqual(shape_env.replacements[i0_sym], 0) # After expecting true, guards now resolve given the runtime assert bool(i0 == 0) def test_expect_true_with_s0(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 5) i0 = shape_env.create_unbacked_symint() self.assertTrue(expect_true(i0 < s0)) self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[i0.node.expr]]), """[u0 < s97]""", ) self.assertTrue(i0 < s0) self.assertTrue(i0 != s0) self.assertFalse(i0 > s0) self.assertFalse(i0 >= s0) def test_expect_true_prefer_later(self): shape_env = ShapeEnv() i0 = shape_env.create_unbacked_symint() i1 = shape_env.create_unbacked_symint() i1_sym = i1.node.expr self.assertTrue(expect_true(i0 + i1 == 10)) # Importantly, this is put in i1, not i0! self.assertExpectedInline( str([ra.expr for ra in shape_env.deferred_runtime_asserts[i1_sym]]), """[Eq(u0 + u1, 10)]""", ) self.assertTrue(i0 + i1 == 10) # NB: We currently don't support deriving that we can substitute # i0 + i1 with 10; maybe we should, but this means our rewriting # system is no longer confluent (it's probably OK though, because # you're unlikely to get other equalities like this on the # unbacked SymInts.) def test_unbacked_substitution(self): shape_env = ShapeEnv() i0 = shape_env.create_unbacked_symint() i1 = shape_env.create_unbacked_symint() _constrain_range_for_size(i0) _constrain_range_for_size(i1) self.assertTrue(expect_true(i0 == i1 * 4)) self.assertExpectedInline(str(i0), """u0""") i2 = shape_env.create_unbacked_symint() i3 = shape_env.create_unbacked_symint() _constrain_range_for_size(i2) _constrain_range_for_size(i3) self.assertTrue(expect_true(i2 * 4 == i3)) self.assertExpectedInline(str(i3), """u3""") def test_avoid_unbacked_substitution(self): shape_env = ShapeEnv() i0 = shape_env.create_unbacked_symint() _constrain_range_for_size(i0) i1 = shape_env.create_unbacked_symint() _constrain_range_for_size(i1) self.assertTrue(expect_true(i0 == 10 - i1)) self.assertExpectedInline(str(i0), """u0""") def test_expect_true_double_digits(self): shape_env = ShapeEnv() ia = [shape_env.create_unbacked_symint() for _ in range(11)] # allocate 10 self.assertEqual(str(ia[-1]), "u10") self.assertTrue(expect_true(sum(ia) == 20)) self.assertEqual(len(shape_env.deferred_runtime_asserts[ia[-1].node.expr]), 1) def test_expect_true_refine_range(self): shape_env = ShapeEnv() for i, rel in enumerate( [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] ): with self.subTest(f"i = {i}"): i0 = shape_env.create_unbacked_symint() self.assertTrue(expect_true(rel(i0))) self.assertTrue(statically_known_true(i0 != 3)) self.assertTrue(statically_known_true(i0 != 4)) self.assertFalse(statically_known_true(i0 != 5)) self.assertFalse(statically_known_true(i0 != 6)) self.assertTrue(statically_known_true(i0 > 4)) self.assertTrue(statically_known_true(i0 >= 5)) for i, rel in enumerate( [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] ): with self.subTest(f"i = {i}"): i0 = shape_env.create_unbacked_symint() self.assertTrue(expect_true(rel(i0))) self.assertFalse(statically_known_true(i0 != 2)) self.assertFalse(statically_known_true(i0 != 3)) self.assertTrue(statically_known_true(i0 != 4)) self.assertTrue(statically_known_true(i0 != 5)) self.assertTrue(statically_known_true(i0 < 4)) self.assertTrue(statically_known_true(i0 <= 5)) def test_guard_refine_range(self): shape_env = ShapeEnv() for i, rel in enumerate( [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] ): with self.subTest(f"i = {i}"): i0 = create_symint(shape_env, 10, duck=False) self.assertTrue(bool(rel(i0))) self.assertTrue(statically_known_true(i0 != 3)) self.assertTrue(statically_known_true(i0 != 4)) self.assertFalse(statically_known_true(i0 != 5)) self.assertFalse(statically_known_true(i0 != 6)) self.assertTrue(statically_known_true(i0 > 4)) self.assertTrue(statically_known_true(i0 >= 5)) for i, rel in enumerate( [lambda x: x > 4, lambda x: 4 < x, lambda x: x >= 5, lambda x: 5 <= x] ): with self.subTest(f"i = {i}"): i0 = create_symint(shape_env, 2, duck=False) self.assertFalse(bool(rel(i0))) self.assertFalse(statically_known_true(i0 != 3)) self.assertFalse(statically_known_true(i0 != 4)) self.assertTrue(statically_known_true(i0 != 5)) self.assertTrue(statically_known_true(i0 != 6)) self.assertTrue(statically_known_true(i0 <= 4)) self.assertTrue(statically_known_true(i0 < 5)) for i, rel in enumerate( [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] ): with self.subTest(f"i = {i}"): i0 = create_symint(shape_env, 2, duck=False) self.assertTrue(bool(rel(i0))) self.assertFalse(statically_known_true(i0 != 2)) self.assertFalse(statically_known_true(i0 != 3)) self.assertTrue(statically_known_true(i0 != 4)) self.assertTrue(statically_known_true(i0 != 5)) self.assertTrue(statically_known_true(i0 < 4)) self.assertTrue(statically_known_true(i0 <= 3)) for i, rel in enumerate( [lambda x: x < 4, lambda x: 4 > x, lambda x: x <= 3, lambda x: 3 >= x] ): with self.subTest(f"i = {i}"): i0 = create_symint(shape_env, 10, duck=False) self.assertFalse(bool(rel(i0))) self.assertTrue(statically_known_true(i0 != 2)) self.assertTrue(statically_known_true(i0 != 3)) self.assertFalse(statically_known_true(i0 != 4)) self.assertFalse(statically_known_true(i0 != 5)) self.assertTrue(statically_known_true(i0 >= 4)) self.assertTrue(statically_known_true(i0 > 3)) def test_mul_int_oo_nan(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 5, duck=False) s1 = create_symint(shape_env, 6, duck=False) s2 = create_symint(shape_env, 5, duck=False) bool(s0 * (s1 // s0) == s2) def test_non_overlapping_and_dense_backed(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 5) r = torch.empty_strided((a0, 7), (1, a0), device="meta") self.assertTrue(torch.ops.aten.is_non_overlapping_and_dense.default(r)) def test_non_overlapping_and_dense_unbacked(self): shape_env = ShapeEnv() u0 = shape_env.create_unbacked_symint() cf = torch.ops.aten.is_non_overlapping_and_dense.default self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 2, 2, 1), 1) self.assertEqual(IsNonOverlappingAndDenseIndicator(2, u0.node.expr, 1, 2), 1) self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta"))) self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta"))) self.assertEqual(IsNonOverlappingAndDenseIndicator(u0.node.expr, 1), 1) self.assertEqual(IsNonOverlappingAndDenseIndicator(1, u0.node.expr), 1) self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta"))) self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta"))) Max = torch.sym_max # NB: This only works because we're able to determine this tensor is # contiguous. transpose(0, 1) makes it stop working self.assertTrue( cf( torch.empty_strided( (2, 3, 1, u0), (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), device="meta", ) ) ) def test_prims_non_overlapping_and_dense(self): shape_env = ShapeEnv() cf = torch._prims_common.is_non_overlapping_and_dense # backed case a0 = create_symint(shape_env, 5) self.assertTrue(cf(torch.empty_strided((a0, 7), (1, a0), device="meta"))) # unbacked u0 = shape_env.create_unbacked_symint() self.assertTrue(cf(torch.empty_strided((u0, 2), (2, 1), device="meta"))) self.assertTrue(cf(torch.empty_strided((2, u0), (1, 2), device="meta"))) self.assertTrue(cf(torch.empty_strided((u0,), (1,), device="meta"))) self.assertTrue(cf(torch.empty_strided((1,), (u0,), device="meta"))) Max = torch.sym_max self.assertTrue( cf( torch.empty_strided( (2, 3, 1, u0), (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), device="meta", ) ) ) self.assertFalse( cf( torch.empty_strided( (2, 3, 1, u0), (Max(1, u0), Max(1, u0), 1, 3 * Max(1, u0)), device="meta", ) ) ) # return False on arbitrary strides u1 = shape_env.create_unbacked_symint() self.assertFalse( cf( torch.empty_strided( (2 * u0, u0, 1), (u1, u0, u0 + u1), device="meta", ) ) ) self.assertFalse( cf( torch.empty_strided( (2, 3, u0), (u1, 3, 1), device="meta", ) ) ) def test_sympy_optimized_add_binary_search(self): import sympy from torch.fx.experimental.sym_node import _binary_search_insert_arg a = sympy.Symbol("a") b = sympy.Symbol("b") c = sympy.Symbol("c") args = [] args = _binary_search_insert_arg([], b) self.assertEqual(args, [b]) self.assertEqual(_binary_search_insert_arg(args, b), None) args = _binary_search_insert_arg(args, a) self.assertEqual(args, [a, b]) self.assertEqual(_binary_search_insert_arg(args, b), None) self.assertEqual(_binary_search_insert_arg(args, a), None) args = _binary_search_insert_arg(args, c) self.assertEqual(args, [a, b, c]) self.assertEqual(_binary_search_insert_arg(args, a), None) self.assertEqual(_binary_search_insert_arg(args, b), None) self.assertEqual(_binary_search_insert_arg(args, c), None) a1 = sympy.Symbol("a1") a2 = sympy.Symbol("a2") args = _binary_search_insert_arg(args, a1) self.assertEqual(args, [a, a1, b, c]) args = _binary_search_insert_arg(args, a2) self.assertEqual(args, [a, a1, a2, b, c]) c1 = sympy.Symbol("c1") args = _binary_search_insert_arg(args, c1) self.assertEqual(args, [a, a1, a2, b, c, c1]) # insert to front _a = sympy.Symbol("_a") args = _binary_search_insert_arg(args, _a) self.assertEqual(args, [_a, a, a1, a2, b, c, c1]) def test_floor_clean_div_axioms(self): # Test that if we add an axiom that have FloorDiv, after which the # shapeEnv changed such that it can be simplified it to CleanDiv, then # We still correctly replace CleanDiv with the axiom value of FloorDiv. shape_env = ShapeEnv() a = shape_env.create_unbacked_symint() shape_env.guard_or_defer_runtime_assert((a // 3 == 1).node.expr, " test") from sympy import Eq test1 = Eq(FloorDiv(a.node.expr, 3), 1) test2 = Eq(CleanDiv(a.node.expr, 3), 1) self.assertTrue(shape_env.evaluate_expr(test1)) self.assertEqual(shape_env._maybe_evaluate_static(test2), None) # After this FloorDiv(a, 3) is simplified to CleanDiv(a, 3) shape_env.guard_or_defer_runtime_assert(Eq(Mod(a, 3), 0), " test") self.assertEqual(test2, shape_env.simplify(test1)) self.assertTrue(shape_env.evaluate_expr(test1)) self.assertTrue(shape_env.evaluate_expr(test2)) def test_sympy_optimized_add(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 2) s1 = create_symint(shape_env, 3) s2 = create_symint(shape_env, 4) sum = s0 + s1 self.assertTrue(sum.node._optimized_summation) def assert_optimized(sym): self.assertTrue(sym.node._optimized_summation) def assert_not_optimized(sym): self.assertFalse(getattr(sym.node, "_optimized_summation", False)) assert_optimized(sum) # add duplicate symbol assert_not_optimized(sum + s0) # add constant. assert_not_optimized(sum + 1) # add new unique symbol, should maintain _optimized_summation property. assert_optimized(sum + s2) assert_optimized(s2 + sum) # add x + (a+b) with no _optimized_summation on the rhs sum. a = create_symint(shape_env, 10) b = create_symint(shape_env, 11) two_sum = torch.sym_sum([a, b]) assert_not_optimized(two_sum) assert_optimized(sum + two_sum) # adding two expressions of length >2 that are _optimized_summation. a = s0 + s1 + s2 s3 = create_symint(shape_env, 10) s4 = create_symint(shape_env, 20) s5 = create_symint(shape_env, 30) b = s3 + s4 + s5 assert_optimized(a) assert_optimized(b) assert_not_optimized(a + b) assert_not_optimized(b + a) assert_not_optimized(b + a + b) def test_max_of_unique_summation_opt(self): shape_env = ShapeEnv() s0 = shape_env.create_unbacked_symint() s1 = shape_env.create_unbacked_symint() s2 = shape_env.create_unbacked_symint() s3 = shape_env.create_unbacked_symint() s4 = shape_env.create_unbacked_symint() s5 = shape_env.create_unbacked_symint() s7 = shape_env.create_unbacked_symint() def assert_optimized(sym): self.assertTrue(sym.node.expr.unique_summations_symbols is not None) def assert_not_optimized(sym): getattr(sym.node.expr, "unique_summations_symbols", None) mx1 = torch.sym_max(s0, s1) assert_not_optimized(mx1) mx2 = torch.sym_max(s0 + s1, s2 + s3) assert_optimized(mx2) mx3 = torch.sym_max(mx2, s4 + s5) assert_optimized(mx3) assert_optimized(torch.sym_max(s4 + s5, mx2)) assert_not_optimized(torch.sym_max(mx3, s7)) assert_not_optimized(torch.sym_max(mx3, 10)) assert_not_optimized(torch.sym_max(mx3, s3 + s7)) assert_not_optimized(torch.sym_max(mx3, s7 * 2)) def test_sym_max_multi_max_simplify(self): shape_env = ShapeEnv() u0 = shape_env.create_unbacked_symint() self.assertTrue( statically_known_true( torch.sym_max(1, torch.sym_max(257, u0)) == torch.sym_max(257, u0) ) ) def test_numpy_sym_max(self): self.assertEqual(torch.sym_max(np.int64(10), 12), 12) self.assertEqual(torch.sym_max(np.int64(12), 10), 12) self.assertEqual(torch.sym_max(np.int64(10), 12.5), 12.5) self.assertEqual(torch.sym_max(np.int64(14), 12.5), 14.0) self.assertEqual(torch.sym_max(np.float64(14.0), 12), 14.0) self.assertEqual(torch.sym_max(np.float64(14.0), 16), 16.0) def test_numpy_sym_min(self): self.assertEqual(torch.sym_min(np.int64(10), 12), 10) self.assertEqual(torch.sym_min(np.int64(12), 10), 10) self.assertEqual(torch.sym_min(np.int64(10), 12.5), 10.0) self.assertEqual(torch.sym_min(np.int64(14), 12.5), 12.5) self.assertEqual(torch.sym_min(np.float64(14.0), 12), 12.0) self.assertEqual(torch.sym_min(np.float64(14.0), 16), 14.0) def test_debug_has_internal_overlap_unbacked(self): shape_env = ShapeEnv() u0 = shape_env.create_unbacked_symint() cf = torch._debug_has_internal_overlap self.assertEqual(cf(torch.empty_strided((u0, 2), (2, 1), device="meta")), 0) self.assertEqual(cf(torch.empty_strided((2, u0), (1, 2), device="meta")), 0) self.assertEqual(cf(torch.empty_strided((u0,), (1,), device="meta")), 0) self.assertEqual(cf(torch.empty_strided((1,), (u0,), device="meta")), 2) Max = torch.sym_max self.assertEqual( cf( torch.empty_strided( (2, 3, 1, u0), (3 * Max(1, u0), Max(1, u0), Max(1, u0), 1), device="meta", ) ), 2, ) # Wobbling these to zero is OK too self.assertEqual(cf(torch.empty_strided((u0, 2), (3, 1), device="meta")), 2) self.assertEqual(cf(torch.empty_strided((2, u0), (1, 3), device="meta")), 2) def test_specialize_zero_one(self): shape_env = ShapeEnv(specialize_zero_one=True) a0 = create_symint(shape_env, 5) assert a0 != 1 self.assertEqual(len(shape_env.guards), 0) shape_env = ShapeEnv(specialize_zero_one=False) a0 = create_symint(shape_env, 5) assert a0 != 1 self.assertEqual(len(shape_env.guards), 1) def test_duck_shape(self): shape_env = ShapeEnv(duck_shape=True) a0 = create_symint(shape_env, 5) a1 = create_symint(shape_env, 5) assert a0 == a1 self.assertEqual(len(shape_env.guards), 0) shape_env = ShapeEnv(duck_shape=False) a0 = create_symint(shape_env, 5) a1 = create_symint(shape_env, 5) assert a0 == a1 self.assertEqual(len(shape_env.guards), 1) def test_int_bool(self): # See https://github.com/pytorch/pytorch/issues/95981 shape_env = ShapeEnv(duck_shape=True) a0 = create_symint(shape_env, 5) assert a0 self.assertEqual(len(shape_env.guards), 0) def test_symint_as_scalar(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) sym_int_encountered = False class TestSymInt(TorchDispatchMode): def __torch_dispatch__(self, func, types, args=(), kwargs=None): assert func == torch.ops.aten.add.Tensor nonlocal sym_int_encountered # WARNING: do not do identity tests on the outer # SymInt/SymFloat, they are NOT STABLE sym_int_encountered = kwargs["alpha"].node is a0.node kwargs["alpha"] = 0 return func(*args) x = torch.rand([4, 4]) with TestSymInt(): y = torch.add(x, x, alpha=a0) self.assertTrue(sym_int_encountered) def test_deepcopy(self): shape_env = ShapeEnv() a0 = create_symint(shape_env, 2) assert a0 < 4 new_shape_env = copy.deepcopy(shape_env) self.assertEqual(len(new_shape_env.guards), 1) def test_print_readable_with_symints(self): def f(a, b): dim0 = a.shape[0] + b.shape[0] dim1 = a.shape[1] + b.shape[1] d = a.new_empty(dim0, dim1) d = torch.ops.aten.native_dropout(d, 0.5, train=True) return d fx_g = make_fx(f, tracing_mode="symbolic")(torch.randn(5, 3), torch.randn(4, 3)) out = fx_g.print_readable(print_output=False) self.assertExpectedInline( out.strip(), """\ class f(torch.nn.Module): def forward(self, a_1: "f32[s75, s96]", b_1: "f32[s57, s96]"): # No stacktrace found for following nodes sym_size_int: "Sym(s75)" = torch.ops.aten.sym_size.int(a_1, 0) sym_size_int_1: "Sym(s57)" = torch.ops.aten.sym_size.int(b_1, 0) add: "Sym(s57 + s75)" = sym_size_int + sym_size_int_1; sym_size_int = sym_size_int_1 = None sym_size_int_2: "Sym(s96)" = torch.ops.aten.sym_size.int(a_1, 1) sym_size_int_3: "Sym(s96)" = torch.ops.aten.sym_size.int(b_1, 1); b_1 = None add_1: "Sym(2*s96)" = sym_size_int_2 + sym_size_int_3; sym_size_int_2 = sym_size_int_3 = None new_empty: "f32[s57 + s75, 2*s96]" = torch.ops.aten.new_empty.default(a_1, [add, add_1], pin_memory = False); a_1 = add = add_1 = None native_dropout = torch.ops.aten.native_dropout.default(new_empty, 0.5, True); new_empty = None getitem: "f32[s57 + s75, 2*s96]" = native_dropout[0] getitem_1: "b8[s57 + s75, 2*s96]" = native_dropout[1]; native_dropout = None return (getitem, getitem_1)""", # noqa: B950 ) def test_statically_known_true(self): shape_env = ShapeEnv() s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5)) # Statically known true self.assertTrue(statically_known_true(True)) self.assertTrue(statically_known_true(s2 == s2)) self.assertTrue(statically_known_true(s2 * s3 > s3)) self.assertTrue(statically_known_true(s3 * s4 > s4)) self.assertTrue(statically_known_true((s3 + s3) % 2 == 0)) # Statically known false self.assertFalse(statically_known_true(False)) self.assertFalse(statically_known_true(s3 * s4 <= s4)) self.assertFalse(statically_known_true((s3 + s3) % 2 == 1)) # True for hints, but not known statically self.assertFalse(statically_known_true(s2 + s2 == s4)) self.assertFalse(statically_known_true(s4 % s2 == 0)) self.assertFalse(statically_known_true(s2 != s3)) self.assertFalse(statically_known_true(s3 * s4 > s2)) # False for hints, but not known statically self.assertFalse(statically_known_true(s2 == s3)) self.assertFalse(statically_known_true(s2 > s3)) self.assertFalse(statically_known_true(s3 + s3 == s4)) # No guards should be generated self.assertEqual(len(shape_env.guards), 0) def test_statically_known_false(self): shape_env = ShapeEnv() s2, s3, s4 = (create_symint(shape_env, i) for i in range(2, 5)) # Statically known true self.assertFalse(statically_known_false(True)) self.assertFalse(statically_known_false(s2 == s2)) self.assertFalse(statically_known_false(s2 * s3 > s3)) self.assertFalse(statically_known_false(s3 * s4 > s4)) self.assertFalse(statically_known_false((s3 + s3) % 2 == 0)) # Statically known false self.assertTrue(statically_known_false(False)) self.assertTrue(statically_known_false(s3 * s4 <= s4)) self.assertTrue(statically_known_false((s3 + s3) % 2 == 1)) # True for hints, but not known statically self.assertFalse(statically_known_false(s2 + s2 == s4)) self.assertFalse(statically_known_false(s4 % s2 == 0)) self.assertFalse(statically_known_false(s2 != s3)) self.assertFalse(statically_known_false(s3 * s4 > s2)) # False for hints, but not known statically self.assertFalse(statically_known_false(s2 == s3)) self.assertFalse(statically_known_false(s2 > s3)) self.assertFalse(statically_known_false(s3 + s3 == s4)) # No guards should be generated self.assertEqual(len(shape_env.guards), 0) def test_ephemeral_source_simplification(self): from torch._dynamo.source import EphemeralSource # For full robustness, ensure the ephemeral source symbols are simplified out regardless # of construction order or check order. for construct_ephemeral_first, x_first_in_check in itertools.product( [False, True], [False, True] ): shape_env = ShapeEnv() shape = (5, 10) dynamic_dims = [DimDynamic.DYNAMIC for _ in shape] x = create_symbolic_tensor( "x", torch.randn(*shape), shape_env, source=(EphemeralSource() if construct_ephemeral_first else None), dynamic_dims=dynamic_dims, ) y = create_symbolic_tensor( "y", torch.randn(*shape), shape_env, source=(EphemeralSource() if not construct_ephemeral_first else None), dynamic_dims=dynamic_dims, ) t_with_ephemeral = x if construct_ephemeral_first else y def _get_ephemeral_source_symbols(t): return [ s.node.expr for s in itertools.chain(t.shape, t.stride(), (t.storage_offset(),)) if isinstance(s, torch.SymInt) and s.node.expr in shape_env.var_to_sources and any( source.is_ephemeral() for source in shape_env.var_to_sources[s.node.expr] ) ] # these checks should simplify out the ephemeral symbols, regardless of the # ordering x == y or y == x self.assertTrue(len(_get_ephemeral_source_symbols(t_with_ephemeral)) > 0) if x_first_in_check: torch._check(x.size() == y.size()) torch._check(x.stride() == y.stride()) torch._check(x.storage_offset() == y.storage_offset()) else: torch._check(y.size() == x.size()) torch._check(y.stride() == x.stride()) torch._check(y.storage_offset() == x.storage_offset()) self.assertEqual(len(_get_ephemeral_source_symbols(t_with_ephemeral)), 0) def test_ephemeral_source_unified_with_non_ephemeral_source(self): from torch._dynamo.source import EphemeralSource for construct_ephemeral_first in (False, True): shape_env = ShapeEnv() shape = (5, 10) # use duck sizing here to ensure symbol reuse across x and y duck_dims = [DimDynamic.DUCK for _ in shape] x = create_symbolic_tensor( "x", torch.randn(*shape), shape_env, source=(EphemeralSource() if construct_ephemeral_first else None), dynamic_dims=duck_dims, ) y = create_symbolic_tensor( "y", torch.randn(*shape), shape_env, source=(EphemeralSource() if not construct_ephemeral_first else None), dynamic_dims=duck_dims, ) # regardless of construction order, non-ephemeral sources should be preferred # first in the var_to_sources list for potential guarding later on for source_list in shape_env.var_to_sources.values(): self.assertFalse(source_list[0].is_ephemeral()) self.assertEqual(x.size(), y.size()) self.assertEqual(x.stride(), y.stride()) self.assertEqual(x.storage_offset(), y.storage_offset()) def test_tensor_factory_with_symint(self): args = list(range(3)) expected = torch.tensor(args) shape_env = ShapeEnv() sym_args = [create_symint(shape_env, i) for i in args] # test tensor factories for dt in all_types_and(torch.half, torch.bfloat16): res = torch.tensor(sym_args, dtype=dt) self.assertEqual(res, expected, exact_dtype=False) # test legacy tensor factories legacy_ctors = [ torch.Tensor, torch.LongTensor, torch.DoubleTensor, torch.FloatTensor, torch.IntTensor, torch.ShortTensor, torch.HalfTensor, torch.ByteTensor, ] for Tensor in legacy_ctors: res = Tensor(sym_args) self.assertEqual(res, expected, exact_dtype=False) def test_backed_size_oblivious_01_spec(self): from torch.fx.experimental.symbolic_shapes import guard_size_oblivious @torch.compile(dynamic=True, fullgraph=True) def f(a, b): if guard_size_oblivious(a.size(0) == 1): return b * 10 else: return b * 20 with torch.fx.experimental._config.patch(backed_size_oblivious=True): # always go to the >= 2 branch. self.assertEqual( f(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) ) @fresh_cache() def test_slice_backed_size_oblivious(self): @torch.compile(backend="inductor", fullgraph=True, dynamic=True) def f(x): return x[:5] with torch.fx.experimental._config.patch(backed_size_oblivious=True): f(torch.randn(10, 10)) def test_baddbmm_symint(self): from torch._subclasses.fake_tensor import FakeTensorMode shape_env = ShapeEnv() fake_mode = FakeTensorMode(shape_env=shape_env) B, M, K, N = [shape_env.create_unbacked_symint() for _ in range(4)] with fake_mode: A = torch.empty((B, M, K), device="meta") Bmat = torch.empty((B, K, N), device="meta") bias3 = torch.empty((B, M, N), device="meta") _ = torch.baddbmm(bias3, A, Bmat) @skipIfTorchDynamo( "Creating ShapeEnv fails for confusing reasons (also we never expect dynamo to see code like this)" ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized # TODO: We don't have to circuitously create the float, can just # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node def get_sym_inp(inp): # NB: this must come before int if isinstance(inp, bool): return torch.SymBool(to_node(bool_seed_node, inp)) elif isinstance(inp, int): return torch.SymInt(to_node(seed_node, inp)) else: return torch.SymFloat(to_node(seed_node, inp)) if fn == "float_pow": if inp1 < 0: return if fn == "pow_by_natural": if isinstance(inp1, float) or isinstance(inp2, float): return if inp2 < 0: return def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) elif ( fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") and inp2 == 0 ): # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( # TODO: dear catastrophe waitress, # this doesn't work fn in ["float_pow", "pow_by_natural"] and inp1 < 0 and ( type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) ) and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float return self.assertRaises((RuntimeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): # TypeError: unsupported operand type(s) return self.assertRaises((TypeError,)) elif fn in ("lshift", "rshift") and inp2 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) else: return contextlib.nullcontext() lambda_apply = method_to_operator(fn) def guard_fn(v): if type(v) in (SymBool, bool): return guard_bool(v) elif type(v) in (SymFloat, float): return guard_float(v) else: # SymInt, int return guard_int(v) # Get reference result with maybe_xfail(inp1, inp2): if is_unary_fn: ref_out = lambda_apply(inp1) else: ref_out = lambda_apply(inp1, inp2) # Symified first arg sym_inp1 = get_sym_inp(inp1) with maybe_xfail(sym_inp1, inp2): if is_unary_fn: out = lambda_apply(sym_inp1) else: out = lambda_apply(sym_inp1, inp2) self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) out = guard_fn(out) self.assertEqual(out, ref_out) if is_unary_fn: return # Symified second arg sym_inp2 = get_sym_inp(inp2) with maybe_xfail(inp1, sym_inp2): out = lambda_apply(inp1, sym_inp2) self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) out = guard_fn(out) self.assertEqual(out, ref_out) # Symified both args with maybe_xfail(sym_inp1, sym_inp2): out = lambda_apply(sym_inp1, sym_inp2) self.assertTrue(isinstance(out, (SymInt, SymFloat, SymBool))) out = guard_fn(out) self.assertEqual(out, ref_out) @parametrize("fn", list(sym_node.magic_methods.keys())) def test_bool_method(self, fn): # sym_ite has its own tests if fn not in sym_node.bool_magic_methods or fn == "sym_ite": self.skipTest(f"{fn} is non-bool") is_unary_fn = fn in sym_node.unary_methods shape_env = ShapeEnv() self._do_test(fn, True, False, shape_env, is_unary_fn) @parametrize("fn", list(sym_node.magic_methods.keys())) @parametrize("first_type", ["int", "float"]) @parametrize("second_type", ["int", "float"]) def test_method(self, fn, first_type, second_type): if first_type == "float": # TODO: Hmm, this looks like we skip all floats self.skipTest(f"{fn} is not a float magic method") if ( first_type == "int" or second_type == "int" ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") if second_type == "float" and fn in ["mod"]: self.skipTest(f"{fn} only handles int") if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"): self.skipTest(f"{fn} is a bitwise op, only handles int") is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": self.skipTest(f"{fn} is unary and already tested") if fn in sym_node.bool_magic_methods: self.skipTest(f"{fn} is bool") # Only floats here since these will be converted to int if necessary. # We also ignore complex and bool. values = ( 0.0, 1.0, 0.5 if fn in ("sym_acos", "sym_asin") else 2.5, # avoid math domain error ) neg_values = tuple(-x for x in values) for inp1, inp2 in itertools.chain( itertools.product(values, values), itertools.product(values, neg_values), itertools.product(neg_values, values), itertools.product(neg_values, neg_values), ): if first_type == "int": inp1 = int(inp1) if second_type == "int": inp2 = int(inp2) shape_env = ShapeEnv() self._do_test(fn, inp1, inp2, shape_env, is_unary_fn) def get_constant_bool(self, val): return SymBool(torch._C._get_constant_bool_symnode(val)) @unittest.expectedFailure def test_symint_hashing(self): shape_env = ShapeEnv() hash(create_symint(shape_env, 3)) def test_symnode_hashing(self): shape_env = ShapeEnv() # These all trigger specialization when hashed hash(create_symbool(shape_env, True)) # We should be passing in float here, but create_symbol currently # only supports int hash(create_symfloat(shape_env, 3.0)) # NestedInt (SymInt), constant SymBool, SymNode are hashable j1 = torch._C._get_nested_int(1, 1) j1_copy = torch._C._get_nested_int(1, 1) j2 = torch._C._get_nested_int(2, 1) t = self.get_constant_bool(True) t_copy = self.get_constant_bool(True) f = self.get_constant_bool(False) n = create_symint(shape_env, 3).node m = self.get_constant_bool(True).node self.assertIs(j1 == j1_copy, True) self.assertEqual(hash(j1), hash(j1_copy)) self.assertIs(j1 == j2, False) self.assertNotEqual(hash(j1), hash(j2)) self.assertIs(t == t_copy, True) self.assertEqual(hash(t), hash(t_copy)) self.assertIs(t == f, False) self.assertNotEqual(hash(t), hash(f)) hash(n) hash(m) def test_symint_deepcopy(self): shape_env = ShapeEnv() symnodes = (torch._C._get_nested_int(1, 1),) deepcopied_symnodes = copy.deepcopy(symnodes) self.assertEqual(symnodes, deepcopied_symnodes) def test_non_symbolic_symnode(self): j1 = torch._C._get_nested_int(1, 1) j2 = torch._C._get_nested_int(1, 1) j3 = torch._C._get_nested_int(3, 1) self.assertIsInstance(j1, torch.SymInt) self.assertNotIsInstance(j1, int) with self.assertRaisesRegex( RuntimeError, "add not supported by NestedIntSymNode" ): j1 + 3 self.assertFalse(j1 == 3) with self.assertRaisesRegex(RuntimeError, "indeterminate"): self.assertFalse(3 >= j2) self.assertIs(j1 == j1, True) self.assertIs(j1 == j2, True) self.assertIs(j1 == j3, False) self.assertIs(j1 != j3, True) self.assertIs(j1 != j2, False) x = self.get_constant_bool(True) # # Unary # # op(constant SymBool) self.assertIs(x.__sym_not__(), False) # # Binary # # op(constant SymBool, bool) # op(constant SymBool, constant SymBool) # op(bool, constant SymBool) self.assertIs(operator.and_(x, True), True) self.assertIs(operator.and_(x, x), True) self.assertIs(operator.and_(True, x), True) # op(symbolic SymBool, constant Symbool) # op(constant SymBool, symbolic Symbool) shape_env = ShapeEnv() a = create_symint(shape_env, 2) b = create_symint(shape_env, 2) c = a == b # symbolic SymBool d = self.get_constant_bool(True) e = operator.and_(c, d) f = operator.and_(d, c) self.assertTrue(is_symbolic(e)) self.assertTrue(is_symbolic(f)) self.assertIs(e.node.guard_bool("", 0), True) self.assertIs(f.node.guard_bool("", 0), True) # Comparing sizes sz1 = torch.Size([j1, j1, j1]) sz2 = torch.Size([j1, j1, j1]) self.assertIs(sz1 == sz2, True) sz1 = torch.Size([3, j1, 4]) sz2 = torch.Size([3, j2, 4]) self.assertIs(sz1 == sz2, True) self.assertIs(sz1 != sz2, False) def test_stride_symnode(self): shape_env = ShapeEnv() # check everything static t = create_fake_tensor_with_dynamic_size( torch.ones(3, 6), shape_env, dynamic_sizes=[ DimDynamic.STATIC, DimDynamic.STATIC, ], dynamic_strides=[ DimDynamic.INFER_STRIDE, DimDynamic.INFER_STRIDE, ], ) self.assertTrue(all(isinstance(size, int) for size in t.size())) self.assertTrue(all(isinstance(stride, int) for stride in t.stride())) # check dynamic size but static dims t = create_fake_tensor_with_dynamic_size( torch.ones(3, 6), shape_env, dynamic_sizes=[ DimDynamic.DYNAMIC, DimDynamic.DYNAMIC, ], dynamic_strides=[ DimDynamic.INFER_STRIDE, DimDynamic.INFER_STRIDE, ], ) # Expect stride to be inferred s0, s1 = t.size() s2, s3 = t.stride() self.assertTrue(isinstance(s0, torch.SymInt)) self.assertTrue(isinstance(s1, torch.SymInt)) self.assertTrue(isinstance(s2, torch.SymInt)) self.assertTrue(s1 == s2) self.assertEqual(s3, 1) # Check dynamic stride but static dims t = create_fake_tensor_with_dynamic_size( torch.ones(3, 6), shape_env, dynamic_sizes=[ DimDynamic.STATIC, DimDynamic.STATIC, ], dynamic_strides=[ DimDynamic.DYNAMIC, DimDynamic.INFER_STRIDE, ], ) s0, s1 = t.size() s2, s3 = t.stride() self.assertTrue(isinstance(s0, int)) self.assertTrue(isinstance(s1, int)) self.assertTrue(isinstance(s2, torch.SymInt)) self.assertTrue(isinstance(s3, int)) # Check dynamic sizes and dims, and ensure different symbol t = create_fake_tensor_with_dynamic_size( torch.ones(3, 6), shape_env, dynamic_sizes=[ DimDynamic.DYNAMIC, DimDynamic.DYNAMIC, ], dynamic_strides=[ DimDynamic.DYNAMIC, DimDynamic.INFER_STRIDE, ], ) s0, s1 = t.size() s2, s3 = t.stride() self.assertTrue(isinstance(s0, torch.SymInt)) self.assertTrue(isinstance(s1, torch.SymInt)) self.assertTrue(isinstance(s2, torch.SymInt)) self.assertTrue(isinstance(s3, int)) self.assertTrue(str(s1.node.expr) != str(s2.node.expr)) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_dynamic_int_basic_compile(self, backend): from torch.fx.experimental.sym_node import DynamicInt cnt = CompileCounterWithBackend(backend) # test scalar inputs to function def f(x, y, z): out = torch.tensor([x + y + z]) out = out + torch.zeros(abs(x) + 2).sum() # test out tensor construction return out fn = torch.compile(f, fullgraph=True, backend=cnt) x = DynamicInt(1) z = DynamicInt(3) self.assertEqual(fn(x, x, z), f(1, 1, 3)) # guard: x == y self.assertEqual(fn(2, 2, 0), f(2, 2, 0)) self.assertEqual(fn(-1, -1, 2), f(-1, -1, 2)) self.assertEqual(cnt.frame_count, 1) # no recompiles self.assertEqual(fn(3, 4, 5), f(3, 4, 5)) # now we recompile self.assertEqual(cnt.frame_count, 2) # test nn module property class Foo(torch.nn.Module): def __init__(self): super().__init__() self.i = DynamicInt(1) def forward(self, x): return torch.tensor([x + self.i]) cnt.clear() m = Foo() mc = torch.compile(m, backend=cnt, fullgraph=True) self.assertEqual(mc(DynamicInt(0)), m(0)) mc.i = -2 # override attribute self.assertEqual(mc(-1), m(-1)) self.assertEqual(cnt.frame_count, 1) def test_dynamic_int_eager_usage(self): from torch.fx.experimental.sym_node import DynamicInt w = DynamicInt(-1) x = DynamicInt(0) y = DynamicInt(1) z = DynamicInt(2) def check(l, r): self.assertTrue(isinstance(l, DynamicInt)) self.assertEqual(l, r) # test arithmetic check(2 * y + z, 4) check((10 - z) // 2, 4) check(1 // z, 0) check(-w + w**2, 2) check(x % z, 0) check(1 << z, 4) check(z | y, 3) check(min(y, z), 1) self.assertTrue(z > -2) with self.assertRaises(ZeroDivisionError): y % x # math, numpy self.assertEqual(math.cos(x), y) self.assertEqual(math.prod([z, z], start=z), 8) self.assertEqual(np.arange(z)[y], 1) self.assertTrue(np.allclose(np.ones([y, z]).sum(axis=x), np.ones(z))) # test conversions self.assertTrue(isinstance(x + 2, int)) self.assertTrue(isinstance(x + 2, DynamicInt)) self.assertEqual(y / 2.0, 0.5) # this could return DynamicFloat in future self.assertEqual(float(z), 2.0) self.assertFalse(bool(x)) self.assertEqual(DynamicInt(x).real, x.real) # torch functions, scalar inputs self.assertEqual(torch.arange(z)[:w][x], 0) self.assertEqual(torch.add(torch.tensor(w), torch.tensor(w), alpha=z), -3) self.assertEqual( list(torch.nn.Linear(z, y)(torch.randn(z * 2, z)).shape), [4, 1] ) self.assertEqual(z * torch.ones(z).sum(dim=x), 4) instantiate_parametrized_tests(TestSymNumberMagicMethods) class TestFloorDiv(TestCase): @staticmethod def python_floordiv(x, y): return x // y @staticmethod def torch_floordiv(x, y): # Note: we fully evaluate here since FloorDiv might not always do # that. shape_env = ShapeEnv() return shape_env.evaluate_expr(FloorDiv(x, y)) @staticmethod def yield_test_cases(values, negate=True): for x, y in values: yield (x, y) if negate: yield (-x, y) yield (x, -y) yield (-x, -y) def test_floordiv_float_int(self): values = ((7, 2),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) def test_floordiv_div_by_one(self): values = ((2, 1),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) def test_floordiv_div_does_not_generate_non_int_rational(self): s14 = sympy.Symbol("s14", integer=True, positive=True) s37 = sympy.Symbol("s37", integer=True, positive=True) inner_expr = FloorDiv(s14, 2016) middle_expr = (24 * s37 + 672) * inner_expr numerator = middle_expr + 21 denominator = 22 result = FloorDiv(numerator, denominator) rationals = result.atoms(sympy.Rational) all_rationals_ints = all(r.q == 1 for r in rationals) self.assertTrue(all_rationals_ints) def test_floordiv_simplify(self): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 exprs = (7 * FloorDiv(6, 2),) for expr in exprs: self.assertEqual(expr, result) self.assertEqual(expr.doit(deep=False), result) self.assertEqual(expr.doit(deep=True), result) self.assertEqual(sympy.simplify(expr), result) self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) def test_floordiv_assumptions(self): cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), ) for base, divisor in itertools.product(cases, repeat=2): def op(): return FloorDiv(base, divisor) def is_complex(x): return x.is_integer is False and x.is_real is False and x.is_complex if is_complex(base) or is_complex(divisor): self.assertRaisesRegex( TypeError, ( r"unsupported operand type\(s\) for //: 'Symbol' and 'Symbol'," r" expected integer or real" ), op, ) continue op = op() # In regular Python, x//x == 1.0 if x is a float, but FloorDiv # always returns an integer 1 when both args are the same object. # This even works for Symbols with no assumptions specified. if base is divisor: self.assertTrue(op.is_integer) self.assertTrue(op.is_real) elif base.is_integer and divisor.is_integer: self.assertTrue(op.is_integer) self.assertTrue(op.is_real) else: self.assertEqual(op.is_integer, None) self.assertTrue(op.is_real) class TestDimConstraints(TestCase): @skipIfTorchDynamo("mark_dynamic not supported") def test_simplify_max_1_0(self): x = torch.rand(10) torch._dynamo.mark_dynamic(x, 0, max=20, min=5) @torch.compile(fullgraph=True) def func(x, v): # test that statically_known_true if (v == 0 or v == 1) and not statically_known_true( max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2) ): raise AssertionError("error") if max(v, (-1 + x.size()[0] // 2)) == (-1 + x.size()[0] // 2): return x * 400 else: return (x * 10) * 100 # testing that this does not throw constraint violation error. self.assertEqual(func(x, 1), x * 400) self.assertEqual(func(x, 0), x * 400) def test_dim_constraints_reduce_congruences_simple(self): from sympy import Symbol s = Symbol("s", positive=True, integer=True) dim_constraints = DimConstraints({}, {}, set(), {}) dim_constraints._congruences[s] = { (s / 2) % 2, (s / 2) % 8, (s / 2) % 4, s % 2, ((s / 16) + 2) % 4, } congruences = dim_constraints._reduce_congruences() self.assertEqual(congruences[s], {(s + 32) % 64}) def test_dim_constraints_reduce_inequalities_simple(self): from sympy import Eq, Interval, Ne, Symbol from sympy.solvers.inequalities import reduce_inequalities s = Symbol("s", positive=True, integer=True) exprs = { s >= 2, Ne(8 * s, 16), Ne(s / 2, 1), Ne(16 * s, 32), s < 16, Ne(s, 2), s / 2 < 16, s / 2 > 1, s / 2 >= 2, Ne(3 * s / 2, 3), } solution = reduce_inequalities(exprs, s).as_set() self.assertEqual(solution, Interval.Ropen(4, 16)) exprs.add(Eq(s / 2, 4)) solution = reduce_inequalities(exprs, s).as_set() self.assertEqual(solution, {8}) def test_dim_constraints_reduce_inequalities_error(self): from collections import defaultdict from sympy import Symbol from sympy.solvers.inequalities import reduce_inequalities from torch._dynamo.source import ( LocalSource, TensorProperty, TensorPropertySource, ) from torch.fx.experimental.symbolic_shapes import DynamicDimConstraintPrinter s0 = Symbol("s0", positive=True, integer=True) exprs = { 4 * s0**3 - 4 * s0**2 + s0 <= 2147483647, s0 >= 2, s0**3 <= 2147483647, s0 <= 2147483647, } answer = reduce_inequalities(exprs, s0) symbol_to_source = defaultdict(list) symbol_to_source[s0].append( TensorPropertySource( base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 ) ) dcp = DynamicDimConstraintPrinter(symbol_to_source, {}) with self.assertRaisesRegex( AssertionError, "Unknown symbol.*created by constraints solver", ): dcp.doprint(answer) def test_dim_constraints_solve_full(self): from sympy import Eq, Integer, Ne, Symbol from torch._dynamo.source import ( LocalSource, TensorProperty, TensorPropertySource, ) src0 = TensorPropertySource( base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=0 ) src2 = TensorPropertySource( base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=0 ) src3 = TensorPropertySource( base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=0 ) src4 = TensorPropertySource( base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=0 ) src1 = TensorPropertySource( base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=2 ) src7 = TensorPropertySource( base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=3 ) src5 = TensorPropertySource( base=LocalSource(local_name="a"), prop=TensorProperty.SIZE, idx=1 ) src8 = TensorPropertySource( base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=1 ) src6 = TensorPropertySource( base=LocalSource(local_name="c"), prop=TensorProperty.SIZE, idx=1 ) src9 = TensorPropertySource( base=LocalSource(local_name="d"), prop=TensorProperty.SIZE, idx=1 ) src10 = TensorPropertySource( base=LocalSource(local_name="e"), prop=TensorProperty.SIZE, idx=1 ) src11 = TensorPropertySource( base=LocalSource(local_name="f"), prop=TensorProperty.SIZE, idx=1 ) src12 = TensorPropertySource( base=LocalSource(local_name="b"), prop=TensorProperty.SIZE, idx=2 ) s0 = Symbol("s0", positive=True, integer=True) s1 = Symbol("s1", positive=True, integer=True) s5 = Symbol("s5", positive=True, integer=True) s6 = Symbol("s6", positive=True, integer=True) symbol_to_source = { s0: [src0, src2, src3, src4], s1: [src1, src7], s5: [src5, src8], s6: [src6, src9, src10], } var_to_val = {s0: 8, s1: 96, s5: 22, s6: 21} marked_dynamic = {s0, s1, s5, s6} dim_constraints = DimConstraints( symbol_to_source, var_to_val, marked_dynamic, {} ) dim_constraints.add_equality(src2, s0) dim_constraints.add_equality(src3, s0) dim_constraints.add_equality(src4, s0) dim_constraints.add_equality(src7, s1) dim_constraints.add_equality(src8, s5) dim_constraints.add_equality(src9, s6) dim_constraints.add_equality(src10, s6) dim_constraints.add_equality(src11, Integer(1)) dim_constraints.add_equality(src12, Integer(3)) dim_constraints.add(s1**2 <= 2147483647) dim_constraints.add(32 * s1**2 <= 2147483647) dim_constraints.add(s0 < 16) dim_constraints.add(Eq(Mod(s1, 2), 0)) dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) dim_constraints.add(Ne((FloorDiv(s1, 2)) ** 2, 1)) dim_constraints.add(32 * (FloorDiv(s1, 2)) ** 2 <= 2147483647) dim_constraints.add((FloorDiv(s1, 2)) ** 2 > 1) dim_constraints.add(Ne(FloorDiv(s1, 2), 1)) dim_constraints.add( 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 + 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) + 64 <= 2147483647 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) + 1, 1, ) ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 2) + 1, 1)) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 2)) ** 2 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 2)) + 1 > 1 ) dim_constraints.add( 128 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 + 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) + 128 <= 2147483647 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) + 1, 1, ) ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 4) + 1, 1)) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 4)) ** 2 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 4)) + 1 > 1 ) dim_constraints.add( 256 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 + 512 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) + 256 <= 2147483647 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) + 1, 1, ) ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1)) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 + 2 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) + 1 > 1 ) dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 3) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 <= 2147483647 ) dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 0) dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 1) dim_constraints.add( Ne( 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * s0, 0, ) ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 1, ) ) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 0, ) ) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 0 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 0)) dim_constraints.add( 1 < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, -1)) dim_constraints.add( Ne( 60 * s0 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * s0 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * s0, 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120, ) ) dim_constraints.add( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120 > 0 ) dim_constraints.add( Eq( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 * (Mod(s0, 2)) - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) * Mod(s0, 2) + 60 * (Mod(s0, 2)), 0, ) ) dim_constraints.add( Ne( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120, 0, ) ) dim_constraints.add( Ne( 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, 2) * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 0, ) ) dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) dim_constraints.add( Ne( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60, 0, ) ) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 >= 0 ) dim_constraints.add( 1 < 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) ) dim_constraints.add(Ne(16 * s0, 32)) dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) dim_constraints.add(Ne(16 * s0, 32)) dim_constraints.add(Eq(16 * (Mod(s0, 2)), 0)) dim_constraints.add(FloorDiv(s0, 2) >= 2) dim_constraints.add(Ne(FloorDiv(s0, 2), 1)) dim_constraints.add(1 < FloorDiv(s0, 2)) dim_constraints.add(Ne(s0, 2)) dim_constraints.add( 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) >= 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 ) dim_constraints.add( 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, 2) * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) > 0 ) dim_constraints.add( Ne( 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, 2) * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), 3 * (FloorDiv(s0, 2)) * (FloorDiv(s0, (FloorDiv(s0, 2)))), ) ) dim_constraints.add( Ne( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20, 0, ) ) dim_constraints.add( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20 >= 0 ) dim_constraints.add( Ne( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20, 20, ) ) dim_constraints.add( Ne( 20 * ( Mod( 1, (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, ) ), 0, ) ) dim_constraints.add( Ne( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) * ( Mod( 1, (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), ) ) - 20 * Mod( 1, (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), ), 0, ) ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 1)) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 >= 1 ) dim_constraints.add( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20 >= 0 ) dim_constraints.add( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20 >= 1 ) dim_constraints.add( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20 >= 2 ) dim_constraints.add( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20 > 1 ) dim_constraints.add( 20 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 40 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 20 < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 ) dim_constraints.add( Ne( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60, 60, ) ) dim_constraints.add( Ne( FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, ) ) dim_constraints.add( Eq( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) * ( Mod( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 1, ) ) - Mod( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1) + 1 / (FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1), 1, ), 0, ) ) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, ) ) dim_constraints.add(Ne(8 * s0, 16)) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 >= (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 ) dim_constraints.add( 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) <= 2147483647 ) dim_constraints.add( 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 90 <= 2147483647 ) dim_constraints.add(FloorDiv(s0, 2) < 16) dim_constraints.add(FloorDiv(s0, 2) > 1) dim_constraints.add( Ne( 90 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 180 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 90 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add( 1 < 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 90 ) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 2 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1 > 1 ) dim_constraints.add( 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, (FloorDiv(s0, 2))) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, (FloorDiv(s0, 2)))) > 1 ) dim_constraints.add( Ne( 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add( 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 90 > 1 ) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 > 1 ) dim_constraints.add( Ne( 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, 2)), 3 * (FloorDiv(s0, 2)), ) ) dim_constraints.add( 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 * (FloorDiv(s0, 2)) > 0 ) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 > 0 ) dim_constraints.add( Ne( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120, 0, ) ) dim_constraints.add( 1 < 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120 ) dim_constraints.add( Ne( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120, 6, ) ) dim_constraints.add( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120 > 0 ) dim_constraints.add( Ne( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120, 0, ) ) dim_constraints.add( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120 <= 2147483647 ) dim_constraints.add( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120 <= 20480 ) dim_constraints.add( Ne( 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 90, 0, ) ) dim_constraints.add( 120 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 240 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 120 > 1 ) dim_constraints.add( 90 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 180 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 90 <= 20480 ) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 120 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 60 <= 20480 ) dim_constraints.add( Ne( 240 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 480 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 240, 0, ) ) dim_constraints.add(Eq(6 * s5, 132)) dim_constraints.add(Eq(4, FloorDiv(s0, 2))) dim_constraints.add(Eq(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1, 4)) dim_constraints.add( Ne( 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 128 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 64 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add( 1 < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 64 ) dim_constraints.add( 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 64 <= 2147483647 ) dim_constraints.add( 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 128 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 64 > 1 ) dim_constraints.add( 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 62 <= 2147483647 ) dim_constraints.add( Ne( 62 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 124 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 62 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add( 1 < 62 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 124 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 62 ) dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) dim_constraints.add(Ne(3 * (FloorDiv(s0, 2)), 3)) dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) dim_constraints.add(Eq(4, FloorDiv(s0, 2))) dim_constraints.add(Eq(FloorDiv(s0, 2), 4)) dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 1 >= 3) dim_constraints.add( 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 576 <= 2147483647 ) dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 0) dim_constraints.add(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3 >= 1) dim_constraints.add( Ne( 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 576 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9, 1, ) ) dim_constraints.add( Ne( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9, 0, ) ) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 >= 0 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 0)) dim_constraints.add( 1 < 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 576 ) dim_constraints.add(Ne(FloorDiv((FloorDiv(s1, 2) - 1), 8) - 3, 1)) dim_constraints.add( Ne( 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 576 * (FloorDiv(s0, 2)), 256, ) ) dim_constraints.add( Eq( 64 * ( Mod( (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 * (FloorDiv(s0, 2)), 4, ) ), 0, ) ) dim_constraints.add( Eq( FloorDiv(s0, 2), FloorDiv( ( (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 * (FloorDiv(s0, 2)) ), 4, ), ) ) dim_constraints.add( Eq( FloorDiv( ( (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 * (FloorDiv(s0, 2)) ), 4, ), FloorDiv(s0, 2), ) ) dim_constraints.add( Ne(64 * (Mod(FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4)), 0) ) dim_constraints.add( Eq( 64 * ( Mod( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 1, 4, ) ), 0, ) ) dim_constraints.add( 64 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 384 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 576 * (FloorDiv(s0, 2)) > 0 ) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 >= 1 ) dim_constraints.add( Eq( 64 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 384 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 576, 256, ) ) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 540 <= 2147483647 ) dim_constraints.add( Ne( 60 * (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 360 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 540 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add( 1 < 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 540 ) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 <= 2147483647 ) dim_constraints.add( Ne( (FloorDiv(s0, 2)) * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv(s0, 2) * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 * (FloorDiv(s0, 2)), 0, ) ) dim_constraints.add( 1 < (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 ) dim_constraints.add( (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 6 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 9 > 1 ) dim_constraints.add( 60 * (FloorDiv((FloorDiv(s1, 2) - 1), 8)) ** 2 - 360 * FloorDiv((FloorDiv(s1, 2) - 1), 8) + 540 > 1 ) dim_constraints.add(s0 >= 2) dim_constraints.add(s1 >= 2) dim_constraints.add(s6 >= 2) dim_constraints.add(s5 >= 2) dim_constraints.solve() self.assertEqual( dim_constraints._static_results, { "L['c'].size()[0] == 8", "L['d'].size()[0] == 8", "L['a'].size()[2] == 96", "L['f'].size()[1] == 1", "L['a'].size()[3] == 96", "L['b'].size()[2] == 3", "L['b'].size()[1] == 22", "L['b'].size()[0] == 8", "L['a'].size()[1] == 22", "L['a'].size()[0] == 8", }, ) self.assertEqual( dim_constraints._dynamic_results, { "2 <= L['c'].size()[1]", "L['d'].size()[1] == L['c'].size()[1]", "L['e'].size()[1] == L['c'].size()[1]", }, ) class TestGuardsExpressions(TestCase): """ Tests the guards-related methods used by the inductor FX graph cache. """ def test_guards_gt_lt(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 6) s1 = create_symint(shape_env, 7) s2 = create_symint(shape_env, 5) guard_int(sym_int(s0 > 5)) guard_int(sym_int(s0 < 7)) guards = shape_env.produce_guards_expression([s0]) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s2)])) def test_guards_float_print(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 3) guard_bool(2 / s0 == 2 / 3) guards = shape_env.produce_guards_expression([s0]) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) @skipIfTorchDynamo("Not a TorchDynamo suitable test") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_guard_or_true(self): from torch.fx.experimental.symbolic_shapes import guard_or_true def func(a, b): x = a.item() if guard_or_true(x == 1): return b * 10 else: return b * 20 # eager. self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])) self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])) # compile with unbacked. unbacked_func = torch.compile(func, dynamic=True, fullgraph=True) a = torch.tensor([1]) b = torch.tensor([1]) unbacked_func(a, b) # always return b*10 self.assertEqual( unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]) ) self.assertEqual( unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) ) # Test that statically known true works. def func2(a, b): x = a.item() if guard_or_true(x != x): return b * 10 else: return b * 20 unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True) a = torch.tensor([1]) b = torch.tensor([1]) unbacked_func2(a, b) # always return b*20 self.assertEqual( unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) ) self.assertEqual( unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) ) # Test backed_size_oblivious with torch.fx.experimental._config.patch("backed_size_oblivious", True): def func3(a, b): if guard_or_true(a.size()[0] != 9): return b * 10 else: return b * 20 compiled = torch.compile(func3, dynamic=True, fullgraph=True) a = torch.rand(9, 2) b = torch.rand(3, 4) self.assertEqual(func3(a, b), b * 20) self.assertEqual(compiled(a, b), b * 10) @skipIfTorchDynamo("Not a TorchDynamo suitable test") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_guard_or_false(self): from torch.fx.experimental.symbolic_shapes import guard_or_false def func(a, b): x = a.item() if guard_or_false(x == 1): return b * 10 else: return b * 20 # eager. self.assertEqual(func(torch.tensor([1]), torch.tensor([1])), torch.tensor([10])) self.assertEqual(func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20])) # compile with unbacked. unbacked_func = torch.compile(func, dynamic=True, fullgraph=True) a = torch.tensor([1]) b = torch.tensor([1]) unbacked_func(a, b) # always return b*20 self.assertEqual( unbacked_func(torch.tensor([1]), torch.tensor([1])), torch.tensor([20]) ) self.assertEqual( unbacked_func(torch.tensor([2]), torch.tensor([1])), torch.tensor([20]) ) # Test that statically known true works. def func2(a, b): x = a.item() if guard_or_false(x == x): return b * 10 else: return b * 20 unbacked_func2 = torch.compile(func2, dynamic=True, fullgraph=True) a = torch.tensor([1]) b = torch.tensor([1]) unbacked_func2(a, b) # always return b*10 self.assertEqual( unbacked_func2(torch.tensor([1]), torch.tensor([1])), torch.tensor([10]) ) self.assertEqual( unbacked_func2(torch.tensor([2]), torch.tensor([1])), torch.tensor([10]) ) # Test backed_size_oblivious with torch.fx.experimental._config.patch("backed_size_oblivious", True): def func3(a, b): if guard_or_false(a.size()[0] == 9): return b * 10 else: return b * 20 compiled = torch.compile(func3, dynamic=True, fullgraph=True) a = torch.rand(9, 2) b = torch.rand(3, 4) self.assertEqual(func3(a, b), b * 10) self.assertEqual(compiled(a, b), b * 20) def test_guards_float_div(self): shape_env = ShapeEnv() s0 = create_symint(shape_env, 8) s1 = create_symint(shape_env, 7) guard_int(sym_int(s0 / 2.0)) guards = shape_env.produce_guards_expression([s0]) self.assertIn("math.trunc(", guards) self.assertIn("float(", guards) self.assertTrue(shape_env.evaluate_guards_expression(guards, [hint_int(s0)])) self.assertFalse(shape_env.evaluate_guards_expression(guards, [hint_int(s1)])) @skipIfTorchDynamo("Attempt to trace generator") @torch.fx.experimental._config.patch("use_duck_shape", False) def test_size_comparison_no_recompile(self): """ Test that size comparisons don't cause recompilation. When comparing x.size() == b.size() with different sizes, the compiled function should only compile once. We should not guard in sizes of the inner elements. """ cnt = CompileCounter() @torch.compile(fullgraph=True, dynamic=True, backend=cnt) def f(x, b): if x.size() == b.size(): return x return x * 2 # First call: shapes differ (1, 2) vs (2, 4, 9), so if branch is False f(torch.rand(10, 2), torch.rand(20, 4, 9)) # Second call: shapes differ again (1, 2) vs (1, 4, 9), so if branch is False f(torch.rand(10, 2), torch.rand(10, 4, 9)) # Should only compile once despite different input shapes self.assertEqual( cnt.frame_count, 1, f"Expected 1 compilation, got {cnt.frame_count}. " f"Size comparison should not cause recompilation.", ) def test_remove_symbols_without_guarding(self): from torch._functorch.partitioners import _remove_symbols_without_guarding shape_env = ShapeEnv() x = create_fake_tensor_with_dynamic_size( torch.randn(5, 8), shape_env, dynamic_sizes=[ DimDynamic.DYNAMIC, DimDynamic.DYNAMIC, ], dynamic_strides=[ DimDynamic.INFER_STRIDE, DimDynamic.INFER_STRIDE, ], ) self.assertEqual(f"{x.stride()}", "(s49, 1)") self.assertEqual(f"{x.shape}", "torch.Size([s26, s49])") x_clean = _remove_symbols_without_guarding(x, 4096) self.assertEqual(f"{x_clean.stride()}", "(8, 1)") self.assertEqual(f"{x_clean.shape}", "torch.Size([5, 8])") def custom_pass(graph: torch.fx.Graph) -> torch.fx.Graph: for node in graph.nodes: if node.name == "arg3_1": assert node.meta["val"].size()[0] == 2 return graph class TestUnbacked(TestCase): @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_neq_assert(self, backend): @torch.compile(fullgraph=True, backend=backend) def func(a): torch._check(a.item() != 5) return a.item() * 10 func(torch.tensor([100])) with self.assertRaises(RuntimeError): func(torch.tensor([5])) # Test a situation where we generate a runtime assert i.e: u1==s1, then we specialize s1 # later on to a constant. @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_post_specialize_runtime_assert1(self, backend): @torch.compile(dynamic=True, backend=backend) def func(x, y): u0 = y.item() s0 = x.size()[0] s1 = x.size()[1] torch._check(u0 + s0 + s1 == 102) assert s0 == 2 return x * 10 func(torch.rand(2, 50), torch.tensor([50])) with self.assertRaises(RuntimeError): func(torch.rand(2, 50), torch.tensor([51])) @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._inductor.config.patch(post_grad_custom_pre_pass=custom_pass) @parametrize("backend", ["inductor", "eager"]) def test_post_specialize_runtime_assert2(self, backend): @torch.compile(dynamic=True, backend=backend) def func(x, y): u0 = y.item() s0 = x.size()[0] s1 = x.size()[1] torch._check(u0 + s0 + s1 == 102) return x * 10 func(torch.rand(2, 50), torch.tensor([50])) with self.assertRaises(RuntimeError): func(torch.rand(2, 50), torch.tensor([51])) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_sym_or_assert(self, backend): @torch.compile(fullgraph=True, backend=backend) def func(a, b): torch._check(operator.or_(a.item() == 5, b.item() == 5)) return a.item() * 10 func(torch.tensor([5]), torch.tensor([100])) func(torch.tensor([100]), torch.tensor([5])) def test_has_free_symbols(self): self.assertFalse(has_free_symbols(sympy.S.true)) self.assertFalse(has_free_symbols(sympy.Max(1, 10, evaluate=False))) self.assertFalse(has_free_symbols(sympy.sympify("1"))) self.assertFalse(has_free_symbols(sympy.sympify("1.1"))) self.assertTrue(has_free_symbols(sympy.sympify("a"))) self.assertTrue(has_free_symbols(sympy.sympify("a*2"))) self.assertTrue(has_free_symbols(sympy.sympify("a+b"))) @skipIfTorchDynamo("https://github.com/pytorch/pytorch/issues/156135") @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) def test_deferred_sym_eq_assert(self, backend): @torch.compile(fullgraph=True, backend=backend) def func(a, b): torch._check(b.item() == 5) return a * 10 func(torch.tensor([5]), torch.tensor([5])) with self.assertRaises(RuntimeError): func(torch.tensor([100]), torch.tensor([1])) @torch._dynamo.config.patch("capture_scalar_outputs", True) @parametrize("backend", ["inductor", "eager"]) @skipIfTorchDynamo("mark_unbacked is not traceable") def test_deferred_with_unbacked_input(self, backend): @torch.compile(fullgraph=True, dynamic=True, backend=backend) def func(a, b): torch._check(a.size()[0] == b.size()[0]) return a * 10 a = torch.rand(1, 1) b = torch.rand(1, 1) torch._dynamo.decorators.mark_unbacked(a, 0) torch._dynamo.decorators.mark_unbacked(b, 0) func(a, b) # inductor adds the check sometimes itself so it will be reflected # as AssertionError. with self.assertRaises((AssertionError, RuntimeError)): func(a, torch.rand(2, 1)) @pytest.mark.xfail(reason="https://github.com/pytorch/pytorch/issues/163785") @skipIfTorchDynamo("mark_unbacked is not traceable") def test_do_not_guard_unbacked_inputs(self): @torch.compile(fullgraph=True, dynamic=True, backend="inductor") def func(a, b): a.expand(b.shape) return a * 10 a = torch.rand(1, 1) b = torch.rand(1, 1) torch._dynamo.decorators.mark_unbacked(a, 0) torch._dynamo.decorators.mark_unbacked(a, 1) torch._dynamo.decorators.mark_unbacked(b, 0) torch._dynamo.decorators.mark_unbacked(b, 1) log_stream, ctx = logs_to_string("torch._dynamo.guards", "guards") with ctx(): func(a, b) func(torch.rand(4, 5), torch.rand(4, 5)) guards = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertFalse("SYMBOLIC_SHAPE_GUARD" in guards) @skipIfTorchDynamo("mark_unbacked is not traceable") def test_div_unbacked_eq_input_tensors(self): @torch.compile(fullgraph=True) def func(a, b): x = a.size()[0] y = b.size()[0] torch._check(x == y) if x // y == 1: a = a * 10 if 2 * x // y == 2: a = a * 20 return a a = torch.randn(10, 10) b = torch.randn(10, 20) torch._dynamo.decorators.mark_unbacked(a, 0) torch._dynamo.decorators.mark_unbacked(b, 0) func(a, b) @torch.compiler.config.patch(unbacked_sources="L['x'],L['y']") def test_div_unbacked_eq_input_ints(self): @torch.compile(fullgraph=True) def func(x, y): a = torch.rand(1) torch._check(x == y) if x // y == 1: a = a * 10 if 2 * x // y == 2: a = a * 20 return a func(10, 10) @skipIfTorchDynamo("mark_unbacked is not traceable") @torch.compiler.config.patch(unbacked_sources="L['y']") def test_div_unbacked_eq_globals(self): tensor = torch.rand(10, 44) y = 10 @torch.compile(fullgraph=True, dynamic=True) def func(): a = torch.rand(1) x = tensor.size()[0] torch._check(x == y) if x // y == 1: a = a * 10 if 2 * x // y == 2: a = a * 20 return a torch._dynamo.decorators.mark_unbacked(tensor, 0) func() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_div_unbacked_eq_item(self): @torch.compile(fullgraph=True) def func(a, b): x = a.item() y = b.item() torch._check(x == y) # TODO we should not need those torch checks. torch._check(x // y == 1) torch._check(2 * x // y == 2) if x // y == 1: a = a * 10 if 2 * x // y == 2: a = a * 20 return a a = torch.tensor([1]) b = torch.tensor([1]) func(a, b) class TestUbackedOps(TestCase): @fresh_cache() @skipIfTorchDynamo("not allowed to trace mark_unbacked") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_reshape1(self): cnt = CompileCounterWithBackend("inductor") # Reshape happens in place reshape (no-clone) # reshape u1 -> (u0*u0) def func(x, y): f = y.item() t1 = x.view((f, f)) t2 = x.reshape((f, f)) t3 = torch._ops.ops.aten.view_copy(x, (f, f)) return t1 * 10, t2 * 10, t3 compiled_func = torch.compile( fullgraph=True, backend=cnt, dynamic=True, )(func) # create a non-contiguous with data being even numbers in [0:cnt-1] # and reshape it into sqrt(cnt)*sqrt(cnt) def make_non_contiguous_tensor_and_test(cnt): # create a non-contiguous tensor x that is skipping odd indices. x = torch.arange(cnt * 2) x = x.as_strided((x.size()[0] // 2,), (2,)) torch._dynamo.decorators.mark_unbacked(x, 0) sz = torch.tensor([int(math.sqrt(cnt))]) compiled_result = compiled_func(x, sz) eager_result = func(x, sz) self.assertEqual(compiled_result, eager_result) log_stream, ctx = logs_to_string( "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): make_non_contiguous_tensor_and_test(4) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"): ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None view: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) view_1: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]) view_2: "i64[u0, u0][s7*u0, s7]cpu" = torch.ops.aten.view.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None mul_11: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_14: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None return (mul_11, mul_14, clone)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) make_non_contiguous_tensor_and_test(49) self.assertEqual(cnt.frame_count, 1) # Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization). # marking strides unbacked would have avoided the recompilation here. x = torch.arange(100) torch._dynamo.decorators.mark_unbacked(x, 0) log_stream, ctx = logs_to_string( "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): compiled_result = compiled_func(x, torch.tensor([10])) eager_result = func(x, torch.tensor([10])) self.assertEqual(compiled_result, eager_result) self.assertEqual(cnt.frame_count, 2) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"): ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None view: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) view_1: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]) view_2: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.view.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None clone: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.clone.default(view_2); view_2 = None mul_6: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None mul_9: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten.mul.Tensor(view_1, 10); view_1 = None return (mul_6, mul_9, clone)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) x = torch.arange(25) compiled_result = compiled_func(x, torch.tensor([5])) eager_result = func(x, torch.tensor([5])) self.assertEqual(cnt.frame_count, 2) @skipIfTorchDynamo("not allowed to trace mark_unbacked") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_reshape2(self): cnt = CompileCounterWithBackend("inductor") # This reshape requires a clone when the input is not contiguous and we can't compute strides. # reshape (u2, u3) -> (u0, u1) def func(x, y): u0, u1 = y.tolist() result1 = torch.reshape(x, (u0, u1)) return result1 * 10 compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) x = torch.randn(10, 10) # make x not contiguous. x = x.t_() torch._dynamo.decorators.mark_unbacked(x, 0) torch._dynamo.decorators.mark_unbacked(x, 1) log_stream, ctx = logs_to_string( "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): result_eager = func(x, torch.tensor([5, 20])) result_compiled = compiled_func(x, torch.tensor([5, 20])) self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 1) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """\ def forward(self, arg0_1: "i64[2][1]cpu", arg1_1: "Sym(u2)", arg2_1: "Sym(u3)", arg3_1: "f32[u2, u3][1, u2]cpu"): ge_1: "Sym(u2 >= 0)" = arg1_1 >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None ge_3: "Sym(u3 >= 0)" = arg2_1 >= 0 _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None ge_4: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_4, "Runtime assertion failed for expression u0 >= 0 on node 'ge_2'"); ge_4 = _assert_scalar_2 = None sym_sum: "Sym(u0 + 1)" = torch.sym_sum((1, _local_scalar_dense)) gt: "Sym(u0 + 1 > 0)" = sym_sum > 0; sym_sum = None _assert_scalar_3 = torch.ops.aten._assert_scalar.default(gt, "Runtime assertion failed for expression 0 < u0 + 1 on node 'gt'"); gt = _assert_scalar_3 = None select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None _local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None ge_5: "Sym(u1 >= 0)" = _local_scalar_dense_1 >= 0 _assert_scalar_4 = torch.ops.aten._assert_scalar.default(ge_5, "Runtime assertion failed for expression u1 >= 0 on node 'ge_3'"); ge_5 = _assert_scalar_4 = None sym_sum_1: "Sym(u1 + 1)" = torch.sym_sum((1, _local_scalar_dense_1)) gt_1: "Sym(u1 + 1 > 0)" = sym_sum_1 > 0; sym_sum_1 = None _assert_scalar_5 = torch.ops.aten._assert_scalar.default(gt_1, "Runtime assertion failed for expression 0 < u1 + 1 on node 'gt_1'"); gt_1 = _assert_scalar_5 = None mul: "Sym(u2*u3)" = arg1_1 * arg2_1; arg1_1 = arg2_1 = None mul_1: "Sym(u0*u1)" = _local_scalar_dense * _local_scalar_dense_1 eq: "Sym(Eq(u2*u3, u0*u1))" = mul == mul_1; mul = mul_1 = None _assert_scalar_6 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u2*u3, u0*u1) on node 'eq'"); eq = _assert_scalar_6 = None clone: "f32[u2, u3][Max(1, u3), 1]cpu" = torch.ops.aten.clone.default(arg3_1, memory_format = torch.contiguous_format); arg3_1 = None view: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.view.default(clone, [_local_scalar_dense, _local_scalar_dense_1]); clone = _local_scalar_dense = _local_scalar_dense_1 = None mul_21: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(view, 10); view = None return (mul_21,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) result_eager = func(x, torch.tensor([2, 50])) result_compiled = compiled_func(x, torch.tensor([2, 50])) self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 1) x = torch.randn(4, 4).t_() result_eager = func(x, torch.tensor([2, 8])) result_compiled = compiled_func(x, torch.tensor([2, 8])) self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 1) # Pass a contiguous tensor. A recompilation will happen due to 0/1 speciialization on stride. log_stream, ctx = logs_to_string( "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): # This used to hit could guard on data-dependent expression Eq(10, u3) x.stride[0]==10. and x.size()=[u2, u3]. # but not anymore since we use contiguous_or_false . # We need a way to mark strides unbacked to avoid the recompilation here. x = torch.randn(10, 10) torch._dynamo.decorators.mark_unbacked(x, 0) torch._dynamo.decorators.mark_unbacked(x, 1) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) result_compiled = compiled_func(x, torch.tensor([2, 50])) result_eager = func(x, torch.tensor([2, 50])) self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 2) x = torch.randn(4, 4) result_eager = func(x, torch.tensor([2, 8])) result_compiled = compiled_func(x, torch.tensor([2, 8])) self.assertEqual(result_compiled, result_eager) self.assertEqual(cnt.frame_count, 2) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_slice(self): from torch.fx.experimental.symbolic_shapes import statically_known_true # standard slice def f1(x, xs): u0, u1 = xs.tolist() # in this test we add the torch checks not to avoid DDE but to ensure # that we pick specific path during compilation. torch._check(u0 >= 0) torch._check(u0 <= x.size(0)) torch._check(u1 >= 0) torch._check(u1 <= x.size(0)) torch._check(u0 <= u1) out = x[u0:u1] assert statically_known_true(out.size(0) == (u1 - u0)) return out x, xs = torch.randn(10), torch.tensor([3, 6]) fn1 = torch.compile(f1, fullgraph=True, backend="inductor") self.assertEqual(fn1(x, xs).size(0), 3) self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs))) with self.assertRaises(RuntimeError): fn1(x, torch.tensor([-1, 5])) # known negative slice def f2(x, n): u0 = n.item() torch._check(u0 > 1) torch._check(u0 <= x.size(0)) out = x[-u0:] assert statically_known_true(out.size(0) == u0) return out x, n = torch.randn(10), torch.tensor([5]) fn2 = torch.compile(f2, fullgraph=True, backend="inductor") self.assertEqual(fn2(x, n).size(0), 5) self.assertTrue(torch.allclose(fn2(x, n), f2(x, n))) with self.assertRaises(RuntimeError): fn2(x, torch.tensor([-5])) # general case: no known info def f3(x, xs): u0, u1 = xs.tolist() return x[u0:u1] log_stream, ctx = logs_to_string( "torch._inductor.compile_fx", "post_grad_graphs" ) cnts = CompileCounterWithBackend("inductor") x, xs = torch.randn(10), torch.tensor([3, 6]) with ctx(): fn3 = torch.compile(f3, fullgraph=True, backend=cnts) xs = torch.tensor([-9, -1]) # negative case self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs))) xs = torch.tensor([-1000, 1000]) # out of bounds self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs))) xs = torch.tensor([2, -2]) # mixed self.assertTrue(torch.allclose(fn3(x, xs), f3(x, xs))) self.assertEqual(cnts.frame_count, 1) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """\ select: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 0) _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(select); select = None select_1: "i64[][]cpu" = torch.ops.aten.select.int(arg0_1, 0, 1); arg0_1 = None _local_scalar_dense_1: "Sym(u1)" = torch.ops.aten._local_scalar_dense.default(select_1); select_1 = None slice_1: "f32[u2][1]cpu" = torch.ops.aten.slice.Tensor(arg1_1, 0, _local_scalar_dense, _local_scalar_dense_1); arg1_1 = _local_scalar_dense = _local_scalar_dense_1 = None sym_size_int: "Sym(u2)" = torch.ops.aten.sym_size.int(slice_1, 0) ge_1: "Sym(u2 >= 0)" = sym_size_int >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u2 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None le: "Sym(u2 <= 10)" = sym_size_int <= 10; sym_size_int = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression u2 <= 10 on node 'le'"); le = _assert_scalar_1 = None sym_storage_offset_default: "Sym(u3)" = torch.ops.aten.sym_storage_offset.default(slice_1) ge_2: "Sym(u3 >= 0)" = sym_storage_offset_default >= 0; sym_storage_offset_default = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u3 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_2 = None return (slice_1,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_unbacked_slice_cpp_wrapper(self): self.test_unbacked_slice() @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_slice_with_step(self): def f1(x, xs): u0, u1 = xs.tolist() out = x[u0:u1:5] return out x, xs = torch.randn(10), torch.tensor([2, -2]) fn1 = torch.compile(f1, fullgraph=True, backend="inductor") self.assertTrue(torch.allclose(fn1(x, xs), f1(x, xs))) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_unbacked_slice_with_step_cpp_wrapper(self): self.test_unbacked_slice_with_step() @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_tensor_split(self): def f1(x, xs): xs = torch.tensor(xs.tolist()) return torch.tensor_split(x, xs) x = torch.randn(20) xs = torch.tensor([5, 10, 15]) fn = torch.compile(f1, fullgraph=True, backend="inductor") def compare(x, xs): for i, j in zip(f1(x, xs), fn(x, xs)): self.assertTrue(torch.allclose(i, j)) compare(x, xs) xs = torch.tensor([-15, 9, 10, 11]) compare(x, xs) xs = torch.tensor([-15, -10, -5, -2]) compare(x, xs) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_tensor_split_cpp_wrapper(self): self.test_tensor_split() @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) def test_nonzero_slice(self): def f(x): nz = x.nonzero() return nz[:-1] x = torch.randn(3, 4) fn = torch.compile(f, fullgraph=True, backend="inductor") self.assertTrue(torch.allclose(f(x), fn(x))) y = torch.zeros(3, 4) self.assertTrue(torch.allclose(f(y), fn(y))) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_nonzero_slice_cpp_wrapper(self): self.test_nonzero_slice() @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) def test_nonzero_select(self): def f(x): nz = x.nonzero() return nz[-1] + nz[0] x = torch.randn(3, 4) fn = torch.compile(f, fullgraph=True, backend="inductor") self.assertTrue(torch.allclose(f(x), fn(x))) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_nonzero_select_cpp_wrapper(self): self.test_nonzero_select() @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._dynamo.config.patch("capture_dynamic_output_shape_ops", True) def test_padnd(self): import torch.nn.functional as F def f(x, xs, y): u0, u1 = xs.tolist() for u in [u0, u1]: torch._check(u >= 0) z = F.pad(x, (u0, u1, u0, u1)) return z @ y x = torch.randn(8, 8) xs = torch.tensor([2, 2]) y = torch.randn(12, 4) fn = torch.compile(f, fullgraph=True, backend="inductor") fn(x, xs, y) @unittest.skip("this test fails due to inductor/autograd issue #153041") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_non_contigious_reshape_failing(self): # reshape u1 -> (u0*u0) # this result in the tensor "i64[u0, u0][s7*u0, s7]. # reshape happens in place reshape (no-clone) def func(x, y): f = y.item() t1 = x.view((f, f)) t2 = x.reshape((f, f)) return t1, t2 # create a non-contiguous with data being even numbers in [0:cnt-1] def make_non_contiguous_tensor(cnt): # create a non-contiguous tensor x that is skipping odd indices. x = torch.arange(cnt * 2) x = x.as_strided((x.size()[0] // 2,), (2,)) return x x = make_non_contiguous_tensor(4) torch._dynamo.decorators.mark_unbacked(x, 0) compiled_func = torch.compile( fullgraph=True, backend="inductor", )(func) compiled_result = compiled_func(x, torch.tensor([2])) eager_result = func(x, torch.tensor([2])) self.assertEqual(compiled_result, eager_result) @skipIfTorchDynamo("not allowed to trace mark_unbacked") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_invalid_view_unbacked_view(self): cnt = CompileCounterWithBackend("inductor") # This view (u2, u3) -> (u0, u1) can't happen in general unless we know that input is contiguous or we have # hints to to compute strides. def func(x, y): u0, u1 = y.tolist() result2 = x.view(u0, u1) * 10 return result2 compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) x = torch.randn(10, 10) # make x not contiguous. x = x.t_() torch._dynamo.decorators.mark_unbacked(x, 0) torch._dynamo.decorators.mark_unbacked(x, 1) with self.assertRaises(torch._dynamo.exc.UserError): # throws a data dependent error. compiled_func(x, torch.tensor([5, 20])) @skipIfTorchDynamo() def test_unbind_not_dynamic(self): cnt = CompileCounter() @torch.compile(fullgraph=True, dynamic=True, backend=cnt) def func(y): return y.unbind(dim=2), y * 10 func(torch.ones(5, 6, 7, 8)) self.assertEqual(cnt.frame_count, 1) # it can be dynamic in all dimensions except dim=2 func(torch.ones(4, 9, 7, 10)) self.assertEqual(cnt.frame_count, 1) func(torch.ones(5, 6, 8, 8)) func(torch.ones(5, 6, 9, 8)) self.assertEqual(cnt.frame_count, 3) @skipIfTorchDynamo("not allowed to trace mark_unbacked") @fresh_cache() def test_unbacked_contiguous(self): cnt = CompileCounterWithBackend("inductor") def func(x): contig = x.contiguous() return (contig + 1) * 100 compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) x = torch.randn(10, 10) # make x not contiguous. x = x.t_() torch._dynamo.decorators.mark_unbacked(x, 0) torch._dynamo.decorators.mark_unbacked(x, 1) log_stream, ctx = logs_to_string( "torch._inductor.compile_fx", "post_grad_graphs" ) with ctx(): compiled_func(x) self.assertEqual(compiled_func(x), func(x)) y = torch.rand(20, 20).t() self.assertEqual(compiled_func(y), func(y)) self.assertEqual(cnt.frame_count, 1) output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( output, """\ ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None clone: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.clone.default(arg2_1, memory_format = torch.contiguous_format); arg2_1 = None add_3: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(clone, 1); clone = None mul_6: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add_3, 100); add_3 = None return (mul_6,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) log_stream, ctx = logs_to_string( "torch._inductor.compile_fx", "post_grad_graphs" ) with ctx(): # recompilation will happen due to stride specialization. y = torch.rand(20, 20) torch._dynamo.decorators.mark_unbacked(y, 0) torch._dynamo.decorators.mark_unbacked(y, 1) self.assertEqual(compiled_func(y), func(y)) self.assertEqual(cnt.frame_count, 2) output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() # No clone this time since input is contiguous. self.assertExpectedInline( output, """\ ge_1: "Sym(u0 >= 0)" = arg0_1 >= 0; arg0_1 = None _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u0 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None ge_3: "Sym(u1 >= 0)" = arg1_1 >= 0; arg1_1 = None _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_3, "Runtime assertion failed for expression u1 >= 0 on node 'ge_1'"); ge_3 = _assert_scalar_1 = None add: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.add.Tensor(arg2_1, 1); arg2_1 = None mul_5: "f32[u0, u1][Max(1, u1), 1]cpu" = torch.ops.aten.mul.Tensor(add, 100); add = None return (mul_5,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_select_index(self): cnt = CompileCounterWithBackend("inductor") def func(x, y): u0 = y.item() return ( torch.select(x, 0, u0), torch.select(x, 1, u0), torch.select(x, 2, u0), ) compiled_func = torch.compile(fullgraph=True, backend=cnt, dynamic=True)(func) x = torch.rand(3, 3, 3) zero = torch.tensor([0]) pos = torch.tensor([1]) # code can handle both negative and positive indices. neg = torch.tensor([-1]) log_stream, ctx = logs_to_string( "torch._inductor.compile_fx", "post_grad_graphs" ) with ctx(): self.assertEqual(compiled_func(x, zero), func(x, zero)) output = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( output, """\ _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None select: "f32[s77, s77][s77, 1]cpu" = torch.ops.aten.select.int(arg2_1, 0, _local_scalar_dense) select_1: "f32[s77, s77][s77**2, 1]cpu" = torch.ops.aten.select.int(arg2_1, 1, _local_scalar_dense) select_2: "f32[s77, s77][s77**2, s77]cpu" = torch.ops.aten.select.int(arg2_1, 2, _local_scalar_dense); arg2_1 = _local_scalar_dense = None return (select, select_1, select_2)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) self.assertEqual(compiled_func(x, pos), func(x, pos)) self.assertEqual(compiled_func(x, neg), func(x, neg)) self.assertEqual(cnt.frame_count, 1) def func2(x, y): u0, u1 = y.tolist() return torch.select(x, 0, u0 + u1) compiled_func2 = torch.compile(fullgraph=True, backend=cnt, dynamic=False)( func2 ) zero = torch.tensor([0, 0]) pos = torch.tensor([1, 1]) neg = torch.tensor([-1, -1]) self.assertEqual(compiled_func2(x, pos), func2(x, pos)) self.assertEqual(compiled_func2(x, neg), func2(x, neg)) self.assertEqual(compiled_func2(x, zero), func2(x, zero)) self.assertEqual(cnt.frame_count, 2) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_select_2(self): class M(torch.nn.Module): def forward(self, x): nz = x.nonzero() return nz[-1] mod = M() x = torch.randn(4) self.assertEqual(torch.compile(mod)(x), mod(x)) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_select_index_with_check(self): def func3(x, y): u0 = y.item() # Test that taking the non-unbacked path works fine also. torch._check(u0 >= 0) return (torch.select(x, 1, u0),) compiled_func3 = torch.compile( fullgraph=True, backend="inductor", dynamic=True )(func3) x = torch.rand(3, 3, 3) zero = torch.tensor([0]) pos = torch.tensor([1]) print(compiled_func3(x, pos)) self.assertEqual(compiled_func3(x, pos), func3(x, pos)) self.assertEqual(compiled_func3(x, zero), func3(x, zero)) @fresh_cache() @torch._dynamo.config.patch("capture_scalar_outputs", True) @torch._inductor.config.patch("cpp_wrapper", True) def test_unbacked_select_index_cpp_wrapper(self): self.test_unbacked_select_index() @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_select2(self): def f(idx, x): x = x.select(0, idx.item()) return x @ x x = torch.randn(3, 3, 3) idx = torch.tensor(1, dtype=torch.int64) out = torch.compile(f)(idx, x) self.assertEqual(out, f(idx, x)) def test_trunc_int_div_true(self): @torch.compile(backend="inductor", dynamic=True, fullgraph=True) def f(x, s13, s57, s77): torch._check(s13 >= 0) torch._check(s57 >= 0) torch._check(s77 >= 0) if int(s13 * ((s57 // s13) + (s77 // s13)) / s13) >= 1: return x * 2 else: return x * 100 # ensure we compile this with no errors. x = torch.rand(10) f(x, 4, 4096, 3920) @skipIfTorchDynamo("not allowed to trace mark_unbacked") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_reshape3(self): def func(x): x = x.as_strided([x.size()[0], 1536], [2048, 1]) result1 = x.view(x.size()[0], -1, 128) return result1 * 10 compiled = torch.compile(fullgraph=True, backend="inductor")(func) x = torch.randn(10, 2048) torch._dynamo.decorators.mark_unbacked(x, 0) self.assertEqual(func(x), compiled(x)) @fresh_cache() @skipIfTorchDynamo("not allowed to trace mark_unbacked") @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_reshape_copy(self): cnt = CompileCounterWithBackend("inductor") # Reshape happens in place reshape (no-clone) # reshape u1 -> (u0*u0) def func(x, y): f = y.item() t3 = torch._ops.ops.aten._reshape_copy(x, (f, f)) return t3 compiled_func = torch.compile( fullgraph=True, backend=cnt, dynamic=True, )(func) # create a non-contiguous with data being even numbers in [0:cnt-1] # and reshape it into sqrt(cnt)*sqrt(cnt) def make_non_contiguous_tensor_and_test(cnt): # create a non-contiguous tensor x that is skipping odd indices. x = torch.arange(cnt * 2) x = x.as_strided((x.size()[0] // 2,), (2,)) torch._dynamo.decorators.mark_unbacked(x, 0) sz = torch.tensor([int(math.sqrt(cnt))]) compiled_result = compiled_func(x, sz) eager_result = func(x, sz) self.assertEqual(compiled_result, eager_result) log_stream, ctx = logs_to_string( "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): make_non_contiguous_tensor_and_test(4) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "Sym(s7)", arg3_1: "i64[u1][s7]cpu"): ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None _reshape_copy: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten._reshape_copy.default(arg3_1, [_local_scalar_dense, _local_scalar_dense]); arg3_1 = _local_scalar_dense = None return (_reshape_copy,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) make_non_contiguous_tensor_and_test(49) self.assertEqual(cnt.frame_count, 1) # Pass in a contiguous tensor, it will recompile due to stride being 1 (0/1 specialization). # marking strides unbacked would have avoided the recompilation here. x = torch.arange(100) torch._dynamo.decorators.mark_unbacked(x, 0) log_stream, ctx = logs_to_string( "torch._functorch._aot_autograd.graph_capture", "aot_graphs" ) with ctx(): compiled_result = compiled_func(x, torch.tensor([10])) eager_result = func(x, torch.tensor([10])) self.assertEqual(compiled_result, eager_result) self.assertEqual(cnt.frame_count, 2) aot_graphs = "\n".join(log_stream.getvalue().strip().split("\n")[4:]).strip() self.assertExpectedInline( aot_graphs, """\ def forward(self, arg0_1: "i64[1][1]cpu", arg1_1: "Sym(u1)", arg2_1: "i64[u1][1]cpu"): ge_1: "Sym(u1 >= 0)" = arg1_1 >= 0 _assert_scalar = torch.ops.aten._assert_scalar.default(ge_1, "Runtime assertion failed for expression u1 >= 0 on node 'ge'"); ge_1 = _assert_scalar = None _local_scalar_dense: "Sym(u0)" = torch.ops.aten._local_scalar_dense.default(arg0_1); arg0_1 = None ge_2: "Sym(u0 >= 0)" = _local_scalar_dense >= 0 _assert_scalar_1 = torch.ops.aten._assert_scalar.default(ge_2, "Runtime assertion failed for expression u0 >= 0 on node 'ge_1'"); ge_2 = _assert_scalar_1 = None pow_1: "Sym(u0**2)" = _local_scalar_dense ** 2 eq: "Sym(Eq(u1, u0**2))" = arg1_1 == pow_1; arg1_1 = pow_1 = None _assert_scalar_2 = torch.ops.aten._assert_scalar.default(eq, "Runtime assertion failed for expression Eq(u1, u0**2) on node 'eq'"); eq = _assert_scalar_2 = None _reshape_copy: "i64[u0, u0][Max(1, u0), 1]cpu" = torch.ops.aten._reshape_copy.default(arg2_1, [_local_scalar_dense, _local_scalar_dense]); arg2_1 = _local_scalar_dense = None return (_reshape_copy,)""", # noqa: B950 ignore_comments=True, ignore_empty_lines=True, ) x = torch.arange(25) compiled_result = compiled_func(x, torch.tensor([5])) eager_result = func(x, torch.tensor([5])) self.assertEqual(cnt.frame_count, 2) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_item(self): def func(): _x_ms = torch.tensor([True, False], dtype=torch.int64) _mask_ms = torch.zeros_like(_x_ms, dtype=torch.bool) _mask_ms[:1] = True var_node_2 = torch.masked_select(_x_ms, _mask_ms) var_node_0 = var_node_2.item() return var_node_0 result_original = func() compiled_program = torch.compile(func, fullgraph=True, dynamic=True) result_compiled = compiled_program() self.assertEqual(result_original, result_compiled) def test_unbacked_item_set_item(self): def my_arithmetic(a, b): wrk = torch.zeros(a.size(0)) for i in range(a.size(0)): idx = b[i].item() wrk[idx] += 1 return wrk compiled = torch.compile(my_arithmetic, fullgraph=True, disable=False) a = torch.randn([9]) b = torch.ones(9, dtype=torch.int32) compiled(a, b) self.assertEqual(compiled(a, b), my_arithmetic(a, b)) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_item_set_item2(self): def accumulate(X0, start): start = start.item() N = 3 result = X0[start] for i in range(N): result += X0[start + 1 + i] return result compiled = torch.compile(accumulate, fullgraph=True) X0 = torch.randn(10, 10) self.assertEqual( accumulate(X0, torch.tensor([1])), compiled(X0, torch.tensor([1])) ) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_unbacked_item_set_item3(self): def func(x, y): u0 = y.item() x[u0] = 0 return x compiled = torch.compile(func, fullgraph=True, disable=False) b = torch.tensor([0]) a = torch.ones(9, dtype=torch.int32) compiled(a, b) self.assertEqual(compiled(a, b), func(a, b)) @torch._dynamo.config.patch("capture_scalar_outputs", True) def test_select_scatter_unbacked_index(self): def func(x, y): u0 = y.item() # Create a scalar tensor to scatter into the selected index scalar_src = torch.tensor(42, dtype=x.dtype) return x.select_scatter(scalar_src, 0, u0) compiled = torch.compile(func, fullgraph=True, dynamic=True, backend="inductor") b = torch.tensor([0]) a = torch.ones(9, dtype=torch.int32) self.assertEqual(compiled(a, b), func(a, b)) instantiate_parametrized_tests(TestUnbacked) if __name__ == "__main__": run_tests()