make use_mem_pool threadlocal (#153356)

Partial fix for #152861, makes allocation to pool thread-local, but doesn't touch the second bug where multiple threads allocating to multiple pools error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153356
Approved by: https://github.com/Skylion007, https://github.com/eellison
This commit is contained in:
Natalia Gimelshein
2025-05-13 00:16:02 +00:00
committed by PyTorch MergeBot
parent d5d26ce436
commit 0cf61ca7e4
2 changed files with 68 additions and 1 deletions

View File

@ -14,6 +14,7 @@ import tempfile
import threading
import unittest
import warnings
from collections import defaultdict
from copy import deepcopy
from itertools import product
from random import randint
@ -5306,6 +5307,63 @@ class TestMemPool(TestCase):
out_0 = torch.randn(nelem_1mb, device="cuda")
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
def test_mempool_ctx_multithread(self):
torch.cuda.empty_cache()
segments = torch.cuda.memory._snapshot()["segments"]
self.assertEqual(len(segments), 0, "Expected empty pool in the beginning")
nelem = 1024 * 1024
trigger_alloc = threading.Event()
done_allocation = threading.Event()
def main_thread_fn():
pool = torch.cuda.MemPool()
out1 = torch.empty(nelem, dtype=torch.int8, device="cuda")
with torch.cuda.use_mem_pool(pool):
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
del out
trigger_alloc.set()
done_allocation.wait()
def side_thread_fn(segments):
trigger_alloc.wait()
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
s = torch.cuda.memory._snapshot()["segments"]
segments.append(s)
done_allocation.set()
segments = []
main_thread = threading.Thread(target=main_thread_fn)
side_thread = threading.Thread(target=side_thread_fn, args=(segments,))
main_thread.start()
side_thread.start()
main_thread.join(timeout=10)
side_thread.join(timeout=10)
if main_thread.is_alive() or side_thread.is_alive():
# release threads so that they don't hang forever
trigger_alloc.set()
done_allocation.set()
self.fail(
"Test timed out - threads did not complete within the allowed time"
)
self.assertEqual(len(segments), 1, "Expected to have memory snapshot")
self.assertEqual(len(segments[0]), 2, "Expected to have 2 segments allocated")
active = defaultdict(int)
for s in segments[0]:
active[s["segment_pool_id"]] += s["active_size"]
for k, v in active.items():
if k == (0, 0):
self.assertEqual(
v, 2097152, "Expected to have 2MB allocated in the default pool"
)
else:
self.assertEqual(
v, 0, "Expected to have 0 bytes allocated in the custom pool"
)
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
@torch.testing._internal.common_utils.markDynamoStrictTest

View File

@ -77,12 +77,16 @@ if not hasattr(torch._C, "_MemPool"):
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
"_cuda_beginAllocateToPool"
)
torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type(
"_cuda_beginAllocateCurrentThreadToPool"
)
torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type(
"_cuda_endAllocateToPool"
)
torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool")
from torch._C import ( # noqa: F401
_cuda_beginAllocateCurrentThreadToPool,
_cuda_beginAllocateToPool,
_cuda_CUDAAllocator,
_cuda_endAllocateToPool,
@ -1192,12 +1196,17 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default).
.. note::
This context manager makes only current thread's allocations route to
the given pool. If a new thread is spawned inside the context manager
(e.g. by calling backward) the allocations in that thread will not
route to the given pool.
"""
ctx = MemPoolContext(pool)
device_index = (
torch.cuda.current_device() if device is None else _get_device_index(device)
)
_cuda_beginAllocateToPool(device_index, pool.id)
_cuda_beginAllocateCurrentThreadToPool(device_index, pool.id)
try:
yield
finally: