diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index c19925acb8a3..ef8f8deb4973 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -146,12 +146,13 @@ inline Tensor new_qtensor( auto scalar_type = typeMetaToScalarType(dtype); int64_t size_bytes = get_sub_byte_tensor_size(sizes, dtype.itemsize(), scalar_type); - auto storage = c10::make_intrusive( + auto storage = make_storage_impl( StorageImpl::use_byte_size_t(), size_bytes, allocator->allocate(size_bytes), allocator, - /*resizable=*/true); + /*resizable=*/true, + device); auto tensor = detail::make_tensor( storage, at::DispatchKeySet(tensorDispatchKey), dtype, quantizer); get_qtensorimpl(tensor)->set_sizes_contiguous(sizes); diff --git a/c10/core/StorageImpl.cpp b/c10/core/StorageImpl.cpp index eb7312527f24..8196acdd9457 100644 --- a/c10/core/StorageImpl.cpp +++ b/c10/core/StorageImpl.cpp @@ -36,4 +36,43 @@ StorageImplCreateHelper GetStorageImplCreate(DeviceType t) { return StorageImplCreate[device_type]; } +c10::intrusive_ptr make_storage_impl( + c10::StorageImpl::use_byte_size_t use_byte_size, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator* allocator, + bool resizable, + c10::optional device_opt) { + // This will be non-nullptr only when there is a custom StorageImpl + // constructor for the given device + c10::StorageImplCreateHelper fptr = nullptr; + if (device_opt.has_value()) { + // We only need to check this here as this is the only case where we can + // have a device that is not CPU (and thus for which the StorageImpl + // constructor can be overwritten). + fptr = c10::GetStorageImplCreate(device_opt.value().type()); + } + + if (fptr != nullptr) { + return fptr( + use_byte_size, + std::move(size_bytes), + std::move(data_ptr), + allocator, + resizable); + } + + // Create a c10::StorageImpl object. + if (data_ptr != nullptr) { + return c10::make_intrusive( + use_byte_size, + std::move(size_bytes), + std::move(data_ptr), + allocator, + resizable); + } + return c10::make_intrusive( + use_byte_size, std::move(size_bytes), allocator, resizable); +} + } // namespace c10 diff --git a/c10/core/StorageImpl.h b/c10/core/StorageImpl.h index ba1cef6f0e19..4b7d837ae8b4 100644 --- a/c10/core/StorageImpl.h +++ b/c10/core/StorageImpl.h @@ -257,6 +257,7 @@ struct C10_API StorageImpl : public c10::intrusive_ptr_target { using StorageImplCreateHelper = intrusive_ptr (*)( StorageImpl::use_byte_size_t, SymInt size_bytes, + DataPtr data_ptr, Allocator* allocator, bool resizable); @@ -264,4 +265,12 @@ C10_API void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr); C10_API StorageImplCreateHelper GetStorageImplCreate(DeviceType t); +C10_API c10::intrusive_ptr make_storage_impl( + c10::StorageImpl::use_byte_size_t use_byte_size, + c10::SymInt size_bytes, + c10::DataPtr data_ptr, + c10::Allocator* allocator, + bool resizable, + c10::optional device_opt); + } // namespace c10 diff --git a/test/cpp_extensions/open_registration_extension.cpp b/test/cpp_extensions/open_registration_extension.cpp index 5f2eac14aeb8..26f6a1c4fde6 100644 --- a/test/cpp_extensions/open_registration_extension.cpp +++ b/test/cpp_extensions/open_registration_extension.cpp @@ -137,10 +137,17 @@ void custom_set_backend_meta(const at::Tensor& t) { // A dummy storageImpl for our custom device, that secretly uses the CPU c10::intrusive_ptr make_custom_storage_impl(c10::StorageImpl::use_byte_size_t, c10::SymInt size_bytes, + c10::DataPtr data_ptr, c10::Allocator* allocator, bool resizable) { - c10::intrusive_ptr custom_storage_impl = c10::make_intrusive( + c10::intrusive_ptr custom_storage_impl; + if (data_ptr == nullptr){ + custom_storage_impl = c10::make_intrusive( c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable); + } else { + custom_storage_impl = c10::make_intrusive( + c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable); + } storageImpl_counter += 1; return custom_storage_impl; } diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 3a99423d7094..57138ad2b069 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -263,6 +263,9 @@ class TestCppExtensionOpenRgistration(common.TestCase): self.assertFalse(self.module.custom_storageImpl_called()) z3 = z3.foo() self.assertTrue(self.module.custom_storageImpl_called()) + self.assertFalse(self.module.custom_storageImpl_called()) + z3 = z3[0:3] + self.assertTrue(self.module.custom_storageImpl_called()) def test_open_device_storage_pin_memory(): torch.utils.rename_privateuse1_backend('foo') diff --git a/torch/csrc/Storage.cpp b/torch/csrc/Storage.cpp index e19e0ae5a105..93dbc9c09bb2 100644 --- a/torch/csrc/Storage.cpp +++ b/torch/csrc/Storage.cpp @@ -290,72 +290,6 @@ static void THPStorage_subclass_dealloc(PyObject* self) { Py_DECREF(type); } -c10::intrusive_ptr make_storage_impl( - c10::StorageImpl::use_byte_size_t use_byte_size, - c10::SymInt size_bytes, - c10::Allocator* allocator, - bool resizable, - c10::optional allocator_opt, - c10::optional device_opt) { - at::OptionalDeviceGuard device_guard; - // This will be non-nullptr only when there is a custom StorageImpl - // constructor for the given device - c10::StorageImplCreateHelper fptr = nullptr; - // For directly passing allocator scenarios, only c10::StorageImpl objects can - // be created. If you need to create a storageimpl object of a subclass, you - // need to pass in the device information. - if (allocator_opt.has_value()) { - // NOLINTNEXTLINE(performance-no-int-to-ptr) - allocator = reinterpret_cast(allocator_opt.value()); - } else if (device_opt.has_value()) { - at::Device device = device_opt.value(); - // We only need to check this here as this is the only case where we can - // have a device that is not CPU (and thus for which the StorageImpl - // constructor can be overwritten). - fptr = c10::GetStorageImplCreate(device.type()); - if (device.type() == at::kCPU) { - allocator = c10::GetDefaultCPUAllocator(); -#ifdef USE_CUDA - } else if (device.type() == at::kCUDA) { - at::globalContext().lazyInitCUDA(); - allocator = c10::cuda::CUDACachingAllocator::get(); -#endif -#ifdef USE_MPS - } else if (device.type() == at::kMPS) { - allocator = at::mps::GetMPSAllocator(); -#endif - // NOLINTBEGIN(bugprone-branch-clone) - } else if (device.type() == at::DeviceType::XPU) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::HPU) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::Meta) { - allocator = c10::GetAllocator(device.type()); - } else if (device.type() == at::DeviceType::PrivateUse1) { - at::globalContext().lazyInitPrivateUse1(); - allocator = c10::GetAllocator(device.type()); - } else { - // NOLINTEND(bugprone-branch-clone) - TORCH_CHECK( - false, - THPStorageStr, - "(): Storage device not recognized: ", - device.type()); - } - device_guard.reset_device(device); - } else { - allocator = c10::GetDefaultCPUAllocator(); - } - - if (fptr != nullptr) { - return fptr(use_byte_size, std::move(size_bytes), allocator, resizable); - } - - // Create a c10::StorageImpl object. - return c10::make_intrusive( - use_byte_size, std::move(size_bytes), allocator, resizable); -} - static PyObject* THPStorage_pynew( PyTypeObject* type, PyObject* args, @@ -393,6 +327,46 @@ static PyObject* THPStorage_pynew( PyObject* self = nullptr; c10::Allocator* allocator = nullptr; + at::OptionalDeviceGuard device_guard; + + if (allocator_opt.has_value()) { + // NOLINTNEXTLINE(performance-no-int-to-ptr) + allocator = reinterpret_cast(allocator_opt.value()); + } else if (device_opt.has_value()) { + at::Device device = device_opt.value(); + if (device.type() == at::kCPU) { + allocator = c10::GetDefaultCPUAllocator(); +#ifdef USE_CUDA + } else if (device.type() == at::kCUDA) { + at::globalContext().lazyInitCUDA(); + allocator = c10::cuda::CUDACachingAllocator::get(); +#endif +#ifdef USE_MPS + } else if (device.type() == at::kMPS) { + allocator = at::mps::GetMPSAllocator(); +#endif + // NOLINTBEGIN(bugprone-branch-clone) + } else if (device.type() == at::DeviceType::XPU) { + allocator = c10::GetAllocator(device.type()); + } else if (device.type() == at::DeviceType::HPU) { + allocator = c10::GetAllocator(device.type()); + } else if (device.type() == at::DeviceType::Meta) { + allocator = c10::GetAllocator(device.type()); + } else if (device.type() == at::DeviceType::PrivateUse1) { + at::globalContext().lazyInitPrivateUse1(); + allocator = c10::GetAllocator(device.type()); + } else { + // NOLINTEND(bugprone-branch-clone) + TORCH_CHECK( + false, + THPStorageStr, + "(): Storage device not recognized: ", + device.type()); + } + device_guard.reset_device(device); + } else { + allocator = c10::GetDefaultCPUAllocator(); + } // torch.Storage(*, ...) if (r.idx == 0) { @@ -401,9 +375,9 @@ static PyObject* THPStorage_pynew( make_storage_impl( c10::StorageImpl::use_byte_size_t(), 0, + at::DataPtr(), allocator, /*resizable=*/true, - allocator_opt, device_opt), c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); @@ -415,9 +389,9 @@ static PyObject* THPStorage_pynew( make_storage_impl( c10::StorageImpl::use_byte_size_t(), size, + at::DataPtr(), allocator, /*resizable=*/true, - allocator_opt, device_opt), c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); @@ -440,9 +414,9 @@ static PyObject* THPStorage_pynew( make_storage_impl( c10::StorageImpl::use_byte_size_t(), length, + at::DataPtr(), allocator, /*resizable=*/true, - allocator_opt, device_opt), c10::impl::PyInterpreterStatus::DEFINITELY_UNINITIALIZED); THPObjectPtr item; @@ -522,7 +496,8 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { at::StorageImpl* old_storage_impl = storage.unsafeGetStorageImpl(); c10::raw::intrusive_ptr::incref(old_storage_impl); - auto new_storage_impl = c10::make_intrusive( + c10::optional device_opt = old_storage_impl->device(); + auto new_storage_impl = make_storage_impl( c10::StorageImpl::use_byte_size_t(), #ifdef THQUANTIZED slicelength * sizeof(quantized_t), @@ -537,7 +512,8 @@ static PyObject* THPStorage_get(THPStorage* self, PyObject* index) { }, old_storage_impl->device()), old_storage_impl->allocator(), - /* resizable */ false); + /* resizable */ false, + device_opt); PyObject* _ret = THPStorage_NewWithStorage( Py_TYPE(self),