Summary:
Cloned https://github.com/pytorch/pytorch/pull/153558 from benjaminglass1 and fixed internal typing errors.
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.
Decisions made along the way:
1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.
The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
Test Plan: CI
Differential Revision: D75497142
Co-authored-by: Benjamin Glass <bglass@quansight.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/154555
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/zou3519, https://github.com/benjaminglass1
Fixes longstanding issue where direct references to aten operations are seen as untyped by type checkers. This is accomplished by setting attributes on several classes more consistently, so that `__getattr__` can return a single type in all other cases.
Decisions made along the way:
1. `torch.ops.higher_order` is now implemented by a single-purpose class. This was effectively true before, but the class implementing it attempted to be generalized unnecessarily. Fixing this simplified typing for the `_Ops` class.
2. `__getattr__` is only called when all other lookup methods have failed, so several constant special-cases in the function could be implemented as class variables.
The remainder of this PR is fixing up all the bugs exposed by the updated typing, as well as all the nitpicky typing issues.
Test plan: CI
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153558
Approved by: https://github.com/rec, https://github.com/Skylion007, https://github.com/cyyever
Applies the remaining flake8-comprehension fixes and checks. This changes replace all remaining unnecessary generator expressions with list/dict/set comprehensions which are more succinct, performant, and better supported by our torch.jit compiler. It also removes useless generators such as 'set(a for a in b)`, resolving it into just the set call.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94676
Approved by: https://github.com/ezyang
Two small changes that I'm bundling together because one of them needs to touch fbcode and I'm not sure how to do stacked diffs + internal changes + land before release cut.
Remove allow_meta from ctor, and allow by default: we should be able to trace through meta with fake tensors, so in some senses it's a bit weird to expose to user to disallow this. However, it's still useful debug wise to error from time to time, so I've added an option to the config that will get back previous behavior.
Remove `throw_on_data_dependent_ops=True`: this was intended as a temporary behavior as we were smoothing things turning on the erroring. There are no uses anywhere of `throw_on_data_dependent_ops=False` I could find.
These are technically backward-incompatble, but fake tensor is new since the last release / in a private namespace, and I don't want to release it with baggage that would be hard to remove later.
Fix for https://github.com/pytorch/pytorch/issues/92877.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93993
Approved by: https://github.com/bdhirsh, https://github.com/ezyang
There is already special handling in the reinplacing pass for removing `{view}_scatter` ops, but there is another case that needs special handling. In this code:
```
def f():
a = torch.zeros(4, 4, 4)
a[:, 2:] = torch.ones(4, 2, 4)
return a
```
Tracing normally with `make_fx()` gives you:
```
def forward(self):
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
copy__default = torch.ops.aten.copy_.default(slice_tensor_1, ones); slice_tensor_1 = ones = None
return zeros
```
Functionalizing it gives you:
```
def forward(self):
zeros = torch.ops.aten.zeros.default([4, 4, 4], device = device(type='cpu'), pin_memory = False)
ones = torch.ops.aten.ones.default([4, 2, 4], device = device(type='cpu'), pin_memory = False)
slice_tensor = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_tensor_1 = torch.ops.aten.slice.Tensor(slice_tensor, 1, 2, 9223372036854775807); slice_tensor = None
slice_tensor_2 = torch.ops.aten.slice.Tensor(zeros, 0, 0, 9223372036854775807)
slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, 9223372036854775807); slice_tensor_2 = ones = None
slice_scatter_default_1 = torch.ops.aten.slice_scatter.default(zeros, slice_scatter_default, 0, 0, 9223372036854775807); zeros = slice_scatter_default = None
return slice_scatter_default_1
```
Notice that there are not any functional ops to directly re-inplace! What actually happened is that functionalization turned the `copy_()` into a `copy()`, but the out-of-place `copy()` operator gets optimized away because it's a no-op (when the input and output metadata are the same, `out = copy(a, b)` just returns `b`).
What we actually want is to replace this line:
```
slice_scatter_default = torch.ops.aten.slice_scatter.default(slice_tensor_2, ones, 1, 2, ...);
```
with this:
```
new_slice = torch.ops.aten.slice.Tensor(slice_tensor_2, 1, 2, ...);
_ = torch.ops.aten.copy_.default(new_slice, ones)
```
In the above, we're taking a fresh slice of the "base" tensor, and performing a `copy_()` on the slice, adding back what functionalization removed.
We actually need to create a fresh "slice" node, because we're not guaranteed that one already exists in the graph (technically there should be one, but it might have been DCE'd by the time we hit re-inplacing)
I also updated the docs for re-inplacing to more closely match the order of the logic.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83846
Approved by: https://github.com/ezyang
Cleaned up some of the arg replacement logic to use tree_map, so it handles FX nodes that have nested containers.
See the added test: when you write a function that returns a list, the `output` node in the FX graph shows up as having `node.args = tuple(immutable_list(...))`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83845
Approved by: https://github.com/ezyang
I'm testing out turning on re-inplacing + functionalization by default with the AOTAutograd + eager backend on torchbench + huggingface models. This PR contains a few bug fixes from turning re-inplacing on:
(1) Handle more gracefully when FakeTensorMode is already turned on when you call reinplace
(2) More robust detection for when an inplace variant of an op exists (the dumb bug was that `pow.Scalar` doesn't have an inplace variant, even though there are several overloads of `pow_`. None of them are eligible though
(3) Avoid re-inplacing when it would require resizing the input buffer. This isn't allowed, because inplace ops aren't allowed to resize their inputs.
For the last one, I gave the two main examples in more detail in the comments. Important cases are:
```
# This should not be re-inplaced at all; the op broadcasts, so this would require resizing the self tensor
torch.add(tensor[1, 4], tensor[4, 4])
# This should not be re-inplaced, because the inplace and out-of-place variants of the op return different dtypes
torch.ge(a, b)
# However, this means that today when functionalization functionalists a `torch.ge_(a, b)` call, reinplacing won't properly de-functionalize it. I mentioned that optimization is worth adding later in the comments
```
(4) There's some logic around keeping `storage_to_nodes` up to date when we see a view op: if we re-inplace `out = a.add(...)`, and later in the program we encounter a "later_node",`out.view(..)`, and need to replace it with `a.view(...)`, then we need to update some metadata structures. I had to fix that logic: specifically, if "later_node" isn't a dispatcher op, (e.g. if it's an FX output node), I wasn't properly handling the case where the node's fake_meta info was not a tensor.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83626
Approved by: https://github.com/ezyang
Adds a "reinplacing" FX transform, that goes through an FX graph and tries to convert out-of-place op calls into inplace calls whenever possible.
Followups from this PR include:
- Set up torch bench, and run the whole torchbench suite using AOTAutograd + functionalize + rein placing transforms to surface any issues (this is what I'm currently working on). Right now, I have some basic unit tests just to sanity check that the general logic makes sense.
- Add any missing inplace ops. This is mostly the `*_scatter*` ops, e.g. `diagonal_scatter_`, because these ops will commonly show up an FX graph after running functionalization.
The criteria for when you can swap an op `b = a.add(...)` with `a.add_(...)` is:
(1) An inplace variant of the operator with the same schema needs to exist (`aten.add` -> `aten.add_`)
(2) `a` (**or any of its aliases**) can't be used as an input to any other operators later on in the graph
(3) `a` can't be one of the inputs to the entire graph. It also can't be an **alias** of any of the inputs ***
*** One thing to note: (3) means that we can't technically guarantee that we'll get back **all** memory usage that we lost from functionalization. Functionalization converts input mutations into out-of-place calls, and then adds a `copy_()` to the end of the graph to preserve semantics.
I added logic to handle `copy_()` in this PR because it it's a pretty important optimizations in the context of `functionalization()`: any program that performs input mutations will have a `copy_()` in it after running functionalization.
There are some examples in the test file, but I think staring at an example of where re-inplacing is/isn't allowed to run is helpful:
```
// Before functionalization
def foo(a):
tmp1 = a.add_(1)
tmp2 = a.add(2)
// After functionalization
def foo(a)
tmp1 = a.add(1)
tmp2 = a.add(2)
....
a.copy_(tmp1)
// After re-inplacing
def foo(a)
// first add() is safe to re-inplace even though a is a program input,
// because a's data is overwritten later by a copy_()
tmp1 = a.add_(1)
// second add() is NOT safe to re-inplace, because:
// (1) a and tmp1 are aliased. Note that they weren't aliased in the original program,
but they are now that we've done some re-inplacing.
// (2) tmp1 is used as an input later in the program
tmp2 = a.add(2)
....
a.copy_(tmp1)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80897
Approved by: https://github.com/ezyang