mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Compile][Platform] Make PiecewiseBackend pluggable and extendable (#18076)
Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
@ -6,9 +6,7 @@ import os
|
||||
import pprint
|
||||
import time
|
||||
from collections.abc import Sequence
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
@ -16,13 +14,13 @@ import torch.fx as fx
|
||||
import vllm.envs as envs
|
||||
from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import resolve_obj_by_qualname
|
||||
|
||||
from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
||||
InductorAdaptor, InductorStandaloneAdaptor)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .monitor import end_monitoring_torch_compile
|
||||
from .pass_manager import PostGradPassManager
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -297,7 +295,9 @@ class PiecewiseCompileInterpreter(torch.fx.Interpreter):
|
||||
num_graphs=len(self.compile_submod_names),
|
||||
runtime_shape=None)
|
||||
|
||||
self.module.__dict__[target] = PiecewiseBackend(
|
||||
piecewise_backend = resolve_obj_by_qualname(
|
||||
current_platform.get_piecewise_backend_cls())
|
||||
self.module.__dict__[target] = piecewise_backend(
|
||||
submod, self.vllm_config, self.graph_pool, index,
|
||||
len(self.compile_submod_names), sym_shape_indices,
|
||||
compiled_graph_for_general_shape, self.vllm_backend)
|
||||
@ -341,7 +341,7 @@ class VllmBackend:
|
||||
):
|
||||
global global_graph_pool
|
||||
if global_graph_pool is None:
|
||||
global_graph_pool = torch.cuda.graph_pool_handle()
|
||||
global_graph_pool = current_platform.graph_pool_handle()
|
||||
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
@ -558,197 +558,3 @@ class VllmBackend:
|
||||
return self.split_gm(*list_args)
|
||||
|
||||
return copy_and_call
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
need_to_compile: bool # the size is in compile_sizes
|
||||
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
||||
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
num_finished_warmup: int = 0
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class PiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
Independently, we will capture cudagraph for different shapes.
|
||||
|
||||
If a shape needs both compilation and cudagraph, we will
|
||||
compile it first, and then capture cudagraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
self.cudagraph_capture_sizes: set[int] = set(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
) if self.compilation_config.use_cudagraph else set()
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
need_to_compile=shape in self.compile_sizes,
|
||||
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.compiled_graph_for_general_shape
|
||||
|
||||
if entry.need_to_compile and not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
if self.is_first_graph:
|
||||
logger.debug(
|
||||
"Warming up %s/%s for shape %s",
|
||||
entry.num_finished_warmup,
|
||||
self.compilation_config.cudagraph_num_of_warmups,
|
||||
runtime_shape)
|
||||
return entry.runnable(*args)
|
||||
|
||||
if self.is_first_graph:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every shape.
|
||||
# We only log it in the debug mode.
|
||||
logger.debug("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of cudagraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_caputured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for cudagraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
||||
|
71
vllm/compilation/base_piecewise_backend.py
Normal file
71
vllm/compilation/base_piecewise_backend.py
Normal file
@ -0,0 +1,71 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
from typing import Any, Callable, Protocol
|
||||
|
||||
import torch.fx as fx
|
||||
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.config import VllmConfig
|
||||
|
||||
|
||||
class AbstractPiecewiseBackend(Protocol):
|
||||
"""
|
||||
PiecewiseBackend interface that allows platforms to extend
|
||||
piecewise static graph.
|
||||
"""
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend, **kwargs):
|
||||
"""
|
||||
Initializes the PiecewiseBackend class with compilation and
|
||||
execution-related configurations.
|
||||
|
||||
This class handles piecewise compilation, graph capturing,
|
||||
and dispatching for specific input shapes.
|
||||
|
||||
Args:
|
||||
graph (fx.GraphModule): The graph represented in fx.
|
||||
vllm_config (VllmConfig): Global configuration for vLLM.
|
||||
graph_pool (Any):
|
||||
Graph memory pool handle, e.g.,
|
||||
`torch.cuda.graph_pool_handle()`.
|
||||
piecewise_compile_index (int):
|
||||
Index of the current piecewise subgraph.
|
||||
total_piecewise_compiles (int):
|
||||
Total number of piecewise-compiled graphs.
|
||||
sym_shape_indices (list[int]):
|
||||
Indices of symbolic shape.
|
||||
compiled_graph_for_general_shape (Callable):
|
||||
Callable that executes the graph compiled for general shapes.
|
||||
vllm_backend (VllmBackend):
|
||||
Backend compiler that manages compilation and graph runtime
|
||||
for vLLM.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Additional keyword arguments reserved for future
|
||||
extensions or custom platforms.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
"""Executes the compiled graph for given input args.
|
||||
|
||||
If this is the first invocation, executes the general compiled graph
|
||||
and initiates the compilation process tracking. For subsequent calls,
|
||||
dynamically dispatches execution to either a compiled graph or a static
|
||||
graph based on the input shape.
|
||||
|
||||
Args:
|
||||
*args: Variable length input arguments to be passed into the
|
||||
graph. The symbolic shape is expected to be in position
|
||||
`sym_shape_indices[0]`.
|
||||
|
||||
Returns:
|
||||
Any: Output of the executed graph. This can be from the general
|
||||
compiled graph, a specialized compiled version for the given shape,
|
||||
or a replayed static graph.
|
||||
"""
|
||||
raise NotImplementedError
|
213
vllm/compilation/cuda_piecewise_backend.py
Normal file
213
vllm/compilation/cuda_piecewise_backend.py
Normal file
@ -0,0 +1,213 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
import dataclasses
|
||||
from contextlib import ExitStack
|
||||
from typing import Any, Callable, Optional
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.backends import VllmBackend
|
||||
from vllm.compilation.counter import compilation_counter
|
||||
from vllm.compilation.monitor import end_monitoring_torch_compile
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ConcreteSizeEntry:
|
||||
runtime_shape: int
|
||||
need_to_compile: bool # the size is in compile_sizes
|
||||
use_cudagraph: bool # the size is in cudagraph_capture_sizes
|
||||
|
||||
compiled: bool = False
|
||||
runnable: Callable = None # type: ignore
|
||||
num_finished_warmup: int = 0
|
||||
cudagraph: Optional[torch.cuda.CUDAGraph] = None
|
||||
output: Optional[Any] = None
|
||||
|
||||
# for cudagraph debugging, track the input addresses
|
||||
# during capture, and check if they are the same during replay
|
||||
input_addresses: Optional[list[int]] = None
|
||||
|
||||
|
||||
class CUDAPiecewiseBackend:
|
||||
|
||||
def __init__(self, graph: fx.GraphModule, vllm_config: VllmConfig,
|
||||
graph_pool: Any, piecewise_compile_index: int,
|
||||
total_piecewise_compiles: int, sym_shape_indices: list[int],
|
||||
compiled_graph_for_general_shape: Callable,
|
||||
vllm_backend: VllmBackend):
|
||||
"""
|
||||
The backend for piecewise compilation.
|
||||
It mainly handles the compilation and cudagraph capturing.
|
||||
|
||||
We will compile `self.graph` once for the general shape,
|
||||
and then compile for different shapes specified in
|
||||
`compilation_config.compile_sizes`.
|
||||
|
||||
Independently, we will capture cudagraph for different shapes.
|
||||
|
||||
If a shape needs both compilation and cudagraph, we will
|
||||
compile it first, and then capture cudagraph.
|
||||
"""
|
||||
self.graph = graph
|
||||
self.vllm_config = vllm_config
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
self.graph_pool = graph_pool
|
||||
self.piecewise_compile_index = piecewise_compile_index
|
||||
self.total_piecewise_compiles = total_piecewise_compiles
|
||||
self.vllm_backend = vllm_backend
|
||||
|
||||
self.is_first_graph = piecewise_compile_index == 0
|
||||
self.is_last_graph = (
|
||||
piecewise_compile_index == total_piecewise_compiles - 1)
|
||||
|
||||
self.compile_sizes: set[int] = set(
|
||||
self.compilation_config.compile_sizes)
|
||||
self.cudagraph_capture_sizes: set[int] = set(
|
||||
self.compilation_config.cudagraph_capture_sizes
|
||||
) if self.compilation_config.use_cudagraph else set()
|
||||
|
||||
self.first_run_finished = False
|
||||
|
||||
self.compiled_graph_for_general_shape = compiled_graph_for_general_shape # noqa
|
||||
|
||||
self.sym_shape_indices = sym_shape_indices
|
||||
|
||||
self.is_debugging_mode = envs.VLLM_LOGGING_LEVEL == "DEBUG"
|
||||
|
||||
# the entries for different shapes that we need to either
|
||||
# compile or capture cudagraph
|
||||
self.concrete_size_entries: dict[int, ConcreteSizeEntry] = {}
|
||||
|
||||
# to_be_compiled_sizes tracks the remaining sizes to compile,
|
||||
# and updates during the compilation process, so we need to copy it
|
||||
self.to_be_compiled_sizes: set[int] = self.compile_sizes.copy()
|
||||
for shape in self.compile_sizes.union(self.cudagraph_capture_sizes):
|
||||
self.concrete_size_entries[shape] = ConcreteSizeEntry(
|
||||
runtime_shape=shape,
|
||||
need_to_compile=shape in self.compile_sizes,
|
||||
use_cudagraph=shape in self.cudagraph_capture_sizes,
|
||||
)
|
||||
|
||||
def check_for_ending_compilation(self):
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
# no specific sizes to compile
|
||||
# save the hash of the inductor graph for the next run
|
||||
self.vllm_backend.compiler_manager.save_to_file()
|
||||
end_monitoring_torch_compile(self.vllm_config)
|
||||
|
||||
def __call__(self, *args) -> Any:
|
||||
if not self.first_run_finished:
|
||||
self.first_run_finished = True
|
||||
self.check_for_ending_compilation()
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
runtime_shape = args[self.sym_shape_indices[0]]
|
||||
if runtime_shape not in self.concrete_size_entries:
|
||||
# we don't need to do anything for this shape
|
||||
return self.compiled_graph_for_general_shape(*args)
|
||||
|
||||
entry = self.concrete_size_entries[runtime_shape]
|
||||
|
||||
if entry.runnable is None:
|
||||
entry.runnable = self.compiled_graph_for_general_shape
|
||||
|
||||
if entry.need_to_compile and not entry.compiled:
|
||||
entry.compiled = True
|
||||
self.to_be_compiled_sizes.remove(runtime_shape)
|
||||
# args are real arguments
|
||||
entry.runnable = self.vllm_backend.compiler_manager.compile(
|
||||
self.graph,
|
||||
args,
|
||||
self.compilation_config.inductor_compile_config,
|
||||
self.compilation_config,
|
||||
graph_index=self.piecewise_compile_index,
|
||||
num_graphs=self.total_piecewise_compiles,
|
||||
runtime_shape=runtime_shape)
|
||||
|
||||
# finished compilations for all required shapes
|
||||
if self.is_last_graph and not self.to_be_compiled_sizes:
|
||||
self.check_for_ending_compilation()
|
||||
|
||||
if not entry.use_cudagraph:
|
||||
return entry.runnable(*args)
|
||||
|
||||
if entry.cudagraph is None:
|
||||
if entry.num_finished_warmup < self.compilation_config.cudagraph_num_of_warmups: # noqa
|
||||
entry.num_finished_warmup += 1
|
||||
if self.is_first_graph:
|
||||
logger.debug(
|
||||
"Warming up %s/%s for shape %s",
|
||||
entry.num_finished_warmup,
|
||||
self.compilation_config.cudagraph_num_of_warmups,
|
||||
runtime_shape)
|
||||
return entry.runnable(*args)
|
||||
|
||||
if self.is_first_graph:
|
||||
# Since we capture cudagraph for many different shapes and
|
||||
# capturing is fast, we don't need to log it for every shape.
|
||||
# We only log it in the debug mode.
|
||||
logger.debug("Capturing a cudagraph for shape %s",
|
||||
runtime_shape)
|
||||
|
||||
input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
entry.input_addresses = input_addresses
|
||||
cudagraph = torch.cuda.CUDAGraph()
|
||||
|
||||
with ExitStack() as stack:
|
||||
if not self.is_first_graph:
|
||||
# during every model forward, we will capture
|
||||
# many pieces of cudagraphs (roughly one per layer).
|
||||
# running gc again and again across layers will
|
||||
# make the cudagraph capture very slow.
|
||||
# therefore, we only run gc for the first graph,
|
||||
# and disable gc for the rest of the graphs.
|
||||
stack.enter_context(patch("gc.collect", lambda: None))
|
||||
stack.enter_context(
|
||||
patch("torch.cuda.empty_cache", lambda: None))
|
||||
|
||||
# mind-exploding: carefully manage the reference and memory.
|
||||
with torch.cuda.graph(cudagraph, pool=self.graph_pool):
|
||||
# `output` is managed by pytorch's cudagraph pool
|
||||
output = entry.runnable(*args)
|
||||
if self.is_last_graph:
|
||||
# by converting it to weak ref,
|
||||
# the original `output` will immediately be released
|
||||
# to save memory. It is only safe to do this for
|
||||
# the last graph, because the output of the last graph
|
||||
# will not be used by any other cuda graph.
|
||||
output = weak_ref_tensors(output)
|
||||
|
||||
# here we always use weak ref for the output
|
||||
# to save memory
|
||||
entry.output = weak_ref_tensors(output)
|
||||
entry.cudagraph = cudagraph
|
||||
|
||||
compilation_counter.num_cudagraph_caputured += 1
|
||||
|
||||
# important: we need to return the output, rather than
|
||||
# the weak ref of the output, so that pytorch can correctly
|
||||
# manage the memory during cuda graph capture
|
||||
return output
|
||||
|
||||
if self.is_debugging_mode:
|
||||
# check if the input addresses are the same
|
||||
new_input_addresses = [
|
||||
x.data_ptr() for x in args if isinstance(x, torch.Tensor)
|
||||
]
|
||||
assert new_input_addresses == entry.input_addresses, (
|
||||
"Input addresses for cudagraphs are different during replay."
|
||||
f" Expected {entry.input_addresses}, got {new_input_addresses}"
|
||||
)
|
||||
|
||||
entry.cudagraph.replay()
|
||||
return entry.output
|
@ -311,6 +311,10 @@ class CudaPlatformBase(Platform):
|
||||
def use_custom_allreduce(cls) -> bool:
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
|
||||
|
||||
|
||||
# NVML utils
|
||||
# Note that NVML is not affected by `CUDA_VISIBLE_DEVICES`,
|
||||
|
@ -478,6 +478,13 @@ class Platform:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
"""
|
||||
Get piecewise backend class for piecewise graph.
|
||||
"""
|
||||
return "vllm.compilation.base_piecewise_backend.AbstractPiecewiseBackend" # noqa
|
||||
|
||||
|
||||
class UnspecifiedPlatform(Platform):
|
||||
_enum = PlatformEnum.UNSPECIFIED
|
||||
|
@ -382,3 +382,7 @@ class RocmPlatform(Platform):
|
||||
@classmethod
|
||||
def is_navi(cls) -> bool:
|
||||
return 'gfx1' in torch.cuda.get_device_properties(0).gcnArchName
|
||||
|
||||
@classmethod
|
||||
def get_piecewise_backend_cls(cls) -> str:
|
||||
return "vllm.compilation.cuda_piecewise_backend.CUDAPiecewiseBackend" # noqa
|
||||
|
Reference in New Issue
Block a user