[BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132770
Approved by: https://github.com/bdhirsh
This commit is contained in:
Edward Z. Yang
2024-08-08 12:19:42 -07:00
committed by PyTorch MergeBot
parent f25df31008
commit 1f66487c69
13 changed files with 27 additions and 39 deletions

View File

@ -850,7 +850,6 @@ coverage_ignore_functions = [
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_disable_fake_tensor_mode",
"maybe_handle_decomp",
"proxy_call",
"set_meta",

View File

@ -85,8 +85,8 @@ Other useful stuff:
.. code:: python
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
with maybe_disable_fake_tensor_mode():
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
# fake mode is disabled here, you can do real tensor compute
When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode.

View File

@ -1926,7 +1926,6 @@
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_disable_fake_tensor_mode",
"maybe_handle_decomp",
"proxy_call",
"set_meta",

View File

@ -53,10 +53,11 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401
unsupported,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._utils_internal import justknobs_check, log_export_usage
from torch.export.dynamic_shapes import _process_dynamic_shapes
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
@ -1534,9 +1535,7 @@ def export(
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
fake_mode
):
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
try:
graph = make_fx(
graph_with_interpreter,

View File

@ -40,7 +40,8 @@ from torch.export.graph_signature import (
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.proxy_tensor import make_fx
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules

View File

@ -7,6 +7,7 @@ import torch
from torch._export.verifier import SpecViolationError
from torch._guards import detect_fake_mode
from torch._library.fake_class_registry import FakeScriptObject
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.export.exported_program import (
ArgumentSpec,
CustomObjArgument,
@ -191,7 +192,7 @@ def lift_constants_pass(
f"it's not registered with register_parameter(). export will treat it as a constant tensor"
)
# We get the real data out of the parameter by disabling the surrounding fake mode.
with torch.fx.experimental.proxy_tensor.maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
constant_val = constant_val.data
constant_kind = InputKind.CONSTANT_TENSOR
constant_fqn = _get_first_fqn(constant_attrs, constant_val)

View File

@ -80,7 +80,8 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import counters
from torch._inductor.config import trace as trace_config
from torch._prims_common import is_integer_dtype
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
@ -1928,9 +1929,7 @@ def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
def lazy_init() -> Any:
counters_ref = counters["inductor"].copy()
with torch._guards.tracing(
None
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
result = fn()
# clear view matches encountered during tracing

View File

@ -41,11 +41,11 @@ if TYPE_CHECKING:
import torch
import torch.utils._pytree as pytree
from torch._export.verifier import Verifier
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.fx._compatibility import compatibility
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
@ -92,7 +92,7 @@ class ModuleCallEntry:
def _disable_prexisiting_fake_mode(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
return fn(*args, **kwargs)
return wrapper

View File

@ -17,7 +17,7 @@ import typing_extensions
import warnings
import weakref
from collections import defaultdict
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from contextlib import contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import (
Any,
@ -660,12 +660,6 @@ def track_tensor_tree(
return inner_res
def maybe_disable_fake_tensor_mode() -> AbstractContextManager:
# TODO: figure out if this API generally makes sense and bake it into the
# library
return unset_fake_temporarily()
@dataclass
class _ProxyTensor:
proxy: Proxy
@ -824,7 +818,7 @@ def proxy_call(
const_args, const_kwargs = pytree.tree_unflatten(
const_flat_args_kwargs, spec
)
with maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
return func(*const_args, **const_kwargs)
# If any of the Tensor inputs are "real" (not FakeTensor), we may
# incorrectly burn in constants by allowing this access. Raise
@ -951,7 +945,7 @@ def proxy_call(
func is torch.ops.aten.lift_fresh_copy.default
and out.numel() <= CONSTANT_NUMEL_LIMIT
):
with maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
assert isinstance(args[0], (Proxy, Tensor)), type(args[0])
constant = args[0].clone()
elif (
@ -961,7 +955,7 @@ def proxy_call(
and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out)
):
# NB: do NOT include factories as constants
with maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
const_flat_args_kwargs = [
t.constant if isinstance(t, _ProxyTensor) else t
for t in f_flat_args_kwargs

View File

@ -12,7 +12,7 @@ from typing import Any, Callable, TYPE_CHECKING
import torch
import torch.fx
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
@ -220,7 +220,7 @@ class Transform(abc.ABC):
Scan through all nodes in graph and their meta['val'] to detect fake mode.
"""
fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes]
with maybe_disable_fake_tensor_mode():
with unset_fake_temporarily():
return torch._dynamo.utils.detect_fake_mode(fake_tensors)
def _maybe_fakefy_args(

View File

@ -7,6 +7,7 @@ from typing import Callable, Mapping, TYPE_CHECKING
import torch
import torch._ops
from torch._dispatch import python as python_dispatch
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
@ -14,7 +15,6 @@ from torch.onnx._internal.fx.passes import _utils
if TYPE_CHECKING:
import torch.fx
from torch._subclasses import fake_tensor
class Decompose(_pass.Transform):
@ -66,7 +66,7 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy
with proxy_tensor.maybe_disable_fake_tensor_mode(), python_dispatch.enable_python_dispatcher(), (
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), (
fake_mode
):
decomposed_module = proxy_tensor.make_fx(

View File

@ -2,22 +2,19 @@
from __future__ import annotations
import contextlib
from typing import Callable, TYPE_CHECKING
from typing import Callable
import torch
import torch._ops
import torch.func
import torch.fx
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
from torch._subclasses import fake_tensor
class Functionalize(_pass.Transform):
"""Functionalize a GraphModule.
@ -119,7 +116,7 @@ class Functionalize(_pass.Transform):
tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake"
assert fake_mode is not None # for mypy
with proxy_tensor.maybe_disable_fake_tensor_mode(), fake_mode:
with fake_tensor.unset_fake_temporarily(), fake_mode:
graph_module = proxy_tensor.make_fx(
functionalized_callable,
decomposition_table={},

View File

@ -19,6 +19,7 @@ from torch._prims_common import (
)
from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
from torch._refs.nn import functional as _functional_refs
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
from torch.utils import _python_dispatch, _pytree
@ -27,8 +28,6 @@ from torch.utils import _python_dispatch, _pytree
if TYPE_CHECKING:
from types import ModuleType
from torch._subclasses import fake_tensor
logger = logging.getLogger(__name__)
@ -1716,7 +1715,7 @@ class InsertTypePromotion(_pass.Transform):
fake_mode = self.fake_mode
assert fake_mode is not None, "Cannot detect fake_mode."
with proxy_tensor.maybe_disable_fake_tensor_mode(), (
with fake_tensor.unset_fake_temporarily(), (
fake_mode
), fx_traceback.preserve_node_meta():
self.interpreter.run(*fake_args)