mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 21:49:24 +08:00
Summary: See Note [Supervisor deleter] for how SupervisedPtr works. This design is not the obvious one, but there were a lot of constraints feeding into it: - It must support the reallocation usage-pattern, where, given an existing Storage, we allocate a new region of memory, copy the existing data to it, and then deallocate the old region of memory. - Creation of a deleter for memory MUST avoid dynamic allocations in the common case. We've done some benchmarking in Caffe2 where dynamic allocation for deleters is ruinously expensive, and it's really hard to avoid these performance tarpits in very general function wrappers like std::function or folly::Function (while benchmarking this, we discovered that folly::Function's move constructor was way more expensive than it should be). - We need to be able to deallocate data that comes from external sources, e.g., dlpack and numpy tensors. Most notably, you often cannot deallocate these with merely the void* data pointer; you need some extra, out-of-band information (e.g., the managing struct) to deallocate it. Sometimes, you may even want to resize data living in an external source! - The "core" allocators need to support being wrapped in a Thrust allocator, so you need to be implement the following two functions: char* allocate(size_t); void deallocate(char*, size_t); - We need to support tensors which contain non-POD, non-trivially copyable data; specifically tensors of std::string. This is an upcoming requirement from Caffe2. It's dirty AF, but it's really useful. - It should use C++ standard library types like std::unique_ptr (which is hugely problematic because std::unique_ptr doesn't call the deleter when the pointer is null.) Here is the billing of changes: - Built-in support for realloc() has been DROPPED ENTIRELY. Instead, you're expected to allocate and then copy from the old memory to the new memory if you want to do a reallocation. This is what you'd generally have expected to occur; and axing realloc() from the design lets us avoid some tricky correctness issues with std::realloc(), namely the fact that we must refuse the realloc if the type of the elements are not trivially copyeable. If it really matters, we can add this back, but there really needs to be a good explanation WHY you need fast resizing reallocations (by in large, people don't resize their storages, and it should be acceptable to have a performance degradation when they do). - TH_STORAGE_FREEMEM is no more; instead, if you want a storage which doesn't free its result, you just give it an empty deleter. - What we used to call an "allocator" (really, a combined object for allocating/deleting) has been split into two concepts, an allocator, and a smart pointer (SupervisedPtr) which knows how to delete data. - Unlike previously, where THAllocator/THCDeviceAllocator could have a per-tensor context storing extra information (e.g., a pointer to the metadata you need to actually free the tensor), there is no context in the allocator or the deleter of the smart pointer; instead, the smart pointer directly holds an owning reference to the metadata necessary to free the data. This metadata is *freshly manufactured* upon every allocation, which permits us to resize tensors even in the absence of built-in support for realloc(). - By default, allocators don't support "raw" allocations and deallocations with raw pointers. This is because some allocations may return a different context every time, in which case you need to reconstruct the context at delete time (because all you got was a void*, not a unique_ptr that carries the deleter). - The diff between at::Allocator and THCDeviceAllocator is a bit larger: - It used to return a cudaError_t. Now, allocators are expected to check the error status immediately and throw an exception if there was an error. It turns out that this is what was immediately done after all occurrences of allocate/release, so it wasn't a big deal (although some subsidiary interfaces had to themselves be converted to not return cudaError_t). There is one notable exception to this, and it is how we handle CUDA OOM: if this occurs, we attempt to return unused memory to the system and try again. This is now handled by a catch-all try-catch block. The cost of catching the exception is probably the least of your worries if you're about to OOM. - It used to take the CUDA stream to perform the allocation on as an argument. However, it turned out that all call sites, this stream was the stream for the current device. So we can push this into the allocator (and the choice, in the future, could be made explicitly by twiddling thread local state.) - It held two extra methods, emptyCache and cacheInfo, specifically for interacting with some state in THCCachingAllocator. But this "generality" was a lie, since THCCachingAllocator was the only allocator that actually implemented these methods, and there is actually a bunch of code in THC which assumes that it is the caching allocator that is the underlying allocator for CUDA allocations. So I folded these two methods into this interface as THCCachingAllocator_emptyCache and THCCachingAllocator_cacheInfo. - It held its context directly inside the THCDeviceAllocator struct. This context has been moved out into whatever is holding the at::Allocator*. - The APIs for getting at allocators/deleters is now a little different. - Previously there were a bunch of static variables you could get the address of (e.g., &THDefaultAllocator); now there is a function getTHDefaultAllocator(). - Some "allocators" didn't actually know how to allocate (e.g., the IPC "allocator"). These have been deleted; instead, you can wrap the produced pointers into SupervisedPtr using an appropriate makeSupervisedPtr() static method. - Storage sharing was a lot of work to wrangle, but I think I've tamed the beast. - THMapAllocator and its "subclasses" have been refactored to be proper, honest to goodness C++ classes. I used the enum argument trick to get "named" constructors. We use inheritance to add refcounting and management (in libshm). What we previously called the "Context" class (Context has been dropped from the name) is now the supervisor for the data. - Sometimes, we need to pull out the file descriptor from a tensor. Previously, it was pulled out of the allocator context. Now, we pull it out of the supervisor of the SupervisorPtr, using the static method fromSupervisedPtr(), which uses the deleter as the typeid, and refines the type if it matches. - I renamed the std::function deleter into InefficientStdFunctionSupervisor, to emphasize the fact that it does a dynamic allocation to save the std::function deleter. TODO: - Windows libshm is in shambles and needs to be fixed. Perhaps for the future: - newFromFd is now unconditionally calling cudaPointerGetAttributes even though this is unnecessary, because we know what the device is from higher up in the callstack. We can fix this by making newWithDataAndAllocator also take an explicit device argument. - Consider statically distinguishing between allocators that support raw_allocate/raw_deallocate, and those which don't. The Thrust constraint applies only to the CUDA device allocator; you never need to allocate CPU memory this way - Really want to get rid of storage views. Ugh. Nontrivial bugs I noticed when preparing this patch: - I forgot to placement-new unique pointers and attempted to assign them directly on uninitialized memory; very bad! Sam Gross has encouraged me to replace this with a proper constructor but I keep putting it off, because once everything goes in StorageImpl there really will be a proper constructor. - I rewrote a number of APIs to use newWithDataAndAllocator instead of newWithAllocator, calling the allocator at the call site (because they required "allocation context" which we no longer give to "allocators"). When I did this, I forgot to insert the multiplication with sizeof(real) to scale from numels to number of bytes. - The implementation of swap on storages was missing it for scalarType and backend. It was benign (because the only case we call swap is when these are the same), but I fixed it anyway. - I accidentally returned a nullptr unique_ptr with no deleter, even though there was a legitimate one. This matters, because some code still shoves its hands in the deleter context to get extra metadata about the function. - I used std::move() on a unique_ptr, and then did a boolean test on the pointer aftewards (always false!) Pull Request resolved: https://github.com/pytorch/pytorch/pull/9358 Reviewed By: SsnL Differential Revision: D8811822 Pulled By: ezyang fbshipit-source-id: 4befe2d12c3e7fd62bad819ff52b054a9bf47c75
49 lines
1.1 KiB
C++
49 lines
1.1 KiB
C++
#define __STDC_FORMAT_MACROS
|
|
|
|
#include "torch/csrc/python_headers.h"
|
|
#ifdef _MSC_VER
|
|
#include <Windows.h>
|
|
#endif
|
|
#include <structmember.h>
|
|
|
|
#define THP_HOST_HALF
|
|
|
|
#include <stdbool.h>
|
|
#include <TH/TH.h>
|
|
// See Note [TH abstraction violation]
|
|
// - Used to get at the allocator associated with a storage
|
|
#include <TH/THStorage.hpp>
|
|
#include <torch/csrc/finalizer.h>
|
|
#include <libshm.h>
|
|
#include "THP.h"
|
|
#include "copy_utils.h"
|
|
#include "DynamicTypes.h"
|
|
|
|
#ifdef USE_CUDA
|
|
#include <THC/THCStorage.hpp>
|
|
#endif
|
|
|
|
#include "generic/Storage.cpp"
|
|
#include <TH/THGenerateAllTypes.h>
|
|
|
|
#include "generic/Storage.cpp"
|
|
#include <TH/THGenerateHalfType.h>
|
|
|
|
// NB: If you ever divest libtorch of USE_CUDA, you'll have to virtualize
|
|
// the CUDA call.
|
|
template<>
|
|
void THPPointer<THStorage>::free() {
|
|
if (ptr) {
|
|
if (ptr->data_ptr.device().is_cpu()) {
|
|
THStorage_free(ptr);
|
|
} else {
|
|
AT_ASSERT(ptr->data_ptr.device().is_cuda());
|
|
#ifdef USE_CUDA
|
|
THCStorage_free(at::globalContext().lazyInitCUDA(), ptr);
|
|
#else
|
|
AT_ERROR("Cannot free THCStorage when not built with CUDA");
|
|
#endif
|
|
}
|
|
}
|
|
}
|