refactor code to share across different devices (#120602)

# Motivation
Refactor utils code to make it possible to share across CUDA, XPU, and other backends.

# Solution
Move `_dummy_type` and `_LazySeedTracker` to torch._utils;

# Additional Context
When upstreaming, refactor these code changes by isolating them into in an additional PR to minimize their impact on the CUDA code.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120602
Approved by: https://github.com/albanD, https://github.com/jgong5, https://github.com/gujinghui, https://github.com/EikanWang
This commit is contained in:
Yu, Guangye
2024-02-28 06:26:05 +00:00
committed by PyTorch MergeBot
parent a11a49af58
commit 46e3f670b4
10 changed files with 49 additions and 88 deletions

View File

@ -2243,6 +2243,7 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch._register_device_module",
"torch._running_with_deploy",
"torch._sparse_coo_tensor_unsafe",
"torch._utils._dummy_type",
"torch._weights_only_unpickler._get_allowed_globals",
"torch._weights_only_unpickler.load",
"torch.align_tensors",
@ -2389,7 +2390,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
"torch.cuda._set_stream_by_id",
"torch.cuda._sleep",
"torch.cuda._transform_uuid_to_ordinals",
"torch.cuda._utils._dummy_type",
"torch.cuda._utils._get_device_index",
"torch.cuda.amp.autocast_mode._cast",
"torch.cuda.amp.autocast_mode.custom_bwd",

View File

@ -891,3 +891,43 @@ def _get_device_module(device_type: str):
f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
)
return device_module
def _dummy_type(name: str) -> type:
def get_err_fn(is_init: bool):
def err_fn(obj, *args, **kwargs):
if is_init:
class_name = obj.__class__.__name__
else:
class_name = obj.__name__
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
return err_fn
return type(
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
)
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
# Note: `manual_seed_all` followed by `manual_seed` overwrites
# the seed on current device. We track the order of **latest**
# calls between these two API.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order

View File

@ -25,8 +25,8 @@ import torch
import torch._C
from torch.types import Device
from .. import device as _device
from .._utils import classproperty
from ._utils import _dummy_type, _get_device_index
from .._utils import _dummy_type, _LazySeedTracker, classproperty
from ._utils import _get_device_index
from .graphs import (
CUDAGraph,
graph,
@ -59,31 +59,6 @@ try:
except ImportError as err:
_PYNVML_ERR = err # sometimes a lib is installed but the import fails for some other reason, so we log the error for later
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
# Note: `manual_seed_all` followed by `manual_seed` overwrites
# the seed on current device. We track the order of **latest**
# calls between these two API.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order
_lazy_seed_tracker = _LazySeedTracker()
# Define dummy _CudaDeviceProperties type if PyTorch was compiled without CUDA

View File

@ -36,19 +36,3 @@ def _get_device_index(
if isinstance(device, torch.cuda.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)
def _dummy_type(name: str) -> type:
def get_err_fn(is_init: bool):
def err_fn(obj, *args, **kwargs):
if is_init:
class_name = obj.__class__.__name__
else:
class_name = obj.__name__
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
return err_fn
return type(
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
)

View File

@ -3,7 +3,7 @@ from typing import Optional
import torch
from torch.utils import _pytree
from ._utils import _dummy_type
from .._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):
# Define dummy base classes

View File

@ -14,10 +14,10 @@ import torch
from torch import _C
from torch.types import Device
from .._utils import _dummy_type
from . import _get_device_index, _get_nvml_device_index, _lazy_init, is_initialized
from ._memory_viz import memory as _memory, segments as _segments
from ._utils import _dummy_type
__all__ = [
"caching_allocator_alloc",

View File

@ -2,7 +2,7 @@ import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from ._utils import _dummy_type
from .._utils import _dummy_type
if not hasattr(torch._C, "_CudaStreamBase"):

View File

@ -13,7 +13,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch._C
from .. import device as _device
from ._utils import _dummy_type, _get_device_index
from .._utils import _dummy_type, _LazySeedTracker
from ._utils import _get_device_index
from .streams import Event, Stream
_initialized = False
@ -24,29 +25,6 @@ _queued_calls: List[
] = [] # don't invoke these until initialization occurs
_is_in_bad_fork = getattr(torch._C, "_xpu_isInBadFork", lambda: False)
_device_t = Union[_device, str, int, None]
class _LazySeedTracker:
# Since seeding is memory-less, only track the latest seed.
def __init__(self):
self.manual_seed_all_cb = None
self.manual_seed_cb = None
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback)
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback)
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
def get_calls(self) -> List:
return self.call_order
_lazy_seed_tracker = _LazySeedTracker()
default_generators: Tuple[torch._C.Generator] = () # type: ignore[assignment]

View File

@ -37,19 +37,3 @@ def _get_device_index(
if isinstance(device, torch.xpu.device):
return device.idx
return _torch_get_device_index(device, optional, allow_cpu)
def _dummy_type(name: str) -> type:
def get_err_fn(is_init: bool):
def err_fn(obj, *args, **kwargs):
if is_init:
class_name = obj.__class__.__name__
else:
class_name = obj.__name__
raise RuntimeError(f"Tried to instantiate dummy base class {class_name}")
return err_fn
return type(
name, (object,), {"__init__": get_err_fn(True), "__new__": get_err_fn(False)}
)

View File

@ -2,7 +2,7 @@ import ctypes
import torch
from torch._streambase import _EventBase, _StreamBase
from ._utils import _dummy_type
from .._utils import _dummy_type
if not hasattr(torch._C, "_XpuStreamBase"):