Fixes https://github.com/pytorch/pytorch/issues/141305.
```python
class M(torch.nn.Module):
def forward(self, x, y, z):
a = y.shape[0]
b = z.shape[0]
def true_fn(x):
return x + a
def false_fn(x):
return x + b * z
# When exporting with non-strict: a and b are symints,
# so torch.compile need to wrap and trace symint inputs.
return torch.cond(x.shape[0] > 5, true_fn, false_fn, (x,))
```
In non-strict export, when inputs are annotated with dynamic shape, the a, and b in above example are torch.SymInt type. true_fn and false_fn will have closure that're of torch.SymInt types. The error is triggered because we didn't handle SymInt inputs in dynamo and ends up using a UserDefinedObjectVariable for it, which doesn't have a proxy. We added support by following how we handle SymBool input previously.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141524
Approved by: https://github.com/zou3519
ghstack dependencies: #141610, #142185
This PR fixes the shape checks that are done in the associative_scan operation.
Before all shapes of the input leaves were required to be the same. With this PR only the shapes of the output of the combine_fn and the input leaves need to be the same, but not among the input leaves.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141698
Approved by: https://github.com/ydwu4
There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR:
1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.**
We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors).
2. **We cache the bound_symbols** to avoid lift the same symbol repeated.
3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part).
4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop.
5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops.
**The interaction of nested tracers:**
The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling].
Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time.
For example, suppose we have the following function:
```python
def f(x: [s1, s2]):
def true_f():
def true_f_inner():
return x.sin()
```
what will happen in time order:
1. we create a subtracer 1 and start to speculate the outer cond's true_f
2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner.
3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like:
```python
def gm(s1, s2, x):
```
4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1.
Now the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(x):
```
5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
```
6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner(s1, s2, x):
```
7. Finally the sin call_function node is created by tracer 2.
**This PR also handles the following cases:**
- What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created.
- what if a subgraph close over a symint? e.g.
```python
def f(x):
def true_f():
c = x.size(0)
def true_fn_inner():
return c
```
When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner():
```
So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner(s1):
return s1
```
- What if subgraph close over an unbacked symint? e.g.
```python
def f(x):
def true_f():
c = x.item()
def true_f_inner():
return c
```
When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like:
```python
def f(x):
def true_f(s1, s2, x):
c = x.item()
def true_gm_inner(u0):
return u0
cond(pred, true_gm_inner, false_gm_inner, (c,))
```
- what if subgraph close over a tensor with unbacked symint shape?
```python
def f(x):
def true_f():
c = x.item()
r = torch.randn((c,))
def true_f_inner():
return r + 1
```
This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363
Approved by: https://github.com/zou3519
Code refactoring only. We move the wrap_to_fake_tensor_logic out of wrap_fx_proxy for placeholders to provide the invariant that **all graph inputs must set their example values when creating the inputs**. This invariant helps us to identify all the free symbols in the graph in top-level and sub-graphs.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138428
Approved by: https://github.com/ezyang, https://github.com/zou3519
ghstack dependencies: #138345
Fixes https://github.com/pytorch/pytorch/issues/136640
Today, inductor has some logic to figure out when it needs to do broadcasting during lowering, which just checks if any of the input shapes have sizes equal to 1.
In particular: we should already have this information by the time we get to inductor, because our FakeTensor compute will have branched/guarded on whether any ops performed broadcasting, appropriately.
In particular, if we have a tensor with a size value of `(64//((2048//(s3*((s2//s3)))))))`, and it happens to be equal to one (and it is used in an op that requires this dim to be broadcasted), FakeTensorProp will have generated a guard:
```
Eq((64//((2048//(s3*((s2//s3))))))), 1)
```
I chose the simplest possible way to beef up inductor's checks to know when a given size is equal to 1: loop over the existing shape env guards, and if our current size is a sympy expression on the LHS of one of our `Eq(LHS, 1)` guards, then return True.
I'm hoping for feedback on whether or not this approach is reasonable. One better option I could imagine is that our symbolic reasoning should have automatically simplified the size of our tensor down to a constant as part of evaluating that guard. I was originally going to try to do this directly in the shape env, but I ran into a few issues:
(1) I wanted to call some version of `set_replacement(expr, 1)`. But `set_replacement()` only accepts plain symbols on the LHS, not expressions
(2) in theory I could get this to work if I could rework the above expression to move everything that is not a free variable to the RHS, e.g. `Eq(s2, 32)`. It looks like our existing `try_solve()` logic is... [not quite able](https://github.com/pytorch/pytorch/blob/main/torch/utils/_sympy/solve.py#L27) to do this generally though.
Checking the guards feels pretty simple-and-easy. Are we worried that it is too slow to iterate over all the guards? I could also cache the lookup so we only need to iterate over guards that are of the form `Eq(LHS, 1)`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136670
Approved by: https://github.com/ezyang
This pr introduces two changes:
1. Before this pr, the subgraphs output is ([], []), in this pr, we change it to a flattened list for easier codegen and consistency with other control flow operators.
2. Before the PR, the combine_fn of scan takes a sliced input but keep the sliced dimension. For exmaple, suppose xs = torch.randn(3, 4, 5) and we scan over dim 0, the combine_fn looks like:
```
# x.shape = (1, 4, 5) instead of (4, 5)
def combine_fn(carry, x):
...
```
In this PR, we fixed this and also simplify some of the slicing logic.
3. this diff also make sure we always stack ys on fist dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135601
Approved by: https://github.com/zou3519
ghstack dependencies: #135600
https://github.com/pytorch/pytorch/pull/133012 caused a regression on ROCm causing pointwise scan tests to fail
```
ERROR: test_pointwise_associative_scan_tuple_reverse_True_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_tuple_reverse_False_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_complex_pytree_reverse_True_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_complex_pytree_reverse_False_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_binary_operator_reverse_True_combine_mode_pointwise_cuda
ERROR: test_pointwise_associative_scan_binary_operator_reverse_False_combine_mode_pointwise_cuda
```
Skipping temporarily while triage is underway.
Full log: https://ossci-raw-job-status.s3.amazonaws.com/log/30067645445
```
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/graph.py", line 1020, in call_function
out = lowerings[target](*args, **kwargs) # type: ignore[index]
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 363, in wrapped
out = decomp_fn(*args, **kwargs)
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 6245, in associative_scan
raise RuntimeError("Unable to generate code for associative_scan op")
torch._inductor.exc.LoweringException: RuntimeError: Unable to generate code for associative_scan op
```
NOTE: even "eager" backend fails
```
File "/opt/conda/envs/py_3.8/lib/python3.8/site-packages/torch/_higher_order_ops/associative_scan.py", line 338, in associative_scan_op_dense
raise NotImplementedError("associative_scan is not implemented for eager")
NotImplementedError: associative_scan is not implemented for eager
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135995
Approved by: https://github.com/malfet
Notable changes:
1. Enable CudaGraph related tests
2. Fix UT problems
3. EXPERIMENTAL Navi31 support. User should enable Navi31 support with Env Var `TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1`
Know Problem:
1. `test/test_transformers.py` will massive failures and/or NaN outputs with `--use-pytest`
+ Update: Confirmed skip `class TestSDPAPrivateUse1Only` can fix the problem with `--use-pytest`
Note:
AOTriton 0.7b adds support to nestedtenosrs+SDPA but need more work (and consequently a separate PR) to enable it.
Fixes#133540
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134498
Approved by: https://github.com/pruthvistony, https://github.com/jeffdaily, https://github.com/malfet
This is part of a series of PRs to improve the functionality of the `associatve_scan` functionality. This specific PR introduces a `combine_mode`, which can be either `pointwise` (default) or `generic`. In case of `generic`, the `associative_scan` is more flexible and allows also to perform non-pointwise functions. This PR has been derived from https://github.com/pytorch/pytorch/pull/129307.
@ydwu4 @Chillee @zou3519
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133012
Approved by: https://github.com/ydwu4
Add a way of generating a FunctionSchema from example values because hop's schema varies even for the same hop.
We didn't use torch._C.FunctionSchema because we cannot construct the classes directly (e.g. "__init__" cannot be used for torch._C.FunctionSchema). Also extending the Basic types in c++ seems not that easy.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133521
Approved by: https://github.com/zou3519
Fixes code object sharing issue in https://github.com/pytorch/pytorch/issues/132417.
Before this Pr, compiled hops such as cond and flex_attenion are wrapped by _dynamo/external_utils.py:wrap_inline. This causes them to share the same code object. There is a condition surrounding the warp_inline call and currently is passing.
We make hops fail the check so that they don't share code objects by adding them to LEGACY_MOD_INLINELIST. Adding them to MOD_INLINELIST doesn't work because trace_rules.check(fn) doesn't check for MOD_INLINLIST by default.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132427
Approved by: https://github.com/jansel, https://github.com/anijain2305
Add similar semantics for creating a buffer object similar to creating a parameter. This is done by introducing a new Buffer class that can be used for type disambiguation. The underlying functionality of registering a buffer remains the same as the register_buffer method has not been changed. The persistent parameter in the Buffer type is to indicate whether a buffer object should be persistent or not. Other non-test changes have to do with getting the new Buffer type recognized by inductor and dynamo. Remaining changes are test changes to make sure that the Buffer type can be used as a drop in replacement for register_buffer as it just leads to register_buffer being called. The addition of this new functionality still allows for normal tensors to be used as buffers so these changes are intended to be backwards compatible.
Fixes#35735
Co-authored-by: Mikayla Gawarecki <mikaylagawarecki@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125971
Approved by: https://github.com/albanD, https://github.com/anijain2305, https://github.com/mlazos
This is an updated PR to equip cond with the autograd feature and replaces the old [PR](https://github.com/pytorch/pytorch/pull/126007)
@ydwu4 I tried to incorporate your requests already.
Currently there are two problems that I struggle with solving:
1. There seems to be an import issue when trying to import cond in `torch/__init__.py`, see [here](8a704035c9/torch/__init__.py (L1914-L1916)). Therefore, I had to comment those lines, which resolved the import issues, but I believe cond is not proberly exposed as torch.cond.
2. I am not entirely sure how to deal with the opinfo test in `hop_db.py`
Co-authored-by: Yidi Wu <yidi@meta.com>
Co-authored-by: Xuehai Pan <XuehaiPan@outlook.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126911
Approved by: https://github.com/ydwu4
This is an updated PR to equip cond with the autograd feature and replaces the old [PR](https://github.com/pytorch/pytorch/pull/126007)
@ydwu4 I tried to incorporate your requests already.
Currently there are two problems that I struggle with solving:
1. There seems to be an import issue when trying to import cond in `torch/__init__.py`, see [here](8a704035c9/torch/__init__.py (L1914-L1916)). Therefore, I had to comment those lines, which resolved the import issues, but I believe cond is not proberly exposed as torch.cond.
2. I am not entirely sure how to deal with the opinfo test in `hop_db.py`
Co-authored-by: Yidi Wu <yidi@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126911
Approved by: https://github.com/ydwu4