mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Avoid some max constructor optimizations when known not needed. (#139741)
Summary: around 10% with 1K nodes more than that with 2K features. 414.5735 -> 333 (20%) This target optimizing patterns like this ``` sym_max: "Sym(Max(u31 + u32, u33 + u34))" = torch.sym_max(sym_sum_6, sym_sum_7); sym_sum_6 = sym_sum_7 = None sym_max_1: "Sym(Max(u31 + u32, u33 + u34, u35 + u36))" = torch.sym_max(sym_max, sym_sum_8); sym_max = sym_sum_8 = None sym_max_2: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38))" = torch.sym_max(sym_max_1, sym_sum_9); sym_max_1 = sym_sum_9 = None sym_max_3: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40))" = torch.sym_max(sym_max_2, sym_sum_10); sym_max_2 = sym_sum_10 = None sym_max_4: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42))" = torch.sym_max(sym_max_3, sym_sum_11); sym_max_3 = sym_sum_11 = None sym_max_5: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44))" = torch.sym_max(sym_max_4, sym_sum_12); sym_max_4 = sym_sum_12 = None sym_max_6: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46))" = torch.sym_max(sym_max_5, sym_sum_13); sym_max_5 = sym_sum_13 = None sym_max_7: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46, u47 + u48))" = torch.sym_max(sym_max_6, sym_sum_14); sym_max_6 = sym_sum_14 = None sym_max_8: "Sym(Max(u31 + u32, u33 + u34, u35 + u36, u37 + u38, u39 + u40, u41 + u42, u43 + u44, u45 + u46, u47 + u48, u49 + u50))" = torch.sym_max(sym_max_7, sym_sum_15); sym_max_7 = sym_sum_15 = sym_max_8 = None ``` <img width="496" alt="Screenshot 2024-11-05 at 11 00 35 AM" src="https://github.com/user-attachments/assets/455c06a3-e1bf-43cb-b880-9470ae6fb07f"> <img width="511" alt="Screenshot 2024-11-05 at 11 00 57 AM" src="https://github.com/user-attachments/assets/ff0d4236-9b5c-4a9a-8520-47b005bb3cb0"> Differential Revision: D65354971 Pull Request resolved: https://github.com/pytorch/pytorch/pull/139741 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
75bbad4768
commit
e39955e82f
@ -96,7 +96,7 @@ def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
|
||||
and len(expr._args) == 2
|
||||
and expr._args[0].is_symbol
|
||||
and expr._args[1].is_symbol
|
||||
and expr._args[0] != expr._args[1]
|
||||
and expr._args[0] is not expr._args[1]
|
||||
)
|
||||
|
||||
|
||||
@ -570,24 +570,35 @@ class RShift(sympy.Function):
|
||||
|
||||
|
||||
class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||
def __new__(cls, *args, **assumptions):
|
||||
def __new__(cls, *original_args, **assumptions):
|
||||
from sympy.core.parameters import global_parameters
|
||||
|
||||
evaluate = assumptions.pop("evaluate", global_parameters.evaluate)
|
||||
args = (sympify(arg) for arg in args)
|
||||
args = (sympify(arg) for arg in original_args)
|
||||
|
||||
# first standard filter, for cls.zero and cls.identity
|
||||
# also reshape Max(a, Max(b, c)) to Max(a, b, c)
|
||||
# See the comment in _satisfy_unique_summations_symbols.
|
||||
unique_summations_symbols = (
|
||||
None
|
||||
if not evaluate
|
||||
else cls._satisfy_unique_summations_symbols(original_args)
|
||||
)
|
||||
|
||||
if evaluate:
|
||||
try:
|
||||
# first standard filter, for cls.zero and cls.identity
|
||||
# also reshape Max(a, Max(b, c)) to Max(a, b, c)
|
||||
args = frozenset(cls._new_args_filter(args)) # type: ignore[assignment]
|
||||
except ShortCircuit:
|
||||
return cls.zero # type: ignore[attr-defined]
|
||||
# remove redundant args that are easily identified
|
||||
args = cls._collapse_arguments(args, **assumptions)
|
||||
# find local zeros
|
||||
args = cls._find_localzeros(args, **assumptions)
|
||||
|
||||
# No need to run _collapse_arguments and _find_localzeros, see the comment
|
||||
# in _satisfy_unique_summations_symbols.
|
||||
if unique_summations_symbols is None:
|
||||
# remove redundant args that are easily identified
|
||||
args = cls._collapse_arguments(args, **assumptions)
|
||||
|
||||
# find local zeros
|
||||
args = cls._find_localzeros(args, **assumptions)
|
||||
|
||||
args = frozenset(args)
|
||||
|
||||
@ -600,8 +611,85 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
||||
# base creation
|
||||
obj = Expr.__new__(cls, *ordered(args), **assumptions)
|
||||
obj._argset = args
|
||||
|
||||
obj.unique_summations_symbols = unique_summations_symbols
|
||||
return obj
|
||||
|
||||
@classmethod
|
||||
def _satisfy_unique_summations_symbols(
|
||||
cls, args
|
||||
) -> Optional[set[sympy.core.symbol.Symbol]]:
|
||||
"""
|
||||
One common case in some models is building expressions of the form
|
||||
max(max(max(a+b...), c+d), e+f) which is simplified to max(a+b, c+d, e+f, ...).
|
||||
For such expressions, we call the Max constructor X times (once for each nested
|
||||
max) and the expression gets flattened.
|
||||
|
||||
An expensive cost in constructing those expressions is running _collapse_arguments
|
||||
and _find_localzeros. However, those two optimizations are unnecessary when the args
|
||||
to max are all of the form a+b, c+d, ..etc where each term uses a unique set of symbols.
|
||||
|
||||
This function is used to detect such properties of the expressions we are building
|
||||
and if so inform that we do not need to run those optimizations. To detect those,
|
||||
we store a property in the expression that tells that this expression is a min/max
|
||||
operation over terms that use unique symbols "unique_summations_symbols". This property
|
||||
also memoize the set of symbols used in all the terms to make it faster to detect this
|
||||
property inductively.
|
||||
|
||||
When we apply max to add a new term, all we need to do is check if the new term uses
|
||||
unique symbols (with respect to existing terms and itself).
|
||||
Example:
|
||||
t = Max(a+b, c+d) ==> satisfies the property
|
||||
Max(t, h+j) ==> h,j not in [a,b,c,d] => satisfy the property.
|
||||
|
||||
The function returns None if the new expression does not satisfy the unique_summations_symbols
|
||||
property. Otherwise, it returns a new set of unique symbols.
|
||||
"""
|
||||
if len(args) != 2:
|
||||
return None
|
||||
|
||||
(lhs, rhs) = (
|
||||
(args[1], args[0])
|
||||
if isinstance(args[1], MinMaxBase)
|
||||
else (args[0], args[1])
|
||||
)
|
||||
|
||||
if not _is_symbols_binary_summation(rhs):
|
||||
return None
|
||||
|
||||
# base case max(a+b, c+d) ==> satisfies the property if a+b and c+d use unique symbols.
|
||||
if _is_symbols_binary_summation(lhs):
|
||||
return cls._unique_symbols(args)
|
||||
|
||||
# inductive case max(t, h+j) ==> satisfies the property if h, j not in t.unique_summations_symbols
|
||||
if isinstance(lhs, MinMaxBase):
|
||||
lhs_unique_summations_symbols = getattr(
|
||||
lhs, "unique_summations_symbols", None
|
||||
)
|
||||
if lhs_unique_summations_symbols is not None:
|
||||
return cls._unique_symbols([rhs], lhs_unique_summations_symbols)
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _unique_symbols(
|
||||
cls, args, initial_set: Optional[set[sympy.core.symbol.Symbol]] = None
|
||||
) -> Optional[set[sympy.core.symbol.Symbol]]:
|
||||
"""
|
||||
Return seen_symbols if all atoms in all args are all unique symbols,
|
||||
else returns None. initial_set can be used to represent initial value for seen_symbols
|
||||
"""
|
||||
seen_symbols = set() if initial_set is None else initial_set
|
||||
for arg in args:
|
||||
for element in arg.atoms():
|
||||
if not isinstance(element, sympy.core.symbol.Symbol):
|
||||
return None
|
||||
elif element in seen_symbols:
|
||||
return None
|
||||
else:
|
||||
seen_symbols.add(element)
|
||||
return seen_symbols
|
||||
|
||||
@classmethod
|
||||
def _collapse_arguments(cls, args, **assumptions):
|
||||
"""Remove redundant args.
|
||||
|
Reference in New Issue
Block a user