mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[torch.compile] integration with compilation control (#9058)
This commit is contained in:
@ -121,7 +121,9 @@ steps:
|
||||
- vllm/core/
|
||||
- tests/distributed
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
|
||||
@ -231,14 +233,16 @@ steps:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph_smoke.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
|
||||
- label: "PyTorch Fullgraph Test" # 18min
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/compile
|
||||
commands:
|
||||
- pytest -v -s compile/test_full_graph.py
|
||||
# TODO: re-write in comparison tests, and fix symbolic shape
|
||||
# for quantization ops.
|
||||
# - label: "PyTorch Fullgraph Test" # 18min
|
||||
# source_file_dependencies:
|
||||
# - vllm/
|
||||
# - tests/compile
|
||||
# commands:
|
||||
# - pytest -v -s compile/test_full_graph.py
|
||||
|
||||
- label: Kernels Test %N # 1h each
|
||||
mirror_hardwares: [amd]
|
||||
@ -394,7 +398,7 @@ steps:
|
||||
- tests/distributed/
|
||||
- vllm/compilation
|
||||
commands:
|
||||
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
- pytest -v -s ./compile/test_wrapper.py
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
|
||||
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus
|
||||
|
48
tests/compile/test_basic_correctness.py
Normal file
48
tests/compile/test_basic_correctness.py
Normal file
@ -0,0 +1,48 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import compare_all_settings
|
||||
|
||||
|
||||
# we cannot afford testing the full Catesian product
|
||||
# of all models and all levels
|
||||
@pytest.mark.parametrize(
|
||||
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
|
||||
[
|
||||
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
|
||||
True),
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
|
||||
["--quantization", "compressed-tensors"
|
||||
], 1, 1, "FLASH_ATTN", "generate", True),
|
||||
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
|
||||
# TODO: add multi-modality test for llava
|
||||
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
|
||||
])
|
||||
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
|
||||
method, fullgraph):
|
||||
# this test is run under multiple suits, with different GPUs.
|
||||
# make sure we only run the test with correct CUDA devices.
|
||||
# don't use "<", as it will duplicate the tests.
|
||||
if cuda_device_count_stateless() != pp_size * tp_size:
|
||||
pytest.skip("Not correct CUDA devices for the test.")
|
||||
import os
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
|
||||
if not fullgraph:
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
|
||||
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
|
||||
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
|
||||
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
|
||||
# inductor will change the output, so we cannot compare them.
|
||||
all_envs: List[Optional[Dict[str, str]]] = [{
|
||||
"VLLM_TORCH_COMPILE_LEVEL":
|
||||
str(level)
|
||||
} for level in [
|
||||
CompilationLevel.NO_COMPILATION,
|
||||
CompilationLevel.DYNAMO_AS_IS,
|
||||
CompilationLevel.DYNAMO_ONCE,
|
||||
]]
|
||||
compare_all_settings(model, all_args, all_envs, method=method)
|
@ -1,13 +1,20 @@
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
from .utils import TEST_MODELS, check_full_graph_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", TEST_MODELS)
|
||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
||||
def test_full_graph(model_info, backend):
|
||||
@pytest.mark.parametrize(
|
||||
"optimization_level",
|
||||
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
|
||||
@fork_new_process_for_each_test
|
||||
def test_full_graph(model_info, optimization_level):
|
||||
model = model_info[0]
|
||||
model_kwargs = model_info[1]
|
||||
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
|
||||
check_full_graph_support(model,
|
||||
model_kwargs,
|
||||
optimization_level,
|
||||
tp_size=1)
|
||||
|
@ -1,22 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
from vllm.utils import cuda_device_count_stateless
|
||||
|
||||
from ..utils import fork_new_process_for_each_test
|
||||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
|
||||
@pytest.mark.parametrize("tp_size", [2])
|
||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
||||
@fork_new_process_for_each_test
|
||||
def test_full_graph_multi_gpu(model_info, tp_size, backend):
|
||||
model = model_info[0]
|
||||
model_kwargs = model_info[1]
|
||||
|
||||
# Skip the test if there are not enough CUDA devices.
|
||||
if cuda_device_count_stateless() < tp_size:
|
||||
pytest.skip("Not enough CUDA devices for the test.")
|
||||
|
||||
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)
|
@ -1,13 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
|
||||
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
|
||||
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
|
||||
def test_full_graph(model_info, backend):
|
||||
model = model_info[0]
|
||||
model_kwargs = model_info[1]
|
||||
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
|
@ -4,16 +4,9 @@ import torch
|
||||
|
||||
from tests.quantization.utils import is_quant_method_supported
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.plugins import set_torch_compile_backend
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.utils import is_hip
|
||||
|
||||
TEST_MODELS_SMOKE = [
|
||||
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
|
||||
"quantization": "compressed-tensors"
|
||||
}),
|
||||
("meta-llama/Meta-Llama-3-8B", {}),
|
||||
]
|
||||
|
||||
TEST_MODELS = [
|
||||
("facebook/opt-125m", {}),
|
||||
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
|
||||
@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
|
||||
}))
|
||||
|
||||
|
||||
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
|
||||
def check_full_graph_support(model,
|
||||
model_kwargs,
|
||||
optimization_level,
|
||||
tp_size=1):
|
||||
# make sure these models can be captured in full graph mode
|
||||
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
|
||||
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
|
||||
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
|
||||
|
||||
# Inductor doesn't support fp8/gptq_marlin_24 yet.
|
||||
quantization = model_kwargs.get("quantization")
|
||||
if (quantization == "fp8" or quantization == "gptq_marlin"
|
||||
or quantization == "gptq_marlin_24") and backend != "eager":
|
||||
or quantization == "gptq_marlin_24"
|
||||
) and optimization_level >= CompilationLevel.INDUCTOR:
|
||||
return
|
||||
|
||||
set_torch_compile_backend(backend)
|
||||
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
|
@ -5,9 +5,11 @@ import tempfile
|
||||
|
||||
import depyf
|
||||
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
|
||||
# disable custom dispatcher, let Dynamo takes over
|
||||
# all the control
|
||||
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
|
||||
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
|
||||
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
with depyf.prepare_debug(temp_dir):
|
||||
|
@ -1,5 +1,7 @@
|
||||
import os
|
||||
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
|
||||
from ..utils import compare_two_settings
|
||||
|
||||
# --enforce-eager on TPU causes graph compilation
|
||||
@ -9,8 +11,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
|
||||
|
||||
|
||||
def test_custom_dispatcher():
|
||||
compare_two_settings("google/gemma-2b",
|
||||
arg1=["--enforce-eager"],
|
||||
arg2=["--enforce-eager"],
|
||||
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
|
||||
env2={})
|
||||
compare_two_settings(
|
||||
"google/gemma-2b",
|
||||
arg1=["--enforce-eager"],
|
||||
arg2=["--enforce-eager"],
|
||||
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
|
||||
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
|
||||
|
@ -1,8 +1,17 @@
|
||||
import copy
|
||||
import operator
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
from vllm.logger import init_logger
|
||||
|
||||
from .compile_context import get_compile_context
|
||||
from .levels import CompilationLevel
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
def fix_functionalization(graph: fx.Graph):
|
||||
"""
|
||||
@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
|
||||
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
|
||||
|
||||
|
||||
def vllm_backend(graph, example_inputs):
|
||||
def wrap_inductor(graph, example_inputs, additional_inductor_config):
|
||||
from torch._inductor import config
|
||||
current_config = config.shallow_copy_dict()
|
||||
from torch._inductor.compile_fx import compile_fx
|
||||
|
||||
if additional_inductor_config is not None:
|
||||
current_config.update(additional_inductor_config)
|
||||
if current_config['post_grad_custom_post_pass'] is not None:
|
||||
logger.warning(
|
||||
"post_grad_custom_post_pass is already set in the config. "
|
||||
"Overwriting it with the fix_functionalization")
|
||||
current_config['post_grad_custom_post_pass'] = fix_functionalization
|
||||
return compile_fx(graph, example_inputs, config_patches=current_config)
|
||||
|
||||
|
||||
def vllm_backend(
|
||||
graph,
|
||||
example_inputs,
|
||||
additional_inductor_config: Optional[Dict] = None) -> Callable:
|
||||
|
||||
context = get_compile_context()
|
||||
context = copy.deepcopy(context) if context is not None else []
|
||||
sizes_to_specialize: List[int] = context
|
||||
|
||||
# flags for all the seen shapes, whether we need to specialize
|
||||
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
|
||||
|
||||
# if we need to specialize, the compiled graph for that shape
|
||||
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
|
||||
|
||||
# this is the first compilation, we will compile a graph with
|
||||
# dynamic shape, as the caller will mark first dimension as dynamic
|
||||
logger.info("Compiling a graph for general shapes")
|
||||
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
|
||||
additional_inductor_config)
|
||||
|
||||
# TODO: Dynamo does not pass all dynamic shapes.
|
||||
# Need to investigate why. It works now because all the dynamic
|
||||
# shapes have the same value, and either of them can be used.
|
||||
sym_shape_indices = [
|
||||
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
|
||||
]
|
||||
|
||||
first_run = True
|
||||
|
||||
# this is the function we return to Dynamo to run finally
|
||||
def compiled_graph_wrapper(*args):
|
||||
|
||||
runtime_shapes: Tuple[int,
|
||||
...] = tuple(args[i] for i in sym_shape_indices)
|
||||
|
||||
nonlocal first_run
|
||||
nonlocal runtime_shapes_to_compile_flags
|
||||
nonlocal runtime_shapes_to_compiled_graph
|
||||
|
||||
if first_run:
|
||||
# the first compilation is for profiling, we directly run it
|
||||
first_run = False
|
||||
return graph_for_symbolic_shape(*args)
|
||||
|
||||
if runtime_shapes not in runtime_shapes_to_compile_flags:
|
||||
# we haven't seen this shape before
|
||||
# query if we need to specialize for this shape
|
||||
# we only specialize for the first dimension.
|
||||
# TODO: investigate if any model needs to specialize
|
||||
# beyond the first dimension
|
||||
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
|
||||
0] in sizes_to_specialize
|
||||
|
||||
if not runtime_shapes_to_compile_flags[runtime_shapes]:
|
||||
# we don't need to specialize for this shape
|
||||
return graph_for_symbolic_shape(*args)
|
||||
|
||||
if runtime_shapes not in runtime_shapes_to_compiled_graph:
|
||||
# we need to specialize for this shape, and we haven't compiled
|
||||
# compile the graph for this shape
|
||||
logger.info("Compiling a graph for shapes %s", runtime_shapes)
|
||||
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
|
||||
graph, args, additional_inductor_config)
|
||||
|
||||
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
|
||||
|
||||
return compiled_graph_wrapper
|
||||
|
||||
|
||||
def select_default_backend(level: int) -> Union[str, Callable]:
|
||||
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
|
||||
backend = "eager"
|
||||
return backend
|
||||
assert level in [
|
||||
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
|
||||
], f"Invalid level {level}"
|
||||
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
from vllm.plugins import get_inductor_additional_configs
|
||||
additional_configs = get_inductor_additional_configs()
|
||||
|
||||
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
|
||||
if "max_autotune" in additional_configs and not additional_configs[
|
||||
"max_autotune"]:
|
||||
logger.warning(
|
||||
"max_autotune is disabled, but is overridden by level %s",
|
||||
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
|
||||
additional_configs['max_autotune'] = True
|
||||
|
||||
from functools import partial
|
||||
backend = partial(vllm_backend,
|
||||
additional_inductor_config=additional_configs)
|
||||
|
||||
return backend
|
||||
|
23
vllm/compilation/compile_context.py
Normal file
23
vllm/compilation/compile_context.py
Normal file
@ -0,0 +1,23 @@
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
_compile_context: Any = None
|
||||
|
||||
|
||||
def get_compile_context() -> Any:
|
||||
"""Get the current compile context."""
|
||||
return _compile_context
|
||||
|
||||
|
||||
@contextmanager
|
||||
def set_compile_context(context: Any):
|
||||
"""A context manager that stores the current compile context,
|
||||
usually it is a list of sizes to specialize.
|
||||
"""
|
||||
global _compile_context
|
||||
prev_context = _compile_context
|
||||
_compile_context = context
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_compile_context = prev_context
|
85
vllm/compilation/decorators.py
Normal file
85
vllm/compilation/decorators.py
Normal file
@ -0,0 +1,85 @@
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
|
||||
from vllm.sequence import IntermediateTensors
|
||||
from vllm.utils import supports_dynamo
|
||||
|
||||
|
||||
def support_compile_llama_style(cls: type):
|
||||
"""
|
||||
A decorator to add support for compiling the forward method of a class.
|
||||
If a module's **forward signature** is compatible with llama, this
|
||||
decorator can be used to enable the compilation of the forward method.
|
||||
"""
|
||||
|
||||
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
|
||||
# will handle the compilation, so we don't need to do anything here.
|
||||
if envs.VLLM_TORCH_COMPILE_LEVEL in [
|
||||
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
|
||||
] or not supports_dynamo():
|
||||
return cls
|
||||
|
||||
# take care of method resolution order
|
||||
# make sure super().__init__ is called on the base class
|
||||
# other than TorchCompileWrapperWithCustomDispatcher
|
||||
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
|
||||
|
||||
old_init = cls.__init__
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
old_init(self, *args, **kwargs)
|
||||
TorchCompileWrapperWithCustomDispatcher.__init__(self)
|
||||
|
||||
cls.__init__ = __init__
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
input_ids: Optional[torch.Tensor],
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[torch.Tensor],
|
||||
attn_metadata: AttentionMetadata,
|
||||
intermediate_tensors: Optional[IntermediateTensors],
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
# torch.compiler.is_compiling() means we are inside the compilation
|
||||
# e.g. TPU has the compilation logic in model runner, so we don't
|
||||
# need to compile the model inside.
|
||||
if torch.compiler.is_compiling():
|
||||
return self.forward(input_ids, positions, kv_caches, attn_metadata,
|
||||
intermediate_tensors, inputs_embeds)
|
||||
|
||||
# the first compilation needs to have dynamic shapes marked
|
||||
if len(self.compiled_codes) < 1:
|
||||
if input_ids is not None:
|
||||
torch._dynamo.mark_dynamic(input_ids, 0)
|
||||
torch._dynamo.mark_dynamic(positions, 0)
|
||||
if inputs_embeds is not None:
|
||||
torch._dynamo.mark_dynamic(inputs_embeds, 0)
|
||||
if intermediate_tensors is not None:
|
||||
for tensors in intermediate_tensors.tensors.values():
|
||||
torch._dynamo.mark_dynamic(tensors, 0)
|
||||
|
||||
# if we don't use custom dispatcher, we can directly call the
|
||||
# compiled function and let torch.compile handle the dispatching,
|
||||
# with the overhead of guard evaluation and recompilation.
|
||||
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
|
||||
return self.compiled_callable(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
|
||||
# usually, capturing the model once is enough, and then we can
|
||||
# dispatch to the compiled code directly, without going through
|
||||
# the Dynamo guard mechanism.
|
||||
with self.dispatch_to_code(0):
|
||||
model_output = self.forward(input_ids, positions, kv_caches,
|
||||
attn_metadata, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return model_output
|
||||
|
||||
cls.__call__ = __call__
|
||||
return cls
|
9
vllm/compilation/levels.py
Normal file
9
vllm/compilation/levels.py
Normal file
@ -0,0 +1,9 @@
|
||||
# constants for the levels of the compilation process
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
NO_COMPILATION = 0
|
||||
DYNAMO_AS_IS = 1
|
||||
DYNAMO_ONCE = 2
|
||||
INDUCTOR = 3
|
||||
INDUCTOR_MAX_AUTOTUNE = 4
|
@ -3,12 +3,14 @@ import sys
|
||||
from abc import abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from types import CodeType
|
||||
from typing import Callable, List
|
||||
from typing import Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
from .levels import CompilationLevel
|
||||
|
||||
|
||||
class TorchCompileWrapperWithCustomDispatcher:
|
||||
"""
|
||||
@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
`torch.compile` over the forward method.
|
||||
"""
|
||||
|
||||
def __init__(self, compiled_callable: Callable):
|
||||
def __init__(self, compiled_callable: Optional[Callable] = None):
|
||||
|
||||
if compiled_callable is None:
|
||||
# default compilation settings
|
||||
# compiling the forward method
|
||||
|
||||
# choose the compile backend
|
||||
|
||||
# if the user has set the backend, use it
|
||||
from vllm.plugins import get_torch_compile_backend
|
||||
backend = get_torch_compile_backend()
|
||||
if backend is None:
|
||||
from vllm.compilation.backends import select_default_backend
|
||||
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
|
||||
|
||||
compiled_callable = torch.compile(
|
||||
self.forward,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
backend=backend)
|
||||
|
||||
self.compiled_callable = compiled_callable
|
||||
self.original_code_object = self.__class__.forward.__code__
|
||||
self.compiled_codes: List[CodeType] = []
|
||||
@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
|
||||
# subclasses can use this to switch between the custom dispatcher
|
||||
# and the default Dynamo guard mechanism.
|
||||
self.use_custom_dispatcher: bool = \
|
||||
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
|
||||
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
"""Implement the dispatch logic here, beyond the torch.compile level.
|
||||
|
16
vllm/envs.py
16
vllm/envs.py
@ -65,6 +65,7 @@ if TYPE_CHECKING:
|
||||
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
|
||||
VLLM_SKIP_P2P_CHECK: bool = False
|
||||
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
|
||||
VLLM_TORCH_COMPILE_LEVEL: int = 0
|
||||
|
||||
|
||||
def get_default_cache_root():
|
||||
@ -198,23 +199,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Internal flag to enable Dynamo graph capture
|
||||
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
|
||||
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
|
||||
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
|
||||
lambda:
|
||||
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
|
||||
("true", "1")),
|
||||
|
||||
# Internal flag to control whether we use custom op,
|
||||
# or use the native pytorch implementation
|
||||
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
|
||||
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
|
||||
|
||||
# Internal flag to enable Dynamo fullgraph capture
|
||||
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
|
||||
lambda: bool(
|
||||
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
|
||||
"VLLM_TORCH_COMPILE_LEVEL":
|
||||
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
|
||||
|
||||
# local rank of the process in the distributed setting, used to determine
|
||||
# the GPU device id
|
||||
|
@ -1,6 +1,7 @@
|
||||
import torch.nn as nn
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import is_cpu, is_hip, is_xpu
|
||||
|
||||
@ -55,7 +56,7 @@ class CustomOp(nn.Module):
|
||||
# NOTE(woosuk): Here we assume that vLLM was built for only one
|
||||
# specific backend. Currently, we do not support dynamic dispatching.
|
||||
|
||||
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
|
||||
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
|
||||
return self.forward_native
|
||||
|
||||
if is_hip():
|
||||
|
@ -21,6 +21,7 @@ from torch import nn
|
||||
from transformers import Gemma2Config
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_compile_llama_style
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
||||
from vllm.logger import init_logger
|
||||
@ -238,6 +239,7 @@ class Gemma2DecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_compile_llama_style
|
||||
class Gemma2Model(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
@ -28,6 +28,7 @@ from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
|
||||
from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.compilation.decorators import support_compile_llama_style
|
||||
from vllm.config import CacheConfig, LoRAConfig
|
||||
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
@ -265,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
@support_compile_llama_style
|
||||
class LlamaModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
|
@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
input_ids = None
|
||||
inputs_embeds = None
|
||||
else:
|
||||
# always pass the input via `inputs_embeds`
|
||||
# to make sure the computation graph is consistent
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
|
||||
if image_input is not None:
|
||||
@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids, inputs_embeds, vision_embeddings,
|
||||
self.config.image_token_index)
|
||||
|
||||
input_ids = None
|
||||
else:
|
||||
inputs_embeds = None
|
||||
inputs_embeds = self.language_model.model.get_input_embeddings(
|
||||
input_ids)
|
||||
input_ids = None
|
||||
|
||||
hidden_states = self.language_model.model(input_ids,
|
||||
positions,
|
||||
|
@ -1,7 +1,21 @@
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.plugins import set_torch_compile_backend
|
||||
|
||||
from .interface import Platform, PlatformEnum
|
||||
|
||||
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
|
||||
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
|
||||
|
||||
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\
|
||||
"TPU does not support Inductor."
|
||||
|
||||
set_torch_compile_backend("openxla")
|
||||
|
||||
|
||||
class TpuPlatform(Platform):
|
||||
_enum = PlatformEnum.TPU
|
||||
|
@ -1,5 +1,5 @@
|
||||
import logging
|
||||
from typing import Callable, Optional, Union
|
||||
from typing import Callable, Dict, Optional, Union
|
||||
|
||||
import vllm.envs as envs
|
||||
|
||||
@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]):
|
||||
|
||||
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
|
||||
return _torch_compile_backend
|
||||
|
||||
|
||||
_inductor_additional_configs: Dict = {}
|
||||
|
||||
|
||||
def set_inductor_additional_configs(configs: Dict):
|
||||
global _inductor_additional_configs
|
||||
_inductor_additional_configs = configs
|
||||
|
||||
|
||||
def get_inductor_additional_configs() -> Dict:
|
||||
return _inductor_additional_configs
|
||||
|
@ -1137,10 +1137,9 @@ class EmbeddingSequenceGroupOutput(
|
||||
return self.embeddings == other.embeddings
|
||||
|
||||
|
||||
class IntermediateTensors(
|
||||
msgspec.Struct,
|
||||
omit_defaults=True, # type: ignore[call-arg]
|
||||
array_like=True): # type: ignore[call-arg]
|
||||
# cannot use msgspec.Struct here because Dynamo does not support it
|
||||
@dataclass
|
||||
class IntermediateTensors:
|
||||
"""For all pipeline stages except the last, we need to return the hidden
|
||||
states and residuals to be sent to the next stage. This data structure
|
||||
contains the hidden states and residuals for a request.
|
||||
|
@ -18,6 +18,8 @@ import vllm.envs as envs
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.attention.backends.abstract import AttentionState
|
||||
from vllm.attention.backends.utils import CommonAttentionState
|
||||
from vllm.compilation.compile_context import set_compile_context
|
||||
from vllm.compilation.levels import CompilationLevel
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ObservabilityConfig, ParallelConfig,
|
||||
PromptAdapterConfig, SchedulerConfig)
|
||||
@ -1126,10 +1128,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
"provided. Defaulting to scaling factors of 1.0. "
|
||||
"This may lead to less accurate results!")
|
||||
|
||||
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
|
||||
from vllm.compilation.backends import vllm_backend
|
||||
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
|
||||
and supports_dynamo():
|
||||
from vllm.plugins import get_torch_compile_backend
|
||||
backend = get_torch_compile_backend() or vllm_backend
|
||||
backend = get_torch_compile_backend() or "eager"
|
||||
self.model = torch.compile(
|
||||
self.model,
|
||||
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
|
||||
@ -1289,7 +1291,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
|
||||
batch_size=batch_size,
|
||||
dtype=self.model_config.dtype,
|
||||
device=self.device)
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
|
||||
graph_batch_size = self.max_batchsize_to_capture
|
||||
batch_size_capture_list = [
|
||||
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
|
||||
]
|
||||
if self.model_config.enforce_eager:
|
||||
batch_size_capture_list = []
|
||||
with set_compile_context(batch_size_capture_list):
|
||||
self.execute_model(model_input, kv_caches, intermediate_tensors)
|
||||
torch.cuda.synchronize()
|
||||
return
|
||||
|
||||
|
Reference in New Issue
Block a user