Commit Graph

100 Commits

Author SHA1 Message Date
c933af2709 Switch to predispatch (#123573)
This PR switches export IR from aot-dispatch to pre-dispatch IR.

**What is pre-dispatch IR and why should you care?**

Currently the default IR returned by torch.export can contain only functional ATen operators after ALL pytorch dispatcher decompositions (for example, CompositeImplicitAutograd) run.

In contrast, pre-dispatch IR refers to an IR that can contain all functional ATen operators (i.e., not just from the core subset), before any decomposition happens, as well as operators that manipulate autograd state. Pre-dispatch IR closely resembles eager PyTorch computation, but is still functional and serializable by torch.export. As a result:
- You can train the pre-dispatch IR in eager mode as the IR contains necessary information for the autograd engine to automatically generate a backward graph.
- You can write sound graph transformations more easily as the IR is functional.
- Since it is an ATen IR, it is still normalized. For example, torch.add has multiple overloads, but aten.add.Tensor is unique in this IR.

If you want to get the core aten IR out of `torch.export`, you will need to:
```
ep = torch.export.export(M(), inputs)
ep_for_core_aten = ep.run_decompositions()
```

Differential Revision: [D56273267](https://our.internmc.facebook.com/intern/diff/D56273267)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123573
Approved by: https://github.com/gmagogsfm
2024-04-24 00:51:09 +00:00
a78450a00b Excise uses of the old custom ops APIs (#124134)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124134
Approved by: https://github.com/albanD
ghstack dependencies: #124180, #124200, #124299
2024-04-19 17:56:26 +00:00
93e249969b [BE] enable ruff rule RSE and remove useless parentheses in raise statements (#124261)
Remove useless parentheses in `raise` statements if the exception type is raised with no argument.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124261
Approved by: https://github.com/albanD
2024-04-17 19:29:34 +00:00
57734202c6 [HSTU][TGIF] Provide a API to check whether running in torch_dispatch mode (#122339)
Summary: We provide a `is_in_torch_dispatch_mode` API returning `bool` to determine whether the program is running in torch dispatch mode or not.

Test Plan:
- OSS CI
- Tested with publish of hstu models with the this diff and following diffs D54964288, D54964702, D54969677, D55025489, runtime errors are not raised anymore in publish

Differential Revision: D55091453

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122339
Approved by: https://github.com/jiayisuse
2024-03-21 01:37:23 +00:00
d813474363 [Pytorch] auto format _python_dispatch file (#122226)
Summary: Auto format the _python_dispatch file, to make D55091453 easier to review

Test Plan: `arc lint`

Differential Revision: D55091454

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122226
Approved by: https://github.com/aakhundov
2024-03-20 19:28:39 +00:00
4f120dc2a6 Clean up mode handling in python dispatcher (#121083)
Things that were bad before this PR:
1. Temporarily unsetting functional tensor mode and proxy mode both had duplicate implementation
2. There are variants of mode handling private utils that has duplicate implementation. (different APIs calling repeated implementation, so i refactored)
3. _push_mode API used to take dispatch key argument which is not necessary.
4. There are unused APIs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121083
Approved by: https://github.com/zou3519
2024-03-08 00:30:34 +00:00
c646030cd2 Support higher order op functionalization in predispatch IR (#115314)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115314
Approved by: https://github.com/bdhirsh
2024-03-01 09:13:47 +00:00
e7eab2f07e Fix to keep stride in return_and_correct_aliasing() (#117860)
Fixes #117794

Fix tripped the assert here: 86dedebeaf/torch/utils/_python_dispatch.py (L216)

From investigation: I found that functionalization of an in-place op (`mul_` in this test case) results in the strides of `TwoTensor`'s `a` / `b` components being mutated to be contiguous. This is not reflected in the outer tensor, causing the assert to be tripped.

After discussion with Brian, I address this in this PR by disallowing input mutations on non-contiguous tensor subclass inputs for now.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117860
Approved by: https://github.com/bdhirsh
2024-02-21 19:15:27 +00:00
3fe437b24b [BE]: Update flake8 to v6.1.0 and fix lints (#116591)
Updates flake8 to v6.1.0 and fixes a few lints using sed and some ruff tooling.
- Replace `assert(0)` with `raise AssertionError()`
- Remove extraneous parenthesis i.e.
  - `assert(a == b)` -> `assert a == b`
  - `if(x > y or y < z):`->`if x > y or y < z:`
  - And `return('...')` -> `return '...'`

Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116591
Approved by: https://github.com/albanD, https://github.com/malfet
2024-01-03 06:04:44 +00:00
76b1d44d57 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-25 04:51:21 +00:00
0567f71ac6 Revert " pre_dispatch aot_export (#115188)"
This reverts commit a267d6735051a4714fa2ac1c163315b650118744.

Reverted https://github.com/pytorch/pytorch/pull/115188 on behalf of https://github.com/jeanschmidt due to sadly, it is required to revert this commit in order to revert https://github.com/pytorch/pytorch/pull/115454 ([comment](https://github.com/pytorch/pytorch/pull/115188#issuecomment-1866310014))
2023-12-21 14:03:18 +00:00
a267d67350 pre_dispatch aot_export (#115188)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115188
Approved by: https://github.com/bdhirsh
2023-12-20 21:36:25 +00:00
d85314c95c Support Predispatch functionalization (#113728)
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work:

1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key.

2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack.

3.  We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup.

Some missing bits after this PR:
1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way.
2. Turn off CompositeImplicitAutograd decomps
3. Functionalizing HOO

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728
Approved by: https://github.com/bdhirsh
2023-12-19 20:28:35 +00:00
22704426c3 Expand dynamic dims support for traceable subclasses (#114311)
Continuation of #112185, following the design in this [doc](https://docs.google.com/document/d/1ipSxcTzEMMOAPvxP-YJlD5JBZZmIGgh8Q34ixtOUCRo).

Summary:
* Introduce `SubclassSymbolicPolicy` containing separate dynamic dim / constraint policies for the outer and inner tensors
    * Expand the automatic dynamic algorithm to recurse into inner tensors and produce one of these for a subclass instance
    * Maintain legacy behavior for subclasses by recursively calling `mark_dynamic()` on inner tensors *of the same dim as outer* when `mark_dynamic(outer, ...)` is called
    * Addresses this: 6a86cf00ad/torch/_dynamo/variables/builder.py (L1750)
* Add `outer_size` and `outer_stride` arguments to `__tensor_unflatten__()` so that you can find out what symbols were allocated for the outer size / stride (you are expected to return a tensor that compares equal to the outer symbols)
    * Signatures now:
    ```python
    # attrs is a list of inner tensor attributes on x; inner_tensor = getattr(x, attr)
    # ctx is anything useful for rebuilding the class we want to guard on
    attrs, ctx = x.__tensor_flatten__()
    ...
    # inner_tensors is a dict of {attr -> tensor}
    # ctx is taken unmodified from flattening and (eventually) guarded on
    # outer_size is the expected size of the output; possibly symbolic
    # outer_stride is the expected strides of the output; possibly symbolic
    y = MySubclass.__tensor_unflatten__(inner_tensors, ctx, outer_size, outer_stride)

    # at the __tensor_unflatten__() call-site in PT2, we assert y.shape == outer_size and y.stride() == outer_stride
    # the assert simplifies symbols when there are relationships between outer and inner symbols
    ```
    * Size info needed for `NestedTensor` at least, stride info needed for `DTensor` at least
    * Punting on `outer_storage_offset` because storage_offset handling is horribly broken in PT2 right now
* ~~Add new `__tensor_mark_dynamic__()` to allow overriding the behavior of mark_dynamic on a per-subclass basis~~ (booted to future work)
* ~~Add guards for tensor subclasses by calling `__tensor_flatten__()` in the guard to test equality on `ctx`~~
    * Now handled in #114469
* Next PR: add TENSOR_MATCH guards on inner tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/114311
Approved by: https://github.com/ezyang, https://github.com/drisspg, https://github.com/voznesenskym, https://github.com/bdhirsh
2023-12-05 21:09:25 +00:00
5e10dd2c78 fix docstring issues in torch.utils (#113335)
Fixes #112634

Fixes all the issues listed except in `torch/utils/_pytree.py` as the file no longer exists.

### Error counts

|File | Count Before | Count now|
|---- | ---- | ---- |
|`torch/utils/collect_env.py` | 39 | 25|
|`torch/utils/cpp_extension.py` | 51 | 13|
|`torch/utils/flop_counter.py` | 25 | 8|
|`torch/utils/_foreach_utils.py.py` | 2 | 0|
|`torch/utils/_python_dispatch.py.py` | 26 | 25|
|`torch/utils/backend_registration.py` | 15 | 4|
|`torch/utils/checkpoint.py` | 29 | 21|

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113335
Approved by: https://github.com/ezyang
2023-11-13 19:37:25 +00:00
bbd5b935e4 Use pytree.tree_leaves everywhere (#112324)
This changes all the instances I could find of `tree_flatten(...)[0]` or
`x, _ = tree_flatten` to use `tree_leaves`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112324
Approved by: https://github.com/lezcano
ghstack dependencies: #112327, #112323
2023-10-30 03:39:04 +00:00
46b0b7bff7 _return_and_correct_aliasing: fix for schemas with mutable tensor in kwargs (#109662)
I missed a few tests the first time around - this fixes out= op handling for `_return_and_correct_aliasing`, which failed a few tests in the python functionalization <> AOTAutograd PR above.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109662
Approved by: https://github.com/ezyang
ghstack dependencies: #108654
2023-09-22 07:09:04 +00:00
71b4b32014 return_and_correct_aliasing: massage some schemas to work with torchgen (#108897)
This issue is that `str(torch.ops.aten.conv2d.default._schema)` does not return the same schema that is in native_functions.yaml ([link](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/native_functions.yaml#L1654)).

Torchscript appears to change the default arg string `int[2] strides=1` to `int[2] strides=[1, 1]`. If you try to parse that with torchgen, torchgen is unhappy (it tries to split arguments on comma, but now we have a comma inside of the default argument).

Fixing the issue directly in torchgen was a bit more painful, so I opted just to undo the transformation that torchscript made: convert `=[1, 1]` back into `=1`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/108897
Approved by: https://github.com/ezyang
ghstack dependencies: #106404, #107917
2023-09-15 20:19:25 +00:00
f22b303f65 Add TorchDispatch version of functionalization (#106404)
This PR adds a new `FunctionalTensor` subclass, and `FunctionalTensorMode` torch dispatch mode. Together, this class/mode are a lightweight wrapper around our existing C++ functionalization logic.

This idea came from Ed - later in the stack, I want to be able to run functionalization **underneath** torch_dispatch, when performing tracing in AOTAutograd. I can't do this easily with vanilla C++ functionalization, because it has a dedicated dispatch key that always runs before TorchDispatch. However, by adding a torch_dispatch mode shim around functionalization, we can use functionalization as a torch_dispatch mode, which will make it easier to run underneath other modes later.

This PR provides the basic new classes, and some light testing.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106404
Approved by: https://github.com/ezyang
2023-09-15 20:19:25 +00:00
4f34caf164 add return_and_correct_aliasing() util for wrapper subclasses (#107915)
This PR adds a `return_and_correct_aliasing()` utility, that wrapper subclasses can use to get correct aliasing. I updated `TwoTensor` to use it, and added some testing that the aliasing of my `TwoTensor` subclass now matches the aliasing behavior of normal tensors.

Right now my test just uses a few hand-picked opinfos (that have varying aliasing behavior). I thought all op infos might be overkill (does that take a while to run?), but I'm happy to add them all if people prefer.

One more general question about this PR: eventually, proper aliasing will be a **requirement** in order for AOTAutograd to handle aliasing/mutations on subclasses properly during compilation. How can we make sure that wrapper subclasses use this API? A few options (from talking to Richard):

(1) Yolo require subclasses to use the API and hope users do as well (what this PR does)

(2) Yolo require subclasses to use the API, but add a kwarg to `_make_wrapper_subclass`, e.g. `manual_aliasing=True`, that torch.compile checks for before allowing the subclass to be used in compilation

(3) Automatically run this API in our python fallback, for **every** tensor subclass that currently implements `__tensor_flatten__` (aka only the "traceable" subclasses)

(4) Automatically run this API in our python fallback, for **every** tensor subclass. This would be a bit higher blast radius, since it would change the existing aliasing behavior of wrapper subclasses. Maybe.. this is the right thing to do though?

Either way, my tentative plan is to do (1) to unblock, and revisit this later once we want to come up with public docs + a more general "tensor subclass in PT2 requirements" plan

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107915
Approved by: https://github.com/ezyang
2023-08-29 14:27:19 +00:00
da54f3c519 reorder proxy / fake modes so they always run last (#104482)
**Update:** Made refactor of the original PR. See the original description below, but here I'll describe the updates:

(1) TLS changes in `TorchDispatchModeTLS.h/cpp`.

I added a `TorchDispatchModeKey` enum, that (for now) just contains PROXY and FAKE. The ModeTLS used to just contain a `std::vector<std::shared_ptr<c10::SafePyObject>>` corresponding to the mode stack. It now **also** contains a separate array of "infra modes", indexed by mode key (PROXY and FAKE, with a new addition, FUNCTIONAL, coming later in the stack).

`TorchDispatchModeTLS::push_onto_stack` and `TorchDispatchModeTLS::pop_stack` are now a bit more complicated. Pushing accepts an optional mode_key, which if set, tells us to add the given mode directly to our "infra_modes" array. Popping will first check the "user mode" stack, before trying to pop anything from the infra mode stack. It also optionally returns the mode key of the mode we popped if there was one - that way if we push that same mode back onto the TLS later, we know where it goes.

`TorchDispatchModeTLS::dispatch_mode_enabled()` now accepts an optional `skip_infra_modes` param, so you can separately query if there are "any modes at all", or if there are "any user modes".

`TorchDispatchModeTLS::get/set/unset_mode()` all take in a mode key, and get/set/unset the mode at that particular mode key (meaning they are only meant to be used for infra modes).

There were also some mild codegen changes to support the new enum

(2) `fake_tensor.py/proxy_tensor.py/_python_dispatch.py`

The way I tell the infra that certain subclasses/modes are "infra" is through the enum: I gave `FakeTensor` and `FakeTensorMode` a `self._mode_key = torch._C.TorchDispatchModeKey.FAKE`. `TorchDispatchMode.__enter/exit__()` (in `_python_dispatch.py` now check if the current mode has a mode key, and if so they plumb it into any `push_onto_stack()` calls (which eventually instructs `TorchDispatchModeTLS` where to put the mode). Same thing for `ProxyTorchDispatchMode`.

I also had to change both of these mode's enter/exit, to handle the fact that there can no longer be multiple proxy/fake modes on the mode stack at once. I updated them both to have a `self.enter_stack: List[Optional[TorchDispatchMode]]` - whenever we push a given mode in `__enter__`, we remove the current ambient fake/proxy mode from the mode stack, and save it in `enter_stack`, so that on exit we can reset the state properly.

(2) dispatching logic in `python_arg_parser.cpp`

This is where the core dispatching logic changes are. I added two helpers, `dispatch_on_subclass()` and `dispatch_on_mode()`. The overall dispatching order is now:
```
(a) dispatch_on_mode()  # try user modes first (where the mode stack automatically considers infra modes last)
(b) dispatch_on_subclass() # try user subclasses next (skipping infra subclasses)
(c) dispatch_on_subclass() # try infra subclasses next (skipping user subclasses)
```

Note that we still want "user subclasses" to run before "infra modes". As Ed helped me realize, this will work today: If proxy/fake modes in step 1, they'll return NotImplemented if they see a user subclass, allowing us to redispatch to the user subclass.

How do (b) and (c) distinguish between user and infra subclasses? Infra subclasses (FakeTensor, and later FunctionalTensor) are required to have a `_mode_key` hidden on the subclass - so we filter via arguments that do/don't have the _mode_key.

(3) I also changed `DoubleTensor` to `TwoTensor` to minimize confusion (@albanD  pointed out that DoubleTensor would be easily confused with `torch.FloatTensor` and friends).

----- original description below -----

The main purpose of this PR is to fix the "ordering problem" between torch_dispatch modes, where we want to ensure that our Fake and Proxy dispatch modes always run **after** any dispatch modes created by the user, regardless of where they are in the stack. See this doc for more details: https://docs.google.com/document/d/1COQ291nOZvtFnzGTQMJqoYZ3sttEYFw_7HbfSyL8gcA/edit

Full set of changes below. I ended up including a few semi-related changes in this PR that I documented - but if folks would rather I separate them out, happy to try to do that.

**(1) Add dedicated TLS slots for FakeTensorMode and ProxyTensorMode**

This is the main component of this PR. There are two new slots, `TorchDispatchModeTLS.fake_mode_` and `TorchDispatchModeTLS.proxy_mode_`, which correspond to a single "global" fake and proxy mode. There is now an invariant that `torchDispatchModeState.stack_` can never contain either of these modes.

I also added a `TorchDispatchModeTLS::maybe_highest_mode()` helper that consults the `stack_` as well as both the proxy and fake slots, and returns the highest priority mode - this is because there are a few places in the codebase where we legitimately want to get the highest priority mode, *including* fake or proxy, if one is set.

This also made the implementations of the existing `disable_proxy_modes_tracing()` and `get_innermost_proxy_mode()` marginally simpler.

**(2) Updated the dispatching logic in handle_torch_function_no_python_arg_parser()**

This is the function that actually figures out which torch_dispatch implementation to call, given the current mode stack and tensor subclass inputs. This function got marginally more complicated as part of the refactor: First we inspect the mode stack and any non-fake subclass inputs. Then we check for the proxy mode slot. Then we check for the Fake mode slot, before finally checking for any fake subclass inputs.

**(3) new python `_get_fake_tensor_mode()` and `_get_proxy_tensor_mode()` API's**

Before, if you wanted to see if proxy or fake modes were active in python, you would have to consult the mode stack. Since these two modes are no longer part of the actual mode stack, I added two new API's to directly check if either proxy or fake modes are active.

**(4) Allow traceable tensor subclasses to access storages from python**
This is convenient later in the stack, where AOTAutograd needs to detect aliasing of inputs and outputs, where those inputs and outputs might be tensor subclasses. Previously, `x.untyped_storage()` would raise an error if `x` was a subclass. In this PR, I tried to relax this constraint as little as possible: `THPVariable_storage()` will only try to return a storage to python if the tensor subclass that you are passing in is "traceable"

**(5) Fixed subclass fakeification**

@wanchaol recently added support to be able to fakeify tensor subclasses. That fakeification logic works in most cases, but there is one case it doesn't handle: autograd metadata. In particular, since autograd sees our tensor subclasses and not their desugared tensors, we need to make sure that our fakeified subclass has the same autograd metadata as the original subclass. I updated `meta_utils.py` to make sure that the autograd metadata is correct.

**(6) make tensor subclasses resizeable**

Previously we didn't allow tensor subclasses to be resizeable. I ran into an issue where fakeifying a tensor subclass occasionally requires swapping out its storage, which can involve resizing the tensor. Mechanically, this required updating `at::for_blob()` to expose a way to request that the tensor that you create has resizeable storage, and then using this new API in `_make_wrapper_tensor()`.

**(7) Added a basic DoubleTensor subclass for testing**

I use this subclass more later in this stack in my AOTAutograd tests - but it serves as a simple subclass example to test the dispatch ordering in this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104482
Approved by: https://github.com/ezyang
ghstack dependencies: #107415
2023-08-29 02:36:48 +00:00
5efd63b1b8 better support for fakeifying and dynamoing through torch_dispatch subclasses (with dynamic shapes) (#107415)
There is already some support for plumbing `__torch_dispatch__` tensor subclasses through dynamo, but this PR beefs it up a bit and adds a test. In particular:

(1) Fakeifying tensor subclasses didn't properly set autograd metadata (requires_grad, is_leaf) on the newly fakeified wrapper subclass. I don't actually have a test for this in this PR, but it's tested pretty heavily later in my aot autograd tests

(2) Fakeifying tensor subclasses didn't properly track source information for dynamic shapes on the inner tensors. I added a new `WrapperSubclassFieldSource` subclass, that represents a source coming from a tensor field on a wrapper subclass, which I use in the fakeifying logic, and again in symbolic_shapes.py to generate proper guards.

(3) `_make_wrapper_subclass()` marginally updated this code to work better with dynamic shapes. One thing that's a bit weird about `_make_wrapper_subclass`: it has two overloads, and the first explicitly does not support dynamic shapes (and the second.. does not support kwargs). I think that later we probably want to consolidate / at least make the first overload work with dynamic shapes, but I didn't want to handle that in this PR (so these smaller changes seemed like a strict improvement).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/107415
Approved by: https://github.com/ezyang
2023-08-29 02:36:48 +00:00
cb23373264 [dynamo] allow tensor subclass fakification in dynamo (#105308)
This PR adds necessary plumbing through torchdynamo to allow tensor
subclasses with certain contract (i.e. with `__tensor_flatten__` and
`__tensor_unflatten__`) to goes through the dynamo fakification pass by
fakifying the tensor subclass internal components.

Some of the tensor subclass contract logic mostly borrowed from
https://github.com/pytorch/pytorch/pull/97540

Added some tests to verify simply passing through a tensor subclass
(i.e. DTensor) through dynamo eager works as expected.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105308
Approved by: https://github.com/ezyang
2023-07-18 17:28:04 +00:00
c3c03e7cb8 Reland of https://github.com/pytorch/pytorch/pull/101818 (#103888)
Original PR broke internal

This reverts commit 5ed618132f466440ad76c884240e07796c7e2c6b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103888
Approved by: https://github.com/albanD
2023-06-21 21:00:56 +00:00
5ed618132f Revert "change pre_autograd to pre_dispatch tracing (#101818)"
This reverts commit b0392de2c39d132b5901fc9a366afc1ddc214f96.

Reverted https://github.com/pytorch/pytorch/pull/101818 on behalf of https://github.com/izaitsevfb due to Breaks internal builds see D46629736 TypeError: wrap_key() got an unexpected keyword argument pre_autograd ([comment](https://github.com/pytorch/pytorch/pull/101818#issuecomment-1587837667))
2023-06-12 18:16:37 +00:00
b0392de2c3 change pre_autograd to pre_dispatch tracing (#101818)
We discussed in a composability meeting a few weeks ago that `pre_autograd` should probably be renamed to `pre_dispatch`.

One question in this PR was: should I re-use a dispatch key? Or should I create a new dispatch key (that yet again corresponds to "top of the dispatcher")?

~~For now, I ended up sticking our proxy mode on the mode stack corresponding to `PythonTLSSnapshot`, because it was simple and it works. It looks like one of the functorch dispatch keys has higher priority though, so it's possible that functorch will end up running first. Open to options, but we can consider adding a new dispatch key later if that becomes a problem~~

Update: I added a dedicated dispatch key, `PreDispatch`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/101818
Approved by: https://github.com/ezyang, https://github.com/Neilblaze, https://github.com/albanD, https://github.com/zou3519
2023-06-09 17:30:15 +00:00
62fad315c1 fix per-dispatchkey-mode caching bug (#98030)
The bug was that: if you want to move a mode to the autograd key, we need to use the "functionality" key for it (AutogradFunctionality). But when we do that, we need to clear any PythonDispatcher caches for every op for **every** autograd key (since you could run autograd ops with both cpu and cuda tensors underneath the mode, which both may have been cached).

I didn't add a test, since this ends up getting indirectly tests by export in the PR. If someone would prefer a direct test I can add one.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98030
Approved by: https://github.com/ezyang
2023-04-25 21:58:14 +00:00
af440c427b [draft for discussion] add per-dispatch key modes (#97052)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97052
Approved by: https://github.com/ezyang, https://github.com/zou3519
2023-03-21 23:45:45 +00:00
db466ae057 Revert "[Modes] Add assert that the mode isn't already on the stack (#90770)"
This reverts commit 702838637d63936460ea2bf00b64ffec86ed6687.

Reverted https://github.com/pytorch/pytorch/pull/90770 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-01-12 16:44:29 +00:00
702838637d [Modes] Add assert that the mode isn't already on the stack (#90770)
Redo of #89726 on a clean PR, thanks @voznesenskym for the first draft!
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90770
Approved by: https://github.com/ezyang
2023-01-11 15:19:43 +00:00
2ce2fc133d Disable Current Modes when printing Tensor (#88344)
Fix for https://github.com/pytorch/pytorch/issues/88087

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88344
Approved by: https://github.com/ezyang, https://github.com/samdow
2022-11-04 00:45:35 +00:00
169ec120ef [Modes] refactor modes to only use a stack in cpp (#86458)
Refactors the mode code to only have the C++ mode stack and not the "C++ mode" like we originally had. This also simplifies the mode logic in a number of places
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86458
Approved by: https://github.com/zou3519
2022-10-21 19:18:23 +00:00
18d8c548f4 [Modes] remove enable and rewrite mode stack (squashed) (#84774)
Based on @ezyang's suggestion, mode stack now has "one true mode" which is the _only_ mode that can ever be active at the C++ level. That mode's torch dispatch is just to take the top mode in the stack, reenable itself (if we aren't at the end of the mode stack), and run the top mode's torch_{dispatch|function}

This maintains that in the middle of a mode's torch dispatch, the mode itself will not be active. It changes the function the user has to call to see what the current mode is (no longer queries the C++, it's python only) but allows the user to also see the entire mode stack easily

Removes `enable_torch_dispatch_mode` and `.restore()` since neither makes sense in this new setup

### Background
Why do we want this? Well, a pretty common pattern that was coming up was that users had to do something like

```python
## PRE-PR UX
def f(mode):
  with mode.restore():  # user needs to understand this restore thing?
    ...

with Mode() as m:
  pass
f(m)
```

Many users were getting error from forgetting to call `.restore` or from forgetting to add the (tbh weird) "mode instantiation"  step where they use the mode as a context manager with an empty body. Really, they wanted to treat modes like context managers and just write
```python
## FROM FEEDBACK, USER DESIRED CODE. POSSIBLE POST-PR
def f(mode):
  with mode:
    ...
f(Mode())
```

** Technical Details **
With the old mode stack, we basically had a linked list so the mode itself could only be used once and had a fixed parent. In this new design, the mode stack is just a python list that we're pushing to and popping from. There's only one mode that's ever active at the C++ level and it runs the next mode in the Python list. The modes don't have state on them anymore
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84774
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-09-27 01:04:35 +00:00
7532d5b125 [Modes] remove inner constructor kwarg (#83925)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/83925
Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-08-31 00:05:56 +00:00
2ac24675cc get rid of push_torch_{dispatch, function}_mode (#78215)
Currently we have 2 ways of doing the same thing for torch dispatch and function modes:
`with push_torch_dispatch_mode(X)` or `with X.push(...)`
is now the equivalent of doing
`with X()`

This removes the first API (which is older and private so we don't need to go through a deprecation cycle)

There is some risk here that this might land race with a PR that uses the old API but in general it seems like most are using the `with X()` API or `enable_torch_dispatch_mode(X())` which isn't getting removed.

EDIT: left the `with X.push(...)` API since there were ~3 land races with that over the past day or so. But made it give a warning and ask users to use the other API
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78215
Approved by: https://github.com/ezyang
2022-07-22 18:56:37 +00:00
d4f065d261 Return mode object from __enter__ (#80998)
This makes `with Mode() as m:` work.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/80998
Approved by: https://github.com/samdow
2022-07-12 23:22:26 +00:00
3734fcc8f8 add ability to push a mode if the current mode is an ancestor
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78822

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-06-10 18:27:04 +00:00
184e0065b3 add better error message for class method
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78821

Approved by: https://github.com/ezyang
2022-06-06 13:31:32 +00:00
aa06d05297 enable with semantics
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78214

Approved by: https://github.com/ezyang, https://github.com/zou3519
2022-06-01 21:14:45 +00:00
4941e72e40 Revert "Revert "Implement sym_sizes to create proper IR for sym ints representing tensor sizes (#76836)""
This reverts commit c35bd8d423ca53408c3aa39c2280167f3a22cea0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77719

Approved by: https://github.com/Chillee, https://github.com/malfet
2022-05-18 18:40:57 +00:00
48581d74ad Revert "Add dispatch mode testing for meta tensors and other stuff"
This reverts commit c1cdb1216b97970d903a6d6e9e7d0e2b4ffaef46.

Reverted https://github.com/pytorch/pytorch/pull/77477 on behalf of https://github.com/malfet
2022-05-18 02:56:48 +00:00
c1cdb1216b Add dispatch mode testing for meta tensors and other stuff
We don't have any coverage for meta tensor correctness for backwards
because torch function mode can only allow us to interpose on
Python torch API calls, but backwards invocations happen from C++.
To make this possible, I add torch_dispatch_meta test which runs the
tests with __torch_dispatch__

While doing this, I needed to generate fresh expected failure / skip
lists for the new test suite, and I discovered that my original
scaffolding for this purpose was woefully insufficient.  So I rewrote
how the test framework worked, and at the same time rewrote the
__torch_function__ code to also use the new logic.  Here's whats
new:

- Expected failure / skip is now done on a per function call basis,
  rather than the entire test.  This means that separate OpInfo
  samples for a function don't affect each other.

- There are now only two lists: expect failure list (where the test
  consistently fails on all runs) and skip list (where the test
  sometimes passes and fails.

- We explicitly notate the dtype that failed.  I considered detecting
  when something failed on all dtypes, but this was complicated and
  listing everything out seemed to be nice and simple.  To keep the
  dtypes short, I introduce a shorthand notation for dtypes.

- Conversion to meta tensors is factored into its own class
  MetaConverter

- To regenerate the expected failure / skip lists, just run with
  PYTORCH_COLLECT_EXPECT and filter on a specific test type
  (test_meta or test_dispatch_meta) for whichever you want to update.

Other misc fixes:

- Fix max_pool1d to work with BFloat16 in all circumstances, by making
  it dispatch and then fixing a minor compile error (constexpr doesn't
  work with BFloat16)

- Add resolve_name for turning random torch API functions into string
  names

- Add push classmethod to the Mode classes, so that you can more easily
  push a mode onto the mode stack

- Add some more skips for missing LAPACK

- Added an API to let you query if there's already a registration for
  a function, added a test to check that we register_meta for all
  decompositions (except detach, that decomp is wrong lol), and then
  update all the necessary sites to make the test pass.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77477

Approved by: https://github.com/zou3519
2022-05-18 00:18:34 +00:00
6779366f27 add nested mode to python mode
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75965

Approved by: https://github.com/albanD, https://github.com/ezyang, https://github.com/zou3519
2022-05-04 13:01:06 +00:00
598e7e5f19 [Reland] Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519, https://github.com/albanD
2022-05-02 20:06:43 +00:00
395a620a4f Revert "Change 'python mode' to 'torch dispatch mode'"
This reverts commit 7203a73986943c50b8d2f0fdc9d09821705a1602.

Reverted https://github.com/pytorch/pytorch/pull/76562 on behalf of https://github.com/janeyx99
2022-05-02 14:42:11 +00:00
7203a73986 Change 'python mode' to 'torch dispatch mode'
Changes Python Mode name to Torch Dispatch Mode because there is now a Torch Function Mode, so Torch Dispatch Mode and Torch Function Mode are consistent with each other
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76562
Approved by: https://github.com/zou3519
2022-05-02 13:33:58 +00:00
35cfa74f97 Add a default implementation of __torch_dispatch__
I was working on an explanation of how to call into the "super"
implementation of some given ATen operation inside of __torch_dispatch__
(https://github.com/albanD/subclass_zoo/blob/main/trivial_tensors.py)
and I kept thinking to myself "Why doesn't just calling super() on
__torch_dispatch__ work"?  Well, after this patch, it does!  The idea
is if you don't actually unwrap the input tensors, you can call
super().__torch_dispatch__ to get at the original behavior.

Internally, this is implemented by disabling PythonKey and then
redispatching.  This implementation of disabled_torch_dispatch is
not /quite/ right, and some reasons why are commented in the code.
There is then some extra work I have to do to make sure we recognize
disabled_torch_dispatch as the "default" implementation (so we don't
start slapping PythonKey on all tensors, including base Tensors),
which is modeled the same way as how disabled_torch_function is done.

Signed-off-by: Edward Z. Yang <ezyangfb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/73684

Approved by: albanD
2022-03-03 20:19:33 +00:00
67bd2a31b5 [Reland] Add python mode (#64360)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64360

This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.

Example usage:
```
with enable_python_mode(LoggingTensor):
    z = torch.empty([])
    assert isinstance(z, LoggingTensor)
```

There are quite a few changes that were made to support this.

First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.

Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.

To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.

Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.

There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.

Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.

Test Plan: - new tests

Reviewed By: ezyang

Differential Revision: D30698082

Pulled By: zou3519

fbshipit-source-id: 7094a90eee6aa51f8b71bc4d91cfb6f49e9691f8
2021-09-16 09:02:30 -07:00
0457a85d45 Revert D30543236: Add python mode
Test Plan: revert-hammer

Differential Revision:
D30543236 (4bd03b0242)

Original commit changeset: ef5444d96a5a

fbshipit-source-id: b0042ac2c22765fa11d6d00bf751f6a4489eb6d8
2021-08-31 15:28:33 -07:00
4bd03b0242 Add python mode (#63496)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63496

This PR adds a (private) enable_python_mode context manager.
(see torch/utils/_python_dispatch.py).
enable_python_mode accepts the type of a __torch_dispatch__ object
as its argument. Whenever an operator gets called inside of the
context manager, it dispatches to the __torch_dispatch__ of
the passed-in type.

Example usage:
```
with enable_python_mode(LoggingTensor):
    z = torch.empty([])
    assert isinstance(z, LoggingTensor)
```

There are quite a few changes that were made to support this.

First, we added TorchDispatchTypeObject, a C++ struct that represents the
type of a `__torch_dispatch__` object (e.g. LoggingTensor).
It holds both the PyObject* representing the class and a PyInterpreter*
so we know which Python interpreter it came from.

Next, we updated the concrete_dispatch_fn in python_variable.cpp to accept
a `const std::shared_ptr<TorchDispatchTypeObject>&` argument. When this
is null, dispatching happens as usual. When it is non-null, we prepend
the TorchDispatchTypeObject's PyObject* to the overloaded args list so that
it is considered first for dispatch.

To get that to work, we changed how `handle_torch_dispatch_no_python_arg_parser`
works. The "overloaded args list" previously only consisted of Tensor PyObjects,
but now it can have types in addition to Tensors!
- We renamed `append_overloaded_arg` to `append_overloaded_arg`
- We added a new `append_overloaded_type` that appends a type to
overloaded_args
- We added special handling in `handle_torch_dispatch_no_python_arg_parser`
and `append_overloaded_arg` to handle types in addition to Tensors.

Then, there is PythonMode and PythonModeTLS.
- We reuse the DispatchKey::Python dispatch key as a mode key
- We use PythonMode::enter and PythonMode::exit to enable/disable
DispatchKey::Python and set the PythonModeTLS.
- PythonModeTLS stores a TorchDispatchTypeObject as metadata.
- PythonMode is in libtorch_python, and PythonModeTLS is in ATen.
This split is due to the libtorch_python library boundary (because we need
to save TLS in ATen/ThreadLocalState)
- We modify the PythonFallbackKernel to look up
the relevant TorchDispatchTypeObject (if Python Mode is active) and
dispatch using it.

There are two more miscellaneous changes:
- internal_new_from_data (torch/csrc/utils/tensor_new.cpp) gets an
exclude guard. enable_python_mode currently does not handle
torch.tensor and the exclude guard is to prevent a bug.

Future:
- This PR does not allow for the nesting of Python modes. In the future we
should be able to enable this with a more sane no_dispatch API and by changing
the TLS to a stack. For now I did not need this for CompositeImplicitAutograd testing.

Test Plan: - new tests

Reviewed By: malfet, albanD

Differential Revision: D30543236

Pulled By: zou3519

fbshipit-source-id: ef5444d96a5a957d1657b7e37dce80f9a497d452
2021-08-30 18:44:35 -07:00