[6/N] torch.compile rollout to users (#10437)

Signed-off-by: youkaichao <youkaichao@gmail.com>
This commit is contained in:
youkaichao
2024-11-19 10:09:03 -08:00
committed by GitHub
parent fd9f124971
commit 803f37eaaa
15 changed files with 129 additions and 141 deletions

View File

@ -1,5 +0,0 @@
{
"use_cudagraph": true,
"non_cudagraph_ops": ["silly.attention"],
"cudagraph_copy_inputs": true
}

View File

@ -2,7 +2,6 @@
Test the piecewise compilation with a simple model so that we
can exactly calculate the expected output and side effects.
"""
import os
import torch
from torch import nn
@ -11,7 +10,7 @@ from torch.library import Library
from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationLevel, VllmConfig
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
@ -77,12 +76,12 @@ class SillyModel(nn.Module):
def test_simple_piecewise_compile():
directory = os.path.dirname(__file__)
config = os.path.join(directory, "piecewise_compilation_config.json")
os.environ["VLLM_TORCH_COMPILE_CONFIG"] = config
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
cudagraph_copy_inputs=True,
))
with set_current_vllm_config(vllm_config):
model = SillyModel(vllm_config=vllm_config, prefix='')
@ -109,6 +108,3 @@ def test_simple_piecewise_compile():
output = model(input)
assert global_counter == 2
assert torch.allclose(output.cpu(), torch.tensor([3., 1.]))
# clean up to avoid side effects for other tests
del os.environ["VLLM_TORCH_COMPILE_CONFIG"]

View File

@ -6,7 +6,6 @@ This is a tractable model, the weights and computation are specially designed
if the config `tractable_init` is set to True. Otherwise, the weights are
initialized randomly with a fixed seed.
"""
import os
from dataclasses import dataclass
from typing import Optional, Tuple
@ -18,7 +17,7 @@ from vllm.compilation.compile_context import set_compile_context
from vllm.compilation.counter import compilation_counter
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.plugins import set_compilation_config, set_current_vllm_config
from vllm.plugins import set_current_vllm_config
from vllm.utils import direct_register_custom_op
# create a library to hold the custom op
@ -254,23 +253,17 @@ def run_model(llama_config,
split_attn: bool = False) -> torch.Tensor:
if use_compile:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.PIECEWISE)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
)
if split_attn:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
else:
set_compilation_config(CompilationConfig(use_cudagraph=True, ))
compilation_config.non_cudagraph_ops = ["silly.attention"]
else:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(
CompilationLevel.NO_COMPILATION)
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.NO_COMPILATION, )
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,
@ -288,10 +281,6 @@ def run_model(llama_config,
input_ids[:2].zero_()
output = model(input_ids[:2], positions[:2])
# manual cleanup
del os.environ["VLLM_TORCH_COMPILE_LEVEL"]
set_compilation_config(None)
output = output.cpu()
if llama_config.tractable_init:
@ -361,7 +350,6 @@ def test_toy_llama():
@torch.inference_mode
def benchmark():
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.PIECEWISE)
from triton.testing import do_bench
# similar to llama 3.1-8B
@ -387,15 +375,16 @@ def benchmark():
for piecewise in [False, True]:
if piecewise:
set_compilation_config(
CompilationConfig(
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
))
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE,
use_cudagraph=True,
non_cudagraph_ops=["silly.attention"],
)
else:
set_compilation_config(None)
compilation_config = CompilationConfig(
level=CompilationLevel.PIECEWISE, )
vllm_config = VllmConfig()
vllm_config = VllmConfig(compilation_config=compilation_config)
with set_current_vllm_config(vllm_config):
model = LlamaModel(config=llama_config,
vllm_config=vllm_config,

View File

@ -96,31 +96,36 @@ def test_compile_correctness(test_setting: TestSetting):
final_args = ["--enforce-eager"] + model_args + ["-pp", str(pp_size)] + \
["-tp", str(tp_size)]
all_args: List[List[str]] = []
all_envs: List[Optional[Dict[str, str]]] = []
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.PIECEWISE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
# inductor will change the output, so we only compare if the output
# is close, not exactly the same.
compare_all_settings(
model, [final_args] * 2,
model,
all_args,
all_envs,
method=method if method != "generate" else "generate_close")
all_envs.clear()
all_args.clear()
for level in [
CompilationLevel.NO_COMPILATION,
CompilationLevel.DYNAMO_AS_IS,
CompilationLevel.DYNAMO_ONCE,
]:
all_envs.append({"VLLM_TORCH_COMPILE_LEVEL": str(level)})
all_args.append(final_args + ["-O", str(level)])
all_envs.append({})
if level != CompilationLevel.DYNAMO_ONCE and not fullgraph:
# "DYNAMO_ONCE" will always use fullgraph
all_envs[-1][
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "0" # type: ignore
compare_all_settings(model, [final_args] * 3, all_envs, method=method)
compare_all_settings(model, all_args * 3, all_envs, method=method)

View File

@ -4,7 +4,7 @@ import torch
from tests.quantization.utils import is_quant_method_supported
from vllm import LLM, SamplingParams
from vllm.config import CompilationLevel
from vllm.config import CompilationConfig, CompilationLevel
from vllm.platforms import current_platform
TEST_MODELS = [
@ -65,7 +65,6 @@ def check_full_graph_support(model,
optimization_level,
tp_size=1):
# make sure these models can be captured in full graph mode
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(optimization_level)
os.environ["VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE"] = "1"
# The base meta llama uses too much memory.
@ -86,6 +85,7 @@ def check_full_graph_support(model,
enforce_eager=True,
tensor_parallel_size=tp_size,
disable_custom_all_reduce=True,
compilation_config=CompilationConfig(level=optimization_level),
**model_kwargs)
outputs = llm.generate(prompts, sampling_params)

View File

@ -1,4 +1,3 @@
import os
from typing import List
import pytest
@ -53,9 +52,8 @@ class Relu3(ReLUSquaredActivation):
])
def test_enabled_ops(env: str, torch_level: int, ops_enabled: List[int],
default_on: bool):
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(torch_level)
vllm_config = VllmConfig(compilation_config=CompilationConfig(
custom_ops=env.split(",")))
level=torch_level, custom_ops=env.split(",")))
with set_current_vllm_config(vllm_config):
assert CustomOp.default_on() == default_on

View File

@ -1,24 +1,47 @@
import glob
import os
import runpy
import tempfile
import depyf
from vllm.config import CompilationLevel
# disable custom dispatcher, let Dynamo takes over
# all the control
os.environ['VLLM_TORCH_COMPILE_LEVEL'] = str(CompilationLevel.DYNAMO_AS_IS)
from vllm.config import CompilationConfig, CompilationLevel
temp_dir = tempfile.mkdtemp()
with depyf.prepare_debug(temp_dir):
cur_dir = os.path.dirname(__file__)
parent_dir = os.path.dirname(cur_dir)
root_dir = os.path.dirname(parent_dir)
example_file = os.path.join(root_dir, "examples",
"offline_inference_tpu.py")
runpy.run_path(example_file)
from vllm import LLM, SamplingParams
prompts = [
"A robot may not injure a human being",
"It is only with the heart that one can see rightly;",
"The greatest glory in living lies not in never falling,",
]
answers = [
" or, through inaction, allow a human being to come to harm.",
" what is essential is invisible to the eye.",
" but in rising every time we fall.",
]
N = 1
# Currently, top-p sampling is disabled. `top_p` should be 1.0.
sampling_params = SamplingParams(temperature=0.7,
top_p=1.0,
n=N,
max_tokens=16)
# Set `enforce_eager=True` to avoid ahead-of-time compilation.
# In real workloads, `enforace_eager` should be `False`.
# disable custom dispatcher, let Dynamo takes over
# all the control
llm = LLM(model="google/gemma-2b",
enforce_eager=True,
compilation_config=CompilationConfig(
level=CompilationLevel.DYNAMO_AS_IS))
outputs = llm.generate(prompts, sampling_params)
for output, answer in zip(outputs, answers):
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
assert generated_text.startswith(answer)
compiled_code = sorted(
glob.glob(os.path.join(temp_dir, "__transformed_code*.py")))

View File

@ -13,7 +13,9 @@ os.environ["VLLM_RPC_TIMEOUT"] = "30000"
def test_custom_dispatcher():
compare_two_settings(
"google/gemma-2b",
arg1=["--enforce-eager"],
arg2=["--enforce-eager"],
env1={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_ONCE)},
env2={"VLLM_TORCH_COMPILE_LEVEL": str(CompilationLevel.DYNAMO_AS_IS)})
arg1=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_ONCE)],
arg2=["--enforce-eager", "-O",
str(CompilationLevel.DYNAMO_AS_IS)],
env1={},
env2={})

View File

@ -2174,8 +2174,14 @@ class CompilationConfig(BaseModel):
enabled_custom_ops: Counter[str] = PrivateAttr
disabled_custom_ops: Counter[str] = PrivateAttr
@classmethod
def from_cli(cls, cli_value: str) -> "CompilationConfig":
"""Parse the CLI value for the compilation config."""
if cli_value in ["0", "1", "2", "3"]:
return cls(level=int(cli_value))
return CompilationConfig.model_validate_json(cli_value)
def model_post_init(self, __context: Any) -> None:
self.level = envs.VLLM_TORCH_COMPILE_LEVEL
count_none = self.custom_ops.count("none")
count_all = self.custom_ops.count("all")
@ -2249,26 +2255,6 @@ class CompilationConfig(BaseModel):
"inductor_specialize_for_cudagraph_no_more_than is None")
self.compile_sizes = self.inductor_compile_sizes
@staticmethod
def select_and_init_config() -> "CompilationConfig":
"""The order of selecting config is:
1. Use the config specified in environment variable.
2. Use the config specified in plugins.
3. Use the default config.
"""
config_path = envs.VLLM_TORCH_COMPILE_CONFIG
if config_path is not None:
with open(config_path) as json_file:
config = CompilationConfig.model_validate_json(
json_file.read())
else:
from vllm.plugins import get_compilation_config
predefined_config = get_compilation_config()
config = predefined_config if predefined_config is not None else (
CompilationConfig())
return config
@dataclass
class VllmConfig:
@ -2354,8 +2340,19 @@ class VllmConfig:
self.model_config, self.load_config)
if self.compilation_config is None:
self.compilation_config = CompilationConfig.select_and_init_config(
)
self.compilation_config = CompilationConfig()
if envs.VLLM_USE_V1:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
self.compilation_config.custom_ops = ["none"]
self.compilation_config.use_cudagraph = True
self.compilation_config.non_cudagraph_ops = [
"vllm.unified_v1_flash_attention"
]
self.compilation_config.use_inductor = True
self.compilation_config.enable_fusion = False
current_platform.check_and_update_config(self)

View File

@ -8,12 +8,13 @@ from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Mapping, Optional,
import torch
import vllm.envs as envs
from vllm.config import (CacheConfig, ConfigFormat, DecodingConfig,
DeviceConfig, HfOverrides, LoadConfig, LoadFormat,
LoRAConfig, ModelConfig, ObservabilityConfig,
ParallelConfig, PoolerConfig, PromptAdapterConfig,
SchedulerConfig, SpeculativeConfig, TaskOption,
TokenizerPoolConfig, VllmConfig)
from vllm.config import (CacheConfig, CompilationConfig, ConfigFormat,
DecodingConfig, DeviceConfig, HfOverrides, LoadConfig,
LoadFormat, LoRAConfig, ModelConfig,
ObservabilityConfig, ParallelConfig, PoolerConfig,
PromptAdapterConfig, SchedulerConfig,
SpeculativeConfig, TaskOption, TokenizerPoolConfig,
VllmConfig)
from vllm.executor.executor_base import ExecutorBase
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
@ -189,6 +190,7 @@ class EngineArgs:
override_neuron_config: Optional[Dict[str, Any]] = None
override_pooler_config: Optional[PoolerConfig] = None
compilation_config: Optional[CompilationConfig] = None
def __post_init__(self):
if not self.tokenizer:
@ -868,6 +870,20 @@ class EngineArgs:
help="Override or set the pooling method in the embedding model. "
"e.g. {\"pooling_type\": \"mean\", \"normalize\": false}.'")
parser.add_argument('--compilation-config',
'-O',
type=CompilationConfig.from_cli,
default=None,
help='torch.compile configuration for the model.'
'When it is a number (0, 1, 2, 3), it will be '
'interpreted as the optimization level.\n'
'NOTE: level 0 is the default level without '
'any optimization. level 1 and 2 are for internal '
'testing only. level 3 is the recommended level '
'for production.\n'
'To specify the full compilation config, '
'use a JSON string.')
return parser
@classmethod
@ -1142,6 +1158,7 @@ class EngineArgs:
decoding_config=decoding_config,
observability_config=observability_config,
prompt_adapter_config=prompt_adapter_config,
compilation_config=self.compilation_config,
)

View File

@ -262,7 +262,8 @@ class LLMEngine:
"num_scheduler_steps=%d, chunked_prefill_enabled=%s "
"multi_step_stream_outputs=%s, enable_prefix_caching=%s, "
"use_async_output_proc=%s, use_cached_outputs=%s, "
"mm_processor_kwargs=%s, pooler_config=%r)",
"mm_processor_kwargs=%s, pooler_config=%r,"
"compilation_config=%r",
VLLM_VERSION,
model_config.model,
speculative_config,
@ -297,6 +298,7 @@ class LLMEngine:
use_cached_outputs,
model_config.mm_processor_kwargs,
model_config.pooler_config,
vllm_config.compilation_config,
)
# TODO(woosuk): Print more configs in debug mode.
self.model_config = model_config

View File

@ -67,8 +67,6 @@ if TYPE_CHECKING:
VLLM_USE_TRITON_AWQ: bool = False
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
VLLM_SKIP_P2P_CHECK: bool = False
VLLM_TORCH_COMPILE_LEVEL: int = 0
VLLM_TORCH_COMPILE_CONFIG: Optional[str] = None
VLLM_DISABLED_KERNELS: List[str] = []
VLLM_USE_V1: bool = False
VLLM_ENABLE_V1_MULTIPROCESSING: bool = False
@ -209,12 +207,6 @@ environment_variables: Dict[str, Callable[[], Any]] = {
"VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE":
lambda: bool(
os.environ.get("VLLM_TEST_DYNAMO_FULLGRAPH_CAPTURE", "1") != "0"),
"VLLM_TORCH_COMPILE_LEVEL":
lambda: int(os.environ.get("VLLM_TORCH_COMPILE_LEVEL", "0")),
# Path to the config file for torch compile
"VLLM_TORCH_COMPILE_CONFIG":
lambda: os.environ.get("VLLM_TORCH_COMPILE_CONFIG", None),
# local rank of the process in the distributed setting, used to determine
# the GPU device id

View File

@ -1,4 +1,3 @@
import os
from typing import TYPE_CHECKING
import torch
@ -40,7 +39,8 @@ class TpuPlatform(Platform):
def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
from vllm.config import CompilationLevel
compilation_config = vllm_config.compilation_config
if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
if compilation_config.level == CompilationLevel.NO_COMPILATION:
# TPU does not support NO_COMPILATION
compilation_config.level = CompilationLevel.DYNAMO_ONCE
assert compilation_config.level < CompilationLevel.PIECEWISE,\
"TPU does not support Inductor."

View File

@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Optional
import vllm.envs as envs
if TYPE_CHECKING:
from vllm.config import CompilationConfig, VllmConfig
from vllm.config import VllmConfig
logger = logging.getLogger(__name__)
@ -54,18 +54,6 @@ def load_general_plugins():
logger.exception("Failed to load plugin %s", plugin.name)
_compilation_config: Optional["CompilationConfig"] = None
def set_compilation_config(config: Optional["CompilationConfig"]):
global _compilation_config
_compilation_config = config
def get_compilation_config() -> Optional["CompilationConfig"]:
return _compilation_config
_current_vllm_config: Optional["VllmConfig"] = None

View File

@ -8,13 +8,12 @@ import torch.distributed
import torch.nn as nn
from vllm.compilation.compile_context import set_compile_context
from vllm.config import CompilationConfig, CompilationLevel, VllmConfig
from vllm.config import CompilationLevel, VllmConfig
from vllm.forward_context import set_forward_context
from vllm.inputs import INPUT_REGISTRY, InputRegistry
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.multimodal import MultiModalKwargs
from vllm.plugins import set_compilation_config
from vllm.sampling_params import SamplingParams, SamplingType
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, cdiv,
is_pin_memory_available)
@ -508,20 +507,6 @@ class GPUModelRunner:
return model_runner_output
def load_model(self) -> None:
if self.use_cuda_graph:
# NOTE(woosuk): Currently, we use inductor because the piecewise
# CUDA graphs do not work properly with the custom CUDA kernels.
# FIXME(woosuk): Disable inductor to reduce the compilation time
# and avoid any potential issues with the inductor.
set_compilation_config(
CompilationConfig(
custom_ops=["none"],
use_cudagraph=True,
non_cudagraph_ops=["vllm.unified_v1_flash_attention"],
use_inductor=True,
enable_fusion=False,
))
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
self.model = get_model(vllm_config=self.vllm_config)
@ -562,9 +547,8 @@ class GPUModelRunner:
def capture_model(self) -> None:
if not self.use_cuda_graph:
logger.warning(
"Skipping CUDA graph capture. Please set "
"VLLM_TORCH_COMPILE_LEVEL=%d to use CUDA graphs.",
CompilationLevel.PIECEWISE)
"Skipping CUDA graph capture. Please add "
"-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE)
return
start_time = time.perf_counter()