Merge torch.cuda._UntypedStorage into torch._UntypedStorage (#75459)

Fixes #74933

Pull Request resolved: https://github.com/pytorch/pytorch/pull/75459
Approved by: https://github.com/ezyang
This commit is contained in:
Kurt Mohler
2022-05-19 13:54:37 +00:00
committed by PyTorch MergeBot
parent ac1837ddd3
commit aea6e2c396
32 changed files with 357 additions and 565 deletions

View File

@ -570,7 +570,7 @@ class TestCuda(TestCase):
self.assertTrue(isinstance(q_copy[1], torch.cuda.IntTensor))
self.assertTrue(isinstance(q_copy[2], torch.cuda.FloatTensor))
self.assertTrue(isinstance(q_copy[3], torch.storage._TypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch.cuda._UntypedStorage))
self.assertTrue(isinstance(q_copy[3]._storage, torch._UntypedStorage))
q_copy[1].fill_(10)
self.assertEqual(q_copy[3], torch.cuda.IntStorage(10).fill_(10))

View File

@ -41,59 +41,42 @@ class TestPublicBindings(TestCase):
"AVG",
"BenchmarkConfig",
"BenchmarkExecutionStats",
"BFloat16StorageBase",
"Block",
"BoolStorageBase",
"BoolType",
"BufferDict",
"ByteStorageBase",
"StorageBase",
"CallStack",
"Capsule",
"CharStorageBase",
"ClassType",
"clear_autocast_cache",
"Code",
"CompilationUnit",
"CompleteArgumentSpec",
"ComplexDoubleStorageBase",
"ComplexFloatStorageBase",
"ComplexType",
"ConcreteModuleType",
"ConcreteModuleTypeBuilder",
"CONV_BN_FUSION",
"cpp",
"CudaBFloat16StorageBase",
"CudaBFloat16TensorBase",
"CudaBFloat16TensorBase",
"CudaBoolStorageBase",
"CudaBoolTensorBase",
"CudaBoolTensorBase",
"CudaByteStorageBase",
"CudaByteTensorBase",
"CudaByteTensorBase",
"CudaCharStorageBase",
"CudaCharTensorBase",
"CudaCharTensorBase",
"CudaComplexDoubleStorageBase",
"CudaComplexDoubleTensorBase",
"CudaComplexDoubleTensorBase",
"CudaComplexFloatStorageBase",
"CudaComplexFloatTensorBase",
"CudaComplexFloatTensorBase",
"CudaDoubleStorageBase",
"CudaDoubleTensorBase",
"CudaDoubleTensorBase",
"CudaFloatStorageBase",
"CudaFloatTensorBase",
"CudaHalfStorageBase",
"CudaHalfTensorBase",
"CudaIntStorageBase",
"CudaIntTensorBase",
"CudaIntTensorBase",
"CudaLongStorageBase",
"CudaLongTensorBase",
"CudaLongTensorBase",
"CudaShortStorageBase",
"CudaShortTensorBase",
"CudaShortTensorBase",
"DeepCopyMemoTable",
@ -103,7 +86,6 @@ class TestPublicBindings(TestCase):
"DeviceObjType",
"DictType",
"DisableTorchFunction",
"DoubleStorageBase",
"dtype",
"EnumType",
"ErrorReport",
@ -111,7 +93,6 @@ class TestPublicBindings(TestCase):
"FatalError",
"FileCheck",
"finfo",
"FloatStorageBase",
"FloatType",
"fork",
"FunctionSchema",
@ -126,7 +107,6 @@ class TestPublicBindings(TestCase):
"Gradient",
"Graph",
"GraphExecutorState",
"HalfStorageBase",
"has_cuda",
"has_cudnn",
"has_lapack",
@ -143,7 +123,6 @@ class TestPublicBindings(TestCase):
"init_num_threads",
"INSERT_FOLD_PREPACK_OPS",
"InterfaceType",
"IntStorageBase",
"IntType",
"SymIntType",
"IODescriptor",
@ -159,7 +138,6 @@ class TestPublicBindings(TestCase):
"LiteScriptModule",
"LockingLogger",
"LoggerBase",
"LongStorageBase",
"memory_format",
"merge_type_from_type_comment",
"MobileOptimizerType",
@ -177,12 +155,7 @@ class TestPublicBindings(TestCase):
"PyObjectType",
"PyTorchFileReader",
"PyTorchFileWriter",
"QInt32StorageBase",
"QInt8StorageBase",
"qscheme",
"QUInt4x2StorageBase",
"QUInt2x4StorageBase",
"QUInt8StorageBase",
"read_vitals",
"REMOVE_DROPOUT",
"RRefType",
@ -209,7 +182,6 @@ class TestPublicBindings(TestCase):
"set_num_interop_threads",
"set_num_threads",
"set_vital",
"ShortStorageBase",
"Size",
"StaticModule",
"Stream",

View File

@ -6286,7 +6286,7 @@ class TestTorch(TestCase):
torch.storage._LegacyStorage()
for storage_class in torch._storage_classes:
if storage_class in [torch._UntypedStorage, torch.cuda._UntypedStorage, torch._TypedStorage]:
if storage_class in [torch._UntypedStorage, torch._TypedStorage]:
continue
device = 'cuda' if storage_class.__module__ == 'torch.cuda' else 'cpu'
@ -6371,22 +6371,16 @@ class TestTorch(TestCase):
storage_classes = [
torch.cuda.ByteStorage,
torch.cuda.FloatStorage,
torch.cuda._UntypedStorage,
]
for storage_class in storage_classes:
with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
storage_class.from_buffer()
if storage_class == torch.cuda._UntypedStorage:
with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
storage_class._new_with_weak_ptr()
else:
with self.assertRaisesRegex(AttributeError, r'has no attribute'):
storage_class._new_with_weak_ptr()
with self.assertRaisesRegex(RuntimeError, r'Not available for CUDA storage'):
storage_class._new_shared_filename(0, 0, 0)
storage_class._new_shared_filename_cpu(0, 0, 0)
def test_storage_casts(self):
storage = torch.IntStorage([-1, 0, 1, 2, 3, 4])

View File

@ -832,10 +832,8 @@ libtorch_python_cuda_core_sources = [
"torch/csrc/cuda/Event.cpp",
"torch/csrc/cuda/Module.cpp",
"torch/csrc/cuda/python_comm.cpp",
"torch/csrc/cuda/Storage.cpp",
"torch/csrc/cuda/Stream.cpp",
"torch/csrc/cuda/Graph.cpp",
"torch/csrc/cuda/serialization.cpp",
"torch/csrc/cuda/shared/cudart.cpp",
"torch/csrc/cuda/shared/nvtx.cpp",
"torch/csrc/cuda/utils.cpp",

View File

@ -786,33 +786,7 @@ def gen_pyi(
# Generate type signatures for legacy classes
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# TODO: These are deprecated, maybe we shouldn't type hint them
legacy_storage_base_hints = []
dt = (
"Double",
"Float",
"Long",
"Int",
"Short",
"Char",
"Byte",
"Bool",
"Half",
"BFloat16",
"ComplexDouble",
"ComplexFloat",
"QUInt8",
"QInt8",
"QInt32",
"QUInt4x2",
"QUInt2x4",
)
for c in dt:
legacy_storage_base_hints.append("class {}StorageBase(object): ...".format(c))
for c in dt:
legacy_storage_base_hints.append(
"class Cuda{}StorageBase(object): ...".format(c)
)
legacy_storage_base_hints = ["class StorageBase(object): ..."]
legacy_class_hints = []
for c in (

View File

@ -619,14 +619,11 @@ __all__.extend(['e', 'pi', 'nan', 'inf'])
################################################################################
from ._tensor import Tensor
from .storage import _StorageBase, _TypedStorage, _LegacyStorage
from .storage import _StorageBase, _TypedStorage, _LegacyStorage, _UntypedStorage
# NOTE: New <type>Storage classes should never be added. When adding a new
# dtype, use torch.storage._TypedStorage directly.
class _UntypedStorage(_C.ByteStorageBase, _StorageBase):
pass
class ByteStorage(_LegacyStorage):
@classproperty
def dtype(self):
@ -784,7 +781,8 @@ from .functional import * # noqa: F403
# Remove unnecessary members
################################################################################
del ByteStorageBase
del _StorageBase
del _LegacyStorage
################################################################################
# Define _assert

View File

@ -5,7 +5,7 @@ from torch._C import _add_docstr as add_docstr
storage_classes = [
'ByteStorageBase',
'StorageBase',
]

View File

@ -75,8 +75,7 @@ def _cuda(self, device=None, non_blocking=False, **kwargs):
values = torch.Tensor._values(self).cuda(device, non_blocking)
return new_type(indices, values, self.size())
else:
new_type = getattr(torch.cuda, self.__class__.__name__)
return new_type(self.size()).copy_(self, non_blocking)
return torch._UntypedStorage(self.size(), device=torch.device('cuda')).copy_(self, non_blocking)
def _get_async_or_non_blocking(function_name, non_blocking, kwargs):

View File

@ -9,6 +9,8 @@
#include <torch/csrc/utils/cuda_enabled.h>
#include <torch/csrc/utils/cuda_lazy_init.h>
#include <torch/csrc/utils/object_ptr.h>
#include <torch/csrc/Storage.h>
#include <torch/csrc/Device.h>
#include <ATen/ATen.h>
@ -22,9 +24,6 @@
namespace torch {
namespace {
std::unordered_map<at::DeprecatedTypeProperties*, PyTypeObject*> attype_to_py_storage_type;
std::unordered_map<PyTypeObject*, at::DeprecatedTypeProperties*> py_storage_type_to_attype;
std::array<THPDtype*, static_cast<int>(at::ScalarType::NumOptions)> dtype_registry = {};
std::array<THPLayout*, static_cast<int>(at::Layout::NumOptions)> layout_registry = {};
@ -51,34 +50,8 @@ at::DeprecatedTypeProperties* get_type(at::Backend backend, at::ScalarType scala
}
return &at::getDeprecatedTypeProperties(backend, scalarType);
}
PyTypeObject* getPyTypeObject(const at::Storage& storage) {
// TODO: https://github.com/pytorch/pytorch/issues/47442
if (storage.device_type() == at::DeviceType::Meta) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
}
if (storage.data() == nullptr && storage.nbytes() != 0) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details");
}
at::ScalarType scalarType = at::ScalarType::Byte;
auto attype = &at::getDeprecatedTypeProperties(
at::dispatchKeyToBackend(c10::computeDispatchKey(scalarType, c10::nullopt, storage.device_type())),
scalarType);
auto it = attype_to_py_storage_type.find(attype);
TORCH_INTERNAL_ASSERT(it != attype_to_py_storage_type.end(),
"Failed to get the Python type of `_UntypedStorage`.");
return it->second;
}
} // namespace
void registerStoragePyTypeObject(PyTypeObject *pytype, at::Backend backend, at::ScalarType scalarType) {
auto attype = get_type(backend, scalarType);
if (attype) {
attype_to_py_storage_type[attype] = pytype;
py_storage_type_to_attype[pytype] = attype;
}
}
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType) {
dtype_registry[static_cast<int>(scalarType)] = dtype;
}
@ -104,7 +77,14 @@ THPLayout* getTHPLayout(at::Layout layout) {
}
PyObject* createPyObject(const at::Storage& storage) {
auto type = getPyTypeObject(storage);
// TODO: https://github.com/pytorch/pytorch/issues/47442
if (storage.device_type() == at::DeviceType::Meta) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings for meta storage objects not supported");
}
if (storage.data() == nullptr && storage.nbytes() != 0) {
TORCH_CHECK_NOT_IMPLEMENTED(false, "python bindings to nullptr storage (e.g., from torch.Tensor._make_wrapper_subclass) are currently unsafe and thus disabled. See https://github.com/pytorch/pytorch/issues/61669 for more details");
}
PyTypeObject* type = reinterpret_cast<PyTypeObject*>(THPStorageClass);
auto obj = THPObjectPtr(type->tp_alloc(type, 0));
if (!obj) throw python_error();
((THPVoidStorage*)obj.get())->cdata = at::Storage(/* copy */ storage).unsafeReleaseStorageImpl();
@ -133,50 +113,56 @@ bool isStorage(PyObject* obj)
return true;
}
auto obj_type = Py_TYPE(obj);
for (auto const& item : py_storage_type_to_attype) {
auto const& storage_type = item.first;
if (obj_type == storage_type) {
return true;
}
}
return false;
return obj_type == reinterpret_cast<PyTypeObject*>(THPStorageClass);
}
at::Storage createStorageGetType(PyObject* obj, at::ScalarType& scalar_type, bool& is_typed_storage)
{
is_typed_storage = PyObject_TypeCheck(obj, getTypedStorageTypeObject());
THPObjectPtr maybe_untyped_storage;
if (is_typed_storage) {
PyObject* maybe_untyped_storage_obj = PyObject_GetAttrString(obj, "_storage");
TORCH_INTERNAL_ASSERT(maybe_untyped_storage_obj);
maybe_untyped_storage = maybe_untyped_storage_obj;
}
PyObject* untyped_storage_obj;
auto obj_type = Py_TYPE(obj);
for (auto const& item : py_storage_type_to_attype) {
auto const& storage_type = item.first;
if (is_typed_storage) {
if (Py_TYPE(maybe_untyped_storage.get()) == storage_type) {
auto& type = *item.second;
auto ret = type.unsafeStorageFromTH(
((THPVoidStorage*)maybe_untyped_storage.get())->cdata,
true);
PyObject* dtype_obj = PyObject_GetAttrString(obj, "dtype");
TORCH_INTERNAL_ASSERT(dtype_obj && THPDtype_Check(dtype_obj));
scalar_type = reinterpret_cast<THPDtype*>(dtype_obj)->scalar_type;
return ret;
}
}
if (obj_type == storage_type) {
auto& type = *item.second;
// _UntypedStorage should always be interpreted with byte dtype
untyped_storage_obj = PyObject_GetAttrString(obj, "_storage");
TORCH_INTERNAL_ASSERT(untyped_storage_obj);
// `PyObject_GetAttrString` increments the refcount to the
// `_UntypedStorage`, so we must decrement it. The refcount will stay
// positive for as long as we need it, since the `_TypedStorage` will
// retain its reference to the `_UntypedStorage`, so decrementing it
// immediately like this is fine.
Py_DECREF(untyped_storage_obj);
} else {
scalar_type = at::kByte;
return type.unsafeStorageFromTH(((THPVoidStorage*)obj)->cdata, true);
}
untyped_storage_obj = obj;
}
if (Py_TYPE(untyped_storage_obj) != reinterpret_cast<PyTypeObject*>(THPStorageClass)) {
throw TypeError("not a storage '%s'", Py_TYPE(obj)->tp_name);
}
c10::StorageImpl* impl = static_cast<c10::StorageImpl*>(((THPVoidStorage*)untyped_storage_obj)->cdata);
c10::DeviceType device_type = impl->device().type();
at::Backend backend;
if (device_type == at::kCPU) {
backend = at::Backend::CPU;
} else if (device_type == at::kCUDA) {
backend = at::Backend::CUDA;
} else {
TORCH_CHECK(false, "Invalid device for storage: ", device_type);
}
auto type_properties = get_type(backend, at::kByte);
return type_properties->unsafeStorageFromTH(((THPVoidStorage*)untyped_storage_obj)->cdata, true);
}
at::Storage createStorage(PyObject* obj) {
at::ScalarType scalar_type;
bool is_typed_storage = false;

View File

@ -21,10 +21,6 @@ struct Storage;
}
namespace torch {
// Register a PyTypeObject* with the given attributes
void registerStoragePyTypeObject(
PyTypeObject *pytype, at::Backend backend, at::ScalarType scalarType);
void registerDtypeObject(THPDtype *dtype, at::ScalarType scalarType);
void registerLayoutObject(THPLayout *thp_layout, at::Layout layout);

View File

@ -133,7 +133,7 @@ static PyObject * THPModule_initExtension(PyObject *_unused, PyObject *shm_manag
auto module = THPObjectPtr(PyImport_ImportModule("torch"));
if (!module) throw python_error();
THPByteStorage_postInit(module);
THPStorage_postInit(module);
THPAutograd_initFunctions();
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
@ -735,7 +735,6 @@ static PyMethodDef TorchMethods[] = {
{nullptr, nullptr, 0, nullptr}
};
bool THCPByteStorage_init(PyObject *module);
void THCPStream_init(PyObject *module);
void THCPEvent_init(PyObject *module);
void THCPGraph_init(PyObject *module);
@ -749,8 +748,6 @@ void initModule(PyObject *module);
}} // namespace torch::cuda
#endif
bool THDPByteStorage_init(PyObject *module);
static std::vector<PyMethodDef> methods;
// In Python we can't use the trick of C10_LOG_API_USAGE_ONCE
@ -853,14 +850,13 @@ PyObject* initModule() {
#ifdef USE_CUDA
torch::cuda::initModule(module);
#endif
ASSERT_TRUE(THPByteStorage_init(module));
ASSERT_TRUE(THPStorage_init(module));
#ifdef USE_CUDA
// This will only initialise base classes and attach them to library namespace
// They won't be ready for real usage until importing cuda module, that will
// complete the process (but it defines Python classes before calling back into
// C, so these lines have to execute first)..
ASSERT_TRUE(THCPByteStorage_init(module));
THCPStream_init(module);
THCPEvent_init(module);
THCPGraph_init(module);

View File

@ -2,17 +2,15 @@
#define THP_STORAGE_INC
#include <torch/csrc/THConcat.h>
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
#define THPStorageStr "torch._UntypedStorage"
#define THPStorage_(NAME) TH_CONCAT_2(THPStorage_,NAME)
#define THPByteStorage_Check(obj) \
PyObject_IsInstance(obj, THPByteStorageClass)
#define THPStorage_Check(obj) \
PyObject_IsInstance(obj, THPStorageClass)
#define THPByteStorage_CData(obj) (obj)->cdata
#define THPStorage_CData(obj) (obj)->cdata
#define THPStorageType TH_CONCAT_3(THP,Real,StorageType)
#define THPStorageBaseStr TH_CONCAT_STRING_2(Real,StorageBase)
#define THPStorageBaseStr "StorageBase"
#include <torch/csrc/generic/Storage.h>
#include <torch/csrc/THGenerateByteType.h>

View File

@ -591,9 +591,6 @@ static PyObject * THCPModule_initExtension(PyObject *self, PyObject *noargs)
auto m = THPObjectPtr(PyImport_ImportModule("torch.cuda"));
if (!m) throw python_error();
// Register Storage Python objects with DynamicTypes.cpp
THCPByteStorage_postInit(m);
bool has_half = true;
auto set_module_attr = [&](const char* name, PyObject* v) {

View File

@ -1,17 +0,0 @@
#define __STDC_FORMAT_MACROS
#include <torch/csrc/python_headers.h>
#include <structmember.h>
#include <fmt/format.h>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/override_macros.h>
#include <torch/csrc/copy_utils.h>
#include <torch/csrc/DynamicTypes.h>
#include <torch/csrc/CudaIPCTypes.h>
#include <torch/csrc/Device.h>
#include <torch/csrc/autograd/utils/wrap_outputs.h>
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#include <torch/csrc/THCGenerateByteType.h>

View File

@ -1,21 +0,0 @@
#ifndef THCP_STORAGE_INC
#define THCP_STORAGE_INC
#define THCPStorageStr TH_CONCAT_STRING_3(torch.cuda.,Real,Storage)
#define THCPStorageClass TH_CONCAT_3(THCP,Real,StorageClass)
#define THCPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
#define THCPByteStorage_Check(obj) \
PyObject_IsInstance(obj, THCPByteStorageClass)
#define THCPByteStorage_CData(obj) (obj)->cdata
#define THCPStorageType TH_CONCAT_3(THCP,Real,StorageType)
#define THCPStorageBaseStr TH_CONCAT_STRING_3(Cuda,Real,StorageBase)
#include <torch/csrc/cuda/override_macros.h>
#define THC_GENERIC_FILE "torch/csrc/generic/Storage.h"
#include <torch/csrc/THCGenerateByteType.h>
#endif

View File

@ -3,9 +3,7 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/THP.h>
#include <torch/csrc/cuda/serialization.h>
#include <torch/csrc/cuda/Module.h>
#include <torch/csrc/cuda/Storage.h>
#include <torch/csrc/cuda/Stream.h>
#include <torch/csrc/cuda/Event.h>
#include <torch/csrc/cuda/utils.h>

View File

@ -1,17 +1,10 @@
#include <torch/csrc/cuda/undef_macros.h>
#define THPStoragePtr THCPStoragePtr
#define THPTensorPtr THCPTensorPtr
#define THWTensor THCTensor
#define THWTensor_(NAME) THCTensor_(NAME)
#define THPStorage_(NAME) TH_CONCAT_4(THCP,Real,Storage_,NAME)
#define THPStorageBaseStr THCPStorageBaseStr
#define THPStorageStr THCPStorageStr
#define THPStorageClass THCPStorageClass
#define THPStorageType THCPStorageType
#define THPTensor_(NAME) TH_CONCAT_4(THCP,Real,Tensor_,NAME)
#define THPTensor_stateless_(NAME) TH_CONCAT_4(THCP,Real,Tensor_stateless_,NAME)
#define THPTensor THCPTensor

View File

@ -8,10 +8,5 @@
#define THPTensorClass TH_CONCAT_3(THP,Real,TensorClass)
#define THPTensor_(NAME) TH_CONCAT_4(THP,Real,Tensor_,NAME)
#define THPStorageStr TH_CONCAT_STRING_3(torch.,Real,Storage)
#define THPStorageClass TH_CONCAT_3(THP,Real,StorageClass)
#define THPStorage_(NAME) TH_CONCAT_4(THP,Real,Storage_,NAME)
#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
#define THPStoragePtr TH_CONCAT_3(THP,Real,StoragePtr)
#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr)

View File

@ -1,11 +0,0 @@
#include <torch/csrc/python_headers.h>
#include <torch/csrc/cuda/THCP.h>
#include <torch/csrc/cuda/override_macros.h>
#include <system_error>
#include <memory>
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
#include <torch/csrc/THCGenerateByteType.h>

View File

@ -1,9 +0,0 @@
#ifndef THCP_SERIALIZATION_INC
#define THCP_SERIALIZATION_INC
#include <torch/csrc/cuda/override_macros.h>
#define THC_GENERIC_FILE "torch/csrc/generic/serialization.h"
#include <torch/csrc/THCGenerateByteType.h>
#endif

View File

@ -11,13 +11,6 @@
#undef THPTensorStateless
#undef THPTensorType
#undef THPStorage_
#undef THPStorageBaseStr
#undef THPStorageStr
#undef THPStorageClass
#undef THPStorageType
#undef THPStoragePtr
#undef THPTensorPtr
#undef THWTensor

View File

@ -1,6 +1,7 @@
#ifndef TH_GENERIC_FILE
#define TH_GENERIC_FILE "torch/csrc/generic/Storage.cpp"
#else
#include <torch/csrc/utils/python_arg_parser.h>
PyObject *THPStorageClass = nullptr;
@ -26,68 +27,84 @@ static void THPStorage_(dealloc)(THPStorage* self)
static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObject *kwargs)
{
HANDLE_TH_ERRORS
Py_ssize_t num_args = args ? PyTuple_Size(args) : 0;
static torch::PythonArgParser parser({
THPStorageStr "(*, int64_t allocator=None, Device device=None)",
THPStorageStr "(int64_t size, *, int64_t allocator=None, Device device=None)",
THPStorageStr "(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
});
torch::ParsedArgs<3> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
int64_t allocator_arg_idx = 0;
int64_t device_arg_idx = 1;
if (r.idx > 0) {
allocator_arg_idx = 1;
device_arg_idx = 2;
}
c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
TORCH_CHECK(!allocator_opt.has_value() || !device_opt.has_value(),
THPStorageStr, "(): only one or neither of 'allocator' or 'device' can ",
"be given, but not both");
THPStoragePtr self((THPStorage *)type->tp_alloc(type, 0));
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
c10::Allocator* allocator = nullptr;
at::OptionalDeviceGuard device_guard;
// Internally we allow constructing with a keywoard only argument cdata
if (kwargs != nullptr) {
PyObject *allocator_ptr = PyDict_GetItemString(kwargs, "allocator");
if (allocator_ptr) {
THPUtils_assert(THPUtils_checkLong(allocator_ptr), "invalid allocator");
allocator = static_cast<c10::Allocator*>(PyLong_AsVoidPtr(allocator_ptr));
PyDict_DelItemString(kwargs, "allocator");
}
Py_ssize_t num_kwargs = PyDict_Size(kwargs);
if (num_args == 0) {
PyObject *cdata_ptr = PyDict_GetItemString(kwargs, "cdata");
if (num_kwargs == 1 && cdata_ptr && THPUtils_checkLong(cdata_ptr)) {
c10::StorageImpl *ptr = (c10::StorageImpl*)PyLong_AsVoidPtr(cdata_ptr);
self->cdata = ptr;
return (PyObject*)self.release();
}
}
THPUtils_assert(num_kwargs == 0, THPStorageStr "(): invalid keyword arguments");
}
if (allocator == nullptr) {
#if defined(THC_GENERIC_FILE)
allocator = c10::cuda::CUDACachingAllocator::get();
#else
if (allocator_opt.has_value()) {
allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
} else if (device_opt.has_value()) {
at::Device device = device_opt.value();
if (device.type() == at::kCPU) {
allocator = c10::GetDefaultCPUAllocator();
#ifdef USE_CUDA
} else if (device.type() == at::kCUDA) {
at::globalContext().lazyInitCUDA();
allocator = c10::cuda::CUDACachingAllocator::get();
#endif
} else {
TORCH_CHECK(false,
THPStorageStr, "(): Storage device not recognized: ", device.type());
}
device_guard.reset_device(device);
} else {
allocator = c10::GetDefaultCPUAllocator();
}
// torch.Storage()
if (num_args == 0) {
// torch.Storage(*, ...)
if (r.idx == 0) {
self->cdata = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
0,
allocator,
/*resizable=*/true).release();
return (PyObject*)self.release();
}
PyObject *first_arg = PyTuple_GET_ITEM(args, 0);
// torch.Storage(size)
if (num_args == 1 && THPUtils_checkLong(first_arg)) {
int64_t size = THPUtils_unpackLong(first_arg);
// torch.Storage(size, *, ...)
} else if (r.idx == 1) {
int64_t size = r.toInt64(0);
self->cdata = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size,
allocator,
/*resizable=*/true).release();
return (PyObject*)self.release();
}
// torch.Storage(sequence)
if (num_args == 1 && PySequence_Check(first_arg)) {
Py_ssize_t length = PySequence_Length(first_arg);
THPUtils_assert(length >= 0, "couldn't obtain the length of %s",
THPUtils_typename(first_arg));
// torch.Storage(sequence, *, ...)
} else if (r.idx == 2) {
PyObject *sequence = r.pyobject(0);
Py_ssize_t length = PySequence_Length(sequence);
TORCH_CHECK(PySequence_Check(sequence),
THPStorageStr, "(): Expected a sequence type, but got ",
THPUtils_typename(sequence));
TORCH_CHECK(length >= 0,
THPStorageStr, "(): Could not obtain the length of sequence of type ",
THPUtils_typename(sequence));
self->cdata = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
length,
@ -97,35 +114,31 @@ static PyObject * THPStorage_(pynew)(PyTypeObject *type, PyObject *args, PyObjec
THPObjectPtr item;
try {
for (Py_ssize_t i = 0; i < length; i++) {
item = PySequence_GetItem(first_arg, i);
item = PySequence_GetItem(sequence, i);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
scalar_t value = THPUtils_(unpackReal)(item.get());
#if !defined(THC_GENERIC_FILE)
if (allocator == c10::GetDefaultCPUAllocator()) {
self->cdata->unsafe_data<scalar_t>()[i] = value;
#else
} else {
// TODO: this might be slow - consider batched updates?
storage_set(
at::unsafeStorageFromTH(self->cdata, /*retain=*/true),
i,
value);
#endif
}
}
} catch (const std::exception &e) {
THPUtils_setError("tried to construct a storage from a sequence (%s), "
THPUtils_setError(THPStorageStr
"(): tried to construct a storage from a sequence (%s), "
"but one of the items was of type %s instead of %s",
THPUtils_typename(first_arg),
THPUtils_typename(sequence),
THPUtils_typename(item.get()),
THPUtils_typeTraits<scalar_t>::python_type_str);
return nullptr;
}
return (PyObject*)self.release();
}
THPUtils_invalidArguments(args, kwargs, THPStorageStr " constructor", 6,
"no arguments",
"(int size)",
"(Sequence data)");
return nullptr;
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
@ -346,17 +359,6 @@ void THPStorage_(postInit)(PyObject *module)
{
THPStorageClass = PyObject_GetAttrString(module, "_UntypedStorage");
if (!THPStorageClass) throw python_error();
at::Backend backend = at::Backend::CPU;
#ifdef THC_GENERIC_FILE
backend = at::Backend::CUDA;
#endif
#ifdef THQUANTIZED
backend = at::Backend::QuantizedCPU;
#endif
torch::registerStoragePyTypeObject((PyTypeObject*)THPStorageClass, backend, TH_CONCAT_2(at::k, Real));
}
#endif

View File

@ -6,15 +6,12 @@
#ifdef USE_CUDA
#include <cuda_runtime.h>
#endif
#if !defined(THC_GENERIC_FILE)
#include <c10/core/CPUAllocator.h>
#include <ATen/native/Resize.h>
#else
#include <ATen/native/cuda/Resize.h>
#endif
#include <c10/core/CPUAllocator.h>
#include <ATen/native/Resize.h>
#ifdef _MSC_VER
#define LSEEK _lseeki64
#else
@ -86,14 +83,11 @@ static PyObject * THPStorage_(new)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
c10::Allocator* allocator = self->cdata->allocator();
auto new_storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
0,
#if defined(THC_GENERIC_FILE)
c10::cuda::CUDACachingAllocator::get(),
#else
c10::GetDefaultCPUAllocator(),
#endif
allocator,
/*resizable=*/true);
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -108,16 +102,22 @@ static PyObject * THPStorage_(resize_)(PyObject *_self, PyObject *number_arg)
THPUtils_assert(THPUtils_checkLong(number_arg), "resize_ expects an int, "
"but got %s", THPUtils_typename(number_arg));
int64_t newsize = THPUtils_unpackLong(number_arg);
#if defined(THC_GENERIC_FILE)
c10::DeviceType device_type = self->cdata->device_type();
if (device_type == at::kCPU) {
at::native::resize_bytes_cpu(self->cdata, newsize);
#ifdef USE_CUDA
} else if (device_type == at::kCUDA) {
ptrdiff_t size_bytes_i = newsize;
TORCH_CHECK(!c10::overflows<size_t>(size_bytes_i),
"Requested storage size (", size_bytes_i,
") cannot be represented as a size_t");
const auto size_bytes = static_cast<size_t>(size_bytes_i);
at::native::resize_bytes_cuda(self->cdata, size_bytes);
#else
at::native::resize_bytes_cpu(self->cdata, newsize);
#endif
} else {
TORCH_CHECK(false,
"_UntypedStorage.resize_: got unexpected device type ", device_type);
}
Py_INCREF(self);
return (PyObject*)self;
END_HANDLE_TH_ERRORS
@ -138,7 +138,6 @@ static PyObject * THPStorage_(fill_)(PyObject *_self, PyObject *number_arg)
END_HANDLE_TH_ERRORS
}
#if !defined(THC_GENERIC_FILE)
static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyObject *keywds)
{
HANDLE_TH_ERRORS
@ -224,11 +223,7 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
auto storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
size_bytes,
#if defined(THC_GENERIC_FILE)
c10::cuda::CUDACachingAllocator::get(),
#else
c10::GetDefaultCPUAllocator(),
#endif
/*resizable=*/true);
if (scalar_type == at::kByte || scalar_type == at::kChar) {
@ -284,7 +279,6 @@ static PyObject * THPStorage_(fromBuffer)(PyObject *_unused, PyObject *args, PyO
return (PyObject*)THPStorage_(New)(storage);
END_HANDLE_TH_ERRORS
}
#endif
static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObject *keywds)
{
@ -302,10 +296,6 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj
if (shared)
shared = at::ALLOCATOR_MAPPED_SHARED;
#ifdef THC_GENERIC_FILE
TORCH_CHECK(false, "not available yet for CUDA");
return nullptr;
#else
size_t actual_nbytes = -1;
auto storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
@ -320,7 +310,6 @@ static PyObject * THPStorage_(fromFile)(PyObject *_unused, PyObject *args, PyObj
}
return (PyObject*)THPStorage_(New)(std::move(storage));
#endif
END_HANDLE_TH_ERRORS
}
@ -429,16 +418,6 @@ static PyObject *THPStorage_(setFromFile)(PyObject *_self, PyObject *args)
END_HANDLE_TH_ERRORS
}
#ifdef THC_GENERIC_FILE
PyObject * THPStorage_(getDevice)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
return THPUtils_packInt32(self->cdata->device().index());
END_HANDLE_TH_ERRORS
}
#endif
PyObject * THPStorage_(_setCdata)(PyObject *_self, PyObject *new_cdata)
{
HANDLE_TH_ERRORS
@ -473,15 +452,10 @@ static PyMethodDef THPStorage_(methods)[] = {
{"_write_file", THPStorage_(writeFile), METH_VARARGS, nullptr},
{"_new_with_file", THPStorage_(newWithFile), METH_VARARGS | METH_STATIC, nullptr},
{"_set_from_file", THPStorage_(setFromFile), METH_VARARGS, nullptr},
#if !defined(THC_GENERIC_FILE)
{"from_buffer", castPyCFunctionWithKeywords(THPStorage_(fromBuffer)),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
#endif
{"from_file", castPyCFunctionWithKeywords(THPStorage_(fromFile)),
METH_VARARGS | METH_KEYWORDS | METH_STATIC, nullptr},
#ifdef THC_GENERIC_FILE
{"get_device", THPStorage_(getDevice), METH_NOARGS, nullptr},
#endif
{"_set_cdata", THPStorage_(_setCdata), METH_O, nullptr},
{nullptr}
};

View File

@ -13,13 +13,14 @@ static PyObject * THPStorage_(sharedDecref)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
#ifndef THC_GENERIC_FILE
c10::DeviceType device_type = self->cdata->device_type();
if (device_type == at::kCPU) {
c10::StorageImpl *storage = self->cdata;
THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr());
if (ctx) {
ctx->decref();
}
#endif
}
Py_INCREF(self);
return (PyObject *)self;
END_HANDLE_TH_ERRORS
@ -29,19 +30,18 @@ static PyObject * THPStorage_(sharedIncref)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
#ifndef THC_GENERIC_FILE
c10::DeviceType device_type = self->cdata->device_type();
if (device_type == at::kCPU) {
c10::StorageImpl *storage = self->cdata;
THManagedMapAllocator *ctx = THManagedMapAllocator::fromDataPtr(storage->data_ptr());
if (ctx) {
ctx->incref();
}
#endif
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS
}
#ifndef THC_GENERIC_FILE
static PyObject * THPStorage_(pyNewFilenameStorage)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
@ -65,6 +65,8 @@ static PyObject * THPStorage_(pyNewFilenameStorage)(PyObject *_unused, PyObject
static PyObject * THPStorage_(shareFilename)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
TORCH_CHECK(reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCPU,
"_share_filename_: only available on CPU");
auto self = (THPStorage*)_self;
c10::StorageImpl *storage = self->cdata;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -166,6 +168,8 @@ static PyObject * THPStorage_(pyNewFdStorage)(PyObject *_unused, PyObject *args)
static PyObject * THPStorage_(shareFd)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
TORCH_CHECK(reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCPU,
"_share_fd_: only available on CPU");
auto self = (THPStorage*)_self;
c10::StorageImpl *storage = self->cdata;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -230,11 +234,12 @@ static PyObject * THPStorage_(newSharedFd)(PyObject *_unused, PyObject *args)
END_HANDLE_TH_ERRORS
}
#else // THC_GENERIC_FILE
static PyObject * THPStorage_(shareCuda)(PyObject *_self, PyObject *noargs)
{
HANDLE_TH_ERRORS
#ifdef USE_CUDA
TORCH_CHECK(reinterpret_cast<THPStorage*>(_self)->cdata->device_type() == at::kCUDA,
"_share_cuda_: only available on CUDA");
auto self = (THPStorage*)_self;
c10::StorageImpl *storage = self->cdata;
@ -309,12 +314,16 @@ static PyObject * THPStorage_(shareCuda)(PyObject *_self, PyObject *noargs)
PyTuple_SET_ITEM(tuple.get(), 6, _event_handle.release());
PyTuple_SET_ITEM(tuple.get(), 7, _event_sync_required.release());
return tuple.release();
#else
TORCH_CHECK(false, "CUDA is not available");
#endif
END_HANDLE_TH_ERRORS
}
static PyObject * THPStorage_(releaseIPCCounter)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
#ifdef USE_CUDA
THPUtils_assert(PyTuple_GET_SIZE(args) == 2, "tuple of 2 items expected");
PyObject *_ref_counter = PyTuple_GET_ITEM(args, 0);
PyObject *_ref_counter_offset = PyTuple_GET_ITEM(args, 1);
@ -347,9 +356,13 @@ static PyObject * THPStorage_(releaseIPCCounter)(PyObject *_unused, PyObject *ar
// Already warned inside of producer process
}
Py_RETURN_NONE;
#else
TORCH_CHECK(false, "CUDA is not available");
#endif
END_HANDLE_TH_ERRORS
}
#ifdef USE_CUDA
static std::string THPStorage_(bytesAsHandleString)(PyObject *handle) {
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
char* buffer;
@ -364,10 +377,12 @@ static std::string THPStorage_(bytesAsHandleString)(PyObject *handle) {
handle_size == CUDA_IPC_HANDLE_SIZE, "incorrect handle size");
return std::string(buffer, handle_size);
}
#endif
static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
{
HANDLE_TH_ERRORS
#ifdef USE_CUDA
THPUtils_assert(PyTuple_GET_SIZE(args) == 8, "tuple of 8 items expected");
PyObject *_device = PyTuple_GET_ITEM(args, 0);
PyObject *_handle = PyTuple_GET_ITEM(args, 1);
@ -405,7 +420,7 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
cudaEvent_t event;
cudaIpcOpenEventHandle(&event, *ipc_event_handle);
AT_CUDA_CHECK(
C10_CUDA_CHECK(
cudaStreamWaitEvent(c10::cuda::getCurrentCUDAStream(device), event, 0));
}
@ -485,9 +500,11 @@ static PyObject * THPStorage_(newSharedCuda)(PyObject *_unused, PyObject *args)
base->set_received_cuda(true);
return THPStorage_(New)(std::move(base));
#else
TORCH_CHECK(false, "CUDA is not available");
#endif
END_HANDLE_TH_ERRORS
}
#endif
// Returns an object that holds a "weak" pointer to the c10::StorageImpl. This
// pointer keeps the c10::StorageImpl struct live, but does not retain the data
@ -544,10 +561,10 @@ PyObject * THPStorage_(sharedFd)(PyObject *_self, PyObject *noargs)
HANDLE_TH_ERRORS
auto self = (THPStorage*)_self;
at::MapAllocator *ctx = nullptr;
#ifndef THC_GENERIC_FILE
if (self->cdata->device_type() == at::kCPU) {
c10::StorageImpl *storage = self->cdata;
ctx = at::MapAllocator::fromDataPtr(storage->data_ptr());
#endif
}
THPUtils_assert(ctx, "couldn't retrieve a shared file descriptor");
return THPUtils_packInt32(ctx->fd());
@ -557,33 +574,29 @@ PyObject * THPStorage_(sharedFd)(PyObject *_self, PyObject *noargs)
PyObject * THPStorage_(isShared)(PyObject *_self, PyObject *noargs)
{
auto self = (THPStorage*)_self;
#ifdef THC_GENERIC_FILE
if (self->cdata->device_type() == at::kCUDA) {
Py_RETURN_TRUE;
#else
}
if (at::MapAllocator::fromDataPtr(self->cdata->data_ptr()) ||
THManagedMapAllocator::fromDataPtr(self->cdata->data_ptr())) {
Py_RETURN_TRUE;
} else {
Py_RETURN_FALSE;
}
#endif
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
static PyMethodDef THPStorage_(sharingMethods)[] = {
{"_new_with_weak_ptr", THPStorage_(newWithWeakPtr), METH_O | METH_CLASS, nullptr},
#ifdef THC_GENERIC_FILE
{"_share_cuda_", THPStorage_(shareCuda), METH_NOARGS, nullptr},
{"_new_shared_cuda", THPStorage_(newSharedCuda), METH_VARARGS | METH_STATIC, nullptr},
{"_release_ipc_counter", THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr},
#else
{"_share_fd_", THPStorage_(shareFd), METH_NOARGS, nullptr},
{"_new_shared_fd", THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr},
{"_new_using_fd", THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr},
{"_share_filename_", THPStorage_(shareFilename), METH_NOARGS, nullptr},
{"_new_shared_filename", THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr},
{"_new_using_filename", THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr},
#endif
{"_release_ipc_counter_cuda", THPStorage_(releaseIPCCounter), METH_VARARGS | METH_STATIC, nullptr},
{"_share_fd_cpu_", THPStorage_(shareFd), METH_NOARGS, nullptr},
{"_new_shared_fd_cpu", THPStorage_(newSharedFd), METH_VARARGS | METH_STATIC, nullptr},
{"_new_using_fd_cpu", THPStorage_(pyNewFdStorage), METH_VARARGS | METH_STATIC, nullptr},
{"_share_filename_cpu_", THPStorage_(shareFilename), METH_NOARGS, nullptr},
{"_new_shared_filename_cpu", THPStorage_(newSharedFilename), METH_VARARGS | METH_STATIC, nullptr},
{"_new_using_filename_cpu", THPStorage_(pyNewFilenameStorage), METH_VARARGS | METH_STATIC, nullptr},
{"_weak_ref", THPStorage_(weakRef), METH_NOARGS, nullptr},
{"_free_weak_ref", THPStorage_(freeWeakRef), METH_O | METH_STATIC, nullptr},
{"_expired", THPStorage_(expired), METH_O | METH_STATIC, nullptr},

View File

@ -2,11 +2,7 @@
#define TH_GENERIC_FILE "torch/csrc/generic/serialization.cpp"
#else
#ifdef THC_GENERIC_FILE
#include <c10/cuda/CUDAGuard.h>
#else
#include <c10/core/CPUAllocator.h>
#endif
// save_save is necessary since the old eager format saved storages as
// [size + data], but the v1.5 eager format removes this since size is saved in
@ -14,19 +10,18 @@
template <class io>
void THPStorage_(writeFileRaw)(c10::StorageImpl *self, io fd, bool save_size, uint64_t element_size)
{
#ifdef THC_GENERIC_FILE
c10::cuda::CUDAGuard guard(self->device());
#endif
c10::DeviceGuard guard(self->device());
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
scalar_t *data;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::unique_ptr<char[]> cpu_data;
int64_t size_bytes = self->nbytes();
int64_t numel = size_bytes / element_size;
#ifndef THC_GENERIC_FILE
if (self->device_type() == at::kCPU) {
data = self->data<scalar_t>();
#else
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::unique_ptr<char[]> cpu_data(new char[size_bytes]);
#ifdef USE_CUDA
} else if (self->device_type() == at::kCUDA) {
cpu_data = std::unique_ptr<char[]>(new char[size_bytes]);
data = (scalar_t*)cpu_data.get();
C10_CUDA_CHECK(cudaMemcpy(
data,
@ -34,6 +29,9 @@ void THPStorage_(writeFileRaw)(c10::StorageImpl *self, io fd, bool save_size, ui
size_bytes,
cudaMemcpyDeviceToHost));
#endif
} else {
TORCH_CHECK(false, "writeFileRaw: Device not recognized: ", self->device_type());
}
if (save_size) {
if (torch::utils::THP_nativeByteOrder() ==
torch::utils::THPByteOrder::THP_LITTLE_ENDIAN)
@ -93,13 +91,10 @@ template <class io>
c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
io file, c10::intrusive_ptr<c10::StorageImpl> storage, uint64_t element_size)
{
#ifdef THC_GENERIC_FILE
c10::cuda::OptionalCUDAGuard guard;
c10::OptionalDeviceGuard guard;
if (storage.defined()) {
guard.set_device(storage->device());
guard.reset_device(storage->device());
}
#endif
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
scalar_t *data;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
@ -118,11 +113,7 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
storage = c10::make_intrusive<at::StorageImpl>(
c10::StorageImpl::use_byte_size_t(),
nbytes,
#if defined(THC_GENERIC_FILE)
c10::cuda::CUDACachingAllocator::get(),
#else
c10::GetDefaultCPUAllocator(),
#endif
/*resizable=*/true);
} else {
int64_t _storage_nbytes = storage->nbytes();
@ -133,13 +124,15 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
_storage_nbytes);
}
#ifndef THC_GENERIC_FILE
data = storage->data<scalar_t>();
#else
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
std::unique_ptr<char[]> cpu_data(new char[nbytes]);
std::unique_ptr<char[]> cpu_data;
if (storage->device_type() == at::kCPU) {
data = storage->data<scalar_t>();
} else {
cpu_data = std::unique_ptr<char[]>(new char[nbytes]);
data = (scalar_t*)cpu_data.get();
#endif
}
// fast track for bytes and little endian
if (element_size == 1 ||
@ -179,8 +172,10 @@ c10::intrusive_ptr<c10::StorageImpl> THPStorage_(readFileRaw)(
}
}
#ifdef THC_GENERIC_FILE
#ifdef USE_CUDA
if (storage->device_type() == at::kCUDA) {
C10_CUDA_CHECK(cudaMemcpy(storage->data<scalar_t>(), data, nbytes, cudaMemcpyHostToDevice));
}
#endif
return storage;
}

View File

@ -139,7 +139,7 @@ void THPUtils_addPyMethodDefs(std::vector<PyMethodDef>& vector, PyMethodDef* met
int THPUtils_getCallable(PyObject *arg, PyObject **result);
#define THWTensorPtr TH_CONCAT_3(TH,Real,TensorPtr)
#define THPStoragePtr TH_CONCAT_3(THP,Real,StoragePtr)
#define THPStoragePtr TH_CONCAT_2(THP,StoragePtr)
#define THPTensorPtr TH_CONCAT_3(THP,Real,TensorPtr)
#define THSPTensorPtr TH_CONCAT_3(THSP,Real,TensorPtr)

View File

@ -635,25 +635,6 @@ from .random import * # noqa: F403
# Define Storage and Tensor classes
################################################################################
from ..storage import _StorageBase
if not hasattr(torch._C, 'CudaByteStorageBase'):
# Define dummy base classes
for t in ['Double', 'Float', 'Long', 'Int', 'Short', 'Char', 'Byte', 'Half', 'Bool', 'BFloat16',
'ComplexDouble', 'ComplexFloat']:
tensor_name = 'Cuda{0}TensorBase'.format(t)
torch._C.__dict__[tensor_name] = _dummy_type(tensor_name)
storage_name = 'CudaByteStorageBase'
torch._C.__dict__[storage_name] = _dummy_type(storage_name)
torch._C.__dict__['_CudaStreamBase'] = _dummy_type('CudaStreamBase')
torch._C.__dict__['_CudaEventBase'] = _dummy_type('CudaEventBase')
@staticmethod # type: ignore[misc]
def _lazy_new(cls, *args, **kwargs):
_lazy_init()
@ -675,9 +656,9 @@ class _CudaBase(object):
__new__ = _lazy_new
from torch.storage import _TypedStorage, _LegacyStorage
from torch.storage import _LegacyStorage
class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
class _CudaLegacyStorage(_LegacyStorage):
@classmethod
def from_buffer(cls, *args, **kwargs):
raise RuntimeError('from_buffer: Not available for CUDA storage')
@ -687,70 +668,72 @@ class _UntypedStorage(_CudaBase, torch._C.CudaByteStorageBase, _StorageBase):
raise RuntimeError('_new_with_weak_ptr: Not available for CUDA storage')
@classmethod
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None):
raise RuntimeError('_new_shared_filename: Not available for CUDA storage')
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None):
raise RuntimeError('_new_shared_filename_cpu: Not available for CUDA storage')
class ByteStorage(_LegacyStorage):
class ByteStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.uint8
class DoubleStorage(_LegacyStorage):
class DoubleStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.double
class FloatStorage(_LegacyStorage):
class FloatStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.float
class HalfStorage(_LegacyStorage):
class HalfStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.half
class LongStorage(_LegacyStorage):
class LongStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.long
class IntStorage(_LegacyStorage):
class IntStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.int
class ShortStorage(_LegacyStorage):
class ShortStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.short
class CharStorage(_LegacyStorage):
class CharStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.int8
class BoolStorage(_LegacyStorage):
class BoolStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.bool
class BFloat16Storage(_LegacyStorage):
class BFloat16Storage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.bfloat16
class ComplexDoubleStorage(_LegacyStorage):
class ComplexDoubleStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.cdouble
class ComplexFloatStorage(_LegacyStorage):
class ComplexFloatStorage(_CudaLegacyStorage):
@classproperty
def dtype(self):
return torch.cfloat
torch._storage_classes.add(_UntypedStorage)
del _LegacyStorage
del _CudaLegacyStorage
torch._storage_classes.add(DoubleStorage)
torch._storage_classes.add(FloatStorage)
torch._storage_classes.add(LongStorage)

View File

@ -122,7 +122,7 @@ def rebuild_cuda_tensor(tensor_cls, tensor_size, tensor_stride, tensor_offset,
shared_cache[(storage_handle, storage_offset_bytes)] = StorageWeakRef(storage)
else:
# We already ref counting this Storage, but producer needs new ref-counters to be released.
storage_cls._release_ipc_counter(ref_counter_handle, ref_counter_offset, device=storage_device)
storage_cls._release_ipc_counter_cuda(ref_counter_handle, ref_counter_offset, device=storage_device)
t = torch._utils._rebuild_tensor(
torch.storage._TypedStorage(wrap_storage=storage._untyped(), dtype=dtype),
@ -299,7 +299,7 @@ def rebuild_storage_fd(cls, df, size):
storage = storage_from_cache(cls, fd_id(fd))
if storage is not None:
return storage
storage = cls._new_shared_fd(fd, size)
storage = cls._new_shared_fd_cpu(fd, size)
shared_cache[fd_id(fd)] = StorageWeakRef(storage)
return storage
finally:
@ -311,10 +311,10 @@ def rebuild_storage_filename(cls, manager, handle, size, dtype=None):
if storage is not None:
return storage._shared_decref()
if dtype is None:
storage = torch._UntypedStorage._new_shared_filename(manager, handle, size)
storage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, size)
else:
byte_size = size * torch._utils._element_size(dtype)
untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename(manager, handle, byte_size)
untyped_storage: torch._UntypedStorage = torch._UntypedStorage._new_shared_filename_cpu(manager, handle, byte_size)
storage = torch._TypedStorage(
wrap_storage=untyped_storage,
dtype=dtype)
@ -344,7 +344,7 @@ def reduce_storage(storage):
if storage.is_cuda:
raise RuntimeError("Cannot pickle CUDA storage; try pickling a CUDA tensor instead")
elif get_sharing_strategy() == 'file_system':
metadata = storage._share_filename_()
metadata = storage._share_filename_cpu_()
cache_key = metadata[1]
rebuild = rebuild_storage_filename
if isinstance(storage, torch._TypedStorage):
@ -355,7 +355,7 @@ def reduce_storage(storage):
# (with size 0) cannot be mmapped.
return (rebuild_storage_empty, (type(storage),))
else:
fd, size = storage._share_fd_()
fd, size = storage._share_fd_cpu_()
df = multiprocessing.reduction.DupFd(fd)
cache_key = fd_id(fd)
metadata = (df, size)

View File

@ -880,19 +880,21 @@ class PackageExporter:
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, we can
# remove this case
storage = obj._storage
untyped_storage = obj._storage
storage_type_str = obj.pickle_storage_type()
storage_type = getattr(torch, storage_type_str)
dtype = obj.dtype
storage_numel = obj.size()
else:
storage = obj
elif isinstance(obj, torch._UntypedStorage):
untyped_storage = obj
storage_type = normalize_storage_type(type(storage))
dtype = torch.uint8
storage_numel = storage.nbytes()
else:
raise RuntimeError(f'storage type not recognized: {type(obj)}')
storage = cast(Storage, storage)
storage: Storage = cast(Storage, untyped_storage)
location = location_tag(storage)
# serialize storage if not already written

View File

@ -115,13 +115,13 @@ def check_module_version_greater_or_equal(module, req_version_tuple, error_if_ma
def _cpu_tag(obj):
if type(obj).__module__ == 'torch':
if obj.device.type == 'cpu':
return 'cpu'
def _cuda_tag(obj):
if type(obj).__module__ == 'torch.cuda':
return 'cuda:' + str(obj.get_device())
if obj.device.type == 'cuda':
return 'cuda:' + str(obj.device.index)
def _cpu_deserialize(obj, location):
@ -151,9 +151,8 @@ def _cuda_deserialize(obj, location):
if location.startswith('cuda'):
device = validate_cuda_device(location)
if getattr(obj, "_torch_load_uninitialized", False):
storage_type = getattr(torch.cuda, type(obj).__name__)
with torch.cuda.device(device):
return storage_type(obj.nbytes())
return torch._UntypedStorage(obj.nbytes(), device=torch.device(location))
else:
return obj.cuda(device)
@ -162,7 +161,7 @@ register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def location_tag(storage: Union[Storage, torch.storage._TypedStorage]):
def location_tag(storage: Union[Storage, torch.storage._TypedStorage, torch._UntypedStorage]):
for _, tagger, _ in _package_registry:
location = tagger(storage)
if location:
@ -414,6 +413,8 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
return ('module', obj, source_file, source)
if isinstance(obj, torch.storage._TypedStorage) or torch.is_storage(obj):
storage: torch._UntypedStorage
if isinstance(obj, torch.storage._TypedStorage):
# TODO: Once we decide to break serialization FC, this case
# can be deleted
@ -424,12 +425,14 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
dtype = obj.dtype
storage_numel = obj.size()
else:
elif isinstance(obj, torch._UntypedStorage):
storage = obj
storage_dtype = storage.dtype
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
dtype = torch.uint8
storage_numel = cast(Storage, storage).nbytes()
storage_numel = storage.nbytes()
else:
raise TypeError(f'type not recognized: {type(obj)}')
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
@ -444,7 +447,6 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
storage_dtypes[storage.data_ptr()] = storage_dtype
view_metadata: Optional[Tuple[str, int, int]]
storage = cast(Storage, storage)
# Offset is always 0, but we keep it for backwards compatibility
# with the old serialization format (which supported storage views)
@ -552,12 +554,10 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
else:
storage = obj
storage_dtype = storage.dtype
storage_dtype = torch.uint8
storage_type = normalize_storage_type(type(obj))
storage_numel = storage.nbytes()
storage = cast(Storage, storage)
# If storage is allocated, ensure that any other saved storages
# pointing to the same data all have the same dtype. If storage is
# not allocated, don't perform this check
@ -1009,6 +1009,9 @@ def _load(zip_file, map_location, pickle_module, pickle_file='data.pkl', **pickl
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
storage_type, key, location, numel = data
if storage_type is torch._UntypedStorage:
dtype = torch.uint8
else:
dtype = storage_type.dtype
if key not in loaded_storages:

View File

@ -24,7 +24,7 @@ class _StorageBase(object):
def __init__(self, *args, **kwargs): ... # noqa: E704
def __len__(self) -> int: ... # noqa: E704
def __getitem__(self, idx): ... # noqa: E704
def copy_(self, source: T) -> T: ... # noqa: E704
def copy_(self, source: T, non_blocking: bool = None) -> T: ... # noqa: E704
def nbytes(self) -> int: ... # noqa: E704
def size(self) -> int:
@ -37,25 +37,38 @@ class _StorageBase(object):
def data_ptr(self) -> int: ... # noqa: E704
# Defined in torch/csrc/generic/StorageSharing.cpp
def _share_filename_(self): ... # noqa: E704
def _share_fd_(self): ... # noqa: E704
def _share_filename_cpu_(self, *args, **kwargs): ... # noqa: E704
def _share_fd_cpu_(self, *args, **kwargs): ... # noqa: E704
@classmethod
def _new_using_filename(cls: Type[T], size: int) -> T: ... # noqa: E704
def _new_using_filename_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
@classmethod
def _new_using_fd(cls: Type[T], size: int) -> T: ... # noqa: E704
def _new_using_fd_cpu(cls: Type[T], size: int) -> T: ... # noqa: E704
@classmethod
def from_buffer(cls, *args, **kwargs) -> T: ... # noqa: E704
@classmethod
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
def _new_shared_filename_cpu(cls, manager, obj, size, *, device=None, dtype=None) -> T: ... # noqa: E704
@classmethod
def _release_ipc_counter(cls, *args, **kwargs) -> T: ... # noqa: E704
def _release_ipc_counter_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
@classmethod
def _new_with_weak_ptr(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_decref(self) -> T: ... # noqa: E704
def _write_file(self, *args, **kwargs): ... # noqa: E704
def resize_(self, size: int): ... # noqa: E704
def _weak_ref(self, *args, **kwargs) -> T: ... # noqa: E704
def is_pinned(self) -> bool: ... # noqa: E704
def _set_from_file(self, *args, **kwargs): ... # noqa: E704
def _set_cdata(self, *args, **kwargs): ... # noqa: E704
def _share_cuda_(self, *args, **kwargs): ... # noqa: E704
def is_shared(self) -> bool: ... # noqa: E704
@classmethod
def _new_shared_cuda(cls, *args, **kwargs) -> T: ... # noqa: E704
def _shared_incref(self, *args, **kwargs): ... # noqa: E704
def __str__(self):
content = ' ' + '\n '.join(str(self[i]) for i in range(len(self)))
return content + f'\n[{torch.typename(self)} of size {len(self)}]'
data_str = ' ' + '\n '.join(str(self[i]) for i in range(self.size()))
return data_str + (
f'\n[{torch.typename(self)}(device={self.device}) '
f'of size {len(self)}]')
def __repr__(self):
return str(self)
@ -84,9 +97,7 @@ class _StorageBase(object):
def clone(self):
"""Returns a copy of this storage"""
device = self.get_device() if self.is_cuda else -1
with torch.cuda.device(device):
return type(self)(self.nbytes()).copy_(self)
return type(self)(self.nbytes(), device=self.device).copy_(self)
def tolist(self):
"""Returns a list containing the elements of this storage"""
@ -94,7 +105,10 @@ class _StorageBase(object):
def cpu(self):
"""Returns a CPU copy of this storage if it's not already on the CPU"""
return _type(self, getattr(torch, self.__class__.__name__))
if self.device.type != 'cpu':
return torch._UntypedStorage(self.size()).copy_(self, False)
else:
return self
def _to(self, dtype):
if not isinstance(dtype, torch.dtype):
@ -157,7 +171,7 @@ class _StorageBase(object):
if self.is_cuda:
raise TypeError(f"cannot pin '{self.type()}' only CPU memory can be pinned")
import torch.cuda
allocator = torch.cuda._host_allocator() # type: ignore[attr-defined]
allocator = torch.cuda.memory._host_allocator() # type: ignore[attr-defined]
return type(self)(self.size(), allocator=allocator).copy_(self)
def share_memory_(self):
@ -173,26 +187,31 @@ class _StorageBase(object):
if self.is_cuda:
pass # CUDA doesn't use POSIX shared memory
elif get_sharing_strategy() == 'file_system':
self._share_filename_()
self._share_filename_cpu_()
else:
self._share_fd_()
self._share_fd_cpu_()
return self
@classmethod
def _new_shared(cls, size):
def _new_shared(cls, size, *, device='cpu'):
"""Creates a new storage in shared memory with the same data type"""
from torch.multiprocessing import get_sharing_strategy
if cls.is_cuda:
return cls(size)
device = torch.device(device)
if device.type == 'cuda':
return cls(size, device=device)
elif get_sharing_strategy() == 'file_system':
return cls._new_using_filename(size)
return cls._new_using_filename_cpu(size)
else:
return cls._new_using_fd(size)
return cls._new_using_fd_cpu(size)
def _untyped(self):
return self
class _UntypedStorage(torch._C.StorageBase, _StorageBase):
pass
def _load_from_bytes(b):
return torch.load(io.BytesIO(b))
@ -320,7 +339,7 @@ class _TypedStorage:
"\nNo positional arguments should be given when using "
"'wrap_storage'")
if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
if not isinstance(wrap_storage, torch._UntypedStorage):
raise TypeError(
arg_error_msg +
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
@ -371,7 +390,7 @@ class _TypedStorage:
self.dtype = dtype
if not isinstance(wrap_storage, (torch._UntypedStorage, torch.cuda._UntypedStorage)):
if not isinstance(wrap_storage, torch._UntypedStorage):
raise TypeError(
arg_error_msg +
f"\nArgument 'wrap_storage' must be _UntypedStorage, but got {type(wrap_storage)}")
@ -382,23 +401,16 @@ class _TypedStorage:
self.dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device('cpu' if device is None else device)
if device.type == 'cpu':
untyped_storage_class = torch._UntypedStorage
elif device.type == 'cuda':
untyped_storage_class = torch.cuda._UntypedStorage
else:
raise RuntimeError(f"Storage device not recognized: {device}")
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
if device.type == 'cuda':
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
if len(args) == 0:
self._storage = untyped_storage_class()
self._storage = torch._UntypedStorage(device=device)
elif len(args) == 1:
if _isint(args[0]):
self._storage = untyped_storage_class(int(args[0]) * self.element_size())
self._storage = torch._UntypedStorage(int(args[0]) * self.element_size(), device=device)
elif isinstance(args[0], collections.abc.Sequence):
self._storage = _get_storage_from_sequence(args[0], self.dtype, device)
else:
@ -420,16 +432,12 @@ class _TypedStorage:
return self._storage
def _new_wrapped_storage(self, untyped_storage):
module = eval(untyped_storage.__module__)
assert type(untyped_storage) == module._UntypedStorage
assert type(untyped_storage) == torch._UntypedStorage
if type(self) == _TypedStorage:
return _TypedStorage(wrap_storage=untyped_storage, dtype=self.dtype)
else:
# NOTE: We need to use the module of untyped_storage in case self's
# module is different, e.g. if self is on CPU and untyped_storage
# is on CUDA, and vice versa
return getattr(module, type(self).__name__)(wrap_storage=untyped_storage)
return type(self)(wrap_storage=untyped_storage)
def __len__(self):
return self._storage.nbytes() // self.element_size()
@ -505,7 +513,7 @@ class _TypedStorage:
tmp_tensor = torch.tensor([], dtype=self.dtype, device=self.device).set_(self)
return tmp_tensor[idx_wrapped].item()
def copy_(self, source: T, non_blocking=None):
def copy_(self, source: T, non_blocking: bool = None):
self._storage.copy_(source._untyped(), non_blocking)
return self
@ -527,7 +535,7 @@ class _TypedStorage:
def cuda(self, device=None, non_blocking=False, **kwargs) -> T:
if self.dtype in [torch.quint8, torch.quint4x2, torch.quint2x4, torch.qint32, torch.qint8]:
raise RuntimeError("Cannot create CUDA storage with quantized dtype")
cuda_storage = self._storage.cuda(device, non_blocking, **kwargs)
cuda_storage: torch._UntypedStorage = self._storage.cuda(device, non_blocking, **kwargs)
return self._new_wrapped_storage(cuda_storage)
def element_size(self):
@ -585,13 +593,12 @@ class _TypedStorage:
self._storage.share_memory_()
return self
def _new_shared(self, size):
def _new_shared(self, size, *, device=None):
"""Creates a new storage in shared memory with the same data type"""
if self.is_cuda:
untyped_cls = torch.cuda._UntypedStorage
else:
untyped_cls = torch._UntypedStorage
untyped_storage = untyped_cls._new_shared(size * self.element_size())
if device is None:
device = 'cpu'
device = torch.device(device)
untyped_storage = torch._UntypedStorage._new_shared(size * self.element_size(), device=device)
return _TypedStorage(
wrap_storage=untyped_storage,
dtype=self.dtype)
@ -636,16 +643,9 @@ class _TypedStorage:
if cls == _TypedStorage:
dtype = torch.get_default_dtype() if dtype is None else dtype
device = torch.device('cpu' if device is None else device)
if device.type == 'cpu':
untyped_cls = torch._UntypedStorage
elif device.type == 'cuda':
untyped_cls = torch.cuda._UntypedStorage
else:
raise RuntimeError(
f"_TypedStorage.from_buffer: device '{device}' not recognized")
untyped_storage: Union[torch._UntypedStorage, torch.cuda._UntypedStorage]
untyped_storage = untyped_cls.from_buffer(*args, dtype=dtype, **kwargs)
if device.type != 'cpu':
raise RuntimeError(f'_TypedStorage.from_buffer: Not available for device {device.type}')
untyped_storage: torch._UntypedStorage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
else:
if dtype is not None or len(args) == 5:
@ -658,7 +658,7 @@ class _TypedStorage:
"_UntypedStorage.from_buffer and _TypedStorage.from_buffer"))
dtype = cls.dtype
untyped_storage = eval(cls.__module__)._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
untyped_storage = torch._UntypedStorage.from_buffer(*args, dtype=dtype, **kwargs)
return _TypedStorage(wrap_storage=untyped_storage, dtype=dtype)
@ -770,10 +770,10 @@ class _TypedStorage:
@classmethod
def _new_shared_cuda(cls, *args, **kwargs):
return torch.cuda._UntypedStorage._new_shared_cuda(*args, **kwargs)
return torch._UntypedStorage._new_shared_cuda(*args, **kwargs)
def _share_filename_(self, *args, **kwargs):
manager_handle, storage_handle, size = self._storage._share_filename_(*args, **kwargs)
def _share_filename_cpu_(self, *args, **kwargs):
manager_handle, storage_handle, size = self._storage._share_filename_cpu_(*args, **kwargs)
return manager_handle, storage_handle, size // self.element_size()
def _shared_decref(self):
@ -781,23 +781,14 @@ class _TypedStorage:
return self
@classmethod
def _release_ipc_counter(cls, *args, device=None, **kwargs):
device = torch.device('cpu' if device is None else device)
if device.type == 'cpu':
untyped_cls = torch._UntypedStorage
elif device.type == 'cuda':
untyped_cls = torch.cuda._UntypedStorage
else:
raise RuntimeError(f"device {device} not recognized")
return untyped_cls._release_ipc_counter(*args, **kwargs)
def _release_ipc_counter_cuda(cls, *args, device=None, **kwargs):
return torch._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
def _shared_incref(self, *args, **kwargs):
return self._storage._shared_incref(*args, **kwargs)
def _share_fd_(self, *args, **kwargs):
fd, size = self._storage._share_fd_(*args, **kwargs)
def _share_fd_cpu_(self, *args, **kwargs):
fd, size = self._storage._share_fd_cpu_(*args, **kwargs)
return fd, size // self.element_size()
def _get_legacy_storage_class(self):
@ -837,13 +828,13 @@ class _LegacyStorage(_TypedStorage, metaclass=_LegacyStorageMeta):
return cls(wrap_storage=untyped_storage)
@classmethod
def _release_ipc_counter(cls, *args, **kwargs):
return eval(cls.__module__)._UntypedStorage._release_ipc_counter(*args, **kwargs)
def _release_ipc_counter_cuda(cls, *args, **kwargs):
return eval(cls.__module__)._UntypedStorage._release_ipc_counter_cuda(*args, **kwargs)
@classmethod
def _new_shared_filename(cls, manager, obj, size):
def _new_shared_filename_cpu(cls, manager, obj, size):
bytes_size = size * torch._utils._element_size(cls.dtype)
return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename(manager, obj, bytes_size))
return cls(wrap_storage=eval(cls.__module__)._UntypedStorage._new_shared_filename_cpu(manager, obj, bytes_size))
def _get_dtype_from_pickle_storage_type(pickle_storage_type: str):
try:

View File

@ -136,7 +136,7 @@ def default_collate(batch):
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem.storage()._new_shared(numel)
storage = elem.storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \