[core] fix sleep mode in pytorch 2.6 (#13456)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2025-02-18 13:48:10 +08:00
committed by GitHub
parent a1074b3efe
commit ac19b519ed

View File

@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C.
import dataclasses
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool
yield mem_pool, new_alloc
class CuMemAllocator:
@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback):
self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see