Fixes#43144
This uses the Backend system added by [82682](https://github.com/pytorch/pytorch/pull/82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.
For example, we could have the following allocator in c++
```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>
extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
void *ptr;
std::cout<<"alloc "<< size<<std::endl;
cudaMalloc(&ptr, size);
return ptr;
}
void my_free(void* ptr) {
std::cout<<"free "<<std::endl;
cudaFree(ptr);
}
}
```
Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```
And use it from PyTorch as follows
```python
import torch
# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```
Things to discuss
- How to test this, needs compiling external code ...
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86786
Approved by: https://github.com/albanD
This replaces the manual function pointers, making it easier to write
new drop-in allocators.
Note that most allocation goes through the Allocator interface, which
CUDAAllocator inherits from, and this arrangement avoids adding and
additional layer of dispatch along this pathway compared to what existed before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87251
Approved by: https://github.com/wconstab
We currently can take snapshots of the state of the allocated cuda memory, but we do not have a way to correlate these snapshots with the actions the allocator that were taken between snapshots. This PR adds a simple fixed-sized buffer that records the major actions that the allocator takes (ALLOC, FREE, SEGMENT_ALLOC, SEGMENT_FREE, OOM, SNAPSHOT) and includes these with the snapshot information. Capturing period snapshots with a big enough trace buffer makes it possible to see how the allocator state changes over time.
We plan to use this functionality to guide how settings in the allocator can be adjusted and eventually have a more robust overall algorithm.
As a component of this functionality, we also add the ability to get a callback when the allocator will throw an OOM, primarily so that snapshots can be taken immediately to see why the program ran out of memory (most programs have some C++ state that would free tensors before the OutOfMemory exception can be caught).
This PR also updates the _memory_viz.py script to pretty-print the trace information and provide a better textual summary of snapshots distinguishing between internal and external fragmentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86241
Approved by: https://github.com/ngimel
Summary:
- expose a python call to set the allocator settings, it uses the same format as the value for PYTORCH_CUDA_ALLOCATOR
- keep the implementation contained within the cpp file to avoid increasing build times, only expose a function to call the setting
- make some of the Allocator Config methods public, now it looks more like a singleton
Test Plan: added the unit test
Differential Revision: D39487522
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84970
Approved by: https://github.com/zdevito
Record stack trace information for each allocated segment in the allocator.
It takes around 1.5us to record 50 stack frames of context.
Since invoking a Pytorch operator is around 8us, this adds minimal overhead but we still leave it disabled by default so that we can test it more on real workloads first.
Stack information is kept both for allocated blocks and the last allocation used inactive blocks. We could potential keep around the _first_ allocation that caused the block to get allocated from cuda as well.
Potential Followups:
* stack frame entries are small (16 bytes), but the list of Frames is not compressed eventhough most frames will share some entries. So far this doesn't produce huge dumps (7MB for one real workload that uses all memory on the GPU), but it can be much smaller through compression.
* Code to format the information is slow (a few seconds) because it uses python and FlameGraph.pl
* Things allocated during the backward pass have no stack frames because they are run on another C++ thread.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82146
Approved by: https://github.com/albanD
Resubmit of https://github.com/pytorch/pytorch/pull/77673, which was reverted due to Windows test failures: https://github.com/pytorch/pytorch/pull/77673#issuecomment-1130425845.
I suspect these failures happened because I don't explicitly set a side stream for graph capture in the new test.
Not setting a side stream explicitly is alright on Linux because cuda tests implicitly use a side stream.
I think Windows cuda tests implicitly use the default stream, breaking capture and leaving the backend in a bad state.
Other graphs tests explicitly set side streams and don't error in Windows builds, so i'm 95% sure doing the same for the new test will work.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77789
Approved by: https://github.com/ezyang
In preparation of adopting future rocblas library options, it is necessary to track when the backward pass of training is executing. The scope-based helper class `BackwardPassGuard` is provided to toggle state.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71881
Approved by: https://github.com/albanD
This PR allows user to author a CUDA kernel in python.
```
from torch.cuda.jiterator import create_jit_fn
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x * y + x - y + alpha; }"
jitted_fn = create_jit_fn(code_string, alpha=0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
result = jitted_fn(a, b, alpha=1.0)
```
Limitations:
- Only supports elementwise kernel
- 1~8 tensor inputs (empty input, e.g. factory methods, is not supported)
- inputs tensors must live in cuda device
- cpu Scalar is not supported
- kwargs must be pre-declared when calling create_jit_fn
- kwargs must be convertible to at::Scalar, one of float64, int64_t, bool. (complex not support for now)
TODOs:
- [x] consolidate union and c10::variant implementation
- [x] plug into existing op testing framework
- [ ] rename files, place files in the right folder
- [ ] place util functions in the right file
- [x] enforce assumptions in python interface e.g <8 inputs, kwargs types
- [x] Add user-facing documentation
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76394
Approved by: https://github.com/mruberry
Summary:
This patch moves a CUDA-specific file, `CUDAGeneratorImpl.h` to `ATen/cuda` as the following TODO comment in `CUDAGeneratorImpl.h` suggests:
```
// TODO: this file should be in ATen/cuda, not top level
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70650
Reviewed By: jianyuh, xw285cornell
Differential Revision: D33414890
Pulled By: shintaro-iwasaki
fbshipit-source-id: 4ff839205f4e4ea4c8767f164d583eb7072f1b8b
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69041
`TH_CONCAT_{N}` is still being used by THP so I've moved that into
it's own header but all the compiled code is gone.
Test Plan: Imported from OSS
Reviewed By: anjali411
Differential Revision: D32872477
Pulled By: ngimel
fbshipit-source-id: 06c82d8f96dbcee0715be407c61dfc7d7e8be47a
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62030
Remove dtype tracking from Python Storage interface, remove all the different `<type>Storage` classes except for `ByteStorage`, and update serialization accordingly, while maintaining as much FC/BC as possible
Fixes https://github.com/pytorch/pytorch/issues/47442
* **THE SERIALIZATION FORMAT IS FULLY FC/BC.** We worked very hard to make sure this is the case. We will probably want to break FC at some point to make the serialization structure of tensors make more sense, but not today.
* There is now only a single torch.ByteStorage class. Methods like `Tensor.set_` no longer check that the dtype of storage is appropriate.
* As we no longer know what dtype of a storage is, we've **removed** the size method from Storage, replacing it with nbytes. This is to help catch otherwise silent errors where you confuse number of elements with number of bytes.
* `Storage._new_shared` takes a `nbytes` kwarg and will reject previous positional only calls. `Storage._new_with_file` and `_set_from_file` require explicit element size arguments.
* It's no longer possible to convert storages to different types using the float/double/etc methods. Instead, do the conversion using a tensor.
* It's no longer possible to allocate a typed storage directly using FloatStorage/DoubleStorage/etc constructors. Instead, construct a tensor and extract its storage. The classes still exist but they are used purely for unpickling.
* The preexisting serialization format stores dtype with storage, and in fact this dtype is used to determine the dtype of the tensor overall.
To accommodate this case, we introduce a new TypedStorage concept that exists only during unpickling time which is used to temporarily store the dtype so we can construct a tensor. **If you overrode the handling of pickling/unpickling, you MUST add handling for TypedStorage** or your serialization code will degrade to standard file-based serialization.
Original pull request: https://github.com/pytorch/pytorch/pull/59671
Reviewed By: soulitzer, ngimel
Differential Revision: D29466819
Pulled By: ezyang
fbshipit-source-id: 4a14e5d3c2b08e06e558683d97f7378a3180b00e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65746
This also removes the cudaHostAllocator field on THCState, since there
doesn't seem to be an API anywhere for customizing it.
Test Plan: Imported from OSS
Reviewed By: mrshenli
Differential Revision: D31236630
Pulled By: ngimel
fbshipit-source-id: 2a8e756222ae70565e77f8e7139d60ec5be32276
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65610
- Replace HIP_PLATFORM_HCC with USE_ROCM
- Dont rely on CUDA_VERSION or HIP_VERSION and use USE_ROCM and ROCM_VERSION.
- In the next PR
- Will be removing the mapping from CUDA_VERSION to HIP_VERSION and CUDA to HIP in hipify.
- HIP_PLATFORM_HCC is deprecated, so will add HIP_PLATFORM_AMD to support HIP host code compilation on gcc.
cc jeffdaily sunway513 jithunnair-amd ROCmSupport amathews-amd
Reviewed By: jbschlosser
Differential Revision: D30909053
Pulled By: ezyang
fbshipit-source-id: 224a966ebf1aaec79beccbbd686fdf3d49267e06
Summary:
This creates `torch.cuda.set_warn_on_synchronization()` function that would warn or error when synchronizing operation is performed. We could wrap it in a context manager for ease of use, but it would be a lie, because it sets global, and not thread-local state. Since it's intended for debugging, maybe that's ok though.
As all `torch.cuda.*` functions, it's going through CPython, not pybind, so the argument is converted to long before being passed to c10 function. I'll make python argument a python enum class, but without pybind it'll still have to go thourgh long conversion.
For a test script
```
import torch
torch.cuda.set_warn_on_synchronization(1)
x=torch.randn(10, device="cuda")
x.nonzero()
y=torch.randn((), device="cuda")
if y:
print("something")
torch.multinomial(x.abs(), 10, replacement=False)
torch.randperm(20000, device="cuda")
ind = torch.randint(10, (3,), device="cuda")
mask = torch.randint(2, (10,), device="cuda", dtype=torch.bool)
val = torch.randn((), device="cuda")
x[mask]=1.
x[mask] = val
torch.cuda.synchronize()
```
the output is
```
/../playground/sync_warn_test.py:4: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x.nonzero()
/../playground/sync_warn_test.py:7: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
if y:
something
/../playground/sync_warn_test.py:9: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
torch.multinomial(x.abs(), 10, replacement=False)
/../playground/sync_warn_test.py:15: UserWarning: called a synchronizing operation (Triggered internally at ../c10/cuda/CUDAFunctions.cpp:145.)
x[mask] = val
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62092
Reviewed By: mruberry
Differential Revision: D29968792
Pulled By: ngimel
fbshipit-source-id: cc6f817212c164727ed99ecf6ab050dc29631b9e
Summary:
As GoogleTest `TEST` macro is non-compliant with it as well as `DEFINE_DISPATCH`
All changes but the ones to `.clang-tidy` are generated using following script:
```
for i in `find . -type f -iname "*.c*" -or -iname "*.h"|xargs grep cppcoreguidelines-avoid-non-const-global-variables|cut -f1 -d:|sort|uniq`; do sed -i "/\/\/ NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)/d" $i; done
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62008
Reviewed By: driazati, r-barnes
Differential Revision: D29838584
Pulled By: malfet
fbshipit-source-id: 1b2f8602c945bd4ce50a9bfdd204755556e31d13
Summary:
This PR suppresses clang-tidy warnings in the codebase (for now) so that we can re-enable clang-tidy checks on master.
I ran this script to add the `NOLINTNEXTLINE` comments (on a devserver):
```bash
python3 setup.py develop
# Uses same script that's run on CI and adds the -j (parallel), -s (add comments), -k (continue if diagnostic errors are found) options
python3 tools/clang_tidy.py \
-j \
-s \
-k \
-v \
--paths torch/csrc/ \
-g"-torch/csrc/jit/passes/onnx/helper.cpp" \
-g"-torch/csrc/jit/passes/onnx/shape_type_inference.cpp" \
-g"-torch/csrc/jit/serialization/onnx.cpp" \
-g"-torch/csrc/jit/serialization/export.cpp" \
-g"-torch/csrc/jit/serialization/import.cpp" \
-g"-torch/csrc/jit/serialization/import_legacy.cpp" \
-g"-torch/csrc/onnx/init.cpp" \
-g"-torch/csrc/cuda/nccl.*" \
-g"-torch/csrc/cuda/python_nccl.cpp" \
-g"-torch/csrc/autograd/FunctionsManual.cpp" \
-g"-torch/csrc/generic/*.cpp" \
-g"-torch/csrc/jit/codegen/cuda/runtime/*" \
-g"-torch/csrc/deploy/interpreter/interpreter.cpp" \
-g"-torch/csrc/deploy/interpreter/interpreter.h" \
-g"-torch/csrc/deploy/interpreter/interpreter_impl.h" \
-g"-torch/csrc/deploy/interpreter/test_main.cpp"
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60649
Test Plan: Verified changes by re-running the script (without the `-s` option) and seeing no warnings/errors.
Reviewed By: walterddr, janeyx99
Differential Revision: D29504258
Pulled By: 1ntEgr8
fbshipit-source-id: 78310b30ee8213b73ddb4771ad874665323e7a4e
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59655
CUDAHooks is to be used solely when you need to call into CUDA
functionality from a context where you cannot directly link to
CUDA libraries. Neither of hasPrimaryContext nor
getDevceIndexWithPrimaryContext (sic) needs to be used in such
contexts. By moving them out of CUDAHooks and calling them
directly a dynamic dispatch can be skipped.
I also fixed the typo in getDev(i)ceIndexWithPrimaryContext
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D28972946
Pulled By: ezyang
fbshipit-source-id: edcd7a7b62aec97928f07fbf3bf413b9fb027517
Summary:
Fixes https://github.com/pytorch/pytorch/issues/35901
This change is designed to prevent fragmentation in the Caching Allocator. Permissive block splitting in the allocator allows very large blocks to be split into many pieces. Once split too finely it is unlikely all pieces will be 'free' at that same time so the original allocation can never be returned. Anecdotally, we've seen a model run out of memory failing to alloc a 50 MB block on a 32 GB card while the caching allocator is holding 13 GB of 'split free blocks'
Approach:
- Large blocks above a certain size are designated "oversize". This limit is currently set 1 decade above large, 200 MB
- Oversize blocks can not be split
- Oversize blocks must closely match the requested size (e.g. a 200 MB request will match an existing 205 MB block, but not a 300 MB block)
- In lieu of splitting oversize blocks there is a mechanism to quickly free a single oversize block (to the system allocator) to allow an appropriate size block to be allocated. This will be activated under memory pressure and will prevent _release_cached_blocks()_ from triggering
Initial performance tests show this is similar or quicker than the original strategy. Additional tests are ongoing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44742
Reviewed By: zou3519
Differential Revision: D29186394
Pulled By: ezyang
fbshipit-source-id: c88918836db3f51df59de6d1b3e03602ebe306a9
Summary:
Switches most of the simple for loops outside of `jit` directories to use `c10::irange`.
Generated with D28874212.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59481
Test Plan: Sandcastle
Reviewed By: ngimel
Differential Revision: D28909681
fbshipit-source-id: ec9ab1bd602933238d9d0f73d4d8d027b75d9d85
Summary:
In my last PR I've missed CUDA and distributed folders, fixing this now
This change is autogenerated by `python tool/clang_tidy.py -s`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57235
Reviewed By: janeyx99
Differential Revision: D28084444
Pulled By: malfet
fbshipit-source-id: bf222f69ee90c7872c3cb0931e8cdb84f0cb3cda
Summary:
Fixes https://github.com/pytorch/pytorch/issues/35901
This change is designed to prevent fragmentation in the Caching Allocator. Permissive block splitting in the allocator allows very large blocks to be split into many pieces. Once split too finely it is unlikely all pieces will be 'free' at that same time so the original allocation can never be returned. Anecdotally, we've seen a model run out of memory failing to alloc a 50 MB block on a 32 GB card while the caching allocator is holding 13 GB of 'split free blocks'
Approach:
- Large blocks above a certain size are designated "oversize". This limit is currently set 1 decade above large, 200 MB
- Oversize blocks can not be split
- Oversize blocks must closely match the requested size (e.g. a 200 MB request will match an existing 205 MB block, but not a 300 MB block)
- In lieu of splitting oversize blocks there is a mechanism to quickly free a single oversize block (to the system allocator) to allow an appropriate size block to be allocated. This will be activated under memory pressure and will prevent _release_cached_blocks()_ from triggering
Initial performance tests show this is similar or quicker than the original strategy. Additional tests are ongoing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/44742
Reviewed By: ngimel
Differential Revision: D23752058
Pulled By: ezyang
fbshipit-source-id: ccb7c13e3cf8ef2707706726ac9aaac3a5e3d5c8
Summary:
And unrelying torch._C._cuda_canDeviceAccessPeer, which is a wrapper around cudaDeviceCanAccessPeer
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50446
Reviewed By: mrshenli
Differential Revision: D25890405
Pulled By: malfet
fbshipit-source-id: ef09405f115bbe73ba301d608d56cd8f8453201b
Summary:
Add a new function, torch.cuda.set_per_process_memory_fraction(fraction, device), to torch.cuda. Related: https://github.com/pytorch/pytorch/issues/18626
The fraction (float type, from 0 to 1) is used to limit memory of cashing allocator on GPU device . One can set it on any visible GPU. The allowed memory equals total memory * fraction. It will raise an OOM error when try to apply GPU memory more than the allowed value. This function is similar to Tensorflow's per_process_gpu_memory_fraction
Note, this setting is just limit the cashing allocator in one process. If you are using multiprocess, you need to put this setting in to the subprocess to limit its GPU memory, because subprocess could have its own allocator.
## usage
In some cases, one needs to split a GPU device as two parts. Can set limitation before GPU memory using.
Eg. device: 0, each part takes half memory, the code as follows:
```
torch.cuda.set_per_process_memory_fraction(0.5, 0)
```
There is an example to show what it is.
```python
import torch
torch.cuda.set_per_process_memory_fraction(0.5, 0)
torch.cuda.empty_cache()
total_memory = torch.cuda.get_device_properties(0).total_memory
# less than 0.5 will be ok:
tmp_tensor = torch.empty(int(total_memory * 0.499), dtype=torch.int8, device='cuda')
del tmp_tensordel tmp_tensor
torch.cuda.empty_cache()
# this allocation will raise a OOM:
torch.empty(total_memory // 2, dtype=torch.int8, device='cuda')
"""
It raises an error as follows:
RuntimeError: CUDA out of memory. Tried to allocate 5.59 GiB (GPU 0; 11.17 GiB total capacity; 0 bytes already allocated; 10.91 GiB free; 5.59 GiB allowed; 0 bytes reserved in total by PyTorch)
"""
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/48172
Reviewed By: bdhirsh
Differential Revision: D25275381
Pulled By: VitalyFedyunin
fbshipit-source-id: d8e7af31902c2eb795d416b57011cc8a22891b8f