[export] enumerate unsupported sympy.Functions (#134271)

There's 2 concepts of unsupported sympy.Functions in symbolic_shapes:
1) unsupported by the export solver, meaning the solver doesn't know how to provide useful fixes for those functions
2) unsupported by the sympy interpreter - meaning we can't reify them into FX nodes because the functions aren't present in PythonReferenceAnalysis

This splits the current call into a call for each version, with the Export solver the only user of 1). For 1), we enumerate the functions in _sympy/functions.py, and subtract the functions we know we can support. For 2) there's only 3 functions we've seen pop up in test cases.

Differential Revision: D61677956

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134271
Approved by: https://github.com/avikchaudhuri
This commit is contained in:
Pian Pawakapan
2024-08-26 22:44:12 +00:00
committed by PyTorch MergeBot
parent 55236d0cb7
commit ddd71e3479
3 changed files with 37 additions and 7 deletions

View File

@ -1267,13 +1267,14 @@ def _is_supported_equivalence(expr):
)
return isinstance(expr, sympy.Symbol)
def _has_unsupported_sympy_function(expr) -> bool:
def _has_uninterpretable_sympy_function(expr) -> bool:
"""
Add functions that our sympy interpreter can't reify into FX nodes
"""
return expr.has(
torch.utils._sympy.functions.ToFloat,
torch.utils._sympy.functions.TruncToInt,
torch.utils._sympy.functions.CeilToInt,
# add more sympy functions that involve float<->int conversion here
# since our solver does not know what to do with them
)
@dataclass(frozen=True)
@ -1675,6 +1676,14 @@ class DimConstraints:
# symbols that are marked dynamic
self._marked_dynamic = marked_dynamic
# track supported sympy functions and subtract from list of all sympy functions
self._supported_sympy_functions: Set[sympy.Function] = {
Mod,
PythonMod,
FloorDiv,
}
self._enumerate_sympy_functions()
def rewrite_with_congruences(self, s, expr):
"""
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
@ -1741,6 +1750,20 @@ class DimConstraints:
expr = expr.replace(FloorDiv, floor_div_handler)
return expr
def _enumerate_sympy_functions(self):
module = torch.utils._sympy.functions
all_functions = set()
for attr in dir(module):
if isinstance(func := getattr(module, attr), sympy.FunctionClass):
all_functions.add(func)
self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions)
def _has_unsupported_sympy_function(self, expr) -> bool:
"""
Tracks list of sympy.Functions the export solver doesn't know how to handle.
"""
return expr.has(*self._unsupported_sympy_functions)
def add(self, expr) -> bool:
"""Add an expression to the set of constraints.
@ -1757,7 +1780,7 @@ class DimConstraints:
# a fix for this issue, we delay raising such failures. See solve().
if orig_reduced == sympy.false:
self._inconsistencies.append(f"{orig_expr} is inconsistent!")
if isinstance(expr, sympy.Ne) or _has_unsupported_sympy_function(expr):
if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr):
# we're not going to do anything useful with these, so drop them
return False
free_symbols = expr.free_symbols

View File

@ -93,7 +93,7 @@ def insert_deferred_runtime_asserts(
import sympy
from torch.fx.experimental.symbolic_shapes import (
_has_unsupported_sympy_function,
_has_uninterpretable_sympy_function,
CallMethodKey,
cast_symbool_to_symint_guardless,
ConvertIntKey,
@ -136,7 +136,7 @@ def insert_deferred_runtime_asserts(
(val := _get_sym_val(node)) is not None
and not isinstance(val, sympy.Number)
# this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
and not _has_unsupported_sympy_function(val)
and not _has_uninterpretable_sympy_function(val)
and any(
isinstance(arg, fx.Node)
and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
@ -195,7 +195,7 @@ def insert_deferred_runtime_asserts(
and _is_bound_expr_for_symbol(ra.expr)
)
# don't try to reify sympy functions we can't turn into FX nodes
or _has_unsupported_sympy_function(ra.expr)
or _has_uninterpretable_sympy_function(ra.expr)
):
continue

View File

@ -53,13 +53,20 @@ from .numbers import int_oo
__all__ = [
"FloorDiv",
"ModularIndexing",
"Where",
"PythonMod",
"Mod",
"CleanDiv",
"CeilToInt",
"FloorToInt",
"CeilDiv",
"IntTrueDiv",
"FloatTrueDiv",
"LShift",
"RShift",
"IsNonOverlappingAndDenseIndicator",
"TruncToFloat",
"TruncToInt",
"RoundToInt",
"RoundDecimal",
"ToFloat",