mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This adds support for backwards hooks that are *both*: 1) Interior to the graph; and 2) Dynamically generated (e.g. lambdas) We do this by creating a BackwardState object that is used to register the hooks in the forward, then populated by dynamo *after* the forwards runs. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120382 Approved by: https://github.com/xmfan
28 lines
967 B
Python
28 lines
967 B
Python
import torch.fx
|
|
|
|
|
|
class BackwardState:
|
|
"""
|
|
BackwardState is used to pass Python hooks from the forwards pass
|
|
into the backwards pass in Dynamo+Compiled Autograd.
|
|
|
|
It is created by TorchDynamo and has special handling there.
|
|
Dynamo will pass an empty BackwardState to the forwards, then populate
|
|
members on it (via setattr) only after the forwards graph is finished.
|
|
Later on, in CompileAutograd we will inline and add the needed guards
|
|
on the BackwardState.
|
|
|
|
BackwardState is identified and has special handling in AOTAutograd.
|
|
During AOTAutograd:
|
|
1) BackwardState is an input to the forwards graph
|
|
2) It must only be used in the backwards
|
|
3) It will be empty in the forwards
|
|
4) In the forwards we add a wrapper to save it
|
|
5) In the backwards it becomes an input
|
|
6) There can only be one per graph
|
|
|
|
BackwardState requires CompiledAutograd.
|
|
"""
|
|
|
|
proxy: torch.fx.Proxy
|