Compare commits

...

2 Commits

Author SHA1 Message Date
936da0f740 update
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 23:30:15 +00:00
20098c10d9 Remove global CUDA graph pool
Signed-off-by: Woosuk Kwon <woosuk@thinkingmachines.ai>
2025-09-19 23:27:51 +00:00
3 changed files with 2 additions and 13 deletions

View File

@ -82,7 +82,7 @@ class CUDAGraphWrapper:
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
self.graph_pool = current_platform.graph_pool_handle()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()

View File

@ -140,8 +140,6 @@ class Platform:
additional_env_vars: list[str] = []
_global_graph_pool: Optional[Any] = None
@property
def supported_dtypes(self) -> list[torch.dtype]:
"""Returns the supported dtypes for the current platform."""
@ -535,15 +533,6 @@ class Platform:
" attribute.", self.device_type, key)
return None
def get_global_graph_pool(self) -> Any:
"""
Return the global graph pool for this platform.
"""
cls = self.__class__
if cls._global_graph_pool is None:
cls._global_graph_pool = self.graph_pool_handle()
return cls._global_graph_pool
@classmethod
def get_cu_count(cls, device_id: int = 0) -> int:
"""

View File

@ -54,7 +54,7 @@ class UBatchWrapper:
if runtime_mode is not CUDAGraphMode.NONE:
self.cudagraph_wrapper = CUDAGraphWrapper(
runnable, vllm_config, runtime_mode=runtime_mode)
self.graph_pool = current_platform.get_global_graph_pool()
self.graph_pool = current_platform.graph_pool_handle()
def __getattr__(self, key: str):
# allow accessing the attributes of the runnable.