mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
# Feature Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s. This PR adds a few features to achieve this invariance. - Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression. - Preprocess the expression with this expansion prior to pattern matching. - Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`. # Test plan This PR adds a few new unit tests: - Added a unit test specifically for `expr.expand(identity=True)`. - Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops. I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000 Approved by: https://github.com/eellison, https://github.com/jansel
978 lines
33 KiB
Python
978 lines
33 KiB
Python
# Owner(s): ["oncall: pt2"]
|
|
|
|
import functools
|
|
import itertools
|
|
import math
|
|
import pickle
|
|
import sys
|
|
from typing import Callable
|
|
|
|
import sympy
|
|
|
|
import torch
|
|
import torch.fx as fx
|
|
from sympy.core.relational import is_ge, is_gt, is_le, is_lt
|
|
from torch.testing._internal.common_device_type import skipIf
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
run_tests,
|
|
TEST_Z3,
|
|
TestCase,
|
|
)
|
|
from torch.utils._sympy.functions import (
|
|
FloorDiv,
|
|
Identity,
|
|
OpaqueUnaryFn_cos,
|
|
simple_floordiv_gcd,
|
|
)
|
|
from torch.utils._sympy.interp import sympy_interp
|
|
from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity
|
|
from torch.utils._sympy.reference import (
|
|
PythonReferenceAnalysis,
|
|
ReferenceAnalysis,
|
|
TensorReferenceAnalysis,
|
|
)
|
|
from torch.utils._sympy.singleton_int import SingletonInt
|
|
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
|
from torch.utils._sympy.value_ranges import ValueRanges
|
|
from torch._inductor.bounds import ValueRangeAnalysis
|
|
|
|
|
|
UNARY_OPS = [
|
|
"reciprocal",
|
|
"square",
|
|
"abs",
|
|
"neg",
|
|
"exp",
|
|
"log",
|
|
"sqrt",
|
|
"floor",
|
|
"ceil",
|
|
]
|
|
BINARY_OPS = [
|
|
"truediv",
|
|
"floordiv",
|
|
# "truncdiv", # TODO
|
|
# NB: pow is float_pow
|
|
"add",
|
|
"mul",
|
|
"sub",
|
|
"pow",
|
|
"pow_by_natural",
|
|
"minimum",
|
|
"maximum",
|
|
"mod",
|
|
"bitwise_and",
|
|
"bitwise_or",
|
|
]
|
|
BITWISE_OPS = [
|
|
"bitwise_and",
|
|
"bitwise_or",
|
|
]
|
|
|
|
UNARY_BOOL_OPS = ["not_"]
|
|
BINARY_BOOL_OPS = ["or_", "and_"]
|
|
COMPARE_OPS = ["eq", "ne", "lt", "gt", "le", "ge"]
|
|
|
|
# a mix of constants, powers of two, primes
|
|
CONSTANTS = [
|
|
-1,
|
|
0,
|
|
1,
|
|
2,
|
|
3,
|
|
4,
|
|
5,
|
|
8,
|
|
16,
|
|
32,
|
|
64,
|
|
100,
|
|
101,
|
|
2**24,
|
|
2**32,
|
|
2**37 - 1,
|
|
sys.maxsize - 1,
|
|
sys.maxsize,
|
|
]
|
|
# less constants for N^2 situations
|
|
LESS_CONSTANTS = [-1, 0, 1, 2, 100]
|
|
# SymPy relational types.
|
|
RELATIONAL_TYPES = [sympy.Eq, sympy.Ne, sympy.Gt, sympy.Ge, sympy.Lt, sympy.Le]
|
|
|
|
|
|
def valid_unary(fn, v):
|
|
if fn == "log" and v <= 0:
|
|
return False
|
|
elif fn == "reciprocal" and v == 0:
|
|
return False
|
|
elif fn == "sqrt" and v < 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def valid_binary(fn, a, b):
|
|
if fn == "pow" and (
|
|
# sympy will expand to x*x*... for integral b; don't do it if it's big
|
|
b > 4
|
|
# no imaginary numbers
|
|
or a <= 0
|
|
# 0**0 is undefined
|
|
or (a == b == 0)
|
|
):
|
|
return False
|
|
elif fn == "pow_by_natural" and (
|
|
# sympy will expand to x*x*... for integral b; don't do it if it's big
|
|
b > 4
|
|
or b < 0
|
|
or (a == b == 0)
|
|
):
|
|
return False
|
|
elif fn == "mod" and (a < 0 or b <= 0):
|
|
return False
|
|
elif (fn in ["div", "truediv", "floordiv"]) and b == 0:
|
|
return False
|
|
return True
|
|
|
|
|
|
def generate_range(vals):
|
|
for a1, a2 in itertools.product(vals, repeat=2):
|
|
if a1 in [sympy.true, sympy.false]:
|
|
if a1 == sympy.true and a2 == sympy.false:
|
|
continue
|
|
else:
|
|
if a1 > a2:
|
|
continue
|
|
# ranges that only admit infinite values are not interesting
|
|
if a1 == sympy.oo or a2 == -sympy.oo:
|
|
continue
|
|
yield ValueRanges(a1, a2)
|
|
|
|
|
|
class TestNumbers(TestCase):
|
|
def test_int_infinity(self):
|
|
self.assertIsInstance(int_oo, IntInfinity)
|
|
self.assertIsInstance(-int_oo, NegativeIntInfinity)
|
|
self.assertTrue(int_oo.is_integer)
|
|
# is tests here are for singleton-ness, don't use it for comparisons
|
|
# against numbers
|
|
self.assertIs(int_oo + int_oo, int_oo)
|
|
self.assertIs(int_oo + 1, int_oo)
|
|
self.assertIs(int_oo - 1, int_oo)
|
|
self.assertIs(-int_oo - 1, -int_oo)
|
|
self.assertIs(-int_oo + 1, -int_oo)
|
|
self.assertIs(-int_oo + (-int_oo), -int_oo)
|
|
self.assertIs(-int_oo - int_oo, -int_oo)
|
|
self.assertIs(1 + int_oo, int_oo)
|
|
self.assertIs(1 - int_oo, -int_oo)
|
|
self.assertIs(int_oo * int_oo, int_oo)
|
|
self.assertIs(2 * int_oo, int_oo)
|
|
self.assertIs(int_oo * 2, int_oo)
|
|
self.assertIs(-1 * int_oo, -int_oo)
|
|
self.assertIs(-int_oo * int_oo, -int_oo)
|
|
self.assertIs(2 * -int_oo, -int_oo)
|
|
self.assertIs(-int_oo * 2, -int_oo)
|
|
self.assertIs(-1 * -int_oo, int_oo)
|
|
self.assertIs(int_oo / 2, sympy.oo)
|
|
self.assertIs(-(-int_oo), int_oo) # noqa: B002
|
|
self.assertIs(abs(int_oo), int_oo)
|
|
self.assertIs(abs(-int_oo), int_oo)
|
|
self.assertIs(int_oo**2, int_oo)
|
|
self.assertIs((-int_oo) ** 2, int_oo)
|
|
self.assertIs((-int_oo) ** 3, -int_oo)
|
|
self.assertEqual(int_oo**-1, 0)
|
|
self.assertEqual((-int_oo) ** -1, 0)
|
|
self.assertIs(int_oo**int_oo, int_oo)
|
|
self.assertTrue(int_oo == int_oo)
|
|
self.assertFalse(int_oo != int_oo)
|
|
self.assertTrue(-int_oo == -int_oo)
|
|
self.assertFalse(int_oo == 2)
|
|
self.assertTrue(int_oo != 2)
|
|
self.assertFalse(int_oo == sys.maxsize)
|
|
self.assertTrue(int_oo >= sys.maxsize)
|
|
self.assertTrue(int_oo >= 2)
|
|
self.assertTrue(int_oo >= -int_oo)
|
|
|
|
def test_relation(self):
|
|
self.assertIs(sympy.Add(2, int_oo), int_oo)
|
|
self.assertFalse(-int_oo > 2)
|
|
|
|
def test_lt_self(self):
|
|
self.assertFalse(int_oo < int_oo)
|
|
self.assertIs(min(-int_oo, -4), -int_oo)
|
|
self.assertIs(min(-int_oo, -int_oo), -int_oo)
|
|
|
|
def test_float_cast(self):
|
|
self.assertEqual(float(int_oo), math.inf)
|
|
self.assertEqual(float(-int_oo), -math.inf)
|
|
|
|
def test_mixed_oo_int_oo(self):
|
|
# Arbitrary choice
|
|
self.assertTrue(int_oo < sympy.oo)
|
|
self.assertFalse(int_oo > sympy.oo)
|
|
self.assertTrue(sympy.oo > int_oo)
|
|
self.assertFalse(sympy.oo < int_oo)
|
|
self.assertIs(max(int_oo, sympy.oo), sympy.oo)
|
|
self.assertTrue(-int_oo > -sympy.oo)
|
|
self.assertIs(min(-int_oo, -sympy.oo), -sympy.oo)
|
|
|
|
|
|
class TestValueRanges(TestCase):
|
|
@parametrize("fn", UNARY_OPS)
|
|
@parametrize("dtype", ("int", "float"))
|
|
def test_unary_ref(self, fn, dtype):
|
|
dtype = {"int": sympy.Integer, "float": sympy.Float}[dtype]
|
|
for v in CONSTANTS:
|
|
if not valid_unary(fn, v):
|
|
continue
|
|
with self.subTest(v=v):
|
|
v = dtype(v)
|
|
ref_r = getattr(ReferenceAnalysis, fn)(v)
|
|
r = getattr(ValueRangeAnalysis, fn)(v)
|
|
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
|
self.assertEqual(r.lower, r.upper)
|
|
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
|
self.assertEqual(ref_r, r.lower)
|
|
|
|
def test_pow_half(self):
|
|
ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5))
|
|
|
|
@parametrize("fn", BINARY_OPS)
|
|
@parametrize("dtype", ("int", "float"))
|
|
def test_binary_ref(self, fn, dtype):
|
|
to_dtype = {"int": sympy.Integer, "float": sympy.Float}
|
|
# Don't test bitwise methods since value range analysis on a singleton
|
|
# range may not return a singleton result.
|
|
if fn in BITWISE_OPS:
|
|
return
|
|
# Don't test float on int only methods
|
|
if dtype == "float" and fn in ["pow_by_natural", "mod"]:
|
|
return
|
|
dtype = to_dtype[dtype]
|
|
for a, b in itertools.product(CONSTANTS, repeat=2):
|
|
if not valid_binary(fn, a, b):
|
|
continue
|
|
a = dtype(a)
|
|
b = dtype(b)
|
|
with self.subTest(a=a, b=b):
|
|
r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
if r == ValueRanges.unknown():
|
|
continue
|
|
ref_r = getattr(ReferenceAnalysis, fn)(a, b)
|
|
|
|
self.assertEqual(r.lower.is_integer, r.upper.is_integer)
|
|
self.assertEqual(ref_r.is_integer, r.upper.is_integer)
|
|
self.assertEqual(r.lower, r.upper)
|
|
self.assertEqual(ref_r, r.lower)
|
|
|
|
def test_mul_zero_unknown(self):
|
|
self.assertEqual(
|
|
ValueRangeAnalysis.mul(ValueRanges.wrap(0), ValueRanges.unknown()),
|
|
ValueRanges.wrap(0),
|
|
)
|
|
self.assertEqual(
|
|
ValueRangeAnalysis.mul(ValueRanges.wrap(0.0), ValueRanges.unknown()),
|
|
ValueRanges.wrap(0.0),
|
|
)
|
|
|
|
@parametrize("fn", UNARY_BOOL_OPS)
|
|
def test_unary_bool_ref_range(self, fn):
|
|
vals = [sympy.false, sympy.true]
|
|
for a in generate_range(vals):
|
|
with self.subTest(a=a):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a)
|
|
unique = set()
|
|
for a0 in vals:
|
|
if a0 not in a:
|
|
continue
|
|
with self.subTest(a0=a0):
|
|
r = getattr(ReferenceAnalysis, fn)(a0)
|
|
self.assertIn(r, ref_r)
|
|
unique.add(r)
|
|
if ref_r.lower == ref_r.upper:
|
|
self.assertEqual(len(unique), 1)
|
|
else:
|
|
self.assertEqual(len(unique), 2)
|
|
|
|
@parametrize("fn", BINARY_BOOL_OPS + BITWISE_OPS)
|
|
def test_binary_bool_ref_range(self, fn):
|
|
vals = [sympy.false, sympy.true]
|
|
for a, b in itertools.product(generate_range(vals), repeat=2):
|
|
with self.subTest(a=a, b=b):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
unique = set()
|
|
for a0, b0 in itertools.product(vals, repeat=2):
|
|
if a0 not in a or b0 not in b:
|
|
continue
|
|
with self.subTest(a0=a0, b0=b0):
|
|
r = getattr(ReferenceAnalysis, fn)(a0, b0)
|
|
self.assertIn(r, ref_r)
|
|
unique.add(r)
|
|
if ref_r.lower == ref_r.upper:
|
|
self.assertEqual(len(unique), 1)
|
|
else:
|
|
self.assertEqual(len(unique), 2)
|
|
|
|
@parametrize("fn", UNARY_OPS)
|
|
def test_unary_ref_range(self, fn):
|
|
# TODO: bring back sympy.oo testing for float unary fns
|
|
vals = CONSTANTS
|
|
for a in generate_range(vals):
|
|
with self.subTest(a=a):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a)
|
|
for a0 in CONSTANTS:
|
|
if a0 not in a:
|
|
continue
|
|
if not valid_unary(fn, a0):
|
|
continue
|
|
with self.subTest(a0=a0):
|
|
r = getattr(ReferenceAnalysis, fn)(sympy.Integer(a0))
|
|
self.assertIn(r, ref_r)
|
|
|
|
# This takes about 4s for all the variants
|
|
@parametrize("fn", BINARY_OPS + COMPARE_OPS)
|
|
def test_binary_ref_range(self, fn):
|
|
# TODO: bring back sympy.oo testing for float unary fns
|
|
vals = LESS_CONSTANTS
|
|
for a, b in itertools.product(generate_range(vals), repeat=2):
|
|
# don't attempt pow on exponents that are too large (but oo is OK)
|
|
if fn == "pow" and b.upper > 4 and b.upper != sympy.oo:
|
|
continue
|
|
with self.subTest(a=a, b=b):
|
|
for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2):
|
|
if a0 not in a or b0 not in b:
|
|
continue
|
|
if not valid_binary(fn, a0, b0):
|
|
continue
|
|
with self.subTest(a0=a0, b0=b0):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
r = getattr(ReferenceAnalysis, fn)(
|
|
sympy.Integer(a0), sympy.Integer(b0)
|
|
)
|
|
if r.is_finite:
|
|
self.assertIn(r, ref_r)
|
|
|
|
# stronger test specially for bitwise ops
|
|
@parametrize("fn", BITWISE_OPS)
|
|
def test_bitwise_ref_range(self, fn):
|
|
# N^4 complexity
|
|
vals = range(-4, 5)
|
|
for a, b in itertools.product(generate_range(vals), repeat=2):
|
|
with self.subTest(a=a, b=b):
|
|
for a0, b0 in itertools.product(vals, repeat=2):
|
|
if a0 not in a or b0 not in b:
|
|
continue
|
|
with self.subTest(a0=a0, b0=b0):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
r = getattr(ReferenceAnalysis, fn)(a0, b0)
|
|
self.assertIn(r, ref_r)
|
|
|
|
# test that bitwise ops can take bool arguments
|
|
bool_vals = [
|
|
(3, sympy.true),
|
|
(3, sympy.false),
|
|
(sympy.true, 3),
|
|
(sympy.false, 3),
|
|
(sympy.true, sympy.true),
|
|
(sympy.true, sympy.false),
|
|
(sympy.false, sympy.true),
|
|
(sympy.false, sympy.false),
|
|
]
|
|
for a, b in bool_vals:
|
|
with self.subTest(a=a, b=b):
|
|
ref_r = getattr(ValueRangeAnalysis, fn)(a, b)
|
|
r = getattr(ReferenceAnalysis, fn)(a, b)
|
|
self.assertIn(r, ref_r)
|
|
|
|
|
|
class TestSympyInterp(TestCase):
|
|
@parametrize(
|
|
"fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS
|
|
)
|
|
def test_interp(self, fn):
|
|
# SymPy does not implement truncation for Expressions
|
|
if fn in ("div", "truncdiv", "minimum", "maximum", "mod"):
|
|
return
|
|
|
|
is_integer = None
|
|
if fn == "pow_by_natural":
|
|
is_integer = True
|
|
|
|
x = sympy.Dummy("x", integer=is_integer)
|
|
y = sympy.Dummy("y", integer=is_integer)
|
|
|
|
vals = CONSTANTS
|
|
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
|
vals = [True, False]
|
|
elif fn in BITWISE_OPS:
|
|
vals = vals + [True, False]
|
|
arity = 1
|
|
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
|
arity = 2
|
|
symbols = [x]
|
|
if arity == 2:
|
|
symbols = [x, y]
|
|
for args in itertools.product(vals, repeat=arity):
|
|
if arity == 1 and not valid_unary(fn, *args):
|
|
continue
|
|
elif arity == 2 and not valid_binary(fn, *args):
|
|
continue
|
|
with self.subTest(args=args):
|
|
sargs = [sympy.sympify(a) for a in args]
|
|
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
|
ref_r = getattr(ReferenceAnalysis, fn)(*sargs)
|
|
# Yes, I know this is a longwinded way of saying xreplace; the
|
|
# point is to test sympy_interp
|
|
r = sympy_interp(
|
|
ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr
|
|
)
|
|
self.assertEqual(ref_r, r)
|
|
|
|
@parametrize(
|
|
"fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS
|
|
)
|
|
def test_python_interp_fx(self, fn):
|
|
# These never show up from symbolic_shapes
|
|
if fn in ("log", "exp"):
|
|
return
|
|
|
|
# Sympy does not support truncation on symbolic shapes
|
|
if fn in ("truncdiv", "mod"):
|
|
return
|
|
|
|
vals = CONSTANTS
|
|
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
|
vals = [True, False]
|
|
elif fn in BITWISE_OPS:
|
|
vals = vals + [True, False]
|
|
|
|
arity = 1
|
|
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
|
arity = 2
|
|
|
|
is_integer = None
|
|
if fn == "pow_by_natural":
|
|
is_integer = True
|
|
|
|
x = sympy.Dummy("x", integer=is_integer)
|
|
y = sympy.Dummy("y", integer=is_integer)
|
|
|
|
symbols = [x]
|
|
if arity == 2:
|
|
symbols = [x, y]
|
|
|
|
for args in itertools.product(vals, repeat=arity):
|
|
if arity == 1 and not valid_unary(fn, *args):
|
|
continue
|
|
elif arity == 2 and not valid_binary(fn, *args):
|
|
continue
|
|
if fn == "truncdiv" and args[1] == 0:
|
|
continue
|
|
elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0):
|
|
continue
|
|
elif fn == "floordiv" and args[1] == 0:
|
|
continue
|
|
with self.subTest(args=args):
|
|
# Workaround mpf from symbol error
|
|
if fn == "minimum":
|
|
sympy_expr = sympy.Min(x, y)
|
|
elif fn == "maximum":
|
|
sympy_expr = sympy.Max(x, y)
|
|
else:
|
|
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
|
|
|
if arity == 1:
|
|
|
|
def trace_f(px):
|
|
return sympy_interp(
|
|
PythonReferenceAnalysis, {x: px}, sympy_expr
|
|
)
|
|
|
|
else:
|
|
|
|
def trace_f(px, py):
|
|
return sympy_interp(
|
|
PythonReferenceAnalysis, {x: px, y: py}, sympy_expr
|
|
)
|
|
|
|
gm = fx.symbolic_trace(trace_f)
|
|
|
|
self.assertEqual(
|
|
sympy_interp(
|
|
PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr
|
|
),
|
|
gm(*args),
|
|
)
|
|
|
|
@parametrize(
|
|
"fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS
|
|
)
|
|
def test_tensor_interp(self, fn):
|
|
# Skip operations not implemented or not applicable for tensors
|
|
if fn in ("div", "truncdiv", "int_truediv", "mod", "round_decimal"):
|
|
return
|
|
|
|
is_integer = None
|
|
if fn == "pow_by_natural":
|
|
is_integer = True
|
|
|
|
x = sympy.Symbol("x", integer=is_integer)
|
|
y = sympy.Symbol("y", integer=is_integer)
|
|
|
|
vals = CONSTANTS
|
|
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
|
vals = [True, False]
|
|
elif fn in BITWISE_OPS:
|
|
vals = vals + [True, False]
|
|
|
|
arity = 1
|
|
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
|
arity = 2
|
|
|
|
symbols = [x]
|
|
if arity == 2:
|
|
symbols = [x, y]
|
|
|
|
for args in itertools.product(vals, repeat=arity):
|
|
if arity == 1 and not valid_unary(fn, *args):
|
|
continue
|
|
elif arity == 2 and not valid_binary(fn, *args):
|
|
continue
|
|
|
|
with self.subTest(args=args):
|
|
tensor_args = [
|
|
torch.tensor(
|
|
a, dtype=torch.double if isinstance(a, float) else torch.int64
|
|
)
|
|
for a in args
|
|
]
|
|
|
|
try:
|
|
tensor_fn = getattr(TensorReferenceAnalysis, fn)
|
|
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
|
direct_result = tensor_fn(*tensor_args)
|
|
interp_result = sympy_interp(
|
|
TensorReferenceAnalysis,
|
|
dict(zip(symbols, tensor_args)),
|
|
sympy_expr,
|
|
)
|
|
|
|
# Ensure both results are of the same dtype for comparison
|
|
if direct_result.dtype != interp_result.dtype:
|
|
if (
|
|
direct_result.dtype == torch.bool
|
|
or interp_result.dtype == torch.bool
|
|
):
|
|
direct_result = direct_result.to(torch.bool)
|
|
interp_result = interp_result.to(torch.bool)
|
|
else:
|
|
direct_result = direct_result.to(torch.double)
|
|
interp_result = interp_result.to(torch.double)
|
|
|
|
self.assertTrue(
|
|
torch.allclose(
|
|
direct_result, interp_result, rtol=1e-5, atol=1e-8
|
|
),
|
|
f"Mismatch for {fn}{args}: direct={direct_result}, interp={interp_result}",
|
|
)
|
|
|
|
if fn in UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS:
|
|
self.assertEqual(direct_result.dtype, torch.bool)
|
|
self.assertEqual(interp_result.dtype, torch.bool)
|
|
|
|
if fn in (
|
|
"floor_to_int",
|
|
"ceil_to_int",
|
|
"round_to_int",
|
|
"trunc_to_int",
|
|
):
|
|
self.assertEqual(direct_result.dtype, torch.int64)
|
|
self.assertEqual(interp_result.dtype, torch.int64)
|
|
|
|
except NotImplementedError:
|
|
print(f"Operation {fn} not implemented for TensorReferenceAnalysis")
|
|
except Exception as e:
|
|
self.fail(f"Unexpected error for {fn}{args}: {str(e)}")
|
|
|
|
|
|
def type_name_fn(type: type) -> str:
|
|
return type.__name__
|
|
|
|
|
|
def parametrize_relational_types(*types):
|
|
def wrapper(f: Callable):
|
|
return parametrize("op", types or RELATIONAL_TYPES, name_fn=type_name_fn)(f)
|
|
|
|
return wrapper
|
|
|
|
|
|
class TestSympySolve(TestCase):
|
|
def _create_integer_symbols(self) -> list[sympy.Symbol]:
|
|
return sympy.symbols("a b c", integer=True)
|
|
|
|
def test_give_up(self):
|
|
from sympy import Eq, Ne
|
|
|
|
a, b, c = self._create_integer_symbols()
|
|
|
|
cases = [
|
|
# Not a relational operation.
|
|
a + b,
|
|
# 'a' appears on both sides.
|
|
Eq(a, a + 1),
|
|
# 'a' doesn't appear on neither side.
|
|
Eq(b, c + 1),
|
|
# Result is a 'sympy.And'.
|
|
Eq(FloorDiv(a, b), c),
|
|
# Result is a 'sympy.Or'.
|
|
Ne(FloorDiv(a, b), c),
|
|
]
|
|
|
|
for case in cases:
|
|
e = try_solve(case, a)
|
|
self.assertEqual(e, None)
|
|
|
|
@parametrize_relational_types()
|
|
def test_noop(self, op):
|
|
a, b, _ = self._create_integer_symbols()
|
|
|
|
lhs, rhs = a, 42 * b
|
|
expr = op(lhs, rhs)
|
|
|
|
r = try_solve(expr, a)
|
|
self.assertNotEqual(r, None)
|
|
|
|
r_expr, r_rhs = r
|
|
self.assertEqual(r_expr, expr)
|
|
self.assertEqual(r_rhs, rhs)
|
|
|
|
@parametrize_relational_types()
|
|
def test_noop_rhs(self, op):
|
|
a, b, _ = self._create_integer_symbols()
|
|
|
|
lhs, rhs = 42 * b, a
|
|
|
|
mirror = mirror_rel_op(op)
|
|
self.assertNotEqual(mirror, None)
|
|
|
|
expr = op(lhs, rhs)
|
|
|
|
r = try_solve(expr, a)
|
|
self.assertNotEqual(r, None)
|
|
|
|
r_expr, r_rhs = r
|
|
self.assertEqual(r_expr, mirror(rhs, lhs))
|
|
self.assertEqual(r_rhs, lhs)
|
|
|
|
def _test_cases(
|
|
self,
|
|
cases: list[tuple[sympy.Basic, sympy.Basic]],
|
|
thing: sympy.Basic,
|
|
op: type[sympy.Rel],
|
|
**kwargs,
|
|
):
|
|
for source, expected in cases:
|
|
r = try_solve(source, thing, **kwargs)
|
|
|
|
self.assertTrue(
|
|
(r is None and expected is None)
|
|
or (r is not None and expected is not None)
|
|
)
|
|
|
|
if r is not None:
|
|
r_expr, r_rhs = r
|
|
self.assertEqual(r_rhs, expected)
|
|
self.assertEqual(r_expr, op(thing, expected))
|
|
|
|
def test_addition(self):
|
|
from sympy import Eq
|
|
|
|
a, b, c = self._create_integer_symbols()
|
|
|
|
cases = [
|
|
(Eq(a + b, 0), -b),
|
|
(Eq(a + 5, b - 5), b - 10),
|
|
(Eq(a + c * b, 1), 1 - c * b),
|
|
]
|
|
|
|
self._test_cases(cases, a, Eq)
|
|
|
|
@parametrize_relational_types(sympy.Eq, sympy.Ne)
|
|
def test_multiplication_division(self, op):
|
|
a, b, c = self._create_integer_symbols()
|
|
|
|
cases = [
|
|
(op(a * b, 1), 1 / b),
|
|
(op(a * 5, b - 5), (b - 5) / 5),
|
|
(op(a * b, c), c / b),
|
|
]
|
|
|
|
self._test_cases(cases, a, op)
|
|
|
|
@parametrize_relational_types(*INEQUALITY_TYPES)
|
|
def test_multiplication_division_inequality(self, op):
|
|
a, b, _ = self._create_integer_symbols()
|
|
intneg = sympy.Symbol("neg", integer=True, negative=True)
|
|
intpos = sympy.Symbol("pos", integer=True, positive=True)
|
|
|
|
cases = [
|
|
# Divide/multiply both sides by positive number.
|
|
(op(a * intpos, 1), 1 / intpos),
|
|
(op(a / (5 * intpos), 1), 5 * intpos),
|
|
(op(a * 5, b - 5), (b - 5) / 5),
|
|
# 'b' is not strictly positive nor negative, so we can't
|
|
# divide/multiply both sides by 'b'.
|
|
(op(a * b, 1), None),
|
|
(op(a / b, 1), None),
|
|
(op(a * b * intpos, 1), None),
|
|
]
|
|
|
|
mirror_cases = [
|
|
# Divide/multiply both sides by negative number.
|
|
(op(a * intneg, 1), 1 / intneg),
|
|
(op(a / (5 * intneg), 1), 5 * intneg),
|
|
(op(a * -5, b - 5), -(b - 5) / 5),
|
|
]
|
|
mirror_op = mirror_rel_op(op)
|
|
assert mirror_op is not None
|
|
|
|
self._test_cases(cases, a, op)
|
|
self._test_cases(mirror_cases, a, mirror_op)
|
|
|
|
@parametrize_relational_types()
|
|
def test_floordiv(self, op):
|
|
from sympy import Eq, Ge, Gt, Le, Lt, Ne
|
|
|
|
a, b, c = sympy.symbols("a b c")
|
|
pos = sympy.Symbol("pos", positive=True)
|
|
integer = sympy.Symbol("integer", integer=True)
|
|
|
|
# (Eq(FloorDiv(a, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
|
|
# (Eq(FloorDiv(a + 5, pos), integer), And(Ge(a, integer * pos), Lt(a, (integer + 1) * pos))),
|
|
# (Ne(FloorDiv(a, pos), integer), Or(Lt(a, integer * pos), Ge(a, (integer + 1) * pos))),
|
|
|
|
special_case = {
|
|
# 'FloorDiv' turns into 'And', which can't be simplified any further.
|
|
Eq: (Eq(FloorDiv(a, pos), integer), None),
|
|
# 'FloorDiv' turns into 'Or', which can't be simplified any further.
|
|
Ne: (Ne(FloorDiv(a, pos), integer), None),
|
|
Gt: (Gt(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
|
Ge: (Ge(FloorDiv(a, pos), integer), integer * pos),
|
|
Lt: (Lt(FloorDiv(a, pos), integer), integer * pos),
|
|
Le: (Le(FloorDiv(a, pos), integer), (integer + 1) * pos),
|
|
}[op]
|
|
|
|
cases: list[tuple[sympy.Basic, sympy.Basic]] = [
|
|
# 'b' is not strictly positive
|
|
(op(FloorDiv(a, b), integer), None),
|
|
# 'c' is not strictly positive
|
|
(op(FloorDiv(a, pos), c), None),
|
|
]
|
|
|
|
# The result might change after 'FloorDiv' transformation.
|
|
# Specifically:
|
|
# - [Ge, Gt] => Ge
|
|
# - [Le, Lt] => Lt
|
|
if op in (sympy.Gt, sympy.Ge):
|
|
r_op = sympy.Ge
|
|
elif op in (sympy.Lt, sympy.Le):
|
|
r_op = sympy.Lt
|
|
else:
|
|
r_op = op
|
|
|
|
self._test_cases([special_case, *cases], a, r_op)
|
|
self._test_cases(
|
|
[(special_case[0], None), *cases], a, r_op, floordiv_inequality=False
|
|
)
|
|
|
|
def test_floordiv_eq_simplify(self):
|
|
from sympy import Eq, Le, Lt
|
|
|
|
a = sympy.Symbol("a", positive=True, integer=True)
|
|
|
|
def check(expr, expected):
|
|
r = try_solve(expr, a)
|
|
self.assertNotEqual(r, None)
|
|
r_expr, _ = r
|
|
self.assertEqual(r_expr, expected)
|
|
|
|
# (a + 10) // 3 == 3
|
|
# =====================================
|
|
# 3 * 3 <= a + 10 (always true)
|
|
# a + 10 < 4 * 3 (not sure)
|
|
check(Eq(FloorDiv(a + 10, 3), 3), Lt(a, (3 + 1) * 3 - 10))
|
|
|
|
# (a + 10) // 2 == 4
|
|
# =====================================
|
|
# 4 * 2 <= 10 - a (not sure)
|
|
# 10 - a < 5 * 2 (always true)
|
|
check(Eq(FloorDiv(10 - a, 2), 4), Le(a, -(4 * 2 - 10)))
|
|
|
|
@skipIf(not TEST_Z3, "Z3 not installed")
|
|
def test_z3_proof_floordiv_eq_simplify(self):
|
|
import z3
|
|
from sympy import Eq, Lt
|
|
|
|
a = sympy.Symbol("a", positive=True, integer=True)
|
|
a_ = z3.Int("a")
|
|
|
|
# (a + 10) // 3 == 3
|
|
# =====================================
|
|
# 3 * 3 <= a + 10 (always true)
|
|
# a + 10 < 4 * 3 (not sure)
|
|
solver = z3.SolverFor("QF_NRA")
|
|
|
|
# Add assertions for 'a_'.
|
|
solver.add(a_ > 0)
|
|
|
|
expr = Eq(FloorDiv(a + 10, 3), 3)
|
|
r_expr, _ = try_solve(expr, a)
|
|
|
|
# Check 'try_solve' really returns the 'expected' below.
|
|
expected = Lt(a, (3 + 1) * 3 - 10)
|
|
self.assertEqual(r_expr, expected)
|
|
|
|
# Check whether there is an integer 'a_' such that the
|
|
# equation below is satisfied.
|
|
solver.add(
|
|
# expr
|
|
(z3.ToInt((a_ + 10) / 3.0) == 3)
|
|
!=
|
|
# expected
|
|
(a_ < (3 + 1) * 3 - 10)
|
|
)
|
|
|
|
# Assert that there's no such an integer.
|
|
# i.e. the transformation is sound.
|
|
r = solver.check()
|
|
self.assertEqual(r, z3.unsat)
|
|
|
|
def test_simple_floordiv_gcd(self):
|
|
x, y, z = sympy.symbols("x y z")
|
|
|
|
# positive tests
|
|
self.assertEqual(simple_floordiv_gcd(x, x), x)
|
|
self.assertEqual(simple_floordiv_gcd(128 * x, 2304), 128)
|
|
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y, 2304), 128)
|
|
self.assertEqual(simple_floordiv_gcd(128 * x + 128 * y + 8192 * z, 9216), 128)
|
|
self.assertEqual(simple_floordiv_gcd(49152 * x, 96 * x), 96 * x)
|
|
self.assertEqual(simple_floordiv_gcd(96 * x, 96 * x), 96 * x)
|
|
self.assertEqual(simple_floordiv_gcd(x * y, x), x)
|
|
self.assertEqual(simple_floordiv_gcd(384 * x * y, x * y), x * y)
|
|
self.assertEqual(simple_floordiv_gcd(256 * x * y, 8 * x), 8 * x)
|
|
|
|
# negative tests
|
|
self.assertEqual(simple_floordiv_gcd(x * y + x + y + 1, x + 1), 1)
|
|
|
|
|
|
class TestSympyFunctions(TestCase):
|
|
def test_pickle(self):
|
|
x = OpaqueUnaryFn_cos(sympy.Symbol("a"))
|
|
r = pickle.loads(pickle.dumps(x))
|
|
self.assertEqual(x, r)
|
|
|
|
|
|
class TestSingletonInt(TestCase):
|
|
def test_basic(self):
|
|
j1 = SingletonInt(1, coeff=1)
|
|
j1_copy = SingletonInt(1, coeff=1)
|
|
j2 = SingletonInt(2, coeff=1)
|
|
j1x2 = SingletonInt(1, coeff=2)
|
|
|
|
def test_eq(a, b, expected):
|
|
self.assertEqual(sympy.Eq(a, b), expected)
|
|
self.assertEqual(sympy.Ne(b, a), not expected)
|
|
|
|
# eq, ne
|
|
test_eq(j1, j1, True)
|
|
test_eq(j1, j1_copy, True)
|
|
test_eq(j1, j2, False)
|
|
test_eq(j1, j1x2, False)
|
|
test_eq(j1, sympy.Integer(1), False)
|
|
test_eq(j1, sympy.Integer(3), False)
|
|
|
|
def test_ineq(a, b, expected, *, strict=True):
|
|
greater = (sympy.Gt, is_gt) if strict else (sympy.Ge, is_ge)
|
|
less = (sympy.Lt, is_lt) if strict else (sympy.Le, is_le)
|
|
|
|
if isinstance(expected, bool):
|
|
# expected is always True
|
|
for fn in greater:
|
|
self.assertEqual(fn(a, b), expected)
|
|
self.assertEqual(fn(b, a), not expected)
|
|
for fn in less:
|
|
self.assertEqual(fn(b, a), expected)
|
|
self.assertEqual(fn(a, b), not expected)
|
|
else:
|
|
for fn in greater:
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
fn(a, b)
|
|
for fn in less:
|
|
with self.assertRaisesRegex(ValueError, expected):
|
|
fn(b, a)
|
|
|
|
# ge, le, gt, lt
|
|
for strict in (True, False):
|
|
_test_ineq = functools.partial(test_ineq, strict=strict)
|
|
_test_ineq(j1, sympy.Integer(0), True)
|
|
_test_ineq(j1, sympy.Integer(3), "indeterminate")
|
|
_test_ineq(j1, j2, "indeterminate")
|
|
_test_ineq(j1x2, j1, True)
|
|
|
|
# Special cases for ge, le, gt, lt:
|
|
for ge in (sympy.Ge, is_ge):
|
|
self.assertTrue(ge(j1, j1))
|
|
self.assertTrue(ge(j1, sympy.Integer(2)))
|
|
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
ge(sympy.Integer(2), j1)
|
|
for le in (sympy.Le, is_le):
|
|
self.assertTrue(le(j1, j1))
|
|
self.assertTrue(le(sympy.Integer(2), j1))
|
|
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
le(j1, sympy.Integer(2))
|
|
|
|
for gt in (sympy.Gt, is_gt):
|
|
self.assertFalse(gt(j1, j1))
|
|
self.assertFalse(gt(sympy.Integer(2), j1))
|
|
# it is only known to be that j1 >= 2, j1 > 2 is indeterminate
|
|
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
gt(j1, sympy.Integer(2))
|
|
|
|
for lt in (sympy.Lt, is_lt):
|
|
self.assertFalse(lt(j1, j1))
|
|
self.assertFalse(lt(j1, sympy.Integer(2)))
|
|
with self.assertRaisesRegex(ValueError, "indeterminate"):
|
|
lt(sympy.Integer(2), j1)
|
|
|
|
# mul
|
|
self.assertEqual(j1 * 2, j1x2)
|
|
# Unfortunately, this doesn't not automatically simplify to 2*j1
|
|
# since sympy.Mul doesn't trigger __mul__ unlike the above.
|
|
self.assertIsInstance(sympy.Mul(j1, 2), sympy.core.mul.Mul)
|
|
|
|
with self.assertRaisesRegex(ValueError, "cannot be multiplied"):
|
|
j1 * j2
|
|
|
|
self.assertEqual(j1.free_symbols, set())
|
|
|
|
class TestIdentity(TestCase):
|
|
def test_expand_identity(self):
|
|
"""
|
|
Test removing an identity via expansion.
|
|
"""
|
|
x = sympy.Symbol("x")
|
|
arg = x + sympy.S.One
|
|
expr = Identity(arg)
|
|
expanded = expr.expand(identity=True)
|
|
self.assertEqual(expanded.count(Identity), 0)
|
|
self.assertEqual(expanded, arg)
|
|
|
|
instantiate_parametrized_tests(TestValueRanges)
|
|
instantiate_parametrized_tests(TestSympyInterp)
|
|
instantiate_parametrized_tests(TestSympySolve)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|