mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
[Misc] Set block size at initialization & Fix test_model_runner (#4705)
This commit is contained in:
@ -1,27 +1,38 @@
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.config import ModelConfig, SchedulerConfig
|
||||
from vllm.distributed.parallel_state import init_distributed_environment
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size
|
||||
|
||||
|
||||
def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner:
|
||||
engine_args = EngineArgs(model, *args, **kwargs)
|
||||
engine_config = engine_args.create_engine_config()
|
||||
model_runner = ModelRunner(
|
||||
model_config=engine_config.model_config,
|
||||
parallel_config=engine_config.parallel_config,
|
||||
scheduler_config=engine_config.scheduler_config,
|
||||
device_config=engine_config.device_config,
|
||||
cache_config=engine_config.cache_config,
|
||||
load_config=engine_config.load_config,
|
||||
lora_config=engine_config.lora_config,
|
||||
is_driver_worker=True,
|
||||
)
|
||||
return model_runner
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_prompt(batch_size):
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(model_config=None,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
@ -123,27 +134,15 @@ def test_prepare_prompt(batch_size):
|
||||
|
||||
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
|
||||
def test_prepare_decode_cuda_graph(batch_size):
|
||||
model_config = ModelConfig(
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=False)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
seq_lens = []
|
||||
seq_group_metadata_list = []
|
||||
@ -214,23 +213,12 @@ def test_prepare_decode_cuda_graph(batch_size):
|
||||
|
||||
def test_empty_seq_group():
|
||||
"""Verify prepare prompt and decode returns empty output."""
|
||||
model_config = ModelConfig(
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=False,
|
||||
)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=None,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None)
|
||||
model_runner.set_block_size(16)
|
||||
seq_group_metadata_list = []
|
||||
input_tokens, input_positions, attn_metadata, _, _, _, slot_mapping = (
|
||||
model_runner._prepare_decode(seq_group_metadata_list))
|
||||
@ -260,29 +248,15 @@ def distributed_init():
|
||||
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
|
||||
@pytest.mark.parametrize("enforce_eager", [True, False])
|
||||
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_runner = _create_model_runner(
|
||||
"facebook/opt-125m",
|
||||
"facebook/opt-125m",
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=False,
|
||||
seed=0,
|
||||
dtype="float16",
|
||||
revision=None,
|
||||
enforce_eager=enforce_eager,
|
||||
max_num_batched_tokens=100000,
|
||||
max_num_seqs=100000,
|
||||
enable_chunked_prefill=True,
|
||||
)
|
||||
scheduler_config = SchedulerConfig(100000,
|
||||
100000,
|
||||
100000,
|
||||
enable_chunked_prefill=True)
|
||||
model_runner = ModelRunner(model_config=model_config,
|
||||
parallel_config=None,
|
||||
scheduler_config=scheduler_config,
|
||||
device_config=None,
|
||||
load_config=None,
|
||||
lora_config=None,
|
||||
is_driver_worker=True)
|
||||
model_runner.set_block_size(16)
|
||||
|
||||
# Add prefill requests.
|
||||
seq_lens = []
|
||||
|
@ -4,8 +4,9 @@ import torch
|
||||
from torch import nn
|
||||
|
||||
from vllm.attention import AttentionMetadata, get_attn_backend
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor import SamplingMetadata
|
||||
@ -26,6 +27,7 @@ class CPUModelRunner:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
vision_language_config: Optional[VisionLanguageConfig],
|
||||
@ -39,27 +41,22 @@ class CPUModelRunner:
|
||||
self.scheduler_config = scheduler_config
|
||||
# Currently, CPU worker doesn't support chunked prefill.
|
||||
assert self.scheduler_config.chunked_prefill_enabled is False
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.vision_language_config = vision_language_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
|
||||
# model_config can be None in tests/samplers/test_sampler.py.
|
||||
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||
self.sliding_window = (model_config.get_sliding_window()
|
||||
if model_config is not None else None)
|
||||
self.device_config = (device_config
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.attn_backend = get_attn_backend(self.model_config.dtype)
|
||||
|
||||
# Lazy initialization.
|
||||
self.model: nn.Module # Set after init_Model
|
||||
self.block_size: int # Set after initial profiling.
|
||||
|
||||
def load_model(self) -> None:
|
||||
self.model = get_model(
|
||||
|
@ -151,6 +151,7 @@ class CPUWorker(LoraNotSupportedWorkerBase):
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config=self.load_config,
|
||||
lora_config=self.lora_config,
|
||||
vision_language_config=self.vision_language_config,
|
||||
|
@ -9,8 +9,9 @@ import torch.nn as nn
|
||||
|
||||
from vllm.attention import (AttentionMetadata, AttentionMetadataPerStage,
|
||||
get_attn_backend)
|
||||
from vllm.config import (DeviceConfig, LoadConfig, LoRAConfig, ModelConfig,
|
||||
ParallelConfig, SchedulerConfig, VisionLanguageConfig)
|
||||
from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
|
||||
ModelConfig, ParallelConfig, SchedulerConfig,
|
||||
VisionLanguageConfig)
|
||||
from vllm.distributed import broadcast_tensor_dict, with_pynccl_for_all_reduce
|
||||
from vllm.distributed.device_communicators import (custom_all_reduce,
|
||||
pynccl_utils)
|
||||
@ -106,6 +107,7 @@ class ModelRunner:
|
||||
parallel_config: ParallelConfig,
|
||||
scheduler_config: SchedulerConfig,
|
||||
device_config: DeviceConfig,
|
||||
cache_config: CacheConfig,
|
||||
load_config: LoadConfig,
|
||||
lora_config: Optional[LoRAConfig],
|
||||
kv_cache_dtype: Optional[str] = "auto",
|
||||
@ -115,48 +117,40 @@ class ModelRunner:
|
||||
self.model_config = model_config
|
||||
self.parallel_config = parallel_config
|
||||
self.scheduler_config = scheduler_config
|
||||
self.device_config = device_config
|
||||
self.cache_config = cache_config
|
||||
self.lora_config = lora_config
|
||||
self.load_config = load_config
|
||||
self.is_driver_worker = is_driver_worker
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
# model_config can be None in tests/samplers/test_sampler.py.
|
||||
# FIXME(woosuk): This is a hack to make the tests work. Refactor this.
|
||||
self.sliding_window = (model_config.get_sliding_window()
|
||||
if model_config is not None else None)
|
||||
self.device_config = (device_config
|
||||
if device_config is not None else DeviceConfig())
|
||||
self.device = self.device_config.device
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
|
||||
# Set after load_model.
|
||||
self.lora_manager: LRUCacheWorkerLoRAManager = None
|
||||
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.sliding_window = model_config.get_sliding_window()
|
||||
self.block_size = cache_config.block_size
|
||||
self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture
|
||||
self.graph_runners: Dict[int, CUDAGraphRunner] = {}
|
||||
self.graph_memory_pool: Optional[Tuple[
|
||||
int, int]] = None # Set during graph capture.
|
||||
|
||||
self.max_seq_len_to_capture = (self.model_config.max_seq_len_to_capture
|
||||
if self.model_config is not None else 0)
|
||||
|
||||
self.pin_memory = is_pin_memory_available()
|
||||
self.kv_cache_dtype = kv_cache_dtype
|
||||
self.vision_language_config = vision_language_config
|
||||
|
||||
self.attn_backend = get_attn_backend(
|
||||
self.model_config.dtype if model_config is not None else None)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: torch.nn.Module # Set after load_model
|
||||
self.block_size: int # Set after initial profiling.
|
||||
# When using CUDA graph, the input block tables must be padded to
|
||||
# max_seq_len_to_capture. However, creating the block table in
|
||||
# Python can be expensive. To optimize this, we cache the block table
|
||||
# in numpy and only copy the actual input content at every iteration.
|
||||
# The shape of the cached block table will be
|
||||
# (max batch size to capture, max context len to capture / block size).
|
||||
self.graph_block_tables: torch.Tensor # Set after initial profiling.
|
||||
self.graph_block_tables = np.zeros(
|
||||
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
self.attn_backend = get_attn_backend(self.model_config.dtype)
|
||||
|
||||
# Lazy initialization
|
||||
self.model: torch.nn.Module # Set after load_model
|
||||
# Set if the backend is flashinfer.
|
||||
self.flashinfer_workspace_buffer: torch.Tensor
|
||||
# Set after load_model.
|
||||
self.lora_manager: Optional[LRUCacheWorkerLoRAManager] = None
|
||||
|
||||
def load_model(self) -> None:
|
||||
with CudaMemoryProfiler() as m:
|
||||
@ -211,13 +205,6 @@ class ModelRunner:
|
||||
"but the KV cache data type is not FP8. "
|
||||
"KV cache scaling factors will not be used.")
|
||||
|
||||
def set_block_size(self, block_size: int) -> None:
|
||||
self.block_size = block_size
|
||||
|
||||
self.graph_block_tables = np.zeros(
|
||||
(max(_BATCH_SIZES_TO_CAPTURE), self.get_max_block_per_batch()),
|
||||
dtype=np.int32)
|
||||
|
||||
def get_max_block_per_batch(self) -> int:
|
||||
block_size = self.block_size
|
||||
return (self.max_seq_len_to_capture + block_size - 1) // block_size
|
||||
@ -835,6 +822,7 @@ class ModelRunner:
|
||||
dummy_lora_requests = []
|
||||
dummy_lora_requests_per_seq = []
|
||||
if self.lora_config:
|
||||
assert self.lora_manager is not None
|
||||
with self.lora_manager.dummy_lora_cache():
|
||||
for idx in range(self.lora_config.max_loras):
|
||||
lora_id = idx + 1
|
||||
|
@ -75,6 +75,7 @@ class Worker(WorkerBase):
|
||||
parallel_config,
|
||||
scheduler_config,
|
||||
device_config,
|
||||
cache_config,
|
||||
load_config=load_config,
|
||||
lora_config=self.lora_config,
|
||||
kv_cache_dtype=self.cache_config.cache_dtype,
|
||||
@ -184,7 +185,6 @@ class Worker(WorkerBase):
|
||||
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
|
||||
self.parallel_config)
|
||||
self.gpu_cache = self.cache_engine.gpu_cache
|
||||
self.model_runner.set_block_size(self.cache_engine.block_size)
|
||||
|
||||
def _warm_up_model(self) -> None:
|
||||
if not self.model_config.enforce_eager:
|
||||
|
Reference in New Issue
Block a user