[torch.compile] integration with compilation control (#9058)

This commit is contained in:
youkaichao
2024-10-10 12:39:36 -07:00
committed by GitHub
parent 78c0b4166c
commit e4d652ea3e
22 changed files with 404 additions and 98 deletions

View File

@ -121,7 +121,9 @@ steps:
- vllm/core/
- tests/distributed
- tests/spec_decode/e2e/test_integration_dist_tp4
- tests/compile
commands:
- pytest -v -s compile/test_basic_correctness.py
- pytest -v -s distributed/test_pynccl.py
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
@ -231,14 +233,16 @@ steps:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph_smoke.py
- pytest -v -s compile/test_basic_correctness.py
- label: "PyTorch Fullgraph Test" # 18min
source_file_dependencies:
- vllm/
- tests/compile
commands:
- pytest -v -s compile/test_full_graph.py
# TODO: re-write in comparison tests, and fix symbolic shape
# for quantization ops.
# - label: "PyTorch Fullgraph Test" # 18min
# source_file_dependencies:
# - vllm/
# - tests/compile
# commands:
# - pytest -v -s compile/test_full_graph.py
- label: Kernels Test %N # 1h each
mirror_hardwares: [amd]
@ -394,7 +398,7 @@ steps:
- tests/distributed/
- vllm/compilation
commands:
- pytest -v -s ./compile/test_full_graph_multi_gpu.py
- pytest -v -s ./compile/test_basic_correctness.py
- pytest -v -s ./compile/test_wrapper.py
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed'
- TARGET_TEST_SUITE=L4 VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1=1 pytest basic_correctness/ -v -s -m distributed_2_gpus

View File

@ -0,0 +1,48 @@
from typing import Dict, List, Optional
import pytest
from vllm.compilation.levels import CompilationLevel
from vllm.utils import cuda_device_count_stateless
from ..utils import compare_all_settings
# we cannot afford testing the full Catesian product
# of all models and all levels
@pytest.mark.parametrize(
"model, model_args, pp_size, tp_size, attn_backend, method, fullgraph",
[
("meta-llama/Meta-Llama-3-8B", [], 2, 2, "FLASH_ATTN", "generate",
True),
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples",
["--quantization", "compressed-tensors"
], 1, 1, "FLASH_ATTN", "generate", True),
("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True),
# TODO: add multi-modality test for llava
("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False)
])
def test_compile_correctness(model, model_args, pp_size, tp_size, attn_backend,
method, fullgraph):
# this test is run under multiple suits, with different GPUs.
# make sure we only run the test with correct CUDA devices.
# don't use "<", as it will duplicate the tests.
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
import os
os.environ["VLLM_ATTENTION_BACKEND"] = attn_backend
if not fullgraph:
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0"
all_args = [["--enforce-eager"] + model_args + ["--max_model_len", "1024"]
+ ["-pp", str(pp_size)] + ["-tp", str(tp_size)]] * 3
# don't test VLLM_TORCH_COMPILE_LEVEL == 3 case
# inductor will change the output, so we cannot compare them.
all_envs: List[Optional[Dict[str, str]]] = [{
"VLLM_TORCH_COMPILE_LEVEL":
str(level)
} for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]]
compare_all_settings(model, all_args, all_envs, method=method)

View File

@ -1,13 +1,20 @@
import pytest
from vllm.compilation.backends import vllm_backend
from vllm.compilation.levels import CompilationLevel
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
@pytest.mark.parametrize(
"optimization_level",
[CompilationLevel.DYNAMO_ONCE, CompilationLevel.INDUCTOR])
@fork_new_process_for_each_test
def test_full_graph(model_info, optimization_level):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)
check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1)

View File

@ -1,22 +0,0 @@
import pytest
from vllm.compilation.backends import vllm_backend
from vllm.utils import cuda_device_count_stateless
from ..utils import fork_new_process_for_each_test
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("tp_size", [2])
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
@fork_new_process_for_each_test
def test_full_graph_multi_gpu(model_info, tp_size, backend):
model = model_info[0]
model_kwargs = model_info[1]
# Skip the test if there are not enough CUDA devices.
if cuda_device_count_stateless() < tp_size:
pytest.skip("Not enough CUDA devices for the test.")
check_full_graph_support(model, model_kwargs, backend, tp_size=tp_size)

View File

