make use_mem_pool threadlocal (#153356)

Partial fix for #152861, makes allocation to pool thread-local, but doesn't touch the second bug where multiple threads allocating to multiple pools error.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153356
Approved by: https://github.com/Skylion007, https://github.com/eellison
This commit is contained in:
Natalia Gimelshein
2025-05-13 00:16:02 +00:00
committed by PyTorch MergeBot
parent d5d26ce436
commit 0cf61ca7e4
2 changed files with 68 additions and 1 deletions

View File

@ -14,6 +14,7 @@ import tempfile
import threading import threading
import unittest import unittest
import warnings import warnings
from collections import defaultdict
from copy import deepcopy from copy import deepcopy
from itertools import product from itertools import product
from random import randint from random import randint
@ -5306,6 +5307,63 @@ class TestMemPool(TestCase):
out_0 = torch.randn(nelem_1mb, device="cuda") out_0 = torch.randn(nelem_1mb, device="cuda")
torch.cuda.memory._set_allocator_settings("expandable_segments:False") torch.cuda.memory._set_allocator_settings("expandable_segments:False")
def test_mempool_ctx_multithread(self):
torch.cuda.empty_cache()
segments = torch.cuda.memory._snapshot()["segments"]
self.assertEqual(len(segments), 0, "Expected empty pool in the beginning")
nelem = 1024 * 1024
trigger_alloc = threading.Event()
done_allocation = threading.Event()
def main_thread_fn():
pool = torch.cuda.MemPool()
out1 = torch.empty(nelem, dtype=torch.int8, device="cuda")
with torch.cuda.use_mem_pool(pool):
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
del out
trigger_alloc.set()
done_allocation.wait()
def side_thread_fn(segments):
trigger_alloc.wait()
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
s = torch.cuda.memory._snapshot()["segments"]
segments.append(s)
done_allocation.set()
segments = []
main_thread = threading.Thread(target=main_thread_fn)
side_thread = threading.Thread(target=side_thread_fn, args=(segments,))
main_thread.start()
side_thread.start()
main_thread.join(timeout=10)
side_thread.join(timeout=10)
if main_thread.is_alive() or side_thread.is_alive():
# release threads so that they don't hang forever
trigger_alloc.set()
done_allocation.set()
self.fail(
"Test timed out - threads did not complete within the allowed time"
)
self.assertEqual(len(segments), 1, "Expected to have memory snapshot")
self.assertEqual(len(segments[0]), 2, "Expected to have 2 segments allocated")
active = defaultdict(int)
for s in segments[0]:
active[s["segment_pool_id"]] += s["active_size"]
for k, v in active.items():
if k == (0, 0):
self.assertEqual(
v, 2097152, "Expected to have 2MB allocated in the default pool"
)
else:
self.assertEqual(
v, 0, "Expected to have 0 bytes allocated in the custom pool"
)
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests") @unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
@torch.testing._internal.common_utils.markDynamoStrictTest @torch.testing._internal.common_utils.markDynamoStrictTest

View File

@ -77,12 +77,16 @@ if not hasattr(torch._C, "_MemPool"):
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type( torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
"_cuda_beginAllocateToPool" "_cuda_beginAllocateToPool"
) )
torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type(
"_cuda_beginAllocateCurrentThreadToPool"
)
torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type( torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type(
"_cuda_endAllocateToPool" "_cuda_endAllocateToPool"
) )
torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool") torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool")
from torch._C import ( # noqa: F401 from torch._C import ( # noqa: F401
_cuda_beginAllocateCurrentThreadToPool,
_cuda_beginAllocateToPool, _cuda_beginAllocateToPool,
_cuda_CUDAAllocator, _cuda_CUDAAllocator,
_cuda_endAllocateToPool, _cuda_endAllocateToPool,
@ -1192,12 +1196,17 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
the current device, given by :func:`~torch.cuda.current_device`, the current device, given by :func:`~torch.cuda.current_device`,
if :attr:`device` is ``None`` (default). if :attr:`device` is ``None`` (default).
.. note::
This context manager makes only current thread's allocations route to
the given pool. If a new thread is spawned inside the context manager
(e.g. by calling backward) the allocations in that thread will not
route to the given pool.
""" """
ctx = MemPoolContext(pool) ctx = MemPoolContext(pool)
device_index = ( device_index = (
torch.cuda.current_device() if device is None else _get_device_index(device) torch.cuda.current_device() if device is None else _get_device_index(device)
) )
_cuda_beginAllocateToPool(device_index, pool.id) _cuda_beginAllocateCurrentThreadToPool(device_index, pool.id)
try: try:
yield yield
finally: finally: