Populate submodules of torch._C to sys.modules recursively (#132216)

See comment:

e9d1c26275/torch/__init__.py (L938-L950)

This PR recursively sets the submodules in the C extension to `sys.modules` (e.g., `_C._dynamo.eval_frame`).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132216
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2024-08-08 15:10:39 +08:00
committed by PyTorch MergeBot
parent 7f71f2a997
commit 24dee99cb7
9 changed files with 57 additions and 50 deletions

View File

@ -2,12 +2,10 @@
import torch import torch
import torch._dynamo.test_case import torch._dynamo.test_case
from torch._C._dynamo.eval_frame import set_eval_frame
from torch._guards import CompileId from torch._guards import CompileId
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs): def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
local = 1 local = 1
return { return {

View File

@ -9,6 +9,7 @@ import torch
import torch.backends.cudnn as cudnn import torch.backends.cudnn as cudnn
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch._C._dynamo.guards import assert_size_stride
from torch.testing import make_tensor from torch.testing import make_tensor
from torch.testing._internal.common_cuda import tf32_is_not_fp32 from torch.testing._internal.common_cuda import tf32_is_not_fp32
from torch.testing._internal.common_device_type import ( from torch.testing._internal.common_device_type import (
@ -30,7 +31,6 @@ from torch.testing._internal.common_utils import (
) )
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32() AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
if TEST_SCIPY: if TEST_SCIPY:
import scipy.ndimage import scipy.ndimage

View File

@ -65,6 +65,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
from . import ( from . import (
_aoti, _aoti,
_cpu, _cpu,
_dynamo,
_functorch, _functorch,
_lazy, _lazy,
_lazy_ts_backend, _lazy_ts_backend,

View File

@ -943,15 +943,21 @@ if not TYPE_CHECKING:
# non-standard, and attributes of those submodules cannot be pickled since # non-standard, and attributes of those submodules cannot be pickled since
# pickle expect to be able to import them as "from _C.sub import attr" # pickle expect to be able to import them as "from _C.sub import attr"
# which fails with "_C is not a package # which fails with "_C is not a package
__name, __candidate = "", None def _import_extension_to_sys_modules(module, module_name, memo=None):
for __name in dir(_C): if memo is None:
__candidate = getattr(_C, __name) memo = set()
if inspect.ismodule(__candidate): if module in memo:
# submodule return
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate) memo.add(module)
for name in dir(module):
del __name, __candidate member = getattr(module, name)
if inspect.ismodule(member):
sys.modules.setdefault(f"{module_name}.{name}", member)
# Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
_import_extension_to_sys_modules(member, f"{module_name}.{name}", memo)
_import_extension_to_sys_modules(_C, f"{__name__}._C")
del _import_extension_to_sys_modules
################################################################################ ################################################################################
# Define basic utilities # Define basic utilities
@ -2107,11 +2113,14 @@ def compiled_with_cxx11_abi() -> builtins.bool:
from torch import _library as _library, _ops as _ops from torch import _library as _library, _ops as _ops
from torch._classes import classes as classes
# Import the ops "namespace" # Import the ops and classes "namespace"
from torch._ops import ops as ops # usort: skip from torch._ops import ops as ops # usort: skip
from torch._classes import classes as classes # usort: skip
sys.modules.setdefault(f"{__name__}.ops", ops)
sys.modules.setdefault(f"{__name__}.classes", classes)
# quantization depends on torch.fx and torch.ops # quantization depends on torch.fx and torch.ops
# Import quantization # Import quantization
@ -2519,7 +2528,12 @@ if TYPE_CHECKING:
# Import the following modules during type checking to enable code intelligence features, # Import the following modules during type checking to enable code intelligence features,
# such as auto-completion in tools like pylance, even when these modules are not explicitly # such as auto-completion in tools like pylance, even when these modules are not explicitly
# imported in user code. # imported in user code.
from torch import _dynamo as _dynamo, _inductor as _inductor, onnx as onnx from torch import (
_dynamo as _dynamo,
_inductor as _inductor,
_subclasses as _subclasses,
onnx as onnx,
)
else: else:
_lazy_modules = { _lazy_modules = {

View File

@ -26,6 +26,7 @@ from weakref import ReferenceType
import torch import torch
import torch._logging import torch._logging
import torch.fx.experimental._sym_dispatch_mode import torch.fx.experimental._sym_dispatch_mode
from torch._C._dynamo.guards import GlobalStateGuard
from torch._dynamo.distributed import get_compile_pg from torch._dynamo.distributed import get_compile_pg
from torch._guards import compile_context, CompileContext, CompileId, tracing from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured from torch._logging import structured
@ -125,7 +126,7 @@ if typing.TYPE_CHECKING:
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode") bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks") graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
compile_lock = threading.RLock() compile_lock = threading.RLock()

View File

@ -44,6 +44,14 @@ import torch.fx
import torch.utils._pytree as pytree import torch.utils._pytree as pytree
import torch.utils.checkpoint import torch.utils.checkpoint
from torch import _guards from torch import _guards
# see discussion at https://github.com/pytorch/pytorch/issues/120699
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,
set_guard_error_hook,
skip_code,
unsupported,
)
from torch._dispatch.python import enable_python_dispatcher from torch._dispatch.python import enable_python_dispatcher
from torch._utils_internal import justknobs_check, log_export_usage from torch._utils_internal import justknobs_check, log_export_usage
from torch.export.dynamic_shapes import _process_dynamic_shapes from torch.export.dynamic_shapes import _process_dynamic_shapes
@ -72,14 +80,6 @@ if TYPE_CHECKING:
from .types import CacheEntry, DynamoCallback from .types import CacheEntry, DynamoCallback
# see discussion at https://github.com/pytorch/pytorch/issues/120699
reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401
skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -100,7 +100,8 @@ unset = Unset.token
def _maybe_set_eval_frame(callback: DynamoCallback): def _maybe_set_eval_frame(callback: DynamoCallback):
# A wrapper on set_eval_frame that is guarded by a Justknob. # A wrapper on set_eval_frame that is guarded by a Justknob.
# Users can disable torchDynamo by setting the JK to False. # Users can disable torchDynamo by setting the JK to False.
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401 from torch._C._dynamo.eval_frame import set_eval_frame
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"): if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
log.warning( log.warning(
"Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame" "Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"

View File

@ -37,6 +37,16 @@ from weakref import ReferenceType
import torch import torch
import torch.utils._device import torch.utils._device
from torch._C._dynamo.guards import (
check_obj_id,
check_type_id,
dict_version,
DictGuardManager,
install_no_tensor_aliasing_guard,
install_object_aliasing_guard,
RootGuardManager,
TensorGuards,
)
from torch._dynamo.source import ( from torch._dynamo.source import (
is_from_flatten_script_object_source, is_from_flatten_script_object_source,
is_from_local_source, is_from_local_source,
@ -128,18 +138,6 @@ recompiles_verbose_log = torch._logging.getArtifactLogger(
) )
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards") verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
TensorGuards = torch._C._dynamo.guards.TensorGuards
check_obj_id = torch._C._dynamo.guards.check_obj_id
check_type_id = torch._C._dynamo.guards.check_type_id
dict_version = torch._C._dynamo.guards.dict_version
RootGuardManager = torch._C._dynamo.guards.RootGuardManager
DictGuardManager = torch._C._dynamo.guards.DictGuardManager
install_object_aliasing_guard = torch._C._dynamo.guards.install_object_aliasing_guard
install_no_tensor_aliasing_guard = (
torch._C._dynamo.guards.install_no_tensor_aliasing_guard
)
class GuardManager: class GuardManager:
""" """

View File

@ -2,26 +2,21 @@ import dataclasses
import sys import sys
import types import types
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
from typing_extensions import TypeAlias
import torch # CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
from torch._C._dynamo.eval_frame import (
_CacheEntry as CacheEntry,
_ExtraState as ExtraState,
)
from torch._guards import CompileId from torch._guards import CompileId
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
from torch._C._dynamo import eval_frame from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
DynamoFrameType: TypeAlias = eval_frame._PyInterpreterFrame
else: else:
DynamoFrameType: TypeAlias = types.FrameType from types import FrameType as DynamoFrameType
# This class has a `check_fn` field for the guard,
# and a `code` field for the code object.
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
ExtraState = torch._C._dynamo.eval_frame._ExtraState
# We use a dict to store additional data per frame. # We use a dict to store additional data per frame.
FrameState = Dict[Any, Any] FrameState = Dict[Any, Any]

View File

@ -2211,7 +2211,6 @@ def nn_module_get_all_hooks(
check_backward_hooks=False, check_backward_hooks=False,
check_state_dict_hooks=False, check_state_dict_hooks=False,
): ):
reset_code = torch._C._dynamo.eval_frame.reset_code
""" """
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
hooks executed during module.__call__, and state_dict hooks which are executed separately. hooks executed during module.__call__, and state_dict hooks which are executed separately.