@ -1,13 +0,0 @@
import pytest
from vllm.compilation.backends import vllm_backend
from .utils import TEST_MODELS_SMOKE, check_full_graph_support
@pytest.mark.parametrize("model_info", TEST_MODELS_SMOKE)
@pytest.mark.parametrize("backend", ["eager", vllm_backend])
def test_full_graph(model_info, backend):
model = model_info[0]
model_kwargs = model_info[1]
check_full_graph_support(model, model_kwargs, backend, tp_size=1)

View File

@ -4,16 +4,9 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.plugins import set_torch_compile_backend
from vllm.compilation.levels import CompilationLevel
from vllm.utils import is_hip
TEST_MODELS_SMOKE = [
("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", {
"quantization": "compressed-tensors"
}),
("meta-llama/Meta-Llama-3-8B", {}),
]
TEST_MODELS = [
("facebook/opt-125m", {}),
("nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change", {
@ -68,20 +61,21 @@ if not is_hip() and is_quant_method_supported("awq"):
}))
def check_full_graph_support(model, model_kwargs, backend, tp_size=1):
def check_full_graph_support(model,
model_kwargs,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode
if "VLLM_TEST_DYNAMO_GRAPH_CAPTURE" not in os.environ:
os.environ["VLLM_TEST_DYNAMO_GRAPH_CAPTURE"] = "1"
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# Inductor doesn't support fp8/gptq_marlin_24 yet.
quantization = model_kwargs.get("quantization")
if (quantization == "fp8" or quantization == "gptq_marlin"
or quantization == "gptq_marlin_24") and backend != "eager":
or quantization == "gptq_marlin_24"
) and optimization_level >= CompilationLevel.INDUCTOR:
return
set_torch_compile_backend(backend)
prompts = [
"Hello, my name is",
"The president of the United States is",

View File

@ -5,9 +5,11 @@ import tempfile
import depyf
from vllm.compilation.levels import CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_DYNAMO_USE_CUSTOM_DISPATCHER'] = "0"
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):

View File

@ -1,5 +1,7 @@
import os
from vllm.compilation.levels import CompilationLevel
from ..utils import compare_two_settings
# --enforce-eager on TPU causes graph compilation
@ -9,8 +11,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher():
compare_two_settings("google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER": "0"},
env2={})
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})

View File

@ -1,8 +1,17 @@
import copy
import operator
from typing import Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.fx as fx
from vllm.logger import init_logger
from .compile_context import get_compile_context
from .levels import CompilationLevel
logger = init_logger(__name__)
def fix_functionalization(graph: fx.Graph):
"""
@ -148,9 +157,113 @@ def fix_functionalization(graph: fx.Graph):
# print(graph.python_code(root_module="self", verbose=True).src, file=f)
def vllm_backend(graph, example_inputs):
def wrap_inductor(graph, example_inputs, additional_inductor_config):
from torch._inductor import config
current_config = config.shallow_copy_dict()
from torch._inductor.compile_fx import compile_fx
if additional_inductor_config is not None:
current_config.update(additional_inductor_config)
if current_config['post_grad_custom_post_pass'] is not None:
logger.warning(
"post_grad_custom_post_pass is already set in the config. "
"Overwriting it with the fix_functionalization")
current_config['post_grad_custom_post_pass'] = fix_functionalization
return compile_fx(graph, example_inputs, config_patches=current_config)
def vllm_backend(
graph,
example_inputs,
additional_inductor_config: Optional[Dict] = None) -> Callable:
context = get_compile_context()
context = copy.deepcopy(context) if context is not None else []
sizes_to_specialize: List[int] = context
# flags for all the seen shapes, whether we need to specialize
runtime_shapes_to_compile_flags: Dict[Tuple[int, ...], bool] = {}
# if we need to specialize, the compiled graph for that shape
runtime_shapes_to_compiled_graph: Dict[Tuple[int, ...], Callable] = {}
# this is the first compilation, we will compile a graph with
# dynamic shape, as the caller will mark first dimension as dynamic
logger.info("Compiling a graph for general shapes")
graph_for_symbolic_shape = wrap_inductor(graph, example_inputs,
additional_inductor_config)
# TODO: Dynamo does not pass all dynamic shapes.
# Need to investigate why. It works now because all the dynamic
# shapes have the same value, and either of them can be used.
sym_shape_indices = [
i for i, x in enumerate(example_inputs) if isinstance(x, torch.SymInt)
]
first_run = True
# this is the function we return to Dynamo to run finally
def compiled_graph_wrapper(*args):
runtime_shapes: Tuple[int,
...] = tuple(args[i] for i in sym_shape_indices)
nonlocal first_run
nonlocal runtime_shapes_to_compile_flags
nonlocal runtime_shapes_to_compiled_graph
if first_run:
# the first compilation is for profiling, we directly run it
first_run = False
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compile_flags:
# we haven't seen this shape before
# query if we need to specialize for this shape
# we only specialize for the first dimension.
# TODO: investigate if any model needs to specialize
# beyond the first dimension
runtime_shapes_to_compile_flags[runtime_shapes] = runtime_shapes[
0] in sizes_to_specialize
if not runtime_shapes_to_compile_flags[runtime_shapes]:
# we don't need to specialize for this shape
return graph_for_symbolic_shape(*args)
if runtime_shapes not in runtime_shapes_to_compiled_graph:
# we need to specialize for this shape, and we haven't compiled
# compile the graph for this shape
logger.info("Compiling a graph for shapes %s", runtime_shapes)
runtime_shapes_to_compiled_graph[runtime_shapes] = wrap_inductor(
graph, args, additional_inductor_config)
return runtime_shapes_to_compiled_graph[runtime_shapes](*args)
return compiled_graph_wrapper
def select_default_backend(level: int) -> Union[str, Callable]:
if level in [CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE]:
backend = "eager"
return backend
assert level in [
CompilationLevel.INDUCTOR, CompilationLevel.INDUCTOR_MAX_AUTOTUNE
], f"Invalid level {level}"
from vllm.compilation.backends import vllm_backend
from vllm.plugins import get_inductor_additional_configs
additional_configs = get_inductor_additional_configs()
if level == CompilationLevel.INDUCTOR_MAX_AUTOTUNE:
if "max_autotune" in additional_configs and not additional_configs[
"max_autotune"]:
logger.warning(
"max_autotune is disabled, but is overridden by level %s",
CompilationLevel.INDUCTOR_MAX_AUTOTUNE)
additional_configs['max_autotune'] = True
from functools import partial
backend = partial(vllm_backend,
additional_inductor_config=additional_configs)
return backend

View File

@ -0,0 +1,23 @@
from contextlib import contextmanager
from typing import Any
_compile_context: Any = None
def get_compile_context() -> Any:
"""Get the current compile context."""
return _compile_context
@contextmanager
def set_compile_context(context: Any):
"""A context manager that stores the current compile context,
usually it is a list of sizes to specialize.
"""
global _compile_context
prev_context = _compile_context
_compile_context = context
try:
yield
finally:
_compile_context = prev_context

View File

@ -0,0 +1,85 @@
from typing import List, Optional, Union
import torch
import vllm.envs as envs
from vllm.attention import AttentionMetadata
from vllm.compilation.levels import CompilationLevel
from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher
from vllm.sequence import IntermediateTensors
from vllm.utils import supports_dynamo
def support_compile_llama_style(cls: type):
"""
A decorator to add support for compiling the forward method of a class.
If a module's **forward signature** is compatible with llama, this
decorator can be used to enable the compilation of the forward method.
"""
# for CompilationLevel.DYNAMO_AS_IS , the upper level model runner
# will handle the compilation, so we don't need to do anything here.
if envs.VLLM_TORCH_COMPILE_LEVEL in [
CompilationLevel.NO_COMPILATION, CompilationLevel.DYNAMO_AS_IS
] or not supports_dynamo():
return cls
# take care of method resolution order
# make sure super().__init__ is called on the base class
# other than TorchCompileWrapperWithCustomDispatcher
cls.__bases__ = cls.__bases__ + (TorchCompileWrapperWithCustomDispatcher, )
old_init = cls.__init__
def __init__(self, *args, **kwargs):
old_init(self, *args, **kwargs)
TorchCompileWrapperWithCustomDispatcher.__init__(self)
cls.__init__ = __init__
def __call__(
self,
input_ids: Optional[torch.Tensor],
positions: torch.Tensor,
kv_caches: List[torch.Tensor],
attn_metadata: AttentionMetadata,
intermediate_tensors: Optional[IntermediateTensors],
inputs_embeds: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, IntermediateTensors]:
# torch.compiler.is_compiling() means we are inside the compilation
# e.g. TPU has the compilation logic in model runner, so we don't
# need to compile the model inside.
if torch.compiler.is_compiling():
return self.forward(input_ids, positions, kv_caches, attn_metadata,
intermediate_tensors, inputs_embeds)
# the first compilation needs to have dynamic shapes marked
if len(self.compiled_codes) < 1:
if input_ids is not None:
torch._dynamo.mark_dynamic(input_ids, 0)
torch._dynamo.mark_dynamic(positions, 0)
if inputs_embeds is not None:
torch._dynamo.mark_dynamic(inputs_embeds, 0)
if intermediate_tensors is not None:
for tensors in intermediate_tensors.tensors.values():
torch._dynamo.mark_dynamic(tensors, 0)
# if we don't use custom dispatcher, we can directly call the
# compiled function and let torch.compile handle the dispatching,
# with the overhead of guard evaluation and recompilation.
if len(self.compiled_codes) < 1 or not self.use_custom_dispatcher:
return self.compiled_callable(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
# usually, capturing the model once is enough, and then we can
# dispatch to the compiled code directly, without going through
# the Dynamo guard mechanism.
with self.dispatch_to_code(0):
model_output = self.forward(input_ids, positions, kv_caches,
attn_metadata, intermediate_tensors,
inputs_embeds)
return model_output
cls.__call__ = __call__
return cls

View File

@ -0,0 +1,9 @@
# constants for the levels of the compilation process
class CompilationLevel:
NO_COMPILATION = 0
DYNAMO_AS_IS = 1
DYNAMO_ONCE = 2
INDUCTOR = 3
INDUCTOR_MAX_AUTOTUNE = 4

View File

@ -3,12 +3,14 @@ import sys
from abc import abstractmethod
from contextlib import contextmanager
from types import CodeType
from typing import Callable, List
from typing import Callable, List, Optional
import torch
import vllm.envs as envs
from .levels import CompilationLevel
class TorchCompileWrapperWithCustomDispatcher:
"""
@ -23,7 +25,26 @@ class TorchCompileWrapperWithCustomDispatcher:
`torch.compile` over the forward method.
"""
def __init__(self, compiled_callable: Callable):
def __init__(self, compiled_callable: Optional[Callable] = None):
if compiled_callable is None:
# default compilation settings
# compiling the forward method
# choose the compile backend
# if the user has set the backend, use it
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend()
if backend is None:
from vllm.compilation.backends import select_default_backend
backend = select_default_backend(envs.VLLM_TORCH_COMPILE_LEVEL)
compiled_callable = torch.compile(
self.forward,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
backend=backend)
self.compiled_callable = compiled_callable
self.original_code_object = self.__class__.forward.__code__
self.compiled_codes: List[CodeType] = []
@ -33,7 +54,7 @@ class TorchCompileWrapperWithCustomDispatcher:
# subclasses can use this to switch between the custom dispatcher
# and the default Dynamo guard mechanism.
self.use_custom_dispatcher: bool = \
envs.VLLM_DYNAMO_USE_CUSTOM_DISPATCHER
envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.DYNAMO_ONCE
def __call__(self, *args, **kwargs):
"""Implement the dispatch logic here, beyond the torch.compile level.

