mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[core] fix sleep mode in pytorch 2.6 (#13456)
Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -9,7 +9,7 @@
|
||||
# the only successful approach is to call cuda driver API in C.
|
||||
import dataclasses
|
||||
from contextlib import contextmanager
|
||||
from typing import Callable, Dict, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
|
||||
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
|
||||
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
|
||||
with torch.cuda.memory.use_mem_pool(mem_pool):
|
||||
yield mem_pool
|
||||
yield mem_pool, new_alloc
|
||||
|
||||
|
||||
class CuMemAllocator:
|
||||
@ -142,6 +142,7 @@ class CuMemAllocator:
|
||||
def __init__(self):
|
||||
self.pointer_to_data: Dict[int, AllocationData] = {}
|
||||
self.current_tag: str = CuMemAllocator.default_tag
|
||||
self.allocator_and_pools: Dict[str, Any] = {}
|
||||
|
||||
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
|
||||
"""
|
||||
@ -231,7 +232,13 @@ class CuMemAllocator:
|
||||
old_tag = self.current_tag
|
||||
self.current_tag = tag
|
||||
with use_memory_pool_with_allocator(self.python_malloc_callback,
|
||||
self.python_free_callback):
|
||||
self.python_free_callback) as data:
|
||||
# start to hit another PyTorch bug in PyTorch 2.6,
|
||||
# possibly because of gc-related issue w.r.t. the allocator and
|
||||
# the memory pool.
|
||||
# to avoid the issue, we keep a reference of the data.
|
||||
# see https://github.com/pytorch/pytorch/issues/146431 .
|
||||
self.allocator_and_pools[tag] = data
|
||||
yield
|
||||
# PyTorch's bug, calling torch.cuda.empty_cache() will error
|
||||
# when using pluggable allocator, see
|
||||
|
Reference in New Issue
Block a user