mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Revert "[BE] Reroute all uses of proxy_tensor.maybe_disable_fake_tensor_mode to fake_tensor.unset_fake_temporarily (#132770)"
This reverts commit 902c6f3a191fb2ecb1976895b3e9eaae4b257b89. Reverted https://github.com/pytorch/pytorch/pull/132770 on behalf of https://github.com/ezyang due to Removed API was recommitted ([comment](https://github.com/pytorch/pytorch/pull/132770#issuecomment-2275749689))
This commit is contained in:
@ -850,6 +850,7 @@ coverage_ignore_functions = [
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_disable_fake_tensor_mode",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
|
@ -85,8 +85,8 @@ Other useful stuff:
|
||||
|
||||
.. code:: python
|
||||
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
with unset_fake_temporarily():
|
||||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
# fake mode is disabled here, you can do real tensor compute
|
||||
|
||||
When might you want to disable fake tensor mode? Usually you don't want to do this. One niche case where we've found it useful is to implement constant propagation on fake tensors: in this case, we need to do some actual tensor computation even though we're in a fake tensor mode.
|
||||
|
@ -1926,6 +1926,7 @@
|
||||
"get_torch_dispatch_modes",
|
||||
"has_proxy_slot",
|
||||
"is_sym_node",
|
||||
"maybe_disable_fake_tensor_mode",
|
||||
"maybe_handle_decomp",
|
||||
"proxy_call",
|
||||
"set_meta",
|
||||
|
@ -53,11 +53,10 @@ from torch._C._dynamo.eval_frame import ( # noqa: F401
|
||||
unsupported,
|
||||
)
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch._utils_internal import justknobs_check, log_export_usage
|
||||
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
DimDynamic,
|
||||
@ -1535,7 +1534,9 @@ def export(
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(graph).run(*args)
|
||||
|
||||
with unset_fake_temporarily(), enable_python_dispatcher(), fake_mode:
|
||||
with maybe_disable_fake_tensor_mode(), enable_python_dispatcher(), (
|
||||
fake_mode
|
||||
):
|
||||
try:
|
||||
graph = make_fx(
|
||||
graph_with_interpreter,
|
||||
|
@ -40,8 +40,7 @@ from torch.export.graph_signature import (
|
||||
)
|
||||
from torch.fx import traceback as fx_traceback
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
from .wrappers import _wrap_submodules
|
||||
|
@ -80,8 +80,7 @@ from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.config import trace as trace_config
|
||||
from torch._prims_common import is_integer_dtype
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
@ -1929,7 +1928,9 @@ def init_once_fakemode(fn: Callable[..., Any]) -> Callable[[], Any]:
|
||||
def lazy_init() -> Any:
|
||||
counters_ref = counters["inductor"].copy()
|
||||
|
||||
with torch._guards.tracing(None), unset_fake_temporarily(), FakeTensorMode():
|
||||
with torch._guards.tracing(
|
||||
None
|
||||
), maybe_disable_fake_tensor_mode(), FakeTensorMode():
|
||||
result = fn()
|
||||
|
||||
# clear view matches encountered during tracing
|
||||
|
@ -41,11 +41,11 @@ if TYPE_CHECKING:
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._export.verifier import Verifier
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch._subclasses.functional_tensor import FunctionalTensor
|
||||
from torch.export._tree_utils import is_equivalent, reorder_kwargs
|
||||
from torch.fx._compatibility import compatibility
|
||||
from torch.fx._utils import first_call_function_nn_module_stack
|
||||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
|
||||
from torch.fx.passes.infra.pass_base import PassResult
|
||||
from torch.fx.passes.infra.pass_manager import PassManager
|
||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||
@ -92,7 +92,7 @@ class ModuleCallEntry:
|
||||
def _disable_prexisiting_fake_mode(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
with unset_fake_temporarily():
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
@ -17,7 +17,7 @@ import typing_extensions
|
||||
import warnings
|
||||
import weakref
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, ExitStack, nullcontext
|
||||
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import (
|
||||
Any,
|
||||
@ -660,6 +660,12 @@ def track_tensor_tree(
|
||||
return inner_res
|
||||
|
||||
|
||||
def maybe_disable_fake_tensor_mode() -> AbstractContextManager:
|
||||
# TODO: figure out if this API generally makes sense and bake it into the
|
||||
# library
|
||||
return unset_fake_temporarily()
|
||||
|
||||
|
||||
@dataclass
|
||||
class _ProxyTensor:
|
||||
proxy: Proxy
|
||||
@ -818,7 +824,7 @@ def proxy_call(
|
||||
const_args, const_kwargs = pytree.tree_unflatten(
|
||||
const_flat_args_kwargs, spec
|
||||
)
|
||||
with unset_fake_temporarily():
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
return func(*const_args, **const_kwargs)
|
||||
# If any of the Tensor inputs are "real" (not FakeTensor), we may
|
||||
# incorrectly burn in constants by allowing this access. Raise
|
||||
@ -945,7 +951,7 @@ def proxy_call(
|
||||
func is torch.ops.aten.lift_fresh_copy.default
|
||||
and out.numel() <= CONSTANT_NUMEL_LIMIT
|
||||
):
|
||||
with unset_fake_temporarily():
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
assert isinstance(args[0], (Proxy, Tensor)), type(args[0])
|
||||
constant = args[0].clone()
|
||||
elif (
|
||||
@ -955,7 +961,7 @@ def proxy_call(
|
||||
and pytree.tree_all_only(Tensor, tensor_numel_in_limit, out)
|
||||
):
|
||||
# NB: do NOT include factories as constants
|
||||
with unset_fake_temporarily():
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
const_flat_args_kwargs = [
|
||||
t.constant if isinstance(t, _ProxyTensor) else t
|
||||
for t in f_flat_args_kwargs
|
||||
|
@ -12,7 +12,7 @@ from typing import Any, Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
from torch._subclasses.fake_tensor import unset_fake_temporarily
|
||||
from torch.fx.experimental.proxy_tensor import maybe_disable_fake_tensor_mode
|
||||
from torch.onnx._internal.fx import diagnostics, onnxfunction_dispatcher
|
||||
|
||||
|
||||
@ -220,7 +220,7 @@ class Transform(abc.ABC):
|
||||
Scan through all nodes in graph and their meta['val'] to detect fake mode.
|
||||
"""
|
||||
fake_tensors = [node.meta.get("val") for node in self.module.graph.nodes]
|
||||
with unset_fake_temporarily():
|
||||
with maybe_disable_fake_tensor_mode():
|
||||
return torch._dynamo.utils.detect_fake_mode(fake_tensors)
|
||||
|
||||
def _maybe_fakefy_args(
|
||||
|
@ -7,7 +7,6 @@ from typing import Callable, Mapping, TYPE_CHECKING
|
||||
import torch
|
||||
import torch._ops
|
||||
from torch._dispatch import python as python_dispatch
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.onnx._internal.fx.passes import _utils
|
||||
@ -15,6 +14,7 @@ from torch.onnx._internal.fx.passes import _utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch.fx
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
||||
class Decompose(_pass.Transform):
|
||||
@ -66,7 +66,7 @@ class Decompose(_pass.Transform):
|
||||
|
||||
# Apply decomposition table to the input graph.
|
||||
assert fake_mode is not None # for mypy
|
||||
with fake_tensor.unset_fake_temporarily(), python_dispatch.enable_python_dispatcher(), (
|
||||
with proxy_tensor.maybe_disable_fake_tensor_mode(), python_dispatch.enable_python_dispatcher(), (
|
||||
fake_mode
|
||||
):
|
||||
decomposed_module = proxy_tensor.make_fx(
|
||||
|
@ -2,19 +2,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from typing import Callable
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch._ops
|
||||
import torch.func
|
||||
import torch.fx
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
from torch.onnx._internal.fx import _pass, diagnostics
|
||||
from torch.onnx._internal.fx.passes import _utils
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
||||
class Functionalize(_pass.Transform):
|
||||
"""Functionalize a GraphModule.
|
||||
|
||||
@ -116,7 +119,7 @@ class Functionalize(_pass.Transform):
|
||||
tracing_mode = "symbolic" if self.enable_dynamic_axes else "fake"
|
||||
|
||||
assert fake_mode is not None # for mypy
|
||||
with fake_tensor.unset_fake_temporarily(), fake_mode:
|
||||
with proxy_tensor.maybe_disable_fake_tensor_mode(), fake_mode:
|
||||
graph_module = proxy_tensor.make_fx(
|
||||
functionalized_callable,
|
||||
decomposition_table={},
|
||||
|
@ -19,7 +19,6 @@ from torch._prims_common import (
|
||||
)
|
||||
from torch._refs import linalg as _linalg_refs, nn as _nn_refs, special as _special_refs
|
||||
from torch._refs.nn import functional as _functional_refs
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.experimental import proxy_tensor
|
||||
from torch.onnx._internal.fx import _pass, diagnostics, type_utils as fx_type_utils
|
||||
from torch.utils import _python_dispatch, _pytree
|
||||
@ -28,6 +27,8 @@ from torch.utils import _python_dispatch, _pytree
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@ -1715,7 +1716,7 @@ class InsertTypePromotion(_pass.Transform):
|
||||
fake_mode = self.fake_mode
|
||||
assert fake_mode is not None, "Cannot detect fake_mode."
|
||||
|
||||
with fake_tensor.unset_fake_temporarily(), (
|
||||
with proxy_tensor.maybe_disable_fake_tensor_mode(), (
|
||||
fake_mode
|
||||
), fx_traceback.preserve_node_meta():
|
||||
self.interpreter.run(*fake_args)
|
||||
|
Reference in New Issue
Block a user