View File

@ -65,6 +65,7 @@ if TYPE_CHECKING:
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_ALLOW_DEPRECATED_BLOCK_MANAGER_V1: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
def get_default_cache_root():
@ -198,23 +199,12 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda: (os.environ.get("VLLM_USE_TRITON_FLASH_ATTN", "True").lower() in
("true", "1")),
# Internal flag to enable Dynamo graph capture
"VLLM_TEST_DYNAMO_GRAPH_CAPTURE":
lambda: int(os.environ.get("VLLM_TEST_DYNAMO_GRAPH_CAPTURE", "0")),
"VLLM_DYNAMO_USE_CUSTOM_DISPATCHER":
lambda:
(os.environ.get("VLLM_DYNAMO_USE_CUSTOM_DISPATCHER", "True").lower() in
("true", "1")),
# Internal flag to control whether we use custom op,
# or use the native pytorch implementation
"VLLM_TEST_COMPILE_NO_CUSTOM_OPS":
lambda: int(os.environ.get("VLLM_TEST_COMPILE_NO_CUSTOM_OPS", "0")),
# Internal flag to enable Dynamo fullgraph capture
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# local rank of the process in the distributed setting, used to determine
# the GPU device id

View File

@ -1,6 +1,7 @@
import torch.nn as nn
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.platforms import current_platform
from vllm.utils import is_cpu, is_hip, is_xpu
@ -55,7 +56,7 @@ class CustomOp(nn.Module):
# NOTE(woosuk): Here we assume that vLLM was built for only one
# specific backend. Currently, we do not support dynamic dispatching.
if envs.VLLM_TEST_COMPILE_NO_CUSTOM_OPS:
if envs.VLLM_TORCH_COMPILE_LEVEL >= CompilationLevel.INDUCTOR:
return self.forward_native
if is_hip():

