AOT Compilation for torch.compile (Bundled) (#24274)

Signed-off-by: zhxchen17 <zhxchen17@fb.com>
This commit is contained in:
Zhengxu Chen
2025-10-10 19:02:11 -04:00
committed by GitHub
parent e317414ce1
commit eef921f45e
9 changed files with 484 additions and 41 deletions

View File

@ -403,6 +403,7 @@ steps:
- pytest -v -s compile/test_fusion_all_reduce.py
- pytest -v -s compile/test_decorator.py
- pytest -v -s compile/test_noop_elimination.py
- pytest -v -s compile/test_aot_compile.py
- label: PyTorch Fullgraph Smoke Test # 15min
timeout_in_minutes: 30

View File

@ -0,0 +1,139 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import tempfile
from contextlib import contextmanager
import pytest
import torch
from vllm.compilation.decorators import support_torch_compile
from vllm.config import (
CompilationConfig,
CompilationLevel,
VllmConfig,
set_current_vllm_config,
)
from vllm.forward_context import set_forward_context
from vllm.utils import is_torch_equal_or_newer
def reference_fn(x: torch.Tensor):
assert x.shape[0] <= 42
assert x.shape[0] % 2 == 0
for _ in range(3000):
x = x + x.shape[0]
return x
@support_torch_compile
class CompiledMod(torch.nn.Module):
def __init__(self, **kwargs):
super().__init__()
def forward(self, x: torch.Tensor):
return reference_fn(x)
def make_vllm_config() -> VllmConfig:
return VllmConfig(
compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
)
)
@contextmanager
def use_vllm_config(vllm_config: VllmConfig):
with set_forward_context({}, vllm_config), set_current_vllm_config(vllm_config):
yield
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_no_dynamo_cache_entry(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
vllm_config = make_vllm_config()
args = (torch.randn(10, 10),)
expected = reference_fn(*args)
with use_vllm_config(vllm_config):
m.setenv("VLLM_USE_AOT_COMPILE", "0")
with (
pytest.raises(RuntimeError, match="Detected recompile"),
torch.compiler.set_stance("fail_on_recompile"),
):
CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
torch._dynamo.reset()
with torch.compiler.set_stance("fail_on_recompile"):
actual = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(actual, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_force_aot_load(monkeypatch: pytest.MonkeyPatch):
with tempfile.TemporaryDirectory() as tmpdirname, monkeypatch.context() as m:
args = (torch.randn(10, 10),)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config), pytest.raises(FileNotFoundError):
CompiledMod(vllm_config=vllm_config)(*args)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_save_and_load(monkeypatch: pytest.MonkeyPatch):
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
expected = CompiledMod(vllm_config=vllm_config)(*args)
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
ret = CompiledMod(vllm_config=vllm_config)(*args)
assert torch.allclose(ret, expected)
@pytest.mark.skipif(
not is_torch_equal_or_newer("2.10.0.dev"), reason="requires torch 2.10"
)
def test_shape_env(monkeypatch: pytest.MonkeyPatch):
"""
Test that the shape environment is correctly serialized and preserved
when loading from cache.
"""
with monkeypatch.context() as m:
args = (torch.randn(10, 10),)
with tempfile.TemporaryDirectory() as tmpdirname:
m.setenv("VLLM_CACHE_ROOT", tmpdirname)
m.setenv("VLLM_USE_AOT_COMPILE", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"
m.setenv("VLLM_FORCE_AOT_LOAD", "1")
vllm_config = make_vllm_config()
with use_vllm_config(vllm_config):
compiled_mod = CompiledMod(vllm_config=vllm_config)
compiled_mod(*args)
artifacts = compiled_mod.aot_compiled_fn._artifacts
guards_string = artifacts.compiled_fn.shape_env.format_guards()
assert guards_string == " - s77 <= 42\n - Eq(Mod(s77, 2), 0)"

View File

@ -22,6 +22,7 @@ ALLOWED_FILES = {
"vllm/multimodal/hasher.py",
"vllm/transformers_utils/config.py",
"vllm/model_executor/models/registry.py",
"vllm/compilation/caching.py",
"tests/utils_/test_utils.py",
"tests/tokenization/test_cached_tokenizer.py",
"vllm/distributed/utils.py",

View File

@ -3,6 +3,7 @@
import ast
import dataclasses
import hashlib
import os
import pprint
import time
@ -25,6 +26,7 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import is_torch_equal_or_newer, resolve_obj_by_qualname
from .caching import VllmSerializableFunction
from .compiler_interface import (
CompilerInterface,
EagerAdaptor,
@ -195,6 +197,7 @@ class CompilerManager:
# there can be multiple graphs due to piecewise compilation.
now = time.time()
elapsed = now - compilation_start_time
compilation_config.compilation_time += elapsed
if runtime_shape is None:
logger.info(
"Directly load the compiled graph(s) for dynamic shape "
@ -549,7 +552,11 @@ class VllmBackend:
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager
def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
def __call__(
self, graph: fx.GraphModule, example_inputs
) -> VllmSerializableFunction:
from .caching import _compute_code_hash, compilation_config_hash_factors
vllm_config = self.vllm_config
if not self.compilation_config.cache_dir:
# no provided cache dir, generate one based on the known factors
@ -557,39 +564,11 @@ class VllmBackend:
# the cache dir will be the same so that we can reuse the compiled
# graph.
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
factors = compilation_config_hash_factors(vllm_config)
# 2. factors come from the code files that are traced by Dynamo (
# it mainly summarizes how the model is used in forward pass)
forward_code_files = list(sorted(self.compilation_config.traced_files))
code_hash = _compute_code_hash(self.compilation_config.traced_files)
self.compilation_config.traced_files.clear()
logger.debug(
"Traced files (to be considered for compilation cache):\n%s",
"\n".join(forward_code_files),
)
hash_content = []
for filepath in forward_code_files:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
with open(filepath) as f:
hash_content.append(f.read())
import hashlib
code_hash = hashlib.md5(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
factors.append(code_hash)
# 3. compiler hash
@ -695,7 +674,9 @@ class VllmBackend:
self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE
or not self.compilation_config.cudagraph_copy_inputs
):
return self.split_gm
return VllmSerializableFunction(
graph, example_inputs, self.prefix, self.split_gm
)
# if we need to copy input buffers for cudagraph
from torch._guards import detect_fake_mode
@ -740,4 +721,6 @@ class VllmBackend:
list_args[index] = static_tensor
return self.split_gm(*list_args)
return copy_and_call
return VllmSerializableFunction(
graph, example_inputs, self.prefix, copy_and_call
)

176
vllm/compilation/caching.py Normal file
View File

@ -0,0 +1,176 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import hashlib
import inspect
import pickle
from unittest.mock import patch
import torch
from torch.utils import _pytree as pytree
import vllm.envs as envs
from vllm.config import VllmConfig, get_current_vllm_config
from vllm.logger import init_logger
try:
from torch._dynamo.aot_compile import SerializableCallable
except ImportError:
SerializableCallable = object
assert isinstance(SerializableCallable, type)
logger = init_logger(__name__)
class VllmSerializableFunction(SerializableCallable):
"""
A wrapper around a compiled function by vllm. It will forward the tensor
inputs to the compiled function and return the result.
It also implements a serialization interface to support PyTorch's precompile
with custom backend, so that we can save and load the compiled function on
disk. There's no need to wrap around the compiled function if we don't want
to serialize them in particular cases.
Right now serialization for the custom backend is done via
serializing the Dynamo fx graph plus example inputs.
"""
def __init__(self, graph_module, example_inputs, prefix, optimized_call):
assert isinstance(graph_module, torch.fx.GraphModule)
self.graph_module = graph_module
self.example_inputs = example_inputs
self.prefix = prefix
self.optimized_call = optimized_call
self.shape_env = None
sym_input = next(
(i for i in self.example_inputs if isinstance(i, torch.SymInt)), None
)
if sym_input is not None:
self.shape_env = sym_input.node.shape_env
def __call__(self, *args, **kwargs):
return self.optimized_call(*args, **kwargs)
@classmethod
def serialize_compile_artifacts(
cls, compiled_fn: "VllmSerializableFunction"
) -> bytes:
import sympy
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler, Options
state = compiled_fn.__dict__.copy()
state.pop("optimized_call")
state.pop("shape_env")
for node in state["graph_module"].graph.nodes:
node.meta.pop("source_fn_stack", None)
node.meta.pop("nn_module_stack", None)
graph_reducer_override = GraphPickler.reducer_override
def _graph_reducer_override(self, obj):
if (
inspect.isclass(obj)
and issubclass(obj, sympy.Function)
and hasattr(obj, "_torch_unpickler")
):
return obj._torch_unpickler, (obj._torch_handler_name,)
if isinstance(obj, FakeTensorMode):
return type(None), ()
return graph_reducer_override(self, obj)
# Mask off tensor inputs since they are large and not needed.
state["example_inputs"] = pytree.tree_map_only(
torch.Tensor, lambda _: None, state["example_inputs"]
)
with patch.object(GraphPickler, "reducer_override", _graph_reducer_override):
state["graph_module"] = GraphPickler.dumps(
state["graph_module"], Options(ops_filter=None)
)
state["example_inputs"] = GraphPickler.dumps(state["example_inputs"])
return pickle.dumps(state)
@classmethod
def deserialize_compile_artifacts(cls, data: bytes) -> "VllmSerializableFunction":
from torch._guards import TracingContext, tracing
from torch._subclasses import FakeTensorMode
from torch.fx._graph_pickler import GraphPickler
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from vllm.compilation.backends import VllmBackend
state = pickle.loads(data)
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
state["graph_module"] = GraphPickler.loads(state["graph_module"], fake_mode)
state["example_inputs"] = GraphPickler.loads(state["example_inputs"], fake_mode)
vllm_backend = VllmBackend(get_current_vllm_config(), state["prefix"])
def optimized_call(*example_inputs):
"""
On the first run of the optimized call, we rerun the compiler
backend which should result in a cache hit. After the backend
call returns, we just do a one-time replacement of the optimized
call with the compiled function, so that subsequent calls are on
the AOT compiled path.
"""
compile_inputs = [
inp or example_inputs[i] for i, inp in enumerate(fn.example_inputs)
]
with tracing(TracingContext(fake_mode)):
fn.optimized_call = vllm_backend(
state["graph_module"], compile_inputs
).optimized_call
return fn.optimized_call(*example_inputs)
fn = cls(**state, optimized_call=optimized_call)
return fn
@property
def co_name(self):
"""
Used for depyf debugging.
"""
return "VllmSerializableFunction"
def compilation_config_hash_factors(vllm_config: VllmConfig) -> list[str]:
factors = []
# 0. factors come from the env, for example, The values of
# VLLM_PP_LAYER_PARTITION will affect the computation graph.
env_hash = envs.compute_hash()
factors.append(env_hash)
# 1. factors come from the vllm_config (it mainly summarizes how the
# model is created)
config_hash = vllm_config.compute_hash()
factors.append(config_hash)
return factors
def _compute_code_hash_with_content(file_contents: dict[str, str]) -> str:
items = list(sorted(file_contents.items(), key=lambda x: x[0]))
hash_content = []
for filepath, content in items:
hash_content.append(filepath)
if filepath == "<string>":
# This means the function was dynamically generated, with
# e.g. exec(). We can't actually check these.
continue
hash_content.append(content)
return hashlib.md5(
"\n".join(hash_content).encode(), usedforsecurity=False
).hexdigest()
def _compute_code_hash(files: set[str]) -> str:
logger.debug(
"Traced files (to be considered for compilation cache):\n%s", "\n".join(files)
)
file_contents = {}
for filepath in files:
if filepath == "<string>":
file_contents[filepath] = ""
else:
with open(filepath) as f:
file_contents[filepath] = f.read()
return _compute_code_hash_with_content(file_contents)

View File

@ -199,6 +199,7 @@ class InductorStandaloneAdaptor(CompilerInterface):
if compiler_config is not None:
current_config.update(compiler_config)
set_inductor_config(current_config, runtime_shape)
set_functorch_config()
if isinstance(runtime_shape, int):
dynamic_shapes = "from_example_inputs"
@ -307,6 +308,7 @@ class InductorAdaptor(CompilerInterface):
current_config["fx_graph_remote_cache"] = False
set_inductor_config(current_config, runtime_shape)
set_functorch_config()
# inductor can inplace modify the graph, so we need to copy it
# see https://github.com/pytorch/pytorch/issues/138980
@ -596,6 +598,10 @@ def set_inductor_config(config, runtime_shape):
)
def set_functorch_config():
torch._functorch.config.bundled_autograd_cache = False
class EagerAdaptor(CompilerInterface):
name = "eager"

View File

@ -2,7 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import contextlib
import hashlib
import inspect
import os
import sys
from typing import Callable, Optional, TypeVar, Union, overload
from unittest.mock import patch
@ -11,9 +14,10 @@ import torch.nn as nn
from packaging import version
from torch._dynamo.symbolic_convert import InliningInstructionTranslator
import vllm.envs as envs
from vllm.compilation.counter import compilation_counter
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.logger import init_logger
from vllm.sequence import IntermediateTensors
from vllm.utils import resolve_obj_by_qualname, supports_dynamo
@ -176,6 +180,33 @@ def support_torch_compile(
return cls_decorator_helper
def _model_hash_key(fn) -> str:
import vllm
sha256_hash = hashlib.sha256()
sha256_hash.update(vllm.__version__.encode())
sha256_hash.update(fn.__qualname__.encode())
sha256_hash.update(str(fn.__code__.co_firstlineno).encode())
return sha256_hash.hexdigest()
def _verify_source_unchanged(source_info, vllm_config) -> None:
from .caching import _compute_code_hash, _compute_code_hash_with_content
file_contents = {}
for source in source_info.inlined_sources:
module = sys.modules[source.module]
file = inspect.getfile(module)
vllm_config.compilation_config.traced_files.add(file)
file_contents[file] = source.content
expected_checksum = _compute_code_hash_with_content(file_contents)
actual_checksum = _compute_code_hash(set(file_contents.keys()))
if expected_checksum != actual_checksum:
raise RuntimeError(
"Source code has changed since the last compilation. Recompiling the model."
)
def _support_torch_compile(
cls: _T,
dynamic_arg_dims: dict[str, Union[int, list[int]]],
@ -227,6 +258,64 @@ def _support_torch_compile(
if self.do_not_compile or torch.compiler.is_compiling():
return self.forward(*args, **kwargs)
if getattr(self, "aot_compiled_fn", None) is not None:
return self.aot_compiled_fn(self, *args, **kwargs)
cache_dir = None
aot_compilation_path = None
if envs.VLLM_USE_AOT_COMPILE:
"""
When using torch.compile in AOT mode, we store the cache artifacts
under VLLM_CACHE_ROOT/torch_aot_compile/{hash}/rank_i_j. The {hash}
contains all of the factors except for the source files being
traced through, because we don't actually know which source files
to check at this point (before dynamo runs).
On loading we will actually look at the source files being traced
through. If any source file have changed (compared with the
serialized backend artifacts), then we need to generate a new AOT
compile artifact from scratch.
"""
from .caching import compilation_config_hash_factors
factors: list[str] = compilation_config_hash_factors(self.vllm_config)
factors.append(_model_hash_key(self.forward))
hash_key = hashlib.sha256(str(factors).encode()).hexdigest()
cache_dir = os.path.join(
envs.VLLM_CACHE_ROOT,
"torch_aot_compile",
hash_key,
)
rank = self.vllm_config.parallel_config.rank
dp_rank = self.vllm_config.parallel_config.data_parallel_rank
cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
aot_compilation_path = os.path.join(cache_dir, "model")
try:
with (
set_current_vllm_config(self.vllm_config),
open(aot_compilation_path, "rb") as f,
):
start_monitoring_torch_compile(self.vllm_config)
loaded_fn = torch.compiler.load_compiled_function(f)
_verify_source_unchanged(loaded_fn.source_info(), self.vllm_config)
self.aot_compiled_fn = loaded_fn
except Exception as e:
if os.path.exists(aot_compilation_path):
logger.warning(
"Cannot load aot compilation from path %s, error: %s",
aot_compilation_path,
str(e),
)
if envs.VLLM_FORCE_AOT_LOAD:
raise e
if getattr(self, "aot_compiled_fn", None) is not None:
logger.info(
"Directly load AOT compilation from path %s", aot_compilation_path
)
return self.aot_compiled_fn(self, *args, **kwargs)
# the first compilation needs to have dynamic shapes marked
if len(self.compiled_codes) < 1:
sig = inspect.signature(self.__class__.forward)
@ -275,15 +364,15 @@ def _support_torch_compile(
)
# 2. every time Dynamo sees a function call, it will inline
# the function by calling InliningInstructionTranslator.inline_call
# the function by calling InliningInstructionTranslator.inline_call_
# we hijack this function to know all the functions called
# during Dynamo tracing, and their corresponding files
inline_call = InliningInstructionTranslator.inline_call
inline_call = InliningInstructionTranslator.inline_call_
def patched_inline_call(parent, func, args, kwargs):
code = func.get_code()
def patched_inline_call(self_):
code = self_.f_code
self.vllm_config.compilation_config.traced_files.add(code.co_filename)
return inline_call(parent, func, args, kwargs)
return inline_call(self_)
# Disable the C++ compilation of symbolic shape guards. C++-fication
# of symbolic shape guards can improve guard overhead. But, since
@ -300,13 +389,21 @@ def _support_torch_compile(
with (
patch.object(
InliningInstructionTranslator, "inline_call", patched_inline_call
InliningInstructionTranslator, "inline_call_", patched_inline_call
),
torch._dynamo.config.patch(**dynamo_config_patches),
maybe_use_cudagraph_partition_wrapper(self.vllm_config),
_torch27_patch_tensor_subclasses(),
):
output = self.compiled_callable(*args, **kwargs)
if envs.VLLM_USE_AOT_COMPILE:
self.aot_compiled_fn = self.aot_compile(*args, **kwargs)
output = self.aot_compiled_fn(self, *args, **kwargs)
assert aot_compilation_path is not None
assert cache_dir is not None
os.makedirs(cache_dir, exist_ok=True)
self.aot_compiled_fn.save_compiled_function(aot_compilation_path)
else:
output = self.compiled_callable(*args, **kwargs)
return output
# usually, capturing the model once is enough, and then we can

View File

@ -10,6 +10,7 @@ from typing import Callable, Optional
import torch
import vllm.envs as envs
from vllm.config import CompilationLevel, CUDAGraphMode, get_current_vllm_config
from vllm.logger import init_logger
@ -44,6 +45,19 @@ class TorchCompileWrapperWithCustomDispatcher:
options = (
get_current_vllm_config().compilation_config.inductor_compile_config
)
if envs.VLLM_USE_AOT_COMPILE:
options = options or {}
# This effectively drop all the guards.
# We need this because bytecode hook is not used any more to
# drop guards in the AOT compile mode.
options["guard_filter_fn"] = lambda guards: [False for _ in guards]
if hasattr(torch._dynamo.config, "enable_aot_compile"):
torch._dynamo.config.enable_aot_compile = True
else:
msg = "torch._dynamo.config.enable_aot_compile is not "
msg += "available. AOT compile is disabled and please "
msg += "upgrade PyTorch version to use AOT compile."
logger.warning(msg)
compiled_callable = torch.compile(
self.forward, fullgraph=True, backend=backend, options=options
@ -61,6 +75,15 @@ class TorchCompileWrapperWithCustomDispatcher:
compilation_level >= CompilationLevel.DYNAMO_ONCE
)
def aot_compile(self, *args, **kwargs):
if not hasattr(self.compiled_callable, "aot_compile"):
raise RuntimeError(
"aot_compile is not supported by the current configuration. "
+ "Please make sure torch.compile is enabled with the latest "
+ f"version of PyTorch (current using torch: {torch.__version__})"
)
return self.compiled_callable.aot_compile((args, kwargs))
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.
NOTE: this function can have additional arguments beyond the forward

View File

@ -89,6 +89,8 @@ if TYPE_CHECKING:
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
VLLM_TORCH_PROFILER_RECORD_SHAPES: bool = False
VLLM_TORCH_PROFILER_WITH_PROFILE_MEMORY: bool = False
VLLM_USE_AOT_COMPILE: bool = False
VLLM_FORCE_AOT_LOAD: bool = False
VLLM_TORCH_PROFILER_WITH_STACK: bool = True
VLLM_TORCH_PROFILER_WITH_FLOPS: bool = False
VLLM_USE_TRITON_AWQ: bool = False
@ -235,6 +237,13 @@ def maybe_convert_bool(value: Optional[str]) -> Optional[bool]:
return bool(int(value))
def use_aot_compile() -> bool:
from vllm.utils import is_torch_equal_or_newer
default_value = "1" if is_torch_equal_or_newer("2.10.0.dev") else "0"
return os.environ.get("VLLM_USE_AOT_COMPILE", default_value) == "1"
def env_with_choices(
env_name: str,
default: Optional[str],
@ -494,6 +503,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
# Dump fx graphs to the given directory.
# It will override CompilationConfig.debug_dump_path if set.
"VLLM_DEBUG_DUMP_PATH": lambda: os.environ.get("VLLM_DEBUG_DUMP_PATH", None),
# Feature flag to enable/disable AOT compilation. This will ensure
# compilation is done in warmup phase and the compilation will be
# reused in subsequent calls.
"VLLM_USE_AOT_COMPILE": use_aot_compile,
# Force vllm to always load AOT compiled models from disk. Failure
# to load will result in a hard error when this is enabled.
# Will be ignored when VLLM_USE_AOT_COMPILE is disabled.
"VLLM_FORCE_AOT_LOAD": lambda: os.environ.get("VLLM_FORCE_AOT_LOAD", "0") == "1",
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK": lambda: int(os.environ.get("LOCAL_RANK", "0")),