Add option to use torch._inductor.standalone_compile (#17057)

Signed-off-by: rzou <zou3519@gmail.com>
This commit is contained in:
Richard Zou
2025-05-09 15:59:04 -04:00
committed by GitHub
parent 7d4aedae7c
commit ea2236bf95
3 changed files with 150 additions and 29 deletions

View File

@ -17,7 +17,8 @@ from vllm.config import CompilationConfig, VllmConfig
from vllm.logger import init_logger
from vllm.utils import weak_ref_tensors
from .compiler_interface import EagerAdaptor, InductorAdaptor
from .compiler_interface import (CompilerInterface, EagerAdaptor,
InductorAdaptor, InductorStandaloneAdaptor)
from .counter import compilation_counter
from .inductor_pass import InductorPass
from .monitor import end_monitoring_torch_compile
@ -26,6 +27,19 @@ from .pass_manager import PostGradPassManager
logger = init_logger(__name__)
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
if compilation_config.use_inductor:
if envs.VLLM_TEST_STANDALONE_COMPILE:
logger.info("Using InductorStandaloneAdaptor")
return InductorStandaloneAdaptor()
else:
logger.info("Using InductorAdaptor")
return InductorAdaptor()
else:
logger.info("Using EagerAdaptor")
return EagerAdaptor()
class CompilerManager:
"""
A manager to manage the compilation process, including
@ -41,11 +55,11 @@ class CompilerManager:
support int as key.
"""
def __init__(self, use_inductor: bool):
def __init__(self, compilation_config: CompilationConfig):
self.cache: Dict[Tuple[Optional[int], int, str], Any] = dict()
cls = InductorAdaptor if use_inductor else EagerAdaptor
self.compiler = cls()
self.is_cache_updated = False
self.compilation_config = compilation_config
self.compiler = make_compiler(compilation_config)
def compute_hash(self, vllm_config: VllmConfig) -> str:
return self.compiler.compute_hash(vllm_config)
@ -123,8 +137,15 @@ class CompilerManager:
# no compiler cached the graph, or the cache is disabled,
# we need to compile it
if isinstance(self.compiler, InductorAdaptor):
# Let compile_fx generate a key for us
maybe_key = None
else:
maybe_key = \
f"artifact_shape_{runtime_shape}_subgraph_{graph_index}"
compiled_graph, handle = self.compiler.compile(
graph, example_inputs, additional_inductor_config, runtime_shape)
graph, example_inputs, additional_inductor_config, runtime_shape,
maybe_key)
assert compiled_graph is not None, "Failed to compile the graph"
@ -336,7 +357,7 @@ class VllmBackend:
self.compilation_config = vllm_config.compilation_config
self.compiler_manager: CompilerManager = CompilerManager(
self.compilation_config.use_inductor)
self.compilation_config)
# `torch.compile` is JIT compiled, so we don't need to
# do anything here

View File

@ -50,7 +50,8 @@ class CompilerInterface:
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
"""
Compile the graph with the given example inputs and compiler config,
@ -71,6 +72,10 @@ class CompilerInterface:
If the compiler doesn't support caching, it should return None for the
handle. If the compiler fails to compile the graph, it should return
None for the compiled function as well.
`key` is required for StandaloneInductorAdapter, it specifies where to
save the compiled artifact. The compiled artifact gets saved to
`cache_dir/key`.
"""
return None, None
@ -127,23 +132,108 @@ class AlwaysHitShapeEnv:
return ""
def get_inductor_factors() -> List[Any]:
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
return factors
class InductorStandaloneAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler.
Requires PyTorch 2.8+.
This is not on by default yet, but we plan to turn it on by default for
PyTorch 2.8.
Use VLLM_TEST_STANDALONE_COMPILE to toggle this on or off.
"""
name = "inductor_standalone"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
return hash_str
def initialize_cache(self, cache_dir: str, disable_cache: bool = False):
self.cache_dir = cache_dir
def compile(
self,
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)
if isinstance(runtime_shape, int):
dynamic_shapes = "from_example_inputs"
else:
dynamic_shapes = "from_tracing_context"
from torch._inductor import standalone_compile
with pass_context(runtime_shape):
compiled_graph = standalone_compile(
graph,
example_inputs,
dynamic_shapes=dynamic_shapes,
options={"config_patches": current_config})
# Save the compiled artifact to disk in the specified path
assert key is not None
path = os.path.join(self.cache_dir, key)
compiled_graph.save(path=path, format="unpacked")
return compiled_graph, (key, path)
def load(self,
handle: Any,
graph: fx.GraphModule,
example_inputs: List[Any],
graph_index: int,
runtime_shape: Optional[int] = None) -> Callable:
assert isinstance(handle, tuple)
assert isinstance(handle[0], str)
assert isinstance(handle[1], str)
path = handle[1]
inductor_compiled_graph = torch._inductor.CompiledArtifact.load(
path=path, format="unpacked")
from torch._inductor.compile_fx import graph_returns_tuple
returns_tuple = graph_returns_tuple(graph)
def compiled_graph_wrapper(*args):
graph_output = inductor_compiled_graph(*args)
# unpack the tuple if needed
# TODO(rzou): the implication is that we're not
# reading the python bytecode correctly in vLLM?
if returns_tuple:
return graph_output
else:
return graph_output[0]
return compiled_graph_wrapper
class InductorAdaptor(CompilerInterface):
"""
The adaptor for the Inductor compiler, version 2.5 and 2.6.
The adaptor for the Inductor compiler, version 2.5, 2.6, 2.7.
"""
name = "inductor"
def compute_hash(self, vllm_config: VllmConfig) -> str:
factors: List[Any] = []
# summarize system state
from torch._inductor.codecache import CacheBase
system_factors = CacheBase.get_system()
factors.append(system_factors)
# summarize pytorch state
from torch._inductor.codecache import torch_key
torch_factors = torch_key()
factors.append(torch_factors)
factors = get_inductor_factors()
hash_str = hashlib.md5(str(factors).encode(),
usedforsecurity=False).hexdigest()[:10]
return hash_str
@ -168,23 +258,19 @@ class InductorAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
current_config = {}
from torch._inductor.compile_fx import compile_fx
current_config = {}
if compiler_config is not None:
current_config.update(compiler_config)
# disable remote cache
current_config["fx_graph_cache"] = True
current_config["fx_graph_remote_cache"] = False
if compiler_config is not None:
current_config.update(compiler_config)
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
current_config["max_autotune"] = True
current_config["coordinate_descent_tuning"] = True
set_inductor_config(current_config, runtime_shape)
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
@ -422,6 +508,14 @@ class InductorAdaptor(CompilerInterface):
return contextlib.nullcontext()
def set_inductor_config(config, runtime_shape):
if isinstance(runtime_shape, int):
# for a specific batchsize, tuning triton kernel parameters
# can be beneficial
config["max_autotune"] = True
config["coordinate_descent_tuning"] = True
class EagerAdaptor(CompilerInterface):
name = "eager"
@ -430,7 +524,8 @@ class EagerAdaptor(CompilerInterface):
graph: fx.GraphModule,
example_inputs: List[Any],
compiler_config: Dict[str, Any],
runtime_shape: Optional[int] = None
runtime_shape: Optional[int] = None,
key: Optional[str] = None,
) -> Tuple[Optional[Callable], Optional[Any]]:
# we don't need to compile the graph, just return the graph itself.
# It does not support caching, return None for the handle.

View File

@ -263,6 +263,10 @@ environment_variables: dict[str, Callable[[], Any]] = {
lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
# Internal flag to enable/disable Inductor standalone compile
"VLLM_TEST_STANDALONE_COMPILE":
lambda: os.environ.get("VLLM_TEST_STANDALONE_COMPILE", "0") != "0",
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK":
@ -805,6 +809,7 @@ def compute_hash() -> str:
"VLLM_USE_TRITON_AWQ",
"VLLM_DP_RANK",
"VLLM_DP_SIZE",
"VLLM_TEST_STANDALONE_COMPILE",
]
for key in environment_variables_to_hash:
if key in environment_variables: