Fix some fake mode confusion between inner/outer fake mode in export (#106515)

Fixes https://github.com/pytorch/pytorch/issues/106412

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106515
Approved by: https://github.com/voznesenskym, https://github.com/BowenBao, https://github.com/thiagocrepaldi
This commit is contained in:
Edward Z. Yang
2023-08-04 05:45:48 -07:00
committed by PyTorch MergeBot
parent 5b13c779d4
commit 91afefb55b
14 changed files with 153 additions and 50 deletions

View File

@ -6,6 +6,7 @@ from typing import Iterator
import torch
import torch._C
import torch._ops
import torch.utils._python_dispatch
import torch.utils._pytree as pytree
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
@ -132,7 +133,9 @@ def make_crossref_functionalize(op, final_key):
else:
return t
with suspend_functionalization():
# TODO: This probably does the wrong thing if you're running other
# substantive modes with the normal op outside here
with torch.utils._python_dispatch._disable_current_modes(), suspend_functionalization():
f_args, f_kwargs = pytree.tree_map(fakeify_defun, (args, kwargs))
orig_f_args, orig_f_kwargs = pytree.tree_map(
maybe_detach, (f_args, f_kwargs)