Modify StorageImplCreateHelper (#118459)

I want to use tensor.untyped_storage()[a:b] for ``PrivateUse1`` backend but fail. The code will go into ``THPStorage_get``:
bb6eba189f/torch/csrc/Storage.cpp (L525-L540)

Here ``torch`` will create a new ``c10::StorageImpl`` but not consider about ``PrivateUse1`` backend.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118459
Approved by: https://github.com/albanD
This commit is contained in:
Chen_Liqing
2024-03-07 06:26:51 +00:00
committed by PyTorch MergeBot
parent f848e9c646
commit 291ce86a6c
6 changed files with 109 additions and 74 deletions

View File

@ -146,12 +146,13 @@ inline Tensor new_qtensor(
auto scalar_type = typeMetaToScalarType(dtype);
int64_t size_bytes = get_sub_byte_tensor_size(sizes, dtype.itemsize(), scalar_type);
auto storage = c10::make_intrusive<StorageImpl>(
auto storage = make_storage_impl(
StorageImpl::use_byte_size_t(),
size_bytes,
allocator->allocate(size_bytes),
allocator,
/*resizable=*/true);
/*resizable=*/true,
device);
auto tensor = detail::make_tensor<QTensorImpl>(
storage, at::DispatchKeySet(tensorDispatchKey), dtype, quantizer);
get_qtensorimpl(tensor)->set_sizes_contiguous(sizes);

View File

@ -36,4 +36,43 @@ StorageImplCreateHelper GetStorageImplCreate(DeviceType t) {
return StorageImplCreate[device_type];
}
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
c10::StorageImpl::use_byte_size_t use_byte_size,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable,
c10::optional<at::Device> device_opt) {
// This will be non-nullptr only when there is a custom StorageImpl
// constructor for the given device
c10::StorageImplCreateHelper fptr = nullptr;
if (device_opt.has_value()) {
// We only need to check this here as this is the only case where we can
// have a device that is not CPU (and thus for which the StorageImpl
// constructor can be overwritten).
fptr = c10::GetStorageImplCreate(device_opt.value().type());
}
if (fptr != nullptr) {
return fptr(
use_byte_size,
std::move(size_bytes),
std::move(data_ptr),
allocator,
resizable);
}
// Create a c10::StorageImpl object.
if (data_ptr != nullptr) {
return c10::make_intrusive<c10::StorageImpl>(
use_byte_size,
std::move(size_bytes),
std::move(data_ptr),
allocator,
resizable);
}
return c10::make_intrusive<c10::StorageImpl>(
use_byte_size, std::move(size_bytes), allocator, resizable);
}
} // namespace c10

View File

@ -257,6 +257,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target {
using StorageImplCreateHelper = intrusive_ptr<StorageImpl> (*)(
StorageImpl::use_byte_size_t,
SymInt size_bytes,
DataPtr data_ptr,
Allocator* allocator,
bool resizable);
@ -264,4 +265,12 @@ C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr);
C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t);
C10_API c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
c10::StorageImpl::use_byte_size_t use_byte_size,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable,
c10::optional<at::Device> device_opt);
} // namespace c10

View File

@ -137,10 +137,17 @@ void custom_set_backend_meta(const at::Tensor& t) {
// A dummy storageImpl for our custom device, that secretly uses the CPU
c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
c10::SymInt size_bytes,
c10::DataPtr data_ptr,
c10::Allocator* allocator,
bool resizable) {
c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
if (data_ptr == nullptr){
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
} else {
custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
}
storageImpl_counter += 1;
return custom_storage_impl;
}

View File

@ -263,6 +263,9 @@ class TestCppExtensionOpenRgistration(common.TestCase):
self.assertFalse(self.module.custom_storageImpl_called())
z3 = z3.foo()
self.assertTrue(self.module.custom_storageImpl_called())
self.assertFalse(self.module.custom_storageImpl_called())
z3 = z3[0:3]
self.assertTrue(self.module.custom_storageImpl_called())
def test_open_device_storage_pin_memory():
torch.utils.rename_privateuse1_backend('foo')

View File

