mirror of
https://github.com/pytorch/pytorch.git
synced 2025-11-03 15:35:04 +08:00
This PR adds a new config `backward_pass_autocast`, to set the backward autocast behavior. It does not change the existing behavior. The reason why we need this is that torch.compile acquires a forward and backward graph at the time of the forward pass. This means that implemented naively, if there are any context managers active outside the call to torch.compile, the backward graph will also get the behaviors from those context managers. This PR gives users a way to tweak the autocast behavior of the backward pass. Please see torch._functorch.config for the options to the `backward_pass_autocast` config. Pull Request resolved: https://github.com/pytorch/pytorch/pull/156356 Approved by: https://github.com/bdhirsh ghstack dependencies: #155354
76 lines
2.9 KiB
Markdown
76 lines
2.9 KiB
Markdown
``torch.compile`` has different autograd semantics
|
|
==================================================
|
|
|
|
When you apply ``torch.compile`` to a function in your model's forward pass,
|
|
it will automatically generate a backward pass for the compiled function.
|
|
During compilation, it will trace out a graph for the backward pass that
|
|
is used whenever autograd is invoked. We refer to the component inside
|
|
``torch.compile`` that is responsible for this as ``AOTDispatcher``
|
|
(sometimes known as ``AOTAutograd``).
|
|
|
|
As so, ``torch.compile`` bakes in details of the computation into the
|
|
traced-out backward graph during compilation of the function
|
|
in the forward pass.
|
|
However, in eager-mode PyTorch, the backward computation is dynamic:
|
|
outside of the forward pass, you can wrap the call to
|
|
``tensor.backward()`` or ``torch.autograd.grad(...)``
|
|
in a context manager that may change its behavior.
|
|
|
|
This page documents how ``torch.compile``'s autograd semantics differ from
|
|
eager-mode PyTorch and how to work around it.
|
|
|
|
``Autocast`` behavior
|
|
---------------------
|
|
|
|
``torch.compile`` bakes in an assumption on if the backward pass will be
|
|
run under an ambient autocast context manager. By default,
|
|
Use ``torch._functorch.config.backward_pass_autocast``
|
|
to control that assumption; an incorrect assumption may lead to silent
|
|
incorrectness.
|
|
|
|
The options are either:
|
|
- `"same_as_forward"` (default).
|
|
We assume that the backward of the ``torch.compile``'ed region
|
|
will be run under the same autocast context manager that the region was run
|
|
under (if any). Use this if your code looks like the following:
|
|
```py
|
|
with torch.amp.autocast(...):
|
|
y = torch.compile(region)(x)
|
|
...
|
|
# backward pass run under the same autocast context as the compiled region
|
|
z.backward()
|
|
```
|
|
- `"off"`. We assume that the backward of the torch.compile'd region will
|
|
not be run under any autocast context managers.
|
|
Use this if your code looks like the following:
|
|
```py
|
|
with torch.amp.autocast(...):
|
|
y = torch.compile(region)(x)
|
|
...
|
|
# Backward pass runs under no autocast.
|
|
z.backward()
|
|
```
|
|
- There is a third option. If you set ``torch._functorch.config.backward_pass_autocast``
|
|
to a list of kwargs, we will assume the backward pass runs under an autocast context
|
|
constructed by those kwargs.
|
|
|
|
For example, if your code looks like the following:
|
|
```py
|
|
y = torch.compile(region)(x)
|
|
...
|
|
# Backward pass runs under special context manager
|
|
with torch.amp.autocast(**kwargs):
|
|
z.backward()
|
|
```
|
|
then set ``torch._functorch.config.backward_pass_autocast = kwargs``.
|
|
|
|
Use ``patch`` to apply the option to a specific ``torch.compile`` call:
|
|
```py
|
|
with torch.amp.autocast(...):
|
|
with torch._functorch.config.patch(backward_pass_autocast="same_as_forward")
|
|
y = torch.compile(region)(x)
|
|
...
|
|
# backward pass run under the same autocast context as the compiled region
|
|
z.backward()
|
|
```
|