mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 12:54:11 +08:00
make use_mem_pool threadlocal (#153356)
Partial fix for #152861, makes allocation to pool thread-local, but doesn't touch the second bug where multiple threads allocating to multiple pools error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153356 Approved by: https://github.com/Skylion007, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5d26ce436
commit
0cf61ca7e4
@ -14,6 +14,7 @@ import tempfile
|
||||
import threading
|
||||
import unittest
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from copy import deepcopy
|
||||
from itertools import product
|
||||
from random import randint
|
||||
@ -5306,6 +5307,63 @@ class TestMemPool(TestCase):
|
||||
out_0 = torch.randn(nelem_1mb, device="cuda")
|
||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||
|
||||
def test_mempool_ctx_multithread(self):
|
||||
torch.cuda.empty_cache()
|
||||
segments = torch.cuda.memory._snapshot()["segments"]
|
||||
self.assertEqual(len(segments), 0, "Expected empty pool in the beginning")
|
||||
|
||||
nelem = 1024 * 1024
|
||||
trigger_alloc = threading.Event()
|
||||
done_allocation = threading.Event()
|
||||
|
||||
def main_thread_fn():
|
||||
pool = torch.cuda.MemPool()
|
||||
out1 = torch.empty(nelem, dtype=torch.int8, device="cuda")
|
||||
with torch.cuda.use_mem_pool(pool):
|
||||
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
|
||||
del out
|
||||
trigger_alloc.set()
|
||||
done_allocation.wait()
|
||||
|
||||
def side_thread_fn(segments):
|
||||
trigger_alloc.wait()
|
||||
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
|
||||
s = torch.cuda.memory._snapshot()["segments"]
|
||||
segments.append(s)
|
||||
done_allocation.set()
|
||||
|
||||
segments = []
|
||||
main_thread = threading.Thread(target=main_thread_fn)
|
||||
side_thread = threading.Thread(target=side_thread_fn, args=(segments,))
|
||||
|
||||
main_thread.start()
|
||||
side_thread.start()
|
||||
main_thread.join(timeout=10)
|
||||
side_thread.join(timeout=10)
|
||||
|
||||
if main_thread.is_alive() or side_thread.is_alive():
|
||||
# release threads so that they don't hang forever
|
||||
trigger_alloc.set()
|
||||
done_allocation.set()
|
||||
self.fail(
|
||||
"Test timed out - threads did not complete within the allowed time"
|
||||
)
|
||||
|
||||
self.assertEqual(len(segments), 1, "Expected to have memory snapshot")
|
||||
self.assertEqual(len(segments[0]), 2, "Expected to have 2 segments allocated")
|
||||
active = defaultdict(int)
|
||||
for s in segments[0]:
|
||||
active[s["segment_pool_id"]] += s["active_size"]
|
||||
for k, v in active.items():
|
||||
if k == (0, 0):
|
||||
self.assertEqual(
|
||||
v, 2097152, "Expected to have 2MB allocated in the default pool"
|
||||
)
|
||||
else:
|
||||
self.assertEqual(
|
||||
v, 0, "Expected to have 0 bytes allocated in the custom pool"
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||
|
@ -77,12 +77,16 @@ if not hasattr(torch._C, "_MemPool"):
|
||||
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
||||
"_cuda_beginAllocateToPool"
|
||||
)
|
||||
torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type(
|
||||
"_cuda_beginAllocateCurrentThreadToPool"
|
||||
)
|
||||
torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type(
|
||||
"_cuda_endAllocateToPool"
|
||||
)
|
||||
torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool")
|
||||
|
||||
from torch._C import ( # noqa: F401
|
||||
_cuda_beginAllocateCurrentThreadToPool,
|
||||
_cuda_beginAllocateToPool,
|
||||
_cuda_CUDAAllocator,
|
||||
_cuda_endAllocateToPool,
|
||||
@ -1192,12 +1196,17 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
|
||||
the current device, given by :func:`~torch.cuda.current_device`,
|
||||
if :attr:`device` is ``None`` (default).
|
||||
|
||||
.. note::
|
||||
This context manager makes only current thread's allocations route to
|
||||
the given pool. If a new thread is spawned inside the context manager
|
||||
(e.g. by calling backward) the allocations in that thread will not
|
||||
route to the given pool.
|
||||
"""
|
||||
ctx = MemPoolContext(pool)
|
||||
device_index = (
|
||||
torch.cuda.current_device() if device is None else _get_device_index(device)
|
||||
)
|
||||
_cuda_beginAllocateToPool(device_index, pool.id)
|
||||
_cuda_beginAllocateCurrentThreadToPool(device_index, pool.id)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
|
Reference in New Issue
Block a user