mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Populate submodules of torch._C
to sys.modules
recursively (#132216)
See comment:
e9d1c26275/torch/__init__.py (L938-L950)
This PR recursively sets the submodules in the C extension to `sys.modules` (e.g., `_C._dynamo.eval_frame`).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132216
Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
7f71f2a997
commit
24dee99cb7
@ -2,12 +2,10 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch._dynamo.test_case
|
import torch._dynamo.test_case
|
||||||
|
from torch._C._dynamo.eval_frame import set_eval_frame
|
||||||
from torch._guards import CompileId
|
from torch._guards import CompileId
|
||||||
|
|
||||||
|
|
||||||
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
|
def target_with_varkwargs(arg1, /, positional_only_arg, *, keyword_only_arg, **kwargs):
|
||||||
local = 1
|
local = 1
|
||||||
return {
|
return {
|
||||||
|
@ -9,6 +9,7 @@ import torch
|
|||||||
import torch.backends.cudnn as cudnn
|
import torch.backends.cudnn as cudnn
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch._C._dynamo.guards import assert_size_stride
|
||||||
from torch.testing import make_tensor
|
from torch.testing import make_tensor
|
||||||
from torch.testing._internal.common_cuda import tf32_is_not_fp32
|
from torch.testing._internal.common_cuda import tf32_is_not_fp32
|
||||||
from torch.testing._internal.common_device_type import (
|
from torch.testing._internal.common_device_type import (
|
||||||
@ -30,7 +31,6 @@ from torch.testing._internal.common_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
|
|
||||||
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
|
||||||
if TEST_SCIPY:
|
if TEST_SCIPY:
|
||||||
import scipy.ndimage
|
import scipy.ndimage
|
||||||
|
@ -65,6 +65,7 @@ from torch.utils._python_dispatch import TorchDispatchMode
|
|||||||
from . import (
|
from . import (
|
||||||
_aoti,
|
_aoti,
|
||||||
_cpu,
|
_cpu,
|
||||||
|
_dynamo,
|
||||||
_functorch,
|
_functorch,
|
||||||
_lazy,
|
_lazy,
|
||||||
_lazy_ts_backend,
|
_lazy_ts_backend,
|
||||||
|
@ -943,15 +943,21 @@ if not TYPE_CHECKING:
|
|||||||
# non-standard, and attributes of those submodules cannot be pickled since
|
# non-standard, and attributes of those submodules cannot be pickled since
|
||||||
# pickle expect to be able to import them as "from _C.sub import attr"
|
# pickle expect to be able to import them as "from _C.sub import attr"
|
||||||
# which fails with "_C is not a package
|
# which fails with "_C is not a package
|
||||||
__name, __candidate = "", None
|
def _import_extension_to_sys_modules(module, module_name, memo=None):
|
||||||
for __name in dir(_C):
|
if memo is None:
|
||||||
__candidate = getattr(_C, __name)
|
memo = set()
|
||||||
if inspect.ismodule(__candidate):
|
if module in memo:
|
||||||
# submodule
|
return
|
||||||
sys.modules.setdefault(f"{__name__}._C.{__name}", __candidate)
|
memo.add(module)
|
||||||
|
for name in dir(module):
|
||||||
del __name, __candidate
|
member = getattr(module, name)
|
||||||
|
if inspect.ismodule(member):
|
||||||
|
sys.modules.setdefault(f"{module_name}.{name}", member)
|
||||||
|
# Recurse for submodules (e.g., `_C._dynamo.eval_frame`)
|
||||||
|
_import_extension_to_sys_modules(member, f"{module_name}.{name}", memo)
|
||||||
|
|
||||||
|
_import_extension_to_sys_modules(_C, f"{__name__}._C")
|
||||||
|
del _import_extension_to_sys_modules
|
||||||
|
|
||||||
################################################################################
|
################################################################################
|
||||||
# Define basic utilities
|
# Define basic utilities
|
||||||
@ -2107,11 +2113,14 @@ def compiled_with_cxx11_abi() -> builtins.bool:
|
|||||||
|
|
||||||
|
|
||||||
from torch import _library as _library, _ops as _ops
|
from torch import _library as _library, _ops as _ops
|
||||||
from torch._classes import classes as classes
|
|
||||||
|
|
||||||
|
|
||||||
# Import the ops "namespace"
|
# Import the ops and classes "namespace"
|
||||||
from torch._ops import ops as ops # usort: skip
|
from torch._ops import ops as ops # usort: skip
|
||||||
|
from torch._classes import classes as classes # usort: skip
|
||||||
|
|
||||||
|
sys.modules.setdefault(f"{__name__}.ops", ops)
|
||||||
|
sys.modules.setdefault(f"{__name__}.classes", classes)
|
||||||
|
|
||||||
# quantization depends on torch.fx and torch.ops
|
# quantization depends on torch.fx and torch.ops
|
||||||
# Import quantization
|
# Import quantization
|
||||||
@ -2519,7 +2528,12 @@ if TYPE_CHECKING:
|
|||||||
# Import the following modules during type checking to enable code intelligence features,
|
# Import the following modules during type checking to enable code intelligence features,
|
||||||
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
# such as auto-completion in tools like pylance, even when these modules are not explicitly
|
||||||
# imported in user code.
|
# imported in user code.
|
||||||
from torch import _dynamo as _dynamo, _inductor as _inductor, onnx as onnx
|
from torch import (
|
||||||
|
_dynamo as _dynamo,
|
||||||
|
_inductor as _inductor,
|
||||||
|
_subclasses as _subclasses,
|
||||||
|
onnx as onnx,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
_lazy_modules = {
|
_lazy_modules = {
|
||||||
|
@ -26,6 +26,7 @@ from weakref import ReferenceType
|
|||||||
import torch
|
import torch
|
||||||
import torch._logging
|
import torch._logging
|
||||||
import torch.fx.experimental._sym_dispatch_mode
|
import torch.fx.experimental._sym_dispatch_mode
|
||||||
|
from torch._C._dynamo.guards import GlobalStateGuard
|
||||||
from torch._dynamo.distributed import get_compile_pg
|
from torch._dynamo.distributed import get_compile_pg
|
||||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||||
from torch._logging import structured
|
from torch._logging import structured
|
||||||
@ -125,7 +126,7 @@ if typing.TYPE_CHECKING:
|
|||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
|
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
|
||||||
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||||
GlobalStateGuard = torch._C._dynamo.guards.GlobalStateGuard
|
|
||||||
|
|
||||||
compile_lock = threading.RLock()
|
compile_lock = threading.RLock()
|
||||||
|
|
||||||
|
@ -44,6 +44,14 @@ import torch.fx
|
|||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
import torch.utils.checkpoint
|
import torch.utils.checkpoint
|
||||||
from torch import _guards
|
from torch import _guards
|
||||||
|
|
||||||
|
# see discussion at https://github.com/pytorch/pytorch/issues/120699
|
||||||
|
from torch._C._dynamo.eval_frame import ( # noqa: F401
|
||||||
|
reset_code,
|
||||||
|
set_guard_error_hook,
|
||||||
|
skip_code,
|
||||||
|
unsupported,
|
||||||
|
)
|
||||||
from torch._dispatch.python import enable_python_dispatcher
|
from torch._dispatch.python import enable_python_dispatcher
|
||||||
from torch._utils_internal import justknobs_check, log_export_usage
|
from torch._utils_internal import justknobs_check, log_export_usage
|
||||||
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
||||||
@ -72,14 +80,6 @@ if TYPE_CHECKING:
|
|||||||
from .types import CacheEntry, DynamoCallback
|
from .types import CacheEntry, DynamoCallback
|
||||||
|
|
||||||
|
|
||||||
# see discussion at https://github.com/pytorch/pytorch/issues/120699
|
|
||||||
reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
|
|
||||||
|
|
||||||
set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa: F401
|
|
||||||
skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
|
|
||||||
unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
|
|
||||||
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@ -100,7 +100,8 @@ unset = Unset.token
|
|||||||
def _maybe_set_eval_frame(callback: DynamoCallback):
|
def _maybe_set_eval_frame(callback: DynamoCallback):
|
||||||
# A wrapper on set_eval_frame that is guarded by a Justknob.
|
# A wrapper on set_eval_frame that is guarded by a Justknob.
|
||||||
# Users can disable torchDynamo by setting the JK to False.
|
# Users can disable torchDynamo by setting the JK to False.
|
||||||
set_eval_frame = torch._C._dynamo.eval_frame.set_eval_frame # noqa: F401
|
from torch._C._dynamo.eval_frame import set_eval_frame
|
||||||
|
|
||||||
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
|
if not justknobs_check("pytorch/compiler:enable_compiler_set_eval_frame"):
|
||||||
log.warning(
|
log.warning(
|
||||||
"Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"
|
"Dynamo disabled by Justknob: enable_compiler_set_eval_frame, skipping set_eval_frame"
|
||||||
|
@ -37,6 +37,16 @@ from weakref import ReferenceType
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.utils._device
|
import torch.utils._device
|
||||||
|
from torch._C._dynamo.guards import (
|
||||||
|
check_obj_id,
|
||||||
|
check_type_id,
|
||||||
|
dict_version,
|
||||||
|
DictGuardManager,
|
||||||
|
install_no_tensor_aliasing_guard,
|
||||||
|
install_object_aliasing_guard,
|
||||||
|
RootGuardManager,
|
||||||
|
TensorGuards,
|
||||||
|
)
|
||||||
from torch._dynamo.source import (
|
from torch._dynamo.source import (
|
||||||
is_from_flatten_script_object_source,
|
is_from_flatten_script_object_source,
|
||||||
is_from_local_source,
|
is_from_local_source,
|
||||||
@ -128,18 +138,6 @@ recompiles_verbose_log = torch._logging.getArtifactLogger(
|
|||||||
)
|
)
|
||||||
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
|
verbose_guards_log = torch._logging.getArtifactLogger(__name__, "verbose_guards")
|
||||||
|
|
||||||
TensorGuards = torch._C._dynamo.guards.TensorGuards
|
|
||||||
check_obj_id = torch._C._dynamo.guards.check_obj_id
|
|
||||||
check_type_id = torch._C._dynamo.guards.check_type_id
|
|
||||||
dict_version = torch._C._dynamo.guards.dict_version
|
|
||||||
|
|
||||||
RootGuardManager = torch._C._dynamo.guards.RootGuardManager
|
|
||||||
DictGuardManager = torch._C._dynamo.guards.DictGuardManager
|
|
||||||
install_object_aliasing_guard = torch._C._dynamo.guards.install_object_aliasing_guard
|
|
||||||
install_no_tensor_aliasing_guard = (
|
|
||||||
torch._C._dynamo.guards.install_no_tensor_aliasing_guard
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class GuardManager:
|
class GuardManager:
|
||||||
"""
|
"""
|
||||||
|
@ -2,26 +2,21 @@ import dataclasses
|
|||||||
import sys
|
import sys
|
||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
|
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
|
||||||
from typing_extensions import TypeAlias
|
|
||||||
|
|
||||||
import torch
|
# CacheEntry has a `check_fn` field for the guard, and a `code` field for the code object.
|
||||||
|
from torch._C._dynamo.eval_frame import (
|
||||||
|
_CacheEntry as CacheEntry,
|
||||||
|
_ExtraState as ExtraState,
|
||||||
|
)
|
||||||
from torch._guards import CompileId
|
from torch._guards import CompileId
|
||||||
|
|
||||||
|
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
from torch._C._dynamo import eval_frame
|
from torch._C._dynamo.eval_frame import _PyInterpreterFrame as DynamoFrameType
|
||||||
|
|
||||||
DynamoFrameType: TypeAlias = eval_frame._PyInterpreterFrame
|
|
||||||
else:
|
else:
|
||||||
DynamoFrameType: TypeAlias = types.FrameType
|
from types import FrameType as DynamoFrameType
|
||||||
|
|
||||||
|
|
||||||
# This class has a `check_fn` field for the guard,
|
|
||||||
# and a `code` field for the code object.
|
|
||||||
CacheEntry = torch._C._dynamo.eval_frame._CacheEntry
|
|
||||||
|
|
||||||
ExtraState = torch._C._dynamo.eval_frame._ExtraState
|
|
||||||
|
|
||||||
# We use a dict to store additional data per frame.
|
# We use a dict to store additional data per frame.
|
||||||
FrameState = Dict[Any, Any]
|
FrameState = Dict[Any, Any]
|
||||||
|
|
||||||
|
@ -2211,7 +2211,6 @@ def nn_module_get_all_hooks(
|
|||||||
check_backward_hooks=False,
|
check_backward_hooks=False,
|
||||||
check_state_dict_hooks=False,
|
check_state_dict_hooks=False,
|
||||||
):
|
):
|
||||||
reset_code = torch._C._dynamo.eval_frame.reset_code
|
|
||||||
"""
|
"""
|
||||||
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
|
Sometimes its useful to differentiate between types of hooks such as forward/backward/pre
|
||||||
hooks executed during module.__call__, and state_dict hooks which are executed separately.
|
hooks executed during module.__call__, and state_dict hooks which are executed separately.
|
||||||
|
Reference in New Issue
Block a user