add storage dtype for custom device (#102481)

Fixes #ISSUE_NUMBER
1、add `isinstance` check with dtyped storage for custom device
2、add `storage.type()` support for custom device
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102481
Approved by: https://github.com/albanD
This commit is contained in:
shibo19
2023-06-01 12:46:19 +00:00
committed by PyTorch MergeBot
parent e59db08699
commit d9c8f9a00d
2 changed files with 48 additions and 8 deletions

View File

@ -340,6 +340,39 @@ class TestCppExtensionOpenRgistration(common.TestCase):
with self.assertRaisesRegex(RuntimeError, 'Overflow'):
foo_storage.resize_(8**29)
def test_open_device_storage_type():
torch.utils.rename_privateuse1_backend('foo')
# test cpu float storage
cpu_tensor = torch.randn([8]).float()
cpu_storage = cpu_tensor.storage()
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
# test custom float storage before defining FloatStorage
foo_tensor = cpu_tensor.foo()
foo_storage = foo_tensor.storage()
self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
class CustomFloatStorage():
@property
def __module__(self):
return "torch." + torch._C._get_privateuse1_backend_name()
@property
def __name__(self):
return "FloatStorage"
# test custom float storage after defining FloatStorage
try:
torch.foo.FloatStorage = CustomFloatStorage()
self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
# test custom int storage after defining FloatStorage
foo_tensor2 = torch.randn([8]).int().foo()
foo_storage2 = foo_tensor2.storage()
self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
finally:
torch.foo.FloatStorage = None
test_base_device_registration()
test_before_common_registration()
test_common_registration()
@ -352,6 +385,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
test_open_device_storage_pin_memory()
test_open_device_serialization()
test_open_device_storage_resize()
test_open_device_storage_type()
if __name__ == "__main__":

View File

@ -238,8 +238,8 @@ class _StorageBase:
Returns: self
"""
from torch.multiprocessing import get_sharing_strategy
if self.is_cuda:
pass # CUDA doesn't use POSIX shared memory
if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
pass # CUDA or PrivateUse1 doesn't use POSIX shared memory
elif get_sharing_strategy() == 'file_system':
self._share_filename_cpu_()
else:
@ -251,7 +251,7 @@ class _StorageBase:
"""Creates a new storage in shared memory with the same data type"""
from torch.multiprocessing import get_sharing_strategy
device = torch.device(device)
if device.type == 'cuda':
if device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
return cls(size, device=device)
elif get_sharing_strategy() == 'file_system':
return cls._new_using_filename_cpu(size)
@ -427,6 +427,12 @@ def _warn_typed_storage_removal(stacklevel=2):
def _reset_warn_typed_storage_removal():
_warn_typed_storage_removal.__dict__['has_warned'] = False
def _get_device_from_module(module: str):
if module.split(".")[-1] in ["cuda", torch._C._get_privateuse1_backend_name()]:
return module.split(".")[-1]
else:
return "cpu"
class TypedStorage:
is_sparse = False
@ -484,7 +490,7 @@ class TypedStorage:
return TypedStorage(
*args,
dtype=cls._dtype,
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
device=_get_device_from_module(cls.__module__),
_internal=True)
else:
@ -499,7 +505,7 @@ class TypedStorage:
arg_error_msg +
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
cls_device = _get_device_from_module(cls.__module__)
if wrap_storage.device.type != cls_device:
raise RuntimeError(
@ -1083,10 +1089,10 @@ class TypedStorage:
storage_name = _dtype_to_storage_type_map()[self.dtype]
if self.device.type not in ['cpu', 'cuda']:
if self.device.type not in ['cpu', 'cuda', torch._C._get_privateuse1_backend_name()]:
return None
module = torch if self.device.type == 'cpu' else torch.cuda
module = torch if self.device.type == 'cpu' else getattr(torch, self.device.type)
try:
return getattr(module, storage_name)
@ -1101,7 +1107,7 @@ class _LegacyStorageMeta(type):
def __instancecheck__(cls, instance):
if type(instance) == TypedStorage:
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
cls_device = _get_device_from_module(cls.__module__)
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
return False