Commit Graph

194 Commits

Author SHA1 Message Date
7ff1f3f3f6 Revert "Revert "Expandable blocks in allocator (#96995)"" (#99275)
This reverts commit 851e89c8e817f28270e0fc21d74ced9446bea747.

Differential Revision: [D45034526](https://our.internmc.facebook.com/intern/diff/D45034526)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99275
Approved by: https://github.com/eellison
2023-04-17 23:46:08 +00:00
851e89c8e8 Revert "Expandable blocks in allocator (#96995)"
This reverts commit 6a50b83b739c2d37d0f518f98b8e624eca0ea153.

Reverted https://github.com/pytorch/pytorch/pull/96995 on behalf of https://github.com/izaitsevfb due to Breaks internal tests
2023-04-16 19:23:37 +00:00
fdbc8625a1 Functionalization of torch.rand/rand_like ops (#97377)
This PR introduces the functionalization of RNG ops. Key points are

* Introduces a new `philox_rand` prim operator that accepts seed, offset.
* Adds decompositions for random operators that use these philox_rand prims
* Adds a PhiloxStateTracker to track the offset for each occurence of rand ops
* Changes calling convention of AOT Autograd and adds <fwd_seed, fwd_base_offset> and <bwd_seed, bwd_base_offset>
* Monkeypatches set_rng_state and get_rng_state while AOT Autograd tracing to record the rng state behavior
* Raises assertion for CPU because CPU does not Philox RNG.

Not dealt in this PR
* dropout op - offset calculation is different
* other distributions like normal, poisson etc
* Inductor support
* Cudagraph support
* Dynamic shape support

An example
~~~

class Custom(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x):
        ctx.save_for_backward(x)
        a = torch.rand_like(x) * x
        a = torch.rand_like(x) * a
        return a

    @staticmethod
    def backward(ctx, grad_out):
        x, = ctx.saved_tensors
        return grad_out * torch.rand_like(grad_out) * torch.cos(x)

====== Forward graph 0 ======
def forward(self, fwd_seed_1: i64[], fwd_base_offset_1: i64[], primals_1: f32[16, 16]):
    # No stacktrace found for following nodes
    add: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 0)
    philox_rand: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add, [16, 1], device(type='cuda', index=0), torch.float32);  add = None
    mul: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand, primals_1);  philox_rand = None
    add_1: i64[] = torch.ops.aten.add.Tensor(fwd_base_offset_1, 4);  fwd_base_offset_1 = None
    philox_rand_1: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], fwd_seed_1, add_1, [16, 1], device(type='cuda', index=0), torch.float32);  fwd_seed_1 = add_1 = None
    mul_1: f32[16, 16] = torch.ops.aten.mul.Tensor(philox_rand_1, mul);  philox_rand_1 = mul = None
    return [mul_1, primals_1]

====== Backward graph 0 ======
def forward(self, bwd_seed_1: i64[], bwd_base_offset_1: i64[], primals_1: f32[16, 16], tangents_1: f32[16, 16]):
    # No stacktrace found for following nodes
    add_2: i64[] = torch.ops.aten.add.Tensor(bwd_base_offset_1, 0);  bwd_base_offset_1 = None
    philox_rand_2: f32[16, 16] = torch.ops.prims.philox_rand.default([16, 16], bwd_seed_1, add_2, [16, 1], device(type='cuda', index=0), torch.float32);  bwd_seed_1 = add_2 = None
    mul_2: f32[16, 16] = torch.ops.aten.mul.Tensor(tangents_1, philox_rand_2);  tangents_1 = philox_rand_2 = None
    cos: f32[16, 16] = torch.ops.aten.cos.default(primals_1);  primals_1 = None
    mul_3: f32[16, 16] = torch.ops.aten.mul.Tensor(mul_2, cos);  mul_2 = cos = None
    return [mul_3]

~~~

