mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Run sympy expressions with Python values / FX tracing (#113978)
To codegen deferred runtime asserts, I need to be able to convert sympy expressions back into regular Python expressions that I can put in FX graphs. This PR adds some of the machinery to do this: it adds a new sympy analysis that runs operations on all FX traceable operations that can also be run with plain Python int/float/bool/etc. It's tested by symbolic tracing through the analysis, and then testing that this traced graph gives the same result as running the Python analysis directly. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/113978 Approved by: https://github.com/aakhundov, https://github.com/lezcano
This commit is contained in:
committed by
PyTorch MergeBot
parent
cd2798943d
commit
473b17c4c1
@ -16,11 +16,12 @@ from torch.testing._internal.common_utils import (
|
||||
from torch.utils._sympy.functions import FloorDiv
|
||||
from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve
|
||||
from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges
|
||||
from torch.utils._sympy.reference import ReferenceAnalysis
|
||||
from torch.utils._sympy.reference import ReferenceAnalysis, PythonReferenceAnalysis
|
||||
from torch.utils._sympy.interp import sympy_interp
|
||||
from torch.utils._sympy.singleton_int import SingletonInt
|
||||
from sympy.core.relational import is_ge, is_le, is_gt, is_lt
|
||||
import functools
|
||||
import torch.fx as fx
|
||||
|
||||
|
||||
|
||||
@ -281,6 +282,60 @@ class TestSympyInterp(TestCase):
|
||||
r = sympy_interp(ReferenceAnalysis, dict(zip(symbols, sargs)), sympy_expr)
|
||||
self.assertEqual(ref_r, r)
|
||||
|
||||
@parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS)
|
||||
def test_python_interp_fx(self, fn):
|
||||
# These never show up from symbolic_shapes
|
||||
if fn in ("log", "exp"):
|
||||
return
|
||||
|
||||
vals = CONSTANTS
|
||||
if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}:
|
||||
vals = [True, False]
|
||||
|
||||
arity = 1
|
||||
if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}:
|
||||
arity = 2
|
||||
|
||||
from sympy.abc import x, y
|
||||
|
||||
symbols = [x]
|
||||
if arity == 2:
|
||||
symbols = [x, y]
|
||||
|
||||
# Workaround mpf from symbol error
|
||||
if fn == "minimum":
|
||||
sympy_expr = sympy.Min(x, y)
|
||||
elif fn == "maximum":
|
||||
sympy_expr = sympy.Max(x, y)
|
||||
else:
|
||||
sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols)
|
||||
|
||||
if arity == 1:
|
||||
def trace_f(px):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr)
|
||||
else:
|
||||
def trace_f(px, py):
|
||||
return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr)
|
||||
|
||||
gm = fx.symbolic_trace(trace_f)
|
||||
|
||||
for args in itertools.product(vals, repeat=arity):
|
||||
if arity == 1 and not valid_unary(fn, *args):
|
||||
continue
|
||||
elif arity == 2 and not valid_binary(fn, *args):
|
||||
continue
|
||||
if fn == "truncdiv" and args[1] == 0:
|
||||
continue
|
||||
elif fn == "pow" and (args[0] == 0 and args[1] <= 0):
|
||||
continue
|
||||
elif fn == "floordiv" and args[1] == 0:
|
||||
continue
|
||||
with self.subTest(args=args):
|
||||
self.assertEqual(
|
||||
sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr),
|
||||
gm(*args)
|
||||
)
|
||||
|
||||
|
||||
def type_name_fn(type: Type) -> str:
|
||||
return type.__name__
|
||||
|
||||
@ -56,6 +56,7 @@ __all__ = [
|
||||
'set_warn_always', 'is_warn_always_enabled', 'SymInt', 'SymFloat',
|
||||
'SymBool', 'sym_not', 'unravel_index',
|
||||
'sym_int', 'sym_float', 'sym_max', 'sym_min', 'sym_ite', 'compile', 'vmap',
|
||||
'sym_sqrt',
|
||||
'export', 'autocast', 'cond',
|
||||
]
|
||||
|
||||
@ -414,8 +415,15 @@ def sym_not(a):
|
||||
Args:
|
||||
a (SymBool or bool): Object to negate
|
||||
"""
|
||||
import sympy
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_not, (a,), a)
|
||||
if hasattr(a, '__sym_not__'):
|
||||
return a.__sym_not__()
|
||||
if isinstance(a, sympy.Basic):
|
||||
return ~a # type: ignore[operator]
|
||||
return not a
|
||||
|
||||
def sym_float(a):
|
||||
@ -424,6 +432,10 @@ def sym_float(a):
|
||||
Args:
|
||||
a (SymInt, SymFloat, or object): Object to cast
|
||||
"""
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_float, (a,), a)
|
||||
if isinstance(a, SymFloat):
|
||||
return a
|
||||
elif hasattr(a, '__sym_float__'):
|
||||
@ -437,6 +449,10 @@ def sym_int(a):
|
||||
Args:
|
||||
a (SymInt, SymFloat, or object): Object to cast
|
||||
"""
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_int, (a,), a)
|
||||
if isinstance(a, SymInt):
|
||||
return a
|
||||
elif isinstance(a, SymFloat):
|
||||
@ -445,6 +461,10 @@ def sym_int(a):
|
||||
|
||||
def sym_max(a, b):
|
||||
""" SymInt-aware utility for max()."""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((a, b)):
|
||||
return handle_torch_function(sym_max, (a, b), a, b)
|
||||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_max__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
@ -456,13 +476,31 @@ def sym_max(a, b):
|
||||
|
||||
def sym_min(a, b):
|
||||
""" SymInt-aware utility for max()."""
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((a, b)):
|
||||
return handle_torch_function(sym_min, (a, b), a, b)
|
||||
if isinstance(a, (SymInt, SymFloat)):
|
||||
return a.__sym_min__(b)
|
||||
elif isinstance(b, (SymInt, SymFloat)):
|
||||
return b.__sym_min__(a)
|
||||
return builtins.min(a, b) # type: ignore[operator]
|
||||
|
||||
# Drop in replacement for math.sqrt
|
||||
def sym_sqrt(a):
|
||||
from .overrides import has_torch_function_unary, handle_torch_function
|
||||
|
||||
if has_torch_function_unary(a):
|
||||
return handle_torch_function(sym_sqrt, (a,), a)
|
||||
if hasattr(a, "__sym_sqrt__"):
|
||||
return a.__sym_sqrt__()
|
||||
return math.sqrt(a)
|
||||
|
||||
def sym_ite(b, t, f):
|
||||
from .overrides import has_torch_function, handle_torch_function
|
||||
|
||||
if has_torch_function((b, t, f)):
|
||||
return handle_torch_function(sym_ite, (b, t, f), b, t, f)
|
||||
assert isinstance(b, (SymBool, builtins.bool)) and type(t) == type(f)
|
||||
if isinstance(b, SymBool):
|
||||
return b.__sym_ite__(t, f)
|
||||
|
||||
@ -58,6 +58,13 @@ Tensor = torch.Tensor
|
||||
|
||||
torch_function_passthrough = {
|
||||
torch.device,
|
||||
torch.sym_not,
|
||||
torch.sym_float,
|
||||
torch.sym_int,
|
||||
torch.sym_max,
|
||||
torch.sym_min,
|
||||
torch.sym_sqrt,
|
||||
torch.sym_ite,
|
||||
torch.Tensor.dim,
|
||||
torch.Tensor.ndim.__get__, # type: ignore[attr-defined]
|
||||
torch.Tensor.numel,
|
||||
|
||||
@ -24,6 +24,7 @@ from torch import ( # noqa: F401
|
||||
sym_max,
|
||||
sym_min,
|
||||
sym_not,
|
||||
sym_sqrt,
|
||||
SymBool,
|
||||
SymFloat,
|
||||
SymInt,
|
||||
@ -398,13 +399,6 @@ class SymNode:
|
||||
return False
|
||||
|
||||
|
||||
# Drop in replacement for math.sqrt
|
||||
def sym_sqrt(a):
|
||||
if hasattr(a, "__sym_sqrt__"):
|
||||
return a.__sym_sqrt__()
|
||||
return math.sqrt(a)
|
||||
|
||||
|
||||
# TODO: this probably needs the sizes-strides eval functions
|
||||
METHOD_TO_OPERATOR = {
|
||||
"abs": operator.abs,
|
||||
|
||||
@ -409,6 +409,9 @@ class Proxy:
|
||||
|
||||
return self.tracer.iter(self)
|
||||
|
||||
def __abs__(self):
|
||||
return self.tracer.create_proxy('call_function', operator.abs, (self,), {})
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
if self.tracer.trace_asserts:
|
||||
# check if this boolean is used in an assertion, bytecode pattern for assertions
|
||||
|
||||
@ -216,12 +216,6 @@ def get_ignored_functions() -> Set[Callable]:
|
||||
torch.sparse_csc_tensor,
|
||||
torch.sparse_bsr_tensor,
|
||||
torch.sparse_bsc_tensor,
|
||||
torch.sym_float,
|
||||
torch.sym_int,
|
||||
torch.sym_max,
|
||||
torch.sym_min,
|
||||
torch.sym_not,
|
||||
torch.sym_ite,
|
||||
torch.sym_constrain_range,
|
||||
torch.sym_constrain_range_for_size,
|
||||
torch.tril_indices,
|
||||
@ -1061,6 +1055,13 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
||||
torch.sub: lambda input, other, out=None: -1,
|
||||
torch.subtract: lambda input, other, out=None: -1,
|
||||
torch.sum: lambda input, dim=None: -1,
|
||||
torch.sym_float: lambda input: -1,
|
||||
torch.sym_int: lambda input: -1,
|
||||
torch.sym_max: lambda a, b: -1,
|
||||
torch.sym_min: lambda a, b: -1,
|
||||
torch.sym_not: lambda input: -1,
|
||||
torch.sym_ite: lambda a, b, c: -1,
|
||||
torch.sym_sqrt: lambda input: -1,
|
||||
torch.nansum: lambda input, dim=None: -1,
|
||||
torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,
|
||||
torch.svd_lowrank: lambda input, q=6, niter=2, M=None: -1,
|
||||
|
||||
@ -1,7 +1,14 @@
|
||||
import math
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
|
||||
# The normal Python interpretation of the operators
|
||||
|
||||
# The sympy interpretation of operators. It will also sometimes work with
|
||||
# plain int/float, but if you do certain operations you will get out a
|
||||
# sympy.Basic in the end. If you want the Python/FX traceable interpretation,
|
||||
# check PythonReferenceAnalysis.
|
||||
# NB: For magic methods this needs to use normal magic methods
|
||||
# so that test_magic_methods works
|
||||
class ReferenceAnalysis:
|
||||
@ -11,12 +18,10 @@ class ReferenceAnalysis:
|
||||
|
||||
@staticmethod
|
||||
def or_(a, b):
|
||||
assert not isinstance(a, bool) and not isinstance(b, bool)
|
||||
return a | b
|
||||
|
||||
@staticmethod
|
||||
def and_(a, b):
|
||||
assert not isinstance(a, bool) and not isinstance(b, bool)
|
||||
return a & b
|
||||
|
||||
@staticmethod
|
||||
@ -151,3 +156,59 @@ class ReferenceAnalysis:
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return sympy.ceiling(x)
|
||||
|
||||
|
||||
# Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain
|
||||
# Python types and is FX traceable. Inheritance here is purely for code
|
||||
# sharing (TODO: considering splitting out a BaseReferenceAnalysis).
|
||||
class PythonReferenceAnalysis(ReferenceAnalysis):
|
||||
@staticmethod
|
||||
def constant(c, dtype):
|
||||
if dtype is torch.int64:
|
||||
return int(c)
|
||||
elif dtype is torch.double:
|
||||
return float(c)
|
||||
elif dtype is torch.bool:
|
||||
return bool(c)
|
||||
else:
|
||||
raise AssertionError(f"unrecognized dtype {dtype}")
|
||||
|
||||
@staticmethod
|
||||
def not_(a):
|
||||
return torch.sym_not(a)
|
||||
|
||||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
return a // b
|
||||
|
||||
@staticmethod
|
||||
def truncdiv(a, b):
|
||||
return a / b
|
||||
|
||||
@staticmethod
|
||||
def exp(x):
|
||||
raise AssertionError("exp is not valid shape sympy expr")
|
||||
|
||||
@staticmethod
|
||||
def log(x):
|
||||
raise AssertionError("log is not valid shape sympy expr")
|
||||
|
||||
@staticmethod
|
||||
def sqrt(x):
|
||||
return torch.sym_sqrt(x)
|
||||
|
||||
@staticmethod
|
||||
def minimum(a, b):
|
||||
return torch.sym_min(a, b)
|
||||
|
||||
@staticmethod
|
||||
def maximum(a, b):
|
||||
return torch.sym_max(a, b)
|
||||
|
||||
@staticmethod
|
||||
def floor(x):
|
||||
return math.floor(x)
|
||||
|
||||
@staticmethod
|
||||
def ceil(x):
|
||||
return math.ceil(x)
|
||||
|
||||
Reference in New Issue
Block a user