Back out "Switch to predispatch" (#124860)

Summary:
Original commit changeset: 1f155b3a0bfc

Original Phabricator Diff: D56273267

Test Plan: CI

Differential Revision: D56526505

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124860
Approved by: https://github.com/angelayi
This commit is contained in:
Tugsbayasgalan (Tugsuu) Manlaibaatar
2024-04-24 17:28:31 +00:00
committed by PyTorch MergeBot
parent 9888d7495e
commit 674e15ae07
14 changed files with 226 additions and 74 deletions

View File

@ -104,25 +104,24 @@ def _get_current_dispatch_mode():
return None
def _detect_infra_mode(key):
assert key in [torch._C._TorchDispatchModeKey.FUNCTIONAL, torch._C._TorchDispatchModeKey.PROXY]
def _detect_functional_mode():
from torch._ops import _get_dispatch_mode_pre_dispatch
pre_dispatch_mode = _get_dispatch_mode_pre_dispatch(
key
pre_dispatch_functional_mode = _get_dispatch_mode_pre_dispatch(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
post_dispatch_mode = torch._C._get_dispatch_mode(
key
post_dispatch_functional_mode = torch._C._get_dispatch_mode(
torch._C._TorchDispatchModeKey.FUNCTIONAL
)
assert (pre_dispatch_mode is None) or (
post_dispatch_mode is None
assert (pre_dispatch_functional_mode is None) or (
post_dispatch_functional_mode is None
)
if pre_dispatch_mode is None:
return post_dispatch_mode
if pre_dispatch_functional_mode is None:
return post_dispatch_functional_mode
return pre_dispatch_mode
return pre_dispatch_functional_mode
def _unset_infra_mode(key):