Revert "[BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770)"

This reverts commit 902c6f3a191fb2ecb1976895b3e9eaae4b257b89.

Reverted https://github.com/pytorch/pytorch/pull/132770 on behalf of https://github.com/ezyang due to Removed API was recommitted ([comment](https://github.com/pytorch/pytorch/pull/132770#issuecomment-2275749689))
This commit is contained in:
PyTorch MergeBot
2024-08-08 12:54:34 +00:00
parent 902c6f3a19
commit d1f73fd844
12 changed files with 38 additions and 25 deletions

View File

@ -850,6 +850,7 @@ coverage_ignore_functions = [
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_disable_fake_tensor_mode",
"maybe_handle_decomp",
"proxy_call",
"set_meta",

View File

@ -85,8 +85,8 @@ Other useful stuff:
.. code:: python
from torch._subclasses.fake_tensor import unset_fake_temporarily
with unset_fake_temporarily():
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
with maybe_disable_fake_tensor_mode():
# fake mode is disabled here, you can do real tensor compute
When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode.

View File

@ -1926,6 +1926,7 @@
"get_torch_dispatch_modes",
"has_proxy_slot",
"is_sym_node",
"maybe_disable_fake_tensor_mode",
"maybe_handle_decomp",
"proxy_call",
"set_meta",

View File

@ -53,11 +53,10 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401
unsupported,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._utils_internal import justknobs_check, log_export_usage
from torch.export.dynamic_shapes import _process_dynamic_shapes
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
DimDynamic,
@ -1535,7 +1534,9 @@ def export(
with torch.fx.traceback.preserve_node_meta():
return torch.fx.Interpreter(graph).run(*args)
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
fake_mode
):
try:
graph = make_fx(
graph_with_interpreter,

View File

@ -40,8 +40,7 @@ from torch.export.graph_signature import (
)
from torch.fx import traceback as fx_traceback
from torch.fx._compatibility import compatibility
from torch.fx.experimental.proxy_tensor import make_fx
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from .wrappers import _wrap_submodules

View File

@ -80,8 +80,7 @@ from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.utils import counters
from torch._inductor.config import trace as trace_config
from torch._prims_common import is_integer_dtype
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
@ -1929,7 +1928,9 @@ def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
def lazy_init() -> Any:
counters_ref = counters["inductor"].copy()
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
with torch._guards.tracing(
None
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
result = fn()
# clear view matches encountered during tracing

View File

@ -41,11 +41,11 @@ if TYPE_CHECKING:
import torch
import torch.utils._pytree as pytree
from torch._export.verifier import Verifier
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch._subclasses.functional_tensor import FunctionalTensor
from torch.export._tree_utils import is_equivalent, reorder_kwargs
from torch.fx._compatibility import compatibility
from torch.fx._utils import first_call_function_nn_module_stack
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.fx.passes.infra.pass_base import PassResult
from torch.fx.passes.infra.pass_manager import PassManager
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
@ -92,7 +92,7 @@ class ModuleCallEntry:
def _disable_prexisiting_fake_mode(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with unset_fake_temporarily():
with maybe_disable_fake_tensor_mode():
return fn(*args, **kwargs)
return wrapper

View File

@ -17,7 +17,7 @@ import typing_extensions
import warnings
import weakref
from collections import defaultdict
from contextlib import contextmanager, ExitStack, nullcontext
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass
from typing import (
Any,
@ -660,6 +660,12 @@ def track_tensor_tree(
return inner_res
def maybe_disable_fake_tensor_mode() -> AbstractContextManager:
# TODO: figure out if this API generally makes sense and bake it into the
# library
return unset_fake_temporarily()
@dataclass
class _ProxyTensor:
proxy: Proxy
@ -818,7 +824,7 @@ def proxy_call(
const_args, const_kwargs = pytree.tree_unflatten(
const_flat_args_kwargs, spec
)
with unset_fake_temporarily():
with maybe_disable_fake_tensor_mode():
return func(*const_args, **const_kwargs)
# If any of the Tensor inputs are "real" (not FakeTensor), we may
# incorrectly burn in constants by allowing this access. Raise
@ -945,7 +951,7 @@ def proxy_call(
func is torch.ops.aten.lift_fresh_copy.default
and out.numel() <= CONSTANT_NUMEL_LIMIT
):
with unset_fake_temporarily():
with maybe_disable_fake_tensor_mode():
assert isinstance(args[0], (Proxy, Tensor)), type(args[0])
constant = args[0].clone()
elif (
@ -955,7 +961,7 @@ def proxy_call(
and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out)
):
# NB: do NOT include factories as constants
with unset_fake_temporarily():
with maybe_disable_fake_tensor_mode():
const_flat_args_kwargs = [
t.constant if isinstance(t, _ProxyTensor) else t
for t in f_flat_args_kwargs

View File

@ -12,7 +12,7 @@ from typing import Any, Callable, TYPE_CHECKING
import torch
import torch.fx
from torch._subclasses.fake_tensor import unset_fake_temporarily
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
@ -220,7 +220,7 @@ class Transform(abc.ABC):
Scan through all nodes in graph and their meta['val'] to detect fake mode.
"""
fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes]
with unset_fake_temporarily():
with maybe_disable_fake_tensor_mode():
return torch._dynamo.utils.detect_fake_mode(fake_tensors)
def _maybe_fakefy_args(

View File

@ -7,7 +7,6 @@ from typing import Callable, Mapping, TYPE_CHECKING
import torch
import torch._ops
from torch._dispatch import python as python_dispatch
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
@ -15,6 +14,7 @@ from torch.onnx._internal.fx.passes import _utils
if TYPE_CHECKING:
import torch.fx
from torch._subclasses import fake_tensor
class Decompose(_pass.Transform):
@ -66,7 +66,7 @@ class Decompose(_pass.Transform):
# Apply decomposition table to the input graph.
assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), (
with proxy_tensor.maybe_disable_fake_tensor_mode(), python_dispatch.enable_python_dispatcher(), (
fake_mode
):
decomposed_module = proxy_tensor.make_fx(

View File

@ -2,19 +2,22 @@
from __future__ import annotations
import contextlib
from typing import Callable
from typing import Callable, TYPE_CHECKING
import torch
import torch._ops
import torch.func
import torch.fx
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics
from torch.onnx._internal.fx.passes import _utils
from torch.utils import _pytree as pytree
if TYPE_CHECKING:
from torch._subclasses import fake_tensor
class Functionalize(_pass.Transform):
"""Functionalize a GraphModule.
@ -116,7 +119,7 @@ class Functionalize(_pass.Transform):
tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake"
assert fake_mode is not None # for mypy
with fake_tensor.unset_fake_temporarily(), fake_mode:
with proxy_tensor.maybe_disable_fake_tensor_mode(), fake_mode:
graph_module = proxy_tensor.make_fx(
functionalized_callable,
decomposition_table={},

View File

@ -19,7 +19,6 @@ from torch._prims_common import (
)
from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
from torch._refs.nn import functional as _functional_refs
from torch._subclasses import fake_tensor
from torch.fx.experimental import proxy_tensor
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
from torch.utils import _python_dispatch, _pytree
@ -28,6 +27,8 @@ from torch.utils import _python_dispatch, _pytree
if TYPE_CHECKING:
from types import ModuleType
from torch._subclasses import fake_tensor
logger = logging.getLogger(__name__)
@ -1715,7 +1716,7 @@ class InsertTypePromotion(_pass.Transform):
fake_mode = self.fake_mode
assert fake_mode is not None, "Cannot detect fake_mode."
with fake_tensor.unset_fake_temporarily(), (
with proxy_tensor.maybe_disable_fake_tensor_mode(), (
fake_mode
), fx_traceback.preserve_node_meta():
self.interpreter.run(*fake_args)