View File

@ -21,6 +21,7 @@ from torch import nn
from transformers import Gemma2Config
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_compile_llama_style
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.logger import init_logger
@ -238,6 +239,7 @@ class Gemma2DecoderLayer(nn.Module):
return hidden_states, residual
@support_compile_llama_style
class Gemma2Model(nn.Module):
def __init__(

View File

@ -28,6 +28,7 @@ from torch import nn
from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_compile_llama_style
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
@ -265,6 +266,7 @@ class LlamaDecoderLayer(nn.Module):
return hidden_states, residual
@support_compile_llama_style
class LlamaModel(nn.Module):
def __init__(

View File

@ -365,6 +365,8 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
input_ids = None
inputs_embeds = None
else:
# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
image_input = self._parse_and_validate_image_input(**kwargs)
if image_input is not None:
@ -375,10 +377,10 @@ class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
inputs_embeds = merge_multimodal_embeddings(
input_ids, inputs_embeds, vision_embeddings,
self.config.image_token_index)
input_ids = None
else:
inputs_embeds = None
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)
input_ids = None
hidden_states = self.language_model.model(input_ids,
positions,

View File

@ -1,7 +1,21 @@
import os
import torch
import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.plugins import set_torch_compile_backend
from .interface import Platform, PlatformEnum
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
assert envs.VLLM_TORCH_COMPILE_LEVEL < CompilationLevel.INDUCTOR,\
"TPU does not support Inductor."
set_torch_compile_backend("openxla")
class TpuPlatform(Platform):
_enum = PlatformEnum.TPU

View File

@ -1,5 +1,5 @@
import logging
from typing import Callable, Optional, Union
from typing import Callable, Dict, Optional, Union
import vllm.envs as envs
@ -42,3 +42,15 @@ def set_torch_compile_backend(backend: Union[Callable, str]):
def get_torch_compile_backend() -> Optional[Union[Callable, str]]:
return _torch_compile_backend
_inductor_additional_configs: Dict = {}
def set_inductor_additional_configs(configs: Dict):
global _inductor_additional_configs
_inductor_additional_configs = configs
def get_inductor_additional_configs() -> Dict:
return _inductor_additional_configs

View File

@ -1137,10 +1137,9 @@ class EmbeddingSequenceGroupOutput(
return self.embeddings == other.embeddings
class IntermediateTensors(
msgspec.Struct,
omit_defaults=True, # type: ignore[call-arg]
array_like=True): # type: ignore[call-arg]
# cannot use msgspec.Struct here because Dynamo does not support it
@dataclass
class IntermediateTensors:
"""For all pipeline stages except the last, we need to return the hidden
states and residuals to be sent to the next stage. This data structure
contains the hidden states and residuals for a request.

View File

@ -18,6 +18,8 @@ import vllm.envs as envs
from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.attention.backends.abstract import AttentionState
from vllm.attention.backends.utils import CommonAttentionState
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.levels import CompilationLevel
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
PromptAdapterConfig, SchedulerConfig)
@ -1126,10 +1128,10 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
"provided. Defaulting to scaling factors of 1.0. "
"This may lead to less accurate results!")
if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo():
from vllm.compilation.backends import vllm_backend
if envs.VLLM_TORCH_COMPILE_LEVEL == CompilationLevel.DYNAMO_AS_IS \
and supports_dynamo():
from vllm.plugins import get_torch_compile_backend
backend = get_torch_compile_backend() or vllm_backend
backend = get_torch_compile_backend() or "eager"
self.model = torch.compile(
self.model,
fullgraph=envs.VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE,
@ -1289,7 +1291,15 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
batch_size=batch_size,
dtype=self.model_config.dtype,
device=self.device)
self.execute_model(model_input, kv_caches, intermediate_tensors)
graph_batch_size = self.max_batchsize_to_capture
batch_size_capture_list = [
bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size
]
if self.model_config.enforce_eager:
batch_size_capture_list = []
with set_compile_context(batch_size_capture_list):
self.execute_model(model_input, kv_caches, intermediate_tensors)
torch.cuda.synchronize()
return