Remove graph_pool as member of VllmBackend and argument to CUDAGraphWrapper (#23385)

Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com>
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
Copilot
2025-08-25 19:34:15 -07:00
committed by GitHub
parent 6fd45e7b8a
commit 6fad29b11b
3 changed files with 7 additions and 20 deletions

View File

@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
def __init__(self, module: torch.fx.GraphModule,
compile_submod_names: list[str], vllm_config: VllmConfig,
graph_pool, vllm_backend: "VllmBackend"):
vllm_backend: "VllmBackend"):
super().__init__(module)
from torch._guards import detect_fake_mode
self.fake_mode = detect_fake_mode()
self.compile_submod_names = compile_submod_names
self.compilation_config = vllm_config.compilation_config
self.graph_pool = graph_pool
self.vllm_config = vllm_config
self.vllm_backend = vllm_backend
# When True, it annoyingly dumps the torch.fx.Graph on errors.
@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
runnable=piecewise_backend,
vllm_config=self.vllm_config,
runtime_mode=CUDAGraphMode.PIECEWISE,
graph_pool=self.graph_pool,
cudagraph_options=CUDAGraphOptions(
debug_log_enable=piecewise_backend.is_first_graph,
gc_disable=not piecewise_backend.is_first_graph,
@ -405,7 +403,6 @@ class VllmBackend:
vllm_config: VllmConfig
compilation_config: CompilationConfig
graph_pool: Any
_called: bool = False
# the graph we compiled
graph: fx.GraphModule
@ -433,13 +430,6 @@ class VllmBackend:
# them, e.g. backbone (default), eagle_head, etc.
self.prefix = prefix or model_tag
global_graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = global_graph_pool
# Passes to run on the graph post-grad.
self.post_grad_pass_manager = PostGradPassManager()
@ -586,7 +576,7 @@ class VllmBackend:
# propagate the split graph to the piecewise backend,
# compile submodules with symbolic shapes
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
self.vllm_config, self.graph_pool,
self.vllm_config,
self).run(*example_inputs)
graph_path = os.path.join(local_cache_dir, "computation_graph.py")

View File

@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
"""
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
runtime_mode: CUDAGraphMode, **kwargs):
"""
Initializes the StaticGraphWrapper class with graph capturing and
execution-related configurations.
@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
graph runtime. See CUDAGraphMode in vllm/config.py.
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
are used as concrete runtime mode for cudagraph dispatching.
graph_pool (Any):
Graph memory pool handle, e.g.,
`torch.cuda.graph_pool_handle()`.
Keyword Args:
kwargs: Additional keyword arguments for platform-specific
configurations.

View File

@ -67,11 +67,9 @@ class CUDAGraphWrapper:
runnable: Callable,
vllm_config: VllmConfig,
runtime_mode: CUDAGraphMode,
graph_pool: Any = None,
cudagraph_options: Optional[CUDAGraphOptions] = None):
self.runnable = runnable
self.vllm_config = vllm_config
self.graph_pool = graph_pool
self.runtime_mode = runtime_mode
self.compilation_config = vllm_config.compilation_config
@ -81,8 +79,10 @@ class CUDAGraphWrapper:
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
# need to initialize a CUDAGraphWrapper.
assert self.runtime_mode != CUDAGraphMode.NONE
if self.graph_pool is None:
self.graph_pool = current_platform.get_global_graph_pool()
# TODO: in the future, if we want to use multiple
# streams, it might not be safe to share a global pool.
# only investigate this when we use multiple streams
self.graph_pool = current_platform.get_global_graph_pool()
if cudagraph_options is None:
cudagraph_options = CUDAGraphOptions()