Files
pytorch/torch/fx/experimental/_backward_state.py
Jason Ansel 01ec8df6d8 [Compiled Autograd] Introduce BackwardState capture (#120382)
This adds support for backwards hooks that are *both*:
1) Interior to the graph; and
2) Dynamically generated (e.g. lambdas)

We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382
Approved by: https://github.com/xmfan
2024-02-28 20:36:47 +00:00

28 lines
967 B
Python

import torch.fx
class BackwardState:
"""
BackwardState is used to pass Python hooks from the forwards pass
into the backwards pass in Dynamo+Compiled Autograd.
It is created by TorchDynamo and has special handling there.
Dynamo will pass an empty BackwardState to the forwards, then populate
members on it (via setattr) only after the forwards graph is finished.
Later on, in CompileAutograd we will inline and add the needed guards
on the BackwardState.
BackwardState is identified and has special handling in AOTAutograd.
During AOTAutograd:
1) BackwardState is an input to the forwards graph
2) It must only be used in the backwards
3) It will be empty in the forwards
4) In the forwards we add a wrapper to save it
5) In the backwards it becomes an input
6) There can only be one per graph
BackwardState requires CompiledAutograd.
"""
proxy: torch.fx.Proxy