It turns out that we *do* need to update *_scatter ops to return the exact same strides as their inputs. I added a test to `test/test_functionalization.py`, which now trips thanks to Ed's functionalization stride debugging check. It only actually ends up tripping silent correctness if you try to .backward() on that function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91029
Approved by: https://github.com/ezyang
Move functorch/functorch into `functorch` folder
- Add functorch/CMakeLists.txt that adds `functorch` native python exension
- Modify `setup.py` to package pytorch and functorch together into a single wheel
- Modify `functorch.__version__` is not equal to that of `torch.__version__`
- Add dummy `functorch/setup.py` file for the projects that still want to build it
Differential Revision: [D39058811](https://our.internmc.facebook.com/intern/diff/D39058811)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83464
Approved by: https://github.com/zou3519
There is already special handling in the reinplacing pass for removing `{view}_scatter` ops, but there is another case that needs special handling. In this code:
```
def f():
a = torch.zeros(4, 4, 4)
a[:, 2:] = torch.ones(4, 2, 4)
return a
```
Tracing normally with `make_fx()` gives you:
```
def forward(self):
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
copy__default = torch.ops.aten.copy_.default(slice_tensor_1, ones); slice_tensor_1 = ones = None
return zeros
```
Functionalizing it gives you:
```
def forward(self):
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, 9223372036854775807); slice_tensor_2 = ones = None
slice_scatter_default_1 = torch.ops.aten.slice_scatter.default(zeros, slice_scatter_default, 0, 0, 9223372036854775807); zeros = slice_scatter_default = None
return slice_scatter_default_1
```
Notice that there are not any functional ops to directly re-inplace! What actually happened is that functionalization turned the `copy_()` into a `copy()`, but the out-of-place `copy()` operator gets optimized away because it's a no-op (when the input and output metadata are the same, `out = copy(a, b)` just returns `b`).
What we actually want is to replace this line:
```
slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, ...);
```
with this:
```
new_slice = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, ...);
_ = torch.ops.aten.copy_.default(new_slice, ones)
```
In the above, we're taking a fresh slice of the "base" tensor, and performing a `copy_()` on the slice, adding back what functionalization removed.
We actually need to create a fresh "slice" node, because we're not guaranteed that one already exists in the graph (technically there should be one, but it might have been DCE'd by the time we hit re-inplacing)
I also updated the docs for re-inplacing to more closely match the order of the logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83846
Approved by: https://github.com/ezyang
Cleaned up some of the arg replacement logic to use tree_map, so it handles FX nodes that have nested containers.
See the added test: when you write a function that returns a list, the `output` node in the FX graph shows up as having `node.args = tuple(immutable_list(...))`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83845
Approved by: https://github.com/ezyang
I'm testing out turning on re-inplacing + functionalization by default with the AOTAutograd + eager backend on torchbench + huggingface models. This PR contains a few bug fixes from turning re-inplacing on:
(1) Handle more gracefully when FakeTensorMode is already turned on when you call reinplace
(2) More robust detection for when an inplace variant of an op exists (the dumb bug was that `pow.Scalar` doesn't have an inplace variant, even though there are several overloads of `pow_`. None of them are eligible though
(3) Avoid re-inplacing when it would require resizing the input buffer. This isn't allowed, because inplace ops aren't allowed to resize their inputs.
For the last one, I gave the two main examples in more detail in the comments. Important cases are:
```
# This should not be re-inplaced at all; the op broadcasts, so this would require resizing the self tensor
torch.add(tensor[1, 4], tensor[4, 4])
# This should not be re-inplaced, because the inplace and out-of-place variants of the op return different dtypes
torch.ge(a, b)
# However, this means that today when functionalization functionalists a `torch.ge_(a, b)` call, reinplacing won't properly de-functionalize it. I mentioned that optimization is worth adding later in the comments
```
(4) There's some logic around keeping `storage_to_nodes` up to date when we see a view op: if we re-inplace `out = a.add(...)`, and later in the program we encounter a "later_node",`out.view(..)`, and need to replace it with `a.view(...)`, then we need to update some metadata structures. I had to fix that logic: specifically, if "later_node" isn't a dispatcher op, (e.g. if it's an FX output node), I wasn't properly handling the case where the node's fake_meta info was not a tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83626
Approved by: https://github.com/ezyang
Adds a "reinplacing" FX transform, that goes through an FX graph and tries to convert out-of-place op calls into inplace calls whenever possible.
Followups from this PR include:
- Set up torch bench, and run the whole torchbench suite using AOTAutograd + functionalize + rein placing transforms to surface any issues (this is what I'm currently working on). Right now, I have some basic unit tests just to sanity check that the general logic makes sense.
- Add any missing inplace ops. This is mostly the `*_scatter*` ops, e.g. `diagonal_scatter_`, because these ops will commonly show up an FX graph after running functionalization.
The criteria for when you can swap an op `b = a.add(...)` with `a.add_(...)` is:
(1) An inplace variant of the operator with the same schema needs to exist (`aten.add` -> `aten.add_`)
(2) `a` (**or any of its aliases**) can't be used as an input to any other operators later on in the graph
(3) `a` can't be one of the inputs to the entire graph. It also can't be an **alias** of any of the inputs ***
*** One thing to note: (3) means that we can't technically guarantee that we'll get back **all** memory usage that we lost from functionalization. Functionalization converts input mutations into out-of-place calls, and then adds a `copy_()` to the end of the graph to preserve semantics.
I added logic to handle `copy_()` in this PR because it it's a pretty important optimizations in the context of `functionalization()`: any program that performs input mutations will have a `copy_()` in it after running functionalization.
There are some examples in the test file, but I think staring at an example of where re-inplacing is/isn't allowed to run is helpful:
```
// Before functionalization
def foo(a):
tmp1 = a.add_(1)
tmp2 = a.add(2)
// After functionalization
def foo(a)
tmp1 = a.add(1)
tmp2 = a.add(2)
....
a.copy_(tmp1)
// After re-inplacing
def foo(a)
// first add() is safe to re-inplace even though a is a program input,
// because a's data is overwritten later by a copy_()
tmp1 = a.add_(1)
// second add() is NOT safe to re-inplace, because:
// (1) a and tmp1 are aliased. Note that they weren't aliased in the original program,
but they are now that we've done some re-inplacing.
// (2) tmp1 is used as an input later in the program
tmp2 = a.add(2)
....
a.copy_(tmp1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80897
Approved by: https://github.com/ezyang