mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Add option to use torch._inductor.standalone_compile (#17057)
Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
@ -17,7 +17,8 @@ from vllm.config import CompilationConfig, VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import weak_ref_tensors
|
||||
|
||||
from .compiler_interface import EagerAdaptor, InductorAdaptor
|
||||
from .compiler_interface import (CompilerInterface, EagerAdaptor,
|
||||
InductorAdaptor, InductorStandaloneAdaptor)
|
||||
from .counter import compilation_counter
|
||||
from .inductor_pass import InductorPass
|
||||
from .monitor import end_monitoring_torch_compile
|
||||
@ -26,6 +27,19 @@ from .pass_manager import PostGradPassManager
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
if envs.VLLM_TEST_STANDALONE_COMPILE:
|
||||
logger.info("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
logger.info("Using InductorAdaptor")
|
||||
return InductorAdaptor()
|
||||
else:
|
||||
logger.info("Using EagerAdaptor")
|
||||
return EagerAdaptor()
|
||||
|
||||
|
||||
class CompilerManager:
|
||||
"""
|
||||
A manager to manage the compilation process, including
|
||||
@ -41,11 +55,11 @@ class CompilerManager:
|
||||
support int as key.
|
||||
"""
|
||||
|
||||
def __init__(self, use_inductor: bool):
|
||||
def __init__(self, compilation_config: CompilationConfig):
|
||||
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
|
||||
cls = InductorAdaptor if use_inductor else EagerAdaptor
|
||||
self.compiler = cls()
|
||||
self.is_cache_updated = False
|
||||
self.compilation_config = compilation_config
|
||||
self.compiler = make_compiler(compilation_config)
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
return self.compiler.compute_hash(vllm_config)
|
||||
@ -123,8 +137,15 @@ class CompilerManager:
|
||||
|
||||
# no compiler cached the graph, or the cache is disabled,
|
||||
# we need to compile it
|
||||
if isinstance(self.compiler, InductorAdaptor):
|
||||
# Let compile_fx generate a key for us
|
||||
maybe_key = None
|
||||
else:
|
||||
maybe_key = \
|
||||
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
|
||||
compiled_graph, handle = self.compiler.compile(
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape)
|
||||
graph, example_inputs, additional_inductor_config, runtime_shape,
|
||||
maybe_key)
|
||||
|
||||
assert compiled_graph is not None, "Failed to compile the graph"
|
||||
|
||||
@ -336,7 +357,7 @@ class VllmBackend:
|
||||
self.compilation_config = vllm_config.compilation_config
|
||||
|
||||
self.compiler_manager: CompilerManager = CompilerManager(
|
||||
self.compilation_config.use_inductor)
|
||||
self.compilation_config)
|
||||
|
||||
# `torch.compile` is JIT compiled, so we don't need to
|
||||
# do anything here
|
||||
|
@ -50,7 +50,8 @@ class CompilerInterface:
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
"""
|
||||
Compile the graph with the given example inputs and compiler config,
|
||||
@ -71,6 +72,10 @@ class CompilerInterface:
|
||||
If the compiler doesn't support caching, it should return None for the
|
||||
handle. If the compiler fails to compile the graph, it should return
|
||||
None for the compiled function as well.
|
||||
|
||||
`key` is required for StandaloneInductorAdapter, it specifies where to
|
||||
save the compiled artifact. The compiled artifact gets saved to
|
||||
`cache_dir/key`.
|
||||
"""
|
||||
return None, None
|
||||
|
||||
@ -127,23 +132,108 @@ class AlwaysHitShapeEnv:
|
||||
return ""
|
||||
|
||||
|
||||
def get_inductor_factors() -> List[Any]:
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
return factors
|
||||
|
||||
|
||||
class InductorStandaloneAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler.
|
||||
Requires PyTorch 2.8+.
|
||||
This is not on by default yet, but we plan to turn it on by default for
|
||||
PyTorch 2.8.
|
||||
|
||||
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
|
||||
"""
|
||||
name = "inductor_standalone"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
|
||||
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
|
||||
self.cache_dir = cache_dir
|
||||
|
||||
def compile(
|
||||
self,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
dynamic_shapes = "from_example_inputs"
|
||||
else:
|
||||
dynamic_shapes = "from_tracing_context"
|
||||
|
||||
from torch._inductor import standalone_compile
|
||||
with pass_context(runtime_shape):
|
||||
compiled_graph = standalone_compile(
|
||||
graph,
|
||||
example_inputs,
|
||||
dynamic_shapes=dynamic_shapes,
|
||||
options={"config_patches": current_config})
|
||||
|
||||
# Save the compiled artifact to disk in the specified path
|
||||
assert key is not None
|
||||
path = os.path.join(self.cache_dir, key)
|
||||
compiled_graph.save(path=path, format="unpacked")
|
||||
return compiled_graph, (key, path)
|
||||
|
||||
def load(self,
|
||||
handle: Any,
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
graph_index: int,
|
||||
runtime_shape: Optional[int] = None) -> Callable:
|
||||
assert isinstance(handle, tuple)
|
||||
assert isinstance(handle[0], str)
|
||||
assert isinstance(handle[1], str)
|
||||
path = handle[1]
|
||||
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
|
||||
path=path, format="unpacked")
|
||||
from torch._inductor.compile_fx import graph_returns_tuple
|
||||
returns_tuple = graph_returns_tuple(graph)
|
||||
|
||||
def compiled_graph_wrapper(*args):
|
||||
graph_output = inductor_compiled_graph(*args)
|
||||
# unpack the tuple if needed
|
||||
# TODO(rzou): the implication is that we're not
|
||||
# reading the python bytecode correctly in vLLM?
|
||||
if returns_tuple:
|
||||
return graph_output
|
||||
else:
|
||||
return graph_output[0]
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
class InductorAdaptor(CompilerInterface):
|
||||
"""
|
||||
The adaptor for the Inductor compiler, version 2.5 and 2.6.
|
||||
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
|
||||
"""
|
||||
name = "inductor"
|
||||
|
||||
def compute_hash(self, vllm_config: VllmConfig) -> str:
|
||||
factors: List[Any] = []
|
||||
# summarize system state
|
||||
from torch._inductor.codecache import CacheBase
|
||||
system_factors = CacheBase.get_system()
|
||||
factors.append(system_factors)
|
||||
|
||||
# summarize pytorch state
|
||||
from torch._inductor.codecache import torch_key
|
||||
torch_factors = torch_key()
|
||||
factors.append(torch_factors)
|
||||
factors = get_inductor_factors()
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()[:10]
|
||||
return hash_str
|
||||
@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
current_config = {}
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
current_config = {}
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
# disable remote cache
|
||||
current_config["fx_graph_cache"] = True
|
||||
current_config["fx_graph_remote_cache"] = False
|
||||
|
||||
if compiler_config is not None:
|
||||
current_config.update(compiler_config)
|
||||
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
current_config["max_autotune"] = True
|
||||
current_config["coordinate_descent_tuning"] = True
|
||||
set_inductor_config(current_config, runtime_shape)
|
||||
|
||||
# inductor can inplace modify the graph, so we need to copy it
|
||||
# see https://github.com/pytorch/pytorch/issues/138980
|
||||
@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface):
|
||||
return contextlib.nullcontext()
|
||||
|
||||
|
||||
def set_inductor_config(config, runtime_shape):
|
||||
if isinstance(runtime_shape, int):
|
||||
# for a specific batchsize, tuning triton kernel parameters
|
||||
# can be beneficial
|
||||
config["max_autotune"] = True
|
||||
config["coordinate_descent_tuning"] = True
|
||||
|
||||
|
||||
class EagerAdaptor(CompilerInterface):
|
||||
name = "eager"
|
||||
|
||||
@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface):
|
||||
graph: fx.GraphModule,
|
||||
example_inputs: List[Any],
|
||||
compiler_config: Dict[str, Any],
|
||||
runtime_shape: Optional[int] = None
|
||||
runtime_shape: Optional[int] = None,
|
||||
key: Optional[str] = None,
|
||||
) -> Tuple[Optional[Callable], Optional[Any]]:
|
||||
# we don't need to compile the graph, just return the graph itself.
|
||||
# It does not support caching, return None for the handle.
|
||||
|
@ -263,6 +263,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: bool(
|
||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||
|
||||
# Internal flag to enable/disable Inductor standalone compile
|
||||
"VLLM_TEST_STANDALONE_COMPILE":
|
||||
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",
|
||||
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
"LOCAL_RANK":
|
||||
@ -805,6 +809,7 @@ def compute_hash() -> str:
|
||||
"VLLM_USE_TRITON_AWQ",
|
||||
"VLLM_DP_RANK",
|
||||
"VLLM_DP_SIZE",
|
||||
"VLLM_TEST_STANDALONE_COMPILE",
|
||||
]
|
||||
for key in environment_variables_to_hash:
|
||||
if key in environment_variables:
|
||||
|
Reference in New Issue
Block a user