Populate submodules of torch._C to sys.modules recursively (#132216)

See comment:

e9d1c26275/torch/__init__.py (L938-L950)

This PR recursively sets the submodules in the C extension to `sys.modules` (e.g., `_C._dynamo.eval_frame`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132216
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-08-08 15:10:39 +08:00
committed by PyTorch MergeBot
parent 7f71f2a997
commit 24dee99cb7
9 changed files with 57 additions and 50 deletions

View File

@ -2,12 +2,10 @@
import torch
import torch._dynamo.test_case
from torch._C._dynamo.eval_frame import set_eval_frame
from torch._guards import CompileId
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
local = 1
return {

View File

@ -9,6 +9,7 @@ import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
from torch._C._dynamo.guards import assert_size_stride
from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_is_not_fp32
from torch.testing._internal.common_device_type import (
@ -30,7 +31,6 @@ from torch.testing._internal.common_utils import (
)
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
if TEST_SCIPY:
import scipy.ndimage

View File

@ -65,6 +65,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
from . import (
_aoti,
_cpu,
_dynamo,
_functorch,
_lazy,
_lazy_ts_backend,

View File

@ -943,15 +943,21 @@ if not TYPE_CHECKING:
# non-standard, and attributes of those submodules cannot be pickled since
# pickle expect to be able to import them as "from _C.sub import attr"
# which fails with "_C is not a package
__name, __candidate = "", None
for __name in dir(_C):
__candidate = getattr(_C, __name)
if inspect.ismodule(__candidate):
# submodule
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
del __name, __candidate
def _import_extension_to_sys_modules(module, module_name, memo=None):
if memo is None:
memo = set()
if module in memo:
return
memo.add(module)
for name in dir(module):
member = getattr(module, name)
if inspect.ismodule(member):
sys.modules.setdefault(f"{module_name}.{name}", member)
# Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
_import_extension_to_sys_modules(member, f"{module_name}.{name}", memo)
_import_extension_to_sys_modules(_C, f"{__name__}._C")
del _import_extension_to_sys_modules
################################################################################
# Define basic utilities
@ -2107,11 +2113,14 @@ def compiled_with_cxx11_abi() -> builtins.bool:
from torch import _library as _library, _ops as _ops
from torch._classes import classes as classes
# Import the ops "namespace"
# Import the ops and classes "namespace"
from torch._ops import ops as ops # usort: skip
from torch._classes import classes as classes # usort: skip
sys.modules.setdefault(f"{__name__}.ops", ops)
sys.modules.setdefault(f"{__name__}.classes", classes)
# quantization depends on torch.fx and torch.ops
# Import quantization
@ -2519,7 +2528,12 @@ if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code.
from torch import _dynamo as _dynamo, _inductor as _inductor, onnx as onnx
from torch import (
_dynamo as _dynamo,
_inductor as _inductor,
_subclasses as _subclasses,
onnx as onnx,
)
else:
_lazy_modules = {

View File

@ -26,6 +26,7 @@ from weakref import ReferenceType
import torch
import torch._logging
import torch.fx.experimental._sym_dispatch_mode
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
@ -125,7 +126,7 @@ if typing.TYPE_CHECKING:
log = logging.getLogger(__name__)
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
compile_lock = threading.RLock()

View File

@ -44,6 +44,14 @@ import torch.fx
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch import _guards
# see discussion at https://github.com/pytorch/pytorch/issues/120699
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_guard_error_hook,
skip_code,
unsupported,
)
from torch._dispatch.python import enable_python_dispatcher
from torch._utils_internal import justknobs_check, log_export_usage
from torch.export.dynamic_shapes import _process_dynamic_shapes
@ -72,14 +80,6 @@ if TYPE_CHECKING:
from .types import CacheEntry, DynamoCallback
# see discussion at https://github.com/pytorch/pytorch/issues/120699
reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401
skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
log = logging.getLogger(__name__)
@ -100,7 +100,8 @@ unset = Unset.token
def _maybe_set_eval_frame(callback: DynamoCallback):
# A wrapper on set_eval_frame that is guarded by a Justknob.
# Users can disable torchDynamo by setting the JK to False.
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
from torch._C._dynamo.eval_frame import set_eval_frame
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
log.warning(
"Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"

View File

@ -37,6 +37,16 @@ from weakref import ReferenceType
import torch
import torch.utils._device
from torch._C._dynamo.guards import (
check_obj_id,
check_type_id,
dict_version,
DictGuardManager,
install_no_tensor_aliasing_guard,
install_object_aliasing_guard,
RootGuardManager,
TensorGuards,
)
from torch._dynamo.source import (
is_from_flatten_script_object_source,
is_from_local_source,
@ -128,18 +138,6 @@ recompiles_verbose_log = torch._logging.getArtifactLogger(
)
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
TensorGuards = torch._C._dynamo.guards.TensorGuards
check_obj_id = torch._C._dynamo.guards.check_obj_id
check_type_id = torch._C._dynamo.guards.check_type_id
dict_version = torch._C._dynamo.guards.dict_version
RootGuardManager = torch._C._dynamo.guards.RootGuardManager
DictGuardManager = torch._C._dynamo.guards.DictGuardManager
install_object_aliasing_guard = torch._C._dynamo.guards.install_object_aliasing_guard
install_no_tensor_aliasing_guard = (
torch._C._dynamo.guards.install_no_tensor_aliasing_guard
)
class GuardManager:
"""

View File

@ -2,26 +2,21 @@ import dataclasses
import sys
import types
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
from typing_extensions import TypeAlias
import torch
# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
from torch._C._dynamo.eval_frame import (
_CacheEntry as CacheEntry,
_ExtraState as ExtraState,
)
from torch._guards import CompileId
if sys.version_info >= (3, 11):
from torch._C._dynamo import eval_frame
DynamoFrameType: TypeAlias = eval_frame._PyInterpreterFrame
from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
else:
DynamoFrameType: TypeAlias = types.FrameType
from types import FrameType as DynamoFrameType
# This class has a `check_fn` field for the guard,
# and a `code` field for the code object.
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
ExtraState = torch._C._dynamo.eval_frame._ExtraState
# We use a dict to store additional data per frame.
FrameState = Dict[Any, Any]

View File

@ -2211,7 +2211,6 @@ def nn_module_get_all_hooks(
check_backward_hooks=False,
check_state_dict_hooks=False,
):
reset_code = torch._C._dynamo.eval_frame.reset_code
"""
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
hooks executed during module.__call__, and state_dict hooks which are executed separately.