Avoid some max constructor optimizations when known not needed. (#139741)

Summary:
around 10% with 1K nodes
more than that with 2K features. 414.5735 -> 333 (20%)

This target optimizing patterns like this
```
 sym_max: "Sym(Max(u31 + u32, u33 + u34))" = torch.sym_max(sym_sum_6, sym_sum_7);  sym_sum_6 = sym_sum_7 = None
        sym_max_1: "Sym(Max(u31 + u32, u33 + u34, u35 + u36))" = torch.sym_max(sym_max, sym_sum_8);  sym_max = sym_sum_8 = None
        sym_max_2: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38))" = torch.sym_max(sym_max_1, sym_sum_9);  sym_max_1 = sym_sum_9 = None
        sym_max_3: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40))" = torch.sym_max(sym_max_2, sym_sum_10);  sym_max_2 = sym_sum_10 = None
        sym_max_4: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42))" = torch.sym_max(sym_max_3, sym_sum_11);  sym_max_3 = sym_sum_11 = None
        sym_max_5: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44))" = torch.sym_max(sym_max_4, sym_sum_12);  sym_max_4 = sym_sum_12 = None
        sym_max_6: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46))" = torch.sym_max(sym_max_5, sym_sum_13);  sym_max_5 = sym_sum_13 = None
        sym_max_7: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46, u47 + u48))" = torch.sym_max(sym_max_6, sym_sum_14);  sym_max_6 = sym_sum_14 = None
        sym_max_8: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46, u47 + u48, u49 + u50))" = torch.sym_max(sym_max_7, sym_sum_15);  sym_max_7 = sym_sum_15 = sym_max_8 = None
```

<img width="496" alt="Screenshot 2024-11-05 at 11 00 35 AM" src="https://github.com/user-attachments/assets/455c06a3-e1bf-43cb-b880-9470ae6fb07f">
<img width="511" alt="Screenshot 2024-11-05 at 11 00 57 AM" src="https://github.com/user-attachments/assets/ff0d4236-9b5c-4a9a-8520-47b005bb3cb0">

Differential Revision: D65354971

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139741
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka
2024-11-21 16:50:52 +00:00
committed by PyTorch MergeBot
parent 75bbad4768
commit e39955e82f
2 changed files with 128 additions and 9 deletions

View File

@ -96,7 +96,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
and len(expr._args) == 2
and expr._args[0].is_symbol
and expr._args[1].is_symbol
and expr._args[0] != expr._args[1]
and expr._args[0] is not expr._args[1]
)
@ -570,24 +570,35 @@ class RShift(sympy.Function):
class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
def __new__(cls, *args, **assumptions):
def __new__(cls, *original_args, **assumptions):
from sympy.core.parameters import global_parameters
evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
args = (sympify(arg) for arg in args)
args = (sympify(arg) for arg in original_args)
# first standard filter, for cls.zero and cls.identity
# also reshape Max(a, Max(b, c)) to Max(a, b, c)
# See the comment in _satisfy_unique_summations_symbols.
unique_summations_symbols = (
None
if not evaluate
else cls._satisfy_unique_summations_symbols(original_args)
)
if evaluate:
try:
# first standard filter, for cls.zero and cls.identity
# also reshape Max(a, Max(b, c)) to Max(a, b, c)
args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment]
except ShortCircuit:
return cls.zero # type: ignore[attr-defined]
# remove redundant args that are easily identified
args = cls._collapse_arguments(args, **assumptions)
# find local zeros
args = cls._find_localzeros(args, **assumptions)
# No need to run _collapse_arguments and _find_localzeros, see the comment
# in _satisfy_unique_summations_symbols.
if unique_summations_symbols is None:
# remove redundant args that are easily identified
args = cls._collapse_arguments(args, **assumptions)
# find local zeros
args = cls._find_localzeros(args, **assumptions)
args = frozenset(args)
@ -600,8 +611,85 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
# base creation
obj = Expr.__new__(cls, *ordered(args), **assumptions)
obj._argset = args
obj.unique_summations_symbols = unique_summations_symbols
return obj
@classmethod
def _satisfy_unique_summations_symbols(
cls, args
) -> Optional[set[sympy.core.symbol.Symbol]]:
"""
One common case in some models is building expressions of the form
max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
For such expressions, we call the Max constructor X times (once for each nested
max) and the expression gets flattened.
An expensive cost in constructing those expressions is running _collapse_arguments
and _find_localzeros. However, those two optimizations are unnecessary when the args
to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols.
This function is used to detect such properties of the expressions we are building
and if so inform that we do not need to run those optimizations. To detect those,
we store a property in the expression that tells that this expression is a min/max
operation over terms that use unique symbols "unique_summations_symbols". This property
also memoize the set of symbols used in all the terms to make it faster to detect this
property inductively.
When we apply max to add a new term, all we need to do is check if the new term uses
unique symbols (with respect to existing terms and itself).
Example:
t = Max(a+b, c+d) ==> satisfies the property
Max(t, h+j) ==> h,j not in [a,b,c,d] => satisfy the property.
The function returns None if the new expression does not satisfy the unique_summations_symbols
property. Otherwise, it returns a new set of unique symbols.
"""
if len(args) != 2:
return None
(lhs, rhs) = (
(args[1], args[0])
if isinstance(args[1], MinMaxBase)
else (args[0], args[1])
)
if not _is_symbols_binary_summation(rhs):
return None
# base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols.
if _is_symbols_binary_summation(lhs):
return cls._unique_symbols(args)
# inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols
if isinstance(lhs, MinMaxBase):
lhs_unique_summations_symbols = getattr(
lhs, "unique_summations_symbols", None
)
if lhs_unique_summations_symbols is not None:
return cls._unique_symbols([rhs], lhs_unique_summations_symbols)
return None
@classmethod
def _unique_symbols(
cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None
) -> Optional[set[sympy.core.symbol.Symbol]]:
"""
Return seen_symbols if all atoms in all args are all unique symbols,
else returns None. initial_set can be used to represent initial value for seen_symbols
"""
seen_symbols = set() if initial_set is None else initial_set
for arg in args:
for element in arg.atoms():
if not isinstance(element, sympy.core.symbol.Symbol):
return None
elif element in seen_symbols:
return None
else:
seen_symbols.add(element)
return seen_symbols
@classmethod
def _collapse_arguments(cls, args, **assumptions):
"""Remove redundant args.