mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
refactor code to share across different devices (#120602)
# Motivation Refactor utils code to make it possible to share across CUDA, XPU, and other backends. # Solution Move `_dummy_type` and `_LazySeedTracker` to torch._utils; # Additional Context When upstreaming, refactor these code changes by isolating them into in an additional PR to minimize their impact on the CUDA code. Pull Request resolved: https://github.com/pytorch/pytorch/pull/120602 Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang
This commit is contained in:
committed by
PyTorch MergeBot
parent
a11a49af58
commit
46e3f670b4
@ -2243,6 +2243,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch._register_device_module",
|
||||
"torch._running_with_deploy",
|
||||
"torch._sparse_coo_tensor_unsafe",
|
||||
"torch._utils._dummy_type",
|
||||
"torch._weights_only_unpickler._get_allowed_globals",
|
||||
"torch._weights_only_unpickler.load",
|
||||
"torch.align_tensors",
|
||||
@ -2389,7 +2390,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
||||
"torch.cuda._set_stream_by_id",
|
||||
"torch.cuda._sleep",
|
||||
"torch.cuda._transform_uuid_to_ordinals",
|
||||
"torch.cuda._utils._dummy_type",
|
||||
"torch.cuda._utils._get_device_index",
|
||||
"torch.cuda.amp.autocast_mode._cast",
|
||||
"torch.cuda.amp.autocast_mode.custom_bwd",
|
||||
|
@ -891,3 +891,43 @@ def _get_device_module(device_type: str):
|
||||
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
|
||||
)
|
||||
return device_module
|
||||
|
||||
|
||||
def _dummy_type(name: str) -> type:
|
||||
def get_err_fn(is_init: bool):
|
||||
def err_fn(obj, *args, **kwargs):
|
||||
if is_init:
|
||||
class_name = obj.__class__.__name__
|
||||
else:
|
||||
class_name = obj.__name__
|
||||
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
|
||||
|
||||
return err_fn
|
||||
|
||||
return type(
|
||||
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
|
||||
)
|
||||
|
||||
|
||||
class _LazySeedTracker:
|
||||
# Since seeding is memory-less, only track the latest seed.
|
||||
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||||
# the seed on current device. We track the order of **latest**
|
||||
# calls between these two API.
|
||||
def __init__(self):
|
||||
self.manual_seed_all_cb = None
|
||||
self.manual_seed_cb = None
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
@ -25,8 +25,8 @@ import torch
|
||||
import torch._C
|
||||
from torch.types import Device
|
||||
from .. import device as _device
|
||||
from .._utils import classproperty
|
||||
from ._utils import _dummy_type, _get_device_index
|
||||
from .._utils import _dummy_type, _LazySeedTracker, classproperty
|
||||
from ._utils import _get_device_index
|
||||
from .graphs import (
|
||||
CUDAGraph,
|
||||
graph,
|
||||
@ -59,31 +59,6 @@ try:
|
||||
except ImportError as err:
|
||||
_PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
|
||||
|
||||
|
||||
class _LazySeedTracker:
|
||||
# Since seeding is memory-less, only track the latest seed.
|
||||
# Note: `manual_seed_all` followed by `manual_seed` overwrites
|
||||
# the seed on current device. We track the order of **latest**
|
||||
# calls between these two API.
|
||||
def __init__(self):
|
||||
self.manual_seed_all_cb = None
|
||||
self.manual_seed_cb = None
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
||||
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
|
||||
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA
|
||||
|
@ -36,19 +36,3 @@ def _get_device_index(
|
||||
if isinstance(device, torch.cuda.device):
|
||||
return device.idx
|
||||
return _torch_get_device_index(device, optional, allow_cpu)
|
||||
|
||||
|
||||
def _dummy_type(name: str) -> type:
|
||||
def get_err_fn(is_init: bool):
|
||||
def err_fn(obj, *args, **kwargs):
|
||||
if is_init:
|
||||
class_name = obj.__class__.__name__
|
||||
else:
|
||||
class_name = obj.__name__
|
||||
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
|
||||
|
||||
return err_fn
|
||||
|
||||
return type(
|
||||
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
|
||||
)
|
||||
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
from torch.utils import _pytree
|
||||
from ._utils import _dummy_type
|
||||
from .._utils import _dummy_type
|
||||
|
||||
if not hasattr(torch._C, "_CudaStreamBase"):
|
||||
# Define dummy base classes
|
||||
|
@ -14,10 +14,10 @@ import torch
|
||||
from torch import _C
|
||||
|
||||
from torch.types import Device
|
||||
from .._utils import _dummy_type
|
||||
from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
|
||||
|
||||
from ._memory_viz import memory as _memory, segments as _segments
|
||||
from ._utils import _dummy_type
|
||||
|
||||
__all__ = [
|
||||
"caching_allocator_alloc",
|
||||
|
@ -2,7 +2,7 @@ import ctypes
|
||||
|
||||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from ._utils import _dummy_type
|
||||
from .._utils import _dummy_type
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_CudaStreamBase"):
|
||||
|
@ -13,7 +13,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch._C
|
||||
from .. import device as _device
|
||||
from ._utils import _dummy_type, _get_device_index
|
||||
from .._utils import _dummy_type, _LazySeedTracker
|
||||
from ._utils import _get_device_index
|
||||
from .streams import Event, Stream
|
||||
|
||||
_initialized = False
|
||||
@ -24,29 +25,6 @@ _queued_calls: List[
|
||||
] = [] # don't invoke these until initialization occurs
|
||||
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
|
||||
_device_t = Union[_device, str, int, None]
|
||||
|
||||
|
||||
class _LazySeedTracker:
|
||||
# Since seeding is memory-less, only track the latest seed.
|
||||
def __init__(self):
|
||||
self.manual_seed_all_cb = None
|
||||
self.manual_seed_cb = None
|
||||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback)
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback)
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
def get_calls(self) -> List:
|
||||
return self.call_order
|
||||
|
||||
|
||||
_lazy_seed_tracker = _LazySeedTracker()
|
||||
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]
|
||||
|
||||
|
@ -37,19 +37,3 @@ def _get_device_index(
|
||||
if isinstance(device, torch.xpu.device):
|
||||
return device.idx
|
||||
return _torch_get_device_index(device, optional, allow_cpu)
|
||||
|
||||
|
||||
def _dummy_type(name: str) -> type:
|
||||
def get_err_fn(is_init: bool):
|
||||
def err_fn(obj, *args, **kwargs):
|
||||
if is_init:
|
||||
class_name = obj.__class__.__name__
|
||||
else:
|
||||
class_name = obj.__name__
|
||||
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
|
||||
|
||||
return err_fn
|
||||
|
||||
return type(
|
||||
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
|
||||
)
|
||||
|
@ -2,7 +2,7 @@ import ctypes
|
||||
|
||||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
from ._utils import _dummy_type
|
||||
from .._utils import _dummy_type
|
||||
|
||||
|
||||
if not hasattr(torch._C, "_XpuStreamBase"):
|
||||
|
Reference in New Issue
Block a user