Previously when we recorded a free action in a memory trace, we would provide
the stack for when the block was allocated. This is faster because we do not
have to record stacks for free, which would otherwise double the number of stacks
collected. However, sometimes knowing the location of a free is useful for
figuring out why a tensor was live. So this PR adds this behavior. If
performance ends up being a concern the old behavior is possible by passing
"alloc" to the context argument rather than "all".
Also refactors some of glue logic to be consistent across C++ and Python and
routes the Python API through the C++ version.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106758
Approved by: https://github.com/albanD
cudaGetLastError and hipGetLastError will clear any error value within CUDA and HIP, respectively. This is often done on purpose to clear benign errors. Discarding the return value should be indicated by casting to void and a nearby comment. This silences warnings from HIP:
warning: ignoring return value of function declared with 'nodiscard' attribute [-Wunused-result]
Performing an audit of pytorch sources found one use of cudaGetLastError that was incorrectly ignored in IndexKernel.cu.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100488
Approved by: https://github.com/ezyang
Previously the allocator would query whether a stream was recording a graph,
and look up the pool associated with a graph. This change has the allocator
directly associate a stream with a mempool, decoupling "record this stream to a pool"
from the action of "record all actions to a cuda graph".
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96542
Approved by: https://github.com/eellison
Copying note from cuda caching allocator:
```
* Note [Checkpointing PrivatePoolState]
*
* Refer above to Note [Interaction with CUDA graph capture]. Allocations made
* during graph capture are made from a separate private pool. During graph
* capture allocations behave as usual. During graph replay the allocator
* state does not change even as new tensors are created. The private pool
* will not free its blocks to the main caching allocator until cuda graph use
* is finished to prevent an allocation from eager clobbering the memory from
* a live but unaccounted for tensor that was created during replay.
*
* `make_graphed_callables`, a series of separate callables chained in
* successive cuda graphs, can share a memory pool because after a cuda graph
* recording the allocations in the shared private pool exactly reflect the
* tensors that are allocated.
*
* We would like to extend callable chaining to support a graphed callable
* tree. In this scenario, we have a tree of callable chains which will be
* captured with cuda graphs. In the diagram below, we have a tree with four
* callables, A, B, C, and D. Suppose we have captured, and subsequently
* replayed, A, B, and C. Then on a new invocation, we replay A and B, but
* would now like to record D. At this point the private pool will not reflect
* any of the live tensors created during graph replay. Allocations made
* during a new recording with the pool could overwrite those live tensors.
*
* In order to record a new graph capture after replaying prior callables in
* the tree, we need the allocator to reflect the state of the live tensors.
* We checkpoint the state of the private after each recording, and then
* reapply it when we are starting a new recording chain. Additionally, we
* must free the allocations for any tensors that died between the end of our
* previous graph replaying and our new recording (TODO). All of the allocated
* segments that existed in the checkpointed state must still exist in the
* pool. There may also exist new segments, which we will free (TODO : link
* note [live tensors between iterations] when it exists).
*
*
* ---------------> A ---------------> B ---------------> C
* |
* |
* |
* |
* ---------------> D
```
A few TODOs:
- need to add logic for freeing tensors that have died between a last replay and current new recording
- Add logic for free that might be called on a pointer multiple times (because we are manually freeing live tensors)
The two scenarios above have not been exercised in the tests yet.
Differential Revision: [D43999889](https://our.internmc.facebook.com/intern/diff/D43999889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94653
Approved by: https://github.com/zdevito
Fixes#43144
This uses the Backend system added by [82682](https://github.com/pytorch/pytorch/pull/82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.
For example, we could have the following allocator in c++
```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
std::cout<<"alloc "<< size<<std::endl;
cudaMalloc(&ptr, size);
return ptr;
}
void my_free(void* ptr) {
std::cout<<"free "<<std::endl;
cudaFree(ptr);
}
}
```
Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```
And use it from PyTorch as follows
```python
import torch
# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```
Things to discuss
- How to test this, needs compiling external code ...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86786
Approved by: https://github.com/albanD
This replaces the manual function pointers, making it easier to write
new drop-in allocators.
Note that most allocation goes through the Allocator interface, which
CUDAAllocator inherits from, and this arrangement avoids adding and
additional layer of dispatch along this pathway compared to what existed before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87251
Approved by: https://github.com/wconstab