mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Summary: Fixes https://github.com/pytorch/pytorch/issues/66228 cc ezyang bhosmer smessmer ljk53 bdhirsh Pull Request resolved: https://github.com/pytorch/pytorch/pull/66970 Reviewed By: bdhirsh Differential Revision: D33245612 Pulled By: ezyang fbshipit-source-id: 4c61c2cb029e2b94b0e68927c377d3e1c358dd7c (cherry picked from commit d29fcdfb4bc2cc17b1795d4349e4b56fa0d1cf12)
854 lines
30 KiB
Python
854 lines
30 KiB
Python
import io
|
|
|
|
import torch
|
|
from ._utils import _type, _cuda
|
|
from torch.types import Storage
|
|
from typing import Any, TypeVar, Type, Union, cast
|
|
import copy
|
|
import collections
|
|
from functools import lru_cache
|
|
try:
|
|
import numpy as np
|
|
HAS_NUMPY = True
|
|
except ModuleNotFoundError:
|
|
np = None # type: ignore[assignment]
|
|
|
|
T = TypeVar('T', bound='Union[_StorageBase, _TypedStorage]')
|
|
class _StorageBase(object):
|
|
_cdata: Any
|
|
is_cuda: bool = False
|
|
is_sparse: bool = False
|
|
is_sparse_csr: bool = False
|
|
device: torch.device
|
|
|
|
def __init__(self, *args, **kwargs): ... # noqa: E704
|
|
def __len__(self) -> int: ... # noqa: E704
|
|
def __getitem__(self, idx): ... # noqa: E704
|
|
def copy_(self, source: T) -> T: ... # noqa: E704
|
|
def nbytes(self) -> int: ... # noqa: E704
|
|
|
|
def size(self) -> int:
|
|
return self.nbytes()
|
|
|
|
def type(self, dtype: str = None, non_blocking: bool = False) -> T: ... # noqa: E704
|
|
def cuda(self, device=None, non_blocking=False, **kwargs) -> T: ... # noqa: E704
|
|
def element_size(self) -> int: ... # noqa: E704
|
|
def get_device(self) -> int: ... # noqa: E704
|
|
def data_ptr(self) -> int: ... # noqa: E704
|
|
|
|
# Defined in torch/csrc/generic/StorageSharing.cpp
|
|
def _share_filename_(self): ... # noqa: E704
|
|
def _share_fd_(self): ... # noqa: E704
|
|
@classmethod
|
|
def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704
|
|
@classmethod
|
|
def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704
|
|
@classmethod
|
|
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
|
|
@classmethod
|
|
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
|
|
@classmethod
|
|
def _release_ipc_counter(cls, *args, **kwargs) -> T: ... # noqa: E704
|
|
@classmethod
|
|
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
|
|
def _shared_decref(self) -> T: ... # noqa: E704
|
|
|
|
def __str__(self):
|
|
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
|
|
return content + f'\n[{torch.typename(self)} of size {len(self)}]'
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def __iter__(self):
|
|
return iter(map(lambda i: self[i], range(self.size())))
|
|
|
|
def __copy__(self):
|
|
return self.clone()
|
|
|
|
def __deepcopy__(self, memo):
|
|
memo = memo.setdefault('torch', {})
|
|
if self._cdata in memo:
|
|
return memo[self._cdata]
|
|
new_storage = self.clone()
|
|
memo[self._cdata] = new_storage
|
|
return new_storage
|
|
|
|
def __reduce__(self):
|
|
b = io.BytesIO()
|
|
torch.save(self, b, _use_new_zipfile_serialization=False)
|
|
return (_load_from_bytes, (b.getvalue(),))
|
|
|
|
def __sizeof__(self):
|
|
return super(_StorageBase, self).__sizeof__() + self.size()
|
|
|
|
def clone(self):
|
|
"""Returns a copy of this storage"""
|
|
device = self.get_device() if self.is_cuda else -1
|
|
with torch.cuda.device(device):
|
|
return type(self)(self.nbytes()).copy_(self)
|
|
|
|
def tolist(self):
|
|
"""Returns a list containing the elements of this storage"""
|
|
return list(self)
|
|
|
|
def cpu(self):
|
|
"""Returns a CPU copy of this storage if it's not already on the CPU"""
|
|
return _type(self, getattr(torch, self.__class__.__name__))
|
|
|
|
def _to(self, dtype):
|
|
if not isinstance(dtype, torch.dtype):
|
|
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
|
|
storage = torch.tensor([], dtype=torch.uint8, device=self.device).set_(cast(Storage, self)).to(dtype).storage()
|
|
if storage.data_ptr() == self.data_ptr():
|
|
storage = storage.clone()
|
|
return storage
|
|
|
|
def double(self):
|
|
"""Casts this storage to double type"""
|
|
return self._to(torch.double)
|
|
|
|
def float(self):
|
|
"""Casts this storage to float type"""
|
|
return self._to(torch.float)
|
|
|
|
def half(self):
|
|
"""Casts this storage to half type"""
|
|
return self._to(torch.half)
|
|
|
|
def long(self):
|
|
"""Casts this storage to long type"""
|
|
return self._to(torch.long)
|
|
|
|
def int(self):
|
|
"""Casts this storage to int type"""
|
|
return self._to(torch.int)
|
|
|
|
def short(self):
|
|
"""Casts this storage to short type"""
|
|
return self._to(torch.short)
|
|
|
|
def char(self):
|
|
"""Casts this storage to char type"""
|
|
return self._to(torch.int8)
|
|
|
|
def byte(self):
|
|
"""Casts this storage to byte type"""
|
|
return self._to(torch.uint8)
|
|
|
|
def bool(self):
|
|
"""Casts this storage to bool type"""
|
|
return self._to(torch.bool)
|
|
|
|
def bfloat16(self):
|
|
"""Casts this storage to bfloat16 type"""
|
|
return self._to(torch.bfloat16)
|
|
|
|
def complex_double(self):
|
|
"""Casts this storage to complex double type"""
|
|
return self._to(torch.cdouble)
|
|
|
|
def complex_float(self):
|
|
"""Casts this storage to complex float type"""
|
|
return self._to(torch.cfloat)
|
|
|
|
def pin_memory(self):
|
|
"""Copies the storage to pinned memory, if it's not already pinned."""
|
|
if self.is_cuda:
|
|
raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
|
|
import torch.cuda
|
|
allocator = torch.cuda._host_allocator() # type: ignore[attr-defined]
|
|
return type(self)(self.size(), allocator=allocator).copy_(self)
|
|
|
|
def share_memory_(self):
|
|
"""Moves the storage to shared memory.
|
|
|
|
This is a no-op for storages already in shared memory and for CUDA
|
|
storages, which do not need to be moved for sharing across processes.
|
|
Storages in shared memory cannot be resized.
|
|
|
|
Returns: self
|
|
"""
|
|
from torch.multiprocessing import get_sharing_strategy
|
|
if self.is_cuda:
|
|
pass # CUDA doesn't use POSIX shared memory
|
|
elif get_sharing_strategy() == 'file_system':
|
|
self._share_filename_()
|
|
else:
|
|
self._share_fd_()
|
|
return self
|
|
|
|
@classmethod
|
|
def _new_shared(cls, size):
|
|
"""Creates a new storage in shared memory with the same data type"""
|
|
from torch.multiprocessing import get_sharing_strategy
|
|
if cls.is_cuda:
|
|
return cls(size)
|
|
elif get_sharing_strategy() == 'file_system':
|
|
return cls._new_using_filename(size)
|
|
else:
|
|
return cls._new_using_fd(size)
|
|
|
|
def _untyped(self):
|
|
return self
|
|
|
|
|
|
def _load_from_bytes(b):
|
|
return torch.load(io.BytesIO(b))
|
|
|
|
|
|
_StorageBase.type = _type # type: ignore[assignment]
|
|
_StorageBase.cuda = _cuda # type: ignore[assignment]
|
|
|
|
|
|
@lru_cache(maxsize=None)
|
|
def _dtype_to_storage_type_map():
|
|
# NOTE: We should no longer add dtypes to this map. This map
|
|
# is only used for BC/FC with older PyTorch versions. Going forward,
|
|
# new dtypes of _TypedStorage should not translate to a legacy
|
|
# <type>Storage class. Instead, new dtypes of _TypedStorage should
|
|
# be serialized as an _UntypedStorage paired with a torch.dtype
|
|
return {
|
|
torch.double: 'DoubleStorage',
|
|
torch.float: 'FloatStorage',
|
|
torch.half: 'HalfStorage',
|
|
torch.long: 'LongStorage',
|
|
torch.int: 'IntStorage',
|
|
torch.int16: 'ShortStorage',
|
|
torch.int8: 'CharStorage',
|
|
torch.uint8: 'ByteStorage',
|
|
torch.bool: 'BoolStorage',
|
|
torch.bfloat16: 'BFloat16Storage',
|
|
torch.cdouble: 'ComplexDoubleStorage',
|
|
torch.cfloat: 'ComplexFloatStorage',
|
|
torch.qint8: 'QInt8Storage',
|
|
torch.qint32: 'QInt32Storage',
|
|
torch.quint8: 'QUInt8Storage',
|
|
torch.quint4x2: 'QUInt4x2Storage',
|
|
torch.quint2x4: 'QUInt2x4Storage',
|
|
}
|
|
|
|
@lru_cache(maxsize=None)
|
|
def _storage_type_to_dtype_map():
|
|
dtype_map = {
|
|
val: key for key, val in _dtype_to_storage_type_map().items()}
|
|
return dtype_map
|
|
|
|
def _get_storage_from_sequence(sequence, dtype, device):
|
|
if dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
|
interpret_dtypes = {
|
|
torch.quint8: torch.uint8,
|
|
torch.quint4x2: torch.uint8,
|
|
torch.quint2x4: torch.uint8,
|
|
torch.qint32: torch.int32,
|
|
torch.qint8: torch.int8
|
|
}
|
|
tmp_tensor = torch.tensor(
|
|
sequence,
|
|
dtype=interpret_dtypes[dtype],
|
|
device=device)
|
|
|
|
else:
|
|
tmp_tensor = torch.tensor(
|
|
sequence,
|
|
dtype=dtype,
|
|
device=device)
|
|
|
|
return tmp_tensor.storage()._untyped()
|
|
|
|
def _isint(x):
|
|
if HAS_NUMPY:
|
|
return isinstance(x, (int, np.integer))
|
|
else:
|
|
return isinstance(x, int)
|
|
|
|
class _TypedStorage:
|
|
is_sparse = False
|
|
|
|
dtype: torch.dtype
|
|
|
|
def fill_(self, value):
|
|
self[0:len(self)] = value
|
|
return self
|
|
|
|
def __new__(cls, *args, wrap_storage=None, dtype=None, device=None):
|
|
if cls == torch.storage._LegacyStorage:
|
|
raise RuntimeError("Only child classes of _LegacyStorage can be instantiated")
|
|
|
|
if cls == _TypedStorage:
|
|
return super().__new__(cls)
|
|
|
|
else:
|
|
arg_error_msg = (
|
|
f'{cls}.__new__ received an invalid combination '
|
|
f'of arguments. Expected one of:\n'
|
|
' * no arguments\n'
|
|
' * (int size)\n'
|
|
' * (Sequence data)\n'
|
|
' * (*, _UntypedStorage wrap_storage)')
|
|
|
|
if device is not None:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nKeyword argument 'device' cannot be specified")
|
|
|
|
if dtype is not None:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nKeyword argument 'dtype' cannot be specified")
|
|
|
|
if wrap_storage is None:
|
|
if len(args) > 1:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nToo many positional arguments")
|
|
|
|
if len(args) == 1 and not _isint(args[0]) and not isinstance(args[0], collections.abc.Sequence):
|
|
raise TypeError(
|
|
arg_error_msg +
|
|
f"\nArgument type not recognized: {type(args[0])}")
|
|
|
|
return _TypedStorage(
|
|
*args,
|
|
dtype=cls.dtype,
|
|
device='cuda' if eval(cls.__module__) is torch.cuda else 'cpu')
|
|
|
|
else:
|
|
if len(args) != 0:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nNo positional arguments should be given when using "
|
|
"'wrap_storage'")
|
|
|
|
if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
|
|
raise TypeError(
|
|
arg_error_msg +
|
|
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
|
|
|
|
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
|
|
|
if wrap_storage.device.type != cls_device:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
f"\nDevice of 'wrap_storage' must be {cls_device}"
|
|
f", but got {wrap_storage.device.type}")
|
|
|
|
return _TypedStorage(
|
|
*args,
|
|
wrap_storage=wrap_storage,
|
|
dtype=cls.dtype)
|
|
|
|
def __init__(self, *args, device=None, dtype=None, wrap_storage=None):
|
|
arg_error_msg = (
|
|
'_TypedStorage.__init__ received an invalid combination '
|
|
'of arguments. Expected one of:\n'
|
|
' * (*, torch.device device, torch.dtype dtype)\n'
|
|
' * (int size, *, torch.device device, torch.dtype dtype)\n'
|
|
' * (Sequence data, *, torch.device device, torch.dtype dtype)\n'
|
|
' * (*, _UntypedStorage wrap_storage, torch.dtype dtype)')
|
|
|
|
if wrap_storage is not None:
|
|
if len(args) != 0:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nNo positional arguments should be given when using "
|
|
"'wrap_storage'")
|
|
|
|
if dtype is None:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nArgument 'dtype' must be specified")
|
|
|
|
if not isinstance(dtype, torch.dtype):
|
|
raise TypeError(
|
|
arg_error_msg +
|
|
f"\nArgument 'dtype' must be torch.dtype, not {type(dtype)}")
|
|
|
|
if device is not None:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nArgument 'device' should not be specified when 'wrap_storage' is given")
|
|
|
|
self.dtype = dtype
|
|
|
|
if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
|
|
raise TypeError(
|
|
arg_error_msg +
|
|
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
|
|
|
|
self._storage = wrap_storage
|
|
|
|
else:
|
|
self.dtype = torch.get_default_dtype() if dtype is None else dtype
|
|
device = torch.device('cpu' if device is None else device)
|
|
|
|
if device.type == 'cpu':
|
|
untyped_storage_class = torch._UntypedStorage
|
|
elif device.type == 'cuda':
|
|
untyped_storage_class = torch.cuda._UntypedStorage
|
|
else:
|
|
raise RuntimeError(f"Storage device not recognized: {device}")
|
|
|
|
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
|
if device.type == 'cuda':
|
|
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
|
|
|
if len(args) == 0:
|
|
self._storage = untyped_storage_class()
|
|
|
|
elif len(args) == 1:
|
|
if _isint(args[0]):
|
|
self._storage = untyped_storage_class(int(args[0]) * self.element_size())
|
|
elif isinstance(args[0], collections.abc.Sequence):
|
|
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
|
|
else:
|
|
raise TypeError(
|
|
arg_error_msg +
|
|
f"\nArgument type not recognized: {type(args[0])}")
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
arg_error_msg +
|
|
"\nToo many positional arguments")
|
|
|
|
|
|
@property
|
|
def is_cuda(self):
|
|
return self._storage.device.type == 'cuda'
|
|
|
|
def _untyped(self):
|
|
return self._storage
|
|
|
|
def _new_wrapped_storage(self, untyped_storage):
|
|
module = eval(untyped_storage.__module__)
|
|
assert type(untyped_storage) == module._UntypedStorage
|
|
|
|
if type(self) == _TypedStorage:
|
|
return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
|
|
else:
|
|
# NOTE: We need to use the module of untyped_storage in case self's
|
|
# module is different, e.g. if self is on CPU and untyped_storage
|
|
# is on CUDA, and vice versa
|
|
return getattr(module, type(self).__name__)(wrap_storage=untyped_storage)
|
|
|
|
def __len__(self):
|
|
return self._storage.nbytes() // self.element_size()
|
|
|
|
def _maybe_wrap_index(self, idx, is_stop=False):
|
|
if idx is None:
|
|
if is_stop:
|
|
return self.size()
|
|
else:
|
|
return 0
|
|
|
|
else:
|
|
if type(idx) != int:
|
|
raise TypeError(
|
|
f"can't index a {type(self)} with {type(idx)}")
|
|
if is_stop:
|
|
if (idx > self.size()) or (idx < -self.size()):
|
|
raise IndexError(
|
|
f'index {idx} out of range for storage of size {self.size()}')
|
|
if idx > 0:
|
|
return idx
|
|
else:
|
|
return idx % self.size()
|
|
else:
|
|
if (idx >= self.size()) or (idx < -self.size()):
|
|
raise IndexError(
|
|
f'index {idx} out of range for storage of size {self.size()}')
|
|
return idx % self.size()
|
|
|
|
def __setitem__(self, idx, value):
|
|
if not isinstance(idx, (int, slice)):
|
|
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
|
|
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
|
interpret_dtypes = {
|
|
torch.quint8: torch.uint8,
|
|
torch.quint4x2: torch.uint8,
|
|
torch.quint2x4: torch.uint8,
|
|
torch.qint32: torch.int32,
|
|
torch.qint8: torch.int8
|
|
}
|
|
tmp_dtype = interpret_dtypes[self.dtype]
|
|
tmp_tensor = torch.tensor([], dtype=tmp_dtype, device=self.device).set_(_TypedStorage(
|
|
wrap_storage=self._storage,
|
|
dtype=tmp_dtype))
|
|
else:
|
|
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
|
|
|
|
tmp_tensor[idx] = value
|
|
|
|
def __getitem__(self, idx):
|
|
# NOTE: Before _TypedStorage existed, indexing with a slice used to be
|
|
# possible for <type>Storage objects. However, it would return
|
|
# a storage view, which would be a hassle to implement in _TypedStorage,
|
|
# so it was disabled
|
|
if isinstance(idx, slice):
|
|
raise RuntimeError('slices are only supported in _UntypedStorage.__getitem__')
|
|
elif not isinstance(idx, int):
|
|
raise RuntimeError(f"can't index a {type(self)} with {type(idx)}")
|
|
|
|
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
|
interpret_dtypes = {
|
|
torch.quint8: torch.uint8,
|
|
torch.quint4x2: torch.uint8,
|
|
torch.quint2x4: torch.uint8,
|
|
torch.qint32: torch.int32,
|
|
torch.qint8: torch.int8
|
|
}
|
|
return _TypedStorage(
|
|
wrap_storage=self._storage,
|
|
dtype=interpret_dtypes[self.dtype])[idx]
|
|
|
|
idx_wrapped = self._maybe_wrap_index(idx)
|
|
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
|
|
return tmp_tensor[idx_wrapped].item()
|
|
|
|
def copy_(self, source: T, non_blocking=None):
|
|
self._storage.copy_(source._untyped(), non_blocking)
|
|
return self
|
|
|
|
def nbytes(self):
|
|
return self._storage.nbytes()
|
|
|
|
def type(self, dtype: str = None, non_blocking: bool = False) -> Union[T, str]:
|
|
if dtype is None:
|
|
legacy_class = self._get_legacy_storage_class()
|
|
|
|
if legacy_class is not None:
|
|
return legacy_class.__module__ + '.' + legacy_class.__name__
|
|
|
|
return '.'.join([self.__module__, type(self).__name__])
|
|
|
|
else:
|
|
return self._storage.type(dtype, non_blocking)
|
|
|
|
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
|
|
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
|
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
|
cuda_storage = self._storage.cuda(device, non_blocking, **kwargs)
|
|
return self._new_wrapped_storage(cuda_storage)
|
|
|
|
def element_size(self):
|
|
return torch._utils._element_size(self.dtype)
|
|
|
|
def get_device(self) -> int:
|
|
return self._storage.get_device()
|
|
|
|
def __str__(self):
|
|
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
|
|
return data_str + (
|
|
f'\n[{torch.typename(self)}(dtype={self.dtype}, '
|
|
f'device={self.device}) of size {len(self)}]')
|
|
|
|
def __repr__(self):
|
|
return str(self)
|
|
|
|
def __iter__(self):
|
|
return iter(map(lambda i: self[i], range(self.size())))
|
|
|
|
def __copy__(self):
|
|
return self._new_wrapped_storage(copy.copy(self._storage))
|
|
|
|
def __deepcopy__(self, memo):
|
|
return self._new_wrapped_storage(copy.deepcopy(self._storage, memo))
|
|
|
|
def __sizeof__(self):
|
|
return super(_TypedStorage, self).__sizeof__() + self.nbytes()
|
|
|
|
def clone(self):
|
|
"""Returns a copy of this storage"""
|
|
return self._new_wrapped_storage(self._storage.clone())
|
|
|
|
def tolist(self):
|
|
"""Returns a list containing the elements of this storage"""
|
|
return list(self)
|
|
|
|
def cpu(self):
|
|
"""Returns a CPU copy of this storage if it's not already on the CPU"""
|
|
return self._new_wrapped_storage(self._storage.cpu())
|
|
|
|
def pin_memory(self):
|
|
"""Coppies the storage to pinned memory, if it's not already pinned."""
|
|
return self._new_wrapped_storage(self._storage.pin_memory())
|
|
|
|
def share_memory_(self):
|
|
"""Moves the storage to shared memory.
|
|
|
|
This is a no-op for storages already in shared memory and for CUDA
|
|
storages, which do not need to be moved for sharing across processes.
|
|
Storages in shared memory cannot be resized.
|
|
|
|
Returns: self
|
|
"""
|
|
self._storage.share_memory_()
|
|
return self
|
|
|
|
def _new_shared(self, size):
|
|
"""Creates a new storage in shared memory with the same data type"""
|
|
if self.is_cuda:
|
|
untyped_cls = torch.cuda._UntypedStorage
|
|
else:
|
|
untyped_cls = torch._UntypedStorage
|
|
untyped_storage = untyped_cls._new_shared(size * self.element_size())
|
|
return _TypedStorage(
|
|
wrap_storage=untyped_storage,
|
|
dtype=self.dtype)
|
|
|
|
@property
|
|
def _cdata(self):
|
|
return self._storage._cdata
|
|
|
|
@property
|
|
def device(self):
|
|
return self._storage.device
|
|
|
|
def size(self):
|
|
return len(self)
|
|
|
|
def pickle_storage_type(self):
|
|
try:
|
|
return _dtype_to_storage_type_map()[self.dtype]
|
|
except KeyError:
|
|
raise KeyError(f'dtype {self.dtype} is not recognized')
|
|
|
|
def __reduce__(self):
|
|
b = io.BytesIO()
|
|
torch.save(self, b, _use_new_zipfile_serialization=False)
|
|
return (_load_from_bytes, (b.getvalue(),))
|
|
|
|
def data_ptr(self):
|
|
return self._storage.data_ptr()
|
|
|
|
def resize_(self, size):
|
|
self._storage.resize_(size * self.element_size())
|
|
|
|
@classmethod
|
|
def _free_weak_ref(cls, *args, **kwargs):
|
|
return eval(cls.__module__)._UntypedStorage._free_weak_ref(*args, **kwargs)
|
|
|
|
def _weak_ref(self, *args, **kwargs):
|
|
return self._storage._weak_ref(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def from_buffer(cls, *args, dtype=None, device=None, **kwargs):
|
|
if cls == _TypedStorage:
|
|
dtype = torch.get_default_dtype() if dtype is None else dtype
|
|
device = torch.device('cpu' if device is None else device)
|
|
|
|
if device.type == 'cpu':
|
|
untyped_cls = torch._UntypedStorage
|
|
elif device.type == 'cuda':
|
|
untyped_cls = torch.cuda._UntypedStorage
|
|
else:
|
|
raise RuntimeError(
|
|
f"_TypedStorage.from_buffer: device '{device}' not recognized")
|
|
untyped_storage: Union[torch._UntypedStorage, torch.cuda._UntypedStorage]
|
|
untyped_storage = untyped_cls.from_buffer(*args, dtype=dtype, **kwargs)
|
|
|
|
else:
|
|
if dtype is not None or len(args) == 5:
|
|
raise RuntimeError((
|
|
"from_buffer: 'dtype' can only be specified in "
|
|
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
|
|
if device is not None:
|
|
raise RuntimeError((
|
|
"from_buffer: 'device' can only be specified in "
|
|
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
|
|
|
|
dtype = cls.dtype
|
|
untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
|
|
|
return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
|
|
|
|
def _to(self, dtype):
|
|
if not isinstance(dtype, torch.dtype):
|
|
raise TypeError(f"Argument 'dtype' must be torch.dtype, not {type(dtype)}")
|
|
storage = torch.tensor([], dtype=self.dtype, device=self.device).set_(self).to(dtype).storage()
|
|
if storage.data_ptr() == self.data_ptr():
|
|
storage = storage.clone()
|
|
return storage
|
|
|
|
def double(self):
|
|
"""Casts this storage to double type"""
|
|
return self._to(torch.double)
|
|
|
|
def float(self):
|
|
"""Casts this storage to float type"""
|
|
return self._to(torch.float)
|
|
|
|
def half(self):
|
|
"""Casts this storage to half type"""
|
|
return self._to(torch.half)
|
|
|
|
def long(self):
|
|
"""Casts this storage to long type"""
|
|
return self._to(torch.long)
|
|
|
|
def int(self):
|
|
"""Casts this storage to int type"""
|
|
return self._to(torch.int)
|
|
|
|
def short(self):
|
|
"""Casts this storage to short type"""
|
|
return self._to(torch.short)
|
|
|
|
def char(self):
|
|
"""Casts this storage to char type"""
|
|
return self._to(torch.int8)
|
|
|
|
def byte(self):
|
|
"""Casts this storage to byte type"""
|
|
return self._to(torch.uint8)
|
|
|
|
def bool(self):
|
|
"""Casts this storage to bool type"""
|
|
return self._to(torch.bool)
|
|
|
|
def bfloat16(self):
|
|
"""Casts this storage to bfloat16 type"""
|
|
return self._to(torch.bfloat16)
|
|
|
|
def complex_double(self):
|
|
"""Casts this storage to complex double type"""
|
|
return self._to(torch.cdouble)
|
|
|
|
def complex_float(self):
|
|
"""Casts this storage to complex float type"""
|
|
return self._to(torch.cfloat)
|
|
|
|
@classmethod
|
|
def from_file(cls, filename, shared, size):
|
|
"""
|
|
from_file(filename, shared=False, size=0) -> Storage
|
|
|
|
If `shared` is `True`, then memory is shared between all processes.
|
|
All changes are written to the file. If `shared` is `False`, then the changes on
|
|
the storage do not affect the file.
|
|
|
|
`size` is the number of elements in the storage. If `shared` is `False`,
|
|
then the file must contain at least `size * sizeof(Type)` bytes
|
|
(`Type` is the type of storage). If `shared` is `True` the file will be
|
|
created if needed.
|
|
|
|
Args:
|
|
filename (str): file name to map
|
|
shared (bool): whether to share memory
|
|
size (int): number of elements in the storage
|
|
"""
|
|
if cls == _TypedStorage:
|
|
raise RuntimeError('from_file can only be called on derived classes')
|
|
untyped_storage = eval(cls.__module__)._UntypedStorage.from_file(
|
|
filename,
|
|
shared,
|
|
size * torch._utils._element_size(cls.dtype))
|
|
storage = cls(wrap_storage=untyped_storage)
|
|
return storage
|
|
|
|
@classmethod
|
|
def _expired(cls, *args, **kwargs):
|
|
return eval(cls.__module__)._UntypedStorage._expired(*args, **kwargs)
|
|
|
|
def is_pinned(self):
|
|
return self._storage.is_pinned()
|
|
|
|
def _write_file(self, *args, **kwargs):
|
|
return self._storage._write_file(*args, **kwargs)
|
|
|
|
def _set_from_file(self, *args, **kwargs):
|
|
return self._storage._set_from_file(*args, **kwargs)
|
|
|
|
def _set_cdata(self, *args, **kwargs):
|
|
return self._storage._set_cdata(*args, **kwargs)
|
|
|
|
def _share_cuda_(self, *args, **kwargs):
|
|
return self._storage._share_cuda_(*args, **kwargs)
|
|
|
|
def is_shared(self):
|
|
return self._storage.is_shared()
|
|
|
|
@classmethod
|
|
def _new_shared_cuda(cls, *args, **kwargs):
|
|
return torch.cuda._UntypedStorage._new_shared_cuda(*args, **kwargs)
|
|
|
|
def _share_filename_(self, *args, **kwargs):
|
|
manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
|
|
return manager_handle, storage_handle, size // self.element_size()
|
|
|
|
def _shared_decref(self):
|
|
self._storage._shared_decref()
|
|
return self
|
|
|
|
@classmethod
|
|
def _release_ipc_counter(cls, *args, device=None, **kwargs):
|
|
device = torch.device('cpu' if device is None else device)
|
|
|
|
if device.type == 'cpu':
|
|
untyped_cls = torch._UntypedStorage
|
|
elif device.type == 'cuda':
|
|
untyped_cls = torch.cuda._UntypedStorage
|
|
else:
|
|
raise RuntimeError(f"device {device} not recognized")
|
|
|
|
return untyped_cls._release_ipc_counter(*args, **kwargs)
|
|
|
|
def _shared_incref(self, *args, **kwargs):
|
|
return self._storage._shared_incref(*args, **kwargs)
|
|
|
|
def _share_fd_(self, *args, **kwargs):
|
|
fd, size = self._storage._share_fd_(*args, **kwargs)
|
|
return fd, size // self.element_size()
|
|
|
|
def _get_legacy_storage_class(self):
|
|
if self.dtype not in _dtype_to_storage_type_map():
|
|
return None
|
|
|
|
storage_name = _dtype_to_storage_type_map()[self.dtype]
|
|
|
|
if self.device.type not in ['cpu', 'cuda']:
|
|
return None
|
|
|
|
module = 'torch.' if self.device.type == 'cpu' else 'torch.cuda.'
|
|
|
|
try:
|
|
return eval(module + storage_name)
|
|
except AttributeError:
|
|
return None
|
|
|
|
_TypedStorage.type.__doc__ = _type.__doc__
|
|
_TypedStorage.cuda.__doc__ = _cuda.__doc__
|
|
|
|
class _LegacyStorageMeta(type):
|
|
dtype: torch.dtype
|
|
|
|
def __instancecheck__(cls, instance):
|
|
if type(instance) == _TypedStorage:
|
|
cls_device = 'cuda' if cls.__module__ == 'torch.cuda' else 'cpu'
|
|
return (cls_device == instance.device.type) and (cls.dtype == instance.dtype)
|
|
return False
|
|
|
|
class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
|
|
@classmethod
|
|
def _new_shared(cls, size):
|
|
"""Creates a new storage in shared memory with the same data type"""
|
|
module = eval(cls.__module__)
|
|
untyped_storage = module._UntypedStorage._new_shared(size * cls().element_size())
|
|
return cls(wrap_storage=untyped_storage)
|
|
|
|
@classmethod
|
|
def _release_ipc_counter(cls, *args, **kwargs):
|
|
return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
|
|
|
|
@classmethod
|
|
def _new_shared_filename(cls, manager, obj, size):
|
|
bytes_size = size * torch._utils._element_size(cls.dtype)
|
|
return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
|
|
|
|
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
|
|
try:
|
|
return _storage_type_to_dtype_map()[pickle_storage_type]
|
|
except KeyError:
|
|
raise KeyError(
|
|
f'pickle storage type "{pickle_storage_type}" is not recognized')
|