add storage dtype for custom device (#102481)

Fixes #ISSUE_NUMBER
1、add `isinstance` check with dtyped storage for custom device
2、add `storage.type()` support for custom device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102481
Approved by: https://github.com/albanD
This commit is contained in:
shibo19
2023-06-01 12:46:19 +00:00
committed by PyTorch MergeBot
parent e59db08699
commit d9c8f9a00d
2 changed files with 48 additions and 8 deletions

View File

@ -238,8 +238,8 @@ class _StorageBase:
Returns: self
"""
from torch.multiprocessing import get_sharing_strategy
if self.is_cuda:
pass # CUDA doesn't use POSIX shared memory
if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
pass # CUDA or PrivateUse1 doesn't use POSIX shared memory
elif get_sharing_strategy() == 'file_system':
self._share_filename_cpu_()
else:
@ -251,7 +251,7 @@ class _StorageBase:
"""Creates a new storage in shared memory with the same data type"""
from torch.multiprocessing import get_sharing_strategy
device = torch.device(device)
if device.type == 'cuda':
if device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
return cls(size, device=device)
elif get_sharing_strategy() == 'file_system':
return cls._new_using_filename_cpu(size)
@ -427,6 +427,12 @@ def _warn_typed_storage_removal(stacklevel=2):
def _reset_warn_typed_storage_removal():
_warn_typed_storage_removal.__dict__['has_warned'] = False
def _get_device_from_module(module: str):
if module.split(".")[-1] in ["cuda", torch._C._get_privateuse1_backend_name()]:
return module.split(".")[-1]
else:
return "cpu"
class TypedStorage:
is_sparse = False
@ -484,7 +490,7 @@ class TypedStorage:
return TypedStorage(
*args,
dtype=cls._dtype,
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
device=_get_device_from_module(cls.__module__),
_internal=True)
else:
@ -499,7 +505,7 @@ class TypedStorage:
arg_error_msg +
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
cls_device = _get_device_from_module(cls.__module__)
if wrap_storage.device.type != cls_device:
raise RuntimeError(
@ -1083,10 +1089,10 @@ class TypedStorage:
storage_name = _dtype_to_storage_type_map()[self.dtype]
if self.device.type not in ['cpu', 'cuda']:
if self.device.type not in ['cpu', 'cuda', torch._C._get_privateuse1_backend_name()]:
return None
module = torch if self.device.type == 'cpu' else torch.cuda
module = torch if self.device.type == 'cpu' else getattr(torch, self.device.type)
try:
return getattr(module, storage_name)
@ -1101,7 +1107,7 @@ class _LegacyStorageMeta(type):
def __instancecheck__(cls, instance):
if type(instance) == TypedStorage:
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
cls_device = _get_device_from_module(cls.__module__)
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
return False