Pull Request resolved: https://github.com/pytorch/pytorch/pull/97377
Approved by: https://github.com/ezyang
2023-04-16 09:55:56 +00:00
6a50b83b73 Expandable blocks in allocator (#96995)
Common advice we give for handling memory fragmentation issues is to
allocate a big block upfront to reserve memory which will get split up later.
For programs with changing tensor sizes this can be especially helpful to
avoid OOMs that happen the first time we see a new largest input and would
otherwise have to allocate new segments.

However the issue with allocating a block upfront is that is nearly impossible
to correctly estimate the size of that block. If too small, space in the block
will run out and the allocator will allocate separate blocks anyway. Too large,
and other non-PyTorch libraries might stop working because they cannot allocate
any memory.

This patch provides the same benefits as using a pre-allocating block but
without having to choose its size upfront. Using the cuMemMap-style APIs,
it adds the ability to expand the last block in a segment when more memory is
needed.

Compared to universally using cudaMallocAsync to avoid fragmentation,
this patch can fix this common fragmentation issue while preserving most
of the existing allocator behavior. This behavior can be enabled and disabled dynamically.
 This should allow users to, for instance, allocate long-lived parameters and state in individual buffers,
and put temporary state into the large expandable blocks, further reducing
fragmentation.

See inline comments for information about the implementation and its limitations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/96995
Approved by: https://github.com/eellison
2023-04-14 09:49:11 +00:00
69eef5a4be [CUDA12] set_device change (#94864)
This PR adds workaround for CUDA 12 [`cudaSetDevice` change](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb) which will always create primary context on target device. So operations like this:
```Python
import torch
x = torch.randn(1, device="cuda:1")
```
would always create primary context on on device `cuda:1` because it is creating a tensor on it and on device `cuda:0` because the destructor of CUDA Device guard calls `cudaSetDevice(0)`.
After this PR the CUDA Device guard will not call `cudaSetDevice(0)` if primary context does not exist on `cuda:0`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94864
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/ezyang
2023-04-10 17:31:12 +00:00
5c8fea5647 Reduce overhead in CUDAGraph Trees (#98529)
Significantly reduces overhead of constructing Tensors and Storages and checking Storage Liveness. Removes the regression for HF models that I tested and removes 75% of overhead of the extremely overhead bound resnet50 training we have in torchbench. (.91x base commit, 1.02x torchinductor default, 1.16x this PR, 1.25 previous cudagraphs impl).

This PR takes care of all of the lower hanging fruit.

- Computes storage aliasing at record time instead of during at runtime. We no longer need to use a runtime storage cache, and can instead index directly into the existing alias if there is one, or construct a new Storage

- Moves the heavyweight C++ calls into a batch - getting storage weakrefs and constructing tensors

Pull Request resolved: https://github.com/pytorch/pytorch/pull/98529
Approved by: https://github.com/jansel, https://github.com/ngimel
2023-04-07 05:46:08 +00:00
279ca5f9db Revert "[CUDA12] set_device change (#94864)"
This reverts commit c18be2b2ec00133abe28efcdd0462e50ddd45a1a.

Reverted https://github.com/pytorch/pytorch/pull/94864 on behalf of https://github.com/ezyang due to avoid affecting cuda 11
2023-04-05 14:53:00 +00:00
c18be2b2ec [CUDA12] set_device change (#94864)
This PR adds workaround for CUDA 12 [`cudaSetDevice` change](https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__DEVICE.html#group__CUDART__DEVICE_1g159587909ffa0791bbe4b40187a4c6bb) which will always create primary context on target device. So operations like this:
```Python
import torch
x = torch.randn(1, device="cuda:1")
```
would always create primary context on on device `cuda:1` because it is creating a tensor on it and on device `cuda:0` because the destructor of CUDA Device guard calls `cudaSetDevice(0)`.
After this PR the CUDA Device guard will not call `cudaSetDevice(0)` if primary context does not exist on `cuda:0`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94864
Approved by: https://github.com/malfet, https://github.com/atalman, https://github.com/ezyang
2023-04-05 14:34:00 +00:00
da28af3286 distinguish mutability of StorageImpl::data_ptr() member (#97651)
See D44409928.

Differential Revision: [D44410323](https://our.internmc.facebook.com/intern/diff/D44410323/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97651
Approved by: https://github.com/ezyang
2023-03-30 19:13:56 +00:00
24ce3a7c34 Move hasPrimaryContext to c10::cuda (#96800)
This method has to be accessible from `c10` to enable CUDA-12 integration.
Implemented by providing private `c10::cuda:_internal::setHasPrimaryContext` that passes the pointer to the implementation (in `torch_cuda`) back to c10.
Use global class constructor/destructor to guarantee RAII.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96800
Approved by: https://github.com/ngimel
2023-03-17 04:50:35 +00:00
571f96bf59 cudagraph trees (#89146)
CUDA Graph Trees

Design doc: https://docs.google.com/document/d/1ZrxLGWz7T45MSX6gPsL6Ln4t0eZCSfWewtJ_qLd_D0E/edit

Not currently implemented :

- Right now, we are using weak tensor refs from outputs to check if a tensor has dies. This doesn't work because a) aliasing, and b) aot_autograd detaches tensors (see note [Detaching saved tensors in AOTAutograd]). Would need either https://github.com/pytorch/pytorch/issues/91395 to land to use storage weak refs or manually add a deleter fn that does what I want. This is doable but theres some interactions with the caching allocator checkpointing so saving for a stacked pr.

- Reclaiming memory from the inputs during model recording. This isn't terribly difficult but deferring to another PR. You would need to write over the input memory during warmup, and therefore copy the inputs to cpu. Saving for a stacked pr.

- Warning on overwriting previous generation outputs. and handling nested torch.compile() calls in generation tracking

Differential Revision: [D43999887](https://our.internmc.facebook.com/intern/diff/D43999887)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/89146
Approved by: https://github.com/ezyang
2023-03-17 02:47:03 +00:00
ea7415087a Expose Stream Recording Apis in python (#96384)
Differential Revision: [D43999891](https://our.internmc.facebook.com/intern/diff/D43999891)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96384
Approved by: https://github.com/zdevito
2023-03-16 23:45:43 +00:00
e74f70d212 Revert "Revert "[memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)"" (#96878)
This reverts commit e1ea584b1caf9c50de25ce69396dfeb523a452c0.
Adds __has_include check to fix fbcode build.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96878
Approved by: https://github.com/ezyang
2023-03-16 04:12:54 +00:00
e1ea584b1c Revert "[memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)"
This reverts commit 4e1060c609c094fd5f58041ebed803f74410ee36.

Reverted https://github.com/pytorch/pytorch/pull/95541 on behalf of https://github.com/DanilBaibak due to breaking internal builds
2023-03-15 13:28:41 +00:00
85639c1a88 [allocator] Generalize recording to a pool (#96542)
Previously the allocator would query whether a stream was recording a graph,
and look up the pool associated with a graph. This change has the allocator
directly associate a stream with a mempool, decoupling "record this stream to a pool"
from the action of "record all actions to a cuda graph".
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96542
Approved by: https://github.com/eellison
2023-03-15 04:28:49 +00:00
4e1060c609 [memory profiling] add a facility to gather combined C++/Python/TorchScript stack traces. (#95541)
This refactors the stack trace facility specific to memory profiling
    in python+cuda to make a generic facility to generate combined stack
    traces.

    The generic facility (combined_traceback.h) does not require
    python to be around to work, but will return python stacks if it is
    present.

    This facility is then used to add support for stack trace gathering in memory profiling that
    happens directly from C++.

    It is also used to expose a python API for gathering and symbolizing
    combineds stacks.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95541
Approved by: https://github.com/ezyang
2023-03-14 18:26:05 +00:00
da265652d6 Return Live Data Pointers from Checkpoint, swap onto tensors (#95020)
When we checkpoint the state of the private pool allocator, we will need to make sure that its current live allocated blocks will get properly cleaned up when the tensors they correspond to die. Return DataPtrs for these new allocated blocks that the callee can swap onto live Tensors.

The exact api for setting the checkpoint can be manipulated after this as the cudagraph implementation is built out, but this at least shows its sufficiently general.

This should be the last PR touching cuda caching allocator necessary for new cudagraphs integration.

Differential Revision: [D43999888](https://our.internmc.facebook.com/intern/diff/D43999888)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95020
Approved by: https://github.com/zdevito
2023-03-14 01:22:19 +00:00
1cc32aedb0 Handle additional live allocations not in checkpointed state (#94943)
We choose to ignore certain blocks that are currently allocated when we set the pool to its checkpoint. For those blocks, we need to swap out the deleter function of their corresponding blocks so that a deallocation is not triggered when they die.

Differential Revision: [D43999886](https://our.internmc.facebook.com/intern/diff/D43999886)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94943
Approved by: https://github.com/zdevito
2023-03-14 01:00:47 +00:00
d798de2b05 Checkpoint CUDA Allocator Private Pool State (#94653)
Copying note from cuda caching allocator:

```
   * Note [Checkpointing PrivatePoolState]
   *
   * Refer above to Note [Interaction with CUDA graph capture]. Allocations made
   * during graph capture are made from a separate private pool. During graph
   * capture allocations behave as usual. During graph replay the allocator
   * state does not change even as new tensors are created. The private pool
   * will not free its blocks to the main caching allocator until cuda graph use
   * is finished to prevent an allocation from eager clobbering the memory from
   * a live but unaccounted for tensor that was created during replay.
   *
   * `make_graphed_callables`, a series of separate callables chained in
   * successive cuda graphs, can share a memory pool because after a cuda graph
   * recording the allocations in the shared private pool exactly reflect the
   * tensors that are allocated.
   *
   * We would like to extend callable chaining to support a graphed callable
   * tree. In this scenario, we have a tree of callable chains which will be
   * captured with cuda graphs. In the diagram below, we have a tree with four
   * callables, A, B, C, and D. Suppose we have captured, and subsequently
   * replayed, A, B, and C. Then on a new invocation, we replay A and B, but
   * would now like to record D. At this point the private pool will not reflect
   * any of the live tensors created during graph replay. Allocations made
   * during a new recording with the pool could overwrite those live tensors.
   *
   * In order to record a new graph capture after replaying prior callables in
   * the tree, we need the allocator to reflect the state of the live tensors.
   * We checkpoint the state of the private after each recording, and then
   * reapply it when we are starting a new recording chain. Additionally, we
   * must free the allocations for any tensors that died between the end of our
   * previous graph replaying and our new recording (TODO). All of the allocated
   * segments that existed in the checkpointed state must still exist in the
   * pool. There may also exist new segments, which we will free (TODO : link
   * note [live tensors between iterations] when it exists).
   *
   *
   *  ---------------> A ---------------> B ---------------> C
   *                                |
   *                                |
   *                                |
   *                                |
   *                                  ---------------> D
```

A few TODOs:
- need to add logic for freeing tensors that have died between a last replay and current new recording
- Add logic for free that might be called on a pointer multiple times (because we are manually freeing live tensors)

The two scenarios above have not been exercised in the tests yet.

Differential Revision: [D43999889](https://our.internmc.facebook.com/intern/diff/D43999889)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94653
Approved by: https://github.com/zdevito
2023-03-14 00:47:30 +00:00
4b372e3958 [memory profiling] C++ tracing support (#95357)
Adds the ability to quickly generate stack traces for C++,
and combine Python, TorchScript, and C++ frames into a single trace.

This makes it possible for the memory tracer to record allocations inside
C++ code (e.g. convolution temporaries, backward operators).

The unwinder code is ~10x faster than execinfo.h's backward because it
cache fast unwinder routines for instruction pointers that have already been seen.
It is also only 1.2--2x slower than copying the entire stack (the approach perf takes),
while using 2 orders of magnitude less space per stack.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95357
Approved by: https://github.com/bertmaher
2023-03-12 07:24:14 +00:00
48490cec28 [memory profiling] Move Context object to c10 (#96280)
Minor refactor so that follow up PR can have objects that meet the GatheredContext
inferface without having to depend on CUDA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96280
Approved by: https://github.com/eellison
2023-03-12 07:24:14 +00:00
266089a3fe [memory snapshots] record scripted stack traces (#95356)
Adds support for seeing both python and script stack traces in memory
debugging.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95356
Approved by: https://github.com/aaronenyeshi
2023-03-12 07:24:14 +00:00
cyy
6786a24fd2 fix some tiny code issues (#95757)
This PR tries to fix:
1. a misspelled NDEBUG preprocessing condition.
2. get ride of all writable-strings warnings.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/95757
Approved by: https://github.com/soulitzer
2023-03-01 23:27:32 +00:00
4f84c57c87 Fix potential deadlock when recording memory traces (#95273)
See comment in the diff

Differential Revision: [D43490668](https://our.internmc.facebook.com/intern/diff/D43490668)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95273
Approved by: https://github.com/eellison
2023-02-27 19:04:47 +00:00
54b7c7d5e9 Added requested_bytes to CUDA Caching Allocator Stats (#88575)
Summary:
The caching allocator can be configured to round memory allocations in order to reduce fragmentation. Sometimes however, the overhead from rounding can be higher than the fragmentation it helps reduce.

We have added a new stat to CUDA caching allocator stats to help track if rounding is adding too much overhead and help tune the roundup_power2_divisions flag:
    - "requested_bytes.{current,peak,allocated,freed}": memory requested by client code, compare this with allocated_bytes to check if allocation rounding adds too much overhead

Test Plan: Added test case in caffe2/test/test_cuda.py

Differential Revision: D40810674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88575
Approved by: https://github.com/zdevito
2023-02-09 21:37:25 +00:00
cyy
27efdc5eed fix writable-strings warnings (#93246)
clang reports "ISO C++11 does not allow conversion from string
literal to 'char *'"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/93246
Approved by: https://github.com/malfet
2023-02-04 02:11:15 +00:00
cyy
bfe5e1258b avoid unnecessary static_cast (#93898)
avoid unnecessary static_cast
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93898
Approved by: https://github.com/Skylion007
2023-02-03 03:44:43 +00:00
cyy
e292ddff4e More clang-tidy fixes (#92944)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92944
Approved by: https://github.com/Skylion007
2023-01-25 19:11:51 +00:00
523d4f2562 Revert "[cuDNN][cuDNN V8 API] Always build assuming cuDNN >= 8.0 (#91527)"
This reverts commit 4d07ad74f1c11efa55501433d6cf1f06840f5207.

Reverted https://github.com/pytorch/pytorch/pull/91527 on behalf of https://github.com/DanilBaibak due to Break internal build
2023-01-16 13:28:09 +00:00
4d07ad74f1 [cuDNN][cuDNN V8 API] Always build assuming cuDNN >= 8.0 (#91527)
We've been building with V8 (incl. V8 API) by default for a while now; this PR cleans up some guards for cuDNN < 8.0.

CC @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91527
Approved by: https://github.com/ngimel
2023-01-13 18:55:37 +00:00
eece6da162 [inductor] Reduce device context manager overhead (#91045)
This adds `torch.cuda._DeviceGuard` which is a stripped down version of
`torch.cuda.device` with lower overhead. To do this, it only accepts `int` as
the device so we don't need to call `_get_device_index` and is implemented
with a new C++ helper `torch._C._cuda_exchangeDevice` that allows
`_DeviceGuard.__enter__` to be just a single function call. On my machine,
I see a drop from 3.8us of overhead to 0.94 us with this simple benchmark:

```python
def set_device():
    with torch.cuda.device(0):
        pass

%timeit set_device()
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/91045
Approved by: https://github.com/ngimel, https://github.com/anijain2305
2023-01-12 16:51:59 +00:00
e096d2db5a [BC-Breaking] Separate stream_id, device_index, and device_type in pack and unpack for Streams (#81596)
#75854

A naive attempt at working around the limitations of using a single 64-bit integer to pack `stream_id`, `device_index`, and `device_type`.

Stills needs sanity checks, testing, and minimization of BC-breaking changes.

Currently a Holder for the `StreamData3` struct is used for `IValue` compatibility. While doing this seems to work for `ivalue.h` and `ivalue_inl.h`, this doesn't seem to be naively working for the JIT CUDA stream wrapper? (Something about ambiguous calls if an `intrusive_ptr` to `c10::ivalue::StreamData3Holder` is used as the return type for `pack()`. It turns out that the methods required to access the fields for rematerializing a CUDA Stream are basically already present anyway, so `pack` is simply removed in the wrapper for now and the methods to access the required fields are called directly.

CC @ptrblck

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81596
Approved by: https://github.com/ezyang
2023-01-12 14:16:49 +00:00
07e595e88a Add device_idx to free_fn in CUDAPluggableAllocator (#91398)
This was requested by nvidia folks, track also the device_id in the free function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91398
Approved by: https://github.com/albanD
2023-01-12 05:03:48 +00:00
c9d4390d13 Add Pluggable CUDA allocator backend (#86786)
Fixes #43144

This uses the Backend system added by [82682](https://github.com/pytorch/pytorch/pull/82682) to change allocators dynamically during the code execution. This will allow us to use RMM, use CUDA managed memory for some portions of the code that do not fit in GPU memory. Write static memory allocators to reduce fragmentation while training models and improve interoperability with external DL compilers/libraries.

For example, we could have the following allocator in c++

```c++
#include <sys/types.h>
#include <cuda_runtime_api.h>
#include <iostream>

extern "C" {
void* my_malloc(ssize_t size, int device, cudaStream_t stream) {
   void *ptr;
   std::cout<<"alloc "<< size<<std::endl;
   cudaMalloc(&ptr, size);
   return ptr;
}

void my_free(void* ptr) {
   std::cout<<"free "<<std::endl;
   cudaFree(ptr);
}
}
```

Compile it as a shared library
```
nvcc allocator.cc -o alloc.so -shared --compiler-options '-fPIC'
```

And use it from PyTorch as follows

```python
import torch

# Init caching
# b = torch.zeros(10, device='cuda')
new_alloc = torch.cuda.memory.CUDAPluggableAllocator('alloc.so', 'my_malloc', 'my_free')
old = torch.cuda.memory.get_current_allocator()
torch.cuda.memory.change_current_allocator(new_alloc)
b = torch.zeros(10, device='cuda')
# This will error since the current allocator was already instantiated
torch.cuda.memory.change_current_allocator(old)
```

Things to discuss
- How to test this, needs compiling external code ...

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86786
Approved by: https://github.com/albanD
2022-11-23 17:54:36 +00:00
0d2c2110f1 [allocator] Introduce the abstract class CUDACachingAllocator (#87251)
This replaces the manual function pointers, making it easier to write
new drop-in allocators.

Note that most allocation goes through the Allocator interface, which
CUDAAllocator inherits from, and this arrangement avoids adding and
additional layer of dispatch along this pathway compared to what existed before.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87251
Approved by: https://github.com/wconstab
2022-10-20 01:17:00 +00:00
f56ce8dbad [allocator] Move getFreeMutex (#87237)
It isn't used at all the allocators and this change makes that more clear.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87237
Approved by: https://github.com/wconstab
2022-10-19 18:00:40 +00:00
25725fd624 (Re-open) Adds cudaMallocAsync as an alternative backend for the CUDA allocator (#82682)
Rebased version of @mcarilli 's cudaMallocAsync #65365 for continued testing
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82682
Approved by: https://github.com/ngimel
2022-10-12 03:44:21 +00:00
eqy
352d926482 [CUBLAS][CUDA GRAPHS] (re-re-re-re-open of #83461) Explicitly set the workspace for cuBLAS handles (#86645)
re-opening (again) in hopes of working around failed/stuck CLA check

CC @ptrblck @ngimel @huydhn
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86645
Approved by: https://github.com/zdevito
2022-10-11 16:03:49 +00:00
91b1bae1df Caching allocator tracing (#86241)
We currently can take snapshots of the state of the allocated cuda memory, but we do not have a way to correlate these snapshots with the actions the allocator that were taken between snapshots. This PR adds a simple fixed-sized buffer that records the major actions that the allocator takes (ALLOC, FREE, SEGMENT_ALLOC, SEGMENT_FREE, OOM, SNAPSHOT) and includes these with the snapshot information. Capturing period snapshots with a big enough trace buffer makes it possible to see how the allocator state changes over time.

We plan to use this functionality to guide how settings in the allocator can be adjusted and eventually have a more robust overall algorithm.

As a component of this functionality, we also add the ability to get a callback when the allocator will throw an OOM, primarily so that snapshots can be taken immediately to see why the program ran out of memory (most programs have some C++ state that would free tensors before the OutOfMemory exception can be caught).

This PR also updates the _memory_viz.py script to pretty-print the trace information and provide a better textual summary of snapshots distinguishing between internal and external fragmentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86241
Approved by: https://github.com/ngimel
2022-10-07 23:19:54 +00:00
adf5919720 Add option to record C++ backtraces in _record_memory_history (#86145)
I used this to debug https://github.com/pytorch/pytorch/issues/86136 so it is useful. The implementation is not so fast so it is not enabled by default.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86145
Approved by: https://github.com/albanD, https://github.com/zdevito
2022-10-06 04:07:37 +00:00
97d6b5bbf8 Refactor _cuda_recordMemoryHistory to use pybind11 (#86139)
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86139
Approved by: https://github.com/albanD
2022-10-06 04:07:37 +00:00
db13049b88 [allocator tracing] missing GIL acquire (#86254)
Bug where the context destructor needs to hold the GIL to free the context.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86254
Approved by: https://github.com/ezyang
2022-10-05 05:54:05 +00:00
71eb04403c Revert "[CUBLAS][CUDA GRAPHS] (re-re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85447)"
This reverts commit b04b2fa9aa52cacbdc9aaaf477d55b0af845ce81.

Reverted https://github.com/pytorch/pytorch/pull/85447 on behalf of https://github.com/seemethere due to Caused a CUDA memory leak, detected by our performance benchmark suite
2022-09-30 20:53:41 +00:00
b04b2fa9aa [CUBLAS][CUDA GRAPHS] (re-re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85447)
Now includes @dagitses 's optimizations and fixes for teardown

CC @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85447
Approved by: https://github.com/malfet
2022-09-28 16:04:58 +00:00
0ac6311356 Revert "[CUBLAS][CUDA GRAPHS] (re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85292)"
This reverts commit 4012e623e8689b873cae94590766d990d155017c.

Reverted https://github.com/pytorch/pytorch/pull/85292 on behalf of https://github.com/dagitses due to broke an internal test during shutdown. Re-submit with #85399 in stack
2022-09-21 17:57:49 +00:00
eqy
4012e623e8 [CUBLAS][CUDA GRAPHS] (re-open of #83461) Explicitly set the workspace for cuBLAS handles (#85292)
re-open of #83461 with fix for 10.2 build

CC @ngimel @malfet
Pull Request resolved: https://github.com/pytorch/pytorch/pull/85292
Approved by: https://github.com/malfet
2022-09-20 16:31:54 +00:00
d23ce29761 allow changing the cuda allocator settings even after the process started (#84970)
Summary:
- expose a python call to set the allocator settings, it uses the same format as the value for PYTORCH_CUDA_ALLOCATOR
- keep the implementation contained within the cpp file to avoid increasing build times, only expose a function to call the setting
- make some of the Allocator Config methods public, now it looks more like a singleton

Test Plan: added the unit test

Differential Revision: D39487522

Pull Request resolved: https://github.com/pytorch/pytorch/pull/84970
Approved by: https://github.com/zdevito
2022-09-17 09:42:42 +00:00
726d040692 annotated allocator snapshots (#82146)
Record stack trace information for each allocated segment in the allocator.
It takes around 1.5us to record 50 stack frames of context.
Since invoking a Pytorch operator is around 8us, this adds minimal overhead but we still leave it disabled by default so that we can test it more on real workloads first.

Stack information is kept both for allocated blocks and the last allocation used inactive blocks. We could potential keep around the _first_ allocation that caused the block to get allocated from cuda as well.

Potential Followups:
* stack frame entries are small (16 bytes), but the list of Frames is not compressed eventhough most frames will share some entries. So far this doesn't produce huge dumps (7MB for one real workload that uses all memory on the GPU), but it can be much smaller through compression.
* Code to format the information is slow (a few seconds) because it uses python and FlameGraph.pl
* Things allocated during the backward pass have no stack frames because they are run on another C++ thread.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/82146
Approved by: https://github.com/albanD
2022-08-09 17:21:35 +00:00
56dea92d97 Fix set_requires_cuda_init (#81183)
Fixes the buggy `set_requires_cuda_init` introduced in #80788.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81183
Approved by: https://github.com/ezyang
2022-07-11 15:36:42 +00:00
ae6dd20ba7 [cuDNN V8 API] (reopen 2) Allow the number of kernels profiled under torch.backends.cudnn.benchmark = True to be limitedCudnnv8 benchmark limit (#78299)
Reopen of #77002 to address comments by @malfet

CC @ngimel @ptrblck
Pull Request resolved: https://github.com/pytorch/pytorch/pull/78299
Approved by: https://github.com/ngimel
2022-07-07 23:25:23 +00:00