We don't want to allow scan's combine_fn to mutate its inputs. The semantic of the mutation can be confusing. For example:
```python
def combine_fn(init, x):
```
If combine_fn mutates init, only first iteration mutates init, the rest of the iterations mutates the previous carry, which is an intermediate result. This is kind of a weird semantic because the only observable mutation is for init, which can be done outside of the combine_fn.
If combine_fn mutates x, where x is a slice of scanned inputs (i.e. xs), this pattern is more meaningful but we've not seen any use case yet.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158864
Approved by: https://github.com/zou3519
ghstack dependencies: #154193, #158965, #158863
This PR intends to rework the dispatching of the autograd key.
I.e., currently the DispatchKey.Autograd of the HOPs was triggered, even if non of the operands of the HOP have `requires_grad=True`. With this rework, the autograd is bypassed if non of the operands require gradients and only invoked if any of the operands require gradients.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151107
Approved by: https://github.com/ydwu4
This PR implements the idea of checking input mutations through tensor version and check aliasing via storage from @zou3519. Previously, we rely on whether there's a in place op that takes placeholder input, which doesn't take views into account.
When writing the PR, I also noticed a bug in previous input mutation checking logic: we were checking the whether there are operators functionalized_f where all the mutating ops have been replaced so we won't be able to detect any thing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145298
Approved by: https://github.com/zou3519
There are 4 parts (they are hard to further break into smaller ones cause they're highly coupled) in this PR:
1. **Whenever we call create_graph_input, we try to bind the symbols in the graph input.**
We've enforced the invariant that all create_graph_inputs calls must provide an example value, we could intercept at the create_graph_input calls (This PR only handles free symbols in tensors).
2. **We cache the bound_symbols** to avoid lift the same symbol repeated.
3. For lifted symbols, we re-used **lifted_freevars** i.e. the mapping between symbol proxy in parent graph to the lifted phs in current subgraph, which we handle lifted tensors. In this way, all hops that supports lifted tensors should be able to handle lifted_symints automatically (at least in dynamo part).
4. For **unbacked symbols** created during tracing, we need to also bound these symbols to its proxy. This is to support the tests cases where we want to lift unbacked symbols as input. We need the proxy of the unbacked symbol in parent graph in order to properly create the args to the hop.
5. We change all the tests after free symbols are lifted in subgraphs. And also supports the lifted symbols in existing higher order ops.
**The interaction of nested tracers:**
The previous design for lifting tensor closures is that: suppose we're in nested tracers, whenever we see a new proxy that's not created by create tracer, we recursively look for the proxy in parent tracer until we find the tracer that creates this proxy (either a placeholder or some intermediate results). More detail is in Note [Nested SubgraphTracer and free_variable handling].
Given the above design, the plan for lifting the free symbols is: whenever we lift a free tensor to be the inputs of current subgraph, we'll look at the symbols in it and bind the symbols at the same time.
For example, suppose we have the following function:
```python
def f(x: [s1, s2]):
def true_f():
def true_f_inner():
return x.sin()
```
what will happen in time order:
1. we create a subtracer 1 and start to speculate the outer cond's true_f
2. we create a another subtracer 2 and start to speculate the inner cond's true_f_inner.
3. dynamo realize the tensor input x by calling wrap_tensor in top-level to create graph input x (tracer 0), we bind the symbol s1, s2 after ph for x is created. So the graph now looks like:
```python
def gm(s1, s2, x):
```
4. when seeing TensorVariable.call_method of x, tracer2 wants to create a call_function(sin, proxy_of_x), but it finds that proxy_of_x is not created by current tracer. So it recursively look up its parent tracer1 and find parent tracer1 also doesn't track this proxy_of_x then it finds the root tracer0, who is the creator of it and tracks it as a ph. Then tracer 1 create_graph_input to lift the closure to its input ph1 and add (proxy_of_x: ph1) k-v in **lifted_freevars** of tracer 1.
Now the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(x):
```
5. Since there are free symbols inside this new tensor input, tracer 1 also binds the symbols (maybe_bind_symbol), which calls create_graph_input for s1 and s2. Now the graph looks like
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
```
6. then it goes back to tracer 2, and call create_graph_input for x and get ph2, tracer 2's **lifted_freevars** records (ph1, ph2). and tracer 2 also binds the symbols in this new tensor input. Now the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner(s1, s2, x):
```
7. Finally the sin call_function node is created by tracer 2.
**This PR also handles the following cases:**
- What if we lift two tensors share the same symbol? e.g. x1 [s1, s2], x2 [s2, s3]? Each subtracer maintains bound_symbols as a cache that maps a symbol.expr to its proxy in current tracer. So when we see x1, we'll track s1 and s2 as inputs and bound s1 to ph1, s2 to ph2. So when we try to bind symbols of x2, s2 will already be tracked so no graph input is created.
- what if a subgraph close over a symint? e.g.
```python
def f(x):
def true_f():
c = x.size(0)
def true_fn_inner():
return c
```
When we speculate true_fn_inner, we find proxy_of_c is not tracked by tracer 2, so it recursively looks up its parent. At this point, x and its symbols have been lifted as input of true_f (as a result of lifting x during tracing true_f in tracer 1. Specifically the graph looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner():
```
So tracer 2 is able to find that s1 have been tracked as ph in tracer 1 so it returns back to gm and call create_graph_input on s1. The graph now looks like:
```python
def gm(s1, s2, x):
def true_gm(s1, s2, x):
def true_gm_inner(s1):
return s1
```
- What if subgraph close over an unbacked symint? e.g.
```python
def f(x):
def true_f():
c = x.item()
def true_f_inner():
return c
```
When x.item() is called, proxy_of_c and its symnode variable is created for tracer 1, and we also call track_unbacked_symbols to record this relationship. So when tracer 2 finds proxy_of_c is not created by current tracer, it recursivelly looks up its parent tracer and finds that that expression u0 has been tracked as a result of track_unbacked_symbol in tracer 1. So it will stop the recursion and create_graph_input u0 in tracer 2. Graph looks like:
```python
def f(x):
def true_f(s1, s2, x):
c = x.item()
def true_gm_inner(u0):
return u0
cond(pred, true_gm_inner, false_gm_inner, (c,))
```
- what if subgraph close over a tensor with unbacked symint shape?
```python
def f(x):
def true_f():
c = x.item()
r = torch.randn((c,))
def true_f_inner():
return r + 1
```
This is the same as the case of closing over tensors with backed shapes. where we first lift r, then bind u0 in it, which recursively bind_symint of u0 in its parent and found u0 is tracked in parent tracer as a result of .item() call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138363
Approved by: https://github.com/zou3519
This pr introduces two changes:
1. Before this pr, the subgraphs output is ([], []), in this pr, we change it to a flattened list for easier codegen and consistency with other control flow operators.
2. Before the PR, the combine_fn of scan takes a sliced input but keep the sliced dimension. For exmaple, suppose xs = torch.randn(3, 4, 5) and we scan over dim 0, the combine_fn looks like:
```
# x.shape = (1, 4, 5) instead of (4, 5)
def combine_fn(carry, x):
...
```
In this PR, we fixed this and also simplify some of the slicing logic.
3. this diff also make sure we always stack ys on fist dimension.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135601
Approved by: https://github.com/zou3519
ghstack dependencies: #135600
For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
Approved by: https://github.com/ydwu4
For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
Approved by: https://github.com/ydwu4
For tracing cond/while in eager, we trace the HOP with the eager backend with metadata torchfunction mode enabled. HOPs disallow the mutation that occurs in this torch function mode, so it is not able to be traced. As a result, we use a custom backend which enters this mode for tracing these HOPs. Thanks to @ydwu4 for the help with implementing this
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134732
Approved by: https://github.com/ydwu4