mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
add storage dtype for custom device (#102481)
Fixes #ISSUE_NUMBER 1、add `isinstance` check with dtyped storage for custom device 2、add `storage.type()` support for custom device Pull Request resolved: https://github.com/pytorch/pytorch/pull/102481 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e59db08699
commit
d9c8f9a00d
@ -238,8 +238,8 @@ class _StorageBase:
|
||||
Returns: self
|
||||
"""
|
||||
from torch.multiprocessing import get_sharing_strategy
|
||||
if self.is_cuda:
|
||||
pass # CUDA doesn't use POSIX shared memory
|
||||
if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
pass # CUDA or PrivateUse1 doesn't use POSIX shared memory
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
self._share_filename_cpu_()
|
||||
else:
|
||||
@ -251,7 +251,7 @@ class _StorageBase:
|
||||
"""Creates a new storage in shared memory with the same data type"""
|
||||
from torch.multiprocessing import get_sharing_strategy
|
||||
device = torch.device(device)
|
||||
if device.type == 'cuda':
|
||||
if device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
return cls(size, device=device)
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
return cls._new_using_filename_cpu(size)
|
||||
@ -427,6 +427,12 @@ def _warn_typed_storage_removal(stacklevel=2):
|
||||
def _reset_warn_typed_storage_removal():
|
||||
_warn_typed_storage_removal.__dict__['has_warned'] = False
|
||||
|
||||
def _get_device_from_module(module: str):
|
||||
if module.split(".")[-1] in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
return module.split(".")[-1]
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
class TypedStorage:
|
||||
is_sparse = False
|
||||
|
||||
@ -484,7 +490,7 @@ class TypedStorage:
|
||||
return TypedStorage(
|
||||
*args,
|
||||
dtype=cls._dtype,
|
||||
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
|
||||
device=_get_device_from_module(cls.__module__),
|
||||
_internal=True)
|
||||
|
||||
else:
|
||||
@ -499,7 +505,7 @@ class TypedStorage:
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
|
||||
|
||||
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
||||
cls_device = _get_device_from_module(cls.__module__)
|
||||
|
||||
if wrap_storage.device.type != cls_device:
|
||||
raise RuntimeError(
|
||||
@ -1083,10 +1089,10 @@ class TypedStorage:
|
||||
|
||||
storage_name = _dtype_to_storage_type_map()[self.dtype]
|
||||
|
||||
if self.device.type not in ['cpu', 'cuda']:
|
||||
if self.device.type not in ['cpu', 'cuda', torch._C._get_privateuse1_backend_name()]:
|
||||
return None
|
||||
|
||||
module = torch if self.device.type == 'cpu' else torch.cuda
|
||||
module = torch if self.device.type == 'cpu' else getattr(torch, self.device.type)
|
||||
|
||||
try:
|
||||
return getattr(module, storage_name)
|
||||
@ -1101,7 +1107,7 @@ class _LegacyStorageMeta(type):
|
||||
|
||||
def __instancecheck__(cls, instance):
|
||||
if type(instance) == TypedStorage:
|
||||
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
||||
cls_device = _get_device_from_module(cls.__module__)
|
||||
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
|
||||
return False
|
||||
|
||||
|
Reference in New Issue
Block a user