Optimize increment summations [Latest Nov 15] (#140822)

Summary:
**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form
a+b+c...  (all unique symbols).
which are very common in torchrec models.

**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have  a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.

**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)

Differential Revision: D66008482

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140822
Approved by: https://github.com/ezyang
This commit is contained in:
Laith Sakka
2024-11-20 16:48:20 +00:00
committed by PyTorch MergeBot
parent a440a01832
commit 8d708090c0
4 changed files with 245 additions and 6 deletions

View File

@ -89,6 +89,17 @@ __all__ = [
]
def _is_symbols_binary_summation(expr: sympy.Expr) -> bool:
# No need to check that two args are not the same, since expr is pr-optimized but we do it anyway.
return (
expr.is_Add
and len(expr._args) == 2
and expr._args[0].is_symbol
and expr._args[1].is_symbol
and expr._args[0] != expr._args[1]
)
def _keep_float(f: Callable[..., _T]) -> Callable[..., Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Any) -> Union[_T, sympy.Float]: