Remove graph_pool as member of VllmBackend and argument to CUDAGraphWrapper (#23385)
Signed-off-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: ProExpertProg <11367180+ProExpertProg@users.noreply.github.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
This commit is contained in:
@ -294,13 +294,12 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
|
||||
def __init__(self, module: torch.fx.GraphModule,
|
||||
compile_submod_names: list[str], vllm_config: VllmConfig,
|
||||
graph_pool, vllm_backend: "VllmBackend"):
|
||||
vllm_backend: "VllmBackend"):
|
||||
super().__init__(module)
|
||||
from torch._guards import detect_fake_mode
|
||||
self.fake_mode = detect_fake_mode()
|
||||
self.compile_submod_names = compile_submod_names
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.vllm_config = vllm_config
|
||||
self.vllm_backend = vllm_backend
|
||||
# When True, it annoyingly dumps the torch.fx.Graph on errors.
|
||||
@ -359,7 +358,6 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
runnable=piecewise_backend,
|
||||
vllm_config=self.vllm_config,
|
||||
runtime_mode=CUDAGraphMode.PIECEWISE,
|
||||
graph_pool=self.graph_pool,
|
||||
cudagraph_options=CUDAGraphOptions(
|
||||
debug_log_enable=piecewise_backend.is_first_graph,
|
||||
gc_disable=not piecewise_backend.is_first_graph,
|
||||
@ -405,7 +403,6 @@ class VllmBackend:
|
||||
|
||||
vllm_config: VllmConfig
|
||||
compilation_config: CompilationConfig
|
||||
graph_pool: Any
|
||||
_called: bool = False
|
||||
# the graph we compiled
|
||||
graph: fx.GraphModule
|
||||
@ -433,13 +430,6 @@ class VllmBackend:
|
||||
# them, e.g. backbone (default), eagle_head, etc.
|
||||
self.prefix = prefix or model_tag
|
||||
|
||||
global_graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = global_graph_pool
|
||||
|
||||
# Passes to run on the graph post-grad.
|
||||
self.post_grad_pass_manager = PostGradPassManager()
|
||||
|
||||
@ -586,7 +576,7 @@ class VllmBackend:
|
||||
# propagate the split graph to the piecewise backend,
|
||||
# compile submodules with symbolic shapes
|
||||
PiecewiseCompileInterpreter(self.split_gm, submod_names_to_compile,
|
||||
self.vllm_config, self.graph_pool,
|
||||
self.vllm_config,
|
||||
self).run(*example_inputs)
|
||||
|
||||
graph_path = os.path.join(local_cache_dir, "computation_graph.py")
|
||||
|
@ -13,7 +13,7 @@ class AbstractStaticGraphWrapper(Protocol):
|
||||
"""
|
||||
|
||||
def __init__(self, runnable: Callable, vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode, graph_pool: Any, **kwargs):
|
||||
runtime_mode: CUDAGraphMode, **kwargs):
|
||||
"""
|
||||
Initializes the StaticGraphWrapper class with graph capturing and
|
||||
execution-related configurations.
|
||||
@ -25,9 +25,6 @@ class AbstractStaticGraphWrapper(Protocol):
|
||||
graph runtime. See CUDAGraphMode in vllm/config.py.
|
||||
Note that only the subset enum `NONE`, `PIECEWISE` and `FULL`
|
||||
are used as concrete runtime mode for cudagraph dispatching.
|
||||
graph_pool (Any):
|
||||
Graph memory pool handle, e.g.,
|
||||
`torch.cuda.graph_pool_handle()`.
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments for platform-specific
|
||||
configurations.
|
||||
|
@ -67,11 +67,9 @@ class CUDAGraphWrapper:
|
||||
runnable: Callable,
|
||||
vllm_config: VllmConfig,
|
||||
runtime_mode: CUDAGraphMode,
|
||||
graph_pool: Any = None,
|
||||
cudagraph_options: Optional[CUDAGraphOptions] = None):
|
||||
self.runnable = runnable
|
||||
self.vllm_config = vllm_config
|
||||
self.graph_pool = graph_pool
|
||||
self.runtime_mode = runtime_mode
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
@ -81,8 +79,10 @@ class CUDAGraphWrapper:
|
||||
# assert runtime_mode is not NONE(no cudagraph), otherwise, we don't
|
||||
# need to initialize a CUDAGraphWrapper.
|
||||
assert self.runtime_mode != CUDAGraphMode.NONE
|
||||
if self.graph_pool is None:
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
|
Reference in New Issue
Block a user