mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Populate submodules of torch._C to sys.modules recursively (#132216)
See comment:
e9d1c26275/torch/__init__.py (L938-L950)
This PR recursively sets the submodules in the C extension to `sys.modules` (e.g., `_C._dynamo.eval_frame`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132216
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f71f2a997
commit
24dee99cb7
@ -2,12 +2,10 @@
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
from torch._C._dynamo.eval_frame import set_eval_frame
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
||||
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
|
||||
|
||||
|
||||
def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
|
||||
local = 1
|
||||
return {
|
||||
|
||||
@ -9,6 +9,7 @@ import torch
|
||||
import torch.backends.cudnn as cudnn
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch._C._dynamo.guards import assert_size_stride
|
||||
from torch.testing import make_tensor
|
||||
from torch.testing._internal.common_cuda import tf32_is_not_fp32
|
||||
from torch.testing._internal.common_device_type import (
|
||||
@ -30,7 +31,6 @@ from torch.testing._internal.common_utils import (
|
||||
)
|
||||
|
||||
|
||||
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||
if TEST_SCIPY:
|
||||
import scipy.ndimage
|
||||
|
||||
@ -65,6 +65,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
||||
from . import (
|
||||
_aoti,
|
||||
_cpu,
|
||||
_dynamo,
|
||||
_functorch,
|
||||
_lazy,
|
||||
_lazy_ts_backend,
|
||||
|
||||
@ -943,15 +943,21 @@ if not TYPE_CHECKING:
|
||||
# non-standard, and attributes of those submodules cannot be pickled since
|
||||
# pickle expect to be able to import them as "from _C.sub import attr"
|
||||
# which fails with "_C is not a package
|
||||
__name, __candidate = "", None
|
||||
for __name in dir(_C):
|
||||
__candidate = getattr(_C, __name)
|
||||
if inspect.ismodule(__candidate):
|
||||
# submodule
|
||||
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
|
||||
|
||||
del __name, __candidate
|
||||
def _import_extension_to_sys_modules(module, module_name, memo=None):
|
||||
if memo is None:
|
||||
memo = set()
|
||||
if module in memo:
|
||||
return
|
||||
memo.add(module)
|
||||
for name in dir(module):
|
||||
member = getattr(module, name)
|
||||
if inspect.ismodule(member):
|
||||
sys.modules.setdefault(f"{module_name}.{name}", member)
|
||||
# Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
|
||||
_import_extension_to_sys_modules(member, f"{module_name}.{name}", memo)
|
||||
|
||||
_import_extension_to_sys_modules(_C, f"{__name__}._C")
|
||||
del _import_extension_to_sys_modules
|
||||
|
||||
################################################################################
|
||||
# Define basic utilities
|
||||
@ -2107,11 +2113,14 @@ def compiled_with_cxx11_abi() -> builtins.bool:
|
||||
|
||||
|
||||
from torch import _library as _library, _ops as _ops
|
||||
from torch._classes import classes as classes
|
||||
|
||||
|
||||
# Import the ops "namespace"
|
||||
# Import the ops and classes "namespace"
|
||||
from torch._ops import ops as ops # usort: skip
|
||||
from torch._classes import classes as classes # usort: skip
|
||||
|
||||
sys.modules.setdefault(f"{__name__}.ops", ops)
|
||||
sys.modules.setdefault(f"{__name__}.classes", classes)
|
||||
|
||||
# quantization depends on torch.fx and torch.ops
|
||||
# Import quantization
|
||||
@ -2519,7 +2528,12 @@ if TYPE_CHECKING:
|
||||
# Import the following modules during type checking to enable code intelligence features,
|
||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||
# imported in user code.
|
||||
from torch import _dynamo as _dynamo, _inductor as _inductor, onnx as onnx
|
||||
from torch import (
|
||||
_dynamo as _dynamo,
|
||||
_inductor as _inductor,
|
||||
_subclasses as _subclasses,
|
||||
onnx as onnx,
|
||||
)
|
||||
|
||||
else:
|
||||
_lazy_modules = {
|
||||
|
||||
@ -26,6 +26,7 @@ from weakref import ReferenceType
|
||||
import torch
|
||||
import torch._logging
|
||||
import torch.fx.experimental._sym_dispatch_mode
|
||||
from torch._C._dynamo.guards import GlobalStateGuard
|
||||
from torch._dynamo.distributed import get_compile_pg
|
||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||
from torch._logging import structured
|
||||
@ -125,7 +126,7 @@ if typing.TYPE_CHECKING:
|
||||
log = logging.getLogger(__name__)
|
||||
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
|
||||
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
|
||||
|
||||
|
||||
compile_lock = threading.RLock()
|
||||
|
||||
|
||||
@ -44,6 +44,14 @@ import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.checkpoint
|
||||
from torch import _guards
|
||||
|
||||
# see discussion at https://github.com/pytorch/pytorch/issues/120699
|
||||
from torch._C._dynamo.eval_frame import ( # noqa: F401
|
||||
reset_code,
|
||||
set_guard_error_hook,
|
||||
skip_code,
|
||||
unsupported,
|
||||
)
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._utils_internal import justknobs_check, log_export_usage
|
||||
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
||||
@ -72,14 +80,6 @@ if TYPE_CHECKING:
|
||||
from .types import CacheEntry, DynamoCallback
|
||||
|
||||
|
||||
# see discussion at https://github.com/pytorch/pytorch/issues/120699
|
||||
reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
|
||||
|
||||
set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401
|
||||
skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
|
||||
unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@ -100,7 +100,8 @@ unset = Unset.token
|
||||
def _maybe_set_eval_frame(callback: DynamoCallback):
|
||||
# A wrapper on set_eval_frame that is guarded by a Justknob.
|
||||
# Users can disable torchDynamo by setting the JK to False.
|
||||
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
|
||||
from torch._C._dynamo.eval_frame import set_eval_frame
|
||||
|
||||
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
|
||||
log.warning(
|
||||
"Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"
|
||||
|
||||
@ -37,6 +37,16 @@ from weakref import ReferenceType
|
||||
|
||||
import torch
|
||||
import torch.utils._device
|
||||
from torch._C._dynamo.guards import (
|
||||
check_obj_id,
|
||||
check_type_id,
|
||||
dict_version,
|
||||
DictGuardManager,
|
||||
install_no_tensor_aliasing_guard,
|
||||
install_object_aliasing_guard,
|
||||
RootGuardManager,
|
||||
TensorGuards,
|
||||
)
|
||||
from torch._dynamo.source import (
|
||||
is_from_flatten_script_object_source,
|
||||
is_from_local_source,
|
||||
@ -128,18 +138,6 @@ recompiles_verbose_log = torch._logging.getArtifactLogger(
|
||||
)
|
||||
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
|
||||
|
||||
TensorGuards = torch._C._dynamo.guards.TensorGuards
|
||||
check_obj_id = torch._C._dynamo.guards.check_obj_id
|
||||
check_type_id = torch._C._dynamo.guards.check_type_id
|
||||
dict_version = torch._C._dynamo.guards.dict_version
|
||||
|
||||
RootGuardManager = torch._C._dynamo.guards.RootGuardManager
|
||||
DictGuardManager = torch._C._dynamo.guards.DictGuardManager
|
||||
install_object_aliasing_guard = torch._C._dynamo.guards.install_object_aliasing_guard
|
||||
install_no_tensor_aliasing_guard = (
|
||||
torch._C._dynamo.guards.install_no_tensor_aliasing_guard
|
||||
)
|
||||
|
||||
|
||||
class GuardManager:
|
||||
"""
|
||||
|
||||
@ -2,26 +2,21 @@ import dataclasses
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
|
||||
from torch._C._dynamo.eval_frame import (
|
||||
_CacheEntry as CacheEntry,
|
||||
_ExtraState as ExtraState,
|
||||
)
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from torch._C._dynamo import eval_frame
|
||||
|
||||
DynamoFrameType: TypeAlias = eval_frame._PyInterpreterFrame
|
||||
from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
|
||||
else:
|
||||
DynamoFrameType: TypeAlias = types.FrameType
|
||||
from types import FrameType as DynamoFrameType
|
||||
|
||||
|
||||
# This class has a `check_fn` field for the guard,
|
||||
# and a `code` field for the code object.
|
||||
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
|
||||
|
||||
ExtraState = torch._C._dynamo.eval_frame._ExtraState
|
||||
|
||||
# We use a dict to store additional data per frame.
|
||||
FrameState = Dict[Any, Any]
|
||||
|
||||
|
||||
@ -2211,7 +2211,6 @@ def nn_module_get_all_hooks(
|
||||
check_backward_hooks=False,
|
||||
check_state_dict_hooks=False,
|
||||
):
|
||||
reset_code = torch._C._dynamo.eval_frame.reset_code
|
||||
"""
|
||||
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
|
||||
hooks executed during module.__call__, and state_dict hooks which are executed separately.
|
||||
|
||||
Reference in New Issue
Block a user