mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:30:26 +08:00
add storage dtype for custom device (#102481)
Fixes #ISSUE_NUMBER 1、add `isinstance` check with dtyped storage for custom device 2、add `storage.type()` support for custom device Pull Request resolved: https://github.com/pytorch/pytorch/pull/102481 Approved by: https://github.com/albanD
This commit is contained in:
committed by
PyTorch MergeBot
parent
e59db08699
commit
d9c8f9a00d
@ -340,6 +340,39 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
with self.assertRaisesRegex(RuntimeError, 'Overflow'):
|
||||
foo_storage.resize_(8**29)
|
||||
|
||||
def test_open_device_storage_type():
|
||||
torch.utils.rename_privateuse1_backend('foo')
|
||||
# test cpu float storage
|
||||
cpu_tensor = torch.randn([8]).float()
|
||||
cpu_storage = cpu_tensor.storage()
|
||||
self.assertEqual(cpu_storage.type(), "torch.FloatStorage")
|
||||
|
||||
# test custom float storage before defining FloatStorage
|
||||
foo_tensor = cpu_tensor.foo()
|
||||
foo_storage = foo_tensor.storage()
|
||||
self.assertEqual(foo_storage.type(), "torch.storage.TypedStorage")
|
||||
|
||||
class CustomFloatStorage():
|
||||
@property
|
||||
def __module__(self):
|
||||
return "torch." + torch._C._get_privateuse1_backend_name()
|
||||
|
||||
@property
|
||||
def __name__(self):
|
||||
return "FloatStorage"
|
||||
|
||||
# test custom float storage after defining FloatStorage
|
||||
try:
|
||||
torch.foo.FloatStorage = CustomFloatStorage()
|
||||
self.assertEqual(foo_storage.type(), "torch.foo.FloatStorage")
|
||||
|
||||
# test custom int storage after defining FloatStorage
|
||||
foo_tensor2 = torch.randn([8]).int().foo()
|
||||
foo_storage2 = foo_tensor2.storage()
|
||||
self.assertEqual(foo_storage2.type(), "torch.storage.TypedStorage")
|
||||
finally:
|
||||
torch.foo.FloatStorage = None
|
||||
|
||||
test_base_device_registration()
|
||||
test_before_common_registration()
|
||||
test_common_registration()
|
||||
@ -352,6 +385,7 @@ class TestCppExtensionOpenRgistration(common.TestCase):
|
||||
test_open_device_storage_pin_memory()
|
||||
test_open_device_serialization()
|
||||
test_open_device_storage_resize()
|
||||
test_open_device_storage_type()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -238,8 +238,8 @@ class _StorageBase:
|
||||
Returns: self
|
||||
"""
|
||||
from torch.multiprocessing import get_sharing_strategy
|
||||
if self.is_cuda:
|
||||
pass # CUDA doesn't use POSIX shared memory
|
||||
if self.device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
pass # CUDA or PrivateUse1 doesn't use POSIX shared memory
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
self._share_filename_cpu_()
|
||||
else:
|
||||
@ -251,7 +251,7 @@ class _StorageBase:
|
||||
"""Creates a new storage in shared memory with the same data type"""
|
||||
from torch.multiprocessing import get_sharing_strategy
|
||||
device = torch.device(device)
|
||||
if device.type == 'cuda':
|
||||
if device.type in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
return cls(size, device=device)
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
return cls._new_using_filename_cpu(size)
|
||||
@ -427,6 +427,12 @@ def _warn_typed_storage_removal(stacklevel=2):
|
||||
def _reset_warn_typed_storage_removal():
|
||||
_warn_typed_storage_removal.__dict__['has_warned'] = False
|
||||
|
||||
def _get_device_from_module(module: str):
|
||||
if module.split(".")[-1] in ["cuda", torch._C._get_privateuse1_backend_name()]:
|
||||
return module.split(".")[-1]
|
||||
else:
|
||||
return "cpu"
|
||||
|
||||
class TypedStorage:
|
||||
is_sparse = False
|
||||
|
||||
@ -484,7 +490,7 @@ class TypedStorage:
|
||||
return TypedStorage(
|
||||
*args,
|
||||
dtype=cls._dtype,
|
||||
device='cuda' if cls.__module__ == 'torch.cuda' else 'cpu',
|
||||
device=_get_device_from_module(cls.__module__),
|
||||
_internal=True)
|
||||
|
||||
else:
|
||||
@ -499,7 +505,7 @@ class TypedStorage:
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be UntypedStorage, but got {type(wrap_storage)}")
|
||||
|
||||
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
||||
cls_device = _get_device_from_module(cls.__module__)
|
||||
|
||||
if wrap_storage.device.type != cls_device:
|
||||
raise RuntimeError(
|
||||
@ -1083,10 +1089,10 @@ class TypedStorage:
|
||||
|
||||
storage_name = _dtype_to_storage_type_map()[self.dtype]
|
||||
|
||||
if self.device.type not in ['cpu', 'cuda']:
|
||||
if self.device.type not in ['cpu', 'cuda', torch._C._get_privateuse1_backend_name()]:
|
||||
return None
|
||||
|
||||
module = torch if self.device.type == 'cpu' else torch.cuda
|
||||
module = torch if self.device.type == 'cpu' else getattr(torch, self.device.type)
|
||||
|
||||
try:
|
||||
return getattr(module, storage_name)
|
||||
@ -1101,7 +1107,7 @@ class _LegacyStorageMeta(type):
|
||||
|
||||
def __instancecheck__(cls, instance):
|
||||
if type(instance) == TypedStorage:
|
||||
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
||||
cls_device = _get_device_from_module(cls.__module__)
|
||||
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
|
||||
return False
|
||||
|
||||
|
Reference in New Issue
Block a user