mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[export] enumerate unsupported sympy.Functions (#134271)
There's 2 concepts of unsupported sympy.Functions in symbolic_shapes: 1) unsupported by the export solver, meaning the solver doesn't know how to provide useful fixes for those functions 2) unsupported by the sympy interpreter - meaning we can't reify them into FX nodes because the functions aren't present in PythonReferenceAnalysis This splits the current call into a call for each version, with the Export solver the only user of 1). For 1), we enumerate the functions in _sympy/functions.py, and subtract the functions we know we can support. For 2) there's only 3 functions we've seen pop up in test cases. Differential Revision: D61677956 Pull Request resolved: https://github.com/pytorch/pytorch/pull/134271 Approved by: https://github.com/avikchaudhuri
This commit is contained in:
committed by
PyTorch MergeBot
parent
55236d0cb7
commit
ddd71e3479
@ -1267,13 +1267,14 @@ def _is_supported_equivalence(expr):
|
||||
)
|
||||
return isinstance(expr, sympy.Symbol)
|
||||
|
||||
def _has_unsupported_sympy_function(expr) -> bool:
|
||||
def _has_uninterpretable_sympy_function(expr) -> bool:
|
||||
"""
|
||||
Add functions that our sympy interpreter can't reify into FX nodes
|
||||
"""
|
||||
return expr.has(
|
||||
torch.utils._sympy.functions.ToFloat,
|
||||
torch.utils._sympy.functions.TruncToInt,
|
||||
torch.utils._sympy.functions.CeilToInt,
|
||||
# add more sympy functions that involve float<->int conversion here
|
||||
# since our solver does not know what to do with them
|
||||
)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@ -1675,6 +1676,14 @@ class DimConstraints:
|
||||
# symbols that are marked dynamic
|
||||
self._marked_dynamic = marked_dynamic
|
||||
|
||||
# track supported sympy functions and subtract from list of all sympy functions
|
||||
self._supported_sympy_functions: Set[sympy.Function] = {
|
||||
Mod,
|
||||
PythonMod,
|
||||
FloorDiv,
|
||||
}
|
||||
self._enumerate_sympy_functions()
|
||||
|
||||
def rewrite_with_congruences(self, s, expr):
|
||||
"""
|
||||
Eliminate expressions of the form b // d and b % d while adding congruences of the form b % d == k.
|
||||
@ -1741,6 +1750,20 @@ class DimConstraints:
|
||||
expr = expr.replace(FloorDiv, floor_div_handler)
|
||||
return expr
|
||||
|
||||
def _enumerate_sympy_functions(self):
|
||||
module = torch.utils._sympy.functions
|
||||
all_functions = set()
|
||||
for attr in dir(module):
|
||||
if isinstance(func := getattr(module, attr), sympy.FunctionClass):
|
||||
all_functions.add(func)
|
||||
self._unsupported_sympy_functions = all_functions.difference(self._supported_sympy_functions)
|
||||
|
||||
def _has_unsupported_sympy_function(self, expr) -> bool:
|
||||
"""
|
||||
Tracks list of sympy.Functions the export solver doesn't know how to handle.
|
||||
"""
|
||||
return expr.has(*self._unsupported_sympy_functions)
|
||||
|
||||
def add(self, expr) -> bool:
|
||||
"""Add an expression to the set of constraints.
|
||||
|
||||
@ -1757,7 +1780,7 @@ class DimConstraints:
|
||||
# a fix for this issue, we delay raising such failures. See solve().
|
||||
if orig_reduced == sympy.false:
|
||||
self._inconsistencies.append(f"{orig_expr} is inconsistent!")
|
||||
if isinstance(expr, sympy.Ne) or _has_unsupported_sympy_function(expr):
|
||||
if isinstance(expr, sympy.Ne) or self._has_unsupported_sympy_function(expr):
|
||||
# we're not going to do anything useful with these, so drop them
|
||||
return False
|
||||
free_symbols = expr.free_symbols
|
||||
|
@ -93,7 +93,7 @@ def insert_deferred_runtime_asserts(
|
||||
import sympy
|
||||
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_has_unsupported_sympy_function,
|
||||
_has_uninterpretable_sympy_function,
|
||||
CallMethodKey,
|
||||
cast_symbool_to_symint_guardless,
|
||||
ConvertIntKey,
|
||||
@ -136,7 +136,7 @@ def insert_deferred_runtime_asserts(
|
||||
(val := _get_sym_val(node)) is not None
|
||||
and not isinstance(val, sympy.Number)
|
||||
# this holds back from reifying anything in torch.utils._sympy.functions.py that's unsupported
|
||||
and not _has_unsupported_sympy_function(val)
|
||||
and not _has_uninterpretable_sympy_function(val)
|
||||
and any(
|
||||
isinstance(arg, fx.Node)
|
||||
and isinstance(_get_example_value(arg), (torch.Tensor, torch.Size))
|
||||
@ -195,7 +195,7 @@ def insert_deferred_runtime_asserts(
|
||||
and _is_bound_expr_for_symbol(ra.expr)
|
||||
)
|
||||
# don't try to reify sympy functions we can't turn into FX nodes
|
||||
or _has_unsupported_sympy_function(ra.expr)
|
||||
or _has_uninterpretable_sympy_function(ra.expr)
|
||||
):
|
||||
continue
|
||||
|
||||
|
@ -53,13 +53,20 @@ from .numbers import int_oo
|
||||
__all__ = [
|
||||
"FloorDiv",
|
||||
"ModularIndexing",
|
||||
"Where",
|
||||
"PythonMod",
|
||||
"Mod",
|
||||
"CleanDiv",
|
||||
"CeilToInt",
|
||||
"FloorToInt",
|
||||
"CeilDiv",
|
||||
"IntTrueDiv",
|
||||
"FloatTrueDiv",
|
||||
"LShift",
|
||||
"RShift",
|
||||
"IsNonOverlappingAndDenseIndicator",
|
||||
"TruncToFloat",
|
||||
"TruncToInt",
|
||||
"RoundToInt",
|
||||
"RoundDecimal",
|
||||
"ToFloat",
|
||||
|
Reference in New Issue
Block a user