Fixes#155688
Root Cause:
in [`torch/_inductor/index_propagation.py`](f151b20123/torch/_inductor/index_propagation.py (L57-L68))
When creating a `TypedExpr` from an `Identity` (a `torch.utils._sympy.functions.Identity`, not a `sympy.matrices.expressions.Identity `) and the inner value of the identity, `Identity.args[0]`, is any torch int type, the `TypedExpr.__post_init__` method tries to cast the Identity object to a python `int`. This is where to `TypeError` from the issue was raised, because Identity does not know how to cast to an `int`.
Fix:
Define `__int__` method for `torch.utils._sympy.functions.Identity`.
wlog for `float`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155873
Approved by: https://github.com/williamwen42
Make Functorch interpreters serializable most of the time, so that we can save the guards on functorch states.
## Test Cases:
0. torch.compile() without functorch layers present. Guard should fail with any layer being pushed.
1. torch.compile() nested in vmap.
2. torch.compile() nested in grad.
3. torch.compile() nested in jvp + vmap
4. torch.compile() nested functionalize
5. torch.compile() nested in vmap + grad
Differential Revision: [D74008787](https://our.internmc.facebook.com/intern/diff/D74008787/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152616
Approved by: https://github.com/zou3519
ghstack dependencies: #152615
# Feature
Inductor sometimes uses `Identity` functions to group various terms of an expression. While this is convenient in some scenarios, it can frustrate pattern matching. For example, when we're matching an indexing expression to tell if it can be represented as a block pointer, that analysis should be invariant to `Identity`'s.
This PR adds a few features to achieve this invariance.
- Create a new expansion mode `expr.expand(identity=True)`, which removes all `Identity` functions from the expression.
- Preprocess the expression with this expansion prior to pattern matching.
- Bonus: create a new test utility function called `dummy_graph()`, which creates a simple `GraphLowering`. This is useful for testing the pattern matcher, as we need to initialize `V.graph` before we can access `V.graph.sizevars`.
# Test plan
This PR adds a few new unit tests:
- Added a unit test specifically for `expr.expand(identity=True)`.
- Added a new unit test module for the block pattern matcher. Tested that we can correctly match some example patterns containing Identity ops.
I originally intended to add an end to end test compiling pointwise cat, and mapping the corresponding memory accesses to block pointers. However, it looks like that will take more work, since the [relevant code path](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/codegen/triton.py#L1306) disables block pointer analysis. It might be better to defer that to a future PR.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146000
Approved by: https://github.com/eellison, https://github.com/jansel
* Let's say x is an integer beyond 2^53 where Python floats lose precision i.e. can't increment by 1.
* Therefore, float(x) will lose precision and won't retain the exact value of x even though it's an integer.
* That means `FloorToInt(very_large_number)` will lose precision if we cast it to float
```
>>> int(float(1000000007999999992))
1000000008000000000
```
This means when we try to do this in set_replacement():
32bb6f83d5/torch/fx/experimental/symbolic_shapes.py (L6011-L6019)
We run into this:
```
TORCH_LOGS="+torch.fx.experimental.symbolic_shapes" pytest -s test_export.py -k test_replace_unbacked_with_very_large_upperbound
File "/data/users/colinpeppler/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6258, in _maybe_guard_rel
self._set_replacement(rhs, self._find(lhs), "trivial_rhs")
File "/data/users/colinpeppler/pytorch/torch/fx/experimental/symbolic_shapes.py", line 6039, in _set_replacement
assert tgt_bound.issubset(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function add>(*(FakeTensor(..., size=(2*s0,)), FakeTensor(..., size=(u0,))), **{}):
tgt_bound=VR[4, 1000000008000000000] not a subset of src_bound=VR[4, 1000000007999999992]
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146001
Approved by: https://github.com/bobrenjc93
ghstack dependencies: #145898
I encountered this C++ compilation error.
```
579 | int64_t var_6 = (static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))) | std::floor((1.0/16.0)*(static_cast<int64_t>(std::floor((1.0/2.0)*u0)) | static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))));
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ ^ ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| | |
| int64_t {aka long int} double
```
Then, I figured out where this std::floor came from with the help of Bob's guard provenance tool. It comes from RShift which is used in `triton.next_power_of_2`.
---
Before, we used `std::floor`
```
int64_t var_6 = (
static_cast<int64_t>(std::floor((1.0/2.0)*u0)) |
static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0)))))
| std::floor((1.0/16.0)*(static_cast<int64_t>(std::floor((1.0/2.0)*u0)) # no cast to int here.
| static_cast<int64_t>(std::floor((1.0/4.0)*static_cast<int64_t>(std::floor((1.0/2.0)*u0))))));
```
Now, we use `c10::div_floor_integer` instead
```
int64_t var_6 = (
(c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(2L))) |
(c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(8L)))) |
(c10::div_floor_integer(static_cast<int64_t>((c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(2L)))
| (c10::div_floor_integer(static_cast<int64_t>(u0), static_cast<int64_t>(8L)))), static_cast<int64_t>(16L)));
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145898
Approved by: https://github.com/desertfire, https://github.com/bobrenjc93
ghstack dependencies: #145802
The codebase has a few locations where callable parameter type information is lost when the unpackings *args and **kwargs are typed as Any. Refactor these instances to retain type information using typing_extensions.ParamSpec.
Also, in these functions, enforce return type with TypeVar.
Addresses #142306
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143797
Approved by: https://github.com/Skylion007
Co-authored-by: Aaron Gokaslan <aaronGokaslan@gmail.com>
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
Summary:
**wins**
on torchrec benchmark, for 2K nodes it save 40seconds
with the recent sympy changes (https://www.internalfb.com/diff/D65883538) we save around 13 second ( with the max opt on).
```
buck2 run fbcode//mode/opt fbcode//torchrec/distributed/tests:pt2_compile_benchmark -- --num-features=200
```
This diff optimizes construction expressions of the form
a+b+c... (all unique symbols).
which are very common in torchrec models.
**How**
Expressions of the form a+b+c are not optimized by add, the only needed optimization is sorting them.
If we have a+b+c and we are adding (d) to it, we can do a binary search to know
the position of (d) and avoid optimizing the new expression by passing the new order.
**Extensions**:
1. support constant terms.
2. support 10a+10b+.. (this will give even more wins will extend the support in second PR)
Differential Revision: D66008482
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140822
Approved by: https://github.com/ezyang