mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Support Predispatch functionalization (#113728)
In this PR, we are implementing Functionalization on pre-dispatch graph. Today, every dispatch key except for Dispatchkey.Python has a dedicated mode stack in python. PreDispatch tracing relies on this behaviour by pushing ProxyTorchDispatchMode to Dispatchkey.PreDispatch mode stack and handle the dispatching logic in python. To make pre-dispatch functionalization work, we now need to push FunctionalTensorMode on DispatchKey.PreDispatch mode stack and make sure it runs before ProxyTorchDispatchMode. (this is very similar to how post-dispatch tracing work). Here are some design decisions we made for this flow to work: 1. FunctionalTensorMode internally calls C++ functionalize key. Since C++ functionalization goes after PreDispatch, if we are not careful, we will keep re-entering into PreDispatch key. We solve this by directly dispatching to C++ Functionalize key. 2. We delete mode_stack_per_key logic because the only realistic time it is exercised is for PreDispatch and it is in general not safe to have a plain list because FunctionalTensorMode and ProxyTorchDispatchMode ordering matter and it is hard to enforce it on plain list. Instead, now we have a private class that tracks PreDispatch mode stack. 3. We will still run CompositeImplicitAutograd decomps in this PR, and disable this logic later as a followup. Some missing bits after this PR: 1. Preserving autograd ops in a functional form. Right now they still show up in the graph but in a "non-functional" way. 2. Turn off CompositeImplicitAutograd decomps 3. Functionalizing HOO Pull Request resolved: https://github.com/pytorch/pytorch/pull/113728 Approved by: https://github.com/bdhirsh
This commit is contained in:
committed by
PyTorch MergeBot
parent
1474eb5f29
commit
d85314c95c
@ -87,6 +87,7 @@ class DispatchKey(Enum):
|
||||
NestedTensor = auto()
|
||||
Dense = auto()
|
||||
|
||||
PreDispatch = auto()
|
||||
Python = auto()
|
||||
FuncTorchDynamicLayerBackMode = auto()
|
||||
ZeroTensor = auto()
|
||||
|
||||
Reference in New Issue
Block a user