Run sympy expressions with Python values / FX tracing (#113978)

To codegen deferred runtime asserts, I need to be able to convert sympy expressions back into regular Python expressions that I can put in FX graphs. This PR adds some of the machinery to do this: it adds a new sympy analysis that runs operations on all FX traceable operations that can also be run with plain Python int/float/bool/etc. It's tested by symbolic tracing through the analysis, and then testing that this traced graph gives the same result as running the Python analysis directly.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113978
Approved by: https://github.com/aakhundov, https://github.com/lezcano
This commit is contained in:
Edward Z. Yang
2023-11-20 10:01:27 -08:00
committed by PyTorch MergeBot
parent cd2798943d
commit 473b17c4c1
7 changed files with 176 additions and 17 deletions

View File

@ -16,11 +16,12 @@ from torch.testing._internal.common_utils import (
from torch.utils._sympy.functions import FloorDiv
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
from torch.utils._sympy.reference import ReferenceAnalysis
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
from torch.utils._sympy.interp import sympy_interp
from torch.utils._sympy.singleton_int import SingletonInt
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
import functools
import torch.fx as fx
@ -281,6 +282,60 @@ class TestSympyInterp(TestCase):
r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr)
self.assertEqual(ref_r, r)
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
def test_python_interp_fx(self, fn):
# These never show up from symbolic_shapes
if fn in ("log", "exp"):
return
vals = CONSTANTS
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
vals = [True, False]
arity = 1
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
arity = 2
from sympy.abc import x, y
symbols = [x]
if arity == 2:
symbols = [x, y]
# Workaround mpf from symbol error
if fn == "minimum":
sympy_expr = sympy.Min(x, y)
elif fn == "maximum":
sympy_expr = sympy.Max(x, y)
else:
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
if arity == 1:
def trace_f(px):
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
else:
def trace_f(px, py):
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
gm = fx.symbolic_trace(trace_f)
for args in itertools.product(vals, repeat=arity):
if arity == 1 and not valid_unary(fn, *args):
continue
elif arity == 2 and not valid_binary(fn, *args):
continue
if fn == "truncdiv" and args[1] == 0:
continue
elif fn == "pow" and (args[0] == 0 and args[1] <= 0):
continue
elif fn == "floordiv" and args[1] == 0:
continue
with self.subTest(args=args):
self.assertEqual(
sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
gm(*args)
)
def type_name_fn(type: Type) -> str:
return type.__name__

View File

@ -56,6 +56,7 @@ __all__ = [
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
'SymBool', 'sym_not', 'unravel_index',
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
'sym_sqrt',
'export', 'autocast', 'cond',
]
@ -414,8 +415,15 @@ def sym_not(a):
Args:
a (SymBool or bool): Object to negate
"""
import sympy
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_not, (a,), a)
if hasattr(a, '__sym_not__'):
return a.__sym_not__()
if isinstance(a, sympy.Basic):
return ~a # type: ignore[operator]
return not a
def sym_float(a):
@ -424,6 +432,10 @@ def sym_float(a):
Args:
a (SymInt, SymFloat, or object): Object to cast
"""
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_float, (a,), a)
if isinstance(a, SymFloat):
return a
elif hasattr(a, '__sym_float__'):
@ -437,6 +449,10 @@ def sym_int(a):
Args:
a (SymInt, SymFloat, or object): Object to cast
"""
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_int, (a,), a)
if isinstance(a, SymInt):
return a
elif isinstance(a, SymFloat):
@ -445,6 +461,10 @@ def sym_int(a):
def sym_max(a, b):
""" SymInt-aware utility for max()."""
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((a, b)):
return handle_torch_function(sym_max, (a, b), a, b)
if isinstance(a, (SymInt, SymFloat)):
return a.__sym_max__(b)
elif isinstance(b, (SymInt, SymFloat)):
@ -456,13 +476,31 @@ def sym_max(a, b):
def sym_min(a, b):
""" SymInt-aware utility for max()."""
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((a, b)):
return handle_torch_function(sym_min, (a, b), a, b)
if isinstance(a, (SymInt, SymFloat)):
return a.__sym_min__(b)
elif isinstance(b, (SymInt, SymFloat)):
return b.__sym_min__(a)
return builtins.min(a, b) # type: ignore[operator]
# Drop in replacement for math.sqrt
def sym_sqrt(a):
from .overrides import has_torch_function_unary, handle_torch_function
if has_torch_function_unary(a):
return handle_torch_function(sym_sqrt, (a,), a)
if hasattr(a, "__sym_sqrt__"):
return a.__sym_sqrt__()
return math.sqrt(a)
def sym_ite(b, t, f):
from .overrides import has_torch_function, handle_torch_function
if has_torch_function((b, t, f)):
return handle_torch_function(sym_ite, (b, t, f), b, t, f)
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
if isinstance(b, SymBool):
return b.__sym_ite__(t, f)

View File

@ -58,6 +58,13 @@ Tensor = torch.Tensor
torch_function_passthrough = {
torch.device,
torch.sym_not,
torch.sym_float,
torch.sym_int,
torch.sym_max,
torch.sym_min,
torch.sym_sqrt,
torch.sym_ite,
torch.Tensor.dim,
torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
torch.Tensor.numel,

View File

@ -24,6 +24,7 @@ from torch import ( # noqa: F401
sym_max,
sym_min,
sym_not,
sym_sqrt,
SymBool,
SymFloat,
SymInt,
@ -398,13 +399,6 @@ class SymNode:
return False
# Drop in replacement for math.sqrt
def sym_sqrt(a):
if hasattr(a, "__sym_sqrt__"):
return a.__sym_sqrt__()
return math.sqrt(a)
# TODO: this probably needs the sizes-strides eval functions
METHOD_TO_OPERATOR = {
"abs": operator.abs,

View File

@ -409,6 +409,9 @@ class Proxy:
return self.tracer.iter(self)
def __abs__(self):
return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
def __bool__(self) -> bool:
if self.tracer.trace_asserts:
# check if this boolean is used in an assertion, bytecode pattern for assertions

View File

@ -216,12 +216,6 @@ def get_ignored_functions() -> Set[Callable]:
torch.sparse_csc_tensor,
torch.sparse_bsr_tensor,
torch.sparse_bsc_tensor,
torch.sym_float,
torch.sym_int,
torch.sym_max,
torch.sym_min,
torch.sym_not,
torch.sym_ite,
torch.sym_constrain_range,
torch.sym_constrain_range_for_size,
torch.tril_indices,
@ -1061,6 +1055,13 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1,
torch.sum: lambda input, dim=None: -1,
torch.sym_float: lambda input: -1,
torch.sym_int: lambda input: -1,
torch.sym_max: lambda a, b: -1,
torch.sym_min: lambda a, b: -1,
torch.sym_not: lambda input: -1,
torch.sym_ite: lambda a, b, c: -1,
torch.sym_sqrt: lambda input: -1,
torch.nansum: lambda input, dim=None: -1,
torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,

View File

@ -1,7 +1,14 @@
import math
import sympy
import torch
# The normal Python interpretation of the operators
# The sympy interpretation of operators. It will also sometimes work with
# plain int/float, but if you do certain operations you will get out a
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
# check PythonReferenceAnalysis.
# NB: For magic methods this needs to use normal magic methods
# so that test_magic_methods works
class ReferenceAnalysis:
@ -11,12 +18,10 @@ class ReferenceAnalysis:
@staticmethod
def or_(a, b):
assert not isinstance(a, bool) and not isinstance(b, bool)
return a | b
@staticmethod
def and_(a, b):
assert not isinstance(a, bool) and not isinstance(b, bool)
return a & b
@staticmethod
@ -151,3 +156,59 @@ class ReferenceAnalysis:
@staticmethod
def ceil(x):
return sympy.ceiling(x)
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
# Python types and is FX traceable. Inheritance here is purely for code
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
class PythonReferenceAnalysis(ReferenceAnalysis):
@staticmethod
def constant(c, dtype):
if dtype is torch.int64:
return int(c)
elif dtype is torch.double:
return float(c)
elif dtype is torch.bool:
return bool(c)
else:
raise AssertionError(f"unrecognized dtype {dtype}")
@staticmethod
def not_(a):
return torch.sym_not(a)
@staticmethod
def floordiv(a, b):
return a // b
@staticmethod
def truncdiv(a, b):
return a / b
@staticmethod
def exp(x):
raise AssertionError("exp is not valid shape sympy expr")
@staticmethod
def log(x):
raise AssertionError("log is not valid shape sympy expr")
@staticmethod
def sqrt(x):
return torch.sym_sqrt(x)
@staticmethod
def minimum(a, b):
return torch.sym_min(a, b)
@staticmethod
def maximum(a, b):
return torch.sym_max(a, b)
@staticmethod
def floor(x):
return math.floor(x)
@staticmethod
def ceil(x):
return math.ceil(x)