mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Part of #91395 Pull Request resolved: https://github.com/pytorch/pytorch/pull/96801 Approved by: https://github.com/ezyang
462 lines
14 KiB
C++
462 lines
14 KiB
C++
#include <torch/csrc/python_headers.h>
|
|
#ifdef _MSC_VER
|
|
#include <c10/util/win32-headers.h>
|
|
#endif
|
|
#include <structmember.h>
|
|
|
|
#include <ATen/mps/MPSDevice.h>
|
|
#include <c10/core/CPUAllocator.h>
|
|
#include <libshm.h>
|
|
#include <torch/csrc/CudaIPCTypes.h>
|
|
#include <torch/csrc/Device.h>
|
|
#include <torch/csrc/DynamicTypes.h>
|
|
#include <torch/csrc/StorageMethods.h>
|
|
#include <torch/csrc/StorageSharing.h>
|
|
#include <torch/csrc/THP.h>
|
|
#include <torch/csrc/autograd/utils/wrap_outputs.h>
|
|
#include <torch/csrc/copy_utils.h>
|
|
#include <torch/csrc/utils/python_arg_parser.h>
|
|
|
|
#include <c10/util/intrusive_ptr.h>
|
|
#include <fmt/format.h>
|
|
|
|
template <>
|
|
void THPPointer<c10::StorageImpl>::free() {
|
|
if (ptr) {
|
|
c10::raw::intrusive_ptr::decref(ptr);
|
|
}
|
|
}
|
|
|
|
PyObject* THPStorageClass = nullptr;
|
|
|
|
PyObject* THPStorage_New(c10::Storage storage) {
|
|
PyTypeObject* type = (PyTypeObject*)THPStorageClass;
|
|
PyObject* obj = type->tp_alloc(type, 0);
|
|
if (obj) {
|
|
((THPStorage*)obj)->cdata =
|
|
c10::MaybeOwned<c10::Storage>::owned(std::move(storage));
|
|
}
|
|
return obj;
|
|
}
|
|
|
|
static void THPStorage_subclass_dealloc(PyObject* self) {
|
|
THPStorage* _self = (THPStorage*)self;
|
|
// Some subclass of StorageBase are GC-tracked objects even
|
|
// though the base class is not.
|
|
auto* type = Py_TYPE(self);
|
|
if (PyType_HasFeature(type, Py_TPFLAGS_HAVE_GC) != 0) {
|
|
PyObject_GC_UnTrack(self);
|
|
}
|
|
_self->cdata.~MaybeOwned<c10::Storage>();
|
|
Py_TYPE(_self)->tp_free(self);
|
|
}
|
|
|
|
static PyObject* THPStorage_pynew(
|
|
PyTypeObject* type,
|
|
PyObject* args,
|
|
PyObject* kwargs) {
|
|
HANDLE_TH_ERRORS
|
|
TORCH_CHECK(
|
|
type != &THPStorageType,
|
|
"Cannot directly construct StorageBase; subclass it and then construct that");
|
|
static torch::PythonArgParser parser({
|
|
THPStorageStr "(*, int64_t allocator=None, Device device=None)",
|
|
THPStorageStr
|
|
"(int64_t size, *, int64_t allocator=None, Device device=None)",
|
|
THPStorageStr
|
|
"(PyObject* sequence, *, int64_t allocator=None, Device device=None)",
|
|
});
|
|
torch::ParsedArgs<3> parsed_args;
|
|
auto r = parser.parse(args, kwargs, parsed_args);
|
|
|
|
int64_t allocator_arg_idx = 0;
|
|
int64_t device_arg_idx = 1;
|
|
|
|
if (r.idx > 0) {
|
|
allocator_arg_idx = 1;
|
|
device_arg_idx = 2;
|
|
}
|
|
|
|
c10::optional<int64_t> allocator_opt = r.toInt64Optional(allocator_arg_idx);
|
|
c10::optional<at::Device> device_opt = r.deviceOptional(device_arg_idx);
|
|
|
|
TORCH_CHECK(
|
|
!allocator_opt.has_value() || !device_opt.has_value(),
|
|
THPStorageStr,
|
|
"(): only one or neither of 'allocator' or 'device' can ",
|
|
"be given, but not both");
|
|
|
|
THPStoragePtr self((THPStorage*)type->tp_alloc(type, 0));
|
|
THPUtils_assert(self, "failed to allocate a " THPStorageStr " object");
|
|
c10::Allocator* allocator = nullptr;
|
|
at::OptionalDeviceGuard device_guard;
|
|
|
|
if (allocator_opt.has_value()) {
|
|
allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
|
|
} else if (device_opt.has_value()) {
|
|
at::Device device = device_opt.value();
|
|
if (device.type() == at::kCPU) {
|
|
allocator = c10::GetDefaultCPUAllocator();
|
|
#ifdef USE_CUDA
|
|
} else if (device.type() == at::kCUDA) {
|
|
at::globalContext().lazyInitCUDA();
|
|
allocator = c10::cuda::CUDACachingAllocator::get();
|
|
#endif
|
|
#ifdef USE_MPS
|
|
} else if (device.type() == at::kMPS) {
|
|
allocator = at::mps::GetMPSAllocator();
|
|
#endif
|
|
} else if (device.type() == at::DeviceType::XPU) {
|
|
allocator = c10::GetAllocator(device.type());
|
|
} else if (device.type() == at::DeviceType::Meta) {
|
|
allocator = c10::GetAllocator(device.type());
|
|
} else {
|
|
TORCH_CHECK(
|
|
false,
|
|
THPStorageStr,
|
|
"(): Storage device not recognized: ",
|
|
device.type());
|
|
}
|
|
device_guard.reset_device(device);
|
|
} else {
|
|
allocator = c10::GetDefaultCPUAllocator();
|
|
}
|
|
|
|
// torch.Storage(*, ...)
|
|
if (r.idx == 0) {
|
|
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
|
|
c10::make_intrusive<c10::StorageImpl>(
|
|
c10::StorageImpl::use_byte_size_t(),
|
|
0,
|
|
allocator,
|
|
/*resizable=*/true));
|
|
return (PyObject*)self.release();
|
|
|
|
// torch.Storage(size, *, ...)
|
|
} else if (r.idx == 1) {
|
|
int64_t size = r.toInt64(0);
|
|
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
|
|
c10::make_intrusive<c10::StorageImpl>(
|
|
c10::StorageImpl::use_byte_size_t(),
|
|
size,
|
|
allocator,
|
|
/*resizable=*/true));
|
|
return (PyObject*)self.release();
|
|
|
|
// torch.Storage(sequence, *, ...)
|
|
} else if (r.idx == 2) {
|
|
PyObject* sequence = r.pyobject(0);
|
|
Py_ssize_t length = PySequence_Length(sequence);
|
|
TORCH_CHECK(
|
|
PySequence_Check(sequence),
|
|
THPStorageStr,
|
|
"(): Expected a sequence type, but got ",
|
|
THPUtils_typename(sequence));
|
|
TORCH_CHECK(
|
|
length >= 0,
|
|
THPStorageStr,
|
|
"(): Could not obtain the length of sequence of type ",
|
|
THPUtils_typename(sequence));
|
|
self->cdata = c10::MaybeOwned<c10::Storage>::owned(
|
|
c10::make_intrusive<c10::StorageImpl>(
|
|
c10::StorageImpl::use_byte_size_t(),
|
|
length,
|
|
allocator,
|
|
/*resizable=*/true));
|
|
THPObjectPtr item;
|
|
try {
|
|
for (Py_ssize_t i = 0; i < length; i++) {
|
|
item = PySequence_GetItem(sequence, i);
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
uint8_t value = THPByteUtils_unpackReal(item.get());
|
|
const auto& storage = THPStorage_Unpack(self);
|
|
if (allocator == c10::GetDefaultCPUAllocator()) {
|
|
storage.unsafe_data<uint8_t>()[i] = value;
|
|
} else {
|
|
// TODO: this might be slow - consider batched updates?
|
|
storage_set(storage, i, value);
|
|
}
|
|
}
|
|
} catch (const std::exception& e) {
|
|
THPUtils_setError(
|
|
THPStorageStr
|
|
"(): tried to construct a storage from a sequence (%s), "
|
|
"but one of the items was of type %s instead of int",
|
|
THPUtils_typename(sequence),
|
|
THPUtils_typename(item.get()));
|
|
return nullptr;
|
|
}
|
|
return (PyObject*)self.release();
|
|
}
|
|
Py_RETURN_NONE;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static Py_ssize_t THPStorage_length(THPStorage* self) {
|
|
HANDLE_TH_ERRORS
|
|
return THPStorage_Unpack(self).nbytes();
|
|
END_HANDLE_TH_ERRORS_RET(-1)
|
|
}
|
|
|
|
static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
|
|
HANDLE_TH_ERRORS
|
|
const auto& storage = THPStorage_Unpack(self);
|
|
/* Integer index */
|
|
if (THPUtils_checkLong(index)) {
|
|
int64_t nindex = THPUtils_unpackLong(index);
|
|
if (nindex < 0)
|
|
nindex += storage.nbytes();
|
|
if (nindex < 0 || nindex >= static_cast<int64_t>(storage.nbytes())) {
|
|
PyErr_SetString(
|
|
PyExc_IndexError,
|
|
fmt::format(
|
|
"index {} out of range for storage of size {}",
|
|
nindex,
|
|
storage.nbytes()));
|
|
return nullptr;
|
|
}
|
|
uint8_t value = storage_get(storage, nindex);
|
|
return THPByteUtils_newReal(value);
|
|
/* Slice index */
|
|
} else if (PySlice_Check(index)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
Py_ssize_t start, stop, slicelength, step;
|
|
int64_t len = storage.nbytes();
|
|
if (PySlice_GetIndicesEx(index, len, &start, &stop, &step, &slicelength) !=
|
|
0)
|
|
return nullptr;
|
|
if (step != 1) {
|
|
THPUtils_setError(
|
|
"Trying to slice with a step of %lld, but only a step of "
|
|
"1 is supported",
|
|
(long long)step);
|
|
return nullptr;
|
|
}
|
|
|
|
const auto& storage = THPStorage_Unpack(self);
|
|
uint8_t* data = storage.data<uint8_t>();
|
|
|
|
at::StorageImpl* old_storage_impl = storage.unsafeGetStorageImpl();
|
|
c10::raw::intrusive_ptr::incref(old_storage_impl);
|
|
auto new_storage_impl = c10::make_intrusive<at::StorageImpl>(
|
|
c10::StorageImpl::use_byte_size_t(),
|
|
#ifdef THQUANTIZED
|
|
slicelength * sizeof(quantized_t),
|
|
#else
|
|
slicelength,
|
|
#endif
|
|
at::DataPtr(
|
|
static_cast<void*>(data + start),
|
|
old_storage_impl,
|
|
[](void* s) {
|
|
c10::raw::intrusive_ptr::decref(static_cast<at::StorageImpl*>(s));
|
|
},
|
|
old_storage_impl->device()),
|
|
old_storage_impl->allocator(),
|
|
/* resizable */ false);
|
|
|
|
PyObject* _ret = THPStorage_New(std::move(new_storage_impl));
|
|
return _ret;
|
|
}
|
|
PyErr_Format(
|
|
PyExc_TypeError,
|
|
"can't index a " THPStorageStr " with %s",
|
|
THPUtils_typename(index));
|
|
return nullptr;
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
static int THPStorage_set(THPStorage* self, PyObject* index, PyObject* value) {
|
|
HANDLE_TH_ERRORS
|
|
if (!THPByteUtils_checkReal(value)) {
|
|
THPUtils_setError(
|
|
"can only set storage content with a int types, but got "
|
|
"%s instead",
|
|
THPUtils_typename(value));
|
|
return -1;
|
|
}
|
|
|
|
uint8_t rvalue = THPByteUtils_unpackReal(value);
|
|
const auto& storage = THPStorage_Unpack(self);
|
|
if (THPUtils_checkLong(index)) {
|
|
int64_t nindex = THPUtils_unpackLong(index);
|
|
storage_set(storage, nindex, rvalue);
|
|
return 0;
|
|
} else if (PySlice_Check(index)) {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
|
|
Py_ssize_t start, stop, slicelength, step;
|
|
int64_t len = storage.nbytes();
|
|
if (PySlice_GetIndicesEx(index, len, &start, &stop, &step, &slicelength) !=
|
|
0)
|
|
return -1;
|
|
if (step != 1) {
|
|
THPUtils_setError(
|
|
"Trying to slice with a step of %lld, but only a step of "
|
|
"1 is supported",
|
|
(long long)step);
|
|
return 0;
|
|
}
|
|
// TODO: check the bounds only once
|
|
// TODO: fill?
|
|
for (; start < stop; start++)
|
|
storage_set(storage, start, rvalue);
|
|
return 0;
|
|
}
|
|
THPUtils_setError(
|
|
"can't index a " THPStorageStr " with %s", THPUtils_typename(index));
|
|
return -1;
|
|
END_HANDLE_TH_ERRORS_RET(-1)
|
|
}
|
|
|
|
static PyMappingMethods THPStorage_mappingmethods = {
|
|
(lenfunc)THPStorage_length,
|
|
(binaryfunc)THPStorage_get,
|
|
(objobjargproc)THPStorage_set};
|
|
|
|
struct THPStorageMeta {
|
|
PyHeapTypeObject base;
|
|
};
|
|
|
|
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs);
|
|
|
|
PyTypeObject THPStorageMetaType = {
|
|
PyVarObject_HEAD_INIT(
|
|
DEFERRED_ADDRESS(&PyType_Type),
|
|
0) "torch._C._StorageMeta", /* tp_name */
|
|
sizeof(THPStorageMeta), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
nullptr, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
nullptr, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
nullptr, /* tp_methods */
|
|
nullptr, /* tp_members */
|
|
nullptr, /* tp_getset */
|
|
DEFERRED_ADDRESS(&PyType_Type), /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
THPStorageMetaType_init, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
nullptr, /* tp_new */
|
|
};
|
|
|
|
// TODO: implement equality
|
|
PyTypeObject THPStorageType = {
|
|
PyVarObject_HEAD_INIT(
|
|
&THPStorageMetaType,
|
|
0) "torch._C.StorageBase", /* tp_name */
|
|
sizeof(THPStorage), /* tp_basicsize */
|
|
0, /* tp_itemsize */
|
|
nullptr, /* tp_dealloc */
|
|
0, /* tp_vectorcall_offset */
|
|
nullptr, /* tp_getattr */
|
|
nullptr, /* tp_setattr */
|
|
nullptr, /* tp_reserved */
|
|
nullptr, /* tp_repr */
|
|
nullptr, /* tp_as_number */
|
|
nullptr, /* tp_as_sequence */
|
|
&THPStorage_mappingmethods, /* tp_as_mapping */
|
|
nullptr, /* tp_hash */
|
|
nullptr, /* tp_call */
|
|
nullptr, /* tp_str */
|
|
nullptr, /* tp_getattro */
|
|
nullptr, /* tp_setattro */
|
|
nullptr, /* tp_as_buffer */
|
|
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, /* tp_flags */
|
|
nullptr, /* tp_doc */
|
|
nullptr, /* tp_traverse */
|
|
nullptr, /* tp_clear */
|
|
nullptr, /* tp_richcompare */
|
|
0, /* tp_weaklistoffset */
|
|
nullptr, /* tp_iter */
|
|
nullptr, /* tp_iternext */
|
|
nullptr,
|
|
/* will be assigned in init */ /* tp_methods */
|
|
nullptr,
|
|
/* will be assigned in init */ /* tp_members */
|
|
nullptr, /* tp_getset */
|
|
nullptr, /* tp_base */
|
|
nullptr, /* tp_dict */
|
|
nullptr, /* tp_descr_get */
|
|
nullptr, /* tp_descr_set */
|
|
0, /* tp_dictoffset */
|
|
nullptr, /* tp_init */
|
|
nullptr, /* tp_alloc */
|
|
THPStorage_pynew, /* tp_new */
|
|
};
|
|
|
|
int THPStorageMetaType_init(PyObject* cls, PyObject* args, PyObject* kwargs) {
|
|
if (PyType_Type.tp_init(cls, args, kwargs) < 0) {
|
|
return -1;
|
|
}
|
|
((PyTypeObject*)cls)->tp_dealloc = (destructor)THPStorage_subclass_dealloc;
|
|
return 0;
|
|
}
|
|
|
|
static PyObject* THPStorage_device(THPStorage* self, void* unused) {
|
|
HANDLE_TH_ERRORS
|
|
return THPDevice_New(THPStorage_Unpack(self).device());
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
PyObject* THPStorage_get_cdata(THPStorage* self, void* unused) {
|
|
HANDLE_TH_ERRORS
|
|
return PyLong_FromVoidPtr(THPStorage_Unpack(self).unsafeGetStorageImpl());
|
|
END_HANDLE_TH_ERRORS
|
|
}
|
|
|
|
typedef PyObject* (*getter)(PyObject*, void*);
|
|
|
|
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays,cppcoreguidelines-avoid-non-const-global-variables)
|
|
static struct PyGetSetDef THPStorage_properties[] = {
|
|
{"device", (getter)THPStorage_device, nullptr, nullptr, nullptr},
|
|
{"_cdata", (getter)THPStorage_get_cdata, nullptr, nullptr, nullptr},
|
|
{nullptr}};
|
|
|
|
bool THPStorage_init(PyObject* module) {
|
|
static std::vector<PyMethodDef> methods;
|
|
THPUtils_addPyMethodDefs(methods, THPStorage_getMethods());
|
|
THPUtils_addPyMethodDefs(methods, THPStorage_getSharingMethods());
|
|
|
|
THPStorageMetaType.tp_base = &PyType_Type;
|
|
if (PyType_Ready(&THPStorageMetaType) < 0)
|
|
return false;
|
|
Py_INCREF(&THPStorageMetaType);
|
|
PyModule_AddObject(module, "_StorageMeta", (PyObject*)&THPStorageMetaType);
|
|
|
|
THPStorageType.tp_methods = methods.data();
|
|
THPStorageType.tp_getset = THPStorage_properties;
|
|
if (PyType_Ready(&THPStorageType) < 0)
|
|
return false;
|
|
Py_INCREF(&THPStorageType);
|
|
PyModule_AddObject(module, "StorageBase", (PyObject*)&THPStorageType);
|
|
return true;
|
|
}
|
|
|
|
void THPStorage_postInit(PyObject* module) {
|
|
THPStorageClass = PyObject_GetAttrString(module, "UntypedStorage");
|
|
if (!THPStorageClass)
|
|
throw python_error();
|
|
}
|