@ -290,72 +290,6 @@ static void THPStorage_subclass_dealloc(PyObject* self) {
Py_DECREF(type);
}
c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
c10::StorageImpl::use_byte_size_t use_byte_size,
c10::SymInt size_bytes,
c10::Allocator* allocator,
bool resizable,
c10::optional<int64_t> allocator_opt,
c10::optional<at::Device> device_opt) {
at::OptionalDeviceGuard device_guard;
// This will be non-nullptr only when there is a custom StorageImpl
// constructor for the given device
c10::StorageImplCreateHelper fptr = nullptr;
// For directly passing allocator scenarios, only c10::StorageImpl objects can
// be created. If you need to create a storageimpl object of a subclass, you
// need to pass in the device information.
if (allocator_opt.has_value()) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
} else if (device_opt.has_value()) {
at::Device device = device_opt.value();
// We only need to check this here as this is the only case where we can
// have a device that is not CPU (and thus for which the StorageImpl
// constructor can be overwritten).
fptr = c10::GetStorageImplCreate(device.type());
if (device.type() == at::kCPU) {
allocator = c10::GetDefaultCPUAllocator();
#ifdef USE_CUDA
} else if (device.type() == at::kCUDA) {
at::globalContext().lazyInitCUDA();
allocator = c10::cuda::CUDACachingAllocator::get();
#endif
#ifdef USE_MPS
} else if (device.type() == at::kMPS) {
allocator = at::mps::GetMPSAllocator();
#endif
// NOLINTBEGIN(bugprone-branch-clone)
} else if (device.type() == at::DeviceType::XPU) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::HPU) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::Meta) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::PrivateUse1) {
at::globalContext().lazyInitPrivateUse1();
allocator = c10::GetAllocator(device.type());
} else {
// NOLINTEND(bugprone-branch-clone)
TORCH_CHECK(
false,
THPStorageStr,
"(): Storage device not recognized: ",
device.type());
}
device_guard.reset_device(device);
} else {
allocator = c10::GetDefaultCPUAllocator();
}
if (fptr != nullptr) {
return fptr(use_byte_size, std::move(size_bytes), allocator, resizable);
}
// Create a c10::StorageImpl object.
return c10::make_intrusive<c10::StorageImpl>(
use_byte_size, std::move(size_bytes), allocator, resizable);
}
static PyObject* THPStorage_pynew(
PyTypeObject* type,
PyObject* args,
@ -393,6 +327,46 @@ static PyObject* THPStorage_pynew(
PyObject* self = nullptr;
c10::Allocator* allocator = nullptr;
at::OptionalDeviceGuard device_guard;
if (allocator_opt.has_value()) {
// NOLINTNEXTLINE(performance-no-int-to-ptr)
allocator = reinterpret_cast<c10::Allocator*>(allocator_opt.value());
} else if (device_opt.has_value()) {
at::Device device = device_opt.value();
if (device.type() == at::kCPU) {
allocator = c10::GetDefaultCPUAllocator();
#ifdef USE_CUDA
} else if (device.type() == at::kCUDA) {
at::globalContext().lazyInitCUDA();
allocator = c10::cuda::CUDACachingAllocator::get();
#endif
#ifdef USE_MPS
} else if (device.type() == at::kMPS) {
allocator = at::mps::GetMPSAllocator();
#endif
// NOLINTBEGIN(bugprone-branch-clone)
} else if (device.type() == at::DeviceType::XPU) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::HPU) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::Meta) {
allocator = c10::GetAllocator(device.type());
} else if (device.type() == at::DeviceType::PrivateUse1) {
at::globalContext().lazyInitPrivateUse1();
allocator = c10::GetAllocator(device.type());
} else {
// NOLINTEND(bugprone-branch-clone)
TORCH_CHECK(
false,
THPStorageStr,
"(): Storage device not recognized: ",
device.type());
}
device_guard.reset_device(device);
} else {
allocator = c10::GetDefaultCPUAllocator();
}
// torch.Storage(*, ...)
if (r.idx == 0) {
@ -401,9 +375,9 @@ static PyObject* THPStorage_pynew(
make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
0,
at::DataPtr(),
allocator,
/*resizable=*/true,
allocator_opt,
device_opt),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
@ -415,9 +389,9 @@ static PyObject* THPStorage_pynew(
make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
size,
at::DataPtr(),
allocator,
/*resizable=*/true,
allocator_opt,
device_opt),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
@ -440,9 +414,9 @@ static PyObject* THPStorage_pynew(
make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
length,
at::DataPtr(),
allocator,
/*resizable=*/true,
allocator_opt,
device_opt),
c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED);
THPObjectPtr item;
@ -522,7 +496,8 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
at::StorageImpl* old_storage_impl = storage.unsafeGetStorageImpl();
c10::raw::intrusive_ptr::incref(old_storage_impl);
auto new_storage_impl = c10::make_intrusive<at::StorageImpl>(
c10::optional<at::Device> device_opt = old_storage_impl->device();
auto new_storage_impl = make_storage_impl(
c10::StorageImpl::use_byte_size_t(),
#ifdef THQUANTIZED
slicelength * sizeof(quantized_t),
@ -537,7 +512,8 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) {
},
old_storage_impl->device()),
old_storage_impl->allocator(),
/* resizable */ false);
/* resizable */ false,
device_opt);
PyObject* _ret = THPStorage_NewWithStorage(
Py_TYPE(self),