mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Merge torch.cuda._UntypedStorage into torch._UntypedStorage (#75459)
Fixes #74933 Pull Request resolved: https://github.com/pytorch/pytorch/pull/75459 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
ac1837ddd3
commit
aea6e2c396
@ -570,7 +570,7 @@ class TestCuda(TestCase):
|
||||
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
|
||||
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
|
||||
self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3]._storage, torch.cuda._UntypedStorage))
|
||||
self.assertTrue(isinstance(q_copy[3]._storage, torch._UntypedStorage))
|
||||
q_copy[1].fill_(10)
|
||||
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))
|
||||
|
||||
|
@ -41,59 +41,42 @@ class TestPublicBindings(TestCase):
|
||||
"AVG",
|
||||
"BenchmarkConfig",
|
||||
"BenchmarkExecutionStats",
|
||||
"BFloat16StorageBase",
|
||||
"Block",
|
||||
"BoolStorageBase",
|
||||
"BoolType",
|
||||
"BufferDict",
|
||||
"ByteStorageBase",
|
||||
"StorageBase",
|
||||
"CallStack",
|
||||
"Capsule",
|
||||
"CharStorageBase",
|
||||
"ClassType",
|
||||
"clear_autocast_cache",
|
||||
"Code",
|
||||
"CompilationUnit",
|
||||
"CompleteArgumentSpec",
|
||||
"ComplexDoubleStorageBase",
|
||||
"ComplexFloatStorageBase",
|
||||
"ComplexType",
|
||||
"ConcreteModuleType",
|
||||
"ConcreteModuleTypeBuilder",
|
||||
"CONV_BN_FUSION",
|
||||
"cpp",
|
||||
"CudaBFloat16StorageBase",
|
||||
"CudaBFloat16TensorBase",
|
||||
"CudaBFloat16TensorBase",
|
||||
"CudaBoolStorageBase",
|
||||
"CudaBoolTensorBase",
|
||||
"CudaBoolTensorBase",
|
||||
"CudaByteStorageBase",
|
||||
"CudaByteTensorBase",
|
||||
"CudaByteTensorBase",
|
||||
"CudaCharStorageBase",
|
||||
"CudaCharTensorBase",
|
||||
"CudaCharTensorBase",
|
||||
"CudaComplexDoubleStorageBase",
|
||||
"CudaComplexDoubleTensorBase",
|
||||
"CudaComplexDoubleTensorBase",
|
||||
"CudaComplexFloatStorageBase",
|
||||
"CudaComplexFloatTensorBase",
|
||||
"CudaComplexFloatTensorBase",
|
||||
"CudaDoubleStorageBase",
|
||||
"CudaDoubleTensorBase",
|
||||
"CudaDoubleTensorBase",
|
||||
"CudaFloatStorageBase",
|
||||
"CudaFloatTensorBase",
|
||||
"CudaHalfStorageBase",
|
||||
"CudaHalfTensorBase",
|
||||
"CudaIntStorageBase",
|
||||
"CudaIntTensorBase",
|
||||
"CudaIntTensorBase",
|
||||
"CudaLongStorageBase",
|
||||
"CudaLongTensorBase",
|
||||
"CudaLongTensorBase",
|
||||
"CudaShortStorageBase",
|
||||
"CudaShortTensorBase",
|
||||
"CudaShortTensorBase",
|
||||
"DeepCopyMemoTable",
|
||||
@ -103,7 +86,6 @@ class TestPublicBindings(TestCase):
|
||||
"DeviceObjType",
|
||||
"DictType",
|
||||
"DisableTorchFunction",
|
||||
"DoubleStorageBase",
|
||||
"dtype",
|
||||
"EnumType",
|
||||
"ErrorReport",
|
||||
@ -111,7 +93,6 @@ class TestPublicBindings(TestCase):
|
||||
"FatalError",
|
||||
"FileCheck",
|
||||
"finfo",
|
||||
"FloatStorageBase",
|
||||
"FloatType",
|
||||
"fork",
|
||||
"FunctionSchema",
|
||||
@ -126,7 +107,6 @@ class TestPublicBindings(TestCase):
|
||||
"Gradient",
|
||||
"Graph",
|
||||
"GraphExecutorState",
|
||||
"HalfStorageBase",
|
||||
"has_cuda",
|
||||
"has_cudnn",
|
||||
"has_lapack",
|
||||
@ -143,7 +123,6 @@ class TestPublicBindings(TestCase):
|
||||
"init_num_threads",
|
||||
"INSERT_FOLD_PREPACK_OPS",
|
||||
"InterfaceType",
|
||||
"IntStorageBase",
|
||||
"IntType",
|
||||
"SymIntType",
|
||||
"IODescriptor",
|
||||
@ -159,7 +138,6 @@ class TestPublicBindings(TestCase):
|
||||
"LiteScriptModule",
|
||||
"LockingLogger",
|
||||
"LoggerBase",
|
||||
"LongStorageBase",
|
||||
"memory_format",
|
||||
"merge_type_from_type_comment",
|
||||
"MobileOptimizerType",
|
||||
@ -177,12 +155,7 @@ class TestPublicBindings(TestCase):
|
||||
"PyObjectType",
|
||||
"PyTorchFileReader",
|
||||
"PyTorchFileWriter",
|
||||
"QInt32StorageBase",
|
||||
"QInt8StorageBase",
|
||||
"qscheme",
|
||||
"QUInt4x2StorageBase",
|
||||
"QUInt2x4StorageBase",
|
||||
"QUInt8StorageBase",
|
||||
"read_vitals",
|
||||
"REMOVE_DROPOUT",
|
||||
"RRefType",
|
||||
@ -209,7 +182,6 @@ class TestPublicBindings(TestCase):
|
||||
"set_num_interop_threads",
|
||||
"set_num_threads",
|
||||
"set_vital",
|
||||
"ShortStorageBase",
|
||||
"Size",
|
||||
"StaticModule",
|
||||
"Stream",
|
||||
|
@ -6286,7 +6286,7 @@ class TestTorch(TestCase):
|
||||
torch.storage._LegacyStorage()
|
||||
|
||||
for storage_class in torch._storage_classes:
|
||||
if storage_class in [torch._UntypedStorage, torch.cuda._UntypedStorage, torch._TypedStorage]:
|
||||
if storage_class in [torch._UntypedStorage, torch._TypedStorage]:
|
||||
continue
|
||||
|
||||
device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
|
||||
@ -6371,22 +6371,16 @@ class TestTorch(TestCase):
|
||||
storage_classes = [
|
||||
torch.cuda.ByteStorage,
|
||||
torch.cuda.FloatStorage,
|
||||
torch.cuda._UntypedStorage,
|
||||
]
|
||||
for storage_class in storage_classes:
|
||||
with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
|
||||
storage_class.from_buffer()
|
||||
|
||||
if storage_class == torch.cuda._UntypedStorage:
|
||||
with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
|
||||
storage_class._new_with_weak_ptr()
|
||||
|
||||
else:
|
||||
with self.assertRaisesRegex(AttributeError, r'has no attribute'):
|
||||
storage_class._new_with_weak_ptr()
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
|
||||
storage_class._new_shared_filename(0, 0, 0)
|
||||
storage_class._new_shared_filename_cpu(0, 0, 0)
|
||||
|
||||
def test_storage_casts(self):
|
||||
storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])
|
||||
|
@ -832,10 +832,8 @@ libtorch_python_cuda_core_sources = [
|
||||
"torch/csrc/cuda/Event.cpp",
|
||||
"torch/csrc/cuda/Module.cpp",
|
||||
"torch/csrc/cuda/python_comm.cpp",
|
||||
"torch/csrc/cuda/Storage.cpp",
|
||||
"torch/csrc/cuda/Stream.cpp",
|
||||
"torch/csrc/cuda/Graph.cpp",
|
||||
"torch/csrc/cuda/serialization.cpp",
|
||||
"torch/csrc/cuda/shared/cudart.cpp",
|
||||
"torch/csrc/cuda/shared/nvtx.cpp",
|
||||
"torch/csrc/cuda/utils.cpp",
|
||||
|
@ -786,33 +786,7 @@ def gen_pyi(
|
||||
# Generate type signatures for legacy classes
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
# TODO: These are deprecated, maybe we shouldn't type hint them
|
||||
legacy_storage_base_hints = []
|
||||
dt = (
|
||||
"Double",
|
||||
"Float",
|
||||
"Long",
|
||||
"Int",
|
||||
"Short",
|
||||
"Char",
|
||||
"Byte",
|
||||
"Bool",
|
||||
"Half",
|
||||
"BFloat16",
|
||||
"ComplexDouble",
|
||||
"ComplexFloat",
|
||||
"QUInt8",
|
||||
"QInt8",
|
||||
"QInt32",
|
||||
"QUInt4x2",
|
||||
"QUInt2x4",
|
||||
)
|
||||
for c in dt:
|
||||
legacy_storage_base_hints.append("class {}StorageBase(object): ...".format(c))
|
||||
for c in dt:
|
||||
legacy_storage_base_hints.append(
|
||||
"class Cuda{}StorageBase(object): ...".format(c)
|
||||
)
|
||||
legacy_storage_base_hints = ["class StorageBase(object): ..."]
|
||||
|
||||
legacy_class_hints = []
|
||||
for c in (
|
||||
|
@ -619,14 +619,11 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
|
||||
################################################################################
|
||||
|
||||
from ._tensor import Tensor
|
||||
from .storage import _StorageBase, _TypedStorage, _LegacyStorage
|
||||
from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage
|
||||
|
||||
# NOTE: New <type>Storage classes should never be added. When adding a new
|
||||
# dtype, use torch.storage._TypedStorage directly.
|
||||
|
||||
class _UntypedStorage(_C.ByteStorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
class ByteStorage(_LegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
@ -784,7 +781,8 @@ from .functional import * # noqa: F403
|
||||
# Remove unnecessary members
|
||||
################################################################################
|
||||
|
||||
del ByteStorageBase
|
||||
del _StorageBase
|
||||
del _LegacyStorage
|
||||
|
||||
################################################################################
|
||||
# Define _assert
|
||||
|
@ -5,7 +5,7 @@ from torch._C import _add_docstr as add_docstr
|
||||
|
||||
|
||||
storage_classes = [
|
||||
'ByteStorageBase',
|
||||
'StorageBase',
|
||||
]
|
||||
|
||||
|
||||
|
@ -75,8 +75,7 @@ def _cuda(self, device=None, non_blocking=False, **kwargs):
|
||||
values = torch.Tensor._values(self).cuda(device, non_blocking)
|
||||
return new_type(indices, values, self.size())
|
||||
else:
|
||||
new_type = getattr(torch.cuda, self.__class__.__name__)
|
||||
return new_type(self.size()).copy_(self, non_blocking)
|
||||
return torch._UntypedStorage(self.size(), device=torch.device('cuda')).copy_(self, non_blocking)
|
||||
|
||||
|
||||
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
|
||||
|
@ -9,6 +9,8 @@
|
||||
#include <torch/csrc/utils/cuda_enabled.h>
|
||||
#include <torch/csrc/utils/cuda_lazy_init.h>
|
||||
#include <torch/csrc/utils/object_ptr.h>
|
||||
#include <torch/csrc/Storage.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
|
||||
@ -22,9 +24,6 @@
|
||||
|
||||
namespace torch {
|
||||
namespace {
|
||||
std::unordered_map<at::DeprecatedTypeProperties*, PyTypeObject*> attype_to_py_storage_type;
|
||||
std::unordered_map<PyTypeObject*, at::DeprecatedTypeProperties*> py_storage_type_to_attype;
|
||||
|
||||
std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)> dtype_registry = {};
|
||||
|
||||
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)> layout_registry = {};
|
||||
@ -51,34 +50,8 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala
|
||||
}
|
||||
return &at::getDeprecatedTypeProperties(backend, scalarType);
|
||||
}
|
||||
|
||||
PyTypeObject* getPyTypeObject(const at::Storage& storage) {
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/47442
|
||||
if (storage.device_type() == at::DeviceType::Meta) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
|
||||
}
|
||||
if (storage.data() == nullptr && storage.nbytes() != 0) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details");
|
||||
}
|
||||
at::ScalarType scalarType = at::ScalarType::Byte;
|
||||
auto attype = &at::getDeprecatedTypeProperties(
|
||||
at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())),
|
||||
scalarType);
|
||||
auto it = attype_to_py_storage_type.find(attype);
|
||||
TORCH_INTERNAL_ASSERT(it != attype_to_py_storage_type.end(),
|
||||
"Failed to get the Python type of `_UntypedStorage`.");
|
||||
return it->second;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void registerStoragePyTypeObject(PyTypeObject *pytype, at::Backend backend, at::ScalarType scalarType) {
|
||||
auto attype = get_type(backend, scalarType);
|
||||
if (attype) {
|
||||
attype_to_py_storage_type[attype] = pytype;
|
||||
py_storage_type_to_attype[pytype] = attype;
|
||||
}
|
||||
}
|
||||
|
||||
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType) {
|
||||
dtype_registry[static_cast<int>(scalarType)] = dtype;
|
||||
}
|
||||
@ -104,7 +77,14 @@ THPLayout* getTHPLayout(at::Layout layout) {
|
||||
}
|
||||
|
||||
PyObject* createPyObject(const at::Storage& storage) {
|
||||
auto type = getPyTypeObject(storage);
|
||||
// TODO: https://github.com/pytorch/pytorch/issues/47442
|
||||
if (storage.device_type() == at::DeviceType::Meta) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
|
||||
}
|
||||
if (storage.data() == nullptr && storage.nbytes() != 0) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details");
|
||||
}
|
||||
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass);
|
||||
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
|
||||
if (!obj) throw python_error();
|
||||
((THPVoidStorage*)obj.get())->cdata = at::Storage(/* copy */ storage).unsafeReleaseStorageImpl();
|
||||
@ -133,50 +113,56 @@ bool isStorage(PyObject* obj)
|
||||
return true;
|
||||
}
|
||||
auto obj_type = Py_TYPE(obj);
|
||||
for (auto const& item : py_storage_type_to_attype) {
|
||||
auto const& storage_type = item.first;
|
||||
if (obj_type == storage_type) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
||||
return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
|
||||
}
|
||||
|
||||
at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, bool& is_typed_storage)
|
||||
{
|
||||
is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
|
||||
THPObjectPtr maybe_untyped_storage;
|
||||
if (is_typed_storage) {
|
||||
PyObject* maybe_untyped_storage_obj = PyObject_GetAttrString(obj, "_storage");
|
||||
TORCH_INTERNAL_ASSERT(maybe_untyped_storage_obj);
|
||||
maybe_untyped_storage = maybe_untyped_storage_obj;
|
||||
}
|
||||
PyObject* untyped_storage_obj;
|
||||
|
||||
auto obj_type = Py_TYPE(obj);
|
||||
for (auto const& item : py_storage_type_to_attype) {
|
||||
auto const& storage_type = item.first;
|
||||
if (is_typed_storage) {
|
||||
if (Py_TYPE(maybe_untyped_storage.get()) == storage_type) {
|
||||
auto& type = *item.second;
|
||||
auto ret = type.unsafeStorageFromTH(
|
||||
((THPVoidStorage*)maybe_untyped_storage.get())->cdata,
|
||||
true);
|
||||
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
|
||||
TORCH_INTERNAL_ASSERT(dtype_obj && THPDtype_Check(dtype_obj));
|
||||
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
if (obj_type == storage_type) {
|
||||
auto& type = *item.second;
|
||||
// _UntypedStorage should always be interpreted with byte dtype
|
||||
|
||||
untyped_storage_obj = PyObject_GetAttrString(obj, "_storage");
|
||||
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
|
||||
|
||||
// `PyObject_GetAttrString` increments the refcount to the
|
||||
// `_UntypedStorage`, so we must decrement it. The refcount will stay
|
||||
// positive for as long as we need it, since the `_TypedStorage` will
|
||||
// retain its reference to the `_UntypedStorage`, so decrementing it
|
||||
// immediately like this is fine.
|
||||
Py_DECREF(untyped_storage_obj);
|
||||
|
||||
} else {
|
||||
scalar_type = at::kByte;
|
||||
return type.unsafeStorageFromTH(((THPVoidStorage*)obj)->cdata, true);
|
||||
}
|
||||
untyped_storage_obj = obj;
|
||||
}
|
||||
|
||||
if (Py_TYPE(untyped_storage_obj) != reinterpret_cast<PyTypeObject*>(THPStorageClass)) {
|
||||
throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name);
|
||||
}
|
||||
|
||||
c10::StorageImpl* impl = static_cast<c10::StorageImpl*>(((THPVoidStorage*)untyped_storage_obj)->cdata);
|
||||
c10::DeviceType device_type = impl->device().type();
|
||||
|
||||
at::Backend backend;
|
||||
if (device_type == at::kCPU) {
|
||||
backend = at::Backend::CPU;
|
||||
} else if (device_type == at::kCUDA) {
|
||||
backend = at::Backend::CUDA;
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid device for storage: ", device_type);
|
||||
}
|
||||
|
||||
auto type_properties = get_type(backend, at::kByte);
|
||||
|
||||
return type_properties->unsafeStorageFromTH(((THPVoidStorage*)untyped_storage_obj)->cdata, true);
|
||||
}
|
||||
|
||||
at::Storage createStorage(PyObject* obj) {
|
||||
at::ScalarType scalar_type;
|
||||
bool is_typed_storage = false;
|
||||
|
@ -21,10 +21,6 @@ struct Storage;
|
||||
}
|
||||
|
||||
namespace torch {
|
||||
// Register a PyTypeObject* with the given attributes
|
||||
void registerStoragePyTypeObject(
|
||||
PyTypeObject *pytype, at::Backend backend, at::ScalarType scalarType);
|
||||
|
||||
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType);
|
||||
void registerLayoutObject(THPLayout *thp_layout, at::Layout layout);
|
||||
|
||||
|
@ -133,7 +133,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
|
||||
auto module = THPObjectPtr(PyImport_ImportModule("torch"));
|
||||
if (!module) throw python_error();
|
||||
|
||||
THPByteStorage_postInit(module);
|
||||
THPStorage_postInit(module);
|
||||
THPAutograd_initFunctions();
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -735,7 +735,6 @@ static PyMethodDef TorchMethods[] = {
|
||||
{nullptr, nullptr, 0, nullptr}
|
||||
};
|
||||
|
||||
bool THCPByteStorage_init(PyObject *module);
|
||||
void THCPStream_init(PyObject *module);
|
||||
void THCPEvent_init(PyObject *module);
|
||||
void THCPGraph_init(PyObject *module);
|
||||
@ -749,8 +748,6 @@ void initModule(PyObject *module);
|
||||
}} // namespace torch::cuda
|
||||
#endif
|
||||
|
||||
bool THDPByteStorage_init(PyObject *module);
|
||||
|
||||
static std::vector<PyMethodDef> methods;
|
||||
|
||||
// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
|
||||
@ -853,14 +850,13 @@ PyObject* initModule() {
|
||||
#ifdef USE_CUDA
|
||||
torch::cuda::initModule(module);
|
||||
#endif
|
||||
ASSERT_TRUE(THPByteStorage_init(module));
|
||||
ASSERT_TRUE(THPStorage_init(module));
|
||||
|
||||
#ifdef USE_CUDA
|
||||
// This will only initialise base classes and attach them to library namespace
|
||||
// They won't be ready for real usage until importing cuda module, that will
|
||||
// complete the process (but it defines Python classes before calling back into
|
||||
// C, so these lines have to execute first)..
|
||||
ASSERT_TRUE(THCPByteStorage_init(module));
|
||||
THCPStream_init(module);
|
||||
THCPEvent_init(module);
|
||||
THCPGraph_init(module);
|
||||
|
@ -2,17 +2,15 @@
|
||||
#define THP_STORAGE_INC
|
||||
#include <torch/csrc/THConcat.h>
|
||||
|
||||
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
|
||||
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
|
||||
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
|
||||
#define THPStorageStr "torch._UntypedStorage"
|
||||
#define THPStorage_(NAME) TH_CONCAT_2(THPStorage_,NAME)
|
||||
|
||||
#define THPByteStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THPByteStorageClass)
|
||||
#define THPStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THPStorageClass)
|
||||
|
||||
#define THPByteStorage_CData(obj) (obj)->cdata
|
||||
#define THPStorage_CData(obj) (obj)->cdata
|
||||
|
||||
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
|
||||
#define THPStorageBaseStr TH_CONCAT_STRING_2(Real,StorageBase)
|
||||
#define THPStorageBaseStr "StorageBase"
|
||||
|
||||
#include <torch/csrc/generic/Storage.h>
|
||||
#include <torch/csrc/THGenerateByteType.h>
|
||||
|
@ -591,9 +591,6 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
|
||||
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
|
||||
if (!m) throw python_error();
|
||||
|
||||
// Register Storage Python objects with DynamicTypes.cpp
|
||||
THCPByteStorage_postInit(m);
|
||||
|
||||
bool has_half = true;
|
||||
|
||||
auto set_module_attr = [&](const char* name, PyObject* v) {
|
||||
|
@ -1,17 +0,0 @@
|
||||
#define __STDC_FORMAT_MACROS
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <structmember.h>
|
||||
#include <fmt/format.h>
|
||||
|
||||
#include <torch/csrc/cuda/THCP.h>
|
||||
|
||||
#include <torch/csrc/cuda/override_macros.h>
|
||||
#include <torch/csrc/copy_utils.h>
|
||||
#include <torch/csrc/DynamicTypes.h>
|
||||
#include <torch/csrc/CudaIPCTypes.h>
|
||||
#include <torch/csrc/Device.h>
|
||||
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
|
||||
#include <torch/csrc/THCGenerateByteType.h>
|
@ -1,21 +0,0 @@
|
||||
#ifndef THCP_STORAGE_INC
|
||||
#define THCP_STORAGE_INC
|
||||
|
||||
#define THCPStorageStr TH_CONCAT_STRING_3(torch.cuda.,Real,Storage)
|
||||
#define THCPStorageClass TH_CONCAT_3(THCP,Real,StorageClass)
|
||||
#define THCPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
|
||||
|
||||
#define THCPByteStorage_Check(obj) \
|
||||
PyObject_IsInstance(obj, THCPByteStorageClass)
|
||||
|
||||
#define THCPByteStorage_CData(obj) (obj)->cdata
|
||||
|
||||
#define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType)
|
||||
#define THCPStorageBaseStr TH_CONCAT_STRING_3(Cuda,Real,StorageBase)
|
||||
|
||||
#include <torch/csrc/cuda/override_macros.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
|
||||
#include <torch/csrc/THCGenerateByteType.h>
|
||||
|
||||
#endif
|
@ -3,9 +3,7 @@
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/THP.h>
|
||||
#include <torch/csrc/cuda/serialization.h>
|
||||
#include <torch/csrc/cuda/Module.h>
|
||||
#include <torch/csrc/cuda/Storage.h>
|
||||
#include <torch/csrc/cuda/Stream.h>
|
||||
#include <torch/csrc/cuda/Event.h>
|
||||
#include <torch/csrc/cuda/utils.h>
|
||||
|
@ -1,17 +1,10 @@
|
||||
#include <torch/csrc/cuda/undef_macros.h>
|
||||
|
||||
#define THPStoragePtr THCPStoragePtr
|
||||
#define THPTensorPtr THCPTensorPtr
|
||||
|
||||
#define THWTensor THCTensor
|
||||
#define THWTensor_(NAME) THCTensor_(NAME)
|
||||
|
||||
#define THPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
|
||||
#define THPStorageBaseStr THCPStorageBaseStr
|
||||
#define THPStorageStr THCPStorageStr
|
||||
#define THPStorageClass THCPStorageClass
|
||||
#define THPStorageType THCPStorageType
|
||||
|
||||
#define THPTensor_(NAME) TH_CONCAT_4(THCP,Real,Tensor_,NAME)
|
||||
#define THPTensor_stateless_(NAME) TH_CONCAT_4(THCP,Real,Tensor_stateless_,NAME)
|
||||
#define THPTensor THCPTensor
|
||||
|
@ -8,10 +8,5 @@
|
||||
#define THPTensorClass TH_CONCAT_3(THP,Real,TensorClass)
|
||||
#define THPTensor_(NAME) TH_CONCAT_4(THP,Real,Tensor_,NAME)
|
||||
|
||||
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
|
||||
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
|
||||
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
|
||||
|
||||
#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
|
||||
#define THPStoragePtr TH_CONCAT_3(THP,Real,StoragePtr)
|
||||
#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr)
|
||||
|
@ -1,11 +0,0 @@
|
||||
#include <torch/csrc/python_headers.h>
|
||||
|
||||
#include <torch/csrc/cuda/THCP.h>
|
||||
|
||||
#include <torch/csrc/cuda/override_macros.h>
|
||||
|
||||
#include <system_error>
|
||||
#include <memory>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
|
||||
#include <torch/csrc/THCGenerateByteType.h>
|
@ -1,9 +0,0 @@
|
||||
#ifndef THCP_SERIALIZATION_INC
|
||||
#define THCP_SERIALIZATION_INC
|
||||
|
||||
#include <torch/csrc/cuda/override_macros.h>
|
||||
|
||||
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
|
||||
#include <torch/csrc/THCGenerateByteType.h>
|
||||
|
||||
#endif
|
@ -11,13 +11,6 @@
|
||||
#undef THPTensorStateless
|
||||
#undef THPTensorType
|
||||
|
||||
#undef THPStorage_
|
||||
#undef THPStorageBaseStr
|
||||
#undef THPStorageStr
|
||||
#undef THPStorageClass
|
||||
#undef THPStorageType
|
||||
|
||||
#undef THPStoragePtr
|
||||
#undef THPTensorPtr
|
||||
|
||||
#undef THWTensor
|
||||
|
@ -1,6 +1,7 @@
|
||||
#ifndef TH_GENERIC_FILE
|
||||
#define TH_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
|
||||
#else
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
PyObject *THPStorageClass = nullptr;
|
||||
|
||||
@ -26,68 +27,84 @@ static void THPStorage_(dealloc)(THPStorage* self)
|
||||
static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
|
||||
|
||||
static torch::PythonArgParser parser({
|
||||
THPStorageStr "(*, int64_t allocator=None, Device device=None)",
|
||||
THPStorageStr "(int64_t size, *, int64_t allocator=None, Device device=None)",
|
||||
THPStorageStr "(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
|
||||
});
|
||||
torch::ParsedArgs<3> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
int64_t allocator_arg_idx = 0;
|
||||
int64_t device_arg_idx = 1;
|
||||
|
||||
if (r.idx > 0) {
|
||||
allocator_arg_idx = 1;
|
||||
device_arg_idx = 2;
|
||||
}
|
||||
|
||||
c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
|
||||
c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
|
||||
|
||||
TORCH_CHECK(!allocator_opt.has_value() || !device_opt.has_value(),
|
||||
THPStorageStr, "(): only one or neither of 'allocator' or 'device' can ",
|
||||
"be given, but not both");
|
||||
|
||||
THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0));
|
||||
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
|
||||
c10::Allocator* allocator = nullptr;
|
||||
at::OptionalDeviceGuard device_guard;
|
||||
|
||||
// Internally we allow constructing with a keywoard only argument cdata
|
||||
if (kwargs != nullptr) {
|
||||
PyObject *allocator_ptr = PyDict_GetItemString(kwargs, "allocator");
|
||||
if (allocator_ptr) {
|
||||
THPUtils_assert(THPUtils_checkLong(allocator_ptr), "invalid allocator");
|
||||
allocator = static_cast<c10::Allocator*>(PyLong_AsVoidPtr(allocator_ptr));
|
||||
PyDict_DelItemString(kwargs, "allocator");
|
||||
}
|
||||
|
||||
Py_ssize_t num_kwargs = PyDict_Size(kwargs);
|
||||
if (num_args == 0) {
|
||||
PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
|
||||
if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
|
||||
c10::StorageImpl *ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(cdata_ptr);
|
||||
self->cdata = ptr;
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
}
|
||||
THPUtils_assert(num_kwargs == 0, THPStorageStr "(): invalid keyword arguments");
|
||||
}
|
||||
if (allocator == nullptr) {
|
||||
#if defined(THC_GENERIC_FILE)
|
||||
allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
#else
|
||||
if (allocator_opt.has_value()) {
|
||||
allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
|
||||
} else if (device_opt.has_value()) {
|
||||
at::Device device = device_opt.value();
|
||||
if (device.type() == at::kCPU) {
|
||||
allocator = c10::GetDefaultCPUAllocator();
|
||||
#ifdef USE_CUDA
|
||||
} else if (device.type() == at::kCUDA) {
|
||||
at::globalContext().lazyInitCUDA();
|
||||
allocator = c10::cuda::CUDACachingAllocator::get();
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
THPStorageStr, "(): Storage device not recognized: ", device.type());
|
||||
}
|
||||
device_guard.reset_device(device);
|
||||
} else {
|
||||
allocator = c10::GetDefaultCPUAllocator();
|
||||
}
|
||||
|
||||
// torch.Storage()
|
||||
if (num_args == 0) {
|
||||
// torch.Storage(*, ...)
|
||||
if (r.idx == 0) {
|
||||
self->cdata = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
0,
|
||||
allocator,
|
||||
/*resizable=*/true).release();
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
|
||||
|
||||
// torch.Storage(size)
|
||||
if (num_args == 1 && THPUtils_checkLong(first_arg)) {
|
||||
int64_t size = THPUtils_unpackLong(first_arg);
|
||||
// torch.Storage(size, *, ...)
|
||||
} else if (r.idx == 1) {
|
||||
int64_t size = r.toInt64(0);
|
||||
self->cdata = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size,
|
||||
allocator,
|
||||
/*resizable=*/true).release();
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
// torch.Storage(sequence)
|
||||
if (num_args == 1 && PySequence_Check(first_arg)) {
|
||||
Py_ssize_t length = PySequence_Length(first_arg);
|
||||
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
|
||||
THPUtils_typename(first_arg));
|
||||
// torch.Storage(sequence, *, ...)
|
||||
} else if (r.idx == 2) {
|
||||
PyObject *sequence = r.pyobject(0);
|
||||
Py_ssize_t length = PySequence_Length(sequence);
|
||||
TORCH_CHECK(PySequence_Check(sequence),
|
||||
THPStorageStr, "(): Expected a sequence type, but got ",
|
||||
THPUtils_typename(sequence));
|
||||
TORCH_CHECK(length >= 0,
|
||||
THPStorageStr, "(): Could not obtain the length of sequence of type ",
|
||||
THPUtils_typename(sequence));
|
||||
self->cdata = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
length,
|
||||
@ -97,35 +114,31 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
|
||||
THPObjectPtr item;
|
||||
try {
|
||||
for (Py_ssize_t i = 0; i < length; i++) {
|
||||
item = PySequence_GetItem(first_arg, i);
|
||||
item = PySequence_GetItem(sequence, i);
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
scalar_t value = THPUtils_(unpackReal)(item.get());
|
||||
#if !defined(THC_GENERIC_FILE)
|
||||
if (allocator == c10::GetDefaultCPUAllocator()) {
|
||||
self->cdata->unsafe_data<scalar_t>()[i] = value;
|
||||
#else
|
||||
} else {
|
||||
// TODO: this might be slow - consider batched updates?
|
||||
storage_set(
|
||||
at::unsafeStorageFromTH(self->cdata, /*retain=*/true),
|
||||
i,
|
||||
value);
|
||||
#endif
|
||||
}
|
||||
}
|
||||
} catch (const std::exception &e) {
|
||||
THPUtils_setError("tried to construct a storage from a sequence (%s), "
|
||||
THPUtils_setError(THPStorageStr
|
||||
"(): tried to construct a storage from a sequence (%s), "
|
||||
"but one of the items was of type %s instead of %s",
|
||||
THPUtils_typename(first_arg),
|
||||
THPUtils_typename(sequence),
|
||||
THPUtils_typename(item.get()),
|
||||
THPUtils_typeTraits<scalar_t>::python_type_str);
|
||||
return nullptr;
|
||||
}
|
||||
return (PyObject*)self.release();
|
||||
}
|
||||
|
||||
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
|
||||
"no arguments",
|
||||
"(int size)",
|
||||
"(Sequence data)");
|
||||
return nullptr;
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -346,17 +359,6 @@ void THPStorage_(postInit)(PyObject *module)
|
||||
{
|
||||
THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage");
|
||||
if (!THPStorageClass) throw python_error();
|
||||
|
||||
at::Backend backend = at::Backend::CPU;
|
||||
#ifdef THC_GENERIC_FILE
|
||||
backend = at::Backend::CUDA;
|
||||
#endif
|
||||
|
||||
#ifdef THQUANTIZED
|
||||
backend = at::Backend::QuantizedCPU;
|
||||
#endif
|
||||
|
||||
torch::registerStoragePyTypeObject((PyTypeObject*)THPStorageClass, backend, TH_CONCAT_2(at::k, Real));
|
||||
}
|
||||
|
||||
#endif
|
||||
|
@ -6,15 +6,12 @@
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
||||
#if !defined(THC_GENERIC_FILE)
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
#else
|
||||
#include <ATen/native/cuda/Resize.h>
|
||||
#endif
|
||||
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#include <ATen/native/Resize.h>
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#define LSEEK _lseeki64
|
||||
#else
|
||||
@ -86,14 +83,11 @@ static PyObject * THPStorage_(new)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPStorage*)_self;
|
||||
c10::Allocator* allocator = self->cdata->allocator();
|
||||
auto new_storage = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
0,
|
||||
#if defined(THC_GENERIC_FILE)
|
||||
c10::cuda::CUDACachingAllocator::get(),
|
||||
#else
|
||||
c10::GetDefaultCPUAllocator(),
|
||||
#endif
|
||||
allocator,
|
||||
/*resizable=*/true);
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -108,16 +102,22 @@ static PyObject * THPStorage_(resize_)(PyObject *_self, PyObject *number_arg)
|
||||
THPUtils_assert(THPUtils_checkLong(number_arg), "resize_ expects an int, "
|
||||
"but got %s", THPUtils_typename(number_arg));
|
||||
int64_t newsize = THPUtils_unpackLong(number_arg);
|
||||
#if defined(THC_GENERIC_FILE)
|
||||
c10::DeviceType device_type = self->cdata->device_type();
|
||||
if (device_type == at::kCPU) {
|
||||
at::native::resize_bytes_cpu(self->cdata, newsize);
|
||||
#ifdef USE_CUDA
|
||||
} else if (device_type == at::kCUDA) {
|
||||
ptrdiff_t size_bytes_i = newsize;
|
||||
TORCH_CHECK(!c10::overflows<size_t>(size_bytes_i),
|
||||
"Requested storage size (", size_bytes_i,
|
||||
") cannot be represented as a size_t");
|
||||
const auto size_bytes = static_cast<size_t>(size_bytes_i);
|
||||
at::native::resize_bytes_cuda(self->cdata, size_bytes);
|
||||
#else
|
||||
at::native::resize_bytes_cpu(self->cdata, newsize);
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false,
|
||||
"_UntypedStorage.resize_: got unexpected device type ", device_type);
|
||||
}
|
||||
Py_INCREF(self);
|
||||
return (PyObject*)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -138,7 +138,6 @@ static PyObject * THPStorage_(fill_)(PyObject *_self, PyObject *number_arg)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#if !defined(THC_GENERIC_FILE)
|
||||
static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyObject *keywds)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -224,11 +223,7 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
|
||||
auto storage = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes,
|
||||
#if defined(THC_GENERIC_FILE)
|
||||
c10::cuda::CUDACachingAllocator::get(),
|
||||
#else
|
||||
c10::GetDefaultCPUAllocator(),
|
||||
#endif
|
||||
/*resizable=*/true);
|
||||
|
||||
if (scalar_type == at::kByte || scalar_type == at::kChar) {
|
||||
@ -284,7 +279,6 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
|
||||
return (PyObject*)THPStorage_(New)(storage);
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif
|
||||
|
||||
static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObject *keywds)
|
||||
{
|
||||
@ -302,10 +296,6 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj
|
||||
if (shared)
|
||||
shared = at::ALLOCATOR_MAPPED_SHARED;
|
||||
|
||||
#ifdef THC_GENERIC_FILE
|
||||
TORCH_CHECK(false, "not available yet for CUDA");
|
||||
return nullptr;
|
||||
#else
|
||||
size_t actual_nbytes = -1;
|
||||
auto storage = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
@ -320,7 +310,6 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj
|
||||
}
|
||||
|
||||
return (PyObject*)THPStorage_(New)(std::move(storage));
|
||||
#endif
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
@ -429,16 +418,6 @@ static PyObject *THPStorage_(setFromFile)(PyObject *_self, PyObject *args)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef THC_GENERIC_FILE
|
||||
PyObject * THPStorage_(getDevice)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPStorage*)_self;
|
||||
return THPUtils_packInt32(self->cdata->device().index());
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif
|
||||
|
||||
PyObject * THPStorage_(_setCdata)(PyObject *_self, PyObject *new_cdata)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -473,15 +452,10 @@ static PyMethodDef THPStorage_(methods)[] = {
|
||||
{"_write_file", THPStorage_(writeFile), METH_VARARGS, nullptr},
|
||||
{"_new_with_file", THPStorage_(newWithFile), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_set_from_file", THPStorage_(setFromFile), METH_VARARGS, nullptr},
|
||||
#if !defined(THC_GENERIC_FILE)
|
||||
{"from_buffer", castPyCFunctionWithKeywords(THPStorage_(fromBuffer)),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
#endif
|
||||
{"from_file", castPyCFunctionWithKeywords(THPStorage_(fromFile)),
|
||||
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
|
||||
#ifdef THC_GENERIC_FILE
|
||||
{"get_device", THPStorage_(getDevice), METH_NOARGS, nullptr},
|
||||
#endif
|
||||
{"_set_cdata", THPStorage_(_setCdata), METH_O, nullptr},
|
||||
{nullptr}
|
||||
};
|
||||
|
@ -13,13 +13,14 @@ static PyObject * THPStorage_(sharedDecref)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPStorage*)_self;
|
||||
#ifndef THC_GENERIC_FILE
|
||||
c10::DeviceType device_type = self->cdata->device_type();
|
||||
if (device_type == at::kCPU) {
|
||||
c10::StorageImpl *storage = self->cdata;
|
||||
THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr());
|
||||
if (ctx) {
|
||||
ctx->decref();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
Py_INCREF(self);
|
||||
return (PyObject *)self;
|
||||
END_HANDLE_TH_ERRORS
|
||||
@ -29,19 +30,18 @@ static PyObject * THPStorage_(sharedIncref)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPStorage*)_self;
|
||||
#ifndef THC_GENERIC_FILE
|
||||
c10::DeviceType device_type = self->cdata->device_type();
|
||||
if (device_type == at::kCPU) {
|
||||
c10::StorageImpl *storage = self->cdata;
|
||||
THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr());
|
||||
if (ctx) {
|
||||
ctx->incref();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifndef THC_GENERIC_FILE
|
||||
|
||||
static PyObject * THPStorage_(pyNewFilenameStorage)(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
@ -65,6 +65,8 @@ static PyObject * THPStorage_(pyNewFilenameStorage)(PyObject *_unused, PyObject
|
||||
static PyObject * THPStorage_(shareFilename)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCPU,
|
||||
"_share_filename_: only available on CPU");
|
||||
auto self = (THPStorage*)_self;
|
||||
c10::StorageImpl *storage = self->cdata;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -166,6 +168,8 @@ static PyObject * THPStorage_(pyNewFdStorage)(PyObject *_unused, PyObject *args)
|
||||
static PyObject * THPStorage_(shareFd)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
TORCH_CHECK(reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCPU,
|
||||
"_share_fd_: only available on CPU");
|
||||
auto self = (THPStorage*)_self;
|
||||
c10::StorageImpl *storage = self->cdata;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -230,11 +234,12 @@ static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args)
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#else // THC_GENERIC_FILE
|
||||
|
||||
static PyObject * THPStorage_(shareCuda)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#ifdef USE_CUDA
|
||||
TORCH_CHECK(reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCUDA,
|
||||
"_share_cuda_: only available on CUDA");
|
||||
auto self = (THPStorage*)_self;
|
||||
c10::StorageImpl *storage = self->cdata;
|
||||
|
||||
@ -309,12 +314,16 @@ static PyObject * THPStorage_(shareCuda)(PyObject *_self, PyObject *noargs)
|
||||
PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release());
|
||||
PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release());
|
||||
return tuple.release();
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA is not available");
|
||||
#endif
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
static PyObject * THPStorage_(releaseIPCCounter)(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#ifdef USE_CUDA
|
||||
THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected");
|
||||
PyObject *_ref_counter = PyTuple_GET_ITEM(args, 0);
|
||||
PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 1);
|
||||
@ -347,9 +356,13 @@ static PyObject * THPStorage_(releaseIPCCounter)(PyObject *_unused, PyObject *ar
|
||||
// Already warned inside of producer process
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA is not available");
|
||||
#endif
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
#ifdef USE_CUDA
|
||||
static std::string THPStorage_(bytesAsHandleString)(PyObject *handle) {
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
char* buffer;
|
||||
@ -364,10 +377,12 @@ static std::string THPStorage_(bytesAsHandleString)(PyObject *handle) {
|
||||
handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
|
||||
return std::string(buffer, handle_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
|
||||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#ifdef USE_CUDA
|
||||
THPUtils_assert(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected");
|
||||
PyObject *_device = PyTuple_GET_ITEM(args, 0);
|
||||
PyObject *_handle = PyTuple_GET_ITEM(args, 1);
|
||||
@ -405,7 +420,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
cudaEvent_t event;
|
||||
cudaIpcOpenEventHandle(&event, *ipc_event_handle);
|
||||
AT_CUDA_CHECK(
|
||||
C10_CUDA_CHECK(
|
||||
cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0));
|
||||
}
|
||||
|
||||
@ -485,9 +500,11 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
|
||||
base->set_received_cuda(true);
|
||||
|
||||
return THPStorage_(New)(std::move(base));
|
||||
#else
|
||||
TORCH_CHECK(false, "CUDA is not available");
|
||||
#endif
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
#endif
|
||||
|
||||
// Returns an object that holds a "weak" pointer to the c10::StorageImpl. This
|
||||
// pointer keeps the c10::StorageImpl struct live, but does not retain the data
|
||||
@ -544,10 +561,10 @@ PyObject * THPStorage_(sharedFd)(PyObject *_self, PyObject *noargs)
|
||||
HANDLE_TH_ERRORS
|
||||
auto self = (THPStorage*)_self;
|
||||
at::MapAllocator *ctx = nullptr;
|
||||
#ifndef THC_GENERIC_FILE
|
||||
if (self->cdata->device_type() == at::kCPU) {
|
||||
c10::StorageImpl *storage = self->cdata;
|
||||
ctx = at::MapAllocator::fromDataPtr(storage->data_ptr());
|
||||
#endif
|
||||
}
|
||||
|
||||
THPUtils_assert(ctx, "couldn't retrieve a shared file descriptor");
|
||||
return THPUtils_packInt32(ctx->fd());
|
||||
@ -557,33 +574,29 @@ PyObject * THPStorage_(sharedFd)(PyObject *_self, PyObject *noargs)
|
||||
PyObject * THPStorage_(isShared)(PyObject *_self, PyObject *noargs)
|
||||
{
|
||||
auto self = (THPStorage*)_self;
|
||||
#ifdef THC_GENERIC_FILE
|
||||
if (self->cdata->device_type() == at::kCUDA) {
|
||||
Py_RETURN_TRUE;
|
||||
#else
|
||||
}
|
||||
if (at::MapAllocator::fromDataPtr(self->cdata->data_ptr()) ||
|
||||
THManagedMapAllocator::fromDataPtr(self->cdata->data_ptr())) {
|
||||
Py_RETURN_TRUE;
|
||||
} else {
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
||||
static PyMethodDef THPStorage_(sharingMethods)[] = {
|
||||
{"_new_with_weak_ptr", THPStorage_(newWithWeakPtr), METH_O | METH_CLASS, nullptr},
|
||||
#ifdef THC_GENERIC_FILE
|
||||
{"_share_cuda_", THPStorage_(shareCuda), METH_NOARGS, nullptr},
|
||||
{"_new_shared_cuda", THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_release_ipc_counter", THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr},
|
||||
#else
|
||||
{"_share_fd_", THPStorage_(shareFd), METH_NOARGS, nullptr},
|
||||
{"_new_shared_fd", THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_new_using_fd", THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_share_filename_", THPStorage_(shareFilename), METH_NOARGS, nullptr},
|
||||
{"_new_shared_filename", THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_new_using_filename", THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr},
|
||||
#endif
|
||||
{"_release_ipc_counter_cuda", THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_share_fd_cpu_", THPStorage_(shareFd), METH_NOARGS, nullptr},
|
||||
{"_new_shared_fd_cpu", THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_new_using_fd_cpu", THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_share_filename_cpu_", THPStorage_(shareFilename), METH_NOARGS, nullptr},
|
||||
{"_new_shared_filename_cpu", THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_new_using_filename_cpu", THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr},
|
||||
{"_weak_ref", THPStorage_(weakRef), METH_NOARGS, nullptr},
|
||||
{"_free_weak_ref", THPStorage_(freeWeakRef), METH_O | METH_STATIC, nullptr},
|
||||
{"_expired", THPStorage_(expired), METH_O | METH_STATIC, nullptr},
|
||||
|
@ -2,11 +2,7 @@
|
||||
#define TH_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
|
||||
#else
|
||||
|
||||
#ifdef THC_GENERIC_FILE
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#else
|
||||
#include <c10/core/CPUAllocator.h>
|
||||
#endif
|
||||
|
||||
// save_save is necessary since the old eager format saved storages as
|
||||
// [size + data], but the v1.5 eager format removes this since size is saved in
|
||||
@ -14,19 +10,18 @@
|
||||
template <class io>
|
||||
void THPStorage_(writeFileRaw)(c10::StorageImpl *self, io fd, bool save_size, uint64_t element_size)
|
||||
{
|
||||
#ifdef THC_GENERIC_FILE
|
||||
c10::cuda::CUDAGuard guard(self->device());
|
||||
#endif
|
||||
|
||||
c10::DeviceGuard guard(self->device());
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
scalar_t *data;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::unique_ptr<char[]> cpu_data;
|
||||
int64_t size_bytes = self->nbytes();
|
||||
int64_t numel = size_bytes / element_size;
|
||||
#ifndef THC_GENERIC_FILE
|
||||
if (self->device_type() == at::kCPU) {
|
||||
data = self->data<scalar_t>();
|
||||
#else
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::unique_ptr<char[]> cpu_data(new char[size_bytes]);
|
||||
#ifdef USE_CUDA
|
||||
} else if (self->device_type() == at::kCUDA) {
|
||||
cpu_data = std::unique_ptr<char[]>(new char[size_bytes]);
|
||||
data = (scalar_t*)cpu_data.get();
|
||||
C10_CUDA_CHECK(cudaMemcpy(
|
||||
data,
|
||||
@ -34,6 +29,9 @@ void THPStorage_(writeFileRaw)(c10::StorageImpl *self, io fd, bool save_size, ui
|
||||
size_bytes,
|
||||
cudaMemcpyDeviceToHost));
|
||||
#endif
|
||||
} else {
|
||||
TORCH_CHECK(false, "writeFileRaw: Device not recognized: ", self->device_type());
|
||||
}
|
||||
if (save_size) {
|
||||
if (torch::utils::THP_nativeByteOrder() ==
|
||||
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
|
||||
@ -93,13 +91,10 @@ template <class io>
|
||||
c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
|
||||
io file, c10::intrusive_ptr<c10::StorageImpl> storage, uint64_t element_size)
|
||||
{
|
||||
#ifdef THC_GENERIC_FILE
|
||||
c10::cuda::OptionalCUDAGuard guard;
|
||||
c10::OptionalDeviceGuard guard;
|
||||
if (storage.defined()) {
|
||||
guard.set_device(storage->device());
|
||||
guard.reset_device(storage->device());
|
||||
}
|
||||
#endif
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
scalar_t *data;
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
@ -118,11 +113,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
|
||||
storage = c10::make_intrusive<at::StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
nbytes,
|
||||
#if defined(THC_GENERIC_FILE)
|
||||
c10::cuda::CUDACachingAllocator::get(),
|
||||
#else
|
||||
c10::GetDefaultCPUAllocator(),
|
||||
#endif
|
||||
/*resizable=*/true);
|
||||
} else {
|
||||
int64_t _storage_nbytes = storage->nbytes();
|
||||
@ -133,13 +124,15 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
|
||||
_storage_nbytes);
|
||||
}
|
||||
|
||||
#ifndef THC_GENERIC_FILE
|
||||
data = storage->data<scalar_t>();
|
||||
#else
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
||||
std::unique_ptr<char[]> cpu_data(new char[nbytes]);
|
||||
std::unique_ptr<char[]> cpu_data;
|
||||
|
||||
if (storage->device_type() == at::kCPU) {
|
||||
data = storage->data<scalar_t>();
|
||||
} else {
|
||||
cpu_data = std::unique_ptr<char[]>(new char[nbytes]);
|
||||
data = (scalar_t*)cpu_data.get();
|
||||
#endif
|
||||
}
|
||||
|
||||
// fast track for bytes and little endian
|
||||
if (element_size == 1 ||
|
||||
@ -179,8 +172,10 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef THC_GENERIC_FILE
|
||||
#ifdef USE_CUDA
|
||||
if (storage->device_type() == at::kCUDA) {
|
||||
C10_CUDA_CHECK(cudaMemcpy(storage->data<scalar_t>(), data, nbytes, cudaMemcpyHostToDevice));
|
||||
}
|
||||
#endif
|
||||
return storage;
|
||||
}
|
||||
|
@ -139,7 +139,7 @@ void THPUtils_addPyMethodDefs(std::vector<PyMethodDef>& vector, PyMethodDef* met
|
||||
int THPUtils_getCallable(PyObject *arg, PyObject **result);
|
||||
|
||||
#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
|
||||
#define THPStoragePtr TH_CONCAT_3(THP,Real,StoragePtr)
|
||||
#define THPStoragePtr TH_CONCAT_2(THP,StoragePtr)
|
||||
#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr)
|
||||
#define THSPTensorPtr TH_CONCAT_3(THSP,Real,TensorPtr)
|
||||
|
||||
|
@ -635,25 +635,6 @@ from .random import * # noqa: F403
|
||||
# Define Storage and Tensor classes
|
||||
################################################################################
|
||||
|
||||
|
||||
from ..storage import _StorageBase
|
||||
|
||||
|
||||
if not hasattr(torch._C, 'CudaByteStorageBase'):
|
||||
# Define dummy base classes
|
||||
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool', 'BFloat16',
|
||||
'ComplexDouble', 'ComplexFloat']:
|
||||
tensor_name = 'Cuda{0}TensorBase'.format(t)
|
||||
|
||||
torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
|
||||
|
||||
storage_name = 'CudaByteStorageBase'
|
||||
torch._C.__dict__[storage_name] = _dummy_type(storage_name)
|
||||
|
||||
torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
|
||||
torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
|
||||
|
||||
|
||||
@staticmethod # type: ignore[misc]
|
||||
def _lazy_new(cls, *args, **kwargs):
|
||||
_lazy_init()
|
||||
@ -675,9 +656,9 @@ class _CudaBase(object):
|
||||
|
||||
__new__ = _lazy_new
|
||||
|
||||
from torch.storage import _TypedStorage, _LegacyStorage
|
||||
from torch.storage import _LegacyStorage
|
||||
|
||||
class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
|
||||
class _CudaLegacyStorage(_LegacyStorage):
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, **kwargs):
|
||||
raise RuntimeError('from_buffer: Not available for CUDA storage')
|
||||
@ -687,70 +668,72 @@ class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
|
||||
raise RuntimeError('_new_with_weak_ptr: Not available for CUDA storage')
|
||||
|
||||
@classmethod
|
||||
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
|
||||
raise RuntimeError('_new_shared_filename: Not available for CUDA storage')
|
||||
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None):
|
||||
raise RuntimeError('_new_shared_filename_cpu: Not available for CUDA storage')
|
||||
|
||||
class ByteStorage(_LegacyStorage):
|
||||
class ByteStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.uint8
|
||||
|
||||
class DoubleStorage(_LegacyStorage):
|
||||
class DoubleStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.double
|
||||
|
||||
class FloatStorage(_LegacyStorage):
|
||||
class FloatStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.float
|
||||
|
||||
class HalfStorage(_LegacyStorage):
|
||||
class HalfStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.half
|
||||
|
||||
class LongStorage(_LegacyStorage):
|
||||
class LongStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.long
|
||||
|
||||
class IntStorage(_LegacyStorage):
|
||||
class IntStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.int
|
||||
|
||||
class ShortStorage(_LegacyStorage):
|
||||
class ShortStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.short
|
||||
|
||||
class CharStorage(_LegacyStorage):
|
||||
class CharStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.int8
|
||||
|
||||
class BoolStorage(_LegacyStorage):
|
||||
class BoolStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.bool
|
||||
|
||||
class BFloat16Storage(_LegacyStorage):
|
||||
class BFloat16Storage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.bfloat16
|
||||
|
||||
class ComplexDoubleStorage(_LegacyStorage):
|
||||
class ComplexDoubleStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.cdouble
|
||||
|
||||
class ComplexFloatStorage(_LegacyStorage):
|
||||
class ComplexFloatStorage(_CudaLegacyStorage):
|
||||
@classproperty
|
||||
def dtype(self):
|
||||
return torch.cfloat
|
||||
|
||||
torch._storage_classes.add(_UntypedStorage)
|
||||
del _LegacyStorage
|
||||
del _CudaLegacyStorage
|
||||
|
||||
torch._storage_classes.add(DoubleStorage)
|
||||
torch._storage_classes.add(FloatStorage)
|
||||
torch._storage_classes.add(LongStorage)
|
||||
|
@ -122,7 +122,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
|
||||
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
|
||||
else:
|
||||
# We already ref counting this Storage, but producer needs new ref-counters to be released.
|
||||
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
|
||||
storage_cls._release_ipc_counter_cuda(ref_counter_handle, ref_counter_offset, device=storage_device)
|
||||
|
||||
t = torch._utils._rebuild_tensor(
|
||||
torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
|
||||
@ -299,7 +299,7 @@ def rebuild_storage_fd(cls, df, size):
|
||||
storage = storage_from_cache(cls, fd_id(fd))
|
||||
if storage is not None:
|
||||
return storage
|
||||
storage = cls._new_shared_fd(fd, size)
|
||||
storage = cls._new_shared_fd_cpu(fd, size)
|
||||
shared_cache[fd_id(fd)] = StorageWeakRef(storage)
|
||||
return storage
|
||||
finally:
|
||||
@ -311,10 +311,10 @@ def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
|
||||
if storage is not None:
|
||||
return storage._shared_decref()
|
||||
if dtype is None:
|
||||
storage = torch._UntypedStorage._new_shared_filename(manager, handle, size)
|
||||
storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size)
|
||||
else:
|
||||
byte_size = size * torch._utils._element_size(dtype)
|
||||
untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename(manager, handle, byte_size)
|
||||
untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
|
||||
storage = torch._TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=dtype)
|
||||
@ -344,7 +344,7 @@ def reduce_storage(storage):
|
||||
if storage.is_cuda:
|
||||
raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
metadata = storage._share_filename_()
|
||||
metadata = storage._share_filename_cpu_()
|
||||
cache_key = metadata[1]
|
||||
rebuild = rebuild_storage_filename
|
||||
if isinstance(storage, torch._TypedStorage):
|
||||
@ -355,7 +355,7 @@ def reduce_storage(storage):
|
||||
# (with size 0) cannot be mmapped.
|
||||
return (rebuild_storage_empty, (type(storage),))
|
||||
else:
|
||||
fd, size = storage._share_fd_()
|
||||
fd, size = storage._share_fd_cpu_()
|
||||
df = multiprocessing.reduction.DupFd(fd)
|
||||
cache_key = fd_id(fd)
|
||||
metadata = (df, size)
|
||||
|
@ -880,19 +880,21 @@ class PackageExporter:
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, we can
|
||||
# remove this case
|
||||
storage = obj._storage
|
||||
untyped_storage = obj._storage
|
||||
storage_type_str = obj.pickle_storage_type()
|
||||
storage_type = getattr(torch, storage_type_str)
|
||||
dtype = obj.dtype
|
||||
storage_numel = obj.size()
|
||||
|
||||
else:
|
||||
storage = obj
|
||||
elif isinstance(obj, torch._UntypedStorage):
|
||||
untyped_storage = obj
|
||||
storage_type = normalize_storage_type(type(storage))
|
||||
dtype = torch.uint8
|
||||
storage_numel = storage.nbytes()
|
||||
else:
|
||||
raise RuntimeError(f'storage type not recognized: {type(obj)}')
|
||||
|
||||
storage = cast(Storage, storage)
|
||||
storage: Storage = cast(Storage, untyped_storage)
|
||||
location = location_tag(storage)
|
||||
|
||||
# serialize storage if not already written
|
||||
|
@ -115,13 +115,13 @@ def check_module_version_greater_or_equal(module, req_version_tuple, error_if_ma
|
||||
|
||||
|
||||
def _cpu_tag(obj):
|
||||
if type(obj).__module__ == 'torch':
|
||||
if obj.device.type == 'cpu':
|
||||
return 'cpu'
|
||||
|
||||
|
||||
def _cuda_tag(obj):
|
||||
if type(obj).__module__ == 'torch.cuda':
|
||||
return 'cuda:' + str(obj.get_device())
|
||||
if obj.device.type == 'cuda':
|
||||
return 'cuda:' + str(obj.device.index)
|
||||
|
||||
|
||||
def _cpu_deserialize(obj, location):
|
||||
@ -151,9 +151,8 @@ def _cuda_deserialize(obj, location):
|
||||
if location.startswith('cuda'):
|
||||
device = validate_cuda_device(location)
|
||||
if getattr(obj, "_torch_load_uninitialized", False):
|
||||
storage_type = getattr(torch.cuda, type(obj).__name__)
|
||||
with torch.cuda.device(device):
|
||||
return storage_type(obj.nbytes())
|
||||
return torch._UntypedStorage(obj.nbytes(), device=torch.device(location))
|
||||
else:
|
||||
return obj.cuda(device)
|
||||
|
||||
@ -162,7 +161,7 @@ register_package(10, _cpu_tag, _cpu_deserialize)
|
||||
register_package(20, _cuda_tag, _cuda_deserialize)
|
||||
|
||||
|
||||
def location_tag(storage: Union[Storage, torch.storage._TypedStorage]):
|
||||
def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
|
||||
for _, tagger, _ in _package_registry:
|
||||
location = tagger(storage)
|
||||
if location:
|
||||
@ -414,6 +413,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
return ('module', obj, source_file, source)
|
||||
|
||||
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
|
||||
storage: torch._UntypedStorage
|
||||
|
||||
if isinstance(obj, torch.storage._TypedStorage):
|
||||
# TODO: Once we decide to break serialization FC, this case
|
||||
# can be deleted
|
||||
@ -424,12 +425,14 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
dtype = obj.dtype
|
||||
storage_numel = obj.size()
|
||||
|
||||
else:
|
||||
elif isinstance(obj, torch._UntypedStorage):
|
||||
storage = obj
|
||||
storage_dtype = storage.dtype
|
||||
storage_dtype = torch.uint8
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
dtype = torch.uint8
|
||||
storage_numel = cast(Storage, storage).nbytes()
|
||||
storage_numel = storage.nbytes()
|
||||
else:
|
||||
raise TypeError(f'type not recognized: {type(obj)}')
|
||||
|
||||
# If storage is allocated, ensure that any other saved storages
|
||||
# pointing to the same data all have the same dtype. If storage is
|
||||
@ -444,7 +447,6 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||
storage_dtypes[storage.data_ptr()] = storage_dtype
|
||||
|
||||
view_metadata: Optional[Tuple[str, int, int]]
|
||||
storage = cast(Storage, storage)
|
||||
|
||||
# Offset is always 0, but we keep it for backwards compatibility
|
||||
# with the old serialization format (which supported storage views)
|
||||
@ -552,12 +554,10 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
|
||||
|
||||
else:
|
||||
storage = obj
|
||||
storage_dtype = storage.dtype
|
||||
storage_dtype = torch.uint8
|
||||
storage_type = normalize_storage_type(type(obj))
|
||||
storage_numel = storage.nbytes()
|
||||
|
||||
storage = cast(Storage, storage)
|
||||
|
||||
# If storage is allocated, ensure that any other saved storages
|
||||
# pointing to the same data all have the same dtype. If storage is
|
||||
# not allocated, don't perform this check
|
||||
@ -1009,6 +1009,9 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
|
||||
assert typename == 'storage', \
|
||||
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
|
||||
storage_type, key, location, numel = data
|
||||
if storage_type is torch._UntypedStorage:
|
||||
dtype = torch.uint8
|
||||
else:
|
||||
dtype = storage_type.dtype
|
||||
|
||||
if key not in loaded_storages:
|
||||
|
145
torch/storage.py
145
torch/storage.py
@ -24,7 +24,7 @@ class _StorageBase(object):
|
||||
def __init__(self, *args, **kwargs): ... # noqa: E704
|
||||
def __len__(self) -> int: ... # noqa: E704
|
||||
def __getitem__(self, idx): ... # noqa: E704
|
||||
def copy_(self, source: T) -> T: ... # noqa: E704
|
||||
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
|
||||
def nbytes(self) -> int: ... # noqa: E704
|
||||
|
||||
def size(self) -> int:
|
||||
@ -37,25 +37,38 @@ class _StorageBase(object):
|
||||
def data_ptr(self) -> int: ... # noqa: E704
|
||||
|
||||
# Defined in torch/csrc/generic/StorageSharing.cpp
|
||||
def _share_filename_(self): ... # noqa: E704
|
||||
def _share_fd_(self): ... # noqa: E704
|
||||
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
|
||||
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
|
||||
@classmethod
|
||||
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
|
||||
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
|
||||
@classmethod
|
||||
def _release_ipc_counter(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _shared_decref(self) -> T: ... # noqa: E704
|
||||
def _write_file(self, *args, **kwargs): ... # noqa: E704
|
||||
def resize_(self, size: int): ... # noqa: E704
|
||||
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def is_pinned(self) -> bool: ... # noqa: E704
|
||||
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
|
||||
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
|
||||
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
|
||||
def is_shared(self) -> bool: ... # noqa: E704
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
|
||||
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
|
||||
|
||||
def __str__(self):
|
||||
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
|
||||
return content + f'\n[{torch.typename(self)} of size {len(self)}]'
|
||||
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
|
||||
return data_str + (
|
||||
f'\n[{torch.typename(self)}(device={self.device}) '
|
||||
f'of size {len(self)}]')
|
||||
|
||||
def __repr__(self):
|
||||
return str(self)
|
||||
@ -84,9 +97,7 @@ class _StorageBase(object):
|
||||
|
||||
def clone(self):
|
||||
"""Returns a copy of this storage"""
|
||||
device = self.get_device() if self.is_cuda else -1
|
||||
with torch.cuda.device(device):
|
||||
return type(self)(self.nbytes()).copy_(self)
|
||||
return type(self)(self.nbytes(), device=self.device).copy_(self)
|
||||
|
||||
def tolist(self):
|
||||
"""Returns a list containing the elements of this storage"""
|
||||
@ -94,7 +105,10 @@ class _StorageBase(object):
|
||||
|
||||
def cpu(self):
|
||||
"""Returns a CPU copy of this storage if it's not already on the CPU"""
|
||||
return _type(self, getattr(torch, self.__class__.__name__))
|
||||
if self.device.type != 'cpu':
|
||||
return torch._UntypedStorage(self.size()).copy_(self, False)
|
||||
else:
|
||||
return self
|
||||
|
||||
def _to(self, dtype):
|
||||
if not isinstance(dtype, torch.dtype):
|
||||
@ -157,7 +171,7 @@ class _StorageBase(object):
|
||||
if self.is_cuda:
|
||||
raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
|
||||
import torch.cuda
|
||||
allocator = torch.cuda._host_allocator() # type: ignore[attr-defined]
|
||||
allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined]
|
||||
return type(self)(self.size(), allocator=allocator).copy_(self)
|
||||
|
||||
def share_memory_(self):
|
||||
@ -173,26 +187,31 @@ class _StorageBase(object):
|
||||
if self.is_cuda:
|
||||
pass # CUDA doesn't use POSIX shared memory
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
self._share_filename_()
|
||||
self._share_filename_cpu_()
|
||||
else:
|
||||
self._share_fd_()
|
||||
self._share_fd_cpu_()
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _new_shared(cls, size):
|
||||
def _new_shared(cls, size, *, device='cpu'):
|
||||
"""Creates a new storage in shared memory with the same data type"""
|
||||
from torch.multiprocessing import get_sharing_strategy
|
||||
if cls.is_cuda:
|
||||
return cls(size)
|
||||
device = torch.device(device)
|
||||
if device.type == 'cuda':
|
||||
return cls(size, device=device)
|
||||
elif get_sharing_strategy() == 'file_system':
|
||||
return cls._new_using_filename(size)
|
||||
return cls._new_using_filename_cpu(size)
|
||||
else:
|
||||
return cls._new_using_fd(size)
|
||||
return cls._new_using_fd_cpu(size)
|
||||
|
||||
def _untyped(self):
|
||||
return self
|
||||
|
||||
|
||||
class _UntypedStorage(torch._C.StorageBase, _StorageBase):
|
||||
pass
|
||||
|
||||
|
||||
def _load_from_bytes(b):
|
||||
return torch.load(io.BytesIO(b))
|
||||
|
||||
@ -320,7 +339,7 @@ class _TypedStorage:
|
||||
"\nNo positional arguments should be given when using "
|
||||
"'wrap_storage'")
|
||||
|
||||
if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
|
||||
if not isinstance(wrap_storage, torch._UntypedStorage):
|
||||
raise TypeError(
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
|
||||
@ -371,7 +390,7 @@ class _TypedStorage:
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
|
||||
if not isinstance(wrap_storage, torch._UntypedStorage):
|
||||
raise TypeError(
|
||||
arg_error_msg +
|
||||
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
|
||||
@ -382,23 +401,16 @@ class _TypedStorage:
|
||||
self.dtype = torch.get_default_dtype() if dtype is None else dtype
|
||||
device = torch.device('cpu' if device is None else device)
|
||||
|
||||
if device.type == 'cpu':
|
||||
untyped_storage_class = torch._UntypedStorage
|
||||
elif device.type == 'cuda':
|
||||
untyped_storage_class = torch.cuda._UntypedStorage
|
||||
else:
|
||||
raise RuntimeError(f"Storage device not recognized: {device}")
|
||||
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
if device.type == 'cuda':
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
|
||||
if len(args) == 0:
|
||||
self._storage = untyped_storage_class()
|
||||
self._storage = torch._UntypedStorage(device=device)
|
||||
|
||||
elif len(args) == 1:
|
||||
if _isint(args[0]):
|
||||
self._storage = untyped_storage_class(int(args[0]) * self.element_size())
|
||||
self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device)
|
||||
elif isinstance(args[0], collections.abc.Sequence):
|
||||
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
|
||||
else:
|
||||
@ -420,16 +432,12 @@ class _TypedStorage:
|
||||
return self._storage
|
||||
|
||||
def _new_wrapped_storage(self, untyped_storage):
|
||||
module = eval(untyped_storage.__module__)
|
||||
assert type(untyped_storage) == module._UntypedStorage
|
||||
assert type(untyped_storage) == torch._UntypedStorage
|
||||
|
||||
if type(self) == _TypedStorage:
|
||||
return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
|
||||
else:
|
||||
# NOTE: We need to use the module of untyped_storage in case self's
|
||||
# module is different, e.g. if self is on CPU and untyped_storage
|
||||
# is on CUDA, and vice versa
|
||||
return getattr(module, type(self).__name__)(wrap_storage=untyped_storage)
|
||||
return type(self)(wrap_storage=untyped_storage)
|
||||
|
||||
def __len__(self):
|
||||
return self._storage.nbytes() // self.element_size()
|
||||
@ -505,7 +513,7 @@ class _TypedStorage:
|
||||
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
|
||||
return tmp_tensor[idx_wrapped].item()
|
||||
|
||||
def copy_(self, source: T, non_blocking=None):
|
||||
def copy_(self, source: T, non_blocking: bool = None):
|
||||
self._storage.copy_(source._untyped(), non_blocking)
|
||||
return self
|
||||
|
||||
@ -527,7 +535,7 @@ class _TypedStorage:
|
||||
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
|
||||
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
|
||||
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
|
||||
cuda_storage = self._storage.cuda(device, non_blocking, **kwargs)
|
||||
cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
|
||||
return self._new_wrapped_storage(cuda_storage)
|
||||
|
||||
def element_size(self):
|
||||
@ -585,13 +593,12 @@ class _TypedStorage:
|
||||
self._storage.share_memory_()
|
||||
return self
|
||||
|
||||
def _new_shared(self, size):
|
||||
def _new_shared(self, size, *, device=None):
|
||||
"""Creates a new storage in shared memory with the same data type"""
|
||||
if self.is_cuda:
|
||||
untyped_cls = torch.cuda._UntypedStorage
|
||||
else:
|
||||
untyped_cls = torch._UntypedStorage
|
||||
untyped_storage = untyped_cls._new_shared(size * self.element_size())
|
||||
if device is None:
|
||||
device = 'cpu'
|
||||
device = torch.device(device)
|
||||
untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device)
|
||||
return _TypedStorage(
|
||||
wrap_storage=untyped_storage,
|
||||
dtype=self.dtype)
|
||||
@ -636,16 +643,9 @@ class _TypedStorage:
|
||||
if cls == _TypedStorage:
|
||||
dtype = torch.get_default_dtype() if dtype is None else dtype
|
||||
device = torch.device('cpu' if device is None else device)
|
||||
|
||||
if device.type == 'cpu':
|
||||
untyped_cls = torch._UntypedStorage
|
||||
elif device.type == 'cuda':
|
||||
untyped_cls = torch.cuda._UntypedStorage
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"_TypedStorage.from_buffer: device '{device}' not recognized")
|
||||
untyped_storage: Union[torch._UntypedStorage, torch.cuda._UntypedStorage]
|
||||
untyped_storage = untyped_cls.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
if device.type != 'cpu':
|
||||
raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}')
|
||||
untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
|
||||
else:
|
||||
if dtype is not None or len(args) == 5:
|
||||
@ -658,7 +658,7 @@ class _TypedStorage:
|
||||
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
|
||||
|
||||
dtype = cls.dtype
|
||||
untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
|
||||
|
||||
return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
|
||||
|
||||
@ -770,10 +770,10 @@ class _TypedStorage:
|
||||
|
||||
@classmethod
|
||||
def _new_shared_cuda(cls, *args, **kwargs):
|
||||
return torch.cuda._UntypedStorage._new_shared_cuda(*args, **kwargs)
|
||||
return torch._UntypedStorage._new_shared_cuda(*args, **kwargs)
|
||||
|
||||
def _share_filename_(self, *args, **kwargs):
|
||||
manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
|
||||
def _share_filename_cpu_(self, *args, **kwargs):
|
||||
manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
|
||||
return manager_handle, storage_handle, size // self.element_size()
|
||||
|
||||
def _shared_decref(self):
|
||||
@ -781,23 +781,14 @@ class _TypedStorage:
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def _release_ipc_counter(cls, *args, device=None, **kwargs):
|
||||
device = torch.device('cpu' if device is None else device)
|
||||
|
||||
if device.type == 'cpu':
|
||||
untyped_cls = torch._UntypedStorage
|
||||
elif device.type == 'cuda':
|
||||
untyped_cls = torch.cuda._UntypedStorage
|
||||
else:
|
||||
raise RuntimeError(f"device {device} not recognized")
|
||||
|
||||
return untyped_cls._release_ipc_counter(*args, **kwargs)
|
||||
def _release_ipc_counter_cuda(cls, *args, device=None, **kwargs):
|
||||
return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
|
||||
def _shared_incref(self, *args, **kwargs):
|
||||
return self._storage._shared_incref(*args, **kwargs)
|
||||
|
||||
def _share_fd_(self, *args, **kwargs):
|
||||
fd, size = self._storage._share_fd_(*args, **kwargs)
|
||||
def _share_fd_cpu_(self, *args, **kwargs):
|
||||
fd, size = self._storage._share_fd_cpu_(*args, **kwargs)
|
||||
return fd, size // self.element_size()
|
||||
|
||||
def _get_legacy_storage_class(self):
|
||||
@ -837,13 +828,13 @@ class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
|
||||
return cls(wrap_storage=untyped_storage)
|
||||
|
||||
@classmethod
|
||||
def _release_ipc_counter(cls, *args, **kwargs):
|
||||
return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
|
||||
def _release_ipc_counter_cuda(cls, *args, **kwargs):
|
||||
return eval(cls.__module__)._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _new_shared_filename(cls, manager, obj, size):
|
||||
def _new_shared_filename_cpu(cls, manager, obj, size):
|
||||
bytes_size = size * torch._utils._element_size(cls.dtype)
|
||||
return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
|
||||
return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
|
||||
|
||||
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
|
||||
try:
|
||||
|
@ -136,7 +136,7 @@ def default_collate(batch):
|
||||
# If we're in a background process, concatenate directly into a
|
||||
# shared memory tensor to avoid an extra copy
|
||||
numel = sum(x.numel() for x in batch)
|
||||
storage = elem.storage()._new_shared(numel)
|
||||
storage = elem.storage()._new_shared(numel, device=elem.device)
|
||||
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
|
||||
return torch.stack(batch, 0, out=out)
|
||||
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
||||
|
Reference in New Issue
Block a user