mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
make use_mem_pool threadlocal (#153356)
Partial fix for #152861, makes allocation to pool thread-local, but doesn't touch the second bug where multiple threads allocating to multiple pools error. Pull Request resolved: https://github.com/pytorch/pytorch/pull/153356 Approved by: https://github.com/Skylion007, https://github.com/eellison
This commit is contained in:
committed by
PyTorch MergeBot
parent
d5d26ce436
commit
0cf61ca7e4
@ -14,6 +14,7 @@ import tempfile
|
|||||||
import threading
|
import threading
|
||||||
import unittest
|
import unittest
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import defaultdict
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from itertools import product
|
from itertools import product
|
||||||
from random import randint
|
from random import randint
|
||||||
@ -5306,6 +5307,63 @@ class TestMemPool(TestCase):
|
|||||||
out_0 = torch.randn(nelem_1mb, device="cuda")
|
out_0 = torch.randn(nelem_1mb, device="cuda")
|
||||||
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
torch.cuda.memory._set_allocator_settings("expandable_segments:False")
|
||||||
|
|
||||||
|
def test_mempool_ctx_multithread(self):
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
segments = torch.cuda.memory._snapshot()["segments"]
|
||||||
|
self.assertEqual(len(segments), 0, "Expected empty pool in the beginning")
|
||||||
|
|
||||||
|
nelem = 1024 * 1024
|
||||||
|
trigger_alloc = threading.Event()
|
||||||
|
done_allocation = threading.Event()
|
||||||
|
|
||||||
|
def main_thread_fn():
|
||||||
|
pool = torch.cuda.MemPool()
|
||||||
|
out1 = torch.empty(nelem, dtype=torch.int8, device="cuda")
|
||||||
|
with torch.cuda.use_mem_pool(pool):
|
||||||
|
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
|
||||||
|
del out
|
||||||
|
trigger_alloc.set()
|
||||||
|
done_allocation.wait()
|
||||||
|
|
||||||
|
def side_thread_fn(segments):
|
||||||
|
trigger_alloc.wait()
|
||||||
|
out = torch.empty(nelem, dtype=torch.int8, device="cuda")
|
||||||
|
s = torch.cuda.memory._snapshot()["segments"]
|
||||||
|
segments.append(s)
|
||||||
|
done_allocation.set()
|
||||||
|
|
||||||
|
segments = []
|
||||||
|
main_thread = threading.Thread(target=main_thread_fn)
|
||||||
|
side_thread = threading.Thread(target=side_thread_fn, args=(segments,))
|
||||||
|
|
||||||
|
main_thread.start()
|
||||||
|
side_thread.start()
|
||||||
|
main_thread.join(timeout=10)
|
||||||
|
side_thread.join(timeout=10)
|
||||||
|
|
||||||
|
if main_thread.is_alive() or side_thread.is_alive():
|
||||||
|
# release threads so that they don't hang forever
|
||||||
|
trigger_alloc.set()
|
||||||
|
done_allocation.set()
|
||||||
|
self.fail(
|
||||||
|
"Test timed out - threads did not complete within the allowed time"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.assertEqual(len(segments), 1, "Expected to have memory snapshot")
|
||||||
|
self.assertEqual(len(segments[0]), 2, "Expected to have 2 segments allocated")
|
||||||
|
active = defaultdict(int)
|
||||||
|
for s in segments[0]:
|
||||||
|
active[s["segment_pool_id"]] += s["active_size"]
|
||||||
|
for k, v in active.items():
|
||||||
|
if k == (0, 0):
|
||||||
|
self.assertEqual(
|
||||||
|
v, 2097152, "Expected to have 2MB allocated in the default pool"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.assertEqual(
|
||||||
|
v, 0, "Expected to have 0 bytes allocated in the custom pool"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
@unittest.skipIf(not TEST_CUDA, "CUDA not available, skipping tests")
|
||||||
@torch.testing._internal.common_utils.markDynamoStrictTest
|
@torch.testing._internal.common_utils.markDynamoStrictTest
|
||||||
|
@ -77,12 +77,16 @@ if not hasattr(torch._C, "_MemPool"):
|
|||||||
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
torch._C.__dict__["_cuda_beginAllocateToPool"] = _dummy_type(
|
||||||
"_cuda_beginAllocateToPool"
|
"_cuda_beginAllocateToPool"
|
||||||
)
|
)
|
||||||
|
torch._C.__dict__["_cuda_beginAllocateCurrentThreadToPool"] = _dummy_type(
|
||||||
|
"_cuda_beginAllocateCurrentThreadToPool"
|
||||||
|
)
|
||||||
torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type(
|
torch._C.__dict__["_cuda_endAllocateToPool"] = _dummy_type(
|
||||||
"_cuda_endAllocateToPool"
|
"_cuda_endAllocateToPool"
|
||||||
)
|
)
|
||||||
torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool")
|
torch._C.__dict__["_cuda_releasePool"] = _dummy_type("_cuda_releasePool")
|
||||||
|
|
||||||
from torch._C import ( # noqa: F401
|
from torch._C import ( # noqa: F401
|
||||||
|
_cuda_beginAllocateCurrentThreadToPool,
|
||||||
_cuda_beginAllocateToPool,
|
_cuda_beginAllocateToPool,
|
||||||
_cuda_CUDAAllocator,
|
_cuda_CUDAAllocator,
|
||||||
_cuda_endAllocateToPool,
|
_cuda_endAllocateToPool,
|
||||||
@ -1192,12 +1196,17 @@ def use_mem_pool(pool: MemPool, device: "Device" = None):
|
|||||||
the current device, given by :func:`~torch.cuda.current_device`,
|
the current device, given by :func:`~torch.cuda.current_device`,
|
||||||
if :attr:`device` is ``None`` (default).
|
if :attr:`device` is ``None`` (default).
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This context manager makes only current thread's allocations route to
|
||||||
|
the given pool. If a new thread is spawned inside the context manager
|
||||||
|
(e.g. by calling backward) the allocations in that thread will not
|
||||||
|
route to the given pool.
|
||||||
"""
|
"""
|
||||||
ctx = MemPoolContext(pool)
|
ctx = MemPoolContext(pool)
|
||||||
device_index = (
|
device_index = (
|
||||||
torch.cuda.current_device() if device is None else _get_device_index(device)
|
torch.cuda.current_device() if device is None else _get_device_index(device)
|
||||||
)
|
)
|
||||||
_cuda_beginAllocateToPool(device_index, pool.id)
|
_cuda_beginAllocateCurrentThreadToPool(device_index, pool.id)
|
||||||
try:
|
try:
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
|
Reference in New Issue
Block a user