mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
44 Commits
codex/remo
...
woosuk/fix
Author | SHA1 | Date | |
---|---|---|---|
936da0f740 | |||
20098c10d9 | |||
ee7a66dd9a | |||
431535b522 | |||
711e912946 | |||
e69e0b8b5f | |||
ddc9048394 | |||
b1a63d1b3b | |||
48ecb4438b | |||
e57fc15971 | |||
4bdf400218 | |||
7852b82b93 | |||
a2a5f79e09 | |||
c59a0eca42 | |||
b716ab93a7 | |||
138f0d1e75 | |||
2506ce5189 | |||
47fd08aaf9 | |||
12aed7e453 | |||
d90e212a3a | |||
2821986450 | |||
6c117cff7d | |||
7ac67ea525 | |||
ce75e15373 | |||
aed16879a9 | |||
cf278ff3b2 | |||
838d7116ba | |||
5089fd749c | |||
a3d087adec | |||
058525b997 | |||
1dfea5f4a9 | |||
cea91a32f2 | |||
a684c0124c | |||
f2718d2948 | |||
825fdb11ad | |||
8c1d4acbfe | |||
486c5599e3 | |||
a6149aa587 | |||
6c8a3c099b | |||
31a8a2a7bc | |||
1a0a04dae9 | |||
6d8246aaff | |||
9d1c50a5ac | |||
9a4600e4dc |
13
.github/CODEOWNERS
vendored
13
.github/CODEOWNERS
vendored
@ -66,18 +66,25 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
/tests/models/test_transformers.py @hmellor
|
||||
|
||||
# Docs
|
||||
/docs @hmellor
|
||||
/docs/mkdocs @hmellor
|
||||
/docs/**/*.yml @hmellor
|
||||
/requirements/docs.txt @hmellor
|
||||
.readthedocs.yaml @hmellor
|
||||
mkdocs.yaml @hmellor
|
||||
|
||||
# Linting
|
||||
.markdownlint.yaml @hmellor
|
||||
.pre-commit-config.yaml @hmellor
|
||||
|
||||
# CPU
|
||||
/vllm/v1/worker/^cpu @bigPYJ1151
|
||||
/vllm/v1/worker/cpu* @bigPYJ1151
|
||||
/csrc/cpu @bigPYJ1151
|
||||
/vllm/platforms/cpu.py @bigPYJ1151
|
||||
/cmake/cpu_extension.cmake @bigPYJ1151
|
||||
/docker/Dockerfile.cpu @bigPYJ1151
|
||||
|
||||
# Intel GPU
|
||||
/vllm/v1/worker/^xpu @jikunshang
|
||||
/vllm/v1/worker/xpu* @jikunshang
|
||||
/vllm/platforms/xpu.py @jikunshang
|
||||
/docker/Dockerfile.xpu @jikunshang
|
||||
|
||||
|
@ -49,7 +49,7 @@ repos:
|
||||
rev: 0.6.17
|
||||
hooks:
|
||||
- id: pip-compile
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128]
|
||||
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128, --python-platform, x86_64-manylinux_2_28]
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
|
@ -11,13 +11,13 @@ from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import triton
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
_w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.triton_utils import triton
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
mp.set_start_method("spawn", force=True)
|
||||
|
38
csrc/launch_bounds_utils.h
Normal file
38
csrc/launch_bounds_utils.h
Normal file
@ -0,0 +1,38 @@
|
||||
#pragma once
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include <algorithm>
|
||||
|
||||
// maximum blocks per SM cap
|
||||
#ifndef VLLM_LAUNCH_BLOCKS_CAP
|
||||
#define VLLM_LAUNCH_BLOCKS_CAP 4
|
||||
#endif
|
||||
|
||||
// compile-time estimate of max threads per SM for launch bounds.
|
||||
#ifndef VLLM_MAX_THREADS_PER_SM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 300
|
||||
#define VLLM_MAX_THREADS_PER_SM 1536
|
||||
#else
|
||||
#define VLLM_MAX_THREADS_PER_SM 2048
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// compute the number of blocks per SM to request in __launch_bounds__
|
||||
#define VLLM_BLOCKS_DIV(VAL) (VLLM_MAX_THREADS_PER_SM / (VAL))
|
||||
#define VLLM_CLAMP_BLOCKS_PER_SM(VAL) \
|
||||
(((VAL) <= 0) \
|
||||
? 1 \
|
||||
: (((VAL) < VLLM_LAUNCH_BLOCKS_CAP) ? (VAL) : VLLM_LAUNCH_BLOCKS_CAP))
|
||||
#define VLLM_BLOCKS_PER_SM(BLOCK_THREADS) \
|
||||
VLLM_CLAMP_BLOCKS_PER_SM(VLLM_BLOCKS_DIV(BLOCK_THREADS))
|
||||
|
||||
// runtime-time helper to compute blocks/SM
|
||||
static inline int vllm_runtime_blocks_per_sm(int block_threads) {
|
||||
int device = -1;
|
||||
cudaGetDevice(&device);
|
||||
int max_threads_per_sm = VLLM_MAX_THREADS_PER_SM;
|
||||
cudaDeviceGetAttribute(&max_threads_per_sm,
|
||||
cudaDevAttrMaxThreadsPerMultiProcessor, device);
|
||||
int blocks = (block_threads > 0) ? (max_threads_per_sm / block_threads) : 1;
|
||||
return VLLM_CLAMP_BLOCKS_PER_SM(blocks);
|
||||
}
|
@ -26,6 +26,7 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
@ -63,7 +64,7 @@ __inline__ __device__ PackedVec<Type> compute_silu_mul(PackedVec<Type>& vec,
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out,
|
||||
uint32_t* SFout) {
|
||||
@ -131,7 +132,8 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d]
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 1024));
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
VLLM_DISPATCH_HALF_TYPES(
|
||||
|
@ -26,12 +26,13 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "nvfp4_utils.cuh"
|
||||
#include "launch_bounds_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(512, 4)
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
@ -129,7 +130,7 @@ __global__ void __launch_bounds__(512, 4)
|
||||
|
||||
// Kernel for LARGE_M_TOPK = true (large m_topk optimized version)
|
||||
template <class Type, bool UE8M0_SF = false, bool SMALL_NUM_EXPERTS = false>
|
||||
__global__ void __launch_bounds__(1024, 4)
|
||||
__global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout,
|
||||
uint32_t* input_offset_by_experts,
|
||||
@ -233,8 +234,9 @@ void quant_impl(void* output, void* output_scale, void* input,
|
||||
int const workSizePerRow = k / ELTS_PER_THREAD;
|
||||
int const totalWorkSize = m_topk * workSizePerRow;
|
||||
dim3 block(std::min(workSizePerRow, 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
// Get number of blocks per SM
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(static_cast<int>((totalWorkSize + block.x - 1) / block.x),
|
||||
multiProcessorCount * numBlocksPerSM));
|
||||
while (grid.x <= multiProcessorCount && block.x > 64) {
|
||||
|
@ -26,13 +26,14 @@
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
#include "launch_bounds_utils.h"
|
||||
#include "nvfp4_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Use UE4M3 by default.
|
||||
template <class Type, bool UE8M0_SF = false>
|
||||
__global__ void __launch_bounds__(512, 4)
|
||||
__global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512))
|
||||
cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, Type const* in,
|
||||
float const* SFScale, uint32_t* out, uint32_t* SFout) {
|
||||
using PackedVec = PackedVec<Type>;
|
||||
@ -75,8 +76,9 @@ void invokeFP4Quantization(int m, int n, T const* input, float const* SFScale,
|
||||
// Grid, Block size.
|
||||
// Each thread converts 8 values.
|
||||
dim3 block(std::min(int(n / ELTS_PER_THREAD), 512));
|
||||
// Get number of blocks per SM (assume we can fully utilize the SM).
|
||||
int const numBlocksPerSM = 2048 / block.x;
|
||||
// Get number of blocks per SM
|
||||
int const numBlocksPerSM =
|
||||
vllm_runtime_blocks_per_sm(static_cast<int>(block.x));
|
||||
dim3 grid(std::min(int(m), multiProcessorCount * numBlocksPerSM));
|
||||
|
||||
// Launch the cvt kernel.
|
||||
|
@ -59,7 +59,7 @@ enabling the corresponding APIs:
|
||||
#### Predefined models
|
||||
|
||||
If the [Pooler][vllm.model_executor.layers.pooler.Pooler] defined by the model accepts `pooler_config`,
|
||||
you can override some of its attributes via the `--override-pooler-config` option.
|
||||
you can override some of its attributes via the `--pooler-config` option.
|
||||
|
||||
#### Converted models
|
||||
|
||||
@ -75,7 +75,7 @@ the pooler assigned to each task has the following attributes by default:
|
||||
When loading [Sentence Transformers](https://huggingface.co/sentence-transformers) models,
|
||||
its Sentence Transformers configuration file (`modules.json`) takes priority over the model's defaults.
|
||||
|
||||
You can further customize this via the `--override-pooler-config` option,
|
||||
You can further customize this via the `--pooler-config` option,
|
||||
which takes priority over both the model's and Sentence Transformers's defaults.
|
||||
|
||||
## Offline Inference
|
||||
|
@ -17,9 +17,24 @@ These models are what we list in [supported-text-models][supported-text-models]
|
||||
|
||||
### Transformers
|
||||
|
||||
vLLM also supports model implementations that are available in Transformers. This does not currently work for all models, but most decoder language models and common vision language models are supported! Vision-language models currently accept only image inputs. Support for video inputs will be added in future releases.
|
||||
vLLM also supports model implementations that are available in Transformers. You should expect the performance of a Transformers model implementation used in vLLM to be within <1% of the performance of a dedicated vLLM model implementation. We call this feature the "Transformers backend".
|
||||
|
||||
To check if the modeling backend is Transformers, you can simply do this:
|
||||
Currently, the Transformers backend works for the following:
|
||||
|
||||
- Modalities: embedding models, language models and vision-language models*
|
||||
- Architectures: encoder-only, decoder-only
|
||||
- Attention types: full attention and/or sliding attention
|
||||
|
||||
_*Vision-language models currently accept only image inputs. Support for video inputs will be added in a future release._
|
||||
|
||||
If the Transformers model implementation follows all the steps in [writing a custom model](#writing-custom-models) then, when used with the Transformers backend, it will be compatible with the following features of vLLM:
|
||||
|
||||
- All the features listed in the [compatibility matrix](../features/compatibility_matrix.md#feature-x-feature)
|
||||
- Any combination of the following vLLM parallelisation schemes:
|
||||
- Pipeline parallel
|
||||
- Tensor parallel
|
||||
|
||||
Checking if the modeling backend is Transformers is as simple as:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
@ -27,16 +42,12 @@ llm = LLM(model=...) # Name or path of your model
|
||||
llm.apply_model(lambda model: print(type(model)))
|
||||
```
|
||||
|
||||
If it is `TransformersForCausalLM` or `TransformersForMultimodalLM` then it means it's based on Transformers!
|
||||
If the printed type starts with `Transformers...` then it's using the Transformers model implementation!
|
||||
|
||||
!!! tip
|
||||
You can force the use of `TransformersForCausalLM` by setting `model_impl="transformers"` for [offline-inference](../serving/offline_inference.md) or `--model-impl transformers` for the [openai-compatible-server](../serving/openai_compatible_server.md).
|
||||
If a model has a vLLM implementation but you would prefer to use the Transformers implementation via the Transformers backend, set `model_impl="transformers"` for [offline inference](../serving/offline_inference.md) or `--model-impl transformers` for the [online serving](../serving/openai_compatible_server.md).
|
||||
|
||||
!!! note
|
||||
vLLM may not fully optimise the Transformers implementation so you may see degraded performance if comparing a native model to a Transformers model in vLLM.
|
||||
|
||||
!!! note
|
||||
In case of vision language models if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
|
||||
For vision-language models, if you are loading with `dtype="auto"`, vLLM loads the whole model with config's `dtype` if it exists. In contrast the native Transformers will respect the `dtype` attribute of each backbone in the model. That might cause a slight difference in performance.
|
||||
|
||||
#### Custom models
|
||||
|
||||
@ -66,10 +77,11 @@ This section details the necessary modifications to make to a Transformers compa
|
||||
To make your model compatible with the Transformers backend, it needs:
|
||||
|
||||
1. `kwargs` passed down through all modules from `MyModel` to `MyAttention`.
|
||||
1. If your model is encoder-only, you must also add `is_causal = False` to `MyAttention`.
|
||||
2. `MyAttention` must use `ALL_ATTENTION_FUNCTIONS` to call attention.
|
||||
3. `MyModel` must contain `_supports_attention_backend = True`.
|
||||
|
||||
<details>
|
||||
<details class="code">
|
||||
<summary>modeling_my_model.py</summary>
|
||||
|
||||
```python
|
||||
@ -78,6 +90,7 @@ from transformers import PreTrainedModel
|
||||
from torch import nn
|
||||
|
||||
class MyAttention(nn.Module):
|
||||
is_causal = False # Only do this for encoder-only models
|
||||
|
||||
def forward(self, hidden_states, **kwargs):
|
||||
...
|
||||
@ -101,13 +114,13 @@ Here is what happens in the background when this model is loaded:
|
||||
|
||||
1. The config is loaded.
|
||||
2. `MyModel` Python class is loaded from the `auto_map` in config, and we check that the model `is_backend_compatible()`.
|
||||
3. `MyModel` is loaded into `TransformersForCausalLM` or `TransformersForMultimodalLM` (see <gh-file:vllm/model_executor/models/transformers.py>) which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
|
||||
3. `MyModel` is loaded into one of the Transformers backend classes in <gh-file:vllm/model_executor/models/transformers.py> which sets `self.config._attn_implementation = "vllm"` so that vLLM's attention layer is used.
|
||||
|
||||
That's it!
|
||||
|
||||
For your model to be compatible with vLLM's tensor parallel and/or pipeline parallel features, you must add `base_model_tp_plan` and/or `base_model_pp_plan` to your model's config class:
|
||||
|
||||
<details>
|
||||
<details class="code">
|
||||
<summary>configuration_my_model.py</summary>
|
||||
|
||||
```python
|
||||
@ -457,7 +470,7 @@ These models primarily support the [`LLM.embed`](./pooling_models.md#llmembed) A
|
||||
|
||||
!!! note
|
||||
`ssmits/Qwen2-7B-Instruct-embed-base` has an improperly defined Sentence Transformers config.
|
||||
You need to manually set mean pooling by passing `--override-pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
You need to manually set mean pooling by passing `--pooler-config '{"pooling_type": "MEAN"}'`.
|
||||
|
||||
!!! note
|
||||
For `Alibaba-NLP/gte-Qwen2-*`, you need to enable `--trust-remote-code` for the correct tokenizer to be loaded.
|
||||
@ -552,7 +565,7 @@ If your model is not in the above list, we will try to automatically convert the
|
||||
|
||||
!!! important
|
||||
For process-supervised reward models such as `peiyi9979/math-shepherd-mistral-7b-prm`, the pooling config should be set explicitly,
|
||||
e.g.: `--override-pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
e.g.: `--pooler-config '{"pooling_type": "STEP", "step_tag_id": 123, "returned_token_ids": [456, 789]}'`.
|
||||
|
||||
#### Token Classification
|
||||
|
||||
|
@ -42,7 +42,7 @@ python client.py
|
||||
|
||||
### Server Configuration
|
||||
|
||||
The key parameters for chunked processing are in the `--override-pooler-config`:
|
||||
The key parameters for chunked processing are in the `--pooler-config`:
|
||||
|
||||
```json
|
||||
{
|
||||
|
@ -13,7 +13,7 @@ Prerequisites:
|
||||
|
||||
# MEAN pooling (processes all chunks, recommended for complete coverage)
|
||||
vllm serve intfloat/multilingual-e5-large \
|
||||
--override-pooler-config \
|
||||
--pooler-config \
|
||||
'{"pooling_type": "MEAN", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 3072000}' \
|
||||
--served-model-name multilingual-e5-large \
|
||||
@ -23,7 +23,7 @@ Prerequisites:
|
||||
|
||||
# OR CLS pooling (native CLS within chunks, MEAN aggregation across chunks)
|
||||
vllm serve BAAI/bge-large-en-v1.5 \
|
||||
--override-pooler-config \
|
||||
--pooler-config \
|
||||
'{"pooling_type": "CLS", "normalize": true, ' \
|
||||
'"enable_chunked_processing": true, "max_embed_len": 1048576}' \
|
||||
--served-model-name bge-large-en-v1.5 \
|
||||
|
@ -103,7 +103,7 @@ POOLER_CONFIG="{\"pooling_type\": \"$POOLING_TYPE\", \"normalize\": true, \"enab
|
||||
vllm serve "$MODEL_NAME" \
|
||||
--tensor-parallel-size "$GPU_COUNT" \
|
||||
--enforce-eager \
|
||||
--override-pooler-config "$POOLER_CONFIG" \
|
||||
--pooler-config "$POOLER_CONFIG" \
|
||||
--served-model-name ${MODEL_CODE} \
|
||||
--api-key "$API_KEY" \
|
||||
--trust-remote-code \
|
||||
|
@ -1,8 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
import dataclasses
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@ -327,12 +325,7 @@ def main():
|
||||
|
||||
|
||||
if args.command == "serialize":
|
||||
eng_args_dict = {f.name: getattr(args, f.name) for f in
|
||||
dataclasses.fields(EngineArgs)}
|
||||
|
||||
engine_args = EngineArgs.from_cli_args(
|
||||
argparse.Namespace(**eng_args_dict)
|
||||
)
|
||||
engine_args = EngineArgs.from_cli_args(args)
|
||||
|
||||
input_dir = tensorizer_dir.rstrip('/')
|
||||
suffix = args.suffix if args.suffix else uuid.uuid4().hex
|
||||
|
@ -24,7 +24,7 @@ outlines_core == 0.2.11
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.23; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
xgrammar == 0.1.24; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
|
@ -1,5 +1,5 @@
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128
|
||||
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128 --python-platform x86_64-manylinux_2_28
|
||||
absl-py==2.1.0
|
||||
# via rouge-score
|
||||
accelerate==1.0.1
|
||||
|
@ -76,11 +76,6 @@ def test_models(
|
||||
model_executor: str,
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if not envs.VLLM_USE_V1:
|
||||
if async_scheduling:
|
||||
pytest.skip("async_scheduling only supported in v1.")
|
||||
@ -164,11 +159,6 @@ def test_models_distributed(
|
||||
extra_env: dict[str, str],
|
||||
enable_prompt_embeds: bool,
|
||||
) -> None:
|
||||
|
||||
if enable_prompt_embeds and envs.is_set(
|
||||
"VLLM_USE_V1") and envs.VLLM_USE_V1:
|
||||
pytest.skip("enable_prompt_embeds is not supported in v1.")
|
||||
|
||||
if test_suite != TARGET_TEST_SUITE:
|
||||
pytest.skip(f"Skip test for {test_suite}")
|
||||
|
||||
|
@ -39,7 +39,8 @@ from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
from vllm.assets.image import ImageAsset
|
||||
from vllm.assets.video import VideoAsset
|
||||
from vllm.config import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||||
from vllm.config.model import (ConvertOption, RunnerOption,
|
||||
_get_and_verify_dtype)
|
||||
from vllm.connections import global_http_connection
|
||||
from vllm.distributed import (cleanup_dist_env_and_memory,
|
||||
init_distributed_environment,
|
||||
@ -244,39 +245,6 @@ class DecoderPromptType(Enum):
|
||||
EMPTY_STR = 3
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_encoder_decoder_prompts(
|
||||
) -> dict[DecoderPromptType, list[ExplicitEncoderDecoderPrompt]]:
|
||||
'''
|
||||
Returns an encoder prompt list and a decoder prompt list, wherein each pair
|
||||
of same-index entries in both lists corresponds to an (encoder prompt,
|
||||
decoder prompt) tuple.
|
||||
|
||||
Returns:
|
||||
|
||||
* Encoder prompt list
|
||||
* Decoder prompt list (reverse of encoder prompt list)
|
||||
'''
|
||||
|
||||
encoder_prompts = []
|
||||
for filename in _TEST_PROMPTS:
|
||||
encoder_prompts += _read_prompts(filename)
|
||||
|
||||
custom_decoder_prompts = encoder_prompts[::-1]
|
||||
empty_str_decoder_prompts = [""] * len(encoder_prompts)
|
||||
none_decoder_prompts = [None] * len(encoder_prompts)
|
||||
|
||||
# NONE decoder prompt type
|
||||
return {
|
||||
DecoderPromptType.NONE:
|
||||
zip_enc_dec_prompts(encoder_prompts, none_decoder_prompts),
|
||||
DecoderPromptType.EMPTY_STR:
|
||||
zip_enc_dec_prompts(encoder_prompts, empty_str_decoder_prompts),
|
||||
DecoderPromptType.CUSTOM:
|
||||
zip_enc_dec_prompts(encoder_prompts, custom_decoder_prompts),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_long_prompts() -> list[str]:
|
||||
prompts = []
|
||||
@ -690,68 +658,6 @@ class HfRunner:
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs_limit(
|
||||
self,
|
||||
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: Optional[int],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
**kwargs: Any,
|
||||
) -> list[TokensTextLogprobs]:
|
||||
'''
|
||||
Greedy logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
all_logprobs: list[list[dict[int, float]]] = []
|
||||
all_output_ids: list[list[int]] = []
|
||||
all_output_strs: list[str] = []
|
||||
|
||||
for i, (encoder_prompt, decoder_prompt) in enumerate(
|
||||
to_enc_dec_tuple_list(encoder_decoder_prompts)):
|
||||
processor_kwargs: dict[str, Any] = {
|
||||
"text": encoder_prompt,
|
||||
"return_tensors": "pt",
|
||||
}
|
||||
if images is not None and images[i] is not None:
|
||||
processor_kwargs["images"] = images[i]
|
||||
|
||||
encoder_inputs = self.processor(**processor_kwargs)
|
||||
encoder_inputs = self.wrap_device(encoder_inputs)
|
||||
|
||||
if decoder_prompt is None:
|
||||
decoder_input_ids = None
|
||||
else:
|
||||
decoder_inputs = self.tokenizer(decoder_prompt,
|
||||
return_tensors="pt")
|
||||
decoder_input_ids = self.wrap_device(decoder_inputs.input_ids)
|
||||
|
||||
output = self.model.generate(
|
||||
decoder_input_ids=decoder_input_ids,
|
||||
use_cache=True,
|
||||
do_sample=False,
|
||||
max_new_tokens=max_tokens,
|
||||
output_hidden_states=True,
|
||||
return_dict_in_generate=True,
|
||||
**encoder_inputs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
(
|
||||
seq_logprobs_lst,
|
||||
output_len,
|
||||
) = self._hidden_states_to_logprobs(output.decoder_hidden_states,
|
||||
num_logprobs)
|
||||
|
||||
all_logprobs.append(seq_logprobs_lst)
|
||||
seq_ids = output.sequences[0]
|
||||
output_ids = seq_ids[-output_len:]
|
||||
all_output_ids.append(output_ids.tolist())
|
||||
all_output_strs.append(self.tokenizer.decode(output_ids))
|
||||
|
||||
outputs = zip(all_output_ids, all_output_strs, all_logprobs)
|
||||
return [(output_ids, output_str, output_logprobs)
|
||||
for output_ids, output_str, output_logprobs in outputs]
|
||||
|
||||
def encode(self, prompts: list[str], *args,
|
||||
**kwargs) -> list[list[torch.Tensor]]:
|
||||
return self.model.encode(prompts, *args, **kwargs)
|
||||
@ -940,26 +846,6 @@ class VllmRunner:
|
||||
if sampling_params.prompt_logprobs is None else
|
||||
toks_str_logsprobs_prompt_logprobs)
|
||||
|
||||
def generate_encoder_decoder_w_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
list[TokensTextLogprobsPromptLogprobs]]:
|
||||
'''
|
||||
Logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
assert sampling_params.logprobs is not None
|
||||
req_outputs = self.llm.generate(encoder_decoder_prompts,
|
||||
sampling_params=sampling_params)
|
||||
toks_str_logsprobs_prompt_logprobs = (
|
||||
self._final_steps_generate_w_logprobs(req_outputs))
|
||||
# Omit prompt logprobs if not required by sampling params
|
||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None else
|
||||
toks_str_logsprobs_prompt_logprobs)
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
@ -1037,29 +923,6 @@ class VllmRunner:
|
||||
|
||||
return perplexities
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: list[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: Optional[int],
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
skip_special_tokens: bool = True,
|
||||
) -> Union[list[TokensTextLogprobs],
|
||||
list[TokensTextLogprobsPromptLogprobs]]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=(num_prompt_logprobs),
|
||||
skip_special_tokens=skip_special_tokens,
|
||||
)
|
||||
'''
|
||||
Greedy logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
return self.generate_encoder_decoder_w_logprobs(
|
||||
encoder_decoder_prompts, greedy_logprobs_params)
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: list[str],
|
||||
|
@ -14,7 +14,7 @@ from typing import Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
|
||||
from vllm.config.model import _FLOAT16_NOT_SUPPORTED_MODELS, RunnerOption
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
|
@ -36,7 +36,6 @@ def default_server_args() -> list[str]:
|
||||
"--enforce-eager",
|
||||
# Prompt Embeds server args
|
||||
"--enable-prompt-embeds",
|
||||
"--no-enable-chunked-prefill",
|
||||
]
|
||||
|
||||
|
||||
|
@ -287,6 +287,57 @@ async def test_stateful_multi_turn(client: OpenAI, model_name: str):
|
||||
assert response3.status == "completed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_streaming_types(client: OpenAI, model_name: str):
|
||||
prompts = [
|
||||
"tell me a story about a cat in 20 words",
|
||||
]
|
||||
|
||||
# this links the "done" type with the "start" type
|
||||
# so every "done" type should have a corresponding "start" type
|
||||
# and every open block should be closed by the end of the stream
|
||||
pairs_of_event_types = {
|
||||
"response.completed": "response.created",
|
||||
"response.output_item.done": "response.output_item.added",
|
||||
"response.content_part.done": "response.content_part.added",
|
||||
"response.output_text.done": "response.output_text.delta",
|
||||
"response.web_search_call.done": "response.web_search_call.added",
|
||||
"response.reasoning_text.done": "response.reasoning_text.delta",
|
||||
"response.reasoning_part.done": "response.reasoning_part.added",
|
||||
}
|
||||
|
||||
for prompt in prompts:
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input=prompt,
|
||||
reasoning={"effort": "low"},
|
||||
tools=[],
|
||||
stream=True,
|
||||
background=False,
|
||||
)
|
||||
|
||||
stack_of_event_types = []
|
||||
async for event in response:
|
||||
if event.type == 'response.created':
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type == 'response.completed':
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[
|
||||
event.type]
|
||||
stack_of_event_types.pop()
|
||||
if event.type.endswith("added"):
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("delta"):
|
||||
if stack_of_event_types[-1] == event.type:
|
||||
continue
|
||||
stack_of_event_types.append(event.type)
|
||||
elif event.type.endswith("done"):
|
||||
assert stack_of_event_types[-1] == pairs_of_event_types[
|
||||
event.type]
|
||||
stack_of_event_types.pop()
|
||||
assert len(stack_of_event_types) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
@pytest.mark.parametrize("background", [True, False])
|
||||
@ -343,7 +394,10 @@ async def test_streaming(client: OpenAI, model_name: str, background: bool):
|
||||
assert event.item_id == current_item_id
|
||||
|
||||
# verify content_index_id is correct
|
||||
if event.type == "response.content_part.added":
|
||||
if event.type in [
|
||||
"response.content_part.added",
|
||||
"response.reasoning_part.added"
|
||||
]:
|
||||
assert event.content_index != current_content_index
|
||||
current_content_index = event.content_index
|
||||
elif event.type in [
|
||||
@ -461,6 +515,7 @@ async def test_function_calling(client: OpenAI, model_name: str):
|
||||
model=model_name,
|
||||
input="What's the weather like in Paris today?",
|
||||
tools=tools,
|
||||
temperature=0.0,
|
||||
)
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
@ -689,3 +744,18 @@ async def test_function_calling_full_history(client: OpenAI, model_name: str):
|
||||
assert response_2 is not None
|
||||
assert response_2.status == "completed"
|
||||
assert response_2.output_text is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
||||
async def test_output_messages_enabled(client: OpenAI, model_name: str,
|
||||
server):
|
||||
response = await client.responses.create(
|
||||
model=model_name,
|
||||
input="What is the capital of South Korea?",
|
||||
extra_body={"enable_response_messages": True})
|
||||
|
||||
assert response is not None
|
||||
assert response.status == "completed"
|
||||
assert len(response.input_messages) > 0
|
||||
assert len(response.output_messages) > 0
|
||||
|
@ -216,7 +216,7 @@ def server_with_chunked_processing():
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"512", # Set smaller max_model_len to trigger chunking mechanism
|
||||
'--override-pooler-config',
|
||||
'--pooler-config',
|
||||
('{"pooling_type": "MEAN", "normalize": true, '
|
||||
'"enable_chunked_processing": true, "max_embed_len": 10000}'),
|
||||
"--gpu-memory-utilization",
|
||||
|
@ -60,7 +60,7 @@ def test_api_server_process_manager_init(api_server_args, with_stats_update):
|
||||
global WORKER_RUNTIME_SECONDS
|
||||
WORKER_RUNTIME_SECONDS = 0.5
|
||||
|
||||
# Copy the args to avoid mutating the
|
||||
# Copy the args to avoid mutating them
|
||||
args = api_server_args.copy()
|
||||
|
||||
if not with_stats_update:
|
||||
|
@ -83,7 +83,7 @@ def ref_paged_attn(
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("sliding_window", [None, 256])
|
||||
@pytest.mark.parametrize("sliding_window", [None, 64, 128, 256])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", [None, 50.0])
|
||||
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
|
||||
|
@ -1,9 +1,12 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import NamedTuple
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from packaging.version import Version
|
||||
from transformers import AutoConfig
|
||||
from transformers import __version__ as TRANSFORMERS_VERSION
|
||||
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.platforms import current_platform
|
||||
@ -15,6 +18,7 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
|
||||
head_size: int, max_position_embeddings: int,
|
||||
dtype: torch.dtype, device: torch.device):
|
||||
"""Generate test data for given configuration."""
|
||||
current_platform.seed_everything(42)
|
||||
# Create 2D positions (3, num_tokens) for multimodal case
|
||||
positions = torch.randint(0,
|
||||
max_position_embeddings // 4, (3, num_tokens),
|
||||
@ -33,22 +37,37 @@ def generate_test_data(num_tokens: int, num_q_heads: int, num_kv_heads: int,
|
||||
return positions, query, key
|
||||
|
||||
|
||||
def unroll_model_tp_dict(model_tp_dict):
|
||||
return [(model_name, tp_size)
|
||||
for model_name, tp_sizes in model_tp_dict.items()
|
||||
for tp_size in tp_sizes]
|
||||
class MRoPETestInfo(NamedTuple):
|
||||
model_name: str
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
atol: float = 1e-2
|
||||
rtol: float = 1.6e-2
|
||||
marks: list[pytest.MarkDecorator] = []
|
||||
|
||||
|
||||
model_tp_dict = {
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2-VL-72B-Instruct": [1, 2],
|
||||
"Qwen/Qwen2.5-VL-72B-Instruct": [1, 2],
|
||||
"zai-org/GLM-4.1V-9B-Thinking": [1, 2],
|
||||
}
|
||||
TRANSFORMERS_BASE_VERSION = Version(TRANSFORMERS_VERSION).base_version
|
||||
|
||||
# https://github.com/pytorch/pytorch/blob/main/torch/testing/_comparison.py#L1317
|
||||
dtype_atol_rtol_list = [
|
||||
[torch.bfloat16, 1e-2, 1.6e-2],
|
||||
MODELS_TO_TEST = [
|
||||
MRoPETestInfo(model_name="zai-org/GLM-4.1V-9B-Thinking"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-7B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2-VL-72B-Instruct"),
|
||||
MRoPETestInfo(model_name="Qwen/Qwen2.5-VL-72B-Instruct"),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-4B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
]),
|
||||
MRoPETestInfo(
|
||||
model_name="Qwen/Qwen3-VL-30B-A3B-Instruct",
|
||||
marks=[
|
||||
pytest.mark.skipif(
|
||||
Version(TRANSFORMERS_BASE_VERSION) < Version("4.57.0"),
|
||||
reason="Qwen3-VL only available after Transformers v4.57",
|
||||
)
|
||||
]),
|
||||
]
|
||||
|
||||
num_tokens_list = [11, 8192]
|
||||
@ -56,20 +75,29 @@ num_tokens_list = [11, 8192]
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize("model_name, tp_size",
|
||||
unroll_model_tp_dict(model_tp_dict))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("model_info, model_name", [
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
def test_mrope(model_name: str, model_info: MRoPETestInfo, tp_size: int,
|
||||
dtype: torch.dtype, num_tokens: int):
|
||||
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = config.get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
head_dim = (config.head_dim if hasattr(config, "head_dim") else
|
||||
config.hidden_size // total_num_heads)
|
||||
is_neox_style = True
|
||||
|
||||
rope_theta = config.rope_theta
|
||||
@ -111,24 +139,30 @@ def test_mrope(model_name, tp_size, dtype, atol, rtol, num_tokens):
|
||||
|
||||
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
|
||||
reason="Skipping CUDA/ROCm only tests.")
|
||||
@pytest.mark.parametrize(
|
||||
"model_name, tp_size",
|
||||
unroll_model_tp_dict({
|
||||
"Qwen/Qwen2-VL-7B-Instruct": [1, 2],
|
||||
"zai-org/GLM-4.1V-9B-Thinking": [1, 2]
|
||||
}))
|
||||
@pytest.mark.parametrize("dtype, atol, rtol", dtype_atol_rtol_list)
|
||||
@pytest.mark.parametrize("num_tokens", [4])
|
||||
def test_mrope_torch_compile_tracing(model_name, tp_size, dtype, atol, rtol,
|
||||
num_tokens):
|
||||
@pytest.mark.parametrize("model_info, model_name", [
|
||||
pytest.param(test_config, test_config.model_name, marks=test_config.marks)
|
||||
for test_config in MODELS_TO_TEST
|
||||
])
|
||||
@pytest.mark.parametrize("tp_size", [1, 2])
|
||||
@pytest.mark.parametrize("dtype", [torch.bfloat16])
|
||||
@pytest.mark.parametrize("num_tokens", num_tokens_list)
|
||||
def test_mrope_torch_compile_tracing(model_name: str,
|
||||
model_info: MRoPETestInfo, tp_size: int,
|
||||
dtype: torch.dtype, num_tokens: int):
|
||||
|
||||
atol = model_info.atol
|
||||
rtol = model_info.rtol
|
||||
|
||||
config = AutoConfig.from_pretrained(model_name)
|
||||
config = config.get_text_config()
|
||||
|
||||
# get the model config
|
||||
total_num_kv_heads = config.num_key_value_heads
|
||||
total_num_heads = config.num_attention_heads
|
||||
num_heads = total_num_heads // tp_size
|
||||
num_kv_heads = max(1, total_num_kv_heads // tp_size)
|
||||
head_dim = config.hidden_size // total_num_heads
|
||||
head_dim = (config.head_dim if hasattr(config, "head_dim") else
|
||||
config.hidden_size // total_num_heads)
|
||||
is_neox_style = True
|
||||
rope_theta = config.rope_theta
|
||||
max_position = config.max_position_embeddings
|
||||
|
@ -11,7 +11,8 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.config import VllmConfig, current_platform, set_current_vllm_config
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
|
@ -8,11 +8,12 @@ import torch
|
||||
from safetensors.torch import load_file
|
||||
from torch import nn
|
||||
|
||||
from vllm.config import ModelConfig, VllmConfig
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.lora.layers import (ColumnParallelLinearWithLoRA,
|
||||
MergedColumnParallelLinearWithLoRA,
|
||||
RowParallelLinearWithLoRA)
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.models import (LoRAMapping, LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager)
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
@ -435,10 +436,19 @@ def test_lru_cache_worker_adapter_manager(dist_init, dummy_model, device,
|
||||
target_modules=["layer1.dense1", "dense2"],
|
||||
lora_dtype=DEFAULT_DTYPE,
|
||||
)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
worker_adapter_manager = LRUCacheWorkerLoRAManager(
|
||||
4, 2,
|
||||
dummy_model.unpadded_vocab_size - lora_config.lora_extra_vocab_size,
|
||||
lora_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
vllm_config, device, EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
|
||||
worker_adapter_manager.max_num_seqs = 4
|
||||
worker_adapter_manager.max_num_batched_tokens = 2
|
||||
|
||||
worker_adapter_manager.create_lora_manager(dummy_model)
|
||||
|
||||
mapping = LoRAMapping([], [])
|
||||
@ -517,10 +527,20 @@ def test_worker_adapter_manager(dist_init, dummy_model_gate_up, device,
|
||||
max_cpu_loras=4,
|
||||
max_loras=4,
|
||||
lora_dtype=DEFAULT_DTYPE)
|
||||
worker_adapter_manager = WorkerLoRAManager(
|
||||
4, 2, dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size, lora_config, device,
|
||||
EMBEDDING_MODULES, EMBEDDING_PADDING_MODULES)
|
||||
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config,
|
||||
lora_config=lora_config)
|
||||
|
||||
vllm_config.scheduler_config.max_num_seqs = 4
|
||||
vllm_config.scheduler_config.max_num_batched_tokens = 2
|
||||
|
||||
worker_adapter_manager = WorkerLoRAManager(vllm_config, device,
|
||||
EMBEDDING_MODULES,
|
||||
EMBEDDING_PADDING_MODULES)
|
||||
worker_adapter_manager.vocab_size = (
|
||||
dummy_model_gate_up.unpadded_vocab_size -
|
||||
lora_config.lora_extra_vocab_size)
|
||||
worker_adapter_manager.create_lora_manager(dummy_model_gate_up)
|
||||
|
||||
dummy_lora_files = f"{tmp_path}/lora_adapter"
|
||||
|
@ -9,7 +9,7 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from safetensors.torch import save_file
|
||||
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
|
||||
|
||||
class DummyLoRAManager:
|
||||
|
@ -125,12 +125,6 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
|
||||
# in parts of the operators
|
||||
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
|
||||
|
||||
# Note: can be removed when
|
||||
# https://github.com/vllm-project/vllm/pull/24278 finished
|
||||
if current_platform.is_cpu() and use_prompt_embeds:
|
||||
pytest.skip("Skipping use_prompt_embeds=True with "
|
||||
"V1-only CPU backend.")
|
||||
|
||||
with hf_runner(model) as hf_model:
|
||||
hf_outputs = hf_model.generate_greedy_logprobs_limit(
|
||||
example_prompts, max_tokens, num_logprobs)
|
||||
|
@ -58,7 +58,7 @@ def test_models(
|
||||
|
||||
vllm_extra_kwargs = {}
|
||||
if model == "ssmits/Qwen2-7B-Instruct-embed-base":
|
||||
vllm_extra_kwargs["override_pooler_config"] = \
|
||||
vllm_extra_kwargs["pooler_config"] = \
|
||||
PoolerConfig(pooling_type="MEAN", normalize=False)
|
||||
|
||||
max_model_len: Optional[int] = 512
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from vllm.config.pooler import PoolerConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@ -99,7 +100,7 @@ def test_gemma_multimodal(
|
||||
convert="classify",
|
||||
load_format="auto",
|
||||
hf_overrides=update_config,
|
||||
override_pooler_config={"pooling_type": "LAST"},
|
||||
pooler_config=PoolerConfig(pooling_type="LAST"),
|
||||
max_model_len=512,
|
||||
enforce_eager=True,
|
||||
tensor_parallel_size=1,
|
||||
|
@ -24,18 +24,18 @@ def test_classify_models_using_activation(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(
|
||||
activation=False)) as vllm_model:
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(activation=False)) as vllm_model:
|
||||
wo_activation_out = vllm_model.classify(example_prompts)
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(
|
||||
activation=True)) as vllm_model:
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(activation=True)) as vllm_model:
|
||||
w_activation_out = vllm_model.classify(example_prompts)
|
||||
|
||||
for wo_activation, w_activation in zip(wo_activation_out,
|
||||
@ -43,9 +43,8 @@ def test_classify_models_using_activation(
|
||||
wo_activation = torch.tensor(wo_activation)
|
||||
w_activation = torch.tensor(w_activation)
|
||||
|
||||
assert not torch.allclose(
|
||||
wo_activation, w_activation,
|
||||
atol=1e-2), "override_pooler_config is not working"
|
||||
assert not torch.allclose(wo_activation, w_activation,
|
||||
atol=1e-2), "pooler_config is not working"
|
||||
assert torch.allclose(softmax(wo_activation), w_activation,
|
||||
1e-3 if dtype == "float" else 1e-2)
|
||||
|
||||
@ -65,23 +64,22 @@ def test_embed_models_using_normalize(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(
|
||||
normalize=False)) as vllm_model:
|
||||
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(normalize=True)) as vllm_model:
|
||||
pooler_config=PoolerConfig(normalize=False)) as vllm_model:
|
||||
wo_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||
|
||||
with vllm_runner(model,
|
||||
max_model_len=512,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(normalize=True)) as vllm_model:
|
||||
w_normalize = torch.tensor(vllm_model.embed(example_prompts))
|
||||
|
||||
assert not torch.allclose(
|
||||
wo_normalize, w_normalize,
|
||||
atol=1e-2), "override_pooler_config normalize is not working"
|
||||
atol=1e-2), "pooler_config normalize is not working"
|
||||
assert torch.allclose(
|
||||
F.normalize(wo_normalize, p=2, dim=-1), w_normalize,
|
||||
atol=1e-2), "w_normal should be close to normal(wo_normal)."
|
||||
@ -102,18 +100,16 @@ def test_reward_models_using_softmax(
|
||||
dtype: str,
|
||||
) -> None:
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(softmax=False)) as vllm_model:
|
||||
with vllm_runner(model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(softmax=False)) as vllm_model:
|
||||
wo_softmax = vllm_model.encode(example_prompts)
|
||||
|
||||
with vllm_runner(
|
||||
model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
override_pooler_config=PoolerConfig(softmax=True)) as vllm_model:
|
||||
with vllm_runner(model,
|
||||
max_model_len=1024,
|
||||
dtype=dtype,
|
||||
pooler_config=PoolerConfig(softmax=True)) as vllm_model:
|
||||
w_softmax = vllm_model.encode(example_prompts)
|
||||
|
||||
for wo, w in zip(wo_softmax, w_softmax):
|
||||
@ -121,7 +117,7 @@ def test_reward_models_using_softmax(
|
||||
w = torch.tensor(w)
|
||||
|
||||
assert not torch.allclose(
|
||||
wo, w, atol=1e-2), "override_pooler_config softmax is not working"
|
||||
wo, w, atol=1e-2), "pooler_config softmax is not working"
|
||||
assert torch.allclose(
|
||||
softmax(wo), w,
|
||||
atol=1e-2), "w_softmax should be close to softmax(wo_softmax)."
|
@ -7,7 +7,6 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import ModelImpl
|
||||
from vllm.engine.llm_engine import LLMEngine as V0LLMEngine
|
||||
from vllm.utils import GiB_bytes
|
||||
from vllm.v1.core.kv_cache_utils import get_kv_cache_configs
|
||||
@ -111,8 +110,8 @@ def can_initialize(model_arch: str, monkeypatch: pytest.MonkeyPatch,
|
||||
# these tests seem to produce leftover memory
|
||||
gpu_memory_utilization=0.80,
|
||||
load_format="dummy",
|
||||
model_impl=ModelImpl.TRANSFORMERS
|
||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else ModelImpl.VLLM,
|
||||
model_impl="transformers"
|
||||
if model_arch in _TRANSFORMERS_BACKEND_MODELS else "vllm",
|
||||
hf_overrides=hf_overrides_fn,
|
||||
max_num_seqs=model_info.max_num_seqs)
|
||||
|
||||
|
@ -9,7 +9,7 @@ from vllm.platforms import current_platform
|
||||
|
||||
from ..conftest import HfRunner, VllmRunner
|
||||
from ..utils import multi_gpu_test, prep_prompts
|
||||
from .utils import check_logprobs_close
|
||||
from .utils import check_embeddings_close, check_logprobs_close
|
||||
|
||||
|
||||
def check_implementation(
|
||||
@ -165,6 +165,40 @@ def test_embed_loading(vllm_runner, model):
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
# Encoder model
|
||||
"BAAI/bge-base-en-v1.5",
|
||||
])
|
||||
def test_embed_correctness(hf_runner, vllm_runner, example_prompts, model):
|
||||
import transformers
|
||||
from packaging.version import Version
|
||||
installed = Version(transformers.__version__)
|
||||
required = Version("4.57.0.dev0")
|
||||
if installed < required:
|
||||
pytest.skip("Encoder models with the Transformers backend require "
|
||||
f"transformers>={required}, but got {installed}")
|
||||
|
||||
with vllm_runner(model, max_model_len=512,
|
||||
model_impl="transformers") as vllm_model:
|
||||
model_config = vllm_model.llm.llm_engine.model_config
|
||||
assert model_config.using_transformers_backend()
|
||||
|
||||
vllm_outputs = vllm_model.embed(example_prompts)
|
||||
|
||||
with hf_runner(model, is_sentence_transformer=True) as hf_model:
|
||||
hf_outputs = hf_model.encode(example_prompts)
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=hf_outputs,
|
||||
embeddings_1_lst=vllm_outputs,
|
||||
name_0="hf",
|
||||
name_1="vllm",
|
||||
tol=1e-2,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
["jason9693/Qwen2.5-1.5B-apeach"],
|
||||
|
@ -207,25 +207,19 @@ def test_get_pooling_config():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id)
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
|
||||
assert pooling_config.normalize
|
||||
assert pooling_config.pooling_type == PoolingType.MEAN.name
|
||||
assert model_config.pooler_config is not None
|
||||
assert model_config.pooler_config.normalize
|
||||
assert model_config.pooler_config.pooling_type == PoolingType.MEAN.name
|
||||
|
||||
|
||||
@pytest.mark.skipif(current_platform.is_rocm(),
|
||||
reason="Xformers backend is not supported on ROCm.")
|
||||
def test_get_pooling_config_from_args():
|
||||
model_id = "sentence-transformers/all-MiniLM-L12-v2"
|
||||
model_config = ModelConfig(model_id)
|
||||
pooler_config = PoolerConfig(pooling_type="CLS", normalize=True)
|
||||
model_config = ModelConfig(model_id, pooler_config=pooler_config)
|
||||
|
||||
override_pooler_config = PoolerConfig(pooling_type='CLS', normalize=True)
|
||||
model_config.override_pooler_config = override_pooler_config
|
||||
|
||||
pooling_config = model_config._init_pooler_config()
|
||||
assert pooling_config is not None
|
||||
assert asdict(pooling_config) == asdict(override_pooler_config)
|
||||
assert asdict(model_config.pooler_config) == asdict(pooler_config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -513,27 +513,27 @@ def test_hash_request_tokens_no_mm_inputs(hash_fn):
|
||||
assert block_hashes[1] == hash_fn((block_hashes[0], (3, 4, 5), None))
|
||||
|
||||
|
||||
def _stats(requests: int, queries: int, hits: int) -> PrefixCacheStats:
|
||||
return PrefixCacheStats(requests=requests, queries=queries, hits=hits)
|
||||
|
||||
|
||||
def test_metrics():
|
||||
"""
|
||||
Test the prefix caching metrics.
|
||||
"""
|
||||
|
||||
def stats(requests, queries, hits):
|
||||
return PrefixCacheStats(requests=requests, queries=queries, hits=hits)
|
||||
|
||||
metrics = PrefixCachingMetrics(max_recent_requests=5)
|
||||
assert metrics.hit_rate == 0.0
|
||||
|
||||
metrics.observe(stats(1, 20, 9))
|
||||
metrics.observe(_stats(1, 20, 9))
|
||||
# 9 / 20 = 0.45
|
||||
assert metrics.hit_rate == 0.45
|
||||
|
||||
metrics.observe(stats(4, 80, 16))
|
||||
metrics.observe(_stats(4, 80, 16))
|
||||
|
||||
# 25 / 100 = 0.25
|
||||
assert metrics.hit_rate == 0.25
|
||||
|
||||
metrics.observe(stats(1, 10, 2))
|
||||
metrics.observe(_stats(1, 10, 2))
|
||||
|
||||
# Remove (20, 9) and add (10, 2): 18 / 90 = 0.2
|
||||
assert metrics.aggregated_requests == 5
|
||||
@ -549,6 +549,38 @@ def test_metrics():
|
||||
assert not metrics.query_queue
|
||||
|
||||
|
||||
def test_metrics_empty_stats():
|
||||
"""
|
||||
Test the prefix caching metrics with empty stats.
|
||||
"""
|
||||
metrics = PrefixCachingMetrics(max_recent_requests=5)
|
||||
metrics.observe(_stats(0, 0, 0))
|
||||
metrics.observe(_stats(1, 20, 9))
|
||||
metrics.observe(_stats(0, 0, 0))
|
||||
metrics.observe(_stats(4, 80, 16))
|
||||
metrics.observe(_stats(0, 0, 0))
|
||||
metrics.observe(_stats(1, 10, 2))
|
||||
# Remove (20, 9) and add (10, 2): 18 / 90 = 0.2
|
||||
assert metrics.aggregated_requests == 5
|
||||
assert metrics.aggregated_query_total == 90
|
||||
assert metrics.aggregated_query_hit == 18
|
||||
assert metrics.hit_rate == 0.2
|
||||
|
||||
# Only the latest added stats preserved 10 / 20 = 0.5
|
||||
metrics.observe(_stats(11, 20, 10))
|
||||
assert metrics.aggregated_requests == 11
|
||||
assert metrics.aggregated_query_total == 20
|
||||
assert metrics.aggregated_query_hit == 10
|
||||
assert metrics.hit_rate == 0.5
|
||||
|
||||
# Only the latest added stats preserved 30 / 40 = 0.75
|
||||
metrics.observe(_stats(22, 40, 30))
|
||||
assert metrics.aggregated_requests == 22
|
||||
assert metrics.aggregated_query_total == 40
|
||||
assert metrics.aggregated_query_hit == 30
|
||||
assert metrics.hit_rate == 0.75
|
||||
|
||||
|
||||
def test_get_kv_cache_configs_multiple_workers():
|
||||
model_config = ModelConfig(max_model_len=16)
|
||||
vllm_config = VllmConfig(model_config=model_config)
|
||||
|
@ -18,12 +18,18 @@ import torch
|
||||
|
||||
from vllm import LLM
|
||||
from vllm.config import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.utils import KVOutputAggregator
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.multi_connector import (
|
||||
MultiKVConnectorStats)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector import (
|
||||
KVConnectorRole, NixlAgentMetadata, NixlConnector, NixlConnectorMetadata,
|
||||
NixlConnectorWorker)
|
||||
NixlConnectorWorker, NixlKVConnectorStats)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.sampling_params import SamplingParams
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.outputs import KVConnectorOutput, ModelRunnerOutput
|
||||
|
||||
from .utils import create_request, create_scheduler, create_vllm_config
|
||||
|
||||
@ -475,6 +481,209 @@ class TestNixlHandshake:
|
||||
# NOTE: resource cleanup in mp backend is a bit finicky, so the order in which
|
||||
# we put here is important. First run ray, it will clean up the resources, then
|
||||
# the rest of the tests.
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
FakeNixlWrapper)
|
||||
def test_kv_connector_stats(dist_init):
|
||||
"""Test that KV transfer stats are properly recorded and retrieved."""
|
||||
vllm_config = create_vllm_config()
|
||||
|
||||
# Test worker role in decode server.
|
||||
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
|
||||
connector.connector_worker = FakeNixlConnectorWorker(vllm_config,
|
||||
connector.engine_id,
|
||||
hand_shake_latency=0)
|
||||
|
||||
# Verify that xfer_stats starts empty
|
||||
initial_stats = connector.get_kv_connector_stats()
|
||||
assert initial_stats is None
|
||||
|
||||
# Create transfer metadata
|
||||
request_id = "test_req_for_stats"
|
||||
metadata = NixlConnectorMetadata()
|
||||
metadata.add_new_req(request_id=request_id,
|
||||
local_block_ids=[1, 2, 3],
|
||||
kv_transfer_params={
|
||||
"remote_block_ids": [4, 5, 6],
|
||||
"remote_engine_id":
|
||||
FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
|
||||
"remote_host": "localhost",
|
||||
"remote_port": 1234,
|
||||
"remote_tp_size": 1,
|
||||
})
|
||||
connector.bind_connector_metadata(metadata)
|
||||
|
||||
# Start the transfer
|
||||
dummy_ctx = ForwardContext(
|
||||
no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0,
|
||||
)
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
|
||||
# Verify stats are recorded after transfer is complete
|
||||
max_iterations = 2
|
||||
# Clear metadata before start_load_kv to prevent reprocessing same request
|
||||
connector.bind_connector_metadata(NixlConnectorMetadata())
|
||||
for _ in range(max_iterations):
|
||||
# Need to call start_load_kv to process completed handshakes
|
||||
connector.start_load_kv(dummy_ctx)
|
||||
_, done_recving = connector.get_finished(finished_req_ids=set())
|
||||
if len(done_recving) > 0 and request_id in done_recving:
|
||||
break
|
||||
time.sleep(
|
||||
0.1) # Small delay to allow background handshake to complete
|
||||
else:
|
||||
assert "Transfer did not complete within expected iterations"
|
||||
|
||||
# Now check that stats were recorded
|
||||
stats_after_transfer = connector.get_kv_connector_stats()
|
||||
assert isinstance(stats_after_transfer, NixlKVConnectorStats)
|
||||
|
||||
# Verify stats values are recorded
|
||||
assert not stats_after_transfer.is_empty()
|
||||
assert stats_after_transfer.data["num_successful_transfers"] == 1
|
||||
|
||||
# Verify stats are reset after retrieval
|
||||
stats_after_reset = connector.get_kv_connector_stats()
|
||||
assert stats_after_reset is None
|
||||
|
||||
|
||||
def test_kv_connector_stats_aggregation():
|
||||
"""
|
||||
Test KV transfer stats aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
# Create KVOutputAggregator for 3 workers (simulating TP=3), same thing
|
||||
# done in MultiprocExecutor.execute_model
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
|
||||
# Create stats for multiple workers with different transfer patterns
|
||||
worker1_stats = NixlKVConnectorStats()
|
||||
worker2_stats = NixlKVConnectorStats()
|
||||
worker3_stats = NixlKVConnectorStats()
|
||||
|
||||
# Record different transfers on each worker
|
||||
# Worker 1: 2 transfers
|
||||
worker1_stats.record_transfer()
|
||||
worker1_stats.record_transfer()
|
||||
|
||||
# Worker 2: 1 transfer
|
||||
worker2_stats.record_transfer()
|
||||
|
||||
# Worker 3: 3 transfers
|
||||
worker3_stats.record_transfer()
|
||||
worker3_stats.record_transfer()
|
||||
worker3_stats.record_transfer()
|
||||
|
||||
# Create ModelRunnerOutput instances for each worker
|
||||
worker_outputs = []
|
||||
for i, worker_stats in enumerate(
|
||||
[worker1_stats, worker2_stats, worker3_stats]):
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]], # dummy token
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"])
|
||||
if i < 2 else None, # Workers 0,1 finished sending
|
||||
finished_recving=set([f"req_{i}_recv"])
|
||||
if i > 0 else None, # Workers 1,2 finished receiving
|
||||
kv_connector_stats=worker_stats,
|
||||
))
|
||||
worker_outputs.append(output)
|
||||
|
||||
# Use the real aggregation mechanism (like MultiprocExecutor.execute_model)
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_connector_stats = \
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, NixlKVConnectorStats)
|
||||
# Number of total transfers across all workers.
|
||||
assert kv_connector_stats.data["num_successful_transfers"] == 6
|
||||
|
||||
|
||||
def test_multi_kv_connector_stats_aggregation():
|
||||
"""
|
||||
Test MultiKVConnectorStats aggregation across TP ranks using
|
||||
KVOutputAggregator (used by MultiprocExecutor).
|
||||
"""
|
||||
|
||||
aggregator = KVOutputAggregator(world_size=3)
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
@dataclass
|
||||
class FooKVConnectorStats(KVConnectorStats):
|
||||
|
||||
def reset(self):
|
||||
self.data = {"num_foo_transfers": 0}
|
||||
|
||||
def record_transfer(self):
|
||||
if "num_foo_transfers" not in self.data:
|
||||
self.data["num_foo_transfers"] = 0
|
||||
self.data["num_foo_transfers"] += 1
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.data["num_foo_transfers"] == 0
|
||||
|
||||
def aggregate(self,
|
||||
other: "FooKVConnectorStats") -> "FooKVConnectorStats":
|
||||
if not other.is_empty():
|
||||
self.data["num_foo_transfers"] += other.data[
|
||||
"num_foo_transfers"]
|
||||
return self
|
||||
|
||||
def make_multi_stats(nixl_count: int,
|
||||
foo_count: int) -> MultiKVConnectorStats:
|
||||
data: dict[str, KVConnectorStats] = {}
|
||||
if nixl_count > 0:
|
||||
nixl_stats = NixlKVConnectorStats()
|
||||
for _ in range(nixl_count):
|
||||
nixl_stats.record_transfer()
|
||||
data["NixlConnector"] = nixl_stats
|
||||
if foo_count > 0:
|
||||
foo_stats = FooKVConnectorStats()
|
||||
for _ in range(foo_count):
|
||||
foo_stats.record_transfer()
|
||||
data["FooConnector"] = foo_stats
|
||||
return MultiKVConnectorStats(data=data)
|
||||
|
||||
# Create heterogeneous stats across 3 workers
|
||||
worker_patterns = [(2, 1), (3, 0), (0, 5)] # (Nixl, Foo)
|
||||
|
||||
worker_outputs: list[ModelRunnerOutput] = []
|
||||
for i, (nixl, foo) in enumerate(worker_patterns):
|
||||
stats = make_multi_stats(nixl, foo)
|
||||
output = ModelRunnerOutput(
|
||||
req_ids=[f"req_{i}"],
|
||||
req_id_to_index={f"req_{i}": 0},
|
||||
sampled_token_ids=[[123]],
|
||||
logprobs=None,
|
||||
prompt_logprobs_dict={},
|
||||
pooler_output=[None],
|
||||
kv_connector_output=KVConnectorOutput(
|
||||
finished_sending=set([f"req_{i}_send"]) if i < 2 else None,
|
||||
finished_recving=set([f"req_{i}_recv"]) if i > 0 else None,
|
||||
kv_connector_stats=stats,
|
||||
),
|
||||
)
|
||||
worker_outputs.append(output)
|
||||
|
||||
aggregated_output = aggregator.aggregate(worker_outputs, output_rank=0)
|
||||
kv_connector_stats = \
|
||||
aggregated_output.kv_connector_output.kv_connector_stats
|
||||
assert isinstance(kv_connector_stats, MultiKVConnectorStats)
|
||||
|
||||
# Validate per-connector totals across workers
|
||||
assert kv_connector_stats["NixlConnector"].data[
|
||||
"num_successful_transfers"] == 5
|
||||
assert kv_connector_stats["FooConnector"].data["num_foo_transfers"] == 6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["ray", None])
|
||||
@patch(
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
|
||||
|
505
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
505
tests/v1/kv_connector/unit/test_offloading_connector.py
Normal file
@ -0,0 +1,505 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import copy
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm import SamplingParams
|
||||
from vllm.config import KVTransferConfig, VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import KVConnectorRole
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector import (
|
||||
OffloadingConnector, OffloadingConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.utils import sha256
|
||||
from vllm.v1.core.kv_cache_utils import (BlockHash, get_request_block_hasher,
|
||||
init_none_hash)
|
||||
from vllm.v1.core.sched.scheduler import Scheduler
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
OffloadingManager, PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import (OffloadingHandler,
|
||||
TransferResult, TransferSpec)
|
||||
from vllm.v1.outputs import EMPTY_MODEL_RUNNER_OUTPUT, KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
from .utils import (EOS_TOKEN_ID, create_model_runner_output, create_scheduler,
|
||||
create_vllm_config)
|
||||
|
||||
|
||||
class MockLoadStoreSpec(LoadStoreSpec):
|
||||
|
||||
def __init__(self, block_hashes: Iterable[BlockHash]):
|
||||
self.block_hashes: list[BlockHash] = list(block_hashes)
|
||||
|
||||
@staticmethod
|
||||
def medium() -> str:
|
||||
return "Mock"
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return repr(self.block_hashes)
|
||||
|
||||
|
||||
class MockOffloadingHandler(OffloadingHandler):
|
||||
|
||||
def __init__(self):
|
||||
self.completed_transfers: list[TransferResult] = []
|
||||
self.completed_specs: list[TransferSpec] = []
|
||||
|
||||
def get_finished(self) -> list[TransferResult]:
|
||||
finished = self.completed_transfers
|
||||
self.completed_transfers = []
|
||||
return finished
|
||||
|
||||
def transfer_async(self, job_id: int, spec: TransferSpec) -> bool:
|
||||
self.completed_specs.append(spec)
|
||||
self.completed_transfers.append((job_id, True))
|
||||
return True
|
||||
|
||||
|
||||
class MockOffloadingSpec(OffloadingSpec):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig):
|
||||
super().__init__(vllm_config)
|
||||
|
||||
self.manager = MagicMock(spec=OffloadingManager)
|
||||
self.manager.lookup.return_value = 0
|
||||
self.manager.prepare_load = lambda block_hashes: (MockLoadStoreSpec(
|
||||
block_hashes))
|
||||
self.handler = MockOffloadingHandler()
|
||||
|
||||
def get_manager(self) -> OffloadingManager:
|
||||
return self.manager
|
||||
|
||||
def get_handlers(
|
||||
self, _
|
||||
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
||||
OffloadingHandler]]:
|
||||
|
||||
yield GPULoadStoreSpec, MockLoadStoreSpec, self.handler
|
||||
yield MockLoadStoreSpec, GPULoadStoreSpec, self.handler
|
||||
|
||||
def get_completed_transfers(self) -> list[TransferSpec]:
|
||||
specs = self.handler.completed_specs
|
||||
self.handler.completed_specs = []
|
||||
return specs
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferSummary:
|
||||
gpu_block_indices: list[int]
|
||||
offload_addresses: list[Any]
|
||||
|
||||
|
||||
class RequestRunner:
|
||||
|
||||
def __init__(self, offloaded_block_size: int, gpu_block_size: int,
|
||||
num_gpu_blocks: int):
|
||||
self.offloaded_block_size: int = offloaded_block_size
|
||||
self.gpu_block_size: int = gpu_block_size
|
||||
self.num_gpu_blocks: int = num_gpu_blocks
|
||||
|
||||
self.req_id: int = -1
|
||||
|
||||
vllm_config = create_vllm_config(block_size=gpu_block_size,
|
||||
max_num_batched_tokens=1000)
|
||||
vllm_config.kv_transfer_config = KVTransferConfig(
|
||||
kv_connector="OffloadingConnector",
|
||||
kv_role="kv_both",
|
||||
kv_connector_extra_config={
|
||||
"spec_name": "MockOffloadingSpec",
|
||||
"spec_module_path":
|
||||
"tests.v1.kv_connector.unit.test_offloading_connector",
|
||||
"block_size": offloaded_block_size,
|
||||
})
|
||||
|
||||
self.scheduler: Scheduler = create_scheduler(vllm_config,
|
||||
num_blocks=num_gpu_blocks)
|
||||
self.worker_connector = OffloadingConnector(vllm_config,
|
||||
KVConnectorRole.WORKER)
|
||||
|
||||
# register worker kv_caches to enable OffloadingWorker creations
|
||||
self.worker_connector.register_kv_caches(
|
||||
kv_caches={"a": torch.empty(0)})
|
||||
|
||||
# extract connector of scheduler
|
||||
scheduler_connector = self.scheduler.connector
|
||||
assert scheduler_connector is not None
|
||||
assert isinstance(scheduler_connector, OffloadingConnector)
|
||||
self.scheduler_connector: OffloadingConnector = scheduler_connector
|
||||
|
||||
# extract mocked OffloadingManager of scheduler connector
|
||||
connector_scheduler = scheduler_connector.connector_scheduler
|
||||
assert connector_scheduler is not None
|
||||
manager = connector_scheduler.manager
|
||||
assert isinstance(manager, MagicMock)
|
||||
self.manager: MagicMock = manager
|
||||
|
||||
assert connector_scheduler.gpu_block_size == gpu_block_size
|
||||
assert connector_scheduler.offloaded_block_size == offloaded_block_size
|
||||
|
||||
# extract OffloadingSpec of worker_connector
|
||||
connector_worker = self.worker_connector.connector_worker
|
||||
assert connector_worker is not None
|
||||
offloading_spec = connector_worker.spec
|
||||
assert isinstance(offloading_spec, MockOffloadingSpec)
|
||||
self.offloading_spec: MockOffloadingSpec = offloading_spec
|
||||
|
||||
# mapping (offloading address) -> gpu_block_index
|
||||
self.offloaded: dict[Any, int] = {}
|
||||
|
||||
self.pending_loads_count: int = 0
|
||||
self.pending_stores_count: int = 0
|
||||
|
||||
self.completed_loads: list[TransferSummary] = []
|
||||
self.completed_stores: list[TransferSummary] = []
|
||||
|
||||
# maps {block_id: block_offset}
|
||||
self.gpu_block_index: dict[int, int] = {}
|
||||
|
||||
init_none_hash(sha256)
|
||||
self._block_hasher = get_request_block_hasher(gpu_block_size, sha256)
|
||||
|
||||
self._dummy_ctx: ForwardContext = ForwardContext(no_compile_layers={},
|
||||
attn_metadata={},
|
||||
virtual_engine=0)
|
||||
|
||||
def new_request(self, token_ids: list[int]):
|
||||
assert not self.scheduler.requests
|
||||
self.req_id += 1
|
||||
|
||||
req = Request(
|
||||
request_id=str(self.req_id),
|
||||
prompt_token_ids=token_ids,
|
||||
sampling_params=SamplingParams(max_tokens=1000),
|
||||
pooling_params=None,
|
||||
eos_token_id=EOS_TOKEN_ID,
|
||||
block_hasher=self._block_hasher,
|
||||
)
|
||||
|
||||
self.scheduler.add_request(req)
|
||||
|
||||
def _wait_for_transfers(self):
|
||||
block_size_factor = self.offloaded_block_size // self.gpu_block_size
|
||||
|
||||
while self.pending_loads_count or self.pending_stores_count:
|
||||
for transfer_spec in (
|
||||
self.offloading_spec.get_completed_transfers()):
|
||||
src_spec, dst_spec = transfer_spec
|
||||
|
||||
if isinstance(src_spec, GPULoadStoreSpec):
|
||||
store = True
|
||||
gpu_spec = src_spec
|
||||
offload_spec = dst_spec
|
||||
else:
|
||||
store = False
|
||||
gpu_spec = dst_spec
|
||||
offload_spec = src_spec
|
||||
|
||||
assert isinstance(offload_spec, MockLoadStoreSpec)
|
||||
assert isinstance(gpu_spec, GPULoadStoreSpec)
|
||||
|
||||
gpu_block_indices: list[int] = []
|
||||
for block_id in gpu_spec.block_ids:
|
||||
gpu_block_indices.append(
|
||||
self.gpu_block_index[block_id.item()])
|
||||
|
||||
# list of (block_hash, sub_block_offset)
|
||||
offload_addresses: list[Any] = []
|
||||
for block_hash in offload_spec.block_hashes:
|
||||
for sub_block_idx in range(block_size_factor):
|
||||
offload_addresses.append((block_hash, sub_block_idx))
|
||||
|
||||
if store:
|
||||
assert len(gpu_block_indices) == len(offload_addresses)
|
||||
|
||||
self.completed_stores.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
self.pending_stores_count -= 1
|
||||
else:
|
||||
remainder_sub_block_count = (len(offload_addresses) -
|
||||
len(gpu_block_indices))
|
||||
assert remainder_sub_block_count >= 0
|
||||
assert remainder_sub_block_count < block_size_factor
|
||||
offload_addresses = offload_addresses[
|
||||
remainder_sub_block_count:]
|
||||
|
||||
self.completed_loads.append(
|
||||
TransferSummary(gpu_block_indices, offload_addresses))
|
||||
self.pending_loads_count -= 1
|
||||
|
||||
def _update_gpu_block_idx(self):
|
||||
for blocks in (self.scheduler.kv_cache_manager.coordinator.
|
||||
single_type_managers[0].req_to_blocks.values()):
|
||||
for block_idx, block in enumerate(blocks):
|
||||
self.gpu_block_index[block.block_id] = block_idx
|
||||
|
||||
def _run(self, decoded_tokens: list[int]):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
"""
|
||||
|
||||
tokens_iter = iter(decoded_tokens)
|
||||
token_id = next(tokens_iter, None)
|
||||
while token_id is not None:
|
||||
assert self.scheduler.requests
|
||||
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
self._update_gpu_block_idx()
|
||||
|
||||
kv_connector_metadata = scheduler_output.kv_connector_metadata
|
||||
assert kv_connector_metadata is not None
|
||||
assert isinstance(kv_connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
|
||||
self.pending_loads_count += len(kv_connector_metadata.reqs_to_load)
|
||||
self.pending_stores_count += len(
|
||||
kv_connector_metadata.reqs_to_store)
|
||||
|
||||
self.worker_connector.bind_connector_metadata(
|
||||
kv_connector_metadata)
|
||||
self.worker_connector.start_load_kv(self._dummy_ctx)
|
||||
|
||||
if scheduler_output.total_num_scheduled_tokens > 0:
|
||||
self.worker_connector.wait_for_save()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
|
||||
self.worker_connector.clear_connector_metadata()
|
||||
|
||||
model_runner_output = create_model_runner_output(
|
||||
reqs=self.scheduler.running,
|
||||
finished_sending=list(finished_sending),
|
||||
finished_recving=list(finished_recving),
|
||||
token_id=token_id)
|
||||
|
||||
if self.scheduler.running:
|
||||
token_id = next(tokens_iter, None)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
self._wait_for_transfers()
|
||||
|
||||
# run one more step to update finished stored
|
||||
if EOS_TOKEN_ID in decoded_tokens:
|
||||
assert not self.scheduler.running
|
||||
|
||||
while self.scheduler.requests:
|
||||
scheduler_output = self.scheduler.schedule()
|
||||
|
||||
finished_sending, finished_recving = (
|
||||
self.worker_connector.get_finished(
|
||||
scheduler_output.finished_req_ids))
|
||||
|
||||
assert not finished_recving
|
||||
|
||||
model_runner_output = copy.deepcopy(EMPTY_MODEL_RUNNER_OUTPUT)
|
||||
model_runner_output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending)
|
||||
|
||||
self.scheduler.update_from_output(scheduler_output,
|
||||
model_runner_output)
|
||||
|
||||
def run(
|
||||
self,
|
||||
decoded_tokens: list[int],
|
||||
expected_stored_gpu_block_indexes: tuple[int, ...] = (),
|
||||
expected_loaded_gpu_block_indexes: tuple[int, ...] = (),
|
||||
):
|
||||
"""
|
||||
Runs multiple engine (scheduler + worker) steps.
|
||||
Assumes a single request is running.
|
||||
|
||||
Args:
|
||||
decoded_tokens: the tokens to yield at each step.
|
||||
expected_stored_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be written during the run.
|
||||
expected_loaded_gpu_block_indexes: GPU block indexes
|
||||
that are expected to be loaded during the run.
|
||||
"""
|
||||
|
||||
self.manager.reset_mock()
|
||||
self._run(decoded_tokens)
|
||||
|
||||
loaded_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_loads:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
loaded_gpu_block_indexes.add(gpu_block_idx)
|
||||
assert gpu_block_idx == self.offloaded[offloaded_address]
|
||||
|
||||
assert (
|
||||
set(expected_loaded_gpu_block_indexes) == loaded_gpu_block_indexes)
|
||||
self.completed_loads.clear()
|
||||
|
||||
stored_gpu_block_indexes: set[int] = set()
|
||||
for transfer in self.completed_stores:
|
||||
for gpu_block_idx, offloaded_address in zip(
|
||||
transfer.gpu_block_indices, transfer.offload_addresses):
|
||||
stored_gpu_block_indexes.add(gpu_block_idx)
|
||||
self.offloaded[offloaded_address] = gpu_block_idx
|
||||
|
||||
assert (
|
||||
set(expected_stored_gpu_block_indexes) == stored_gpu_block_indexes)
|
||||
self.completed_stores.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def request_runner():
|
||||
runners = []
|
||||
|
||||
def runner_factory(offloaded_block_size, gpu_block_size, num_gpu_blocks):
|
||||
runner = RequestRunner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
runners.append(runner)
|
||||
return runner
|
||||
|
||||
yield runner_factory # pass factory to the test
|
||||
|
||||
|
||||
def generate_store_output(block_hashes: Iterable[BlockHash]):
|
||||
block_hashes = list(block_hashes)
|
||||
return PrepareStoreOutput(
|
||||
block_hashes_to_store=list(block_hashes),
|
||||
store_spec=MockLoadStoreSpec(block_hashes),
|
||||
block_hashes_evicted=[],
|
||||
)
|
||||
|
||||
|
||||
def test_offloading_connector(request_runner):
|
||||
offloaded_block_size = 12
|
||||
gpu_block_size = 4
|
||||
num_gpu_blocks = 100
|
||||
block_size_factor = offloaded_block_size // gpu_block_size
|
||||
|
||||
runner = request_runner(offloaded_block_size=offloaded_block_size,
|
||||
gpu_block_size=gpu_block_size,
|
||||
num_gpu_blocks=num_gpu_blocks)
|
||||
|
||||
# 3 blocks, store just the middle block (skip first and last)
|
||||
# blocks = [0, 1, 2], [3, 4, 5], [6, 7, 8]
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 3)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output(list(block_hashes)[1:2])
|
||||
runner.run(decoded_tokens=[0], expected_stored_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# add block missing 1 token -> no offload
|
||||
runner.run(decoded_tokens=[0] * (offloaded_block_size - 1))
|
||||
runner.manager.prepare_store.assert_not_called()
|
||||
|
||||
# +1 token -> single block, fail prepare_store
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: None
|
||||
runner.run(decoded_tokens=[0])
|
||||
runner.manager.prepare_store.assert_called()
|
||||
|
||||
# 1 more block, now set block_hashes_to_store = []
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size)
|
||||
|
||||
# 1 more block, now check touch was called with all 6 blocks
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output(block_hashes)
|
||||
runner.run(decoded_tokens=[0] * offloaded_block_size,
|
||||
expected_stored_gpu_block_indexes=(15, 16, 17))
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes1 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes1) == 6
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# create a new request differing only on the last token
|
||||
runner.new_request(token_ids=[0] * (offloaded_block_size * 6 - 1) + [1])
|
||||
runner.run(decoded_tokens=[0],
|
||||
expected_stored_gpu_block_indexes=tuple(
|
||||
range(6 * block_size_factor)))
|
||||
runner.manager.touch.assert_called()
|
||||
block_hashes2 = list(runner.manager.touch.call_args.args[0])
|
||||
assert len(block_hashes2) == 6
|
||||
|
||||
# verify hashes are the same, except for the last block
|
||||
assert block_hashes1[:5] == block_hashes2[:5]
|
||||
assert block_hashes1[5] != block_hashes2[5]
|
||||
|
||||
# terminate request
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
|
||||
# full_block_tokens - num_computed_tokens < offloaded_block_size
|
||||
runner.new_request(token_ids=[0] * gpu_block_size + [1] *
|
||||
(offloaded_block_size - gpu_block_size))
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_not_called()
|
||||
|
||||
# single block lookup with no hits
|
||||
runner.new_request(token_ids=[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID])
|
||||
runner.manager.lookup.assert_called()
|
||||
assert len(list(runner.manager.lookup.call_args.args[0])) == 1
|
||||
|
||||
# single block lookup with a hit
|
||||
runner.scheduler.reset_prefix_cache()
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(0, 1, 2))
|
||||
|
||||
# single block lookup with a hit in a middle block
|
||||
runner.new_request(token_ids=[0] * offloaded_block_size * 2 +
|
||||
[1] * offloaded_block_size)
|
||||
runner.manager.prepare_store.side_effect = \
|
||||
lambda block_hashes: generate_store_output([])
|
||||
runner.manager.lookup.return_value = 1
|
||||
runner.run(decoded_tokens=[EOS_TOKEN_ID],
|
||||
expected_loaded_gpu_block_indexes=(3, 4, 5))
|
||||
|
||||
# test take_events
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
def take_events() -> Iterable[OffloadingEvent]:
|
||||
yield OffloadingEvent(block_hashes=to_hashes([1, 2, 3]),
|
||||
block_size=16,
|
||||
medium="A",
|
||||
removed=False)
|
||||
yield OffloadingEvent(block_hashes=to_hashes([4, 5, 6]),
|
||||
block_size=32,
|
||||
medium="B",
|
||||
removed=True)
|
||||
|
||||
runner.manager.take_events.side_effect = take_events
|
||||
events = list(runner.scheduler_connector.take_events())
|
||||
assert len(events) == 2
|
||||
event = events[0]
|
||||
assert isinstance(event, BlockStored)
|
||||
assert event.block_hashes == to_hashes([1, 2, 3])
|
||||
assert event.block_size == 16
|
||||
assert event.medium == "A"
|
||||
assert event.token_ids == []
|
||||
assert event.parent_block_hash is None
|
||||
assert event.lora_id is None
|
||||
event = events[1]
|
||||
assert isinstance(event, BlockRemoved)
|
||||
assert event.block_hashes == to_hashes([4, 5, 6])
|
||||
assert event.medium == "B"
|
@ -176,6 +176,7 @@ def create_model_runner_output(
|
||||
finished_sending: Optional[list[str]] = None,
|
||||
finished_recving: Optional[list[str]] = None,
|
||||
use_eos: bool = False,
|
||||
token_id: int = 0,
|
||||
) -> ModelRunnerOutput:
|
||||
"""Make dummy model runner output for testing."""
|
||||
|
||||
@ -184,7 +185,7 @@ def create_model_runner_output(
|
||||
req_id_to_index = {req_id: idx for idx, req_id in enumerate(req_ids)}
|
||||
|
||||
# Make sampled tokens.
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else 0
|
||||
sampled_token = EOS_TOKEN_ID if use_eos else token_id
|
||||
sampled_token_ids = [[sampled_token] for _ in req_ids]
|
||||
|
||||
kv_connector_output = None if (
|
||||
|
175
tests/v1/kv_offload/test_cpu.py
Normal file
175
tests/v1/kv_offload/test_cpu.py
Normal file
@ -0,0 +1,175 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.kv_offload.abstract import (LoadStoreSpec, OffloadingEvent,
|
||||
PrepareStoreOutput)
|
||||
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
||||
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedPrepareStoreOutput:
|
||||
block_hashes_to_store: list[int]
|
||||
store_block_ids: list[int]
|
||||
block_hashes_evicted: list[int]
|
||||
|
||||
|
||||
def to_hashes(int_hashes: list[int]) -> list[BlockHash]:
|
||||
return [BlockHash(str(i).encode()) for i in int_hashes]
|
||||
|
||||
|
||||
def verify_store_output(
|
||||
prepare_store_output: Optional[PrepareStoreOutput],
|
||||
expected_prepare_store_output: ExpectedPrepareStoreOutput):
|
||||
assert prepare_store_output is not None
|
||||
assert (prepare_store_output.block_hashes_to_store == to_hashes(
|
||||
expected_prepare_store_output.block_hashes_to_store))
|
||||
assert (prepare_store_output.block_hashes_evicted == to_hashes(
|
||||
expected_prepare_store_output.block_hashes_evicted))
|
||||
store_spec = prepare_store_output.store_spec
|
||||
assert isinstance(store_spec, CPULoadStoreSpec)
|
||||
expected_array = np.array(expected_prepare_store_output.store_block_ids,
|
||||
dtype=np.int64)
|
||||
assert np.array_equal(expected_array, store_spec.block_ids)
|
||||
|
||||
|
||||
def verify_load_output(prepare_load_output: LoadStoreSpec,
|
||||
expected_prepare_load_output: list[int]):
|
||||
assert isinstance(prepare_load_output, CPULoadStoreSpec)
|
||||
expected_array = np.array(expected_prepare_load_output, dtype=np.int64)
|
||||
assert np.array_equal(expected_array, prepare_load_output.block_ids)
|
||||
|
||||
|
||||
def verify_events(events: Iterable[OffloadingEvent],
|
||||
block_size: int,
|
||||
expected_stores: tuple[set[int], ...] = (),
|
||||
expected_evictions: tuple[set[int], ...] = ()):
|
||||
stores: list[set[BlockHash]] = []
|
||||
evictions: list[set[BlockHash]] = []
|
||||
for event in events:
|
||||
assert event.medium == CPULoadStoreSpec.medium()
|
||||
assert event.block_size == block_size
|
||||
if event.removed:
|
||||
evictions.append(set(event.block_hashes))
|
||||
else:
|
||||
stores.append(set(event.block_hashes))
|
||||
|
||||
def to_hash_sets(
|
||||
int_sets: tuple[set[int], ...]) -> tuple[set[BlockHash], ...]:
|
||||
return tuple([set(to_hashes(list(int_set))) for int_set in int_sets])
|
||||
|
||||
assert tuple(evictions) == to_hash_sets(expected_evictions)
|
||||
assert tuple(stores) == to_hash_sets(expected_stores)
|
||||
|
||||
|
||||
def test_cpu_manager():
|
||||
"""
|
||||
Tests LRUOffloadingManager with a CPUBackend.
|
||||
"""
|
||||
# initialize a CPU backend with a capacity of 4 blocks
|
||||
block_size = 256
|
||||
cpu_backend = CPUBackend(block_size=block_size, num_blocks=4)
|
||||
cpu_manager = LRUOffloadingManager(cpu_backend, enable_events=True)
|
||||
|
||||
# prepare store [1, 2]
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([1, 2]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[1, 2],
|
||||
store_block_ids=[0, 1],
|
||||
block_hashes_evicted=[],
|
||||
))
|
||||
|
||||
# lookup [1, 2] -> not ready
|
||||
assert cpu_manager.lookup(to_hashes([1, 2])) == 0
|
||||
|
||||
# no events so far
|
||||
assert list(cpu_manager.take_events()) == []
|
||||
|
||||
# complete store [1, 2]
|
||||
cpu_manager.complete_store(to_hashes([1, 2]))
|
||||
verify_events(cpu_manager.take_events(),
|
||||
block_size=block_size,
|
||||
expected_stores=({1, 2}, ))
|
||||
|
||||
# lookup [1, 2]
|
||||
assert cpu_manager.lookup(to_hashes([1])) == 1
|
||||
assert cpu_manager.lookup(to_hashes([1, 2])) == 2
|
||||
assert cpu_manager.lookup(to_hashes([1, 2, 3])) == 2
|
||||
|
||||
# prepare store [2, 3, 4, 5] -> evicts [1]
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([2, 3, 4, 5]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[3, 4, 5],
|
||||
store_block_ids=[2, 3, 0],
|
||||
block_hashes_evicted=[1],
|
||||
))
|
||||
|
||||
# verify eviction event
|
||||
verify_events(cpu_manager.take_events(),
|
||||
block_size=block_size,
|
||||
expected_evictions=({1}, ))
|
||||
|
||||
# prepare store with no space
|
||||
assert cpu_manager.prepare_store(to_hashes([1, 6])) is None
|
||||
|
||||
# complete store [2, 3, 4, 5]
|
||||
cpu_manager.complete_store(to_hashes([2, 3, 4, 5]))
|
||||
|
||||
# prepare load [2, 3]
|
||||
prepare_load_output = cpu_manager.prepare_load(to_hashes([2, 3]))
|
||||
verify_load_output(prepare_load_output, [1, 2])
|
||||
|
||||
# prepare store with no space ([2, 3] is being loaded)
|
||||
assert cpu_manager.prepare_store(to_hashes([6, 7, 8])) is None
|
||||
|
||||
# complete load [2, 3]
|
||||
cpu_manager.complete_load(to_hashes([2, 3]))
|
||||
|
||||
# prepare store [6, 7, 8] -> evicts [2, 3, 4] (oldest)
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([6, 7, 8]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[6, 7, 8],
|
||||
store_block_ids=[3, 2, 1],
|
||||
block_hashes_evicted=[2, 3, 4],
|
||||
))
|
||||
|
||||
# complete store [6, 7, 8]
|
||||
cpu_manager.complete_store(to_hashes([6, 7, 8]))
|
||||
|
||||
# touch [5, 6, 7] (move to end of LRU order)
|
||||
cpu_manager.touch(to_hashes([5, 6, 7]))
|
||||
|
||||
# prepare store [7, 9] -> evicts [8] (oldest following previous touch)
|
||||
prepare_store_output = cpu_manager.prepare_store(to_hashes([9]))
|
||||
verify_store_output(
|
||||
prepare_store_output,
|
||||
ExpectedPrepareStoreOutput(
|
||||
block_hashes_to_store=[9],
|
||||
store_block_ids=[1],
|
||||
block_hashes_evicted=[8],
|
||||
))
|
||||
|
||||
# complete store [7, 9] with failure
|
||||
cpu_manager.complete_store(to_hashes([7, 9]), success=False)
|
||||
|
||||
# assert [7] is still stored, but [9] is not
|
||||
assert cpu_manager.lookup(to_hashes([7])) == 1
|
||||
assert cpu_manager.lookup(to_hashes([9])) == 0
|
||||
|
||||
verify_events(cpu_manager.take_events(),
|
||||
block_size=block_size,
|
||||
expected_stores=({3, 4, 5}, {6, 7, 8}),
|
||||
expected_evictions=({2, 3, 4}, {8}))
|
177
tests/v1/kv_offload/test_cpu_gpu.py
Normal file
177
tests/v1/kv_offload/test_cpu_gpu.py
Normal file
@ -0,0 +1,177 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import random
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend
|
||||
from vllm.v1.attention.backends.flashinfer import FlashInferBackend
|
||||
from vllm.v1.attention.backends.mla.flashattn_mla import FlashAttnMLABackend
|
||||
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.worker.cpu_gpu import CpuGpuOffloadingHandler
|
||||
|
||||
NUM_GPU_BLOCKS = [64]
|
||||
NUM_CPU_BLOCKS = [256]
|
||||
GPU_BLOCK_SIZES = [16]
|
||||
GPU_BLOCKS_PER_CPU_BLOCK = [1, 3]
|
||||
HEAD_SIZES = [64]
|
||||
NUM_HEADS = [8]
|
||||
NUM_LAYERS = [4]
|
||||
DTYPES = [torch.bfloat16]
|
||||
SEEDS = [0]
|
||||
CUDA_DEVICES = ['cuda:0']
|
||||
NUM_MAPPINGS = [3]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("gpu_to_cpu", [True, False])
|
||||
@pytest.mark.parametrize("num_mappings", NUM_MAPPINGS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("gpu_block_size", GPU_BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("gpu_blocks_per_cpu_block", GPU_BLOCKS_PER_CPU_BLOCK)
|
||||
@pytest.mark.parametrize("num_gpu_blocks", NUM_GPU_BLOCKS)
|
||||
@pytest.mark.parametrize("num_cpu_blocks", NUM_CPU_BLOCKS)
|
||||
@pytest.mark.parametrize("num_layers", NUM_LAYERS)
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("seed", SEEDS)
|
||||
@pytest.mark.parametrize("device", CUDA_DEVICES)
|
||||
@torch.inference_mode()
|
||||
def test_transfer(
|
||||
gpu_to_cpu: bool,
|
||||
num_mappings: int,
|
||||
head_size: int,
|
||||
num_heads: int,
|
||||
gpu_block_size: int,
|
||||
gpu_blocks_per_cpu_block: int,
|
||||
num_gpu_blocks: int,
|
||||
num_cpu_blocks: int,
|
||||
num_layers: int,
|
||||
dtype: torch.dtype,
|
||||
seed: int,
|
||||
device: str,
|
||||
) -> None:
|
||||
current_platform.seed_everything(seed)
|
||||
|
||||
# create per-layer GPU KV caches
|
||||
attn_backends_list = [
|
||||
FlashAttentionBackend, FlashInferBackend, FlashAttnMLABackend
|
||||
]
|
||||
|
||||
gpu_caches = {}
|
||||
attn_backends = {}
|
||||
for i in range(num_layers):
|
||||
layer_name = f'layer {i}'
|
||||
|
||||
attn_backend = attn_backends_list[i % len(attn_backends_list)]
|
||||
attn_backends[layer_name] = attn_backend
|
||||
|
||||
gpu_cache_shape = attn_backend.get_kv_cache_shape(
|
||||
num_gpu_blocks, gpu_block_size, num_heads, head_size)
|
||||
gpu_caches[layer_name] = torch.rand(gpu_cache_shape,
|
||||
dtype=dtype,
|
||||
device=device)
|
||||
|
||||
# create handler
|
||||
cpu_block_size = gpu_blocks_per_cpu_block * gpu_block_size
|
||||
handler = CpuGpuOffloadingHandler(attn_backends=attn_backends,
|
||||
gpu_block_size=gpu_block_size,
|
||||
cpu_block_size=cpu_block_size,
|
||||
num_cpu_blocks=num_cpu_blocks,
|
||||
gpu_caches=gpu_caches)
|
||||
|
||||
# select block mappings
|
||||
gpu_blocks = random.sample(range(num_gpu_blocks),
|
||||
num_mappings * gpu_blocks_per_cpu_block)
|
||||
cpu_blocks = random.sample(range(num_cpu_blocks), num_mappings)
|
||||
|
||||
# convert cpu blocks to gpu block size
|
||||
cpu_blocks_in_gpu_block_size = []
|
||||
for cpu_block in cpu_blocks:
|
||||
base_block_id = cpu_block * gpu_blocks_per_cpu_block
|
||||
for i in range(gpu_blocks_per_cpu_block):
|
||||
cpu_blocks_in_gpu_block_size.append(i + base_block_id)
|
||||
|
||||
# maybe skip a GPU block to test writing to the middle of a CPU block
|
||||
if gpu_to_cpu:
|
||||
gpu_blocks = gpu_blocks[gpu_blocks_per_cpu_block - 1:]
|
||||
cpu_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size[
|
||||
gpu_blocks_per_cpu_block - 1:]
|
||||
|
||||
# set transfer direction
|
||||
if gpu_to_cpu:
|
||||
src_kv_caches = handler.gpu_tensors
|
||||
dst_kv_caches = handler.cpu_tensors
|
||||
src_spec_class = GPULoadStoreSpec
|
||||
dst_spec_class = CPULoadStoreSpec
|
||||
src_blocks = gpu_blocks
|
||||
dst_blocks = cpu_blocks
|
||||
src_blocks_in_gpu_block_size = gpu_blocks
|
||||
dst_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_size_in_gpu_blocks = num_cpu_blocks * gpu_blocks_per_cpu_block
|
||||
else:
|
||||
src_kv_caches = handler.cpu_tensors
|
||||
dst_kv_caches = handler.gpu_tensors
|
||||
src_spec_class = CPULoadStoreSpec
|
||||
dst_spec_class = GPULoadStoreSpec
|
||||
src_blocks = cpu_blocks
|
||||
dst_blocks = gpu_blocks
|
||||
src_blocks_in_gpu_block_size = cpu_blocks_in_gpu_block_size
|
||||
dst_blocks_in_gpu_block_size = gpu_blocks
|
||||
dst_size_in_gpu_blocks = num_gpu_blocks
|
||||
|
||||
# build dst -> src mapping
|
||||
dst_to_src = {}
|
||||
for src_block, dst_block in zip(src_blocks_in_gpu_block_size,
|
||||
dst_blocks_in_gpu_block_size):
|
||||
dst_to_src[dst_block] = src_block
|
||||
|
||||
# build transfer specs
|
||||
src_spec = src_spec_class(src_blocks)
|
||||
dst_spec = dst_spec_class(dst_blocks)
|
||||
|
||||
# clone src and dst tensors before transfer
|
||||
orig_src_caches = [x.clone() for x in src_kv_caches]
|
||||
orig_dst_caches = [x.clone() for x in dst_kv_caches]
|
||||
|
||||
# call transfer function
|
||||
assert handler.transfer_async(1, (src_spec, dst_spec))
|
||||
assert set(handler.transfer_events.keys()) == {1}
|
||||
|
||||
# wait for transfer to complete
|
||||
end_time = time.time() + 10
|
||||
while time.time() < end_time:
|
||||
finished = handler.get_finished()
|
||||
if finished:
|
||||
assert finished == [(1, True)]
|
||||
break
|
||||
time.sleep(0.1)
|
||||
|
||||
# verify src tensors did not change
|
||||
for orig_tensor, tensor in zip(orig_src_caches, src_kv_caches):
|
||||
assert torch.equal(orig_tensor, tensor)
|
||||
|
||||
# verify dst tensors
|
||||
for dst_block in range(dst_size_in_gpu_blocks):
|
||||
src_block_candidate = dst_to_src.get(dst_block)
|
||||
for src_cache, dst_cache, orig_dst_cache, kv_dim in zip(
|
||||
src_kv_caches, dst_kv_caches, orig_dst_caches,
|
||||
handler.kv_dim_before_num_blocks):
|
||||
if kv_dim:
|
||||
# iterate over key, value
|
||||
for i in range(2):
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[i][src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[i][dst_block]
|
||||
torch.testing.assert_close(dst_cache[i][dst_block].cpu(),
|
||||
expected_value.cpu())
|
||||
else:
|
||||
if src_block_candidate is not None:
|
||||
expected_value = src_cache[src_block_candidate]
|
||||
else:
|
||||
expected_value = orig_dst_cache[dst_block]
|
||||
torch.testing.assert_close(dst_cache[dst_block].cpu(),
|
||||
expected_value.cpu())
|
@ -3,6 +3,7 @@
|
||||
|
||||
import itertools
|
||||
from collections.abc import Generator
|
||||
from typing import get_args
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
@ -464,7 +465,7 @@ def test_all_logprobs(example_prompts, monkeypatch: pytest.MonkeyPatch):
|
||||
assert len(prompt_logprob) == vocab_size
|
||||
|
||||
|
||||
@pytest.mark.parametrize("logprobs_mode", list(LogprobsMode))
|
||||
@pytest.mark.parametrize("logprobs_mode", get_args(LogprobsMode))
|
||||
def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test with LLM engine with different logprobs_mode.
|
||||
@ -493,14 +494,12 @@ def test_logprobs_mode(logprobs_mode: LogprobsMode,
|
||||
for logprobs in output.logprobs:
|
||||
for token_id in logprobs:
|
||||
logprob = logprobs[token_id]
|
||||
if logprobs_mode in (LogprobsMode.RAW_LOGPROBS,
|
||||
LogprobsMode.PROCESSED_LOGPROBS):
|
||||
if logprobs_mode in ("raw_logprobs", "processed_logprobs"):
|
||||
assert logprob.logprob <= 0
|
||||
if logprob.logprob > 0:
|
||||
positive_values = positive_values + 1
|
||||
total_token_with_logprobs = total_token_with_logprobs + 1
|
||||
assert total_token_with_logprobs >= len(results[0].outputs)
|
||||
if logprobs_mode in (LogprobsMode.RAW_LOGITS,
|
||||
LogprobsMode.PROCESSED_LOGITS):
|
||||
if logprobs_mode in ("raw_logits", "processed_logits"):
|
||||
assert positive_values > 0
|
||||
del llm
|
||||
|
@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from vllm.platforms import current_platform
|
||||
@ -70,6 +71,8 @@ class ExternalLBServerManager:
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
"1",
|
||||
current_platform.device_control_env_var:
|
||||
",".join(
|
||||
str(
|
||||
@ -127,11 +130,19 @@ def default_server_args():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
def server_manager(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with ExternalLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
||||
default_server_args) as server_list:
|
||||
yield server_list
|
||||
server_manager = ExternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args)
|
||||
|
||||
with server_manager:
|
||||
yield server_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers(server_manager):
|
||||
return server_manager.servers
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -144,6 +155,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
||||
]
|
||||
|
||||
|
||||
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||
response.raise_for_status()
|
||||
|
||||
vllm_config = response.json()["vllm_config"]
|
||||
return vllm_config["parallel_config"]
|
||||
|
||||
|
||||
def test_external_lb_server_info(server_manager):
|
||||
servers = server_manager.servers
|
||||
api_server_count = server_manager.api_server_count
|
||||
|
||||
for i, (server, _) in enumerate(servers):
|
||||
print(f"Testing {i=}")
|
||||
|
||||
# Each request will hit one of the API servers
|
||||
# `n_reqs` is set so that there is a good chance each server
|
||||
# receives at least one request
|
||||
n_reqs = 2 * api_server_count * api_server_count
|
||||
parallel_configs = [
|
||||
_get_parallel_config(server) for _ in range(n_reqs)
|
||||
]
|
||||
api_process_counts = [
|
||||
c["_api_process_count"] for c in parallel_configs
|
||||
]
|
||||
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||
|
||||
assert all(c == api_server_count
|
||||
for c in api_process_counts), api_process_counts
|
||||
assert all(0 <= r < api_server_count
|
||||
for r in api_process_ranks), api_process_ranks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
@ -9,6 +9,7 @@ from contextlib import AsyncExitStack
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
@ -92,6 +93,8 @@ class HybridLBServerManager:
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
"1",
|
||||
current_platform.device_control_env_var:
|
||||
",".join(
|
||||
str(
|
||||
@ -150,12 +153,20 @@ def default_server_args():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
def server_manager(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with HybridLBServerManager(MODEL_NAME, DP_SIZE, api_server_count,
|
||||
default_server_args, DP_SIZE_LOCAL,
|
||||
TP_SIZE) as server_list:
|
||||
yield server_list
|
||||
server_manager = HybridLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args, DP_SIZE_LOCAL,
|
||||
TP_SIZE)
|
||||
|
||||
with server_manager:
|
||||
yield server_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers(server_manager):
|
||||
return server_manager.servers
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
@ -168,6 +179,39 @@ async def clients(servers: list[tuple[RemoteOpenAIServer, list[str]]]):
|
||||
]
|
||||
|
||||
|
||||
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||
response.raise_for_status()
|
||||
|
||||
vllm_config = response.json()["vllm_config"]
|
||||
return vllm_config["parallel_config"]
|
||||
|
||||
|
||||
def test_hybrid_dp_server_info(server_manager):
|
||||
servers = server_manager.servers
|
||||
api_server_count = server_manager.api_server_count
|
||||
|
||||
for i, (server, _) in enumerate(servers):
|
||||
print(f"Testing {i=}")
|
||||
|
||||
# Each request will hit one of the API servers
|
||||
# `n_reqs` is set so that there is a good chance each server
|
||||
# receives at least one request
|
||||
n_reqs = 2 * api_server_count * api_server_count
|
||||
parallel_configs = [
|
||||
_get_parallel_config(server) for _ in range(n_reqs)
|
||||
]
|
||||
api_process_counts = [
|
||||
c["_api_process_count"] for c in parallel_configs
|
||||
]
|
||||
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||
|
||||
assert all(c == api_server_count
|
||||
for c in api_process_counts), api_process_counts
|
||||
assert all(0 <= r < api_server_count
|
||||
for r in api_process_ranks), api_process_ranks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
@ -10,6 +10,7 @@ from typing import Optional, cast
|
||||
import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import requests
|
||||
|
||||
from tests.utils import RemoteOpenAIServer
|
||||
from tests.v1.test_utils import check_request_balancing
|
||||
@ -101,6 +102,8 @@ class MultinodeInternalLBServerManager:
|
||||
sargs,
|
||||
auto_port=False,
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE":
|
||||
"1",
|
||||
current_platform.device_control_env_var:
|
||||
",".join(
|
||||
str(
|
||||
@ -214,7 +217,10 @@ class APIOnlyServerManager:
|
||||
self.model_name,
|
||||
api_server_args,
|
||||
auto_port=False,
|
||||
env_dict={}) # No GPUs needed for API-only server
|
||||
env_dict={
|
||||
"VLLM_SERVER_DEV_MODE": "1",
|
||||
# No GPUs needed for API-only server
|
||||
})
|
||||
server.__enter__()
|
||||
print(f"API-only server started successfully with "
|
||||
f"{self.api_server_count} API servers")
|
||||
@ -293,14 +299,21 @@ def default_server_args():
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
def servers(request, default_server_args):
|
||||
def server_manager(request, default_server_args):
|
||||
api_server_count = request.param
|
||||
with MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args,
|
||||
DP_SIZE // NUM_NODES,
|
||||
TP_SIZE) as server_list:
|
||||
yield server_list
|
||||
server_manager = MultinodeInternalLBServerManager(MODEL_NAME, DP_SIZE,
|
||||
api_server_count,
|
||||
default_server_args,
|
||||
DP_SIZE // NUM_NODES,
|
||||
TP_SIZE)
|
||||
|
||||
with server_manager:
|
||||
yield server_manager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def servers(server_manager):
|
||||
return server_manager.servers
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", params=[1, 4])
|
||||
@ -331,6 +344,34 @@ async def api_only_client(api_only_servers: list[tuple[RemoteOpenAIServer,
|
||||
yield client
|
||||
|
||||
|
||||
def _get_parallel_config(server: RemoteOpenAIServer):
|
||||
response = requests.get(server.url_for("server_info?config_format=json"))
|
||||
response.raise_for_status()
|
||||
|
||||
vllm_config = response.json()["vllm_config"]
|
||||
return vllm_config["parallel_config"]
|
||||
|
||||
|
||||
def test_multinode_dp_server_info(server_manager):
|
||||
head_server = server_manager.servers[0][0]
|
||||
api_server_count = server_manager.api_server_count
|
||||
|
||||
# Each request will hit one of the API servers
|
||||
# `n_reqs` is set so that there is a good chance each server
|
||||
# receives at least one request
|
||||
n_reqs = 2 * api_server_count * api_server_count
|
||||
parallel_configs = [
|
||||
_get_parallel_config(head_server) for _ in range(n_reqs)
|
||||
]
|
||||
api_process_counts = [c["_api_process_count"] for c in parallel_configs]
|
||||
api_process_ranks = [c["_api_process_rank"] for c in parallel_configs]
|
||||
|
||||
assert all(c == api_server_count
|
||||
for c in api_process_counts), api_process_counts
|
||||
assert all(0 <= r < api_server_count
|
||||
for r in api_process_ranks), api_process_ranks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
|
@ -50,8 +50,8 @@ ALLOWED_FILES = set([
|
||||
# cloudpickle
|
||||
'vllm/worker/worker_base.py',
|
||||
'vllm/executor/mp_distributed_executor.py',
|
||||
'vllm/executor/ray_distributed_executor.py',
|
||||
'vllm/entrypoints/llm.py',
|
||||
'vllm/v1/executor/ray_distributed_executor.py',
|
||||
'tests/utils.py',
|
||||
# pickle and cloudpickle
|
||||
'vllm/utils/__init__.py',
|
||||
|
@ -23,14 +23,14 @@ class AttentionType:
|
||||
Attention type.
|
||||
Use string to be compatible with `torch.compile`.
|
||||
"""
|
||||
# Decoder attention between previous layer Q/K/V
|
||||
DECODER = "decoder"
|
||||
# Encoder attention between previous layer Q/K/V for encoder-decoder
|
||||
"""Decoder attention between previous layer Q/K/V."""
|
||||
ENCODER = "encoder"
|
||||
# Encoder attention between previous layer Q/K/V
|
||||
"""Encoder attention between previous layer Q/K/V for encoder-decoder."""
|
||||
ENCODER_ONLY = "encoder_only"
|
||||
# Attention between dec. Q and enc. K/V for encoder-decoder
|
||||
"""Encoder attention between previous layer Q/K/V."""
|
||||
ENCODER_DECODER = "encoder_decoder"
|
||||
"""Attention between dec. Q and enc. K/V for encoder-decoder."""
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
|
@ -430,9 +430,11 @@ class MultiHeadAttention(nn.Module):
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""Input shape: batch_size x seq_len x hidden_size"""
|
||||
# TODO(Isotr0py): Use existing backend implementations and support FA3
|
||||
bsz, q_len, _ = query.size()
|
||||
"""Input shape:
|
||||
(batch_size x seq_len x hidden_size) or
|
||||
(batch_size x seq_len x num_heads x head_size)
|
||||
"""
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
query = query.view(bsz, q_len, self.num_heads, self.head_size)
|
||||
|
@ -184,8 +184,30 @@ def kernel_unified_attention_2d(
|
||||
# this prefix can be skipped)
|
||||
num_tiles = cdiv_fn(max_seq_prefix_len, TILE_SIZE)
|
||||
|
||||
# iterate through tiles
|
||||
for j in range(0, num_tiles):
|
||||
# ---- Sliding-window tile pruning --------------------
|
||||
# Default: keep previous global behavior
|
||||
tile_start = 0
|
||||
tile_end = num_tiles
|
||||
if SLIDING_WINDOW > 0:
|
||||
# Query rows covered by this Q-block
|
||||
qpos_lo = q_block_local_idx * BLOCK_Q
|
||||
qpos_hi = tl.minimum(
|
||||
qpos_lo + (BLOCK_M - 1) // num_queries_per_kv,
|
||||
cur_batch_query_len - 1,
|
||||
)
|
||||
# For sliding window, each query position q can only attend to
|
||||
# keys in the range [q_abs - SLIDING_WINDOW + 1, q_abs]
|
||||
# where q_abs = context_len + q
|
||||
# The union of allowed key positions for this Q-block is:
|
||||
# [context_len + qpos_lo - SLIDING_WINDOW + 1, context_len + qpos_hi]
|
||||
first_allowed_key = context_len + qpos_lo - SLIDING_WINDOW + 1
|
||||
last_allowed_key = context_len + qpos_hi
|
||||
# Convert to tile indices and clamp
|
||||
tile_start = tl.maximum(0, first_allowed_key // TILE_SIZE)
|
||||
tile_end = tl.minimum((last_allowed_key // TILE_SIZE) + 1, num_tiles)
|
||||
|
||||
# iterate through tiles (now limited to the sliding window range)
|
||||
for j in range(tile_start, tile_end):
|
||||
seq_offset = j * TILE_SIZE + offs_t
|
||||
tile_mask = seq_offset < max_seq_prefix_len
|
||||
|
||||
|
@ -8,8 +8,9 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Union
|
||||
from typing import Optional, Protocol, Union
|
||||
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
@ -92,6 +93,16 @@ class RequestFuncOutput:
|
||||
start_time: float = 0.0
|
||||
|
||||
|
||||
class RequestFunc(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> Awaitable[RequestFuncOutput]:
|
||||
...
|
||||
|
||||
|
||||
async def async_request_openai_completions(
|
||||
request_func_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
@ -507,7 +518,7 @@ async def async_request_openai_embeddings(
|
||||
|
||||
|
||||
# TODO: Add more request functions for different API protocols.
|
||||
ASYNC_REQUEST_FUNCS = {
|
||||
ASYNC_REQUEST_FUNCS: dict[str, RequestFunc] = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"openai": async_request_openai_completions,
|
||||
"openai-chat": async_request_openai_chat_completions,
|
||||
|
@ -8,11 +8,12 @@ import time
|
||||
import aiohttp
|
||||
from tqdm.asyncio import tqdm
|
||||
|
||||
from .endpoint_request_func import RequestFuncInput, RequestFuncOutput
|
||||
from .endpoint_request_func import (RequestFunc, RequestFuncInput,
|
||||
RequestFuncOutput)
|
||||
|
||||
|
||||
async def wait_for_endpoint(
|
||||
request_func,
|
||||
request_func: RequestFunc,
|
||||
test_input: RequestFuncInput,
|
||||
session: aiohttp.ClientSession,
|
||||
timeout_seconds: int = 600,
|
||||
|
@ -31,8 +31,11 @@ logger = init_logger(__name__)
|
||||
|
||||
def make_compiler(compilation_config: CompilationConfig) -> CompilerInterface:
|
||||
if compilation_config.use_inductor:
|
||||
if envs.VLLM_USE_STANDALONE_COMPILE and is_torch_equal_or_newer(
|
||||
"2.8.0.dev"):
|
||||
# Use standalone compile only if requested, version is new enough,
|
||||
# and the symbol actually exists in this PyTorch build.
|
||||
if (envs.VLLM_USE_STANDALONE_COMPILE
|
||||
and is_torch_equal_or_newer("2.8.0.dev")
|
||||
and hasattr(torch._inductor, "standalone_compile")):
|
||||
logger.debug("Using InductorStandaloneAdaptor")
|
||||
return InductorStandaloneAdaptor()
|
||||
else:
|
||||
|
@ -82,7 +82,7 @@ class CUDAGraphWrapper:
|
||||
# TODO: in the future, if we want to use multiple
|
||||
# streams, it might not be safe to share a global pool.
|
||||
# only investigate this when we use multiple streams
|
||||
self.graph_pool = current_platform.get_global_graph_pool()
|
||||
self.graph_pool = current_platform.graph_pool_handle()
|
||||
|
||||
if cudagraph_options is None:
|
||||
cudagraph_options = CUDAGraphOptions()
|
||||
|
File diff suppressed because it is too large
Load Diff
2009
vllm/config/model.py
Normal file
2009
vllm/config/model.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -193,6 +193,25 @@ class ParallelConfig:
|
||||
not change by dcp, it simply reuse the GPUs of TP group, and tp_size
|
||||
needs to be divisible by dcp_size."""
|
||||
|
||||
_api_process_count: int = 1
|
||||
"""
|
||||
The number of API processes initialized.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
_api_process_rank: int = 0
|
||||
"""
|
||||
The rank of this API process, or `-1` for engine core processes
|
||||
under API server scale-out.
|
||||
|
||||
Note:
|
||||
This is an internal config that is only valid for and
|
||||
should only be set by API server scale-out.
|
||||
"""
|
||||
|
||||
@property
|
||||
def world_size_across_dp(self) -> int:
|
||||
"""world_size_across_dp is TPxPPxDP, it is the size of the world
|
||||
@ -428,6 +447,12 @@ class ParallelConfig:
|
||||
if self.distributed_executor_backend is None and self.world_size == 1:
|
||||
self.distributed_executor_backend = "uni"
|
||||
|
||||
if not -1 <= self._api_process_rank < self._api_process_count:
|
||||
raise ValueError(
|
||||
"Invalid value of `_api_process_rank`. "
|
||||
f"Expected to be `-1` or `[0, {self._api_process_count})`, "
|
||||
f"but found: {self._api_process_rank}")
|
||||
|
||||
@property
|
||||
def use_ray(self) -> bool:
|
||||
return self.distributed_executor_backend == "ray" or (
|
||||
|
97
vllm/config/pooler.py
Normal file
97
vllm/config/pooler.py
Normal file
@ -0,0 +1,97 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import hashlib
|
||||
from typing import Any, Optional
|
||||
|
||||
from pydantic.dataclasses import dataclass
|
||||
|
||||
from vllm.config.utils import config
|
||||
|
||||
|
||||
@config
|
||||
@dataclass
|
||||
class PoolerConfig:
|
||||
"""Controls the behavior of output pooling in pooling models."""
|
||||
|
||||
pooling_type: Optional[str] = None
|
||||
"""
|
||||
The pooling method of the pooling model. This should be a key in
|
||||
[`vllm.model_executor.layers.pooler.PoolingType`][].
|
||||
"""
|
||||
|
||||
## for embeddings models
|
||||
normalize: Optional[bool] = None
|
||||
"""
|
||||
Whether to normalize the embeddings outputs. Defaults to True.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""
|
||||
Reduce the dimensions of embeddings if model
|
||||
support matryoshka representation. Defaults to None.
|
||||
"""
|
||||
enable_chunked_processing: Optional[bool] = None
|
||||
"""
|
||||
Whether to enable chunked processing for long inputs that exceed the model's
|
||||
maximum position embeddings. When enabled, long inputs will be split into
|
||||
chunks, processed separately, and then aggregated using weighted averaging.
|
||||
This allows embedding models to handle arbitrarily long text without CUDA
|
||||
errors. Defaults to False.
|
||||
"""
|
||||
max_embed_len: Optional[int] = None
|
||||
"""
|
||||
Maximum input length allowed for embedding generation. When set, allows
|
||||
inputs longer than max_embed_len to be accepted for embedding models.
|
||||
When an input exceeds max_embed_len, it will be handled according to
|
||||
the original max_model_len validation logic.
|
||||
Defaults to None (i.e. set to max_model_len).
|
||||
"""
|
||||
|
||||
## for classification models
|
||||
activation: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply activation function to the classification outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
logit_bias: Optional[float] = None
|
||||
"""
|
||||
If provided, apply classification logit biases. Defaults to None.
|
||||
"""
|
||||
|
||||
## for reward models
|
||||
softmax: Optional[bool] = None
|
||||
"""
|
||||
Whether to apply softmax to the reward outputs.
|
||||
Defaults to True.
|
||||
"""
|
||||
step_tag_id: Optional[int] = None
|
||||
"""
|
||||
If set, only the score corresponding to the ``step_tag_id`` in the
|
||||
generated sentence should be returned. Otherwise, the scores for all tokens
|
||||
are returned.
|
||||
"""
|
||||
returned_token_ids: Optional[list[int]] = None
|
||||
"""
|
||||
A list of indices for the vocabulary dimensions to be extracted,
|
||||
such as the token IDs of ``good_token`` and ``bad_token`` in the
|
||||
``math-shepherd-mistral-7b-prm`` model.
|
||||
"""
|
||||
|
||||
def compute_hash(self) -> str:
|
||||
"""
|
||||
WARNING: Whenever a new field is added to this config,
|
||||
ensure that it is included in the factors list if
|
||||
it affects the computation graph.
|
||||
|
||||
Provide a hash that uniquely identifies all the configs
|
||||
that affect the structure of the computation
|
||||
graph from input ids/embeddings to the final hidden states,
|
||||
excluding anything before input ids/embeddings and after
|
||||
the final hidden states.
|
||||
"""
|
||||
# no factors to consider.
|
||||
# this config will not affect the computation graph.
|
||||
factors: list[Any] = []
|
||||
hash_str = hashlib.md5(str(factors).encode(),
|
||||
usedforsecurity=False).hexdigest()
|
||||
return hash_str
|
@ -3,7 +3,7 @@
|
||||
|
||||
import hashlib
|
||||
from dataclasses import field
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import SkipValidation, model_validator
|
||||
from pydantic.dataclasses import dataclass
|
||||
@ -15,13 +15,9 @@ from vllm.utils import (DEFAULT_MAX_NUM_BATCHED_TOKENS,
|
||||
MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS,
|
||||
POOLING_MODEL_MAX_NUM_BATCHED_TOKENS)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.config import RunnerType
|
||||
else:
|
||||
RunnerType = Any
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
RunnerType = Literal["generate", "pooling", "draft"]
|
||||
PreemptionMode = Literal["swap", "recompute"]
|
||||
SchedulerPolicy = Literal["fcfs", "priority"]
|
||||
|
||||
|
@ -1,8 +1,13 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import ast
|
||||
import inspect
|
||||
import textwrap
|
||||
from dataclasses import MISSING, Field, field, fields, is_dataclass
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, TypeVar
|
||||
|
||||
import regex as re
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import DataclassInstance
|
||||
@ -45,3 +50,96 @@ def get_field(cls: ConfigType, name: str) -> Field:
|
||||
return field(default=default)
|
||||
raise ValueError(
|
||||
f"{cls.__name__}.{name} must have a default value or default factory.")
|
||||
|
||||
|
||||
def contains_object_print(text: str) -> bool:
|
||||
"""
|
||||
Check if the text looks like a printed Python object, e.g.
|
||||
contains any substring matching the pattern: "at 0xFFFFFFF>"
|
||||
We match against 0x followed by 2-16 hex chars (there's
|
||||
a max of 16 on a 64-bit system).
|
||||
|
||||
Args:
|
||||
text (str): The text to check
|
||||
|
||||
Returns:
|
||||
result (bool): `True` if a match is found, `False` otherwise.
|
||||
"""
|
||||
pattern = r'at 0x[a-fA-F0-9]{2,16}>'
|
||||
match = re.search(pattern, text)
|
||||
return match is not None
|
||||
|
||||
|
||||
def assert_hashable(text: str) -> bool:
|
||||
if not contains_object_print(text):
|
||||
return True
|
||||
raise AssertionError(
|
||||
f"vLLM tried to hash some configs that may have Python objects ids "
|
||||
f"in them. This is a bug, please file an issue. "
|
||||
f"Text being hashed: {text}")
|
||||
|
||||
|
||||
def get_attr_docs(cls: type[Any]) -> dict[str, str]:
|
||||
"""
|
||||
Get any docstrings placed after attribute assignments in a class body.
|
||||
|
||||
https://davidism.com/mit-license/
|
||||
"""
|
||||
|
||||
def pairwise(iterable):
|
||||
"""
|
||||
Manually implement https://docs.python.org/3/library/itertools.html#itertools.pairwise
|
||||
|
||||
Can be removed when Python 3.9 support is dropped.
|
||||
"""
|
||||
iterator = iter(iterable)
|
||||
a = next(iterator, None)
|
||||
|
||||
for b in iterator:
|
||||
yield a, b
|
||||
a = b
|
||||
|
||||
try:
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
except (OSError, KeyError, TypeError):
|
||||
# HACK: Python 3.13+ workaround - set missing __firstlineno__
|
||||
# Workaround can be removed after we upgrade to pydantic==2.12.0
|
||||
with open(inspect.getfile(cls)) as f:
|
||||
for i, line in enumerate(f):
|
||||
if f"class {cls.__name__}" in line and ":" in line:
|
||||
cls.__firstlineno__ = i + 1
|
||||
break
|
||||
cls_node = ast.parse(textwrap.dedent(inspect.getsource(cls))).body[0]
|
||||
|
||||
if not isinstance(cls_node, ast.ClassDef):
|
||||
raise TypeError("Given object was not a class.")
|
||||
|
||||
out = {}
|
||||
|
||||
# Consider each pair of nodes.
|
||||
for a, b in pairwise(cls_node.body):
|
||||
# Must be an assignment then a constant string.
|
||||
if (not isinstance(a, (ast.Assign, ast.AnnAssign))
|
||||
or not isinstance(b, ast.Expr)
|
||||
or not isinstance(b.value, ast.Constant)
|
||||
or not isinstance(b.value.value, str)):
|
||||
continue
|
||||
|
||||
doc = inspect.cleandoc(b.value.value)
|
||||
|
||||
# An assignment can have multiple targets (a = b = v), but an
|
||||
# annotated assignment only has one target.
|
||||
targets = a.targets if isinstance(a, ast.Assign) else [a.target]
|
||||
|
||||
for target in targets:
|
||||
# Must be assigning to a plain name.
|
||||
if not isinstance(target, ast.Name):
|
||||
continue
|
||||
|
||||
out[target.id] = doc
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def is_init_field(cls: ConfigType, name: str) -> bool:
|
||||
return next(f for f in fields(cls) if f.name == name).init
|
||||
|
@ -106,3 +106,8 @@ KVConnectorFactory.register_connector(
|
||||
"MultiConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.multi_connector",
|
||||
"MultiConnector")
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"OffloadingConnector",
|
||||
"vllm.distributed.kv_transfer.kv_connector.v1.offloading_connector",
|
||||
"OffloadingConnector")
|
||||
|
@ -129,7 +129,7 @@ class KVOutputAggregator:
|
||||
def aggregate(self,
|
||||
outputs: list[ModelRunnerOutput],
|
||||
output_rank: int = 0) -> ModelRunnerOutput:
|
||||
# aggregate kv_connector_output from all workers
|
||||
# Aggregate kv_connector_output from all workers
|
||||
|
||||
def update_finished_set(req_ids: Optional[set[str]],
|
||||
remaining_count_dict: dict[str, int],
|
||||
@ -142,8 +142,9 @@ class KVOutputAggregator:
|
||||
|
||||
finished_sending = set[str]()
|
||||
finished_recving = set[str]()
|
||||
for output in outputs:
|
||||
output = output.kv_connector_output
|
||||
aggregated_kv_connector_stats = None
|
||||
for model_runner_output in outputs:
|
||||
output = model_runner_output.kv_connector_output
|
||||
if not output:
|
||||
continue
|
||||
update_finished_set(output.finished_sending,
|
||||
@ -151,12 +152,26 @@ class KVOutputAggregator:
|
||||
update_finished_set(output.finished_recving,
|
||||
self._recv_remaining_count, finished_recving)
|
||||
|
||||
# Aggregate kv_connector_stats from all workers.
|
||||
if aggregated_kv_connector_stats is None:
|
||||
# Use the first worker's kv_connector_stats as accumulator.
|
||||
aggregated_kv_connector_stats = output.kv_connector_stats
|
||||
elif kv_connector_stats := output.kv_connector_stats:
|
||||
if aggregated_kv_connector_stats is None:
|
||||
aggregated_kv_connector_stats = kv_connector_stats
|
||||
else:
|
||||
assert isinstance(aggregated_kv_connector_stats,
|
||||
type(kv_connector_stats))
|
||||
aggregated_kv_connector_stats = \
|
||||
aggregated_kv_connector_stats.aggregate(kv_connector_stats)
|
||||
|
||||
# select output of the worker specified by output_rank
|
||||
output = outputs[output_rank]
|
||||
|
||||
output.kv_connector_output = KVConnectorOutput(
|
||||
finished_sending=finished_sending or None,
|
||||
finished_recving=finished_recving or None,
|
||||
kv_connector_stats=aggregated_kv_connector_stats or None,
|
||||
)
|
||||
|
||||
return output
|
||||
|
@ -49,6 +49,8 @@ if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
@ -235,6 +237,12 @@ class KVConnectorBase_V1(ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
Get the KV connector stats collected during the last interval.
|
||||
"""
|
||||
return None
|
||||
|
||||
# ==============================
|
||||
# Scheduler-side methods
|
||||
# ==============================
|
||||
@ -365,4 +373,16 @@ class KVConnectorBase_V1(ABC):
|
||||
int: expected sending or receiving completion count.
|
||||
"""
|
||||
|
||||
return None
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: Optional[dict[str,
|
||||
Any]] = None) -> Optional["KVConnectorStats"]:
|
||||
"""
|
||||
KVConnectorStats resolution method. This method allows dynamically
|
||||
registered connectors to return their own KVConnectorStats object,
|
||||
which can implement custom aggregation logic on the data dict.
|
||||
"""
|
||||
return None
|
||||
|
100
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
100
vllm/distributed/kv_transfer/kv_connector/v1/metrics.py
Normal file
@ -0,0 +1,100 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_transfer_state import (
|
||||
has_kv_transfer_group)
|
||||
from vllm.logger import init_logger
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KVConnectorStats:
|
||||
"""
|
||||
Base class for KV Connector Stats, a container for transfer performance
|
||||
metrics or otherwise important telemetry from the connector.
|
||||
All sub-classes need to be serializable as stats are sent from worker to
|
||||
logger process.
|
||||
"""
|
||||
data: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def reset(self):
|
||||
"""Reset the stats, clear the state."""
|
||||
raise NotImplementedError
|
||||
|
||||
def aggregate(self, other: "KVConnectorStats") -> "KVConnectorStats":
|
||||
"""
|
||||
Aggregate stats with another `KVConnectorStats` object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def reduce(self) -> dict[str, Union[int, float]]:
|
||||
"""
|
||||
Reduce the observations collected during a time interval to one or
|
||||
more representative values (eg avg/median/sum of the series).
|
||||
This is meant to be called by the logger to produce a summary of the
|
||||
stats for the last time interval.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Return True if the stats are empty."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class KVConnectorLogging:
|
||||
|
||||
def __init__(self, kv_tranfer_config: KVTransferConfig):
|
||||
# This should be called on frontend process.
|
||||
assert not has_kv_transfer_group()
|
||||
# Instantiate the connector's stats class.
|
||||
if kv_tranfer_config and kv_tranfer_config.kv_connector:
|
||||
self.connector_cls = KVConnectorFactory.get_connector_class(
|
||||
kv_tranfer_config)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.transfer_stats_accumulator: Optional[KVConnectorStats] = None
|
||||
|
||||
def observe(self, transfer_stats_data: dict[str, Any]):
|
||||
# Should not be called when a KVConnector is not configured.
|
||||
assert self.connector_cls is not None
|
||||
# Called periodically when connector syncs with the scheduler.
|
||||
# Note that this is not the same as the logging interval.
|
||||
# We expect transfer_stats_data to be aggregated across all workers and
|
||||
# consist of observations from a single connector or a MultiConnector.
|
||||
transfer_stats = self.connector_cls.build_kv_connector_stats(
|
||||
transfer_stats_data)
|
||||
if transfer_stats is None:
|
||||
logger.warning_once(
|
||||
"The connector %s is collecting stats but "
|
||||
"does not implement the "
|
||||
"`build_kv_connector_stats` method. "
|
||||
"Stats will not be logged.", self.connector_cls)
|
||||
return
|
||||
|
||||
if self.transfer_stats_accumulator is None:
|
||||
self.transfer_stats_accumulator = transfer_stats
|
||||
else:
|
||||
# Accumulate last interval stats.
|
||||
self.transfer_stats_accumulator = \
|
||||
self.transfer_stats_accumulator.aggregate(transfer_stats)
|
||||
|
||||
def log(self, log_fn=logger.info):
|
||||
"""Log transfer metrics periodically, similar to throughput logging"""
|
||||
if (self.transfer_stats_accumulator
|
||||
and not self.transfer_stats_accumulator.is_empty()):
|
||||
# Produce a single cumulative stats object for the last time
|
||||
# interval from the recorded observations.
|
||||
xfer_metrics = self.transfer_stats_accumulator.reduce()
|
||||
xfer_metrics_str = ", ".join(f"{k}={v}"
|
||||
for k, v in xfer_metrics.items())
|
||||
log_fn("KV Transfer metrics: %s", xfer_metrics_str)
|
||||
|
||||
# Reset metrics for next interval
|
||||
self.reset()
|
@ -9,19 +9,21 @@ import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.config.kv_transfer import KVTransferConfig
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import (
|
||||
KVConnectorFactory)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
from vllm.distributed.kv_events import KVCacheEvent
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.request import Request
|
||||
|
||||
logger = init_logger(__name__)
|
||||
@ -33,6 +35,43 @@ class MultiKVConnectorMetadata(KVConnectorMetadata):
|
||||
extra_async_saves: Optional[dict[str, int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class MultiKVConnectorStats(KVConnectorStats):
|
||||
"""
|
||||
Maintain a dict of KVConnectorStats objects, one for each connector.
|
||||
This is used to aggregate the stats from all connectors separately.
|
||||
"""
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
for connector_id, stats in other.data.items():
|
||||
if connector_id not in self.data:
|
||||
self[connector_id] = stats
|
||||
else:
|
||||
assert isinstance(stats, type(self.data[connector_id]))
|
||||
self[connector_id] = self[connector_id].aggregate(stats)
|
||||
return self
|
||||
|
||||
def reset(self):
|
||||
for stats in self.data.values():
|
||||
stats.reset()
|
||||
|
||||
def reduce(self) -> dict[str, Any]:
|
||||
# TODO (NickLucche) Adjust for logging on separate lines
|
||||
return {
|
||||
connector_id: stats.reduce()
|
||||
for connector_id, stats in self.data.items()
|
||||
}
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return all(stats.is_empty() for stats in self.data.values())
|
||||
|
||||
def __getitem__(self, connector_id: str) -> KVConnectorStats:
|
||||
return self.data[connector_id]
|
||||
|
||||
def __setitem__(self, connector_id: str, stats: KVConnectorStats):
|
||||
self.data[connector_id] = stats
|
||||
|
||||
|
||||
class MultiConnector(KVConnectorBase_V1):
|
||||
"""
|
||||
A wrapper for using multiple KVConnectors at the same time.
|
||||
@ -46,6 +85,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||
super().__init__(vllm_config=vllm_config, role=role)
|
||||
self._connectors: list[KVConnectorBase_V1] = []
|
||||
self._ktc_kv_transfer_config = []
|
||||
ktcs = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
|
||||
"connectors")
|
||||
assert ktcs is not None
|
||||
@ -57,6 +97,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
**ktc, engine_id=engine_id)
|
||||
self._connectors.append(
|
||||
KVConnectorFactory.create_connector(temp_config, role))
|
||||
self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config)
|
||||
|
||||
# A mapping from request id to the index of the connector chosen to
|
||||
# load the request from (if any).
|
||||
@ -227,7 +268,7 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
|
||||
return async_saves > 0, kv_txfer_params
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
def take_events(self) -> Iterable["KVCacheEvent"]:
|
||||
for c in self._connectors:
|
||||
yield from c.take_events()
|
||||
|
||||
@ -264,3 +305,24 @@ class MultiConnector(KVConnectorBase_V1):
|
||||
f"({', '.join(layouts) })."
|
||||
f"All connectors must use the same layout.")
|
||||
return next(iter(layouts), None)
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: Optional[dict[str,
|
||||
Any]] = None) -> Optional[KVConnectorStats]:
|
||||
return MultiKVConnectorStats(data=data) if data is not None \
|
||||
else MultiKVConnectorStats()
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[MultiKVConnectorStats]:
|
||||
# Group connector stats by connector type.
|
||||
stats_by_connector: Optional[MultiKVConnectorStats] = None
|
||||
for c in self._connectors:
|
||||
stats = c.get_kv_connector_stats()
|
||||
if stats is None:
|
||||
continue
|
||||
if stats_by_connector is None:
|
||||
# Lazy init to allow optional return value.
|
||||
stats_by_connector = MultiKVConnectorStats()
|
||||
stats_by_connector[c.__class__.__name__] = stats
|
||||
return stats_by_connector
|
||||
|
@ -1,6 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import contextlib
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
import queue
|
||||
@ -11,7 +12,7 @@ from collections import defaultdict
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import Future, ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
import msgspec
|
||||
import numpy as np
|
||||
@ -23,6 +24,8 @@ from vllm.attention.selector import backend_name_to_enum, get_attn_backend
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
CopyBlocksOp, KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.metrics import (
|
||||
KVConnectorStats)
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size,
|
||||
get_tp_group)
|
||||
@ -33,7 +36,6 @@ from vllm.platforms import _Backend, current_platform
|
||||
from vllm.utils import make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.attention.backends.utils import get_kv_cache_layout
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@ -206,6 +208,18 @@ class NixlConnector(KVConnectorBase_V1):
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished()
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_kv_connector_stats()
|
||||
|
||||
@classmethod
|
||||
def build_kv_connector_stats(
|
||||
cls,
|
||||
data: Optional[dict[str,
|
||||
Any]] = None) -> Optional[KVConnectorStats]:
|
||||
return NixlKVConnectorStats(data=data) if data is not None \
|
||||
else NixlKVConnectorStats()
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
@ -377,6 +391,7 @@ class NixlConnectorScheduler:
|
||||
Once a request is finished, determine whether request blocks
|
||||
should be freed now or will be sent asynchronously and freed later.
|
||||
"""
|
||||
from vllm.v1.request import RequestStatus
|
||||
|
||||
params = request.kv_transfer_params
|
||||
logger.debug(
|
||||
@ -550,6 +565,7 @@ class NixlConnectorWorker:
|
||||
# With heterogeneous TP, P must wait for all assigned D TP workers to
|
||||
# finish reading before safely freeing the blocks.
|
||||
self.consumer_notification_counts_by_req = defaultdict[ReqId, int](int)
|
||||
self.xfer_stats = NixlKVConnectorStats()
|
||||
|
||||
def __del__(self):
|
||||
"""Cleanup background threads on destruction."""
|
||||
@ -1097,6 +1113,8 @@ class NixlConnectorWorker:
|
||||
xfer_state = self.nixl_wrapper.check_xfer_state(handle)
|
||||
if xfer_state == "DONE":
|
||||
self.nixl_wrapper.release_xfer_handle(handle)
|
||||
# TODO (NickLucche) Get from NIXL telemetry once integrated
|
||||
self.xfer_stats.record_transfer()
|
||||
elif xfer_state == "PROC":
|
||||
in_progress = True
|
||||
continue
|
||||
@ -1248,7 +1266,6 @@ class NixlConnectorWorker:
|
||||
self.nixl_wrapper.transfer(handle)
|
||||
|
||||
# Use handle to check completion in future step().
|
||||
# TODO (NickLucche) surface xfer elapsed time
|
||||
self._recving_transfers[request_id].append(
|
||||
(handle, time.perf_counter()))
|
||||
|
||||
@ -1300,6 +1317,15 @@ class NixlConnectorWorker:
|
||||
block_len = self.block_len
|
||||
return block_len
|
||||
|
||||
def get_kv_connector_stats(self) -> Optional[KVConnectorStats]:
|
||||
"""
|
||||
Get the KV transfer stats for the connector.
|
||||
"""
|
||||
# Clear stats for next iteration
|
||||
if not self.xfer_stats.is_empty():
|
||||
return self.xfer_stats.clone_and_reset()
|
||||
return None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
@ -1318,3 +1344,39 @@ def zmq_ctx(socket_type: Any, addr: str) -> Iterator[zmq.Socket]:
|
||||
finally:
|
||||
if ctx is not None:
|
||||
ctx.destroy(linger=0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class NixlKVConnectorStats(KVConnectorStats):
|
||||
"""Container for transfer performance metrics"""
|
||||
|
||||
def __post_init__(self):
|
||||
if "num_successful_transfers" not in self.data:
|
||||
self.data["num_successful_transfers"] = 0
|
||||
|
||||
def reset(self):
|
||||
self.data = {"num_successful_transfers": 0}
|
||||
|
||||
def record_transfer(self):
|
||||
# TODO: record actual transfer stats when available
|
||||
self.data["num_successful_transfers"] += 1
|
||||
|
||||
def clone_and_reset(self) -> "NixlKVConnectorStats":
|
||||
old = copy.copy(self)
|
||||
self.reset()
|
||||
return old
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.data["num_successful_transfers"] == 0
|
||||
|
||||
def aggregate(self, other: KVConnectorStats) -> KVConnectorStats:
|
||||
if not other.is_empty():
|
||||
self.data["num_successful_transfers"] += other.data[
|
||||
"num_successful_transfers"]
|
||||
return self
|
||||
|
||||
def reduce(self) -> dict[str, Union[int, float]]:
|
||||
# TODO: reduce stats to a single value, calculate latency/throughput
|
||||
return {
|
||||
"num_successful_transfers": self.data["num_successful_transfers"]
|
||||
}
|
@ -0,0 +1,485 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from itertools import islice
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.attention import AttentionMetadata
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_events import BlockRemoved, BlockStored, KVCacheEvent
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1 import (KVConnectorBase_V1,
|
||||
KVConnectorRole)
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||
KVConnectorMetadata)
|
||||
from vllm.forward_context import ForwardContext
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||
from vllm.v1.core.kv_cache_utils import BlockHash
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_offload.abstract import OffloadingManager
|
||||
from vllm.v1.kv_offload.factory import OffloadingSpecFactory
|
||||
from vllm.v1.kv_offload.mediums import GPULoadStoreSpec
|
||||
from vllm.v1.kv_offload.spec import OffloadingSpec
|
||||
from vllm.v1.kv_offload.worker.worker import OffloadingWorker, TransferSpec
|
||||
from vllm.v1.outputs import KVConnectorOutput
|
||||
from vllm.v1.request import Request
|
||||
|
||||
ReqId = str
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class OffloadingConnectorMetadata(KVConnectorMetadata):
|
||||
reqs_to_load: dict[ReqId, TransferSpec]
|
||||
reqs_to_store: dict[ReqId, TransferSpec]
|
||||
|
||||
|
||||
class OffloadingConnector(KVConnectorBase_V1):
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
|
||||
super().__init__(vllm_config, role)
|
||||
|
||||
spec = OffloadingSpecFactory.create_spec(vllm_config)
|
||||
|
||||
self.connector_scheduler: Optional[OffloadingConnectorScheduler] = None
|
||||
self.connector_worker: Optional[OffloadingConnectorWorker] = None
|
||||
if role == KVConnectorRole.SCHEDULER:
|
||||
self.connector_scheduler = OffloadingConnectorScheduler(spec)
|
||||
elif role == KVConnectorRole.WORKER:
|
||||
self.connector_worker = OffloadingConnectorWorker(spec)
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
assert self.connector_worker is not None
|
||||
self.connector_worker.register_kv_caches(kv_caches)
|
||||
|
||||
def start_load_kv(self, forward_context: "ForwardContext",
|
||||
**kwargs) -> None:
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_load_kv(self._connector_metadata)
|
||||
|
||||
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||
pass
|
||||
|
||||
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||
pass
|
||||
|
||||
def wait_for_save(self):
|
||||
assert self.connector_worker is not None
|
||||
assert isinstance(self._connector_metadata,
|
||||
OffloadingConnectorMetadata)
|
||||
self.connector_worker.start_store_kv(self._connector_metadata)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
assert self.connector_worker is not None
|
||||
return self.connector_worker.get_finished(finished_req_ids)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||
request, num_computed_tokens)
|
||||
|
||||
def update_state_after_alloc(self, request: "Request",
|
||||
blocks: "KVCacheBlocks",
|
||||
num_external_tokens: int):
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.update_state_after_alloc(
|
||||
request, blocks, num_external_tokens)
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.build_connector_meta(scheduler_output)
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
assert self.connector_scheduler is not None
|
||||
self.connector_scheduler.update_connector_output(connector_output)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.request_finished(request, block_ids)
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
assert self.connector_scheduler is not None
|
||||
return self.connector_scheduler.take_events()
|
||||
|
||||
|
||||
class OffloadingConnectorScheduler:
|
||||
"""Implementation of Scheduler side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.gpu_block_size = spec.gpu_block_size
|
||||
self.offloaded_block_size = spec.offloaded_block_size
|
||||
self.block_size_factor = (self.offloaded_block_size //
|
||||
self.gpu_block_size)
|
||||
self.manager: OffloadingManager = spec.get_manager()
|
||||
|
||||
self._requests: dict[ReqId, Request] = {}
|
||||
# list of GPU block IDs per request
|
||||
self._request_block_ids: dict[ReqId, list[int]] = {}
|
||||
# requests to load for the current scheduler step
|
||||
self._reqs_to_load: dict[ReqId, TransferSpec] = {}
|
||||
# request blocks are stored in order
|
||||
# index of next block (of size offloaded_block_size) to offload
|
||||
self._next_stored_block_idx: dict[ReqId, int] = {}
|
||||
|
||||
# request ID -> set(block hashes being stored/load)
|
||||
self._reqs_being_stored = defaultdict[ReqId, set[BlockHash]](set)
|
||||
self._reqs_being_loaded = defaultdict[ReqId, set[BlockHash]](set)
|
||||
|
||||
def _get_block_hashes(
|
||||
self,
|
||||
req: Request,
|
||||
start_idx: int = 0,
|
||||
end_idx: Optional[int] = None,
|
||||
) -> Iterable[BlockHash]:
|
||||
return islice(
|
||||
req.block_hashes,
|
||||
self.block_size_factor * start_idx + self.block_size_factor - 1,
|
||||
self.block_size_factor * end_idx if end_idx else None,
|
||||
self.block_size_factor)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: Request,
|
||||
num_computed_tokens: int) -> tuple[int, bool]:
|
||||
"""
|
||||
Get number of new tokens that can be loaded beyond the
|
||||
num_computed_tokens.
|
||||
|
||||
Args:
|
||||
request (Request): the request object.
|
||||
num_computed_tokens (int): the number of locally
|
||||
computed tokens for this request
|
||||
|
||||
Returns:
|
||||
A tuple with the following elements:
|
||||
- The number of tokens that can be loaded beyond what is
|
||||
already computed.
|
||||
- `True` if tokens will be loaded asynchronously
|
||||
(between scheduler steps).
|
||||
"""
|
||||
num_blocks = request.num_tokens // self.offloaded_block_size
|
||||
|
||||
assert (len(request.block_hashes) //
|
||||
self.block_size_factor == num_blocks)
|
||||
block_hashes = self._get_block_hashes(request)
|
||||
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
full_block_tokens = self.offloaded_block_size * num_blocks
|
||||
if full_block_tokens - num_computed_tokens < self.offloaded_block_size:
|
||||
# we can load less than a block, skip
|
||||
return 0, False
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
hits = self.manager.lookup(
|
||||
self._get_block_hashes(request, start_idx=start_block_idx))
|
||||
if hits == 0:
|
||||
return 0, False
|
||||
|
||||
num_hit_tokens = (self.offloaded_block_size *
|
||||
(start_block_idx + hits) - num_computed_tokens)
|
||||
logger.debug(
|
||||
"Request %s hit %s offloaded tokens after %s GPU hit tokens",
|
||||
request.request_id,
|
||||
num_hit_tokens,
|
||||
num_computed_tokens,
|
||||
)
|
||||
if num_hit_tokens < self.offloaded_block_size:
|
||||
return 0, False
|
||||
|
||||
return num_hit_tokens, True
|
||||
|
||||
def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks,
|
||||
num_external_tokens: int):
|
||||
self._requests[request.request_id] = request
|
||||
# the block ids are updated in _get_reqs_to_store
|
||||
self._request_block_ids[request.request_id] = []
|
||||
|
||||
if num_external_tokens == 0:
|
||||
return
|
||||
|
||||
block_groups = blocks.get_block_ids()
|
||||
block_ids = block_groups[0]
|
||||
|
||||
num_computed_gpu_blocks = sum(block.block_hash is not None
|
||||
for block in blocks.blocks[0])
|
||||
num_computed_tokens = num_computed_gpu_blocks * self.gpu_block_size
|
||||
full_block_tokens = num_computed_tokens + num_external_tokens
|
||||
assert full_block_tokens % self.offloaded_block_size == 0
|
||||
|
||||
num_pending_gpu_blocks = len(block_ids) - num_computed_gpu_blocks
|
||||
assert (num_external_tokens == num_pending_gpu_blocks *
|
||||
self.gpu_block_size)
|
||||
|
||||
start_block_idx = num_computed_tokens // self.offloaded_block_size
|
||||
num_blocks = full_block_tokens // self.offloaded_block_size
|
||||
|
||||
assert (len(request.block_hashes) // self.block_size_factor
|
||||
>= num_blocks)
|
||||
block_hashes = self._get_block_hashes(request,
|
||||
start_idx=start_block_idx,
|
||||
end_idx=num_blocks)
|
||||
|
||||
src_spec = self.manager.prepare_load(block_hashes)
|
||||
dst_spec = GPULoadStoreSpec(block_ids[num_computed_gpu_blocks:])
|
||||
|
||||
block_hashes = self._get_block_hashes(request,
|
||||
start_idx=start_block_idx,
|
||||
end_idx=num_blocks)
|
||||
|
||||
self._reqs_to_load[request.request_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_loaded[request.request_id].update(block_hashes)
|
||||
self._next_stored_block_idx[request.request_id] = num_blocks
|
||||
|
||||
def _get_reqs_to_store(self, scheduler_output: SchedulerOutput):
|
||||
reqs_to_store: dict[ReqId, TransferSpec] = {}
|
||||
# iterate over both new and cached requests
|
||||
for req_id, new_block_id_groups, preempted in yield_req_data(
|
||||
scheduler_output):
|
||||
|
||||
if preempted:
|
||||
self._request_block_ids[req_id] = []
|
||||
|
||||
if new_block_id_groups:
|
||||
new_block_ids = new_block_id_groups[0]
|
||||
self._request_block_ids[req_id] += new_block_ids
|
||||
|
||||
block_ids = self._request_block_ids[req_id]
|
||||
|
||||
req = self._requests[req_id]
|
||||
new_tokens = scheduler_output.num_scheduled_tokens[req_id]
|
||||
total_tokens = req.num_computed_tokens + new_tokens
|
||||
num_blocks = total_tokens // self.offloaded_block_size
|
||||
start_block_idx = self._next_stored_block_idx.get(req_id, 0)
|
||||
num_new_blocks = num_blocks - start_block_idx
|
||||
|
||||
if num_new_blocks <= 0:
|
||||
continue
|
||||
|
||||
num_gpu_blocks = num_blocks * self.block_size_factor
|
||||
assert len(req.block_hashes) >= num_gpu_blocks
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks)
|
||||
store_output = self.manager.prepare_store(new_block_hashes)
|
||||
if store_output is None:
|
||||
logger.warning("Cannot store %s blocks", num_new_blocks)
|
||||
break
|
||||
|
||||
self._next_stored_block_idx[req_id] = num_blocks
|
||||
|
||||
if not store_output.block_hashes_to_store:
|
||||
continue
|
||||
block_hashes_to_store = set(store_output.block_hashes_to_store)
|
||||
|
||||
block_hashes = self._get_block_hashes(req, end_idx=num_blocks)
|
||||
self.manager.touch(block_hashes)
|
||||
|
||||
new_block_hashes = self._get_block_hashes(
|
||||
req, start_idx=start_block_idx, end_idx=num_blocks)
|
||||
dst_spec = store_output.store_spec
|
||||
src_block_ids: list[int] = []
|
||||
for idx, blk_hash in enumerate(new_block_hashes):
|
||||
if blk_hash not in block_hashes_to_store:
|
||||
continue
|
||||
offloaded_block_idx = start_block_idx + idx
|
||||
gpu_block_idx = offloaded_block_idx * self.block_size_factor
|
||||
for i in range(self.block_size_factor):
|
||||
src_block_ids.append(block_ids[gpu_block_idx + i])
|
||||
src_spec = GPULoadStoreSpec(src_block_ids)
|
||||
|
||||
reqs_to_store[req_id] = (src_spec, dst_spec)
|
||||
self._reqs_being_stored[req_id] |= block_hashes_to_store
|
||||
|
||||
logger.debug(
|
||||
"Request %s offloading %s blocks starting from block #%d",
|
||||
req_id,
|
||||
len(block_hashes_to_store),
|
||||
start_block_idx,
|
||||
)
|
||||
|
||||
return reqs_to_store
|
||||
|
||||
def build_connector_meta(
|
||||
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||
meta = OffloadingConnectorMetadata(
|
||||
reqs_to_load=self._reqs_to_load,
|
||||
reqs_to_store=self._get_reqs_to_store(scheduler_output))
|
||||
self._reqs_to_load = {}
|
||||
return meta
|
||||
|
||||
def update_connector_output(self, connector_output: KVConnectorOutput):
|
||||
"""
|
||||
Update KVConnector state from worker-side connectors output.
|
||||
|
||||
Args:
|
||||
connector_output (KVConnectorOutput): the worker-side
|
||||
connectors output.
|
||||
"""
|
||||
for req_id in connector_output.finished_sending or []:
|
||||
block_hashes = self._reqs_being_stored.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_store(block_hashes)
|
||||
|
||||
for req_id in connector_output.finished_recving or []:
|
||||
block_hashes = self._reqs_being_loaded.pop(req_id, None)
|
||||
if block_hashes:
|
||||
self.manager.complete_load(block_hashes)
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: Request,
|
||||
block_ids: list[int],
|
||||
) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||
"""
|
||||
Called when a request has finished, before its blocks are freed.
|
||||
|
||||
Returns:
|
||||
True if the request is being saved/sent asynchronously and blocks
|
||||
should not be freed until the request_id is returned from
|
||||
get_finished().
|
||||
Optional KVTransferParams to be included in the request outputs
|
||||
returned by the engine.
|
||||
"""
|
||||
req_id = request.request_id
|
||||
self._requests.pop(req_id, None)
|
||||
self._request_block_ids.pop(req_id, None)
|
||||
self._next_stored_block_idx.pop(req_id, None)
|
||||
|
||||
request_being_stored = req_id in self._reqs_being_stored
|
||||
return request_being_stored, None
|
||||
|
||||
def take_events(self) -> Iterable[KVCacheEvent]:
|
||||
"""Take the KV cache events from the connector.
|
||||
|
||||
Returns:
|
||||
A list of KV cache events.
|
||||
"""
|
||||
for event in self.manager.take_events():
|
||||
if event.removed:
|
||||
yield BlockRemoved(block_hashes=event.block_hashes,
|
||||
medium=event.medium)
|
||||
else:
|
||||
yield BlockStored(block_hashes=event.block_hashes,
|
||||
parent_block_hash=None,
|
||||
token_ids=[],
|
||||
lora_id=None,
|
||||
block_size=event.block_size,
|
||||
medium=event.medium)
|
||||
|
||||
|
||||
class OffloadingConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, spec: OffloadingSpec):
|
||||
self.spec = spec
|
||||
self.worker = OffloadingWorker()
|
||||
|
||||
self._job_counter = 0
|
||||
|
||||
# req_id -> (job_id, store)
|
||||
self._jobs: dict[int, tuple[ReqId, bool]] = {}
|
||||
# req_id -> active job IDs
|
||||
self._load_job: dict[ReqId, int] = {}
|
||||
# req_id -> set(active job IDs)
|
||||
self._store_jobs = defaultdict[ReqId, set[int]](set)
|
||||
|
||||
self._finished_reqs_waiting_for_store: set[ReqId] = set()
|
||||
|
||||
def _generate_job_id(self) -> int:
|
||||
job_id = self._job_counter
|
||||
self._job_counter = job_id + 1
|
||||
return job_id
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
for src_cls, dst_cls, handler in (self.spec.get_handlers(kv_caches)):
|
||||
self.worker.register_handler(src_cls, dst_cls, handler)
|
||||
|
||||
def start_load_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_load.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, False)
|
||||
assert req_id not in self._load_job
|
||||
self._load_job[req_id] = job_id
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def start_store_kv(self, metadata: OffloadingConnectorMetadata):
|
||||
for req_id, transfer_spec in metadata.reqs_to_store.items():
|
||||
job_id = self._generate_job_id()
|
||||
self._jobs[job_id] = (req_id, True)
|
||||
self._store_jobs[req_id].add(job_id)
|
||||
assert self.worker.transfer_async(job_id, transfer_spec)
|
||||
|
||||
def get_finished(self,
|
||||
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
|
||||
"""
|
||||
Notifies worker-side connector ids of requests that have
|
||||
finished generating tokens.
|
||||
Returns a list of request IDs that finished loading or storing.
|
||||
|
||||
Returns:
|
||||
ids of requests that have finished asynchronous transfer
|
||||
tuple of (sending/saving ids, recving/loading ids).
|
||||
"""
|
||||
finished_sending = set()
|
||||
finished_recving = set()
|
||||
for job_id, success in self.worker.get_finished():
|
||||
# we currently do not support job failures
|
||||
assert success
|
||||
req_id, store = self._jobs.pop(job_id)
|
||||
if store:
|
||||
req_jobs = self._store_jobs[req_id]
|
||||
req_jobs.remove(job_id)
|
||||
if req_jobs:
|
||||
continue
|
||||
|
||||
if req_id in self._finished_reqs_waiting_for_store:
|
||||
self._finished_reqs_waiting_for_store.remove(req_id)
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
else:
|
||||
req_job = self._load_job[req_id]
|
||||
assert job_id == req_job
|
||||
del self._load_job[req_id]
|
||||
finished_recving.add(req_id)
|
||||
|
||||
for req_id in finished_req_ids:
|
||||
pending_req_jobs = self._store_jobs.get(req_id)
|
||||
if pending_req_jobs:
|
||||
self._finished_reqs_waiting_for_store.add(req_id)
|
||||
elif pending_req_jobs is not None:
|
||||
finished_sending.add(req_id)
|
||||
del self._store_jobs[req_id]
|
||||
|
||||
return finished_sending, finished_recving
|
||||
|
||||
|
||||
def yield_req_data(
|
||||
scheduler_output) -> Iterator[tuple[str, tuple[list[int], ...], bool]]:
|
||||
"""
|
||||
Yields:
|
||||
(req_id, new_block_id_groups, preempted)
|
||||
"""
|
||||
# new requests
|
||||
for req_data in scheduler_output.scheduled_new_reqs:
|
||||
yield req_data.req_id, req_data.block_ids, False
|
||||
|
||||
# cached requests
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
yield from zip(cached_reqs.req_ids, cached_reqs.new_block_ids,
|
||||
cached_reqs.resumed_from_preemption)
|
@ -27,11 +27,11 @@ from vllm.config import (BlockSize, CacheConfig, CacheDType, CompilationConfig,
|
||||
EPLBConfig, HfOverrides, KVEventsConfig,
|
||||
KVTransferConfig, LoadConfig, LogprobsMode,
|
||||
LoRAConfig, MambaDType, MMEncoderTPMode, ModelConfig,
|
||||
ModelDType, ModelImpl, ObservabilityConfig,
|
||||
ParallelConfig, PoolerConfig, PrefixCachingHashAlgo,
|
||||
RunnerOption, SchedulerConfig, SchedulerPolicy,
|
||||
SpeculativeConfig, StructuredOutputsConfig,
|
||||
TaskOption, TokenizerMode, VllmConfig, get_attr_docs)
|
||||
ModelDType, ObservabilityConfig, ParallelConfig,
|
||||
PoolerConfig, PrefixCachingHashAlgo, RunnerOption,
|
||||
SchedulerConfig, SchedulerPolicy, SpeculativeConfig,
|
||||
StructuredOutputsConfig, TaskOption, TokenizerMode,
|
||||
VllmConfig, get_attr_docs)
|
||||
from vllm.config.multimodal import MMCacheType, MultiModalConfig
|
||||
from vllm.config.parallel import ExpertPlacementStrategy
|
||||
from vllm.config.utils import get_field
|
||||
@ -333,6 +333,8 @@ class EngineArgs:
|
||||
enable_eplb: bool = ParallelConfig.enable_eplb
|
||||
expert_placement_strategy: ExpertPlacementStrategy = \
|
||||
ParallelConfig.expert_placement_strategy
|
||||
_api_process_count: int = ParallelConfig._api_process_count
|
||||
_api_process_rank: int = ParallelConfig._api_process_rank
|
||||
num_redundant_experts: int = EPLBConfig.num_redundant_experts
|
||||
eplb_window_size: int = EPLBConfig.window_size
|
||||
eplb_step_interval: int = EPLBConfig.step_interval
|
||||
@ -441,6 +443,7 @@ class EngineArgs:
|
||||
scheduling_policy: SchedulerPolicy = SchedulerConfig.policy
|
||||
scheduler_cls: Union[str, Type[object]] = SchedulerConfig.scheduler_cls
|
||||
|
||||
pooler_config: Optional[PoolerConfig] = ModelConfig.pooler_config
|
||||
override_pooler_config: Optional[Union[dict, PoolerConfig]] = \
|
||||
ModelConfig.override_pooler_config
|
||||
compilation_config: CompilationConfig = \
|
||||
@ -547,7 +550,6 @@ class EngineArgs:
|
||||
model_group.add_argument("--max-logprobs",
|
||||
**model_kwargs["max_logprobs"])
|
||||
model_group.add_argument("--logprobs-mode",
|
||||
choices=[f.value for f in LogprobsMode],
|
||||
**model_kwargs["logprobs_mode"])
|
||||
model_group.add_argument("--disable-sliding-window",
|
||||
**model_kwargs["disable_sliding_window"])
|
||||
@ -579,8 +581,11 @@ class EngineArgs:
|
||||
help=model_kwargs["hf_token"]["help"])
|
||||
model_group.add_argument("--hf-overrides",
|
||||
**model_kwargs["hf_overrides"])
|
||||
model_group.add_argument("--pooler-config",
|
||||
**model_kwargs["pooler_config"])
|
||||
model_group.add_argument("--override-pooler-config",
|
||||
**model_kwargs["override_pooler_config"])
|
||||
**model_kwargs["override_pooler_config"],
|
||||
deprecated=True)
|
||||
model_group.add_argument("--logits-processor-pattern",
|
||||
**model_kwargs["logits_processor_pattern"])
|
||||
model_group.add_argument("--generation-config",
|
||||
@ -589,9 +594,7 @@ class EngineArgs:
|
||||
**model_kwargs["override_generation_config"])
|
||||
model_group.add_argument("--enable-sleep-mode",
|
||||
**model_kwargs["enable_sleep_mode"])
|
||||
model_group.add_argument("--model-impl",
|
||||
choices=[f.value for f in ModelImpl],
|
||||
**model_kwargs["model_impl"])
|
||||
model_group.add_argument("--model-impl", **model_kwargs["model_impl"])
|
||||
model_group.add_argument("--override-attention-dtype",
|
||||
**model_kwargs["override_attention_dtype"])
|
||||
model_group.add_argument("--logits-processors",
|
||||
@ -951,7 +954,10 @@ class EngineArgs:
|
||||
# Get the list of attributes of this dataclass.
|
||||
attrs = [attr.name for attr in dataclasses.fields(cls)]
|
||||
# Set the attributes from the parsed arguments.
|
||||
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
|
||||
engine_args = cls(**{
|
||||
attr: getattr(args, attr)
|
||||
for attr in attrs if hasattr(args, attr)
|
||||
})
|
||||
return engine_args
|
||||
|
||||
def create_model_config(self) -> ModelConfig:
|
||||
@ -1031,6 +1037,7 @@ class EngineArgs:
|
||||
mm_shm_cache_max_object_size_mb=self.
|
||||
mm_shm_cache_max_object_size_mb,
|
||||
mm_encoder_tp_mode=self.mm_encoder_tp_mode,
|
||||
pooler_config=self.pooler_config,
|
||||
override_pooler_config=self.override_pooler_config,
|
||||
logits_processor_pattern=self.logits_processor_pattern,
|
||||
generation_config=self.generation_config,
|
||||
@ -1364,6 +1371,8 @@ class EngineArgs:
|
||||
worker_cls=self.worker_cls,
|
||||
worker_extension_cls=self.worker_extension_cls,
|
||||
decode_context_parallel_size=self.decode_context_parallel_size,
|
||||
_api_process_count=self._api_process_count,
|
||||
_api_process_rank=self._api_process_rank,
|
||||
)
|
||||
|
||||
speculative_config = self.create_speculative_config(
|
||||
@ -1513,12 +1522,6 @@ class EngineArgs:
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No text embedding inputs so far.
|
||||
if self.enable_prompt_embeds:
|
||||
_raise_or_fallback(feature_name="--enable-prompt-embeds",
|
||||
recommend_to_remove=False)
|
||||
return False
|
||||
|
||||
# No Mamba or Encoder-Decoder so far.
|
||||
if not model_config.is_v1_compatible:
|
||||
_raise_or_fallback(feature_name=model_config.architectures,
|
||||
@ -1651,6 +1654,13 @@ class EngineArgs:
|
||||
"models in V0 and has been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prompt_embeds:
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V0. Prefix caching has "
|
||||
"been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
# Set max_num_seqs to 256 for VLLM_V0.
|
||||
if self.max_num_seqs is None:
|
||||
self.max_num_seqs = 256
|
||||
@ -1664,6 +1674,17 @@ class EngineArgs:
|
||||
# For pooling tasks the default is False
|
||||
if model_config.runner_type != "pooling":
|
||||
self.enable_chunked_prefill = True
|
||||
|
||||
# TODO: When prefix caching supports prompt embeds inputs, this
|
||||
# check can be removed.
|
||||
if (self.enable_prompt_embeds
|
||||
and self.enable_prefix_caching is not False):
|
||||
logger.warning(
|
||||
"--enable-prompt-embeds and --enable-prefix-caching "
|
||||
"are not supported together in V1. Prefix caching has "
|
||||
"been disabled.")
|
||||
self.enable_prefix_caching = False
|
||||
|
||||
if self.enable_prefix_caching is None:
|
||||
self.enable_prefix_caching = True
|
||||
else:
|
||||
|
@ -433,9 +433,9 @@ class LLMEngine:
|
||||
f"ExecutorBase. Got {distributed_executor_backend}.")
|
||||
executor_class = distributed_executor_backend
|
||||
elif distributed_executor_backend == "ray":
|
||||
raise RuntimeError(
|
||||
"The Ray distributed executor is only available in the v1 "
|
||||
"engine. Enable it by setting 'VLLM_USE_V1=1'.")
|
||||
from vllm.executor.ray_distributed_executor import (
|
||||
RayDistributedExecutor)
|
||||
executor_class = RayDistributedExecutor
|
||||
elif distributed_executor_backend == "mp":
|
||||
from vllm.executor.mp_distributed_executor import (
|
||||
MultiprocessingDistributedExecutor)
|
||||
|
@ -135,23 +135,20 @@ def run_headless(args: argparse.Namespace):
|
||||
def run_multi_api_server(args: argparse.Namespace):
|
||||
|
||||
assert not args.headless
|
||||
num_api_servers = args.api_server_count
|
||||
num_api_servers: int = args.api_server_count
|
||||
assert num_api_servers > 0
|
||||
|
||||
orig_mm_processor_cache_gb = args.mm_processor_cache_gb
|
||||
|
||||
if num_api_servers > 1:
|
||||
setup_multiprocess_prometheus()
|
||||
|
||||
# Not compatible with API server scale-out
|
||||
args.mm_processor_cache_gb = 0
|
||||
|
||||
listen_address, sock = setup_server(args)
|
||||
|
||||
engine_args = vllm.AsyncEngineArgs.from_cli_args(args)
|
||||
engine_args._api_process_count = num_api_servers
|
||||
engine_args._api_process_rank = -1
|
||||
|
||||
usage_context = UsageContext.OPENAI_API_SERVER
|
||||
vllm_config = engine_args.create_engine_config(usage_context=usage_context)
|
||||
model_config = vllm_config.model_config
|
||||
|
||||
if num_api_servers > 1:
|
||||
if not envs.VLLM_USE_V1:
|
||||
@ -161,10 +158,6 @@ def run_multi_api_server(args: argparse.Namespace):
|
||||
raise ValueError("VLLM_ALLOW_RUNTIME_LORA_UPDATING cannot be used "
|
||||
"with api_server_count > 1")
|
||||
|
||||
if model_config.is_multimodal_model and orig_mm_processor_cache_gb > 0:
|
||||
logger.warning("Multi-modal processor cache is disabled because "
|
||||
"it is not compatible with `api_server_count > 1`.")
|
||||
|
||||
executor_class = Executor.get_class(vllm_config)
|
||||
log_stats = not engine_args.disable_log_stats
|
||||
|
||||
@ -221,9 +214,10 @@ def run_api_server_worker_proc(listen_address,
|
||||
client_config=None,
|
||||
**uvicorn_kwargs) -> None:
|
||||
"""Entrypoint for individual API server worker processes."""
|
||||
client_config = client_config or {}
|
||||
server_index = client_config.get("client_index", 0)
|
||||
|
||||
# Set process title and add process-specific prefix to stdout and stderr.
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
set_process_title("APIServer", str(server_index))
|
||||
decorate_logs()
|
||||
|
||||
|
@ -151,9 +151,11 @@ class LLM:
|
||||
multi-modal processor obtained from `AutoProcessor.from_pretrained`.
|
||||
The available overrides depend on the model that is being run.
|
||||
For example, for Phi-3-Vision: `{"num_crops": 4}`.
|
||||
override_pooler_config: Initialize non-default pooling config or
|
||||
override default pooling config for the pooling model.
|
||||
e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||
pooler_config: Initialize non-default pooling config for the pooling
|
||||
model. e.g. `PoolerConfig(pooling_type="mean", normalize=False)`.
|
||||
override_pooler_config: [DEPRECATED] Use `pooler_config` instead. This
|
||||
argument is deprecated and will be removed in v0.12.0 or v1.0.0,
|
||||
whichever is sooner.
|
||||
compilation_config: Either an integer or a dictionary. If it is an
|
||||
integer, it is used as the level of compilation optimization. If it
|
||||
is a dictionary, it can specify the full compilation configuration.
|
||||
@ -191,6 +193,7 @@ class LLM:
|
||||
hf_token: Optional[Union[bool, str]] = None,
|
||||
hf_overrides: Optional[HfOverrides] = None,
|
||||
mm_processor_kwargs: Optional[dict[str, Any]] = None,
|
||||
pooler_config: Optional[PoolerConfig] = None,
|
||||
override_pooler_config: Optional[PoolerConfig] = None,
|
||||
structured_outputs_config: Optional[Union[dict[
|
||||
str, Any], StructuredOutputsConfig]] = None,
|
||||
@ -288,6 +291,7 @@ class LLM:
|
||||
hf_token=hf_token,
|
||||
hf_overrides=hf_overrides,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
pooler_config=pooler_config,
|
||||
override_pooler_config=override_pooler_config,
|
||||
structured_outputs_config=structured_outputs_instance,
|
||||
compilation_config=compilation_config_instance,
|
||||
|
@ -17,13 +17,14 @@ from argparse import Namespace
|
||||
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable
|
||||
from contextlib import asynccontextmanager
|
||||
from http import HTTPStatus
|
||||
from typing import Annotated, Any, Callable, Optional
|
||||
from typing import Annotated, Any, Callable, Literal, Optional
|
||||
|
||||
import prometheus_client
|
||||
import pydantic
|
||||
import regex as re
|
||||
import uvloop
|
||||
from fastapi import APIRouter, Depends, FastAPI, Form, HTTPException, Request
|
||||
from fastapi import (APIRouter, Depends, FastAPI, Form, HTTPException, Query,
|
||||
Request)
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
@ -166,6 +167,9 @@ async def build_async_engine_client(
|
||||
# Context manager to handle engine_client lifecycle
|
||||
# Ensures everything is shutdown and cleaned up on error/exit
|
||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
||||
if client_config:
|
||||
engine_args._api_process_count = client_config.get("client_count", 1)
|
||||
engine_args._api_process_rank = client_config.get("client_index", 0)
|
||||
|
||||
if disable_frontend_multiprocessing is None:
|
||||
disable_frontend_multiprocessing = bool(
|
||||
@ -209,8 +213,12 @@ async def build_async_engine_client_from_engine_args(
|
||||
|
||||
from vllm.v1.engine.async_llm import AsyncLLM
|
||||
async_llm: Optional[AsyncLLM] = None
|
||||
client_count = client_config.pop("client_count") if client_config else 1
|
||||
client_index = client_config.pop("client_index") if client_config else 0
|
||||
|
||||
# Don't mutate the input client_config
|
||||
client_config = dict(client_config) if client_config else {}
|
||||
client_count = client_config.pop("client_count", 1)
|
||||
client_index = client_config.pop("client_index", 0)
|
||||
|
||||
try:
|
||||
async_llm = AsyncLLM.from_vllm_config(
|
||||
vllm_config=vllm_config,
|
||||
@ -956,9 +964,22 @@ if envs.VLLM_SERVER_DEV_MODE:
|
||||
logger.warning("SECURITY WARNING: Development endpoints are enabled! "
|
||||
"This should NOT be used in production!")
|
||||
|
||||
PydanticVllmConfig = pydantic.TypeAdapter(VllmConfig)
|
||||
|
||||
@router.get("/server_info")
|
||||
async def show_server_info(raw_request: Request):
|
||||
server_info = {"vllm_config": str(raw_request.app.state.vllm_config)}
|
||||
async def show_server_info(
|
||||
raw_request: Request,
|
||||
config_format: Annotated[Literal["text", "json"],
|
||||
Query()] = "text",
|
||||
):
|
||||
vllm_config: VllmConfig = raw_request.app.state.vllm_config
|
||||
server_info = {
|
||||
"vllm_config":
|
||||
str(vllm_config)
|
||||
if config_format == "text" else PydanticVllmConfig.dump_python(
|
||||
vllm_config, mode="json", fallback=str)
|
||||
# fallback=str is needed to handle e.g. torch.dtype
|
||||
}
|
||||
return JSONResponse(content=server_info)
|
||||
|
||||
@router.post("/reset_prefix_cache")
|
||||
@ -1856,8 +1877,6 @@ async def run_server_worker(listen_address,
|
||||
if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3:
|
||||
ToolParserManager.import_tool_parser(args.tool_parser_plugin)
|
||||
|
||||
server_index = client_config.get("client_index", 0) if client_config else 0
|
||||
|
||||
# Load logging config for uvicorn if specified
|
||||
log_config = load_log_config(args.log_config_file)
|
||||
if log_config is not None:
|
||||
@ -1873,7 +1892,8 @@ async def run_server_worker(listen_address,
|
||||
vllm_config = await engine_client.get_vllm_config()
|
||||
await init_app_state(engine_client, vllm_config, app.state, args)
|
||||
|
||||
logger.info("Starting vLLM API server %d on %s", server_index,
|
||||
logger.info("Starting vLLM API server %d on %s",
|
||||
vllm_config.parallel_config._api_process_rank,
|
||||
listen_address)
|
||||
shutdown_task = await serve_http(
|
||||
app,
|
||||
|
@ -31,6 +31,8 @@ from openai.types.responses import (
|
||||
ResponseReasoningTextDeltaEvent, ResponseReasoningTextDoneEvent,
|
||||
ResponseStatus, ResponseWebSearchCallCompletedEvent,
|
||||
ResponseWebSearchCallInProgressEvent, ResponseWebSearchCallSearchingEvent)
|
||||
from openai.types.responses.response_reasoning_item import (
|
||||
Content as ResponseReasoningTextContent)
|
||||
|
||||
# Backward compatibility for OpenAI client versions
|
||||
try: # For older openai versions (< 1.100.0)
|
||||
@ -260,26 +262,6 @@ ResponseInputOutputItem: TypeAlias = Union[ResponseInputItemParam,
|
||||
ResponseReasoningItem,
|
||||
ResponseFunctionToolCall]
|
||||
|
||||
StreamingResponsesResponse: TypeAlias = Union[
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseCompletedEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
]
|
||||
|
||||
|
||||
class ResponsesRequest(OpenAIBaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
@ -346,6 +328,13 @@ class ResponsesRequest(OpenAIBaseModel):
|
||||
"access by 3rd parties, and long enough to be "
|
||||
"unpredictable (e.g., 43 characters base64-encoded, corresponding "
|
||||
"to 256 bit). Not supported by vLLM engine V0."))
|
||||
|
||||
enable_response_messages: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Dictates whether or not to return messages as part of the "
|
||||
"response object. Currently only supported for non-streaming "
|
||||
"non-background and gpt-oss only. "))
|
||||
# --8<-- [end:responses-extra-params]
|
||||
|
||||
_DEFAULT_SAMPLING_PARAMS = {
|
||||
@ -973,7 +962,6 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
model: Optional[str] = None
|
||||
prompt: Optional[Union[list[int], list[list[int]], str, list[str]]] = None
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
|
||||
best_of: Optional[int] = None
|
||||
echo: Optional[bool] = False
|
||||
frequency_penalty: Optional[float] = 0.0
|
||||
@ -1009,6 +997,7 @@ class CompletionRequest(OpenAIBaseModel):
|
||||
# --8<-- [end:completion-sampling-params]
|
||||
|
||||
# --8<-- [start:completion-extra-params]
|
||||
prompt_embeds: Optional[Union[bytes, list[bytes]]] = None
|
||||
add_special_tokens: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
@ -1849,6 +1838,11 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
model: str
|
||||
object: Literal["response"] = "response"
|
||||
output: list[ResponseOutputItem]
|
||||
# These are populated when enable_response_messages is set to True
|
||||
# TODO: Currently an issue where content of harmony messages
|
||||
# is not available when these are serialized. Metadata is available
|
||||
input_messages: Optional[list[ChatCompletionMessageParam]] = None
|
||||
output_messages: Optional[list[ChatCompletionMessageParam]] = None
|
||||
parallel_tool_calls: bool
|
||||
temperature: float
|
||||
tool_choice: ToolChoice
|
||||
@ -1878,6 +1872,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
output: list[ResponseOutputItem],
|
||||
status: ResponseStatus,
|
||||
usage: Optional[ResponseUsage] = None,
|
||||
input_messages: Optional[list[ChatCompletionMessageParam]] = None,
|
||||
output_messages: Optional[list[ChatCompletionMessageParam]] = None,
|
||||
) -> "ResponsesResponse":
|
||||
|
||||
incomplete_details: Optional[IncompleteDetails] = None
|
||||
@ -1886,7 +1882,6 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
# TODO: implement the other reason for incomplete_details,
|
||||
# which is content_filter
|
||||
# incomplete_details = IncompleteDetails(reason='content_filter')
|
||||
|
||||
return cls(
|
||||
id=request.request_id,
|
||||
created_at=created_time,
|
||||
@ -1895,6 +1890,8 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
metadata=request.metadata,
|
||||
model=model_name,
|
||||
output=output,
|
||||
input_messages=input_messages,
|
||||
output_messages=output_messages,
|
||||
parallel_tool_calls=request.parallel_tool_calls,
|
||||
temperature=sampling_params.temperature,
|
||||
tool_choice=request.tool_choice,
|
||||
@ -1916,6 +1913,72 @@ class ResponsesResponse(OpenAIBaseModel):
|
||||
)
|
||||
|
||||
|
||||
# TODO: this code can be removed once
|
||||
# https://github.com/openai/openai-python/issues/2634 has been resolved
|
||||
class ResponseReasoningPartDoneEvent(OpenAIBaseModel):
|
||||
content_index: int
|
||||
"""The index of the content part that is done."""
|
||||
|
||||
item_id: str
|
||||
"""The ID of the output item that the content part was added to."""
|
||||
|
||||
output_index: int
|
||||
"""The index of the output item that the content part was added to."""
|
||||
|
||||
part: ResponseReasoningTextContent
|
||||
"""The content part that is done."""
|
||||
|
||||
sequence_number: int
|
||||
"""The sequence number of this event."""
|
||||
|
||||
type: Literal["response.reasoning_part.done"]
|
||||
"""The type of the event. Always `response.reasoning_part.done`."""
|
||||
|
||||
|
||||
# TODO: this code can be removed once
|
||||
# https://github.com/openai/openai-python/issues/2634 has been resolved
|
||||
class ResponseReasoningPartAddedEvent(OpenAIBaseModel):
|
||||
content_index: int
|
||||
"""The index of the content part that is done."""
|
||||
|
||||
item_id: str
|
||||
"""The ID of the output item that the content part was added to."""
|
||||
|
||||
output_index: int
|
||||
"""The index of the output item that the content part was added to."""
|
||||
|
||||
part: ResponseReasoningTextContent
|
||||
"""The content part that is done."""
|
||||
|
||||
sequence_number: int
|
||||
"""The sequence number of this event."""
|
||||
|
||||
type: Literal["response.reasoning_part.added"]
|
||||
"""The type of the event. Always `response.reasoning_part.added`."""
|
||||
|
||||
|
||||
StreamingResponsesResponse: TypeAlias = Union[
|
||||
ResponseCreatedEvent,
|
||||
ResponseInProgressEvent,
|
||||
ResponseCompletedEvent,
|
||||
ResponseOutputItemAddedEvent,
|
||||
ResponseOutputItemDoneEvent,
|
||||
ResponseContentPartAddedEvent,
|
||||
ResponseContentPartDoneEvent,
|
||||
ResponseReasoningTextDeltaEvent,
|
||||
ResponseReasoningTextDoneEvent,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
ResponseCodeInterpreterCallInProgressEvent,
|
||||
ResponseCodeInterpreterCallCodeDeltaEvent,
|
||||
ResponseWebSearchCallInProgressEvent,
|
||||
ResponseWebSearchCallSearchingEvent,
|
||||
ResponseWebSearchCallCompletedEvent,
|
||||
ResponseCodeInterpreterCallCodeDoneEvent,
|
||||
ResponseCodeInterpreterCallInterpretingEvent,
|
||||
ResponseCodeInterpreterCallCompletedEvent,
|
||||
]
|
||||
|
||||
BatchRequestInputBody = Union[ChatCompletionRequest, EmbeddingRequest,
|
||||
ScoreRequest, RerankRequest]
|
||||
|
||||
|
@ -58,6 +58,8 @@ from vllm.entrypoints.openai.protocol import (DeltaMessage, ErrorResponse,
|
||||
InputTokensDetails,
|
||||
OutputTokensDetails,
|
||||
RequestResponseMetadata,
|
||||
ResponseReasoningPartAddedEvent,
|
||||
ResponseReasoningPartDoneEvent,
|
||||
ResponsesRequest,
|
||||
ResponsesResponse, ResponseUsage,
|
||||
StreamingResponsesResponse)
|
||||
@ -473,9 +475,14 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# "completed" is implemented as the "catch-all" for now.
|
||||
status: ResponseStatus = "completed"
|
||||
|
||||
input_messages = None
|
||||
output_messages = None
|
||||
if self.use_harmony:
|
||||
assert isinstance(context, HarmonyContext)
|
||||
output = self._make_response_output_items_with_harmony(context)
|
||||
if request.enable_response_messages:
|
||||
input_messages = context.messages[:context.num_init_messages]
|
||||
output_messages = context.messages[context.num_init_messages:]
|
||||
num_tool_output_tokens = context.num_tool_output_tokens
|
||||
if len(output) > 0:
|
||||
if context.finish_reason == "length":
|
||||
@ -494,6 +501,12 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
output = self._make_response_output_items(request, final_output,
|
||||
tokenizer)
|
||||
|
||||
# TODO: context for non-gptoss models doesn't use messages
|
||||
# so we can't get them out yet
|
||||
if request.enable_response_messages:
|
||||
raise NotImplementedError(
|
||||
"enable_response_messages is currently"
|
||||
" only supported for gpt-oss")
|
||||
# Calculate usage.
|
||||
assert final_res.prompt_token_ids is not None
|
||||
num_tool_output_tokens = 0
|
||||
@ -517,6 +530,8 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
response = ResponsesResponse.from_request(
|
||||
request,
|
||||
sampling_params,
|
||||
input_messages=input_messages,
|
||||
output_messages=output_messages,
|
||||
model_name=model_name,
|
||||
created_time=created_time,
|
||||
output=output,
|
||||
@ -1280,14 +1295,13 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
# Deal with tool call here
|
||||
pass
|
||||
elif previous_item.channel == "analysis":
|
||||
content = ResponseReasoningTextContent(
|
||||
text=previous_item.content[0].text,
|
||||
type="reasoning_text",
|
||||
)
|
||||
reasoning_item = ResponseReasoningItem(
|
||||
type="reasoning",
|
||||
content=[
|
||||
ResponseReasoningTextContent(
|
||||
text=previous_item.content[0].text,
|
||||
type="reasoning_text",
|
||||
),
|
||||
],
|
||||
content=[content],
|
||||
status="completed",
|
||||
id=current_item_id,
|
||||
summary=[],
|
||||
@ -1301,6 +1315,15 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
content_index=current_content_index,
|
||||
text=previous_item.content[0].text,
|
||||
))
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseReasoningPartDoneEvent(
|
||||
type="response.reasoning_part.done",
|
||||
sequence_number=-1,
|
||||
item_id=current_item_id,
|
||||
output_index=current_output_index,
|
||||
content_index=current_content_index,
|
||||
part=content,
|
||||
))
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseOutputItemDoneEvent(
|
||||
type="response.output_item.done",
|
||||
@ -1412,17 +1435,15 @@ class OpenAIServingResponses(OpenAIServing):
|
||||
))
|
||||
current_content_index += 1
|
||||
yield _increment_sequence_number_and_return(
|
||||
ResponseContentPartAddedEvent(
|
||||
type="response.content_part.added",
|
||||
ResponseReasoningPartAddedEvent(
|
||||
type="response.reasoning_part.added",
|
||||
sequence_number=-1,
|
||||
output_index=current_output_index,
|
||||
item_id=current_item_id,
|
||||
content_index=current_content_index,
|
||||
part=ResponseOutputText(
|
||||
type="output_text",
|
||||
part=ResponseReasoningTextContent(
|
||||
text="",
|
||||
annotations=[],
|
||||
logprobs=[],
|
||||
type="reasoning_text",
|
||||
),
|
||||
))
|
||||
yield _increment_sequence_number_and_return(
|
||||
|
@ -32,6 +32,7 @@ if TYPE_CHECKING:
|
||||
VLLM_CONFIG_ROOT: str = os.path.expanduser("~/.config/vllm")
|
||||
VLLM_USAGE_STATS_SERVER: str = "https://stats.vllm.ai"
|
||||
VLLM_NO_USAGE_STATS: bool = False
|
||||
VLLM_DISABLE_FLASHINFER_PREFILL: bool = False
|
||||
VLLM_DO_NOT_TRACK: bool = False
|
||||
VLLM_USAGE_SOURCE: str = ""
|
||||
VLLM_CONFIGURE_LOGGING: int = 1
|
||||
@ -479,6 +480,8 @@ environment_variables: dict[str, Callable[[], Any]] = {
|
||||
lambda: os.environ.get("VLLM_USAGE_STATS_SERVER", "https://stats.vllm.ai"),
|
||||
"VLLM_NO_USAGE_STATS":
|
||||
lambda: os.environ.get("VLLM_NO_USAGE_STATS", "0") == "1",
|
||||
"VLLM_DISABLE_FLASHINFER_PREFILL":
|
||||
lambda: os.environ.get("VLLM_DISABLE_FLASHINFER_PREFILL", "0") == "1",
|
||||
"VLLM_DO_NOT_TRACK":
|
||||
lambda: (os.environ.get("VLLM_DO_NOT_TRACK", None) or os.environ.get(
|
||||
"DO_NOT_TRACK", None) or "0") == "1",
|
||||
|
699
vllm/executor/ray_distributed_executor.py
Normal file
699
vllm/executor/ray_distributed_executor.py
Normal file
@ -0,0 +1,699 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import cloudpickle
|
||||
import msgspec
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.executor.executor_base import (
|
||||
DistributedExecutorBase) # yapf: disable
|
||||
from vllm.executor.msgspec_utils import encode_hook
|
||||
from vllm.executor.ray_utils import (RayWorkerWrapper, initialize_ray_cluster,
|
||||
ray)
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.layers.sampler import SamplerOutput
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.ray.ray_env import get_env_vars_to_copy
|
||||
from vllm.sequence import ExecuteModelRequest
|
||||
from vllm.utils import (_run_task_with_lock, get_distributed_init_method,
|
||||
get_ip, get_open_port, make_async)
|
||||
|
||||
if ray is not None:
|
||||
from ray.actor import ActorHandle
|
||||
from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy
|
||||
else:
|
||||
ActorHandle = None
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ray.util.placement_group import PlacementGroup
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RayWorkerMetaData:
|
||||
"""
|
||||
Metadata for a Ray worker.
|
||||
The order of ray worker creation can be random,
|
||||
and we need to reset the rank after creating all workers.
|
||||
"""
|
||||
worker: ActorHandle
|
||||
created_rank: int
|
||||
adjusted_rank: int = -1
|
||||
ip: str = ""
|
||||
|
||||
|
||||
class RayDistributedExecutor(DistributedExecutorBase):
|
||||
"""Ray-based distributed executor"""
|
||||
|
||||
# These env vars are worker-specific, therefore are NOT copied
|
||||
# from the driver to the workers
|
||||
WORKER_SPECIFIC_ENV_VARS = {
|
||||
"VLLM_HOST_IP", "VLLM_HOST_PORT", "LOCAL_RANK", "CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
# These non-vLLM env vars are copied from the driver to workers
|
||||
ADDITIONAL_ENV_VARS = {"HF_TOKEN", "HUGGING_FACE_HUB_TOKEN"}
|
||||
|
||||
uses_ray: bool = True
|
||||
|
||||
def _init_executor(self) -> None:
|
||||
self.forward_dag: Optional[ray.dag.CompiledDAG] = None
|
||||
if envs.VLLM_USE_V1:
|
||||
# V1 uses SPMD worker and compiled DAG
|
||||
os.environ["VLLM_USE_RAY_SPMD_WORKER"] = "1"
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG"] = "1"
|
||||
|
||||
# For TPU or XPU, avoid compiling NVIDIA's NCCL
|
||||
if current_platform.is_tpu() or current_platform.is_xpu():
|
||||
os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
|
||||
|
||||
# If the env var is set, it uses the Ray's compiled DAG API
|
||||
# which optimizes the control plane overhead.
|
||||
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
|
||||
# Currently, this requires USE_RAY_SPMD_WORKER=True.
|
||||
self.use_ray_compiled_dag = envs.VLLM_USE_RAY_COMPILED_DAG
|
||||
# If the env var is set, then we do not distinguish between the
|
||||
# "driver worker" vs other workers. Also, the rank 0 worker will
|
||||
# be executed in a remote Ray worker. Currently this requires
|
||||
# USE_RAY_COMPILED_DAG=True.
|
||||
self.use_ray_spmd_worker = envs.VLLM_USE_RAY_SPMD_WORKER
|
||||
if self.use_ray_compiled_dag:
|
||||
assert self.use_ray_spmd_worker, (
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1 requires "
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if self.use_ray_spmd_worker:
|
||||
# TODO: Support SPMD worker for non-DAG Ray executor.
|
||||
assert self.use_ray_compiled_dag, (
|
||||
"VLLM_USE_RAY_SPMD_WORKER=1 requires "
|
||||
"VLLM_USE_RAY_COMPILED_DAG=1")
|
||||
|
||||
assert self.uses_ray
|
||||
initialize_ray_cluster(self.parallel_config)
|
||||
placement_group = self.parallel_config.placement_group
|
||||
|
||||
# Disable Ray usage stats collection.
|
||||
ray_usage = os.environ.get("RAY_USAGE_STATS_ENABLED", "0")
|
||||
if ray_usage != "1":
|
||||
os.environ["RAY_USAGE_STATS_ENABLED"] = "0"
|
||||
|
||||
# Create the parallel GPU workers.
|
||||
self._init_workers_ray(placement_group)
|
||||
|
||||
self.input_encoder = msgspec.msgpack.Encoder(enc_hook=encode_hook)
|
||||
self.output_decoder = msgspec.msgpack.Decoder(
|
||||
Optional[List[SamplerOutput]])
|
||||
self.use_v1 = envs.VLLM_USE_V1
|
||||
|
||||
self.pp_locks: Optional[List[asyncio.Lock]] = None
|
||||
if not self.use_ray_compiled_dag:
|
||||
self.driver_exec_method = make_async(
|
||||
self.driver_worker.execute_method)
|
||||
|
||||
def shutdown(self) -> None:
|
||||
if logger:
|
||||
# Somehow logger can be None here.
|
||||
logger.info(
|
||||
"Shutting down Ray distributed executor. If you see error log "
|
||||
"from logging.cc regarding SIGTERM received, please ignore "
|
||||
"because this is the expected termination process in Ray.")
|
||||
if hasattr(self, "forward_dag") and self.forward_dag is not None:
|
||||
self.forward_dag.teardown()
|
||||
import ray
|
||||
for worker in self.workers:
|
||||
ray.kill(worker)
|
||||
self.forward_dag = None
|
||||
|
||||
def _configure_ray_workers_use_nsight(self,
|
||||
ray_remote_kwargs) -> Dict[str, Any]:
|
||||
# If nsight profiling is enabled, we need to set the profiling
|
||||
# configuration for the ray workers as runtime env.
|
||||
runtime_env = ray_remote_kwargs.setdefault("runtime_env", {})
|
||||
runtime_env.update({
|
||||
"nsight": {
|
||||
"t": "cuda,cudnn,cublas",
|
||||
"o": "'worker_process_%p'",
|
||||
"cuda-graph-trace": "node",
|
||||
}
|
||||
})
|
||||
|
||||
return ray_remote_kwargs
|
||||
|
||||
# child class could overwrite this to return actual env vars.
|
||||
def _get_env_vars_to_be_updated(self):
|
||||
return self._env_vars_for_all_workers
|
||||
|
||||
def _init_workers_ray(self, placement_group: "PlacementGroup",
|
||||
**ray_remote_kwargs):
|
||||
num_gpus = envs.VLLM_RAY_PER_WORKER_GPUS
|
||||
|
||||
# The driver dummy worker does not actually use any resources.
|
||||
# It holds the resource for the driver worker.
|
||||
self.driver_dummy_worker: Optional[RayWorkerWrapper] = None
|
||||
# The remaining workers are the actual ray actors.
|
||||
self.workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Used in ray compiled DAG: indexed first by PP rank,
|
||||
# and then TP rank. In other words, the inner list is
|
||||
# the TP group of workers for a PP rank.
|
||||
self.pp_tp_workers: List[List[RayWorkerWrapper]] = []
|
||||
|
||||
if self.parallel_config.ray_workers_use_nsight:
|
||||
ray_remote_kwargs = self._configure_ray_workers_use_nsight(
|
||||
ray_remote_kwargs)
|
||||
|
||||
logger.info("use_ray_spmd_worker: %s", self.use_ray_spmd_worker)
|
||||
|
||||
# Create the workers.
|
||||
bundle_indices: List[int]
|
||||
if envs.VLLM_RAY_BUNDLE_INDICES:
|
||||
# Use the bundle indices specified by the user.
|
||||
bundle_indices = list(
|
||||
map(int, envs.VLLM_RAY_BUNDLE_INDICES.split(",")))
|
||||
assert len(bundle_indices) == self.parallel_config.world_size, \
|
||||
("VLLM_RAY_BUNDLE_INDICES must have the same size"
|
||||
f" as the world size, but got {bundle_indices=} "
|
||||
f"and {self.parallel_config.world_size=}")
|
||||
assert len(set(bundle_indices)) == len(bundle_indices), \
|
||||
("VLLM_RAY_BUNDLE_INDICES cannot have duplicate values,"
|
||||
f" but got {bundle_indices=}")
|
||||
else:
|
||||
# use the first N bundles that have GPU resources.
|
||||
bundle_indices = []
|
||||
for bundle_id, bundle in enumerate(placement_group.bundle_specs):
|
||||
if bundle.get(current_platform.ray_device_key, 0):
|
||||
bundle_indices.append(bundle_id)
|
||||
bundle_indices = bundle_indices[:self.parallel_config.world_size]
|
||||
|
||||
worker_metadata: List[RayWorkerMetaData] = []
|
||||
driver_ip = get_ip()
|
||||
for rank, bundle_id in enumerate(bundle_indices):
|
||||
scheduling_strategy = PlacementGroupSchedulingStrategy(
|
||||
placement_group=placement_group,
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=bundle_id,
|
||||
)
|
||||
|
||||
if current_platform.ray_device_key == "GPU":
|
||||
# NV+AMD GPUs, and Intel XPUs
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=num_gpus,
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rpc_rank=rank)
|
||||
else:
|
||||
worker = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
resources={current_platform.ray_device_key: num_gpus},
|
||||
scheduling_strategy=scheduling_strategy,
|
||||
**ray_remote_kwargs,
|
||||
)(RayWorkerWrapper).remote(vllm_config=self.vllm_config,
|
||||
rpc_rank=rank)
|
||||
worker_metadata.append(
|
||||
RayWorkerMetaData(worker=worker, created_rank=rank))
|
||||
|
||||
worker_ips = ray.get([
|
||||
each.worker.get_node_ip.remote() # type: ignore[attr-defined]
|
||||
for each in worker_metadata
|
||||
])
|
||||
|
||||
for each, ip in zip(worker_metadata, worker_ips):
|
||||
each.ip = ip
|
||||
|
||||
if not self.use_ray_spmd_worker:
|
||||
for i, each in enumerate(worker_metadata):
|
||||
# find and remove the dummy worker from the list
|
||||
worker = each.worker
|
||||
worker_ip = each.ip
|
||||
if self.driver_dummy_worker is None and worker_ip == driver_ip:
|
||||
# If the worker is on the same node as the driver, we use it
|
||||
# as the resource holder for the driver process.
|
||||
self.driver_dummy_worker = worker
|
||||
self.driver_worker = RayWorkerWrapper(
|
||||
vllm_config=self.vllm_config, rpc_rank=0)
|
||||
worker_metadata.pop(i)
|
||||
break
|
||||
|
||||
logger.debug("workers: %s", worker_metadata)
|
||||
logger.debug("driver_dummy_worker: %s", self.driver_dummy_worker)
|
||||
if not self.use_ray_spmd_worker and self.driver_dummy_worker is None:
|
||||
raise ValueError(
|
||||
"Ray does not allocate any GPUs on the driver node."
|
||||
f"Driver IP: {driver_ip}, worker IPs: {worker_ips}."
|
||||
"Consider adjusting the Ray placement group or running "
|
||||
"the driver on a GPU node.")
|
||||
|
||||
ip_counts: Dict[str, int] = {}
|
||||
for ip in worker_ips:
|
||||
ip_counts[ip] = ip_counts.get(ip, 0) + 1
|
||||
|
||||
def sort_by_driver_then_worker_ip(item: RayWorkerMetaData):
|
||||
"""
|
||||
Sort the workers based on 3 properties:
|
||||
1. If the worker is on the same node as the driver (vllm engine),
|
||||
it should be placed first.
|
||||
2. Then, if the worker is on a node with fewer workers, it should
|
||||
be placed first.
|
||||
3. Finally, if the work is on a node with smaller IP address, it
|
||||
should be placed first.
|
||||
"""
|
||||
ip = item.ip
|
||||
return (0 if ip == driver_ip else 1, ip_counts[ip], ip)
|
||||
|
||||
# After sorting, the workers on the same node will be
|
||||
# close to each other, and the workers on the driver
|
||||
# node will be placed first.
|
||||
sorted_worker_metadata = sorted(worker_metadata,
|
||||
key=sort_by_driver_then_worker_ip)
|
||||
start_rank = 0 if self.use_ray_spmd_worker else 1
|
||||
for i, item in enumerate(sorted_worker_metadata):
|
||||
item.adjusted_rank = i + start_rank
|
||||
self.workers = [item.worker for item in sorted_worker_metadata]
|
||||
rerank_mapping = {
|
||||
item.created_rank: item.adjusted_rank
|
||||
for item in sorted_worker_metadata
|
||||
}
|
||||
self._run_workers("adjust_rank", rerank_mapping)
|
||||
|
||||
# Get the set of GPU IDs used on each node.
|
||||
worker_node_and_gpu_ids = []
|
||||
for worker in [self.driver_dummy_worker] + self.workers:
|
||||
if worker is None:
|
||||
# driver_dummy_worker can be None when using ray spmd worker.
|
||||
continue
|
||||
worker_node_and_gpu_ids.append(
|
||||
ray.get(worker.get_node_and_gpu_ids.remote()) \
|
||||
) # type: ignore
|
||||
|
||||
node_workers = defaultdict(list) # node id -> list of worker ranks
|
||||
node_gpus = defaultdict(list) # node id -> list of gpu ids
|
||||
|
||||
for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids):
|
||||
node_workers[node_id].append(i)
|
||||
# `gpu_ids` can be a list of strings or integers.
|
||||
# convert them to integers for consistency.
|
||||
# NOTE: gpu_ids can be larger than 9 (e.g. 16 GPUs),
|
||||
# string sorting is not sufficient.
|
||||
# see https://github.com/vllm-project/vllm/issues/5590
|
||||
gpu_ids = [int(x) for x in gpu_ids]
|
||||
node_gpus[node_id].extend(gpu_ids)
|
||||
for node_id, gpu_ids in node_gpus.items():
|
||||
node_gpus[node_id] = sorted(gpu_ids)
|
||||
|
||||
all_ips = set(worker_ips + [driver_ip])
|
||||
n_ips = len(all_ips)
|
||||
n_nodes = len(node_workers)
|
||||
|
||||
if n_nodes != n_ips:
|
||||
raise RuntimeError(
|
||||
f"Every node should have a unique IP address. Got {n_nodes}"
|
||||
f" nodes with node ids {list(node_workers.keys())} and "
|
||||
f"{n_ips} unique IP addresses {all_ips}. Please check your"
|
||||
" network configuration. If you set `VLLM_HOST_IP`"
|
||||
" environment variable, make sure it is unique for"
|
||||
" each node.")
|
||||
|
||||
# Set environment variables for the driver and workers.
|
||||
all_args_to_update_environment_variables = [{
|
||||
current_platform.device_control_env_var:
|
||||
",".join(map(str, node_gpus[node_id])),
|
||||
} for (node_id, _) in worker_node_and_gpu_ids]
|
||||
|
||||
# Environment variables to copy from driver to workers
|
||||
env_vars_to_copy = get_env_vars_to_copy(
|
||||
exclude_vars=self.WORKER_SPECIFIC_ENV_VARS,
|
||||
additional_vars=set(current_platform.additional_env_vars).union(
|
||||
self.ADDITIONAL_ENV_VARS),
|
||||
destination="workers")
|
||||
|
||||
# Copy existing env vars to each worker's args
|
||||
for args in all_args_to_update_environment_variables:
|
||||
# TODO: refactor platform-specific env vars
|
||||
for name in env_vars_to_copy:
|
||||
if name in os.environ:
|
||||
args[name] = os.environ[name]
|
||||
|
||||
self._env_vars_for_all_workers = (
|
||||
all_args_to_update_environment_variables)
|
||||
|
||||
self._run_workers("update_environment_variables",
|
||||
self._get_env_vars_to_be_updated())
|
||||
|
||||
if len(node_gpus) == 1:
|
||||
# in single node case, we don't need to get the IP address.
|
||||
# the loopback address is sufficient
|
||||
# NOTE: a node may have several IP addresses, one for each
|
||||
# network interface. `get_ip()` might return any of them,
|
||||
# while they might not work for communication inside the node
|
||||
# if the network setup is complicated. Using the loopback address
|
||||
# solves this issue, as it always works for communication inside
|
||||
# the node.
|
||||
driver_ip = "127.0.0.1"
|
||||
distributed_init_method = get_distributed_init_method(
|
||||
driver_ip, get_open_port())
|
||||
|
||||
# Initialize the actual workers inside worker wrapper.
|
||||
all_kwargs = []
|
||||
for rank, (node_id, _) in enumerate(worker_node_and_gpu_ids):
|
||||
local_rank = node_workers[node_id].index(rank)
|
||||
kwargs = dict(
|
||||
vllm_config=self.vllm_config,
|
||||
local_rank=local_rank,
|
||||
rank=rank,
|
||||
distributed_init_method=distributed_init_method,
|
||||
is_driver_worker=(not self.parallel_config)
|
||||
or (rank % self.parallel_config.tensor_parallel_size == 0),
|
||||
)
|
||||
all_kwargs.append(kwargs)
|
||||
self._run_workers("init_worker", all_kwargs)
|
||||
|
||||
self._run_workers("init_device")
|
||||
self._run_workers("load_model",
|
||||
max_concurrent_workers=self.parallel_config.
|
||||
max_parallel_loading_workers)
|
||||
|
||||
if self.use_ray_spmd_worker:
|
||||
for pp_rank in range(self.parallel_config.pipeline_parallel_size):
|
||||
self.pp_tp_workers.append([])
|
||||
for tp_rank in range(
|
||||
self.parallel_config.tensor_parallel_size):
|
||||
# PP=2, TP=4
|
||||
# pp_tp_workers = [[0, 1, 2, 3], [4, 5, 6, 7]]
|
||||
rank = (pp_rank * self.parallel_config.tensor_parallel_size
|
||||
) + tp_rank
|
||||
assert len(self.pp_tp_workers[pp_rank]) == tp_rank
|
||||
assert pp_rank < len(self.pp_tp_workers)
|
||||
self.pp_tp_workers[pp_rank].append(self.workers[rank])
|
||||
|
||||
# This is the list of workers that are rank 0 of each TP group EXCEPT
|
||||
# global rank 0. These are the workers that will broadcast to the
|
||||
# rest of the workers.
|
||||
self.tp_driver_workers: List[RayWorkerWrapper] = []
|
||||
# This is the list of workers that are not drivers and not the first
|
||||
# worker in a TP group. These are the workers that will be
|
||||
# broadcasted to.
|
||||
self.non_driver_workers: List[RayWorkerWrapper] = []
|
||||
|
||||
# Enforce rank order for correct rank to return final output.
|
||||
for index, worker in enumerate(self.workers):
|
||||
# The driver worker is rank 0 and not in self.workers.
|
||||
rank = index + 1
|
||||
if rank % self.parallel_config.tensor_parallel_size == 0:
|
||||
self.tp_driver_workers.append(worker)
|
||||
else:
|
||||
self.non_driver_workers.append(worker)
|
||||
|
||||
def _driver_execute_model(
|
||||
self, execute_model_req: Optional[ExecuteModelRequest]
|
||||
) -> Optional[List[SamplerOutput]]:
|
||||
"""Run execute_model in the driver worker.
|
||||
|
||||
Passing None will cause the driver to stop the model execution
|
||||
loop running in each of the remote workers.
|
||||
"""
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
return self.driver_worker.execute_method("execute_model",
|
||||
execute_model_req)
|
||||
|
||||
def execute_model(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return super().execute_model(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=False)
|
||||
|
||||
if self.use_v1:
|
||||
serialized_data = execute_model_req
|
||||
else:
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
outputs = ray.get(self.forward_dag.execute(serialized_data))
|
||||
if self.use_v1:
|
||||
output = outputs[0]
|
||||
else:
|
||||
output = self.output_decoder.decode(outputs[0])
|
||||
return output
|
||||
|
||||
def _run_workers(
|
||||
self,
|
||||
method: Union[str, Callable],
|
||||
*args,
|
||||
async_run_tensor_parallel_workers_only: bool = False,
|
||||
max_concurrent_workers: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Any:
|
||||
"""Runs the given method on all workers. Can be used in the following
|
||||
ways:
|
||||
|
||||
Args:
|
||||
- async_run_tensor_parallel_workers_only: If True the method will be
|
||||
run only in the remote TP workers, not the driver worker.
|
||||
It will also be run asynchronously and return a list of futures
|
||||
rather than blocking on the results.
|
||||
- args/kwargs: All workers share the same args/kwargs
|
||||
"""
|
||||
if isinstance(method, str):
|
||||
sent_method = method
|
||||
else:
|
||||
sent_method = cloudpickle.dumps(method)
|
||||
del method
|
||||
if self.use_ray_spmd_worker:
|
||||
assert not async_run_tensor_parallel_workers_only, (
|
||||
"async_run_tensor_parallel_workers_only is not supported for "
|
||||
"spmd mode.")
|
||||
|
||||
if max_concurrent_workers:
|
||||
raise NotImplementedError(
|
||||
"max_concurrent_workers is not supported yet.")
|
||||
|
||||
# Start the ray workers first.
|
||||
ray_workers = self.workers
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
ray_workers = self.non_driver_workers
|
||||
ray_worker_outputs = [
|
||||
worker.execute_method.remote(sent_method, *args, **kwargs)
|
||||
for worker in ray_workers
|
||||
]
|
||||
|
||||
if async_run_tensor_parallel_workers_only:
|
||||
# Just return futures
|
||||
return ray_worker_outputs
|
||||
|
||||
driver_worker_output = []
|
||||
# In SPMD mode, the driver worker is the same as any other worker,
|
||||
# so we only explicitly execute on the driver worker if using a
|
||||
# non-SPMD worker class.
|
||||
if not self.use_ray_spmd_worker:
|
||||
# Start the driver worker after all the ray workers.
|
||||
driver_worker_output = [
|
||||
self.driver_worker.execute_method(sent_method, *args, **kwargs)
|
||||
]
|
||||
|
||||
# Get the results of the ray workers.
|
||||
if self.workers:
|
||||
ray_worker_outputs = ray.get(ray_worker_outputs)
|
||||
|
||||
return driver_worker_output + ray_worker_outputs
|
||||
|
||||
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
||||
"""Wait for futures returned from _run_workers() with
|
||||
async_run_remote_workers_only to complete."""
|
||||
ray.get(parallel_worker_tasks)
|
||||
|
||||
def _check_ray_cgraph_installation(self):
|
||||
import importlib.metadata
|
||||
|
||||
from packaging import version
|
||||
|
||||
required_version = version.parse("2.43.0")
|
||||
current_version = version.parse(importlib.metadata.version("ray"))
|
||||
if current_version < required_version:
|
||||
raise ValueError(f"Ray version {required_version} is "
|
||||
f"required, but found {current_version}")
|
||||
|
||||
import importlib.util
|
||||
cgraph_spec = importlib.util.find_spec(
|
||||
"ray.experimental.compiled_dag_ref")
|
||||
if cgraph_spec is None:
|
||||
raise ValueError("Ray Compiled Graph is not installed. "
|
||||
"Run `pip install ray[cgraph]` to install it.")
|
||||
|
||||
cupy_spec = importlib.util.find_spec("cupy")
|
||||
if (cupy_spec is None
|
||||
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
|
||||
raise ValueError(
|
||||
"cupy is not installed but required since "
|
||||
"VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
|
||||
"Run `pip install ray[cgraph]` and check cupy installation.")
|
||||
|
||||
def _compiled_ray_dag(self, enable_asyncio: bool):
|
||||
assert self.parallel_config.use_ray
|
||||
self._check_ray_cgraph_installation()
|
||||
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
|
||||
# (it is 10 seconds by default). This is a Ray environment variable to
|
||||
# control the timeout of getting result from a compiled graph execution,
|
||||
# i.e., the distributed execution that includes model forward runs and
|
||||
# intermediate tensor communications, in the case of vllm.
|
||||
# Note: we should set this env var before importing
|
||||
# ray.dag, otherwise it will not take effect.
|
||||
os.environ.setdefault("RAY_CGRAPH_get_timeout", "300") # noqa: SIM112
|
||||
from ray.dag import InputNode, MultiOutputNode
|
||||
logger.info("RAY_CGRAPH_get_timeout is set to %s",
|
||||
os.environ["RAY_CGRAPH_get_timeout"]) # noqa: SIM112
|
||||
logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
|
||||
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
|
||||
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
|
||||
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
|
||||
|
||||
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
|
||||
if channel_type not in ("auto", "nccl", "shm"):
|
||||
raise ValueError(
|
||||
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
|
||||
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
|
||||
|
||||
with InputNode() as input_data:
|
||||
# Example DAG: PP=2, TP=4
|
||||
#
|
||||
# For V0:
|
||||
# ExecuteModelRequest -> 0 -> (ExecuteModelReq, IntermediateTensors) -> 4 -> SamplerOutput # noqa: E501
|
||||
# ExecuteModelRequest -> 1 -> (ExecuteModelReq, IntermediateTensors) -> 5 -> SamplerOutput # noqa: E501
|
||||
# ExecuteModelRequest -> 2 -> (ExecuteModelReq, IntermediateTensors) -> 6 -> SamplerOutput # noqa: E501
|
||||
# ExecuteModelRequest -> 3 -> (ExecuteModelReq, IntermediateTensors) -> 7 -> SamplerOutput # noqa: E501
|
||||
#
|
||||
# For V1:
|
||||
# SchedulerOutput -> 0 -> (SchedulerOutput, IntermediateTensors) -> 4 -> ModelRunnerOutput # noqa: E501
|
||||
# SchedulerOutput -> 1 -> (SchedulerOutput, IntermediateTensors) -> 5 -> ModelRunnerOutput # noqa: E501
|
||||
# SchedulerOutput -> 2 -> (SchedulerOutput, IntermediateTensors) -> 6 -> ModelRunnerOutput # noqa: E501
|
||||
# SchedulerOutput -> 3 -> (SchedulerOutput, IntermediateTensors) -> 7 -> ModelRunnerOutput # noqa: E501
|
||||
|
||||
# All workers in the first TP group will take in the
|
||||
# ExecuteModelRequest as input.
|
||||
outputs = [input_data for _ in self.pp_tp_workers[0]]
|
||||
for pp_rank, tp_group in enumerate(self.pp_tp_workers):
|
||||
# Each PP worker takes in the output of the previous PP worker,
|
||||
# and the TP group executes in SPMD fashion.
|
||||
if self.use_v1:
|
||||
outputs = [
|
||||
worker.execute_model_ray.
|
||||
bind( # type: ignore[attr-defined]
|
||||
outputs[i]) for i, worker in enumerate(tp_group)
|
||||
]
|
||||
else:
|
||||
outputs = [
|
||||
worker.execute_model_spmd.
|
||||
bind( # type: ignore[attr-defined]
|
||||
outputs[i]) for i, worker in enumerate(tp_group)
|
||||
]
|
||||
|
||||
last_pp_rank = len(self.pp_tp_workers) - 1
|
||||
if (pp_rank < last_pp_rank and
|
||||
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
|
||||
# Specify how intermediate tensors should be passed
|
||||
# between pp stages, no need to specify for the last
|
||||
# pp stage or when using shared memory (the default).
|
||||
transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
|
||||
outputs = [
|
||||
output.with_tensor_transport(transport=transport)
|
||||
for output in outputs
|
||||
]
|
||||
|
||||
forward_dag = MultiOutputNode(outputs)
|
||||
|
||||
if envs.VLLM_USE_RAY_WRAPPED_PP_COMM:
|
||||
from ray.experimental.channel.accelerator_context import (
|
||||
register_accelerator_context)
|
||||
|
||||
from vllm.distributed.device_communicators.ray_communicator import (
|
||||
RayPPCommunicator)
|
||||
register_accelerator_context(torch_module_name="cuda",
|
||||
communicator_cls=RayPPCommunicator)
|
||||
logger.info("Using RayPPCommunicator "
|
||||
"(which wraps vLLM _PP GroupCoordinator) "
|
||||
"for Ray Compiled Graph communication.")
|
||||
else:
|
||||
logger.info("Using Ray's NCCL communicator for "
|
||||
"Ray Compiled Graph communication.")
|
||||
|
||||
return forward_dag.experimental_compile(
|
||||
enable_asyncio=enable_asyncio,
|
||||
_overlap_gpu_communication=envs.
|
||||
VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
|
||||
|
||||
def __del__(self):
|
||||
self.shutdown()
|
||||
|
||||
async def execute_model_async(
|
||||
self,
|
||||
execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]:
|
||||
if not self.use_ray_spmd_worker:
|
||||
return await super().execute_model_async(execute_model_req)
|
||||
|
||||
if self.forward_dag is None:
|
||||
self.forward_dag = self._compiled_ray_dag(enable_asyncio=True)
|
||||
|
||||
serialized_data = self.input_encoder.encode(execute_model_req)
|
||||
dag_future = await self.forward_dag.execute_async(serialized_data)
|
||||
output = await dag_future[0]
|
||||
return self.output_decoder.decode(output)
|
||||
|
||||
async def _driver_execute_model_async(
|
||||
self,
|
||||
execute_model_req: Optional[ExecuteModelRequest] = None
|
||||
) -> List[SamplerOutput]:
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"driver_worker does not exist for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
if not self.tp_driver_workers:
|
||||
return await self.driver_exec_method("execute_model",
|
||||
execute_model_req)
|
||||
if self.pp_locks is None:
|
||||
# This locks each pipeline parallel stage so multiple virtual
|
||||
# engines can't execute on the same stage at the same time
|
||||
# We create the locks here to avoid creating them in the constructor
|
||||
# which uses a different asyncio loop.
|
||||
self.pp_locks = [
|
||||
asyncio.Lock()
|
||||
for _ in range(self.parallel_config.pipeline_parallel_size)
|
||||
]
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(self.driver_exec_method, self.pp_locks[0],
|
||||
"execute_model", execute_model_req))
|
||||
]
|
||||
for pp_rank, driver_worker in enumerate(self.tp_driver_workers,
|
||||
start=1):
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
_run_task_with_lock(driver_worker.execute_method.remote,
|
||||
self.pp_locks[pp_rank],
|
||||
"execute_model", execute_model_req)))
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# Only the last PP stage has the final results.
|
||||
return results[-1]
|
||||
|
||||
async def _start_worker_execution_loop(self):
|
||||
assert not self.use_ray_spmd_worker, (
|
||||
"worker loop is disabled for VLLM_USE_RAY_SPMD_WORKER=1")
|
||||
coros = [
|
||||
worker.execute_method.remote("start_worker_execution_loop")
|
||||
for worker in self.non_driver_workers
|
||||
]
|
||||
return await asyncio.gather(*coros)
|
||||
|
||||
def check_health(self) -> None:
|
||||
# Assume that the Ray workers are healthy.
|
||||
# TODO: check the health of the Ray workers
|
||||
return
|
@ -14,7 +14,7 @@ from torch import nn
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping
|
||||
from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.lora_weights import LoRALayerWeights, PackedLoRALayerWeights
|
||||
from vllm.lora.peft_helper import PEFTHelper
|
||||
from vllm.lora.punica_wrapper import get_punica_wrapper
|
||||
from vllm.lora.utils import (from_layer, from_layer_logits_processor,
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config.lora import LoRAConfig
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.lora.models import (LoRAModel, LoRAModelManager,
|
||||
LRUCacheLoRAModelManager, create_lora_manager)
|
||||
@ -27,25 +27,26 @@ class WorkerLoRAManager:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_num_seqs: int,
|
||||
max_num_batched_tokens: int,
|
||||
vocab_size: int,
|
||||
lora_config: LoRAConfig,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
embedding_modules: dict[str, str],
|
||||
embedding_padding_modules: list[str],
|
||||
lora_model_cls: type[LoRAModel] = LoRAModel,
|
||||
max_position_embeddings: Optional[int] = None,
|
||||
):
|
||||
self._lora_model_cls = lora_model_cls
|
||||
self.embedding_modules = embedding_modules
|
||||
self.embedding_padding_modules = embedding_padding_modules
|
||||
self._cached_dummy_lora: Union[None, Literal[False], LoRAModel] = False
|
||||
self.max_num_seqs = max_num_seqs
|
||||
self.max_num_batched_tokens = max_num_batched_tokens
|
||||
self.vocab_size = vocab_size
|
||||
self.lora_config = lora_config
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.max_num_seqs = vllm_config.scheduler_config.max_num_seqs
|
||||
self.max_num_batched_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.vocab_size = vllm_config.model_config.get_vocab_size()
|
||||
self.lora_config = vllm_config.lora_config
|
||||
|
||||
# Use get_text_config() in case of multimodal models
|
||||
text_config = vllm_config.model_config.hf_config.get_text_config()
|
||||
|
||||
self.max_position_embeddings = text_config.max_position_embeddings
|
||||
self.device = device
|
||||
# Lazily initialized by create_lora_manager.
|
||||
self._adapter_manager: LoRAModelManager
|
||||
|
@ -78,3 +78,12 @@ if HAS_TRITON:
|
||||
"TritonOrDeepGemmExperts",
|
||||
"BatchedTritonOrDeepGemmExperts",
|
||||
]
|
||||
else:
|
||||
# Some model classes directly use the custom ops. Add placeholders
|
||||
# to avoid import errors.
|
||||
def _raise_exception(method: str):
|
||||
raise NotImplementedError(
|
||||
f"{method} is not implemented as lack of triton.")
|
||||
|
||||
fused_topk = lambda *args, **kwargs: _raise_exception("fused_topk")
|
||||
fused_experts = lambda *args, **kwargs: _raise_exception("fused_experts")
|
||||
|
@ -286,6 +286,7 @@ class BatchedDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
@ -126,6 +126,7 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -136,5 +137,5 @@ class BatchedTritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
assert experts is not None
|
||||
experts.apply(output, hidden_states, w1, w2, topk_weights, topk_ids,
|
||||
activation, global_num_experts, expert_map, a1q_scale,
|
||||
workspace13, workspace2, expert_tokens_meta,
|
||||
a2_scale, workspace13, workspace2, expert_tokens_meta,
|
||||
apply_router_weight_on_input)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@ -14,7 +14,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
from vllm.utils import cdiv, has_triton_kernels
|
||||
from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe
|
||||
|
||||
if TYPE_CHECKING and has_triton_kernels:
|
||||
if has_triton_kernels():
|
||||
from triton_kernels.matmul_ogs import PrecisionConfig
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
@ -241,6 +241,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -262,7 +263,7 @@ class CutlassExpertsFp8Base(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
run_cutlass_moe_fp8(
|
||||
output, hidden_states, w1, w2, topk_ids, activation_callable,
|
||||
global_num_experts, expert_map, self.w1_scale, self.w2_scale,
|
||||
a1q_scale, self.a2_scale, self.ab_strides1, self.ab_strides2,
|
||||
a1q_scale, a2_scale, self.ab_strides1, self.ab_strides2,
|
||||
self.c_strides1, self.c_strides2, workspace13, workspace2,
|
||||
expert_num_tokens,
|
||||
self.out_dtype if self.out_dtype is not None else in_dtype,
|
||||
@ -705,6 +706,7 @@ class CutlassExpertsFp4(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor], # unused
|
||||
a2_scale: Optional[torch.Tensor], # unused
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
@ -214,13 +214,14 @@ class DeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
):
|
||||
assert a1q_scale is not None
|
||||
assert self.a2_scale is None
|
||||
assert a2_scale is None
|
||||
assert self.block_shape is not None
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
|
@ -129,6 +129,7 @@ class FlashInferExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: Optional[torch.Tensor],
|
||||
workspace2: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
@ -688,6 +688,7 @@ class NaiveBatchedExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -879,6 +880,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -970,7 +972,7 @@ class BatchedTritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
intermediate_cache1.view(-1, N))
|
||||
|
||||
qintermediate_cache2, a2q_scale = batched_moe_kernel_quantize_input(
|
||||
intermediate_cache2, self.a2_scale, max_num_tokens, E, N,
|
||||
intermediate_cache2, a2_scale, max_num_tokens, E, N,
|
||||
expert_num_tokens, self.quant_dtype, self.per_act_token_quant,
|
||||
self.block_shape)
|
||||
|
||||
|
@ -1598,6 +1598,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -1690,7 +1691,7 @@ class TritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
a2q_scale: Optional[torch.Tensor] = None
|
||||
|
||||
qintermediate_cache2, a2q_scale = moe_kernel_quantize_input(
|
||||
intermediate_cache2, self.a2_scale, self.quant_dtype,
|
||||
intermediate_cache2, a2_scale, self.quant_dtype,
|
||||
self.per_act_token_quant, self.block_shape)
|
||||
|
||||
invoke_fused_moe_kernel(
|
||||
|
@ -20,10 +20,10 @@ if has_triton_kernels():
|
||||
from triton_kernels.matmul_ogs import (FnSpecs, FusedActivation,
|
||||
matmul_ogs)
|
||||
from triton_kernels.routing import routing
|
||||
except ModuleNotFoundError:
|
||||
except (ModuleNotFoundError, AttributeError) as e:
|
||||
logger.error(
|
||||
"Failed to import Triton kernels. Please make sure your triton "
|
||||
"version is compatible.")
|
||||
"version is compatible. Error: %s", e)
|
||||
|
||||
|
||||
def triton_kernel_moe_forward(
|
||||
@ -179,6 +179,7 @@ class BatchedOAITritonExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
@ -519,6 +519,7 @@ class FusedMoEPermuteExpertsUnpermute(ABC):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
@ -634,6 +635,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
expert_tokens_meta: Optional[ExpertTokensMetadata],
|
||||
apply_router_weight_on_input: bool,
|
||||
) -> torch.Tensor:
|
||||
@ -671,6 +673,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=a2_scale,
|
||||
workspace13=workspace13,
|
||||
workspace2=workspace2,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
@ -718,6 +721,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=a1q_scale,
|
||||
a2_scale=self.fused_experts.a2_scale,
|
||||
expert_tokens_meta=expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
@ -803,6 +807,7 @@ class FusedMoEModularKernel(torch.nn.Module):
|
||||
local_num_experts=local_num_experts,
|
||||
expert_map=expert_map,
|
||||
a1q_scale=c_a1q_scale,
|
||||
a2_scale=c_a2_scale,
|
||||
expert_tokens_meta=c_expert_tokens_meta,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
)
|
||||
|
@ -111,6 +111,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
@ -134,6 +135,7 @@ class TritonOrDeepGemmExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts,
|
||||
expert_map,
|
||||
a1q_scale,
|
||||
a2_scale,
|
||||
workspace13,
|
||||
workspace2,
|
||||
expert_tokens_meta,
|
||||
|
@ -103,6 +103,7 @@ class TrtLlmGenExperts(mk.FusedMoEPermuteExpertsUnpermute):
|
||||
global_num_experts: int,
|
||||
expert_map: Optional[torch.Tensor],
|
||||
a1q_scale: Optional[torch.Tensor],
|
||||
a2_scale: Optional[torch.Tensor],
|
||||
workspace13: torch.Tensor,
|
||||
workspace2: torch.Tensor,
|
||||
expert_tokens_meta: Optional[mk.ExpertTokensMetadata],
|
||||
|
@ -241,7 +241,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return AWQMoEMethod(quant_args_marlin, layer.moe)
|
||||
return AWQMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
||||
@ -327,7 +327,7 @@ class AutoRoundConfig(QuantizationConfig):
|
||||
|
||||
if isinstance(layer, FusedMoE):
|
||||
if use_marlin:
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe)
|
||||
return GPTQMarlinMoEMethod(quant_args_marlin, layer.moe_config)
|
||||
else:
|
||||
from vllm.model_executor.layers.quantization.moe_wna16 import (
|
||||
MoeWNA16Config)
|
||||
|
@ -160,6 +160,7 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
Handles both exact matching (for fused layers) and substring matching.
|
||||
|
||||
This method handles both regular models and multimodal models that use
|
||||
the language_model prefix. For multimodal models, it checks if the
|
||||
@ -168,11 +169,18 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
if self.exclude_modules is None:
|
||||
return False
|
||||
|
||||
# Check if any excluded module matches the prefix
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# Then check substring matching for patterns not caught by exact match
|
||||
for module in self.exclude_modules:
|
||||
if (module in prefix
|
||||
or (prefix.startswith("language_model.")
|
||||
and module in prefix.removeprefix("language_model."))):
|
||||
# Skip exact matches already handled above
|
||||
if (module != prefix and
|
||||
(module in prefix or
|
||||
(prefix.startswith("language_model.")
|
||||
and module in prefix.removeprefix("language_model.")))):
|
||||
return True
|
||||
return False
|
||||
|
||||
@ -180,9 +188,10 @@ class ModelOptFp8Config(QuantizationConfig):
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping)
|
||||
or self.is_layer_excluded(prefix)):
|
||||
if self.is_layer_excluded(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if this is a vision model layer that should not be quantized
|
||||
if ("vision_tower" in prefix or "vision_model" in prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptFp8LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
@ -778,22 +787,34 @@ class ModelOptNvFp4Config(QuantizationConfig):
|
||||
return cls(is_checkpoint_nvfp4_serialized, kv_cache_quant_algo,
|
||||
exclude_modules, group_size)
|
||||
|
||||
def is_layer_excluded(self, prefix: str,
|
||||
exclude_modules: list[str]) -> bool:
|
||||
def is_layer_excluded(self, prefix: str) -> bool:
|
||||
"""
|
||||
Check if a layer should be excluded from quantization.
|
||||
Handles both exact matching (for fused layers) and pattern matching.
|
||||
"""
|
||||
# First check exact matching with fused layer support
|
||||
if is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping):
|
||||
return True
|
||||
|
||||
# Check regex pattern matching for patterns not caught by exact match
|
||||
import regex as re
|
||||
for pattern in exclude_modules:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
for pattern in self.exclude_modules:
|
||||
# Skip patterns that would be caught by exact matching
|
||||
if '*' in pattern or '.' in pattern:
|
||||
regex_str = pattern.replace('.', r'\.').replace('*', r'.*')
|
||||
if re.fullmatch(regex_str, prefix):
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_quant_method(self, layer: torch.nn.Module,
|
||||
prefix: str) -> Optional["QuantizeMethodBase"]:
|
||||
from vllm.attention.layer import Attention # Avoid circular import
|
||||
if isinstance(layer, LinearBase):
|
||||
if (is_layer_skipped(prefix, self.exclude_modules,
|
||||
self.packed_modules_mapping)
|
||||
or self.is_layer_excluded(prefix, self.exclude_modules)):
|
||||
if self.is_layer_excluded(prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
# Check if this is a vision model layer that should not be quantized
|
||||
if ("vision_tower" in prefix or "vision_model" in prefix):
|
||||
return UnquantizedLinearMethod()
|
||||
return ModelOptNvFp4LinearMethod(self)
|
||||
elif isinstance(layer, Attention):
|
||||
|
@ -638,8 +638,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
||||
return None
|
||||
|
||||
if self.mxfp4_backend == Mxfp4Backend.TRITON:
|
||||
w1_scale = layer.w13_precision_config
|
||||
w2_scale = layer.w2_precision_config
|
||||
w1_scale = self.w13_precision_config
|
||||
w2_scale = self.w2_precision_config
|
||||
else:
|
||||
w1_scale = layer.w13_weight_scale
|
||||
w2_scale = layer.w2_weight_scale
|
||||
|
@ -6,6 +6,8 @@ from typing import Optional
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.custom_op import CustomOp
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils.flashinfer import has_flashinfer
|
||||
|
||||
from .common import apply_rotary_emb_torch
|
||||
|
||||
@ -30,9 +32,17 @@ class RotaryEmbedding(CustomOp):
|
||||
self.base = base
|
||||
self.is_neox_style = is_neox_style
|
||||
self.dtype = dtype
|
||||
# Flashinfer only supports head_size=64, 128, 256, 512.
|
||||
# https://github.com/flashinfer-ai/flashinfer/blob/ebfd655efe830048dba5d582aaa61d61d1cf9a87/include/flashinfer/utils.cuh#L174-L202
|
||||
self.use_flashinfer = (self.enabled()
|
||||
and dtype in (torch.float16, torch.bfloat16)
|
||||
and current_platform.is_cuda()
|
||||
and has_flashinfer()
|
||||
and self.head_size in [64, 128, 256, 512])
|
||||
|
||||
cache = self._compute_cos_sin_cache()
|
||||
cache = cache.to(dtype)
|
||||
if not self.use_flashinfer:
|
||||
cache = cache.to(dtype)
|
||||
self.cos_sin_cache: torch.Tensor
|
||||
self.register_buffer("cos_sin_cache", cache, persistent=False)
|
||||
|
||||
@ -57,6 +67,14 @@ class RotaryEmbedding(CustomOp):
|
||||
cache = torch.cat((cos, sin), dim=-1)
|
||||
return cache
|
||||
|
||||
def _match_cos_sin_cache_dtype(self, query: torch.Tensor) -> None:
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
@ -94,15 +112,16 @@ class RotaryEmbedding(CustomOp):
|
||||
query: torch.Tensor,
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
if self.use_flashinfer:
|
||||
torch.ops.vllm.flashinfer_rotary_embedding(positions, query, key,
|
||||
self.head_size,
|
||||
self.cos_sin_cache,
|
||||
self.is_neox_style)
|
||||
return query, key
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
|
||||
# is expensive, so avoid calling it if possible
|
||||
if self.cos_sin_cache.device != query.device or \
|
||||
self.cos_sin_cache.dtype != query.dtype:
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
|
||||
dtype=query.dtype)
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
ops.rotary_embedding(positions, query, key, self.head_size,
|
||||
@ -117,8 +136,7 @@ class RotaryEmbedding(CustomOp):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
from vllm._ipex_ops import ipex_ops as ops
|
||||
|
||||
self.cos_sin_cache = self.cos_sin_cache.to(positions.device,
|
||||
dtype=query.dtype)
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
# ops.rotary_embedding() is an in-place operation
|
||||
# that updates the query and key tensors.
|
||||
if key is None:
|
||||
|
@ -6,6 +6,7 @@ import math
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import direct_register_custom_op
|
||||
|
||||
if current_platform.is_cuda():
|
||||
from vllm.vllm_flash_attn.layers.rotary import apply_rotary_emb
|
||||
@ -103,3 +104,48 @@ def yarn_get_mscale(scale: float = 1) -> float:
|
||||
if scale <= 1:
|
||||
return 1.0
|
||||
return 0.1 * math.log(scale) + 1.0
|
||||
|
||||
|
||||
def _flashinfer_rotary_embedding(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
"""Custom op wrapper for flashinfer's rotary embedding.
|
||||
|
||||
This is an in-place operation that modifies query and key tensors directly.
|
||||
"""
|
||||
from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace
|
||||
|
||||
apply_rope_with_cos_sin_cache_inplace(
|
||||
positions=positions,
|
||||
query=query,
|
||||
key=key,
|
||||
head_size=head_size,
|
||||
cos_sin_cache=cos_sin_cache,
|
||||
is_neox=is_neox,
|
||||
)
|
||||
|
||||
|
||||
def _flashinfer_rotary_embedding_fake(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool,
|
||||
) -> None:
|
||||
return
|
||||
|
||||
|
||||
# Register flashinfer rotary embedding custom op
|
||||
direct_register_custom_op(
|
||||
op_name="flashinfer_rotary_embedding",
|
||||
op_func=_flashinfer_rotary_embedding,
|
||||
mutates_args=["query", "key"], # These tensors are modified in-place
|
||||
fake_impl=_flashinfer_rotary_embedding_fake,
|
||||
dispatch_key=current_platform.dispatch_key,
|
||||
)
|
||||
|
@ -97,15 +97,13 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""PyTorch-native implementation equivalent to forward()."""
|
||||
assert key is not None
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_rot = query[..., :self.rotary_dim]
|
||||
key_rot = key[..., :self.rotary_dim]
|
||||
if self.rotary_dim < self.head_size:
|
||||
query_pass = query[..., self.rotary_dim:]
|
||||
key_pass = key[..., self.rotary_dim:]
|
||||
|
||||
if self.cos_sin_cache.device != positions.device:
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(
|
||||
positions.device)
|
||||
cos_sin = self.cos_sin_cache[torch.add(positions, offsets)
|
||||
if offsets is not None else positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
|
@ -59,7 +59,7 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
||||
key: Optional[torch.Tensor] = None,
|
||||
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
assert key is not None
|
||||
self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to(query.device)
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
query_ = torch.view_as_complex(query.float().reshape(
|
||||
*query.shape[:-1], -1, 2))
|
||||
key_ = torch.view_as_complex(key.float().reshape(
|
||||
|
@ -15,7 +15,7 @@ from .common import apply_rotary_emb_dispatch
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _triton_qwen2vl_mrope_forward(
|
||||
def _triton_mrope_forward(
|
||||
q_ptr,
|
||||
k_ptr,
|
||||
cos,
|
||||
@ -30,12 +30,14 @@ def _triton_qwen2vl_mrope_forward(
|
||||
pad_hd: tl.constexpr,
|
||||
mrope_section_t: tl.constexpr,
|
||||
mrope_section_h: tl.constexpr,
|
||||
mrope_section_w: tl.constexpr,
|
||||
is_interleaved: tl.constexpr,
|
||||
):
|
||||
# Adapted from
|
||||
# https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
|
||||
# This version supports flatten input tensors from vllm
|
||||
# and supports cos and sin cache with shape (3, num_tokens, head_dim // 2)
|
||||
# instead of (3, bsz, seq_len, head_dim)
|
||||
# instead of (3, bsz, seq_len, head_dim), also supports interleaved rotary
|
||||
pid = tl.program_id(0)
|
||||
# locate start address
|
||||
q_ptr = q_ptr + pid * (n_qh * hd)
|
||||
@ -47,9 +49,6 @@ def _triton_qwen2vl_mrope_forward(
|
||||
# ####################################################################
|
||||
# Note: cos and sin now have shape (3, num_tokens, head_dim // 2)
|
||||
|
||||
t_end = mrope_section_t
|
||||
h_end = t_end + mrope_section_h
|
||||
|
||||
# Updated stride calculation for half head_dim
|
||||
half_rd = rd // 2
|
||||
t_cos = cos + pid * half_rd
|
||||
@ -61,9 +60,18 @@ def _triton_qwen2vl_mrope_forward(
|
||||
|
||||
# Updated offsets for half head_dim
|
||||
cos_offsets = tl.arange(0, pad_hd // 2)
|
||||
t_mask = cos_offsets < t_end
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||
if is_interleaved:
|
||||
h_mask = (((cos_offsets % 3) == 1) &
|
||||
(cos_offsets <= 3 * mrope_section_h))
|
||||
w_mask = (((cos_offsets % 3) == 2) &
|
||||
(cos_offsets <= 3 * mrope_section_w))
|
||||
t_mask = ~(h_mask | w_mask)
|
||||
else:
|
||||
t_end = mrope_section_t
|
||||
h_end = t_end + mrope_section_h
|
||||
t_mask = cos_offsets < mrope_section_t
|
||||
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
|
||||
w_mask = (h_end <= cos_offsets) & (cos_offsets < half_rd)
|
||||
|
||||
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
|
||||
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
|
||||
@ -131,6 +139,7 @@ def triton_mrope(
|
||||
mrope_section: list[int],
|
||||
head_size: int,
|
||||
rotary_dim: int,
|
||||
mrope_interleaved: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Qwen2VL mrope kernel.
|
||||
|
||||
@ -158,7 +167,7 @@ def triton_mrope(
|
||||
cos = cos.contiguous()
|
||||
sin = sin.contiguous()
|
||||
|
||||
_triton_qwen2vl_mrope_forward[(n_row, )](
|
||||
_triton_mrope_forward[(n_row, )](
|
||||
q,
|
||||
k,
|
||||
cos,
|
||||
@ -173,6 +182,8 @@ def triton_mrope(
|
||||
pad_hd,
|
||||
mrope_section[0],
|
||||
mrope_section[1],
|
||||
mrope_section[2],
|
||||
mrope_interleaved,
|
||||
)
|
||||
return q, k
|
||||
|
||||
@ -201,7 +212,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
is_neox_style: bool,
|
||||
dtype: torch.dtype,
|
||||
mrope_section: Optional[list[int]] = None,
|
||||
mrope_interleaved: Optional[bool] = False,
|
||||
mrope_interleaved: bool = False,
|
||||
) -> None:
|
||||
# In Qwen2.5-VL, the maximum index value is related to the duration of
|
||||
# the input video. We enlarge max_position_embeddings to 4 times to get
|
||||
@ -234,6 +245,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -282,10 +294,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
assert positions.ndim == 1 or positions.ndim == 2
|
||||
assert key is not None
|
||||
|
||||
if self.mrope_interleaved:
|
||||
# TODO: add triton implementation to support mrope-interleaved
|
||||
return self.forward_native(positions, query, key)
|
||||
|
||||
self._match_cos_sin_cache_dtype(query)
|
||||
num_tokens = positions.shape[-1]
|
||||
cos_sin = self.cos_sin_cache[positions]
|
||||
cos, sin = cos_sin.chunk(2, dim=-1)
|
||||
@ -302,6 +311,7 @@ class MRotaryEmbedding(RotaryEmbedding):
|
||||
self.mrope_section,
|
||||
self.head_size,
|
||||
self.rotary_dim,
|
||||
self.mrope_interleaved,
|
||||
)
|
||||
|
||||
return q.reshape(query_shape), k.reshape(key_shape)
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user