mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Optimize increment summations [Latest Nov 15] (#140822)
Summary: **wins** on torchrec benchmark, for 2K nodes it save 40seconds with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on). ``` buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200 ``` This diff optimizes construction expressions of the form a+b+c... (all unique symbols). which are very common in torchrec models. **How** Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them. If we have a+b+c and we are adding (d) to it, we can do a binary search to know the position of (d) and avoid optimizing the new expression by passing the new order. **Extensions**: 1. support constant terms. 2. support 10a+10b+.. (this will give even more wins will extend the support in second PR) Differential Revision: D66008482 Pull Request resolved: https://github.com/pytorch/pytorch/pull/140822 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a440a01832
commit
8d708090c0
@ -89,6 +89,17 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
|
||||
# No need to check that two args are not the same, since expr is pr-optimized but we do it anyway.
|
||||
return (
|
||||
expr.is_Add
|
||||
and len(expr._args) == 2
|
||||
and expr._args[0].is_symbol
|
||||
and expr._args[1].is_symbol
|
||||
and expr._args[0] != expr._args[1]
|
||||
)
|
||||
|
||||
|
||||
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Any) -> Union[_T, sympy.Float]:
|
||||
|
||||
Reference in New Issue
Block a user