mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Merge branch 'main' into v1-sched-interface-2
This commit is contained in:
@ -9,6 +9,7 @@ setuptools-scm>=8
|
||||
wheel
|
||||
jinja2
|
||||
ray[default]
|
||||
ray[data]
|
||||
|
||||
# Install torch_xla
|
||||
--pre
|
||||
|
@ -2837,6 +2837,9 @@ class KVTransferConfig(BaseModel):
|
||||
# The KV connector port, used to build distributed connection
|
||||
kv_port: int = 14579
|
||||
|
||||
# any extra config that the connector may need
|
||||
kv_connector_extra_config: dict[str, Any] = {}
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
@ -2896,6 +2899,9 @@ class KVTransferConfig(BaseModel):
|
||||
return self.kv_connector is not None and \
|
||||
self.kv_role in ["kv_consumer", "kv_both"]
|
||||
|
||||
def get_from_extra_config(self, key, default) -> Any:
|
||||
return self.kv_connector_extra_config.get(key, default)
|
||||
|
||||
|
||||
class CompilationLevel:
|
||||
# constants for the levels of the compilation process
|
||||
|
@ -6,7 +6,7 @@
|
||||
- Distributed KV cache transmission using PyNccl pipes.
|
||||
- Non-blocking `insert`, blocking `drop_select`.
|
||||
- Use CPU signal pipe to avoid racing condition
|
||||
- Handles buffer size constraints and provide backpressure mechanism to
|
||||
- Handles buffer size constraints and provide backpressure mechanism to
|
||||
stop the prefill instance when the decode instance is slow.
|
||||
"""
|
||||
import threading
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
"""
|
||||
This module implements a PyNccl pipe for sending and receiving
|
||||
Optional[torch.Tensor] between distributed ranks with advanced
|
||||
This module implements a PyNccl pipe for sending and receiving
|
||||
Optional[torch.Tensor] between distributed ranks with advanced
|
||||
communication features.
|
||||
|
||||
Key Features:
|
||||
@ -59,11 +59,13 @@ class PyNcclPipe(KVPipeBase):
|
||||
self.device = self._select_device(device)
|
||||
|
||||
# build distributed connection and send/recv implementation
|
||||
store_timeout = self.config.get_from_extra_config("store_timeout", 300)
|
||||
self.group = StatelessProcessGroup.create(
|
||||
host=self.config.kv_ip,
|
||||
port=self.config.kv_port + port_offset,
|
||||
rank=self.kv_rank,
|
||||
world_size=self.kv_parallel_size,
|
||||
store_timeout=store_timeout,
|
||||
)
|
||||
# add a barrier to make sure the connection is initiated properly
|
||||
self.group.barrier()
|
||||
@ -134,11 +136,11 @@ class PyNcclPipe(KVPipeBase):
|
||||
Create a buffer to receive the tensor based on the provided metadata.
|
||||
|
||||
Parameters:
|
||||
- metadata: A dictionary with keys "dtype" and "shape", describing
|
||||
- metadata: A dictionary with keys "dtype" and "shape", describing
|
||||
the tensor's data type and shape.
|
||||
|
||||
Returns:
|
||||
- buffer: A tensor of the specified type and shape, allocated on
|
||||
- buffer: A tensor of the specified type and shape, allocated on
|
||||
self.device.
|
||||
"""
|
||||
return torch.empty(metadata["shape"],
|
||||
@ -159,18 +161,18 @@ class PyNcclPipe(KVPipeBase):
|
||||
Receive the metadata dictionary from the target rank.
|
||||
|
||||
Returns:
|
||||
- metadata: A dictionary with keys "dtype" and "shape" describing
|
||||
- metadata: A dictionary with keys "dtype" and "shape" describing
|
||||
the tensor.
|
||||
"""
|
||||
return self.group.recv_obj(self.target_rank_for_recv)
|
||||
|
||||
def _send_impl(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
The actual implementation of sending the tensor and its metadata to the
|
||||
target rank.
|
||||
|
||||
Parameters:
|
||||
- tensor: The input tensor to be sent, or None if no tensor is
|
||||
- tensor: The input tensor to be sent, or None if no tensor is
|
||||
being sent.
|
||||
"""
|
||||
metadata = self._make_metadata(tensor)
|
||||
@ -181,7 +183,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
|
||||
def _recv_impl(self) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
The actual implementation of receiving a tensor and its metadata from
|
||||
the target rank.
|
||||
|
||||
Returns:
|
||||
@ -213,7 +215,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
|
||||
def block_if_full(self):
|
||||
"""
|
||||
Block the current thread if the buffer size is larger than the
|
||||
Block the current thread if the buffer size is larger than the
|
||||
threshold.
|
||||
"""
|
||||
while self.buffer_size > self.buffer_size_thresh:
|
||||
@ -222,7 +224,7 @@ class PyNcclPipe(KVPipeBase):
|
||||
|
||||
def send_tensor(self, tensor: Optional[torch.Tensor]) -> None:
|
||||
"""
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
Sends a tensor and its metadata to the destination rank in a
|
||||
non-blocking way.
|
||||
|
||||
Parameters:
|
||||
|
@ -5,6 +5,7 @@
|
||||
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
import dataclasses
|
||||
import datetime
|
||||
import pickle
|
||||
import time
|
||||
from collections import deque
|
||||
@ -217,6 +218,7 @@ class StatelessProcessGroup:
|
||||
rank: int,
|
||||
world_size: int,
|
||||
data_expiration_seconds: int = 3600,
|
||||
store_timeout: int = 300,
|
||||
) -> "StatelessProcessGroup":
|
||||
"""A replacement for `torch.distributed.init_process_group` that does not
|
||||
pollute the global state.
|
||||
@ -238,6 +240,7 @@ class StatelessProcessGroup:
|
||||
port=port,
|
||||
world_size=world_size,
|
||||
is_master=(rank == 0),
|
||||
timeout=datetime.timedelta(seconds=store_timeout),
|
||||
)
|
||||
|
||||
return StatelessProcessGroup(
|
||||
|
@ -50,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
# We prefer to use separate k_scale and v_scale if present
|
||||
k_scale = layer.k_scale.to("cpu").tolist()
|
||||
v_scale = layer.v_scale.to("cpu").tolist()
|
||||
if current_platform.is_rocm():
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
||||
@ -66,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
||||
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
||||
k_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
v_scale = scale_to_duplicate.to("cpu").tolist()
|
||||
if current_platform.is_rocm():
|
||||
if current_platform.is_fp8_fnuz():
|
||||
k_scale *= 2
|
||||
v_scale *= 2
|
||||
|
||||
|
@ -1330,11 +1330,14 @@ class GGUFModelLoader(BaseModelLoader):
|
||||
local_model_path, gguf_weights_map):
|
||||
model_config.hf_config.update({"tie_word_embeddings": True})
|
||||
|
||||
target_device = torch.device(device_config.device)
|
||||
with set_default_torch_dtype(model_config.dtype):
|
||||
with torch.device(device_config.device):
|
||||
with target_device:
|
||||
model = _initialize_model(vllm_config=vllm_config)
|
||||
model.load_weights(
|
||||
self._get_weights_iterator(local_model_path, gguf_weights_map))
|
||||
|
||||
_process_weights_after_loading(model, model_config, target_device)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -91,13 +91,19 @@ class TpuPlatform(Platform):
|
||||
parallel_config = vllm_config.parallel_config
|
||||
scheduler_config = vllm_config.scheduler_config
|
||||
if parallel_config.worker_cls == "auto":
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
else:
|
||||
if scheduler_config.is_multi_step:
|
||||
if scheduler_config.is_multi_step:
|
||||
if envs.VLLM_USE_V1:
|
||||
raise NotImplementedError(
|
||||
"Multi-step scheduling is not supported (and not "
|
||||
"needed) on vLLM V1. Please launch without "
|
||||
"--num-scheduler-steps.")
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker"
|
||||
else:
|
||||
if envs.VLLM_USE_V1:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.v1.worker.tpu_worker.TPUWorker"
|
||||
else:
|
||||
parallel_config.worker_cls = \
|
||||
"vllm.worker.tpu_worker.TPUWorker"
|
||||
|
@ -23,6 +23,7 @@ from vllm.utils import (get_exception_traceback, resolve_obj_by_qualname,
|
||||
zmq_socket_ctx)
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.core.sched.scheduler import Scheduler as V1Scheduler
|
||||
from vllm.v1.engine import (EngineCoreOutputs, EngineCoreRequest,
|
||||
EngineCoreRequestType, UtilityOutput)
|
||||
from vllm.v1.engine.mm_input_cache import MMInputCacheServer
|
||||
@ -67,15 +68,21 @@ class EngineCore:
|
||||
|
||||
# Setup scheduler.
|
||||
if isinstance(vllm_config.scheduler_config.scheduler_cls, str):
|
||||
Scheduler = resolve_obj_by_qualname(
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
else:
|
||||
Scheduler = vllm_config.scheduler_config.scheduler_cls
|
||||
|
||||
# This warning can be removed once the V1 Scheduler interface is
|
||||
# finalized and we can maintain support for scheduler classes that
|
||||
# implement it
|
||||
if Scheduler is not V1Scheduler:
|
||||
logger.warning(
|
||||
"Using configured V1 scheduler class %s. "
|
||||
"This scheduler interface is not public and "
|
||||
"compatibility may not be maintained.",
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
Scheduler = resolve_obj_by_qualname(
|
||||
vllm_config.scheduler_config.scheduler_cls)
|
||||
else:
|
||||
Scheduler = vllm_config.scheduler_config.scheduler_cls
|
||||
|
||||
self.scheduler = Scheduler(
|
||||
scheduler_config=vllm_config.scheduler_config,
|
||||
model_config=vllm_config.model_config,
|
||||
|
Reference in New Issue
Block a user