diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 4f4131c304e6..885069b3cd9a 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -2243,6 +2243,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch._register_device_module", "torch._running_with_deploy", "torch._sparse_coo_tensor_unsafe", + "torch._utils._dummy_type", "torch._weights_only_unpickler._get_allowed_globals", "torch._weights_only_unpickler.load", "torch.align_tensors", @@ -2389,7 +2390,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys( "torch.cuda._set_stream_by_id", "torch.cuda._sleep", "torch.cuda._transform_uuid_to_ordinals", - "torch.cuda._utils._dummy_type", "torch.cuda._utils._get_device_index", "torch.cuda.amp.autocast_mode._cast", "torch.cuda.amp.autocast_mode.custom_bwd", diff --git a/torch/_utils.py b/torch/_utils.py index e93ec10d424d..5976a906db06 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -891,3 +891,43 @@ def _get_device_module(device_type: str): f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'." ) return device_module + + +def _dummy_type(name: str) -> type: + def get_err_fn(is_init: bool): + def err_fn(obj, *args, **kwargs): + if is_init: + class_name = obj.__class__.__name__ + else: + class_name = obj.__name__ + raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") + + return err_fn + + return type( + name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} + ) + + +class _LazySeedTracker: + # Since seeding is memory-less, only track the latest seed. + # Note: `manual_seed_all` followed by `manual_seed` overwrites + # the seed on current device. We track the order of **latest** + # calls between these two API. + def __init__(self): + self.manual_seed_all_cb = None + self.manual_seed_cb = None + self.call_order = [] + + def queue_seed_all(self, cb, traceback): + self.manual_seed_all_cb = (cb, traceback) + # update seed_all to be latest + self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] + + def queue_seed(self, cb, traceback): + self.manual_seed_cb = (cb, traceback) + # update seed to be latest + self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] + + def get_calls(self) -> List: + return self.call_order diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 29bc3c2bc54a..b042126d99e2 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -25,8 +25,8 @@ import torch import torch._C from torch.types import Device from .. import device as _device -from .._utils import classproperty -from ._utils import _dummy_type, _get_device_index +from .._utils import _dummy_type, _LazySeedTracker, classproperty +from ._utils import _get_device_index from .graphs import ( CUDAGraph, graph, @@ -59,31 +59,6 @@ try: except ImportError as err: _PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later - -class _LazySeedTracker: - # Since seeding is memory-less, only track the latest seed. - # Note: `manual_seed_all` followed by `manual_seed` overwrites - # the seed on current device. We track the order of **latest** - # calls between these two API. - def __init__(self): - self.manual_seed_all_cb = None - self.manual_seed_cb = None - self.call_order = [] - - def queue_seed_all(self, cb, traceback): - self.manual_seed_all_cb = (cb, traceback) - # update seed_all to be latest - self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] - - def queue_seed(self, cb, traceback): - self.manual_seed_cb = (cb, traceback) - # update seed to be latest - self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] - - def get_calls(self) -> List: - return self.call_order - - _lazy_seed_tracker = _LazySeedTracker() # Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA diff --git a/torch/cuda/_utils.py b/torch/cuda/_utils.py index 1794ca9ddd1f..1d0ee8830bd6 100644 --- a/torch/cuda/_utils.py +++ b/torch/cuda/_utils.py @@ -36,19 +36,3 @@ def _get_device_index( if isinstance(device, torch.cuda.device): return device.idx return _torch_get_device_index(device, optional, allow_cpu) - - -def _dummy_type(name: str) -> type: - def get_err_fn(is_init: bool): - def err_fn(obj, *args, **kwargs): - if is_init: - class_name = obj.__class__.__name__ - else: - class_name = obj.__name__ - raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") - - return err_fn - - return type( - name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} - ) diff --git a/torch/cuda/graphs.py b/torch/cuda/graphs.py index 563450e58e17..5e98a7a477d4 100644 --- a/torch/cuda/graphs.py +++ b/torch/cuda/graphs.py @@ -3,7 +3,7 @@ from typing import Optional import torch from torch.utils import _pytree -from ._utils import _dummy_type +from .._utils import _dummy_type if not hasattr(torch._C, "_CudaStreamBase"): # Define dummy base classes diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 55022ae829a6..60440c58dc1d 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -14,10 +14,10 @@ import torch from torch import _C from torch.types import Device +from .._utils import _dummy_type from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized from ._memory_viz import memory as _memory, segments as _segments -from ._utils import _dummy_type __all__ = [ "caching_allocator_alloc", diff --git a/torch/cuda/streams.py b/torch/cuda/streams.py index 3d417958373e..22d541f4e287 100644 --- a/torch/cuda/streams.py +++ b/torch/cuda/streams.py @@ -2,7 +2,7 @@ import ctypes import torch from torch._streambase import _EventBase, _StreamBase -from ._utils import _dummy_type +from .._utils import _dummy_type if not hasattr(torch._C, "_CudaStreamBase"): diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index 8a7969bfd645..a6fb1a0d573b 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -13,7 +13,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch import torch._C from .. import device as _device -from ._utils import _dummy_type, _get_device_index +from .._utils import _dummy_type, _LazySeedTracker +from ._utils import _get_device_index from .streams import Event, Stream _initialized = False @@ -24,29 +25,6 @@ _queued_calls: List[ ] = [] # don't invoke these until initialization occurs _is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False) _device_t = Union[_device, str, int, None] - - -class _LazySeedTracker: - # Since seeding is memory-less, only track the latest seed. - def __init__(self): - self.manual_seed_all_cb = None - self.manual_seed_cb = None - self.call_order = [] - - def queue_seed_all(self, cb, traceback): - self.manual_seed_all_cb = (cb, traceback) - # update seed_all to be latest - self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] - - def queue_seed(self, cb, traceback): - self.manual_seed_cb = (cb, traceback) - # update seed to be latest - self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] - - def get_calls(self) -> List: - return self.call_order - - _lazy_seed_tracker = _LazySeedTracker() default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment] diff --git a/torch/xpu/_utils.py b/torch/xpu/_utils.py index d34db1772abd..8f738267459a 100644 --- a/torch/xpu/_utils.py +++ b/torch/xpu/_utils.py @@ -37,19 +37,3 @@ def _get_device_index( if isinstance(device, torch.xpu.device): return device.idx return _torch_get_device_index(device, optional, allow_cpu) - - -def _dummy_type(name: str) -> type: - def get_err_fn(is_init: bool): - def err_fn(obj, *args, **kwargs): - if is_init: - class_name = obj.__class__.__name__ - else: - class_name = obj.__name__ - raise RuntimeError(f"Tried to instantiate dummy base class {class_name}") - - return err_fn - - return type( - name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)} - ) diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index 4fa639d6bfcc..2c3c3a63d58b 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -2,7 +2,7 @@ import ctypes import torch from torch._streambase import _EventBase, _StreamBase -from ._utils import _dummy_type +from .._utils import _dummy_type if not hasattr(torch._C, "_XpuStreamBase"):