diff --git a/test/test_cuda.py b/test/test_cuda.py index ec84b9fb1708..7d81b3a92c26 100644 --- a/test/test_cuda.py +++ b/test/test_cuda.py @@ -14,6 +14,7 @@ import tempfile import threading import unittest import warnings +from collections import defaultdict from copy import deepcopy from itertools import product from random import randint @@ -5306,6 +5307,63 @@ class TestMemPool(TestCase): out_0 = torch.randn(nelem_1mb, device="cuda") torch.cuda.memory._set_allocator_settings("expandable_segments:False") + def test_mempool_ctx_multithread(self): + torch.cuda.empty_cache() + segments = torch.cuda.memory._snapshot()["segments"] + self.assertEqual(len(segments), 0, "Expected empty pool in the beginning") + + nelem = 1024 * 1024 + trigger_alloc = threading.Event() + done_allocation = threading.Event() + + def main_thread_fn(): + pool = torch.cuda.MemPool() + out1 = torch.empty(nelem, dtype=torch.int8, device="cuda") + with torch.cuda.use_mem_pool(pool): + out = torch.empty(nelem, dtype=torch.int8, device="cuda") + del out + trigger_alloc.set() + done_allocation.wait() + + def side_thread_fn(segments): + trigger_alloc.wait() + out = torch.empty(nelem, dtype=torch.int8, device="cuda") + s = torch.cuda.memory._snapshot()["segments"] + segments.append(s) + done_allocation.set() + + segments = [] + main_thread = threading.Thread(target=main_thread_fn) + side_thread = threading.Thread(target=side_thread_fn, args=(segments,)) + + main_thread.start() + side_thread.start() + main_thread.join(timeout=10) + side_thread.join(timeout=10) + + if main_thread.is_alive() or side_thread.is_alive(): + # release threads so that they don't hang forever + trigger_alloc.set() + done_allocation.set() + self.fail( + "Test timed out - threads did not complete within the allowed time" + ) + + self.assertEqual(len(segments), 1, "Expected to have memory snapshot") + self.assertEqual(len(segments[0]), 2, "Expected to have 2 segments allocated") + active = defaultdict(int) + for s in segments[0]: + active[s["segment_pool_id"]] += s["active_size"] + for k, v in active.items(): + if k == (0, 0): + self.assertEqual( + v, 2097152, "Expected to have 2MB allocated in the default pool" + ) + else: + self.assertEqual( + v, 0, "Expected to have 0 bytes allocated in the custom pool" + ) + @unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @torch.testing._internal.common_utils.markDynamoStrictTest diff --git a/torch/cuda/memory.py b/torch/cuda/memory.py index 71093e8039ca..5c5f6a6b5118 100644 --- a/torch/cuda/memory.py +++ b/torch/cuda/memory.py @@ -77,12 +77,16 @@ if not hasattr(torch._C, "_MemPool"): torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( "_cuda_beginAllocateToPool" ) + torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type( + "_cuda_beginAllocateCurrentThreadToPool" + ) torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type( "_cuda_endAllocateToPool" ) torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool") from torch._C import ( # noqa: F401 + _cuda_beginAllocateCurrentThreadToPool, _cuda_beginAllocateToPool, _cuda_CUDAAllocator, _cuda_endAllocateToPool, @@ -1192,12 +1196,17 @@ def use_mem_pool(pool: MemPool, device: "Device" = None): the current device, given by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` (default). + .. note:: + This context manager makes only current thread's allocations route to + the given pool. If a new thread is spawned inside the context manager + (e.g. by calling backward) the allocations in that thread will not + route to the given pool. """ ctx = MemPoolContext(pool) device_index = ( torch.cuda.current_device() if device is None else _get_device_index(device) ) - _cuda_beginAllocateToPool(device_index, pool.id) + _cuda_beginAllocateCurrentThreadToPool(device_index, pool.id) try: yield finally: