Merge branch 'main' into v1-sched-interface-2

This commit is contained in:
Woosuk Kwon
2025-03-12 21:45:44 -07:00
9 changed files with 51 additions and 23 deletions

View File

@ -9,6 +9,7 @@ setuptools-scm>=8
wheel
jinja2
ray[default]
ray[data]
# Install torch_xla
--pre

View File

@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
# The KV connector port, used to build distributed connection
kv_port: int = 14579
# any extra config that the connector may need
kv_connector_extra_config: dict[str, Any] = {}
def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
return self.kv_connector is not None and \
self.kv_role in ["kv_consumer", "kv_both"]
def get_from_extra_config(self, key, default) -> Any:
return self.kv_connector_extra_config.get(key, default)
class CompilationLevel:
# constants for the levels of the compilation process

View File

@ -6,7 +6,7 @@
- Distributed KV cache transmission using PyNccl pipes.
- Non-blocking `insert`, blocking `drop_select`.
- Use CPU signal pipe to avoid racing condition
- Handles buffer size constraints and provide backpressure mechanism to
- Handles buffer size constraints and provide backpressure mechanism to
stop the prefill instance when the decode instance is slow.
"""
import threading

View File

@ -1,7 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
"""
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
This module implements a PyNccl pipe for sending and receiving
Optional[torch.Tensor] between distributed ranks with advanced
communication features.
Key Features:
@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
self.device = self._select_device(device)
# build distributed connection and send/recv implementation
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
self.group = StatelessProcessGroup.create(
host=self.config.kv_ip,
port=self.config.kv_port + port_offset,
rank=self.kv_rank,
world_size=self.kv_parallel_size,
store_timeout=store_timeout,
)
# add a barrier to make sure the connection is initiated properly
self.group.barrier()
@ -134,11 +136,11 @@ class PyNcclPipe(KVPipeBase):
Create a buffer to receive the tensor based on the provided metadata.
Parameters:
- metadata: A dictionary with keys "dtype" and "shape", describing
- metadata: A dictionary with keys "dtype" and "shape", describing
the tensor's data type and shape.
Returns:
- buffer: A tensor of the specified type and shape, allocated on
- buffer: A tensor of the specified type and shape, allocated on
self.device.
"""
return torch.empty(metadata["shape"],
@ -159,18 +161,18 @@ class PyNcclPipe(KVPipeBase):
Receive the metadata dictionary from the target rank.
Returns:
- metadata: A dictionary with keys "dtype" and "shape" describing
- metadata: A dictionary with keys "dtype" and "shape" describing
the tensor.
"""
return self.group.recv_obj(self.target_rank_for_recv)
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
"""
The actual implementation of sending the tensor and its metadata to the
The actual implementation of sending the tensor and its metadata to the
target rank.
Parameters:
- tensor: The input tensor to be sent, or None if no tensor is
- tensor: The input tensor to be sent, or None if no tensor is
being sent.
"""
metadata = self._make_metadata(tensor)
@ -181,7 +183,7 @@ class PyNcclPipe(KVPipeBase):
def _recv_impl(self) -> Optional[torch.Tensor]:
"""
The actual implementation of receiving a tensor and its metadata from
The actual implementation of receiving a tensor and its metadata from
the target rank.
Returns:
@ -213,7 +215,7 @@ class PyNcclPipe(KVPipeBase):
def block_if_full(self):
"""
Block the current thread if the buffer size is larger than the
Block the current thread if the buffer size is larger than the
threshold.
"""
while self.buffer_size > self.buffer_size_thresh:
@ -222,7 +224,7 @@ class PyNcclPipe(KVPipeBase):
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
"""
Sends a tensor and its metadata to the destination rank in a
Sends a tensor and its metadata to the destination rank in a
non-blocking way.
Parameters:

View File

@ -5,6 +5,7 @@
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import dataclasses
import datetime
import pickle
import time
from collections import deque
@ -217,6 +218,7 @@ class StatelessProcessGroup:
rank: int,
world_size: int,
data_expiration_seconds: int = 3600,
store_timeout: int = 300,
) -> "StatelessProcessGroup":
"""A replacement for `torch.distributed.init_process_group` that does not
pollute the global state.
@ -238,6 +240,7 @@ class StatelessProcessGroup:
port=port,
world_size=world_size,
is_master=(rank == 0),
timeout=datetime.timedelta(seconds=store_timeout),
)
return StatelessProcessGroup(

View File

@ -50,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
# We prefer to use separate k_scale and v_scale if present
k_scale = layer.k_scale.to("cpu").tolist()
v_scale = layer.v_scale.to("cpu").tolist()
if current_platform.is_rocm():
if current_platform.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
@ -66,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
k_scale = scale_to_duplicate.to("cpu").tolist()
v_scale = scale_to_duplicate.to("cpu").tolist()
if current_platform.is_rocm():
if current_platform.is_fp8_fnuz():
k_scale *= 2
v_scale *= 2

View File

@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader):
local_model_path, gguf_weights_map):
model_config.hf_config.update({"tie_word_embeddings": True})
target_device = torch.device(device_config.device)
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
with target_device:
model = _initialize_model(vllm_config=vllm_config)
model.load_weights(
self._get_weights_iterator(local_model_path, gguf_weights_map))
_process_weights_after_loading(model, model_config, target_device)
return model

View File

@ -91,13 +91,19 @@ class TpuPlatform(Platform):
parallel_config = vllm_config.parallel_config
scheduler_config = vllm_config.scheduler_config
if parallel_config.worker_cls == "auto":
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
if scheduler_config.is_multi_step:
if scheduler_config.is_multi_step:
if envs.VLLM_USE_V1:
raise NotImplementedError(
"Multi-step scheduling is not supported (and not "
"needed) on vLLM V1. Please launch without "
"--num-scheduler-steps.")
else:
parallel_config.worker_cls = \
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
else:
if envs.VLLM_USE_V1:
parallel_config.worker_cls = \
"vllm.v1.worker.tpu_worker.TPUWorker"
else:
parallel_config.worker_cls = \
"vllm.worker.tpu_worker.TPUWorker"

View File

@ -23,6 +23,7 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
zmq_socket_ctx)
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
EngineCoreRequestType, UtilityOutput)
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
@ -67,15 +68,21 @@ class EngineCore:
# Setup scheduler.
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
# This warning can be removed once the V1 Scheduler interface is
# finalized and we can maintain support for scheduler classes that
# implement it
if Scheduler is not V1Scheduler:
logger.warning(
"Using configured V1 scheduler class %s. "
"This scheduler interface is not public and "
"compatibility may not be maintained.",
vllm_config.scheduler_config.scheduler_cls)
Scheduler = resolve_obj_by_qualname(
vllm_config.scheduler_config.scheduler_cls)
else:
Scheduler = vllm_config.scheduler_config.scheduler_cls
self.scheduler = Scheduler(
scheduler_config=vllm_config.scheduler_config,
model_config=vllm_config.model_config,