Merge branch 'main' into woosuk-jf

This commit is contained in:
Woosuk Kwon
2025-05-03 10:42:43 -07:00
132 changed files with 3685 additions and 738 deletions

View File

@ -57,6 +57,7 @@ steps:
agents:
queue: tpu_queue_postmerge
commands:
- "yes | docker system prune -a"
- "git fetch --all"
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --tag vllm/vllm-tpu:nightly --tag vllm/vllm-tpu:$BUILDKITE_COMMIT --progress plain -f docker/Dockerfile.tpu ."
- "docker push vllm/vllm-tpu:nightly"

View File

@ -293,6 +293,7 @@ steps:
parallelism: 4
- label: PyTorch Compilation Unit Tests
torch_nightly: true
source_file_dependencies:
- vllm/
- tests/compile
@ -302,6 +303,7 @@ steps:
- pytest -v -s compile/test_sequence_parallelism.py
- label: PyTorch Fullgraph Smoke Test # 9min
torch_nightly: true
source_file_dependencies:
- vllm/
- tests/compile
@ -312,6 +314,7 @@ steps:
- pytest -v -s compile/piecewise/test_toy_llama.py
- label: PyTorch Fullgraph Test # 18min
torch_nightly: true
source_file_dependencies:
- vllm/
- tests/compile
@ -436,6 +439,7 @@ steps:
##### models test #####
- label: Basic Models Test # 24min
torch_nightly: true
source_file_dependencies:
- vllm/
- tests/models

View File

@ -46,7 +46,7 @@ repos:
rev: 0.6.17
hooks:
- id: pip-compile
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match]
args: [requirements/test.in, -o, requirements/test.txt, --index-strategy, unsafe-best-match, --torch-backend, cu128]
files: ^requirements/test\.(in|txt)$
- repo: local
hooks:

View File

@ -15,7 +15,6 @@ project(vllm_extensions LANGUAGES CXX)
# CUDA by default, can be overridden by using -DVLLM_TARGET_DEVICE=... (used by setup.py)
set(VLLM_TARGET_DEVICE "cuda" CACHE STRING "Target device backend for vLLM")
message(STATUS "Build type: ${CMAKE_BUILD_TYPE}")
message(STATUS "Target device: ${VLLM_TARGET_DEVICE}")
@ -250,9 +249,8 @@ set(VLLM_EXT_SRC
if(VLLM_GPU_LANG STREQUAL "CUDA")
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
# Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case.
# Please keep this in sync with FetchContent_Declare line below.
set(CUTLASS_REVISION "v3.9.0" CACHE STRING "CUTLASS revision to use")
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
set(CUTLASS_REVISION "v3.9.1" CACHE STRING "CUTLASS revision to use")
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
@ -270,7 +268,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# Please keep this in sync with CUTLASS_REVISION line above.
GIT_TAG v3.9.0
GIT_TAG ${CUTLASS_REVISION}
GIT_PROGRESS TRUE
# Speed up CUTLASS download by retrieving only the specified GIT_TAG instead of the history.
@ -682,6 +680,17 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
endif()
endif()
if(VLLM_GPU_LANG STREQUAL "CUDA")
set(MOE_PERMUTE_SRC
"csrc/moe/permute_unpermute_kernels/moe_permute_unpermute_kernel.cu"
"csrc/moe/moe_permute_unpermute_op.cu")
set_gencode_flags_for_srcs(
SRCS "${MARLIN_PERMUTE_SRC}"
CUDA_ARCHS "${MOE_PERMUTE_ARCHS}")
list(APPEND VLLM_MOE_EXT_SRC "${MOE_PERMUTE_SRC}")
endif()
message(STATUS "Enabling moe extension.")
define_gpu_extension_target(
_moe_C
@ -690,6 +699,8 @@ define_gpu_extension_target(
SOURCES ${VLLM_MOE_EXT_SRC}
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
INCLUDE_DIRECTORIES ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
USE_SABI 3
WITH_SOABI)

View File

@ -90,7 +90,8 @@ def bench_run(results: list[benchmark.Measurement], model: str,
score = torch.randn((m, num_experts), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, renormalize=False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, renormalize=False)
def run_triton_moe(a: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,

View File

@ -115,8 +115,8 @@ def benchmark_config(config: BenchmarkConfig,
from vllm.model_executor.layers.fused_moe import override_config
with override_config(config):
if use_deep_gemm:
topk_weights, topk_ids = fused_topk(x, input_gating, topk,
False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
x, input_gating, topk, False)
return fused_experts(
x,
w1,
@ -442,8 +442,14 @@ class BenchmarkWorker:
hidden_size, search_space,
is_fp16, topk)
with torch.cuda.device(self.device_id) if current_platform.is_rocm(
) else nullcontext():
need_device_guard = False
if current_platform.is_rocm():
visible_device = os.environ.get("ROCR_VISIBLE_DEVICES", None)
if visible_device != f"{self.device_id}":
need_device_guard = True
with torch.cuda.device(
self.device_id) if need_device_guard else nullcontext():
for config in tqdm(search_space):
try:
kernel_time = benchmark_config(
@ -578,6 +584,15 @@ def main(args: argparse.Namespace):
use_deep_gemm = bool(args.use_deep_gemm)
if current_platform.is_rocm() and "HIP_VISIBLE_DEVICES" in os.environ:
# Ray will set ROCR_VISIBLE_DEVICES for device visibility
logger.warning(
"Ray uses ROCR_VISIBLE_DEVICES to control device accessibility."
"Replacing HIP_VISIBLE_DEVICES with ROCR_VISIBLE_DEVICES.")
val = os.environ["HIP_VISIBLE_DEVICES"]
os.environ["ROCR_VISIBLE_DEVICES"] = val
del os.environ["HIP_VISIBLE_DEVICES"]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]

View File

@ -0,0 +1,349 @@
# SPDX-License-Identifier: Apache-2.0
import argparse
from typing import Any, TypedDict
import ray
import torch
from transformers import AutoConfig
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import (
_moe_permute, _moe_unpermute_and_reduce)
from vllm.model_executor.layers.fused_moe.fused_moe import *
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import *
from vllm.model_executor.layers.fused_moe.utils import _fp8_quantize
from vllm.platforms import current_platform
from vllm.utils import FlexibleArgumentParser
FP8_DTYPE = current_platform.fp8_dtype()
class BenchmarkConfig(TypedDict):
BLOCK_SIZE_M: int
BLOCK_SIZE_N: int
BLOCK_SIZE_K: int
GROUP_SIZE_M: int
num_warps: int
num_stages: int
def benchmark_permute(num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
use_customized_permute: bool = False) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
# output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
dtype=torch.float32)
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False)
def prepare(i: int):
input_gating.copy_(gating_output[i])
def run():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx,
m_indices) = moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
token_expert_indices=token_expert_indices,
topk=topk,
n_expert=num_experts,
n_local_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
else:
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
num_experts, None, align_block_size)
# JIT compilation & warmup
run()
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
prepare(i)
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
def benchmark_unpermute(num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
use_customized_permute: bool = False) -> float:
# init_dtype = torch.float16 if use_fp8_w8a8 else dtype
hidden_states = torch.randn(num_tokens, hidden_size, dtype=dtype)
output_hidden_states = torch.empty_like(hidden_states)
if use_fp8_w8a8:
align_block_size = 128 # deepgemm needs 128 m aligned block
qhidden_states, scale = _fp8_quantize(hidden_states, None, None)
else:
align_block_size = None
qhidden_states = hidden_states
input_gating = torch.randn(num_tokens, num_experts, dtype=torch.float32)
topk_weights, topk_ids, token_expert_indices = fused_topk(
qhidden_states, input_gating, topk, False)
def prepare():
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx,
m_indices) = moe_permute(
qhidden_states,
topk_weights=topk_weights,
topk_ids=topk_ids,
token_expert_indices=token_expert_indices,
topk=topk,
n_expert=num_experts,
n_local_expert=num_experts,
expert_map=None,
align_block_size=align_block_size,
)
# convert to fp16/bf16 as gemm output
return (permuted_hidden_states.to(dtype), first_token_off,
inv_perm_idx, m_indices)
else:
(permuted_qhidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = _moe_permute(qhidden_states, None, topk_ids,
num_experts, None, align_block_size)
# convert to fp16/bf16 as gemm output
return (permuted_qhidden_states.to(dtype), a1q_scale,
sorted_token_ids, expert_ids, inv_perm)
def run(input: tuple):
if use_customized_permute:
(permuted_hidden_states, first_token_off, inv_perm_idx,
m_indices) = input
moe_unpermute(permuted_hidden_states, topk_weights, topk_ids,
inv_perm_idx, first_token_off, topk, num_experts,
num_experts)
else:
(permuted_hidden_states, a1q_scale, sorted_token_ids, expert_ids,
inv_perm) = input
_moe_unpermute_and_reduce(output_hidden_states,
permuted_hidden_states, inv_perm,
topk_weights)
# JIT compilation & warmup
input = prepare()
run(input)
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run(input)
torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
latencies: list[float] = []
for i in range(num_iters):
torch.cuda.synchronize()
start_event.record()
graph.replay()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
return avg
@ray.remote(num_gpus=1)
class BenchmarkWorker:
def __init__(self, seed: int) -> None:
torch.set_default_device("cuda")
current_platform.seed_everything(seed)
self.seed = seed
# Get the device ID to allocate tensors and kernels
# on the respective GPU. This is required for Ray to work
# correctly with multi-GPU tuning on the ROCm platform.
self.device_id = int(ray.get_gpu_ids()[0])
def benchmark(
self,
num_tokens: int,
num_experts: int,
hidden_size: int,
topk: int,
dtype: torch.dtype,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_customized_permute: bool = False,
) -> tuple[dict[str, int], float]:
current_platform.seed_everything(self.seed)
permute_time = benchmark_permute(
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
use_customized_permute=use_customized_permute)
unpermute_time = benchmark_unpermute(
num_tokens,
num_experts,
hidden_size,
topk,
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=100,
use_customized_permute=use_customized_permute)
return permute_time, unpermute_time
def get_weight_block_size_safety(config, default_value=None):
quantization_config = getattr(config, 'quantization_config', {})
if isinstance(quantization_config, dict):
return quantization_config.get('weight_block_size', default_value)
return default_value
def main(args: argparse.Namespace):
print(args)
config = AutoConfig.from_pretrained(
args.model, trust_remote_code=args.trust_remote_code)
if config.architectures[0] == "DbrxForCausalLM":
E = config.ffn_config.moe_num_experts
topk = config.ffn_config.moe_top_k
elif config.architectures[0] == "JambaForCausalLM":
E = config.num_experts
topk = config.num_experts_per_tok
elif (config.architectures[0] == "DeepseekV3ForCausalLM"
or config.architectures[0] == "DeepseekV2ForCausalLM"):
E = config.n_routed_experts
topk = config.num_experts_per_tok
elif config.architectures[0] in [
"Qwen2MoeForCausalLM", "Qwen3MoeForCausalLM"
]:
E = config.num_experts
topk = config.num_experts_per_tok
else:
# Support for llama4
config = config.get_text_config()
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
hidden_size = config.hidden_size
dtype = torch.float16 if current_platform.is_rocm() else config.torch_dtype
use_fp8_w8a8 = args.dtype == "fp8_w8a8"
use_int8_w8a16 = args.dtype == "int8_w8a16"
use_customized_permute = args.use_customized_permute
if args.batch_size is None:
batch_sizes = [
1, 2, 4, 8, 16, 24, 32, 48, 64, 96, 128, 256, 512, 1024, 1536,
2048, 3072, 4096
]
else:
batch_sizes = [args.batch_size]
ray.init()
num_gpus = int(ray.available_resources()["GPU"])
workers = [BenchmarkWorker.remote(args.seed) for _ in range(num_gpus)]
def _distribute(method: str, inputs: list[Any]) -> list[Any]:
outputs = []
worker_idx = 0
for input_args in inputs:
worker = workers[worker_idx]
worker_method = getattr(worker, method)
output = worker_method.remote(*input_args)
outputs.append(output)
worker_idx = (worker_idx + 1) % num_gpus
return ray.get(outputs)
outputs = _distribute(
"benchmark", [(batch_size, E, hidden_size, topk, dtype, use_fp8_w8a8,
use_int8_w8a16, use_customized_permute)
for batch_size in batch_sizes])
for batch_size, (permute, unpermute) in zip(batch_sizes, outputs):
print(f"Batch size: {batch_size}")
print(f"Permute time: {permute:.2f} us")
print(f"Unpermute time: {unpermute:.2f} us")
if __name__ == "__main__":
parser = FlexibleArgumentParser()
parser.add_argument("--model",
type=str,
default="mistralai/Mixtral-8x7B-Instruct-v0.1")
parser.add_argument("--dtype",
type=str,
choices=["auto", "fp8_w8a8", "int8_w8a16"],
default="auto")
parser.add_argument("--use-customized-permute", action="store_true")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
main(args)

View File

@ -0,0 +1,133 @@
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include "permute_unpermute_kernels/moe_permute_unpermute_kernel.h"
#include "permute_unpermute_kernels/dispatch.h"
#include "core/registration.h"
void moe_permute(
const torch::Tensor& input, // [n_token, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& token_expert_indicies, // [n_token, topk]
const std::optional<torch::Tensor>& expert_map, // [n_expert]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
const std::optional<int64_t>& align_block_size,
torch::Tensor&
permuted_input, // [topk * n_token/align_block_size_m, hidden]
torch::Tensor& expert_first_token_offset, // [n_local_expert + 1]
torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
torch::Tensor& m_indices) { // [align_expand_m]
TORCH_CHECK(topk_weights.scalar_type() == at::ScalarType::Float,
"topk_weights must be float32");
TORCH_CHECK(expert_first_token_offset.scalar_type() == at::ScalarType::Long,
"expert_first_token_offset must be int64");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(token_expert_indicies.scalar_type() == at::ScalarType::Int,
"token_expert_indicies must be int32");
TORCH_CHECK(src_row_id2dst_row_id_map.scalar_type() == at::ScalarType::Int,
"src_row_id2dst_row_id_map must be int32");
TORCH_CHECK(expert_first_token_offset.size(0) == n_local_expert + 1,
"expert_first_token_offset shape != n_local_expert+1")
TORCH_CHECK(
src_row_id2dst_row_id_map.sizes() == token_expert_indicies.sizes(),
"token_expert_indicies shape must be same as src_row_id2dst_row_id_map");
auto n_token = input.sizes()[0];
auto n_hidden = input.sizes()[1];
auto align_block_size_value =
align_block_size.has_value() ? align_block_size.value() : -1;
auto stream = at::cuda::getCurrentCUDAStream().stream();
const long sorter_size =
CubKeyValueSorter::getWorkspaceSize(n_token * topk, n_expert);
auto sort_workspace = torch::empty(
{sorter_size},
torch::dtype(torch::kInt8).device(torch::kCUDA).requires_grad(false));
auto permuted_experts_id = torch::empty_like(topk_ids);
auto dst_row_id2src_row_id_map = torch::empty_like(src_row_id2dst_row_id_map);
auto align_expert_first_token_offset =
torch::zeros_like(expert_first_token_offset);
CubKeyValueSorter sorter{};
int64_t* valid_num_ptr = nullptr;
// pre-process kernel for expert-parallelism:
// no local expert id plus "n_expert" offset for priority to local expert
// map local expert id [n, .., n+n_local_expert-1] to [0, n_local_expert -1]
// For example, 4 expert with ep_size=2. ep_rank=1 owns global expert id
// [2,3] with expert_map[-1, -1, 0, 1], preprocess_topk_id process topk_ids
// and map global expert id [2, 3] to local_expert id [0, 1] and map global
// expert id [0, 1] ( not in ep rank=1) to [4, 5] by plus n_expert. This map
// operation is to make local expert high priority in following sort topk_ids
// and scan local expert_first_token_offset for each ep rank for next group
// gemm.
if (expert_map.has_value()) {
const int* expert_map_ptr = get_ptr<int>(expert_map.value());
valid_num_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
preprocessTopkIdLauncher(get_ptr<int>(topk_ids), n_token * topk,
expert_map_ptr, n_expert, stream);
}
// expert sort topk expert id and scan expert id get expert_first_token_offset
sortAndScanExpert(get_ptr<int>(topk_ids), get_ptr<int>(token_expert_indicies),
get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token,
n_expert, n_local_expert, topk, sorter,
get_ptr<int>(sort_workspace), stream);
// dispatch expandInputRowsKernelLauncher
MOE_DISPATCH(input.scalar_type(), [&] {
expandInputRowsKernelLauncher<scalar_t>(
get_ptr<scalar_t>(input), get_ptr<scalar_t>(permuted_input),
get_ptr<float>(topk_weights), get_ptr<int>(permuted_experts_id),
get_ptr<int>(dst_row_id2src_row_id_map),
get_ptr<int>(src_row_id2dst_row_id_map),
get_ptr<int64_t>(expert_first_token_offset), n_token, valid_num_ptr,
n_hidden, topk, n_local_expert, align_block_size_value, stream);
});
// get m_indices and update expert_first_token_offset with align block
getMIndices(get_ptr<int64_t>(expert_first_token_offset),
get_ptr<int64_t>(align_expert_first_token_offset),
get_ptr<int>(m_indices), n_local_expert, align_block_size_value,
stream);
if (align_block_size.has_value()) {
// update align_expert_first_token_offset
expert_first_token_offset.copy_(align_expert_first_token_offset);
}
}
void moe_unpermute(
const torch::Tensor& permuted_hidden_states, // [n_token * topk, hidden]
const torch::Tensor& topk_weights, //[n_token, topk]
const torch::Tensor& topk_ids, // [n_token, topk]
const torch::Tensor& src_row_id2dst_row_id_map, // [n_token, topk]
const torch::Tensor& expert_first_token_offset, // [n_local_expert+1]
int64_t n_expert, int64_t n_local_expert, int64_t topk,
torch::Tensor& hidden_states // [n_token, hidden]
) {
TORCH_CHECK(src_row_id2dst_row_id_map.sizes() == topk_ids.sizes(),
"topk_ids shape must be same as src_row_id2dst_row_id_map");
TORCH_CHECK(topk_ids.scalar_type() == at::ScalarType::Int,
"topk_ids must be int32");
TORCH_CHECK(
permuted_hidden_states.scalar_type() == hidden_states.scalar_type(),
"topk_ids dtype must be same as src_row_id2dst_row_id_map");
auto n_token = hidden_states.size(0);
auto n_hidden = hidden_states.size(1);
auto stream = at::cuda::getCurrentCUDAStream().stream();
const int64_t* valid_ptr =
get_ptr<int64_t>(expert_first_token_offset) + n_local_expert;
MOE_DISPATCH(hidden_states.scalar_type(), [&] {
finalizeMoeRoutingKernelLauncher<scalar_t, scalar_t>(
get_ptr<scalar_t>(permuted_hidden_states),
get_ptr<scalar_t>(hidden_states), get_ptr<float>(topk_weights),
get_ptr<int>(src_row_id2dst_row_id_map), get_ptr<int>(topk_ids),
n_token, n_hidden, topk, valid_ptr, stream);
});
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("moe_permute", &moe_permute);
m.impl("moe_unpermute", &moe_unpermute);
}

View File

@ -0,0 +1,53 @@
#pragma once
#include <cuda_fp8.h>
#define MOE_SWITCH(TYPE, ...) \
at::ScalarType _st = ::detail::scalar_type(TYPE); \
switch (_st) { \
__VA_ARGS__ \
default: \
TORCH_CHECK(false, "[moe permute]data type dispatch fail!") \
}
#define MOE_DISPATCH_CASE(enum_type, ...) \
case enum_type: { \
using scalar_t = ScalarType2CudaType<enum_type>::type; \
__VA_ARGS__(); \
break; \
}
#define MOE_DISPATCH_FLOAT_CASE(...) \
MOE_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e5m2, __VA_ARGS__) \
MOE_DISPATCH_CASE(at::ScalarType::Float8_e4m3fn, __VA_ARGS__)
#define MOE_DISPATCH(TYPE, ...) \
MOE_SWITCH(TYPE, MOE_DISPATCH_FLOAT_CASE(__VA_ARGS__))
template <at::ScalarType type>
struct ScalarType2CudaType;
template <>
struct ScalarType2CudaType<at::ScalarType::Float> {
using type = float;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Half> {
using type = half;
};
template <>
struct ScalarType2CudaType<at::ScalarType::BFloat16> {
using type = __nv_bfloat16;
};
// #if __CUDA_ARCH__ >= 890
// fp8
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e5m2> {
using type = __nv_fp8_e5m2;
};
template <>
struct ScalarType2CudaType<at::ScalarType::Float8_e4m3fn> {
using type = __nv_fp8_e4m3;
};
// #endif

View File

@ -0,0 +1,229 @@
#include "moe_permute_unpermute_kernel.h"
// CubKeyValueSorter definition begin
CubKeyValueSorter::CubKeyValueSorter()
: num_experts_(0), num_bits_(sizeof(int) * 8) {}
int CubKeyValueSorter::expertsToBits(int num_experts) {
// Max value we represent is V = num_experts + (num_experts - 1) = 2 *
// num_experts - 1 The maximum number of bits is therefore floor(log2(V)) + 1
return static_cast<int>(log2(2 * num_experts - 1)) + 1;
}
CubKeyValueSorter::CubKeyValueSorter(int const num_experts)
: num_experts_(num_experts), num_bits_(expertsToBits(num_experts)) {}
void CubKeyValueSorter::updateNumExperts(int const num_experts) {
num_experts_ = num_experts;
num_bits_ = expertsToBits(num_experts);
}
size_t CubKeyValueSorter::getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts) {
int num_bits = expertsToBits(num_experts);
size_t required_storage = 0;
int* null_int = nullptr;
cub::DeviceRadixSort::SortPairs(nullptr, required_storage, null_int, null_int,
null_int, null_int, num_key_value_pairs, 0,
num_bits);
// when num_key_value_pairs, num_experts, num_bits, required_storage = 64,
// 4, 3, 0 The required_storage seems to vary between 0 and 1 for the same
// inputs
if (required_storage == 0) {
required_storage = 1;
}
return required_storage;
}
void CubKeyValueSorter::run(void* workspace, size_t const workspace_size,
int const* keys_in, int* keys_out,
int const* values_in, int* values_out,
size_t const num_key_value_pairs,
cudaStream_t stream) {
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs, num_experts_);
size_t actual_ws_size = workspace_size;
TORCH_CHECK(expected_ws_size <= workspace_size,
"[CubKeyValueSorter::run] The allocated workspace is too small "
"to run this problem.");
cub::DeviceRadixSort::SortPairs(workspace, actual_ws_size, keys_in, keys_out,
values_in, values_out, num_key_value_pairs, 0,
num_bits_, stream);
}
// CubKeyValueSorter definition end
static inline size_t pad_to_multiple_of_16(size_t const& input) {
static constexpr int ALIGNMENT = 16;
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
}
template <class T>
__device__ inline int64_t findTotalEltsLessThanTarget(T const* sorted_indices,
int64_t const arr_length,
T const target) {
int64_t low = 0, high = arr_length - 1, target_location = -1;
while (low <= high) {
int64_t mid = (low + high) / 2;
if (sorted_indices[mid] >= target) {
high = mid - 1;
} else {
low = mid + 1;
target_location = mid;
}
}
return target_location + 1;
}
// Calculates the start offset of the tokens for a given expert. The last
// element is the total number of valid tokens
__global__ void computeExpertFirstTokenOffsetKernel(
int const* sorted_experts, int64_t const sorted_experts_len,
int const num_experts, int64_t* expert_first_token_offset) {
// First, compute the global tid. We only need 1 thread per expert.
int const expert = blockIdx.x * blockDim.x + threadIdx.x;
// Note that expert goes [0, num_experts] (inclusive) because we want a count
// for the total number of active tokens at the end of the scan.
if (expert >= num_experts + 1) {
return;
}
expert_first_token_offset[expert] =
findTotalEltsLessThanTarget(sorted_experts, sorted_experts_len, expert);
}
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream) {
int const num_entries = num_experts + 1;
int const threads = std::min(1024, num_entries);
int const blocks = (num_entries + threads - 1) / threads;
computeExpertFirstTokenOffsetKernel<<<blocks, threads, 0, stream>>>(
sorted_indices, total_indices, num_experts, expert_first_token_offset);
}
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream) {
int64_t const expanded_num_rows = static_cast<int64_t>(k) * num_rows;
// We need to use the full num_experts because that is the sentinel value used
// by topk for disabled experts
sorter.updateNumExperts(num_experts);
size_t const sorter_ws_size_bytes = pad_to_multiple_of_16(
sorter.getWorkspaceSize(expanded_num_rows, num_experts));
sorter.run((void*)sorter_ws, sorter_ws_size_bytes, expert_for_source_row,
permuted_experts, source_rows, permuted_rows, expanded_num_rows,
stream);
computeExpertFirstTokenOffset(permuted_experts, expanded_num_rows,
num_experts_per_node, expert_first_token_offset,
stream);
}
__global__ void preprocessTopkIdKernel(int* topk_id_ptr, int size,
const int* expert_map_ptr,
int num_experts) {
auto tidx = threadIdx.x;
auto bidx = blockIdx.x;
auto lidx = tidx & 31;
auto widx = tidx >> 5;
auto warp_count = (blockDim.x + 31) >> 5;
auto offset = bidx * blockDim.x;
auto bound = min(offset + blockDim.x, size);
extern __shared__ int smem_expert_map[];
// store expert_map in smem
for (int i = tidx; i < num_experts; i += blockDim.x) {
smem_expert_map[i] = expert_map_ptr[i];
}
__syncthreads();
// query global expert id in expert map.
// if global expert id = -1 in exert map, plus n_expert
// else set global expert id = exert map[global expert id]
if (offset + tidx < bound) {
auto topk_id = topk_id_ptr[offset + tidx];
auto local_expert_idx = smem_expert_map[topk_id];
if (local_expert_idx == -1) {
topk_id += num_experts;
} else {
topk_id = local_expert_idx;
}
__syncwarp();
topk_id_ptr[offset + tidx] = topk_id;
}
}
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream) {
int block = std::min(size, 1024);
int grid = (size + block - 1) / block;
int smem_size = (num_experts) * sizeof(int);
preprocessTopkIdKernel<<<grid, block, smem_size, stream>>>(
topk_id_ptr, size, expert_map_ptr, num_experts);
}
template <bool ALIGN_BLOCK_SIZE>
__global__ void getMIndicesKernel(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset,
int* m_indices, const int num_local_expert,
const int align_block_size) {
int eidx = blockIdx.x;
int tidx = threadIdx.x;
extern __shared__ int64_t smem_expert_first_token_offset[];
for (int i = tidx; i <= num_local_expert; i += blockDim.x) {
smem_expert_first_token_offset[tidx] = __ldg(expert_first_token_offset + i);
}
__syncthreads();
auto last_token_offset = smem_expert_first_token_offset[eidx + 1];
auto first_token_offset = smem_expert_first_token_offset[eidx];
int n_token_in_expert = last_token_offset - first_token_offset;
if constexpr (ALIGN_BLOCK_SIZE) {
n_token_in_expert = (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
// round up to ALIGN_BLOCK_SIZE
int64_t accumulate_align_offset = 0;
for (int i = 1; i <= eidx + 1; i++) {
int n_token = smem_expert_first_token_offset[i] -
smem_expert_first_token_offset[i - 1];
accumulate_align_offset =
accumulate_align_offset + (n_token + align_block_size - 1) /
align_block_size * align_block_size;
if (i == eidx) {
first_token_offset = accumulate_align_offset;
}
// last block store align_expert_first_token_offset
if (eidx == num_local_expert - 1 && threadIdx.x == 0) {
align_expert_first_token_offset[i] = accumulate_align_offset;
}
}
}
for (int idx = tidx; idx < n_token_in_expert; idx += blockDim.x) {
// update m_indice with expert id
m_indices[first_token_offset + idx] = eidx;
}
}
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream) {
int block = 256;
int grid = num_local_expert;
int smem_size = sizeof(int64_t) * (num_local_expert + 1);
if (align_block_size == -1) {
getMIndicesKernel<false><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
} else {
getMIndicesKernel<true><<<grid, block, smem_size, stream>>>(
expert_first_token_offset, align_expert_first_token_offset, m_indices,
num_local_expert, align_block_size);
}
}

View File

@ -0,0 +1,95 @@
#pragma once
// reference from tensorrt_llm moe kernel implementation archive in
// https://github.com/BBuf/tensorrt-llm-moe/tree/master
#include <c10/core/ScalarType.h>
#include <torch/all.h>
#include "dispatch.h"
#include <cub/cub.cuh>
#include <cub/device/device_radix_sort.cuh>
#include <cub/util_type.cuh>
#include "cutlass/numeric_size.h"
#include "cutlass/array.h"
template <typename T>
inline T* get_ptr(torch::Tensor& t) {
return reinterpret_cast<T*>(t.data_ptr());
}
template <typename T>
inline const T* get_ptr(const torch::Tensor& t) {
return reinterpret_cast<const T*>(t.data_ptr());
}
class CubKeyValueSorter {
public:
CubKeyValueSorter();
CubKeyValueSorter(int const num_experts);
void updateNumExperts(int const num_experts);
static size_t getWorkspaceSize(size_t const num_key_value_pairs,
int const num_experts);
void run(void* workspace, size_t const workspace_size, int const* keys_in,
int* keys_out, int const* values_in, int* values_out,
size_t const num_key_value_pairs, cudaStream_t stream);
private:
static int expertsToBits(int experts);
int num_experts_;
int num_bits_;
};
void computeExpertFirstTokenOffset(int const* sorted_indices,
int const total_indices,
int const num_experts,
int64_t* expert_first_token_offset,
cudaStream_t stream);
void sortAndScanExpert(int* expert_for_source_row, const int* source_rows,
int* permuted_experts, int* permuted_rows,
int64_t* expert_first_token_offset, int num_rows,
int num_experts, int num_experts_per_node, int k,
CubKeyValueSorter& sorter, void* sorter_ws,
cudaStream_t stream);
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream);
// Final kernel to unpermute and scale
// This kernel unpermutes the original data, does the k-way reduction and
// performs the final skip connection.
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr);
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream);
void preprocessTopkIdLauncher(int* topk_id_ptr, int size,
const int* expert_map_ptr, int num_experts,
cudaStream_t stream);
void getMIndices(int64_t* expert_first_token_offset,
int64_t* align_expert_first_token_offset, int* m_indices,
int num_local_expert, const int align_block_size,
cudaStream_t stream);
#include "moe_permute_unpermute_kernel.inl"

View File

@ -0,0 +1,211 @@
#pragma once
template <typename T, bool CHECK_SKIPPED, bool ALIGN_BLOCK_SIZE>
__global__ void expandInputRowsKernel(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_dest_rows, int64_t const cols, int64_t k,
int num_local_experts, int align_block_size) {
// Reverse permutation map.
// I do this so that later, we can use the source -> dest map to do the k-way
// reduction and unpermuting. I need the reverse map for that reduction to
// allow each threadblock to do 1 k-way reduce without atomics later in MoE. 1
// thread block will be responsible for all k summations.
int64_t expanded_dest_row = blockIdx.x;
int64_t const expanded_source_row =
expanded_dest_row_to_expanded_source_row[expanded_dest_row];
int expert_id = sorted_experts[expanded_dest_row];
extern __shared__ int64_t smem_expert_first_token_offset[];
int64_t align_expanded_row_accumulate = 0;
if constexpr (ALIGN_BLOCK_SIZE) {
// load g2s
for (int idx = threadIdx.x; idx < num_local_experts + 1;
idx += blockDim.x) {
smem_expert_first_token_offset[idx] =
__ldg(expert_first_token_offset + idx);
}
__syncthreads();
int lane_idx = threadIdx.x & 31;
if (lane_idx == 0) {
// set token_offset_in_expert = 0 if this expert is not local expert
int token_offset_in_expert =
expert_id >= num_local_experts
? 0
: expanded_dest_row - smem_expert_first_token_offset[expert_id];
int64_t accumulate_align_offset = 0;
#pragma unroll 1
for (int eidx = 1; eidx <= min(expert_id, num_local_experts); eidx++) {
auto n_token_in_expert = smem_expert_first_token_offset[eidx] -
smem_expert_first_token_offset[eidx - 1];
accumulate_align_offset += (n_token_in_expert + align_block_size - 1) /
align_block_size * align_block_size;
}
expanded_dest_row = accumulate_align_offset + token_offset_in_expert;
}
// lane0 shuffle broadcast align_expanded_dest_row
expanded_dest_row = __shfl_sync(0xffffffff, expanded_dest_row, 0);
}
if (threadIdx.x == 0) {
assert(expanded_dest_row <= INT32_MAX);
expanded_source_row_to_expanded_dest_row[expanded_source_row] =
static_cast<int>(expanded_dest_row);
}
if (!CHECK_SKIPPED || blockIdx.x < *num_dest_rows) {
// Load 128-bits per thread
constexpr int64_t ELEM_PER_THREAD = 128 / cutlass::sizeof_bits<T>::value;
using DataElem = cutlass::Array<T, ELEM_PER_THREAD>;
// Duplicate and permute rows
int64_t const source_k_rank = expanded_source_row / num_rows;
int64_t const source_row = expanded_source_row % num_rows;
auto const* source_row_ptr =
reinterpret_cast<DataElem const*>(unpermuted_input + source_row * cols);
auto* dest_row_ptr =
reinterpret_cast<DataElem*>(permuted_output + expanded_dest_row * cols);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = cols / ELEM_PER_THREAD;
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
dest_row_ptr[elem_index] = source_row_ptr[elem_index];
}
}
}
template <typename T>
void expandInputRowsKernelLauncher(
T const* unpermuted_input, T* permuted_output,
const float* unpermuted_scales, int* sorted_experts,
int const* expanded_dest_row_to_expanded_source_row,
int* expanded_source_row_to_expanded_dest_row,
int64_t* expert_first_token_offset, int64_t const num_rows,
int64_t const* num_valid_tokens_ptr, int64_t const cols, int const k,
int num_local_experts, const int& align_block_size, cudaStream_t stream) {
int64_t const blocks = num_rows * k;
int64_t const threads = 256;
using FuncPtr = decltype(&expandInputRowsKernel<T, true, true>);
FuncPtr func_map[2][2] = {
{&expandInputRowsKernel<T, false, false>,
&expandInputRowsKernel<T, false, true>},
{&expandInputRowsKernel<T, true, false>,
&expandInputRowsKernel<T, true, true>},
};
bool is_check_skip = num_valid_tokens_ptr != nullptr;
bool is_align_block_size = align_block_size != -1;
auto func = func_map[is_check_skip][is_align_block_size];
int64_t smem_size = sizeof(int64_t) * (num_local_experts + 1);
func<<<blocks, threads, smem_size, stream>>>(
unpermuted_input, permuted_output, unpermuted_scales, sorted_experts,
expanded_dest_row_to_expanded_source_row,
expanded_source_row_to_expanded_dest_row, expert_first_token_offset,
num_rows, num_valid_tokens_ptr, cols, k, num_local_experts,
align_block_size);
}
template <class T, class U>
__host__ __device__ constexpr static U arrayConvert(T const& input) {
using Type = typename U::Element;
static_assert(T::kElements == U::kElements);
U u;
#pragma unroll
for (int i = 0; i < U::kElements; i++) {
u[i] = static_cast<Type>(input[i]);
}
return u;
}
template <typename T, typename OutputType, bool CHECK_SKIPPED>
__global__ void finalizeMoeRoutingKernel(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const orig_cols, int64_t const k,
int64_t const* num_valid_ptr) {
assert(orig_cols % 4 == 0);
int64_t const original_row = blockIdx.x;
int64_t const num_rows = gridDim.x;
auto const offset = original_row * orig_cols;
OutputType* reduced_row_ptr = reduced_unpermuted_output + offset;
int64_t const num_valid = *num_valid_ptr;
// Load 128-bits per thread, according to the smallest data type we read/write
constexpr int64_t FINALIZE_ELEM_PER_THREAD =
128 / std::min(cutlass::sizeof_bits<OutputType>::value,
cutlass::sizeof_bits<T>::value);
int64_t const start_offset = threadIdx.x;
int64_t const stride = blockDim.x;
int64_t const num_elems_in_col = orig_cols / FINALIZE_ELEM_PER_THREAD;
using InputElem = cutlass::Array<T, FINALIZE_ELEM_PER_THREAD>;
using OutputElem = cutlass::Array<OutputType, FINALIZE_ELEM_PER_THREAD>;
using ComputeElem = cutlass::Array<float, FINALIZE_ELEM_PER_THREAD>;
auto const* expanded_permuted_rows_v =
reinterpret_cast<InputElem const*>(expanded_permuted_rows);
auto* reduced_row_ptr_v = reinterpret_cast<OutputElem*>(reduced_row_ptr);
#pragma unroll
for (int elem_index = start_offset; elem_index < num_elems_in_col;
elem_index += stride) {
ComputeElem thread_output;
thread_output.fill(0);
float row_rescale{0.f};
for (int k_idx = 0; k_idx < k; ++k_idx) {
int64_t const expanded_original_row = original_row + k_idx * num_rows;
int64_t const expanded_permuted_row =
expanded_source_row_to_expanded_dest_row[expanded_original_row];
int64_t const k_offset = original_row * k + k_idx;
float const row_scale = scales[k_offset];
// Check after row_rescale has accumulated
if (CHECK_SKIPPED && expanded_permuted_row >= num_valid) {
continue;
}
auto const* expanded_permuted_rows_row_ptr =
expanded_permuted_rows_v + expanded_permuted_row * num_elems_in_col;
int64_t const expert_idx = expert_for_source_row[k_offset];
ComputeElem expert_result = arrayConvert<InputElem, ComputeElem>(
expanded_permuted_rows_row_ptr[elem_index]);
thread_output = thread_output + row_scale * (expert_result);
}
OutputElem output_elem =
arrayConvert<ComputeElem, OutputElem>(thread_output);
reduced_row_ptr_v[elem_index] = output_elem;
}
}
template <class T, class OutputType>
void finalizeMoeRoutingKernelLauncher(
T const* expanded_permuted_rows, OutputType* reduced_unpermuted_output,
float const* scales, int const* expanded_source_row_to_expanded_dest_row,
int const* expert_for_source_row, int64_t const num_rows,
int64_t const cols, int64_t const k, int64_t const* num_valid_ptr,
cudaStream_t stream) {
int64_t const blocks = num_rows;
int64_t const threads = 256;
bool const check_finished = num_valid_ptr != nullptr;
using FuncPtr = decltype(&finalizeMoeRoutingKernel<T, OutputType, false>);
FuncPtr func_map[2] = {&finalizeMoeRoutingKernel<T, OutputType, false>,
&finalizeMoeRoutingKernel<T, OutputType, true>};
auto* const kernel = func_map[check_finished];
kernel<<<blocks, threads, 0, stream>>>(
expanded_permuted_rows, reduced_unpermuted_output, scales,
expanded_source_row_to_expanded_dest_row, expert_for_source_row, cols, k,
num_valid_ptr);
}

View File

@ -53,7 +53,29 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
"int size_m, int size_n, int size_k,"
"bool is_full_k, bool use_atomic_add,"
"bool use_fp32_reduce, bool is_zp_float) -> Tensor");
m.def(
"marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, "
"Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! "
"b_zeros, Tensor! g_idx, Tensor! perm, Tensor! workspace, "
"int b_q_type, SymInt size_m, "
"SymInt size_n, SymInt size_k, bool is_k_full, int num_experts, int "
"topk, "
"int moe_block_size, bool replicate_input, bool apply_weights)"
" -> Tensor");
m.def(
"moe_permute(Tensor input, Tensor topk_weight, Tensor! topk_ids,"
"Tensor token_expert_indicies, Tensor? expert_map, int n_expert,"
"int n_local_expert,"
"int topk, int? align_block_size,Tensor! permuted_input, Tensor! "
"expert_first_token_offset, Tensor! src_row_id2dst_row_id_map, Tensor! "
"m_indices)->()");
m.def(
"moe_unpermute(Tensor permuted_hidden_states, Tensor topk_weights,"
"Tensor topk_ids,Tensor src_row_id2dst_row_id_map, Tensor "
"expert_first_token_offset, int n_expert, int n_local_expert,int "
"topk, Tensor! hidden_states)->()");
// conditionally compiled so impl registration is in source file
#endif

View File

@ -58,6 +58,12 @@ Therefore, we recommend developing with Python 3.12 to minimise the chance of yo
Currently, the repository is not fully checked by `mypy`.
:::
:::{note}
Currently, not all unit tests pass when run on CPU platforms. If you don't have access to a GPU
platform to run unit tests locally, rely on the continuous integration system to run the tests for
now.
:::
## Issues
If you encounter a bug or have a feature request, please [search existing issues](https://github.com/vllm-project/vllm/issues?q=is%3Aissue) first to see if it has already been reported. If not, please [file a new issue](https://github.com/vllm-project/vllm/issues/new/choose), providing as much relevant information as possible.

View File

@ -30,6 +30,7 @@ from vllm import LLM
model = LLM("facebook/opt-125m", quantization="fp8")
# INFO 06-10 17:55:42 model_runner.py:157] Loading model weights took 0.1550 GB
result = model.generate("Hello, my name is")
print(result[0].outputs[0].text)
```
:::{warning}
@ -105,7 +106,8 @@ Load and run the model in `vllm`:
```python
from vllm import LLM
model = LLM("./Meta-Llama-3-8B-Instruct-FP8-Dynamic")
model.generate("Hello my name is")
result = model.generate("Hello my name is")
print(result[0].outputs[0].text)
```
Evaluate accuracy with `lm_eval` (for example on 250 samples of `gsm8k`):
@ -188,4 +190,5 @@ from vllm import LLM
model = LLM(model="Meta-Llama-3-8B-Instruct-FP8/")
# INFO 06-10 21:15:41 model_runner.py:159] Loading model weights took 8.4596 GB
result = model.generate("Hello, my name is")
print(result[0].outputs[0].text)
```

View File

@ -17,6 +17,7 @@ gptqmodel
int4
int8
fp8
modelopt
quark
quantized_kvcache
torchao

View File

@ -0,0 +1,78 @@
# NVIDIA TensorRT Model Optimizer
The [NVIDIA TensorRT Model Optimizer](https://github.com/NVIDIA/TensorRT-Model-Optimizer) is a library designed to optimize models for inference with NVIDIA GPUs. It includes tools for Post-Training Quantization (PTQ) and Quantization Aware Training (QAT) of Large Language Models (LLMs), Vision Language Models (VLMs), and diffusion models.
We recommend installing the library with:
```console
pip install nvidia-modelopt
```
## Quantizing HuggingFace Models with PTQ
You can quantize HuggingFace models using the example scripts provided in the TensorRT Model Optimizer repository. The primary script for LLM PTQ is typically found within the `examples/llm_ptq` directory.
Below is an example showing how to quantize a model using modelopt's PTQ API:
```python
import modelopt.torch.quantization as mtq
from transformers import AutoModelForCausalLM
# Load the model from HuggingFace
model = AutoModelForCausalLM.from_pretrained("<path_or_model_id>")
# Select the quantization config, for example, FP8
config = mtq.FP8_DEFAULT_CFG
# Define a forward loop function for calibration
def forward_loop(model):
for data in calib_set:
model(data)
# PTQ with in-place replacement of quantized modules
model = mtq.quantize(model, config, forward_loop)
```
After the model is quantized, you can export it to a quantized checkpoint using the export API:
```python
import torch
from modelopt.torch.export import export_hf_checkpoint
with torch.inference_mode():
export_hf_checkpoint(
model, # The quantized model.
export_dir, # The directory where the exported files will be stored.
)
```
The quantized checkpoint can then be deployed with vLLM. As an example, the following code shows how to deploy `nvidia/Llama-3.1-8B-Instruct-FP8`, which is the FP8 quantized checkpoint derived from `meta-llama/Llama-3.1-8B-Instruct`, using vLLM:
```python
from vllm import LLM, SamplingParams
def main():
model_id = "nvidia/Llama-3.1-8B-Instruct-FP8"
# Ensure you specify quantization='modelopt' when loading the modelopt checkpoint
llm = LLM(model=model_id, quantization="modelopt", trust_remote_code=True)
sampling_params = SamplingParams(temperature=0.8, top_p=0.9)
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
if __name__ == "__main__":
main()
```

View File

@ -129,7 +129,17 @@ The table below shows the compatibility of various quantization implementations
*
*
*
- * modelopt
* ✅︎
* ✅︎
* ✅︎
* ✅︎
* ✅︎︎
*
*
*
*
*
:::
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.

View File

@ -47,8 +47,7 @@ def get_mixed_modalities_query() -> QueryResult:
"image":
ImageAsset("cherry_blossom").pil_image.convert("RGB"),
"video":
VideoAsset(name="sample_demo_1.mp4",
num_frames=16).np_ndarrays,
VideoAsset(name="baby_reading", num_frames=16).np_ndarrays,
},
},
limit_mm_per_prompt={
@ -66,7 +65,7 @@ def get_use_audio_in_video_query() -> QueryResult:
"<|im_start|>user\n<|vision_bos|><|VIDEO|><|vision_eos|>"
f"{question}<|im_end|>\n"
f"<|im_start|>assistant\n")
asset = VideoAsset(name="sample_demo_1.mp4", num_frames=16)
asset = VideoAsset(name="baby_reading", num_frames=16)
audio = asset.get_audio(sampling_rate=16000)
assert not envs.VLLM_USE_V1, ("V1 does not support use_audio_in_video. "
"Please launch this example with "

View File

@ -1109,7 +1109,7 @@ def get_multi_modal_input(args):
if args.modality == "video":
# Input video and question
video = VideoAsset(name="sample_demo_1.mp4",
video = VideoAsset(name="baby_reading",
num_frames=args.num_frames).np_ndarrays
vid_questions = ["Why is this video funny?"]

View File

@ -2,5 +2,7 @@
-r common.txt
# Dependencies for Neuron devices
packaging>=24.2
setuptools>=77.0.3,<80.0.0
torch-neuronx >= 2.5.0
neuronx-cc

View File

@ -23,5 +23,11 @@ runai-model-streamer-s3==0.11.0
tensorizer>=2.9.0
lm-eval==0.4.8
buildkite-test-collector==0.1.9
lm-eval[api]==0.4.8 # required for model evaluation test
# required for quantization test
bitsandbytes>=0.45.3
# required for minicpmo_26 test
vector_quantize_pytorch
vocos

View File

@ -1,5 +1,5 @@
# This file was autogenerated by uv via the following command:
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match
# uv pip compile requirements/test.in -o requirements/test.txt --index-strategy unsafe-best-match --torch-backend cu128
absl-py==2.1.0
# via rouge-score
accelerate==1.0.1
@ -349,28 +349,28 @@ numpy==1.26.4
# transformers
# tritonclient
# vocos
nvidia-cublas-cu12==12.6.4.1
nvidia-cublas-cu12==12.8.3.14
# via
# nvidia-cudnn-cu12
# nvidia-cusolver-cu12
# torch
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-cupti-cu12==12.8.57
# via torch
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-nvrtc-cu12==12.8.61
# via torch
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.8.57
# via torch
nvidia-cudnn-cu12==9.5.1.17
nvidia-cudnn-cu12==9.7.1.26
# via torch
nvidia-cufft-cu12==11.3.0.4
nvidia-cufft-cu12==11.3.3.41
# via torch
nvidia-cufile-cu12==1.11.1.6
nvidia-cufile-cu12==1.13.0.11
# via torch
nvidia-curand-cu12==10.3.7.77
nvidia-curand-cu12==10.3.9.55
# via torch
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusolver-cu12==11.7.2.55
# via torch
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparse-cu12==12.5.7.53
# via
# nvidia-cusolver-cu12
# torch
@ -378,13 +378,13 @@ nvidia-cusparselt-cu12==0.6.3
# via torch
nvidia-nccl-cu12==2.26.2
# via torch
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvjitlink-cu12==12.8.61
# via
# nvidia-cufft-cu12
# nvidia-cusolver-cu12
# nvidia-cusparse-cu12
# torch
nvidia-nvtx-cu12==12.6.77
nvidia-nvtx-cu12==12.8.55
# via torch
opencv-python-headless==4.11.0.86
# via
@ -687,7 +687,7 @@ tomli==2.2.1
# via schemathesis
tomli-w==1.2.0
# via schemathesis
torch==2.7.0
torch==2.7.0+cu128
# via
# -r requirements/test.in
# accelerate
@ -705,12 +705,12 @@ torch==2.7.0
# torchvision
# vector-quantize-pytorch
# vocos
torchaudio==2.7.0
torchaudio==2.7.0+cu128
# via
# -r requirements/test.in
# encodec
# vocos
torchvision==0.22.0
torchvision==0.22.0+cu128
# via
# -r requirements/test.in
# timm

View File

@ -103,7 +103,8 @@ def test_compile_correctness(
method = test_setting.method
fullgraph = test_setting.fullgraph
if cuda_device_count_stateless() != pp_size * tp_size:
pytest.skip("Not correct CUDA devices for the test.")
pytest.skip(f"Need exactly {pp_size}*{tp_size} CUDA gpus but got "
f"{cuda_device_count_stateless()}")
with monkeypatch.context() as m:
m.setenv("VLLM_ATTENTION_BACKEND", attn_backend)

View File

@ -1,9 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import json
import os
import tempfile
from collections import UserList
from enum import Enum
from typing import Any, Callable, Optional, TypedDict, TypeVar, Union
@ -58,16 +56,12 @@ def _read_prompts(filename: str) -> list[str]:
return prompts
class _ImageAssetPrompts(TypedDict):
class ImageAssetPrompts(TypedDict):
stop_sign: str
cherry_blossom: str
class _ImageAssetsBase(UserList[ImageAsset]):
pass
class _ImageAssets(_ImageAssetsBase):
class ImageTestAssets(list[ImageAsset]):
def __init__(self) -> None:
super().__init__([
@ -75,7 +69,7 @@ class _ImageAssets(_ImageAssetsBase):
ImageAsset("cherry_blossom"),
])
def prompts(self, prompts: _ImageAssetPrompts) -> list[str]:
def prompts(self, prompts: ImageAssetPrompts) -> list[str]:
"""
Convenience method to define the prompt for each test image.
@ -85,30 +79,27 @@ class _ImageAssets(_ImageAssetsBase):
return [prompts["stop_sign"], prompts["cherry_blossom"]]
class _VideoAssetPrompts(TypedDict):
sample_demo_1: str
class VideoAssetPrompts(TypedDict):
baby_reading: str
class _VideoAssetsBase(UserList[VideoAsset]):
pass
class _VideoAssets(_VideoAssetsBase):
class VideoTestAssets(list[VideoAsset]):
def __init__(self) -> None:
super().__init__([
VideoAsset("sample_demo_1.mp4"),
VideoAsset("baby_reading"),
])
def prompts(self, prompts: _VideoAssetPrompts) -> list[str]:
return [prompts["sample_demo_1"]]
def prompts(self, prompts: VideoAssetPrompts) -> list[str]:
return [prompts["baby_reading"]]
class _AudioAssetsBase(UserList[AudioAsset]):
pass
class AudioAssetPrompts(TypedDict):
mary_had_lamb: str
winning_call: str
class _AudioAssets(_AudioAssetsBase):
class AudioTestAssets(list[AudioAsset]):
def __init__(self) -> None:
super().__init__([
@ -116,13 +107,16 @@ class _AudioAssets(_AudioAssetsBase):
AudioAsset("winning_call"),
])
def prompts(self, prompts: AudioAssetPrompts) -> list[str]:
return [prompts["mary_had_lamb"], prompts["winning_call"]]
IMAGE_ASSETS = _ImageAssets()
"""Singleton instance of :class:`_ImageAssets`."""
VIDEO_ASSETS = _VideoAssets()
"""Singleton instance of :class:`_VideoAssets`."""
AUDIO_ASSETS = _AudioAssets()
"""Singleton instance of :class:`_AudioAssets`."""
IMAGE_ASSETS = ImageTestAssets()
"""Singleton instance of :class:`ImageTestAssets`."""
VIDEO_ASSETS = VideoTestAssets()
"""Singleton instance of :class:`VideoTestAssets`."""
AUDIO_ASSETS = AudioTestAssets()
"""Singleton instance of :class:`AudioTestAssets`."""
@pytest.fixture(scope="function", autouse=True)
@ -270,17 +264,17 @@ def example_long_prompts() -> list[str]:
@pytest.fixture(scope="session")
def image_assets() -> _ImageAssets:
def image_assets() -> ImageTestAssets:
return IMAGE_ASSETS
@pytest.fixture(scope="session")
def video_assets() -> _VideoAssets:
def video_assets() -> VideoTestAssets:
return VIDEO_ASSETS
@pytest.fixture(scope="session")
def audio_assets() -> _AudioAssets:
def audio_assets() -> AudioTestAssets:
return AUDIO_ASSETS
@ -779,7 +773,7 @@ class VllmRunner:
def get_inputs(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
@ -801,16 +795,18 @@ class VllmRunner:
if audios is not None and (audio := audios[i]) is not None:
multi_modal_data["audio"] = audio
inputs.append(
TextPrompt(prompt=prompt,
multi_modal_data=multi_modal_data
if multi_modal_data else None))
text_prompt_kwargs = {
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
prompt,
"multi_modal_data": multi_modal_data or None
}
inputs.append(TextPrompt(**text_prompt_kwargs))
return inputs
def generate(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
@ -836,7 +832,7 @@ class VllmRunner:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
@ -903,7 +899,7 @@ class VllmRunner:
def generate_greedy(
self,
prompts: list[str],
prompts: Union[list[str], list[torch.Tensor]],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,

View File

@ -2,16 +2,18 @@
import time
from collections import deque
from typing import Optional
from unittest.mock import MagicMock
import pytest # noqa
import torch
from torch import Use # noqa
from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig
from vllm.core.interfaces import AllocStatus
from vllm.core.scheduler import Scheduler, SchedulingBudget
from vllm.lora.request import LoRARequest
from vllm.sequence import SequenceGroup
from vllm.sequence import SequenceGroup, SequenceStatus
from .utils import (append_new_token, append_new_token_seq,
append_new_token_seq_group, create_dummy_prompt,
@ -968,3 +970,73 @@ def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching(
), "A partial prefix of C (4 tokens) should be prefilled, with the "
"remaining tokens fit into 3 token budget (4-1 from the seqA). It will "
"then be rounded down to 2 tokens on block size, thus 6 tokens in total."
def test_no_batches_mixed_with_prompt_tokens_and_prompt_embeds():
"""
Test that the scheduler does not schedule batches with prompt tokens and
prompt embeddings co-mingled.
"""
block_size = 2
max_seq_group = 3
scheduler = initialize_scheduler(
block_size=block_size,
num_cpu_blocks=16,
num_gpu_blocks=16,
max_num_seqs=max_seq_group,
max_model_len=100,
enable_prefix_caching=True,
)
# the odd indexed inputs should be passed in via embeddings,
# evens via token_ids
seq_length = 7
embedding_size = 5
num_seqs = 11
seq_tokens: list[list[int]] = []
seq_embeds: list[Optional[torch.Tensor]] = []
for i in range(num_seqs):
if i % 2:
seq_tokens.append(list(range(seq_length)))
seq_embeds.append(None)
else:
seq_tokens.append([0] * seq_length)
seq_embeds.append(torch.rand(embedding_size))
seq_and_seq_groups = [
create_dummy_prompt(f"{i}",
prompt_tokens=seq_tokens[i],
prompt_embeds=seq_embeds[i],
block_size=block_size)
for i in range(len(seq_tokens))
]
for _, seq_group in seq_and_seq_groups:
scheduler.add_seq_group(seq_group)
while not all(seq.is_finished() for seq, _ in seq_and_seq_groups):
unfinished_seq_groups = [
seq_group for _, seq_group in seq_and_seq_groups
if not seq_group.is_finished()
]
_, out = schedule_and_update_computed_tokens(scheduler)
assert len(out.scheduled_seq_groups) > 0
batch_is_prompt_embeds = out.scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
expected_scheduled_seq_groups = [
seq_group for seq_group in unfinished_seq_groups
if seq_group.uses_prompt_embeds() == batch_is_prompt_embeds
]
# We should have as many scheduled groups as possible, without mixing
assert len(out.scheduled_seq_groups) == min(
max_seq_group, len(expected_scheduled_seq_groups))
assert all(scheduled_seq_group.seq_group.uses_prompt_embeds() ==
batch_is_prompt_embeds
for scheduled_seq_group in out.scheduled_seq_groups)
# Finish the scheduled groups
for scheduled_seq_group in out.scheduled_seq_groups:
for seq in scheduled_seq_group.seq_group.seqs:
seq.status = SequenceStatus.FINISHED_STOPPED
scheduler.free_finished_seq_groups()

View File

@ -5,9 +5,11 @@ from collections import defaultdict
from collections.abc import Sequence as GenericSequence
from typing import Any, Optional
import torch
from vllm import SamplingParams
from vllm.core.scheduler import Scheduler, SchedulerOutputs
from vllm.inputs import EncoderDecoderInputs, token_inputs
from vllm.inputs import EncoderDecoderInputs, embeds_inputs, token_inputs
from vllm.lora.request import LoRARequest
from vllm.sequence import (Logprob, Sequence, SequenceGroup,
SequenceGroupMetadata)
@ -19,6 +21,7 @@ def create_dummy_prompt(
block_size: Optional[int] = None,
lora_request: Optional[LoRARequest] = None,
prompt_tokens: Optional[list[int]] = None,
prompt_embeds: Optional[torch.Tensor] = None,
min_tokens: int = 0,
max_tokens: int = 16,
) -> tuple[Sequence, SequenceGroup]:
@ -31,9 +34,13 @@ def create_dummy_prompt(
prompt_tokens = list(range(prompt_length))
prompt_str = " ".join([str(t) for t in prompt_tokens])
inputs = token_inputs(
prompt_token_ids=prompt_tokens,
prompt=prompt_str) if prompt_embeds is None else embeds_inputs(
prompt_embeds=prompt_embeds)
prompt = Sequence(
int(request_id),
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
inputs=inputs,
block_size=block_size,
)
seq_group = SequenceGroup(

View File

@ -106,6 +106,8 @@ class DummyConfigClass:
"""List with literal choices"""
literal_literal: Literal[Literal[1], Literal[2]] = 1
"""Literal of literals with default 1"""
json_tip: dict = field(default_factory=dict)
"""Dict which will be JSON in CLI"""
@pytest.mark.parametrize(("type_hint", "expected"), [
@ -137,6 +139,9 @@ def test_get_kwargs():
assert kwargs["list_literal"]["choices"] == [1, 2]
# literals of literals should have merged choices
assert kwargs["literal_literal"]["choices"] == [1, 2]
# dict should have json tip in help
json_tip = "\n\nShould be a valid JSON string."
assert kwargs["json_tip"]["help"].endswith(json_tip)
@pytest.mark.parametrize(("arg", "expected"), [

View File

@ -0,0 +1,60 @@
# SPDX-License-Identifier: Apache-2.0
from contextlib import nullcontext
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
skip_tokenizer_init=True,
enforce_eager=True,
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
@pytest.mark.parametrize("enable_prompt_embeds", [True, False])
def test_enable_prompt_embeds(hf_runner, model: str,
enable_prompt_embeds: bool):
prompt = "abc"
with hf_runner(model) as hf_model:
token_ids = hf_model.tokenizer(prompt, return_tensors="pt").input_ids
token_ids = token_ids.to(hf_model.model.device)
embed_layer = hf_model.model.get_input_embeddings()
prompt_embeds = embed_layer(token_ids).squeeze(0)
ctx = (nullcontext() if enable_prompt_embeds else pytest.raises(
ValueError, match="set `--enable-prompt-embeds`"))
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
enable_prompt_embeds=enable_prompt_embeds,
enforce_eager=True,
)
with ctx:
llm.generate({"prompt_embeds": prompt_embeds})

View File

@ -1,29 +0,0 @@
# SPDX-License-Identifier: Apache-2.0
import pytest
from vllm.entrypoints.llm import LLM
from vllm.sampling_params import SamplingParams
@pytest.mark.parametrize("model", ["distilbert/distilgpt2"])
def test_skip_tokenizer_initialization(model: str):
# This test checks if the flag skip_tokenizer_init skips the initialization
# of tokenizer and detokenizer. The generated output is expected to contain
# token ids.
llm = LLM(
model=model,
skip_tokenizer_init=True,
)
sampling_params = SamplingParams(prompt_logprobs=True, detokenize=True)
with pytest.raises(ValueError, match="cannot pass text prompts when"):
llm.generate("abc", sampling_params)
outputs = llm.generate({"prompt_token_ids": [1, 2, 3]},
sampling_params=sampling_params)
assert len(outputs) > 0
completions = outputs[0].outputs
assert len(completions) > 0
assert completions[0].text == ""
assert completions[0].token_ids

View File

@ -420,7 +420,8 @@ def test_fused_marlin_moe(
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, False)
torch_output = torch_moe(a, w_ref1, w_ref2, score, topk, e_map)

View File

@ -0,0 +1,223 @@
# SPDX-License-Identifier: Apache-2.0
"""Tests for the MOE permute/unpermute kernel
Run `pytest tests/kernels/test_moe_permute_unpermute.py`.
"""
from typing import Optional
import numpy as np
import pytest
import torch
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.layer import determine_expert_map
from vllm.model_executor.layers.fused_moe.moe_permute_unpermute import (
moe_permute, moe_unpermute)
from vllm.platforms import current_platform
NUM_EXPERTS = [16, 64]
TOP_KS = [2, 4, 6, 8]
EP_SIZE = [1, 4, 16]
current_platform.seed_everything(0)
def torch_permute(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
start_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1) -> list[torch.Tensor]:
n_token, n_hidden = hidden_states.shape[0], hidden_states.shape[1]
if expert_map is not None:
is_local_expert = (expert_map[topk_ids] != -1)
not_local_expert = (expert_map[topk_ids] == -1)
topk_ids = is_local_expert * (
topk_ids - start_expert) + not_local_expert * (topk_ids + n_expert)
sorted_topk_ids, sorted_indices = torch.sort(topk_ids.flatten(),
stable=True)
dst_row_id2src_row_id_map = token_expert_indices.flatten()[sorted_indices]
expert_first_token_offset = torch.zeros(n_local_expert + 1,
dtype=torch.int64,
device="cuda")
idx = 0
for i in range(0, n_local_expert):
cnt = 0
while idx < sorted_topk_ids.numel() and sorted_topk_ids[idx] == i:
cnt += 1
idx += 1
expert_first_token_offset[i + 1] = expert_first_token_offset[i] + cnt
_, src2dst_idx = torch.sort(dst_row_id2src_row_id_map)
valid_row_idx = []
if align_block_size is None:
permuted_hidden_states = hidden_states[dst_row_id2src_row_id_map %
n_token, ...]
permuted_row_size = permuted_hidden_states.shape[0]
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
m_indices[first_token_offset:last_token_offset] = i - 1
src_row_id2dst_row_id_map = torch.arange(
0, n_token * topk, device="cuda",
dtype=torch.int32)[src2dst_idx].reshape((n_token, topk))
valid_row_idx += [i for i in range(expert_first_token_offset[-1])]
return [
permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, m_indices, valid_row_idx
]
else:
permuted_row_size = (topk * n_token + n_expert *
(align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size
permuted_hidden_states = torch.empty((permuted_row_size, n_hidden),
device="cuda",
dtype=hidden_states.dtype)
align_src_row_id2dst_row_id = torch.empty(n_token * topk,
device="cuda",
dtype=torch.int32)
align_expert_first_token_offset = torch.zeros_like(
expert_first_token_offset)
m_indices = torch.empty(permuted_row_size,
device="cuda",
dtype=torch.int32).fill_(fill_invalid_expert)
# get align_permuted_hidden_states,
# valid row_idx and align_expert_first_token_offset
for i in range(1, n_local_expert + 1):
first_token_offset = expert_first_token_offset[i - 1]
last_token_offset = expert_first_token_offset[i]
n_token_in_expert = last_token_offset - first_token_offset
align_expert_first_token_offset[
i] = align_expert_first_token_offset[
i - 1] + (n_token_in_expert + align_block_size -
1) // align_block_size * align_block_size
align_first_token_offset = align_expert_first_token_offset[i - 1]
align_last_token_offset = align_expert_first_token_offset[i]
dst_row_id2src_row_id_in_expert = dst_row_id2src_row_id_map[
first_token_offset:first_token_offset +
n_token_in_expert] % n_token
# store token in current expert with align_first_token_offset
permuted_hidden_states[align_first_token_offset:\
align_first_token_offset+n_token_in_expert,\
...] = hidden_states[\
dst_row_id2src_row_id_in_expert, ...]
# set current expert m_indices
m_indices[align_first_token_offset:align_last_token_offset] = i - 1
valid_row_idx += [
i for i in range(align_first_token_offset,
align_first_token_offset + n_token_in_expert)
]
# get align_src_row_id2dst_row_id
for i in range(n_token * topk):
eid = sorted_topk_ids[i]
if (eid >= n_local_expert):
# check token not in local expert
align_src_row_id2dst_row_id[
i] = align_expert_first_token_offset[-1]
continue
first_token_offset = expert_first_token_offset[eid]
align_first_token_offset = align_expert_first_token_offset[eid]
token_offset = i - first_token_offset
align_src_row_id2dst_row_id[
i] = align_first_token_offset + token_offset
align_src_row_id2dst_row_id = align_src_row_id2dst_row_id[\
src2dst_idx].reshape((n_token, topk))
return [
permuted_hidden_states, align_expert_first_token_offset,
align_src_row_id2dst_row_id, m_indices, valid_row_idx
]
def torch_unpermute(permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
valid_row_idx: torch.Tensor, topk: int,
n_expert: int) -> torch.Tensor:
# ignore invalid row
mask = torch.zeros(permuted_hidden_states.shape[0],
dtype=bool,
device="cuda")
mask[valid_row_idx] = True
permuted_hidden_states[~mask] = 0
idx = src_row_id2dst_row_id_map.flatten()[
token_expert_indices.flatten()].reshape(token_expert_indices.shape)
output = permuted_hidden_states[idx, ...] * topk_weights[..., None]
output = output.sum(dim=1).to(permuted_hidden_states.dtype)
return output
@pytest.mark.parametrize("n_token", [1, 33, 64, 222, 1024, 2048, 3000, 5000])
@pytest.mark.parametrize("n_hidden", [2048, 4096, 7168])
@pytest.mark.parametrize("n_expert", NUM_EXPERTS)
@pytest.mark.parametrize("topk", TOP_KS)
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize("ep_size", EP_SIZE)
@pytest.mark.parametrize("align_block_size", [None, 128])
def test_moe_permute_unpermute(n_token: int, n_hidden: int, topk: int,
n_expert: int, ep_size: int, dtype: torch.dtype,
align_block_size: Optional[int]):
fill_invalid_expert = 0
ep_rank = np.random.randint(0, ep_size)
expert_map = None
n_local_expert = n_expert
if (ep_size != 1):
n_local_expert, expert_map = determine_expert_map(
ep_size, ep_rank, n_expert)
expert_map = expert_map.cuda()
start_expert = n_local_expert * ep_rank
current_platform.seed_everything(0)
hidden_states = torch.randn((n_token, n_hidden), device="cuda").to(dtype)
gating_output = torch.randn((n_token, n_expert), device="cuda").to(dtype)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, False)
gold0, gold1, gold2, gold3, valid_row_idx = torch_permute(
hidden_states,
topk_ids,
token_expert_indices,
topk,
n_expert,
n_local_expert,
start_expert,
expert_map=expert_map,
align_block_size=align_block_size,
fill_invalid_expert=fill_invalid_expert)
result0, result1, result2, result3 = moe_permute(
hidden_states, topk_weights, topk_ids, token_expert_indices, topk,
n_expert, n_local_expert, expert_map, align_block_size,
fill_invalid_expert)
# check expert_first_token_offset
torch.testing.assert_close(gold1, result1, atol=0, rtol=0)
# check src_row_id2dst_row_id_map
torch.testing.assert_close(gold2, result2, atol=0, rtol=0)
# check mindice
torch.testing.assert_close(gold3, result3, atol=0, rtol=0)
# check permuted_hidden_states, only valid token
torch.testing.assert_close(gold0[valid_row_idx],
result0[valid_row_idx],
atol=0,
rtol=0)
# add a random tensor to simulate group gemm
result0 = 0.5 * result0 + torch.randn_like(result0)
result4 = moe_unpermute(result0, topk_weights, topk_ids, result2, result1,
topk, n_expert, n_local_expert)
gold4 = torch_unpermute(result0, topk_weights, topk_ids,
token_expert_indices, result2, valid_row_idx, topk,
n_local_expert)
# check unpermuted hidden
torch.testing.assert_close(result4, gold4, atol=2e-2, rtol=0)

View File

@ -84,7 +84,8 @@ def test_fused_marlin_moe_awq(
score = torch.randn((m, e), device="cuda", dtype=dtype)
topk_weights, topk_ids = fused_topk(a, score, topk, False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score, topk, False)
marlin_output = torch.ops.vllm.fused_marlin_moe(
a,
qweight1,

View File

@ -338,7 +338,8 @@ def deep_gemm_w8a8_block_fp8_moe(M, K, a, w1, w2, w1_s, w2_s, score, topk,
M, K = a.shape
N = w2.shape[-1]
topk_weight, topk_ids = fused_topk(a, score.float(), topk, False)
topk_weight, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
block_m = deep_gemm.get_m_alignment_for_contiguous_layout()
@ -435,7 +436,8 @@ def test_w8a8_block_fp8_deep_gemm_fused_moe(M, N, K, E, topk, seed):
ref_out = torch_w8a8_block_fp8_moe(a, w1, w2, w1_s, w2_s, score,
topk, block_size)
topk_weights, topk_ids = fused_topk(a, score.float(), topk, False)
topk_weights, topk_ids, token_expert_indices = fused_topk(
a, score.float(), topk, False)
out = deep_gemm_moe_fp8(a, w1, w2, w1_s, w2_s, topk_weights, topk_ids)

View File

@ -14,8 +14,8 @@ import torch
# Fixture to set up environment variables and teardown servers after tests
@pytest.fixture(scope="module", autouse=True)
def setup_servers():
if torch.cuda.device_count() < 4:
pytest.skip("Skipping test: fewer than 4 GPUs available")
if torch.cuda.device_count() < 2:
pytest.skip("Skipping test: fewer than 2 GPUs available")
# Set up environment variables
VLLM_HOST_IP = subprocess.check_output("hostname -I | awk '{print $1}'",

View File

@ -1,4 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
import pytest
import torch
@ -106,19 +109,38 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
# in parts of the operators
pytest.skip(f"Skipping '{model}' model test with AITER kernel.")
use_prompt_embeds = os.getenv("VLLM_USE_V1") == "0"
with hf_runner(model) as hf_model:
hf_outputs = hf_model.generate_greedy_logprobs_limit(
example_prompts, max_tokens, num_logprobs)
prompt_embeds: Optional[list[torch.Tensor]] = ([] if use_prompt_embeds
else None)
prompt_token_ids = []
for prompt in example_prompts:
token_ids = hf_model.tokenizer(prompt,
return_tensors="pt").input_ids.to(
hf_model.model.device)
prompt_token_ids.append(token_ids)
if prompt_embeds is not None:
prompt_embeds.append(hf_model.model.get_input_embeddings()(
token_ids).squeeze(0))
with vllm_runner(
model,
tokenizer_name=model_info.tokenizer or model,
tokenizer_mode=model_info.tokenizer_mode,
trust_remote_code=model_info.trust_remote_code,
max_num_seqs=2,
enable_prompt_embeds=use_prompt_embeds,
) as vllm_model:
vllm_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)
if prompt_embeds is not None:
vllm_outputs_from_embeds = vllm_model.generate_greedy_logprobs(
prompt_embeds, max_tokens, num_logprobs)
check_logprobs_close(
outputs_0_lst=hf_outputs,
@ -126,6 +148,14 @@ def test_models(hf_runner, vllm_runner, example_prompts, model: str,
name_0="hf",
name_1="vllm",
)
if prompt_embeds is not None:
check_logprobs_close(
outputs_0_lst=vllm_outputs,
outputs_1_lst=vllm_outputs_from_embeds,
name_0="vllm",
name_1="vllm_from_embeds",
)
if use_rocm_aiter:
# this is to ensure that vllm engine
# has deallocated the memory before running the next

View File

@ -8,13 +8,14 @@ from collections import defaultdict
from pathlib import PosixPath
import pytest
from transformers import AutoModelForImageTextToText, AutoModelForVision2Seq
from transformers import (AutoModelForImageTextToText,
AutoModelForTextToWaveform, AutoModelForVision2Seq)
from vllm.platforms import current_platform
from vllm.utils import identity
from ....conftest import (IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets,
_VideoAssets)
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets,
VideoTestAssets, VllmRunner)
from ....utils import (create_new_process_for_each_test, large_gpu_mark,
multi_gpu_marks)
from ...utils import check_outputs_equal
@ -140,7 +141,7 @@ VLM_TEST_SETTINGS = {
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
"qwen2_5_omni": VLMTestInfo(
models=["Qwen/Qwen2.5-Omni-7B"],
models=["Qwen/Qwen2.5-Omni-3B"],
test_type=(
VLMTestType.IMAGE,
VLMTestType.MULTI_IMAGE,
@ -151,8 +152,9 @@ VLM_TEST_SETTINGS = {
video_idx_to_prompt=lambda idx: "<|vision_bos|><|VIDEO|><|vision_eos|>", # noqa: E501
max_model_len=4096,
max_num_seqs=2,
auto_cls=AutoModelForVision2Seq,
auto_cls=AutoModelForTextToWaveform,
vllm_output_post_proc=model_utils.qwen2_vllm_to_hf_output,
patch_hf_runner=model_utils.qwen2_5_omni_patch_hf_runner,
image_size_factors=[(), (0.25,), (0.25, 0.25, 0.25), (0.25, 0.2, 0.15)],
marks=[pytest.mark.core_model, pytest.mark.cpu_model],
),
@ -691,7 +693,7 @@ def test_single_image_models(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, monkeypatch):
image_assets: ImageTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -716,7 +718,7 @@ def test_multi_image_models(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, monkeypatch):
image_assets: ImageTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -741,7 +743,7 @@ def test_image_embedding_models(model_type: str,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, monkeypatch):
image_assets: ImageTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -763,7 +765,7 @@ def test_image_embedding_models(model_type: str,
))
def test_video_models(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
video_assets: _VideoAssets, monkeypatch):
video_assets: VideoTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -814,7 +816,7 @@ def test_single_image_models_heavy(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, monkeypatch):
image_assets: ImageTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -840,7 +842,7 @@ def test_multi_image_models_heavy(tmp_path: PosixPath, model_type: str,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, monkeypatch):
image_assets: ImageTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -866,7 +868,8 @@ def test_image_embedding_models_heavy(model_type: str,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, monkeypatch):
image_assets: ImageTestAssets,
monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]
@ -889,7 +892,7 @@ def test_image_embedding_models_heavy(model_type: str,
def test_video_models_heavy(model_type: str, test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
video_assets: _VideoAssets, monkeypatch):
video_assets: VideoTestAssets, monkeypatch):
if model_type in REQUIRES_V0_MODELS:
monkeypatch.setenv("VLLM_USE_V1", "0")
model_test_info = VLM_TEST_SETTINGS[model_type]

View File

@ -9,7 +9,7 @@ from vllm.inputs.data import ExplicitEncoderDecoderPrompt, TextPrompt
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets
from ....conftest import IMAGE_ASSETS, HfRunner, ImageTestAssets, VllmRunner
from ...utils import check_logprobs_close
MODELS = ["microsoft/Florence-2-base"]
@ -118,7 +118,7 @@ def run_test(
@pytest.mark.parametrize("max_tokens", [64])
@pytest.mark.parametrize("num_logprobs", [5])
def test_models(hf_runner: type[HfRunner], vllm_runner: type[VllmRunner],
image_assets: _ImageAssets, model: str,
image_assets: ImageTestAssets, model: str,
size_factors: list[int], dtype: str, max_tokens: int,
num_logprobs: int) -> None:
images = [asset.pil_image for asset in image_assets]

View File

@ -9,7 +9,8 @@ from transformers import AutoModelForSpeechSeq2Seq
from vllm.lora.request import LoRARequest
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, PromptAudioInput, VllmRunner, _AudioAssets
from ....conftest import (AudioTestAssets, HfRunner, PromptAudioInput,
VllmRunner)
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
@ -116,9 +117,9 @@ def run_test(
@pytest.mark.parametrize("max_model_len", [2048])
@pytest.mark.parametrize("max_tokens", [128])
@pytest.mark.parametrize("num_logprobs", [10])
def test_models(hf_runner, vllm_runner, model: str, audio_assets: _AudioAssets,
dtype: str, max_model_len: int, max_tokens: int,
num_logprobs: int) -> None:
def test_models(hf_runner, vllm_runner, model: str,
audio_assets: AudioTestAssets, dtype: str, max_model_len: int,
max_tokens: int, num_logprobs: int) -> None:
model_info = HF_EXAMPLE_MODELS.find_hf_info(model)
model_info.check_available_online(on_fail="skip")
model_info.check_transformers_version(on_fail="skip")

View File

@ -29,7 +29,7 @@ def test_models(vllm_runner, model, dtype: str, max_tokens: int) -> None:
image_cherry = ImageAsset("cherry_blossom").pil_image.convert("RGB")
image_stop = ImageAsset("stop_sign").pil_image.convert("RGB")
images = [image_cherry, image_stop]
video = VideoAsset(name="sample_demo_1.mp4", num_frames=16).np_ndarrays
video = VideoAsset(name="baby_reading", num_frames=16).np_ndarrays
inputs = [
(

View File

@ -14,8 +14,8 @@ from vllm.model_executor.models.mllama import MllamaForConditionalGeneration
from vllm.multimodal.image import rescale_image_size
from vllm.sequence import SampleLogprobs
from ....conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner,
_ImageAssets)
from ....conftest import (IMAGE_ASSETS, HfRunner, ImageTestAssets,
PromptImageInput, VllmRunner)
from ....quantization.utils import is_quant_method_supported
from ....utils import (create_new_process_for_each_test, large_gpu_test,
multi_gpu_test)
@ -90,7 +90,7 @@ def vllm_to_hf_output(vllm_output: tuple[list[int], str,
def _get_inputs(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
*,
size_factors: Optional[list[float]] = None,
sizes: Optional[list[tuple[int, int]]] = None,
@ -126,7 +126,7 @@ def _get_inputs(
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model: str,
*,
size_factors: list[float],
@ -143,7 +143,7 @@ def run_test(
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model: str,
*,
sizes: list[tuple[int, int]],
@ -159,7 +159,7 @@ def run_test(
def run_test(
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model: str,
*,
size_factors: Optional[list[float]] = None,
@ -433,7 +433,7 @@ def test_models_distributed(
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
reason='bitsandbytes is not supported on this GPU type.')
def test_bnb_regression(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model: str,
dtype: str,
max_tokens: int,
@ -473,7 +473,7 @@ def test_bnb_regression(
@pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
def test_explicit_implicit_prompt(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model: str,
dtype: str,
max_tokens: int,

View File

@ -50,7 +50,7 @@ IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
})
VIDEO_PROMPTS = VIDEO_ASSETS.prompts({
"sample_demo_1":
"baby_reading":
qwen2_vl_chat_template(
VIDEO_PLACEHOLDER,
"Describe this video with a short sentence ",

View File

@ -11,13 +11,22 @@ from transformers import AutoModel, AutoTokenizer
from vllm.multimodal.audio import resample_audio_librosa
from vllm.sequence import SampleLogprobs
from ....conftest import HfRunner, VllmRunner, _AudioAssets
from ....conftest import AUDIO_ASSETS, AudioTestAssets, HfRunner, VllmRunner
from ....utils import RemoteOpenAIServer
from ...registry import HF_EXAMPLE_MODELS
from ...utils import check_logprobs_close
MODEL_NAME = "fixie-ai/ultravox-v0_5-llama-3_2-1b"
AUDIO_PROMPTS = AUDIO_ASSETS.prompts({
"mary_had_lamb":
"Transcribe this into English.",
"winning_call":
"What is happening in this audio clip?",
})
MULTI_AUDIO_PROMPT = "Describe each of the audios above."
AudioTuple = tuple[np.ndarray, int]
VLLM_PLACEHOLDER = "<|audio|>"
@ -31,12 +40,6 @@ CHUNKED_PREFILL_KWARGS = {
}
@pytest.fixture(scope="module", params=("mary_had_lamb", "winning_call"))
def audio(request):
from vllm.assets.audio import AudioAsset
return AudioAsset(request.param)
def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
"""Convert kwargs to CLI args."""
args = []
@ -53,7 +56,7 @@ def params_kwargs_to_cli_args(params_kwargs: dict[str, Any]) -> list[str]:
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def server(request, audio_assets: _AudioAssets):
def server(request, audio_assets: AudioTestAssets):
args = [
"--dtype", "bfloat16", "--max-model-len", "4096", "--enforce-eager",
"--limit-mm-per-prompt",
@ -199,15 +202,19 @@ def run_multi_audio_test(
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
num_logprobs: int, vllm_kwargs: dict) -> None:
def test_models(hf_runner, vllm_runner, audio_assets: AudioTestAssets,
dtype: str, max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:
audio_inputs = [(
_get_prompt(1, audio, VLLM_PLACEHOLDER),
_get_prompt(1, audio, HF_PLACEHOLDER),
audio.audio_and_sample_rate,
) for audio in audio_assets]
vllm_prompt = _get_prompt(1, "Describe the audio above.", VLLM_PLACEHOLDER)
hf_prompt = _get_prompt(1, "Describe the audio above.", HF_PLACEHOLDER)
run_test(
hf_runner,
vllm_runner,
[(vllm_prompt, hf_prompt, audio.audio_and_sample_rate)],
audio_inputs,
MODEL_NAME,
dtype=dtype,
max_tokens=max_tokens,
@ -224,13 +231,12 @@ def test_models(hf_runner, vllm_runner, audio, dtype: str, max_tokens: int,
pytest.param({}, marks=pytest.mark.cpu_model),
pytest.param(CHUNKED_PREFILL_KWARGS),
])
def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
dtype: str, max_tokens: int,
num_logprobs: int,
def test_models_with_multiple_audios(vllm_runner,
audio_assets: AudioTestAssets, dtype: str,
max_tokens: int, num_logprobs: int,
vllm_kwargs: dict) -> None:
vllm_prompt = _get_prompt(len(audio_assets),
"Describe each of the audios above.",
vllm_prompt = _get_prompt(len(audio_assets), MULTI_AUDIO_PROMPT,
VLLM_PLACEHOLDER)
run_multi_audio_test(
vllm_runner,
@ -245,7 +251,7 @@ def test_models_with_multiple_audios(vllm_runner, audio_assets: _AudioAssets,
@pytest.mark.asyncio
async def test_online_serving(client, audio_assets: _AudioAssets):
async def test_online_serving(client, audio_assets: AudioTestAssets):
"""Exercises online serving with/without chunked prefill enabled."""
messages = [{

View File

@ -11,7 +11,7 @@ from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.video import (rescale_video_size, resize_video,
sample_frames_from_video)
from .....conftest import _ImageAssets, _VideoAssets
from .....conftest import ImageTestAssets, VideoTestAssets
from .types import (SINGLE_IMAGE_BASE_PROMPTS, TEST_IMG_PLACEHOLDER,
TEST_VIDEO_PLACEHOLDER, VIDEO_BASE_PROMPT,
ImageSizeWrapper, SizeType, VLMTestInfo)
@ -69,7 +69,7 @@ def get_model_prompts(base_prompts: Iterable[str],
def build_single_image_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None):
if test_info.prompt_formatter is None:
@ -116,7 +116,7 @@ def build_single_image_inputs(images, model_prompts,
def build_multi_image_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
tmp_path: Optional[PosixPath] = None):
if test_info.prompt_formatter is None:
@ -159,7 +159,7 @@ def build_multi_image_inputs(image_lists, model_prompts,
def build_embedding_inputs_from_test_info(
test_info: VLMTestInfo,
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
size_wrapper: ImageSizeWrapper,
):
# These conditions will always be true if invoked through filtering,
@ -192,7 +192,7 @@ def build_embedding_inputs_from_test_info(
def build_video_inputs_from_test_info(
test_info: VLMTestInfo,
video_assets: _VideoAssets,
video_assets: VideoTestAssets,
size_wrapper: ImageSizeWrapper,
num_frames: int,
):

View File

@ -16,7 +16,7 @@ from transformers import (AutoConfig, AutoTokenizer, BatchFeature,
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import patch_padding_side
from .....conftest import HfRunner, ImageAsset, _ImageAssets
from .....conftest import HfRunner, ImageAsset, ImageTestAssets
from .types import RunnerOutput
@ -238,14 +238,14 @@ def minimax_vl_01_hf_output(hf_output: RunnerOutput,
####### Functions for converting image assets to embeddings
def get_llava_embeddings(image_assets: _ImageAssets):
def get_llava_embeddings(image_assets: ImageTestAssets):
return [asset.image_embeds for asset in image_assets]
####### Prompt path encoders for models that need models on disk
def qwen_prompt_path_encoder(
tmp_path: PosixPath, prompt: str, assets: Union[list[ImageAsset],
_ImageAssets]) -> str:
tmp_path: PosixPath, prompt: str,
assets: Union[list[ImageAsset], ImageTestAssets]) -> str:
"""Given a temporary dir path, export one or more image assets into the
tempdir & replace its contents with the local path to the string so that
the HF version of Qwen-VL can resolve the path and load the image in its
@ -706,3 +706,11 @@ def ovis2_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
hf_model.processor = processor
return hf_model
def qwen2_5_omni_patch_hf_runner(hf_model: HfRunner) -> HfRunner:
"""Patches and returns an instance of the HfRunner for Qwen2.5-Omni."""
thinker = hf_model.model.thinker
thinker.get_output_embeddings = lambda: thinker.lm_head
hf_model.model = thinker
return hf_model

View File

@ -4,7 +4,8 @@ types / modalities.
"""
from pathlib import PosixPath
from .....conftest import HfRunner, VllmRunner, _ImageAssets, _VideoAssets
from .....conftest import (HfRunner, ImageTestAssets, VideoTestAssets,
VllmRunner)
from . import builders, core
from .types import ExpandableVLMTestArgs, VLMTestInfo
@ -14,7 +15,7 @@ def run_single_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets):
image_assets: ImageTestAssets):
assert test_case.size_wrapper is not None
inputs = builders.build_single_image_inputs_from_test_info(
model_test_info, image_assets, test_case.size_wrapper, tmp_path)
@ -37,7 +38,7 @@ def run_multi_image_test(*, tmp_path: PosixPath, model_test_info: VLMTestInfo,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets):
image_assets: ImageTestAssets):
assert test_case.size_wrapper is not None
inputs = builders.build_multi_image_inputs_from_test_info(
model_test_info, image_assets, test_case.size_wrapper, tmp_path)
@ -60,7 +61,7 @@ def run_embedding_test(*, model_test_info: VLMTestInfo,
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets):
image_assets: ImageTestAssets):
assert test_case.size_wrapper is not None
inputs, vllm_embeddings = builders.build_embedding_inputs_from_test_info(
model_test_info, image_assets, test_case.size_wrapper)
@ -86,7 +87,7 @@ def run_video_test(
test_case: ExpandableVLMTestArgs,
hf_runner: type[HfRunner],
vllm_runner: type[VllmRunner],
video_assets: _VideoAssets,
video_assets: VideoTestAssets,
):
assert test_case.size_wrapper is not None
assert test_case.num_video_frames is not None

View File

@ -15,7 +15,7 @@ from vllm.config import TaskOption
from vllm.sequence import SampleLogprobs
from vllm.transformers_utils.tokenizer import AnyTokenizer
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, _ImageAssets
from .....conftest import IMAGE_ASSETS, HfRunner, ImageAsset, ImageTestAssets
from ....utils import check_logprobs_close
# meta image tag; will be replaced by the appropriate tag for the model
@ -85,7 +85,7 @@ class VLMTestInfo(NamedTuple):
# Function for converting ImageAssets to image embeddings;
# We need to define this explicitly for embedding tests
convert_assets_to_embeddings: Optional[Callable[[_ImageAssets],
convert_assets_to_embeddings: Optional[Callable[[ImageTestAssets],
torch.Tensor]] = None
# Exposed options for vLLM runner; we change these in a several tests,
@ -141,7 +141,7 @@ class VLMTestInfo(NamedTuple):
# for Qwen-VL, which requires encoding the image path / url into the prompt
# for HF runner
prompt_path_encoder: Optional[
Callable[[PosixPath, str, Union[list[ImageAsset], _ImageAssets]],
Callable[[PosixPath, str, Union[list[ImageAsset], ImageTestAssets]],
str]] = None # noqa: E501
# Allows configuring a test to run with custom inputs

View File

@ -1,33 +1,34 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional
import pytest
import torch
import torch.nn as nn
from huggingface_hub import snapshot_download
from transformers import AutoConfig, AutoModel, CLIPImageProcessor
from ....conftest import _ImageAssets
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE
from ....conftest import ImageTestAssets
# we use snapshot_download to prevent conflicts between
# dynamic_module and trust_remote_code for hf_runner
DOWNLOAD_PATTERN = ["*.json", "*.py", "*.safetensors", "*.txt", "*.model"]
@torch.inference_mode()
def run_intern_vit_test(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
*,
dtype: str,
distributed_executor_backend: Optional[str] = None,
):
model = snapshot_download(model_id, allow_patterns=DOWNLOAD_PATTERN)
torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype]
img_processor = CLIPImageProcessor.from_pretrained(model)
images = [asset.pil_image for asset in image_assets]
pixel_values = [
img_processor(images, return_tensors='pt').pixel_values.to(dtype)
img_processor(images, return_tensors='pt').pixel_values.to(torch_dtype)
for images in images
]
@ -36,14 +37,13 @@ def run_intern_vit_test(
config.norm_type = "rms_norm"
hf_model = AutoModel.from_pretrained(model,
torch_dtype=dtype,
torch_dtype=torch_dtype,
trust_remote_code=True).to("cuda")
hf_outputs_per_image = [
hf_model(pixel_value.to("cuda")).last_hidden_state
for pixel_value in pixel_values
]
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.model_executor.models.intern_vit import InternVisionModel
vllm_model = InternVisionModel(config)
vllm_model.load_weights(hf_model.state_dict().items())
@ -51,7 +51,7 @@ def run_intern_vit_test(
del hf_model
cleanup_dist_env_and_memory()
vllm_model = vllm_model.to("cuda", dtype)
vllm_model = vllm_model.to("cuda", torch_dtype)
vllm_outputs_per_image = [
vllm_model(pixel_values=pixel_value.to("cuda"))
for pixel_value in pixel_values
@ -69,8 +69,7 @@ def run_intern_vit_test(
"OpenGVLab/InternViT-300M-448px",
"OpenGVLab/InternViT-6B-448px-V1-5",
])
@pytest.mark.parametrize("dtype", [torch.half])
@torch.inference_mode()
@pytest.mark.parametrize("dtype", ["half"])
def test_models(dist_init, image_assets, model_id, dtype: str) -> None:
run_intern_vit_test(
image_assets,

View File

@ -284,7 +284,7 @@ def _test_processing_correctness_mistral(
"Qwen/Qwen2-VL-2B-Instruct",
"Qwen/Qwen2.5-VL-3B-Instruct",
"Qwen/Qwen2-Audio-7B-Instruct",
"Qwen/Qwen2.5-Omni-7B",
"Qwen/Qwen2.5-Omni-3B",
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",

View File

@ -11,7 +11,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.processing import BaseMultiModalProcessor
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -137,7 +137,7 @@ def _run_check(
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
size_factors: list[int],
min_dynamic_patch: int,
max_dynamic_patch: int,

View File

@ -5,7 +5,7 @@ from transformers import Idefics3Config
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -21,7 +21,7 @@ from ...utils import build_model_context
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,

View File

@ -11,7 +11,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.image import rescale_image_size
from vllm.multimodal.processing import BaseMultiModalProcessor
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -94,7 +94,7 @@ def _run_check(
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
model_id: str,
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
size_factors: list[int],
min_dynamic_patch: int,
max_dynamic_patch: int,

View File

@ -6,7 +6,7 @@ import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.tokenizer import encode_tokens
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -17,7 +17,7 @@ from ...utils import build_model_context
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
@pytest.mark.parametrize("tokenized_prompt", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
mm_processor_kwargs: dict,
num_imgs: int,

View File

@ -7,14 +7,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.multimodal.parse import ImageSize
from vllm.multimodal.processing import BaseMultiModalProcessor
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id", ["MiniMaxAI/MiniMax-VL-01"])
@pytest.mark.parametrize("num_imgs", [1, 2])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
num_imgs: int,
):

View File

@ -4,7 +4,7 @@ import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -22,7 +22,7 @@ from ...utils import build_model_context
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,

View File

@ -4,7 +4,7 @@ import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -22,7 +22,7 @@ from ...utils import build_model_context
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
mm_processor_kwargs: dict[str, int],
expected_toks_per_img: int,

View File

@ -4,7 +4,7 @@ import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -19,7 +19,7 @@ from ...utils import build_model_context
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,

View File

@ -5,7 +5,7 @@ from transformers import SmolVLMConfig
from vllm.multimodal import MULTIMODAL_REGISTRY
from ....conftest import _ImageAssets
from ....conftest import ImageTestAssets
from ...utils import build_model_context
@ -21,7 +21,7 @@ from ...utils import build_model_context
@pytest.mark.parametrize("num_imgs", [1, 2])
@pytest.mark.parametrize("kwargs_on_init", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
model_id: str,
mm_processor_kwargs: dict[str, object],
expected_toks_per_img: int,

View File

@ -7,7 +7,7 @@ import torch
from vllm.multimodal.image import rescale_image_size
from ...conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets
from ...conftest import IMAGE_ASSETS, ImageTestAssets, VllmRunner
from ..utils import check_logprobs_close
HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
@ -20,7 +20,7 @@ HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({
def run_awq_test(
vllm_runner: type[VllmRunner],
image_assets: _ImageAssets,
image_assets: ImageTestAssets,
source_model: str,
quant_model: str,
*,

View File

@ -72,12 +72,15 @@ class _HfExamplesInfo:
return
current_version = TRANSFORMERS_VERSION
cur_base_version = Version(current_version).base_version
min_version = self.min_transformers_version
max_version = self.max_transformers_version
msg = f"`transformers=={current_version}` installed, but `transformers"
if min_version and Version(current_version) < Version(min_version):
# Only check the base version for the min/max version, otherwise preview
# models cannot be run because `x.yy.0.dev0`<`x.yy.0`
if min_version and Version(cur_base_version) < Version(min_version):
msg += f">={min_version}` is required to run this model."
elif max_version and Version(current_version) > Version(max_version):
elif max_version and Version(cur_base_version) > Version(max_version):
msg += f"<={max_version}` is required to run this model."
else:
return
@ -362,8 +365,8 @@ _MULTIMODAL_EXAMPLE_MODELS = {
"Qwen2AudioForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-Audio-7B-Instruct"), # noqa: E501
"Qwen2VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2-VL-2B-Instruct"), # noqa: E501
"Qwen2_5_VLForConditionalGeneration": _HfExamplesInfo("Qwen/Qwen2.5-VL-3B-Instruct"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-7B", # noqa: E501
min_transformers_version="4.52"), # noqa: E501
"Qwen2_5OmniModel": _HfExamplesInfo("Qwen/Qwen2.5-Omni-3B",
min_transformers_version="4.52"),
"SkyworkR1VChatModel": _HfExamplesInfo("Skywork/Skywork-R1V-38B"),
"SmolVLMForConditionalGeneration": _HfExamplesInfo("HuggingFaceTB/SmolVLM2-2.2B-Instruct"), # noqa: E501
"UltravoxModel": _HfExamplesInfo("fixie-ai/ultravox-v0_5-llama-3_2-1b", # noqa: E501

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

View File

@ -0,0 +1,61 @@
# SPDX-License-Identifier: Apache-2.0
from pathlib import Path
import numpy as np
import pytest
import torch
from PIL import Image, ImageDraw
from vllm.multimodal.hasher import MultiModalHasher
ASSETS_DIR = Path(__file__).parent / "assets"
assert ASSETS_DIR.exists()
# NOTE: Images that are the same visually are allowed to have the same hash
@pytest.mark.parametrize("mode_pair", [("1", "L"), ("RGBA", "CMYK")])
def test_hash_collision_image_mode(mode_pair):
mode1, mode2 = mode_pair
image1 = Image.new(mode1, size=(10, 10), color=1)
image2 = Image.new(mode2, size=(10, 10), color=1)
hasher = MultiModalHasher
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
def test_hash_collision_image_palette():
# These images differ only in Image.palette._palette
image1 = Image.open(ASSETS_DIR / "image1.png")
image2 = Image.open(ASSETS_DIR / "image2.png")
hasher = MultiModalHasher
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
def test_hash_collision_image_transpose():
image1 = Image.new("1", size=(10, 20))
ImageDraw.Draw(image1).line([(0, 0), (10, 0)])
image2 = Image.new("1", size=(20, 10))
ImageDraw.Draw(image2).line([(0, 0), (0, 10)])
hasher = MultiModalHasher
assert hasher.hash_kwargs(image=image1) != hasher.hash_kwargs(image=image2)
def test_hash_collision_tensor_shape():
# The hash should be different though the data is the same when flattened
arr1 = torch.zeros((5, 10, 20, 3))
arr2 = torch.zeros((10, 20, 5, 3))
hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)
def test_hash_collision_array_shape():
# The hash should be different though the data is the same when flattened
arr1 = np.zeros((5, 10, 20, 3))
arr2 = np.zeros((10, 20, 5, 3))
hasher = MultiModalHasher
assert hasher.hash_kwargs(data=arr1) != hasher.hash_kwargs(data=arr2)

View File

@ -3,6 +3,7 @@ import importlib.metadata
import importlib.util
import pytest
import torch
DTYPE = ["bfloat16"]
@ -21,5 +22,30 @@ def test_pre_quantized_model(vllm_runner):
print(output)
@pytest.mark.skipif(not TORCHAO_AVAILABLE, reason="torchao is not available")
@pytest.mark.parametrize(
"pt_load_map_location",
[
"cuda:0",
# {"": "cuda"},
])
def test_opt_125m_int4wo_model_loading_with_params(vllm_runner,
pt_load_map_location):
"""
Test loading roberta-base model with no lm_head.
"""
torch._dynamo.reset()
model_name = "jerryzh168/opt-125m-int4wo"
with vllm_runner(model_name=model_name,
quantization="torchao",
dtype="bfloat16",
pt_load_map_location=pt_load_map_location) as llm:
output = llm.generate_greedy(["The capital of France is"],
max_tokens=32)
assert output
print(output)
if __name__ == "__main__":
pytest.main([__file__])

View File

@ -5,7 +5,8 @@ from typing import Literal, Union
import pytest
from vllm.config import ModelConfig, PoolerConfig, config, get_field
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform
@ -410,3 +411,16 @@ def test_generation_config_loading():
override_generation_config=override_generation_config)
assert model_config.get_diff_sampling_param() == override_generation_config
@pytest.mark.parametrize("pt_load_map_location", [
"cuda",
{
"": "cuda"
},
])
def test_load_config_pt_load_map_location(pt_load_map_location):
load_config = LoadConfig(pt_load_map_location=pt_load_map_location)
config = VllmConfig(load_config=load_config)
assert config.load_config.pt_load_map_location == pt_load_map_location

View File

@ -1165,3 +1165,80 @@ def test_kv_connector_handles_preemption():
# All memory should be freed since nothing is running.
assert scheduler.kv_cache_manager.block_pool.get_num_free_blocks() \
== NUM_BLOCKS - 1
def make_output(scheduler: Scheduler):
return ModelRunnerOutput(
req_ids=[req.request_id for req in scheduler.running],
req_id_to_index={
req.request_id: i
for i, req in enumerate(scheduler.running)
},
sampled_token_ids=[[1000]] * len(scheduler.running),
spec_token_ids=None,
logprobs=None,
prompt_logprobs_dict={},
)
def assert_scheduler_empty(scheduler: Scheduler):
"""Confirm the scheduler is "empty" - i.e. no leaks."""
# Scheduler Metadata.
assert len(scheduler.requests) == 0
assert len(scheduler.waiting) == 0
assert len(scheduler.running) == 0
assert len(scheduler.finished_req_ids) == 0
assert len(scheduler._cached_reqs_data) == 0
# EncoderCacheManager.
assert len(scheduler.encoder_cache_manager.freed) == 0
assert len(scheduler.encoder_cache_manager.cached) == 0
# KVCache Manager.
assert len(scheduler.kv_cache_manager.req_to_blocks) == 0
assert len(scheduler.kv_cache_manager.req_to_block_hashes) == 0
assert len(scheduler.kv_cache_manager.num_cached_block) == 0
num_free_blocks = (
scheduler.kv_cache_manager.block_pool.free_block_queue.num_free_blocks)
assert num_free_blocks == (
scheduler.kv_cache_manager.block_pool.num_gpu_blocks - 1)
# NOTE(rob): just the ref count on blocks will be 0. The hash
# value, etc will remain since we lazily evict for prefix cache.
for block in scheduler.kv_cache_manager.block_pool.blocks:
assert block.ref_cnt == 0
# assert block._block_hash is None
# assert (
# len(scheduler.kv_cache_manager.block_pool.cached_block_hash_to_block
# ) == 0)
def test_memory_leak():
"""Test that we do not have a memory leak."""
scheduler = create_scheduler(enable_prefix_caching=True)
NUM_REQUESTS = 5
NUM_TOKENS = 10
MAX_TOKENS = 10
requests = create_requests(num_requests=NUM_REQUESTS,
num_tokens=NUM_TOKENS,
max_tokens=MAX_TOKENS)
# Add each request.
for request in requests:
scheduler.add_request(request)
scheduler_output = scheduler.schedule()
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Iterate until done.
while True:
scheduler_output = scheduler.schedule()
if len(scheduler.running) == 0:
break
model_runner_output = make_output(scheduler)
scheduler.update_from_output(scheduler_output, model_runner_output)
# Confirm no memory leak.
assert_scheduler_empty(scheduler)

View File

@ -31,23 +31,38 @@ def test_deepseek_mla_attn_backend_module():
assert model_runner.attn_backend.__name__ == "TritonMLABackend"
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_prompt(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_prompt(batch_size, use_prompt_embeds, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner(
"facebook/opt-125m",
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enable_prompt_embeds=True,
)
seq_lens: list[int] = []
seq_group_metadata_list: list[SequenceGroupMetadata] = []
block_tables = {0: [1]}
expected_input_embeds_len = 0
for i in range(batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * seq_len,
prompt_embeds=torch.rand(seq_len, 10),
)
expected_input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(prompt_token_ids=range(seq_len))
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@ -68,6 +83,7 @@ def test_prepare_prompt(batch_size):
seq_group_metadata_list)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
slot_mapping = attn_metadata.slot_mapping
@ -121,7 +137,11 @@ def test_prepare_prompt(batch_size):
assert len(input_tokens) == sum(seq_lens)
assert len(input_positions) == sum(seq_lens)
torch.testing.assert_close(input_tokens, input_positions)
if expected_input_embeds_len == 0:
torch.testing.assert_close(input_tokens, input_positions)
assert input_embeds is None
else:
assert len(input_embeds) == expected_input_embeds_len
sampling_metadata = SamplingMetadata.prepare(
seq_group_metadata_list,
@ -145,8 +165,13 @@ def test_prepare_prompt(batch_size):
torch.testing.assert_close(actual, expected)
@pytest.mark.parametrize("batch_size", list(range(1, 257)))
def test_prepare_decode_cuda_graph(batch_size):
@pytest.mark.parametrize("batch_size", list(range(1, 257, 3)))
@pytest.mark.parametrize("use_prompt_embeds", [True, False])
def test_prepare_decode_cuda_graph(batch_size, use_prompt_embeds, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
@ -155,6 +180,7 @@ def test_prepare_decode_cuda_graph(batch_size):
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=False,
enable_prompt_embeds=True,
)
context_lens: list[int] = []
@ -164,10 +190,19 @@ def test_prepare_decode_cuda_graph(batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
context_lens.append(context_len)
seq_data = SequenceData.from_seqs(range(context_len))
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * context_len,
prompt_embeds=torch.rand(context_len, 10),
)
output_embed = torch.rand(10)
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(context_len))
output_embed = None
seq_data.update_num_computed_tokens(context_len)
# Append one token ID since prefill is finished.
seq_data.append_token_id(1, 0)
seq_data.append_token_id(1, 0, output_embed)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=False,
@ -180,9 +215,12 @@ def test_prepare_decode_cuda_graph(batch_size):
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata, slot_mapping = (
model_input.input_tokens, model_input.input_positions,
model_input.attn_metadata, model_input.attn_metadata.slot_mapping)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
slot_mapping = attn_metadata.slot_mapping
assert len(slot_mapping) == len(input_tokens)
expected_bs = model_runner.vllm_config.pad_for_cudagraph(
@ -227,7 +265,7 @@ def test_prepare_decode_cuda_graph(batch_size):
# block table's first index corresponds to each batch, meaning in
# decoding it is each token.
assert attn_metadata.block_tables.shape[0] == len(input_tokens)
# Block table's second dim correspondsd to each token's block number.
# Block table's second dim corresponds to each token's block number.
# It is padded up to
assert attn_metadata.block_tables.shape[1] == (
model_runner.get_max_block_per_batch())
@ -235,7 +273,12 @@ def test_prepare_decode_cuda_graph(batch_size):
assert len(input_tokens) == expected_bs
assert len(input_positions) == expected_bs
torch.allclose(input_tokens, input_positions)
if use_prompt_embeds:
expected_input_embeds_length = start_loc[-1]
assert len(input_embeds) == expected_input_embeds_length
assert expected_input_embeds_length <= expected_bs
else:
assert input_embeds is None
# Verify Sampling
expected_selected_token_indices = []
@ -266,25 +309,27 @@ def test_empty_seq_group():
seq_group_metadata_list: list[SequenceGroupMetadata] = []
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
input_tokens, input_positions, attn_metadata = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
attn_metadata = model_input.attn_metadata
assert input_tokens is None
assert input_positions is None
assert attn_metadata is None
model_input = model_runner._prepare_model_input_tensors(
seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata, return_seq_lens) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
model_input.seq_lens,
)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
return_seq_lens = model_input.seq_lens
assert input_tokens is None
assert input_positions is None
assert input_embeds is None
assert attn_metadata is None
assert return_seq_lens is None
@ -299,9 +344,15 @@ def distributed_init():
ensure_model_parallel_initialized(1, 1)
@pytest.mark.parametrize("batch_size", list(range(2, 128)))
@pytest.mark.parametrize("batch_size", list(range(2, 128, 3)))
@pytest.mark.parametrize("enforce_eager", [True, False])
def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
@pytest.mark.parametrize('use_prompt_embeds', [True, False])
def test_hybrid_batches(batch_size, enforce_eager, use_prompt_embeds,
distributed_init, monkeypatch):
if use_prompt_embeds:
# Prompt Embeddings is only currently supported on V0
monkeypatch.setenv("VLLM_USE_V1", "0")
model_runner = _create_model_runner(
"facebook/opt-125m",
seed=0,
@ -310,6 +361,7 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
max_num_batched_tokens=100000,
max_num_seqs=100000,
enable_chunked_prefill=True,
enable_prompt_embeds=True,
)
# Add prefill requests.
@ -320,11 +372,20 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
block_tables = {0: [1]}
prefill_batch_size = batch_size // 2
decode_batch_size = batch_size - prefill_batch_size
expected_input_embeds_len = 0
for i in range(prefill_batch_size):
# make sure all tokens fit into one block
seq_len = i % (model_runner.block_size - 1) + 1
seq_lens.append(seq_len)
seq_data = SequenceData.from_seqs(range(seq_len))
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * seq_len,
prompt_embeds=torch.rand(seq_len, 10),
)
expected_input_embeds_len += seq_len
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(seq_len), )
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
is_prompt=True,
@ -340,8 +401,21 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
for i in range(prefill_batch_size, batch_size):
# make sure all tokens fit into one block
context_len = i % (model_runner.block_size - 1) + 1
seq_data = SequenceData.from_seqs(range(context_len))
seq_data.append_token_id(1, 0)
if use_prompt_embeds:
seq_data = SequenceData.from_seqs(
prompt_token_ids=[0] * context_len,
prompt_embeds=torch.rand(context_len, 10),
)
output_embed = torch.rand(10)
# This also iterates the expected input_embeds, because the model
# needs both the input and output embeddings passed into together
expected_input_embeds_len += 1
else:
seq_data = SequenceData.from_seqs(
prompt_token_ids=range(context_len), )
output_embed = None
assert len(seq_data.prompt_token_ids) == context_len
seq_data.append_token_id(1, 0, output_embed)
seq_data.update_num_computed_tokens(context_len)
seq_group_metadata = SequenceGroupMetadata(
request_id=f"test_{i}",
@ -355,11 +429,11 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
decode_metadata_list.append(seq_group_metadata)
model_input = model_runner.prepare_model_input(seq_group_metadata_list)
(input_tokens, input_positions, attn_metadata) = (
model_input.input_tokens,
model_input.input_positions,
model_input.attn_metadata,
)
input_tokens = model_input.input_tokens
input_positions = model_input.input_positions
input_embeds = model_input.inputs_embeds
attn_metadata = model_input.attn_metadata
prefill_meta_actual = attn_metadata.prefill_metadata
decode_meta_actual = attn_metadata.decode_metadata
@ -369,6 +443,10 @@ def test_hybrid_batches(batch_size, enforce_eager, distributed_init):
assert attn_metadata.num_prefills == prefill_batch_size
assert attn_metadata.num_decode_tokens == decode_batch_size
assert attn_metadata.num_prefill_tokens == sum(seq_lens)
if expected_input_embeds_len == 0:
assert input_embeds is None
else:
assert len(input_embeds) == expected_input_embeds_len
# Verify attn metadata is consistent. We don't need to test individual
# values here because they are tested above.

View File

@ -559,7 +559,6 @@ def cutlass_scaled_mm(a: torch.Tensor,
scale_a.shape * [1, 128] == a.shape
scale_b.shape * [128, 128] == b.shape
"""
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16)
assert bias is None or bias.shape[0] == b.shape[
1] and bias.dtype == out_dtype
@ -567,7 +566,8 @@ def cutlass_scaled_mm(a: torch.Tensor,
m = a.shape[0]
n = b.shape[1]
if current_platform.is_rocm():
cutlass_compatible_b = (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0)
if current_platform.is_rocm() or not cutlass_compatible_b:
triton_scaled_mm_module = importlib.import_module(
"vllm.model_executor.layers.quantization.compressed_tensors."
"triton_scaled_mm")

View File

@ -18,19 +18,25 @@ except ImportError:
ASSET_DIR = "multimodal_asset"
AudioAssetName = Literal["winning_call", "mary_had_lamb"]
@dataclass(frozen=True)
class AudioAsset:
name: Literal["winning_call", "mary_had_lamb"]
name: AudioAssetName
@property
def filename(self) -> str:
return f"{self.name}.ogg"
@property
def audio_and_sample_rate(self) -> tuple[npt.NDArray, float]:
audio_path = get_vllm_public_assets(filename=f"{self.name}.ogg",
audio_path = get_vllm_public_assets(filename=self.filename,
s3_prefix=ASSET_DIR)
return librosa.load(audio_path, sr=None)
def get_local_path(self) -> Path:
return get_vllm_public_assets(filename=f"{self.name}.ogg",
return get_vllm_public_assets(filename=self.filename,
s3_prefix=ASSET_DIR)
@property

View File

@ -10,10 +10,12 @@ from .base import get_vllm_public_assets
VLM_IMAGES_DIR = "vision_model_images"
ImageAssetName = Literal["stop_sign", "cherry_blossom"]
@dataclass(frozen=True)
class ImageAsset:
name: Literal["stop_sign", "cherry_blossom"]
name: ImageAssetName
@property
def pil_image(self) -> Image.Image:

View File

@ -2,7 +2,7 @@
from dataclasses import dataclass
from functools import lru_cache
from typing import Literal, Optional
from typing import ClassVar, Literal, Optional
import cv2
import numpy as np
@ -76,20 +76,31 @@ def video_to_pil_images_list(path: str,
]
VideoAssetName = Literal["baby_reading"]
@dataclass(frozen=True)
class VideoAsset:
name: Literal["sample_demo_1.mp4"]
name: VideoAssetName
num_frames: int = -1
_NAME_TO_FILE: ClassVar[dict[VideoAssetName, str]] = {
"baby_reading": "sample_demo_1.mp4",
}
@property
def filename(self) -> str:
return self._NAME_TO_FILE[self.name]
@property
def pil_images(self) -> list[Image.Image]:
video_path = download_video_asset(self.name)
video_path = download_video_asset(self.filename)
ret = video_to_pil_images_list(video_path, self.num_frames)
return ret
@property
def np_ndarrays(self) -> npt.NDArray:
video_path = download_video_asset(self.name)
video_path = download_video_asset(self.filename)
ret = video_to_ndarrays(video_path, self.num_frames)
return ret
@ -99,5 +110,5 @@ class VideoAsset:
See also: examples/offline_inference/qwen2_5_omni/only_thinker.py
"""
video_path = download_video_asset(self.name)
video_path = download_video_asset(self.filename)
return librosa.load(video_path, sr=sampling_rate)[0]

View File

@ -281,8 +281,7 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
# remove padding
output = output.view(-1, self.num_heads,
q.shape[-1])[..., :v.shape[-1]]
output = output.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
return output.reshape(-1, self.num_heads * v.shape[-1])
def _forward_decode(
self,
@ -303,4 +302,4 @@ class CPUMLAImpl(MLACommonImpl[CPUMLAMetadata]):
ops.mla_decode_kvcache_cpu(o, q, kv_c_and_k_pe_cache, self.scale,
decode_meta.block_tables,
decode_meta.seq_lens_tensor)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -367,9 +367,17 @@ class FlashInferState(AttentionState):
# scheduled while CUDA graph mode is enabled. We don't run graph in that
# case.
if use_cuda_graph and is_decode:
batch_size = model_input.input_tokens.shape[0]
state = (self.runner.graph_runners[model_input.virtual_engine]
[batch_size].attn_state)
if model_input.inputs_embeds is None:
batch_size = model_input.input_tokens.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, False)].attn_state)
else:
batch_size = model_input.inputs_embeds.shape[0]
state = (
self.runner.graph_runners[model_input.virtual_engine][(
batch_size, True)].attn_state)
model_input.attn_metadata.prefill_wrapper = state._get_prefill_wrapper(
)
model_input.attn_metadata.decode_wrapper = state._get_decode_wrapper()

View File

@ -239,4 +239,4 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
causal=True,
)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -207,7 +207,7 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import get_flash_attn_version
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
LinearBase,
UnquantizedLinearMethod)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
@ -1032,12 +1032,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
@ -1055,9 +1050,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.triton_fa_func = triton_attention
# Handle the differences between the flash_attn_varlen from flash_attn
@ -1141,27 +1134,13 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return attn_out, rest[0]
return attn_out
def _v_up_proj_and_o_proj(self, x):
def _v_up_proj(self, x):
# Convert from (B, N, L) to (N, B, L)
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x = torch.bmm(x, self.W_UV)
# Convert from (N, B, V) to (B, N * V)
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
return self.o_proj(x)[0]
# Return `ql_nope`, `q_pe`
def _q_proj_and_k_up_proj(self, x):
q_nope, q_pe = self.q_proj(x)[0]\
.view(-1, self.num_heads, self.qk_head_dim)\
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
return x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
def process_weights_after_loading(self, act_dtype: torch.dtype):
@ -1345,7 +1324,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
suffix_lse=suffix_lse,
)
return self.o_proj(output.flatten(start_dim=-2))[0]
return output.flatten(start_dim=-2)
@abstractmethod
def _forward_decode(
@ -1360,7 +1339,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
q: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
@ -1391,27 +1370,32 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
assert hasattr(attn_metadata, "input_positions")
num_prefill_tokens: int = attn_metadata.num_prefill_tokens
q = q.view(-1, self.num_heads, self.qk_head_dim)
decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
decode_q = q[num_prefill_tokens:]
decode_k_pe = k_pe[num_prefill_tokens:]
decode_input_positions = \
attn_metadata.input_positions[num_prefill_tokens:]
prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
prefill_q = q[:num_prefill_tokens]
prefill_k_pe = k_pe[:num_prefill_tokens]
prefill_input_positions = \
attn_metadata.input_positions[:num_prefill_tokens]
prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode:
decode_ql_nope, decode_q_pe = \
self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_nope, decode_q_pe = decode_q.split(
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
# Convert from (B, N, P) to (N, B, P)
decode_q_nope = decode_q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
decode_ql_nope = torch.bmm(decode_q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
decode_ql_nope = decode_ql_nope.transpose(0, 1)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
@ -1429,9 +1413,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
output = torch.empty(attn_metadata.num_prefill_tokens +
attn_metadata.num_decode_tokens,
self.o_proj.output_size,
device=hidden_states_or_q_c.device,
dtype=hidden_states_or_q_c.dtype)
self.v_head_dim * self.num_heads,
device=q.device,
dtype=q.dtype)
if has_prefill:
output[:num_prefill_tokens] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,

View File

@ -409,4 +409,4 @@ class AiterMLAImpl(MLACommonImpl[AiterMLAMetadata]):
attn_metadata.paged_kv_indices,
attn_metadata.paged_kv_last_page_lens)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -110,4 +110,4 @@ class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
decode_meta.seq_lens_tensor, attn_logits,
num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
return self._v_up_proj(o)

View File

@ -289,7 +289,7 @@ def chunked_prefill_paged_decode(
max_num_partitions = ((max_seq_len + _PARTITION_SIZE_ROCM - 1) //
_PARTITION_SIZE_ROCM)
assert _PARTITION_SIZE_ROCM % block_size == 0
total_num_seq = query.shape[0]
total_num_seq = block_table.shape[0]
tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions,
head_size),

View File

@ -268,7 +268,7 @@ class ModelConfig:
It can be a branch name, a tag name, or a commit id. If unspecified, will
use the default version."""
rope_scaling: dict[str, Any] = field(default_factory=dict)
"""RoPE scaling configuration in JSON format. For example,
"""RoPE scaling configuration. For example,
`{"rope_type":"dynamic","factor":2.0}`."""
rope_theta: Optional[float] = None
"""RoPE theta. Use with `rope_scaling`. In some cases, changing the RoPE
@ -321,6 +321,10 @@ class ModelConfig:
"""Skip initialization of tokenizer and detokenizer. Expects valid
`prompt_token_ids` and `None` for prompt from the input. The generated
output will contain token ids."""
enable_prompt_embeds: bool = False
"""If `True`, enables passing text embeddings as inputs via the
`prompt_embeds` key. Note that enabling this will double the time required
for graph compilation."""
served_model_name: Optional[Union[str, list[str]]] = None
"""The model name(s) used in the API. If multiple names are provided, the
server will respond to any of the provided names. The model name in the
@ -346,14 +350,13 @@ class ModelConfig:
(stored in `~/.huggingface`)."""
hf_overrides: HfOverrides = field(default_factory=dict)
"""If a dictionary, contains arguments to be forwarded to the Hugging Face
config. If a callable, it is called to update the HuggingFace config. When
specified via CLI, the argument must be a valid JSON string."""
config. If a callable, it is called to update the HuggingFace config."""
mm_processor_kwargs: Optional[dict[str, Any]] = None
"""Arguments to be forwarded to the model's processor for multi-modal data,
e.g., image processor. Overrides for the multi-modal processor obtained
from `AutoProcessor.from_pretrained`. The available overrides depend on the
model that is being run. For example, for Phi-3-Vision: `{"num_crops": 4}`.
When specified via CLI, the argument must be a valid JSON string."""
"""
disable_mm_preprocessor_cache: bool = False
"""If `True`, disable caching of the multi-modal preprocessor/mapper (not
recommended)."""
@ -361,15 +364,14 @@ class ModelConfig:
"""Initialize non-default neuron config or override default neuron config
that are specific to Neuron devices, this argument will be used to
configure the neuron config that can not be gathered from the vllm
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`. When specified via CLI,
the argument must be a valid JSON string."""
arguments. e.g. `{"cast_logits_dtype": "bloat16"}`."""
pooler_config: Optional["PoolerConfig"] = field(init=False)
"""Pooler config which controls the behaviour of output pooling in pooling
models."""
override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
"""Initialize non-default pooling config or override default pooling config
for the pooling model. e.g. `{"pooling_type": "mean", "normalize": false}`.
When specified via CLI, the argument must be a valid JSON string."""
"""
logits_processor_pattern: Optional[str] = None
"""Optional regex pattern specifying valid logits processor qualified names
that can be passed with the `logits_processors` extra completion argument.
@ -385,8 +387,7 @@ class ModelConfig:
"""Overrides or sets generation config. e.g. `{"temperature": 0.5}`. If
used with `--generation-config auto`, the override parameters will be
merged with the default config from the model. If used with
`--generation-config vllm`, only the override parameters are used.
When specified via CLI, the argument must be a valid JSON string."""
`--generation-config vllm`, only the override parameters are used."""
enable_sleep_mode: bool = False
"""Enable sleep mode for the engine (only cuda platform is supported)."""
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO.value
@ -1556,14 +1557,23 @@ class LoadConfig:
cache directory of Hugging Face."""
model_loader_extra_config: dict = field(default_factory=dict)
"""Extra config for model loader. This will be passed to the model loader
corresponding to the chosen load_format. This should be a JSON string that
will be parsed into a dictionary."""
corresponding to the chosen load_format."""
ignore_patterns: Optional[Union[list[str], str]] = None
"""The list of patterns to ignore when loading the model. Default to
"original/**/*" to avoid repeated loading of llama's checkpoints."""
use_tqdm_on_load: bool = True
"""Whether to enable tqdm for showing progress bar when loading model
weights."""
pt_load_map_location: Union[str, dict[str, str]] = "cpu"
"""
pt_load_map_location: the map location for loading pytorch checkpoint, to
support loading checkpoints can only be loaded on certain devices like
"cuda", this is equivalent to {"": "cuda"}. Another supported format is
mapping from different devices like from GPU 1 to GPU 0:
{"cuda:1": "cuda:0"}. Note that when passed from command line, the strings
in dictionary needs to be double quoted for json parsing. For more details,
see original doc for `map_location` in https://pytorch.org/docs/stable/generated/torch.load.html
"""
def compute_hash(self) -> str:
"""
@ -2816,7 +2826,6 @@ class MultiModalConfig:
"limit_mm_per_prompt")
"""
The maximum number of input items allowed per prompt for each modality.
This should be a JSON string that will be parsed into a dictionary.
Defaults to 1 (V0) or 999 (V1) for each modality.
For example, to allow up to 16 images and 2 videos per prompt:
@ -3135,6 +3144,14 @@ def _get_and_verify_max_len(
# derived length from the HF model config.
if max_model_len is None:
max_model_len = int(derived_max_model_len)
if current_platform.is_tpu():
logger.warning(
"--max-model-len is not specified, "
"it's currently using model's default length %s, "
"which might be too large."
"Please input with --max-model-len based on your "
"request input length and output length, to avoid "
"unnecessary degradation.", max_model_len)
elif max_model_len > derived_max_model_len:
# Some models might have a separate key for specifying model_max_length
# that will be bigger than derived_max_model_len. We compare user input

View File

@ -1071,6 +1071,7 @@ class Scheduler:
)
ignored_seq_groups: List[SequenceGroup] = []
seq_groups: List[ScheduledSequenceGroup] = []
using_prompt_embeds: bool = False
waiting_queue = self.waiting
@ -1138,6 +1139,15 @@ class Scheduler:
waiting_queue.popleft()
continue
# We cannot mix sequence groups that use prompt embeds and
# those that do not.
if len(seq_groups) == 0:
using_prompt_embeds = seq_group.uses_prompt_embeds()
if using_prompt_embeds != seq_group.uses_prompt_embeds():
leftover_waiting_sequences.appendleft(seq_group)
waiting_queue.popleft()
continue
lora_int_id = 0
if self.lora_enabled:
lora_int_id = seq_group.lora_int_id
@ -1295,17 +1305,39 @@ class Scheduler:
# Merge lists
num_prefill_groups = len(prefills.seq_groups)
ignored_seq_groups_for_embeds = list[SequenceGroup]()
if num_prefill_groups > 0:
scheduled_seq_groups = prefills.seq_groups
scheduled_seq_groups.extend(running_scheduled.decode_seq_groups)
ignored_seq_groups_for_embeds.clear()
else:
scheduled_seq_groups = running_scheduled.decode_seq_groups
if len(scheduled_seq_groups) > 0:
using_prompt_embeds = scheduled_seq_groups[
0].seq_group.uses_prompt_embeds()
ignored_seq_groups_for_embeds.clear()
indices_ignored = list[int]()
for i, schedule_seq_group in enumerate(scheduled_seq_groups):
if using_prompt_embeds !=\
schedule_seq_group.seq_group.uses_prompt_embeds():
ignored_seq_groups_for_embeds.append(
schedule_seq_group.seq_group)
indices_ignored.append(i)
if len(ignored_seq_groups_for_embeds) > 0:
scheduled_seq_groups = [
group for i, group in enumerate(scheduled_seq_groups)
if i not in indices_ignored
]
else:
ignored_seq_groups_for_embeds.clear()
scheduled_seq_groups.extend(swapped_in.decode_seq_groups)
blocks_to_copy = running_scheduled.blocks_to_copy
blocks_to_copy.extend(swapped_in.blocks_to_copy)
ignored_seq_groups = prefills.ignored_seq_groups
ignored_seq_groups.extend(ignored_seq_groups_for_embeds)
ignored_seq_groups.extend(swapped_in.infeasible_seq_groups)
return SchedulerOutputs(

View File

@ -64,6 +64,13 @@ def optional_type(
return _optional_type
def union_dict_and_str(val: str) -> Optional[Union[str, dict[str, str]]]:
if not re.match("^{.*}$", val):
return str(val)
else:
return optional_type(json.loads)(val)
@deprecated(
"Passing a JSON argument as a string containing comma separated key=value "
"pairs is deprecated. This will be removed in v0.10.0. Please use a JSON "
@ -143,7 +150,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
# Get the help text for the field
name = field.name
help = cls_docs[name]
help = cls_docs[name].strip()
# Escape % for argparse
help = help.replace("%", "%%")
@ -158,6 +165,7 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
type_hints.add(field.type)
# Set other kwargs based on the type hints
json_tip = "\n\nShould be a valid JSON string."
if contains_type(type_hints, bool):
# Creates --no-<name> and --<name> flags
kwargs[name]["action"] = argparse.BooleanOptionalAction
@ -187,9 +195,14 @@ def get_kwargs(cls: ConfigType) -> dict[str, Any]:
kwargs[name]["type"] = human_readable_int
elif contains_type(type_hints, float):
kwargs[name]["type"] = float
elif contains_type(type_hints,
dict) and (contains_type(type_hints, str) or any(
is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = union_dict_and_str
elif contains_type(type_hints, dict):
# Dict arguments will always be optional
kwargs[name]["type"] = optional_type(json.loads)
kwargs[name]["help"] += json_tip
elif (contains_type(type_hints, str)
or any(is_not_builtin(th) for th in type_hints)):
kwargs[name]["type"] = str
@ -221,6 +234,7 @@ class EngineArgs:
hf_config_path: Optional[str] = ModelConfig.hf_config_path
task: TaskOption = ModelConfig.task
skip_tokenizer_init: bool = ModelConfig.skip_tokenizer_init
enable_prompt_embeds: bool = ModelConfig.enable_prompt_embeds
tokenizer_mode: TokenizerMode = ModelConfig.tokenizer_mode
trust_remote_code: bool = ModelConfig.trust_remote_code
allowed_local_media_path: str = ModelConfig.allowed_local_media_path
@ -371,6 +385,7 @@ class EngineArgs:
reasoning_parser: str = DecodingConfig.reasoning_backend
use_tqdm_on_load: bool = LoadConfig.use_tqdm_on_load
pt_load_map_location: str = LoadConfig.pt_load_map_location
def __post_init__(self):
# support `EngineArgs(compilation_config={...})`
@ -431,6 +446,8 @@ class EngineArgs:
**model_kwargs["disable_cascade_attn"])
model_group.add_argument("--skip-tokenizer-init",
**model_kwargs["skip_tokenizer_init"])
model_group.add_argument("--enable-prompt-embeds",
**model_kwargs["enable_prompt_embeds"])
model_group.add_argument("--served-model-name",
**model_kwargs["served_model_name"])
# This one is a special case because it is the
@ -491,6 +508,8 @@ class EngineArgs:
type=str,
default=None,
help='Name or path of the QLoRA adapter.')
load_group.add_argument('--pt-load-map-location',
**load_kwargs["pt_load_map_location"])
# Guided decoding arguments
guided_decoding_kwargs = get_kwargs(DecodingConfig)
@ -858,6 +877,7 @@ class EngineArgs:
disable_sliding_window=self.disable_sliding_window,
disable_cascade_attn=self.disable_cascade_attn,
skip_tokenizer_init=self.skip_tokenizer_init,
enable_prompt_embeds=self.enable_prompt_embeds,
served_model_name=self.served_model_name,
limit_mm_per_prompt=self.limit_mm_per_prompt,
use_async_output_proc=not self.disable_async_output_proc,
@ -883,12 +903,14 @@ class EngineArgs:
if self.quantization == "bitsandbytes":
self.load_format = "bitsandbytes"
return LoadConfig(
load_format=self.load_format,
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=self.use_tqdm_on_load,
pt_load_map_location=self.pt_load_map_location,
)
def create_speculative_config(
@ -1423,8 +1445,8 @@ class EngineArgs:
# as the platform that vLLM is running on (e.g. the case of scaling
# vLLM with Ray) and has no GPUs. In this case we use the default
# values for non-H100/H200 GPUs.
from vllm.platforms import current_platform
try:
from vllm.platforms import current_platform
device_memory = current_platform.get_device_total_memory()
except Exception:
# This is only used to set default_max_num_batched_tokens
@ -1445,11 +1467,37 @@ class EngineArgs:
}
default_max_num_seqs = 256
# tpu specific default values.
if current_platform.is_tpu():
default_max_num_batched_tokens_tpu = {
UsageContext.LLM_CLASS: {
'V6E': 2048,
'V5E': 1024,
'V5P': 512,
},
UsageContext.OPENAI_API_SERVER: {
'V6E': 1024,
'V5E': 512,
'V5P': 256,
}
}
use_context_value = usage_context.value if usage_context else None
if (self.max_num_batched_tokens is None
and usage_context in default_max_num_batched_tokens):
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
if current_platform.is_tpu():
chip_name = current_platform.get_device_name()
if chip_name in default_max_num_batched_tokens_tpu[
usage_context]:
self.max_num_batched_tokens = \
default_max_num_batched_tokens_tpu[
usage_context][chip_name]
else:
self.max_num_batched_tokens = \
default_max_num_batched_tokens[usage_context]
else:
self.max_num_batched_tokens = default_max_num_batched_tokens[
usage_context]
logger.debug(
"Setting max_num_batched_tokens to %d for %s usage context.",
self.max_num_batched_tokens, use_context_value)
@ -1513,7 +1561,7 @@ def _warn_or_fallback(feature_name: str) -> bool:
def human_readable_int(value):
"""Parse human-readable integers like '1k', '2M', etc.
Including decimal values with decimal multipliers.
Examples:
- '1k' -> 1,000
- '1K' -> 1,024

View File

@ -489,9 +489,13 @@ class _AsyncLLMEngine(LLMEngine):
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
tokenizer = await self.get_tokenizer_async(lora_request)
self._validate_token_prompt(prompt, tokenizer=tokenizer)
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
# We use the -2 dimension (instead of 0) in case a batched input
# of batch size 1 is passed in.
prompt["prompt_token_ids"] = [0
] * prompt["prompt_embeds"].shape[-2]
processed_inputs = await self.input_preprocessor.preprocess_async(
prompt,

View File

@ -30,7 +30,7 @@ from vllm.entrypoints.openai.logits_processors import (
get_logits_processors as get_openai_logits_processors)
from vllm.executor.executor_base import ExecutorBase
from vllm.inputs import ProcessorInputs, PromptType, SingletonInputs
from vllm.inputs.parse import is_token_prompt, split_enc_dec_inputs
from vllm.inputs.parse import split_enc_dec_inputs
from vllm.inputs.preprocess import InputPreprocessor
from vllm.logger import init_logger
from vllm.logits_process import get_bad_words_logits_processors
@ -753,10 +753,11 @@ class LLMEngine:
if arrival_time is None:
arrival_time = time.time()
if self.tokenizer is not None:
self._validate_token_prompt(
prompt,
tokenizer=self.get_tokenizer(lora_request=lora_request))
if (isinstance(prompt, dict)
and prompt.get("prompt_embeds", None) is not None
and not prompt.get("prompt_token_ids", None)):
seq_len = prompt["prompt_embeds"].shape[0]
prompt["prompt_token_ids"] = [0] * seq_len
processed_inputs = self.input_preprocessor.preprocess(
prompt,
@ -776,27 +777,6 @@ class LLMEngine:
priority=priority,
)
def _validate_token_prompt(self, prompt: PromptType,
tokenizer: AnyTokenizer):
# Guard against out-of-vocab tokens.
# For some tokenizers, tokenizer.decode will happily return empty text
# for token ids that are out of vocab, and we don't detect token ids
# that are greater than the max token id before running the model.
# However, these token ids will later crash a cuda kernel at runtime
# with an index out of bounds error. This will crash the entire engine.
# This needs to happen before multimodal input pre-processing, which
# may add dummy <image> tokens that aren't part of the tokenizer's
# vocabulary.
if is_token_prompt(prompt):
prompt_ids = prompt["prompt_token_ids"]
if len(prompt_ids) == 0:
# Empty prompt check is handled later
return
max_input_id = max(prompt_ids)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
"Token id {} is out of vocabulary".format(max_input_id))
def _create_sequence_group_with_sampling(
self,
request_id: str,
@ -1267,11 +1247,13 @@ class LLMEngine:
if self.scheduler_config.is_multi_step:
is_prefill_append = seq.data.get_num_uncomputed_tokens(
) == 0
seq.append_token_id(sample.output_token, sample.logprobs)
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if not is_prefill_append:
seq_group.update_num_computed_tokens(1)
else:
seq.append_token_id(sample.output_token, sample.logprobs)
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
def step(self) -> List[Union[RequestOutput, PoolingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
@ -2032,13 +2014,21 @@ class LLMEngine:
tokenizer = (None if self.tokenizer is None else
self.tokenizer.get_lora_tokenizer(lora_request))
prompt_ids = prompt_inputs["prompt_token_ids"]
prompt_ids = prompt_inputs.get("prompt_token_ids", [])
if not prompt_ids:
if prompt_type == "encoder" and model_config.is_multimodal_model:
pass # Mllama may have empty encoder inputs for text-only data
if prompt_inputs["type"] == "embeds":
pass
else:
raise ValueError(f"The {prompt_type} prompt cannot be empty")
if tokenizer is not None:
max_input_id = max(prompt_ids, default=0)
if max_input_id > tokenizer.max_token_id:
raise ValueError(
f"Token id {max_input_id} is out of vocabulary")
max_prompt_len = self.model_config.max_model_len
if len(prompt_ids) > max_prompt_len:
if prompt_type == "encoder" and model_config.is_multimodal_model:

View File

@ -167,6 +167,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
sampling_params: SamplingParams) -> None:
output_token_ids = [sample.output_token for sample in valid_samples]
output_logprobs = [sample.logprobs for sample in valid_samples]
output_embeds = [sample.output_embed for sample in valid_samples]
# Truncate to max_tokens if necessary.
remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() +
@ -190,11 +191,12 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
is_prefill_sampled_token = seq.data.get_num_uncomputed_tokens() == 0
# Incrementally append tokens to the sequence, as if we had only one new
# token.
for output_token_id, output_logprob in zip(output_token_ids,
output_logprobs):
for output_token_id, output_logprob, output_embed in zip(
output_token_ids, output_logprobs, output_embeds):
seq.append_token_id(
token_id=output_token_id,
logprobs=output_logprob,
token_embed=output_embed,
)
if is_prefill_sampled_token:

View File

@ -119,7 +119,8 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor):
sample = outputs.samples[0]
seq = seq_group.first_seq
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
seq.append_token_id(sample.output_token, sample.logprobs,
sample.output_embed)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)

View File

@ -83,6 +83,9 @@ class EngineClient(ABC):
else:
processed_inputs = preprocessor._prompt_to_llm_inputs(prompt)
if processed_inputs["type"] == "embeds":
raise NotImplementedError
prompt_token_ids = processed_inputs["prompt_token_ids"]
prompt_text = processed_inputs.get("prompt")
multi_modal_data = processed_inputs.get("multi_modal_data")

View File

@ -27,7 +27,7 @@ from vllm.entrypoints.score_utils import (_cosine_similarity,
_validate_score_input_lens)
from vllm.entrypoints.utils import _validate_truncation_size
from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt
from vllm.inputs.parse import is_token_prompt, parse_and_batch_prompt
from vllm.inputs.parse import parse_and_batch_prompt
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor.guided_decoding.guided_fields import (
@ -567,10 +567,12 @@ class LLM:
mm_kwargs["mm_processor_kwargs"] = prompt[
"mm_processor_kwargs"]
if is_token_prompt(prompt):
if "prompt_token_ids" in prompt:
prompt = cast(TokensPrompt, prompt) # Needed for mypy
prompt_tokens = prompt["prompt_token_ids"]
else:
prompt_tokens = tokenizer.encode(prompt["prompt"])
instances.append(
BeamSearchInstance(prompt_tokens, logprobs=None, **mm_kwargs))

View File

@ -1,9 +1,9 @@
# SPDX-License-Identifier: Apache-2.0
from .data import (DecoderOnlyInputs, EncoderDecoderInputs,
from .data import (DecoderOnlyInputs, EmbedsInputs, EncoderDecoderInputs,
ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, build_explicit_enc_dec_prompt,
TokensPrompt, build_explicit_enc_dec_prompt, embeds_inputs,
to_enc_dec_tuple_list, token_inputs, zip_enc_dec_prompts)
from .registry import (DummyData, InputContext, InputProcessingContext,
InputRegistry)
@ -21,7 +21,9 @@ __all__ = [
"SingletonPrompt",
"ExplicitEncoderDecoderPrompt",
"TokenInputs",
"EmbedsInputs",
"token_inputs",
"embeds_inputs",
"DecoderOnlyInputs",
"EncoderDecoderInputs",
"ProcessorInputs",

View File

@ -2,6 +2,7 @@
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, Generic, Literal, Optional, Union, cast
import torch
from typing_extensions import NotRequired, TypedDict, TypeVar
if TYPE_CHECKING:
@ -63,12 +64,25 @@ class TokensPrompt(TypedDict):
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt]
class EmbedsPrompt(TypedDict):
"""Schema for a prompt provided via token embeddings."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
SingletonPrompt = Union[str, TextPrompt, TokensPrompt, EmbedsPrompt]
"""
Set of possible schemas for a single prompt:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
Note that "singleton" is as opposed to a data structure
which encapsulates multiple prompts, i.e. of the sort
@ -129,6 +143,7 @@ both decoder-only and encoder/decoder input types:
- A text prompt (:class:`str` or :class:`TextPrompt`)
- A tokenized prompt (:class:`TokensPrompt`)
- An embeddings prompt (:class:`EmbedsPrompt`)
- A single data structure containing both an encoder and a decoder prompt
(:class:`ExplicitEncoderDecoderPrompt`)
"""
@ -176,7 +191,35 @@ def token_inputs(
return inputs
DecoderOnlyInputs = Union[TokenInputs, "MultiModalInputs"]
class EmbedsInputs(TypedDict):
"""Represents embeddings-based inputs."""
type: Literal["embeds"]
"""The type of inputs."""
prompt_embeds: torch.Tensor
"""The embeddings of the prompt."""
cache_salt: NotRequired[str]
"""
Optional cache salt to be used for prefix caching.
"""
def embeds_inputs(
prompt_embeds: torch.Tensor,
cache_salt: Optional[str] = None,
) -> EmbedsInputs:
"""Construct :class:`EmbedsInputs` from optional values."""
inputs = EmbedsInputs(type="embeds", prompt_embeds=prompt_embeds)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
DecoderOnlyInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
@ -198,7 +241,7 @@ class EncoderDecoderInputs(TypedDict):
"""The inputs for the decoder portion."""
SingletonInputs = Union[TokenInputs, "MultiModalInputs"]
SingletonInputs = Union[TokenInputs, EmbedsInputs, "MultiModalInputs"]
"""
A processed :class:`SingletonPrompt` which can be passed to
:class:`vllm.sequence.Sequence`.

View File

@ -6,8 +6,9 @@ from typing_extensions import TypeIs
from vllm.utils import is_list_of
from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
from .data import (EmbedsPrompt, ExplicitEncoderDecoderPrompt, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, TextPrompt,
TokensPrompt)
class ParsedText(TypedDict):
@ -84,23 +85,51 @@ class ParsedTokensPrompt(TypedDict):
content: TokensPrompt
def parse_singleton_prompt(
prompt: SingletonPrompt,
) -> Union[ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt]:
class ParsedEmbedsPrompt(TypedDict):
type: Literal['embeds']
content: EmbedsPrompt
ParsedSingletonPrompt = Union[ParsedStrPrompt, ParsedTextPrompt,
ParsedTokensPrompt, ParsedEmbedsPrompt]
@overload
def parse_singleton_prompt(prompt: str) -> ParsedStrPrompt:
...
@overload
def parse_singleton_prompt(prompt: TextPrompt) -> ParsedTextPrompt:
...
@overload
def parse_singleton_prompt(prompt: TokensPrompt) -> ParsedTokensPrompt:
...
@overload
def parse_singleton_prompt(prompt: EmbedsPrompt) -> ParsedEmbedsPrompt:
...
def parse_singleton_prompt(prompt: SingletonPrompt) -> ParsedSingletonPrompt:
if isinstance(prompt, str):
return ParsedStrPrompt(type="str", content=prompt)
elif isinstance(prompt, dict):
if "prompt_token_ids" in prompt:
return ParsedTokensPrompt(type="tokens",
content=prompt) # type: ignore
# Type ignores are because mypy does not correctly infer the TypedDicts
# Pyright does succeed.
if "prompt_embeds" in prompt:
return ParsedEmbedsPrompt(
type="embeds", content=prompt) # type: ignore[typeddict-item]
elif "prompt_token_ids" in prompt:
return ParsedTokensPrompt(
type="tokens", content=prompt) # type: ignore[typeddict-item]
elif "prompt" in prompt:
return ParsedTextPrompt(type="text", content=prompt)
raise TypeError("inputs must be a string, TextPrompt, or TokensPrompt")
def is_token_prompt(prompt: PromptType) -> TypeIs[TokensPrompt]:
return isinstance(prompt, dict) and "prompt_token_ids" in prompt
raise TypeError(
"inputs must be a string, TextPrompt, TokensPrompt, or EmbedsPrompt")
def is_explicit_encoder_decoder_prompt(

View File

@ -6,6 +6,7 @@ from typing import Any, Optional, Union, cast
from typing_extensions import assert_never
from vllm import envs
from vllm.config import ModelConfig
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
@ -13,12 +14,14 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry
from vllm.multimodal.inputs import (MultiModalDataDict, MultiModalEncDecInputs,
MultiModalInputs)
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.transformers_utils.tokenizer import AnyTokenizer
from vllm.transformers_utils.tokenizer_group import TokenizerGroup
from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs,
PromptType, SingletonInputs, SingletonPrompt, token_inputs)
from .parse import (ParsedStrPrompt, ParsedTextPrompt, ParsedTokensPrompt,
is_explicit_encoder_decoder_prompt, parse_singleton_prompt)
from .data import (DecoderOnlyInputs, EmbedsInputs, EmbedsPrompt,
EncoderDecoderInputs, ProcessorInputs, PromptType,
SingletonInputs, SingletonPrompt, TextPrompt, TokenInputs,
TokensPrompt, embeds_inputs, token_inputs)
from .parse import is_explicit_encoder_decoder_prompt, parse_singleton_prompt
logger = init_logger(__name__)
@ -137,13 +140,10 @@ class InputPreprocessor:
"""
Prepares `decoder_input_ids` for generation with encoder-decoder models.
Based on
https://github.com/huggingface/transformers/blob/
4037a2b5b1278736e566aec12e169100275545ea/
src/transformers/generation/utils.py
specifically GenerationMixin._prepare_decoder_input_ids_for_generation()
Based on:
https://github.com/huggingface/transformers/blob/4037a2b5b1278736e566aec12e169100275545ea/src/transformers/generation/utils.py
specifically,
`GenerationMixin._prepare_decoder_input_ids_for_generation()`.
Arguments:
@ -180,6 +180,23 @@ class InputPreprocessor:
return prompt_token_ids
def _get_tokenization_kw(
self,
overrides: Optional[dict[str, Any]] = None,
) -> dict[str, Any]:
kwargs = dict[str, Any]()
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
kwargs["add_special_tokens"] = False
if overrides:
kwargs.update(overrides)
return kwargs
def _tokenize_prompt(
self,
prompt: str,
@ -191,18 +208,11 @@ class InputPreprocessor:
corresponding token IDs.
"""
tokenizer = self.get_tokenizer_group()
if tokenization_kwargs is None:
tokenization_kwargs = {}
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs["add_special_tokens"] = False
encoder_config = self.model_config.encoder_config
if (self.model_config.encoder_config is not None
and self.model_config.encoder_config.get(
"do_lower_case", False)):
if encoder_config and encoder_config.get("do_lower_case", False):
prompt = prompt.lower()
return tokenizer.encode(prompt=prompt,
@ -217,18 +227,36 @@ class InputPreprocessor:
) -> list[int]:
"""Async version of :meth:`_tokenize_prompt`."""
tokenizer = self.get_tokenizer_group()
if tokenization_kwargs is None:
tokenization_kwargs = {}
tokenization_kwargs = self._get_tokenization_kw(tokenization_kwargs)
if self.model_config.hf_config.model_type == "whisper":
# For Whisper, special tokens should be provided by the user based
# on the task and language of their request. Also needed to avoid
# appending an EOS token to the prompt which disrupts generation.
tokenization_kwargs["add_special_tokens"] = False
return await tokenizer.encode_async(prompt=prompt,
lora_request=lora_request,
**tokenization_kwargs)
def _get_mm_tokenizer(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return tokenizer_group.get_lora_tokenizer(lora_request)
async def _get_mm_tokenizer_async(
self,
lora_request: Optional[LoRARequest],
) -> AnyTokenizer:
# PrithviGeoSpatialMAE needs to be initialized without a tokenizer
# while using also multi-modal input
if not self.tokenizer:
return cast(AnyTokenizer, object()) # Dummy
tokenizer_group = self.get_tokenizer_group()
return await tokenizer_group.get_lora_tokenizer_async(lora_request)
def _process_multimodal(
self,
prompt: Union[str, list[int]],
@ -241,13 +269,7 @@ class InputPreprocessor:
Apply the model's multi-modal processor to a multi-modal prompt,
returning the corresponding token IDs and metadata.
"""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
tokenizer = self._get_mm_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
@ -267,14 +289,7 @@ class InputPreprocessor:
return_mm_hashes: bool = False,
) -> MultiModalInputs:
"""Async version of :meth:`_process_multimodal`."""
# At the moment on model (PrithviGeoSpatialMAE) requires to be
# initialized without a tokenizer while using also multi-modal input
if not self.tokenizer:
tokenizer = object() # Dummy
else:
tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request)
tokenizer = await self._get_mm_tokenizer_async(lora_request)
mm_processor = self.mm_registry.create_processor(self.model_config,
tokenizer=tokenizer)
@ -284,28 +299,163 @@ class InputPreprocessor:
return mm_processor.apply(prompt, mm_data, mm_processor_kwargs,
return_mm_hashes)
def _get_prompt_data(self, parsed_prompt: Union[ParsedStrPrompt,
ParsedTextPrompt,
ParsedTokensPrompt]):
prompt_text = None
prompt_token_ids = None
token_type_ids = None
cache_salt = None
def _process_embeds(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
if not self.model_config.enable_prompt_embeds:
raise ValueError("You must set `--enable-prompt-embeds` to input "
"`prompt_embeds`.")
if envs.VLLM_USE_V1:
raise ValueError("`prompt_embeds` is only available in V0.")
if parsed_prompt["type"] == "str":
prompt_text = parsed_prompt["content"]
prompt_embeds = parsed_content["prompt_embeds"]
# prompt_embeds must be (seq_len, hidden_size), but if the user
# passes in a batch of size 1, i.e. (1, seq_len, hidden_size),
# we can unambiguously process the intent by squeezing the batch
# dimension.
if prompt_embeds.ndim == 3:
prompt_embeds = prompt_embeds.squeeze(dim=0)
if prompt_embeds.ndim != 2:
raise ValueError(
"prompt_embeds must be of shape (seq_len, hidden_size).")
return embeds_inputs(prompt_embeds=prompt_embeds,
cache_salt=parsed_content.get("cache_salt"))
async def _process_embeds_async(
self,
parsed_content: EmbedsPrompt,
) -> EmbedsInputs:
return self._process_embeds(parsed_content)
def _process_tokens(
self,
parsed_content: TokensPrompt,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
cache_salt = parsed_prompt["content"].get("cache_salt")
if parsed_prompt["type"] == "text":
prompt_text = parsed_prompt["content"]["prompt"]
elif parsed_prompt["type"] == "tokens":
prompt_token_ids = parsed_prompt["content"].get(
"prompt_token_ids")
token_type_ids = parsed_prompt["content"].get("token_type_ids")
else:
assert_never(parsed_prompt)
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
return prompt_text, prompt_token_ids, token_type_ids, cache_salt
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
async def _process_tokens_async(
self,
parsed_content: TokensPrompt,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_token_ids = parsed_content["prompt_token_ids"]
token_type_ids = parsed_content.get("token_type_ids")
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_token_ids,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
inputs = token_inputs(
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _process_text(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = self._process_multimodal(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
async def _process_text_async(
self,
parsed_content: TextPrompt,
tokenization_kwargs: Optional[dict[str, Any]] = None,
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> Union[TokenInputs, MultiModalInputs]:
prompt_text = parsed_content["prompt"]
inputs: Union[TokenInputs, MultiModalInputs]
if multi_modal_data := parsed_content.get("multi_modal_data"):
inputs = await self._process_multimodal_async(
prompt_text,
multi_modal_data,
parsed_content.get("mm_processor_kwargs"),
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
else:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
tokenization_kwargs=tokenization_kwargs,
)
inputs = token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
)
if cache_salt := parsed_content.get("cache_salt"):
inputs["cache_salt"] = cache_salt
return inputs
def _prompt_to_llm_inputs(
self,
@ -328,36 +478,31 @@ class InputPreprocessor:
* :class:`SingletonInputs` instance
"""
parsed = parse_singleton_prompt(prompt)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
# If multimodal data is present, process and return immediately
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = self._process_multimodal(
prompt_text if prompt_text is not None else prompt_token_ids,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
if parsed["type"] == "embeds":
return self._process_embeds(parsed["content"])
if parsed["type"] == "tokens":
return self._process_tokens(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = self._tokenize_prompt(
prompt_text,
lora_request=lora_request,
if parsed["type"] == "text":
return self._process_text(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if parsed["type"] == "str":
return self._process_text(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
assert_never(parsed)
async def _prompt_to_llm_inputs_async(
self,
@ -366,49 +511,49 @@ class InputPreprocessor:
lora_request: Optional[LoRARequest] = None,
return_mm_hashes: bool = False,
) -> SingletonInputs:
"""Async version of :meth:`_extract_prompt_components`."""
"""Async version of :meth:`_prompt_to_llm_inputs`."""
parsed = parse_singleton_prompt(prompt)
prompt_text, prompt_token_ids, token_type_ids, cache_salt = \
self._get_prompt_data(parsed)
if parsed["type"] != "str" and parsed["content"].get(
"multi_modal_data") is not None:
inputs = await self._process_multimodal_async(
prompt_token_ids if prompt_text is None else prompt_text,
parsed["content"]["multi_modal_data"],
parsed["content"].get("mm_processor_kwargs"),
if parsed["type"] == "embeds":
return await self._process_embeds_async(parsed["content"])
if parsed["type"] == "tokens":
return await self._process_tokens_async(
parsed["content"],
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if cache_salt is not None:
inputs["cache_salt"] = cache_salt
return inputs
if prompt_token_ids is None:
prompt_token_ids = await self._tokenize_prompt_async(
prompt_text,
lora_request=lora_request,
if parsed["type"] == "text":
return await self._process_text_async(
parsed["content"],
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
if parsed["type"] == "str":
return await self._process_text_async(
TextPrompt(prompt=parsed["content"]),
tokenization_kwargs=tokenization_kwargs,
lora_request=lora_request,
return_mm_hashes=return_mm_hashes,
)
return token_inputs(
prompt=prompt_text,
prompt_token_ids=prompt_token_ids,
token_type_ids=token_type_ids,
cache_salt=cache_salt,
)
assert_never(parsed)
def _build_enc_dec_llm_inputs(
self,
encoder_inputs: SingletonInputs,
decoder_inputs: Optional[SingletonInputs],
) -> EncoderDecoderInputs:
if (encoder_inputs["type"] == "token"
or encoder_inputs["type"] == "multimodal"):
pass
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
if (encoder_inputs["type"] == "embeds"
or decoder_inputs and decoder_inputs["type"] == "embeds"):
raise ValueError("Embedding inputs are not supported for encoder-"
"decoder models")
# Needed for mypy
encoder_inputs = cast(Union[TokenInputs, MultiModalInputs],
encoder_inputs)
decoder_inputs = cast(Optional[Union[TokenInputs, MultiModalInputs]],
decoder_inputs)
if decoder_inputs is None:
if self.model_config.hf_config.model_type == "whisper":
@ -421,73 +566,78 @@ class InputPreprocessor:
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
None)
decoder_inputs = token_inputs(dec_token_ids)
elif (decoder_inputs["type"] == "token"
or decoder_inputs["type"] == "multimodal"):
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
else:
if "multi_modal_data" in decoder_inputs:
raise ValueError("Multi-modal decoder inputs of encoder-"
"decoder models are not supported yet")
else:
assert_never(encoder_inputs) # type: ignore[arg-type]
dec_token_ids = self._prepare_decoder_input_ids_for_generation(
decoder_inputs["prompt_token_ids"])
decoder_inputs["prompt_token_ids"] = dec_token_ids
return EncoderDecoderInputs(
encoder=encoder_inputs,
decoder=decoder_inputs,
)
def _separate_enc_dec_inputs_from_mm_processor_outputs(
def _split_enc_dec_mm_inputs(
self,
inputs: SingletonInputs,
inputs: Union[SingletonInputs, MultiModalEncDecInputs],
decoder_inputs_to_override: Optional[SingletonInputs] = None,
) -> tuple[SingletonInputs, SingletonInputs]:
"""
For encoder/decoder models only:
Separate Encoder/Decoder inputs from a MultiModalEncDecInputs
"""
if (inputs["type"] == "embeds" or decoder_inputs_to_override
and decoder_inputs_to_override["type"] == "embeds"):
raise ValueError("Embedding inputs are not supported for encoder-"
"decoder models")
# Needed for mypy
inputs = cast(
Union[TokenInputs, MultiModalInputs, MultiModalEncDecInputs],
inputs,
)
decoder_inputs_to_override = cast(
Optional[Union[TokenInputs, MultiModalInputs]],
decoder_inputs_to_override,
)
encoder_inputs: SingletonInputs
decoder_inputs: SingletonInputs
if inputs["type"] == "multimodal":
# Multimodal data inputs
assert ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs)
if inputs["type"] == "multimodal": # Multimodal data inputs
if not ("encoder_prompt" in inputs
and "encoder_prompt_token_ids" in inputs):
raise RuntimeError("You should register an encoder-decoder "
"multi-modal processor for encoder-decoder "
"models.")
inputs = cast(MultiModalEncDecInputs, inputs)
encoder_inputs = token_inputs(
prompt=inputs["encoder_prompt"],
prompt_token_ids=inputs["encoder_prompt_token_ids"],
)
if decoder_inputs_to_override is not None:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_inputs_to_override.get("prompt", ""),
prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
else:
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
cache_salt = inputs.get("cache_salt")
if cache_salt is not None:
decoder_prompt_inputs = decoder_inputs_to_override or inputs
decoder_inputs = MultiModalInputs(
type="multimodal",
prompt=decoder_prompt_inputs.get("prompt", ""),
prompt_token_ids=decoder_prompt_inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"],
)
if cache_salt := inputs.get("cache_salt"):
decoder_inputs["cache_salt"] = cache_salt
elif inputs["type"] == "token":
# Text-only inputs
elif inputs["type"] == "token": # Text-only inputs
encoder_inputs = token_inputs(prompt="", prompt_token_ids=[])
decoder_inputs = decoder_inputs_to_override or inputs
else:
assert_never(inputs) # type: ignore[arg-type]
return encoder_inputs, decoder_inputs
def _process_encoder_decoder_prompt(
@ -541,8 +691,8 @@ class InputPreprocessor:
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
self._split_enc_dec_mm_inputs(encoder_inputs,
decoder_inputs))
else:
inputs = self._prompt_to_llm_inputs(
prompt,
@ -551,11 +701,9 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
self._split_enc_dec_mm_inputs(inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
@ -591,8 +739,8 @@ class InputPreprocessor:
# with explicit decoder prompt.
if self.model_config.is_multimodal_model:
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
encoder_inputs, decoder_inputs))
self._split_enc_dec_mm_inputs(encoder_inputs,
decoder_inputs))
else:
inputs = await self._prompt_to_llm_inputs_async(
prompt,
@ -601,11 +749,9 @@ class InputPreprocessor:
if self.model_config.is_multimodal_model:
# Encoder-Decoder Multimodal model
encoder_inputs, decoder_inputs = (
self._separate_enc_dec_inputs_from_mm_processor_outputs(
inputs))
self._split_enc_dec_mm_inputs(inputs))
else:
encoder_inputs = inputs
decoder_inputs = None
return self._build_enc_dec_llm_inputs(encoder_inputs, decoder_inputs)
@ -615,14 +761,13 @@ class InputPreprocessor:
prompt_inputs: DecoderOnlyInputs,
prompt_adapter_request: Optional[PromptAdapterRequest],
) -> DecoderOnlyInputs:
if (prompt_inputs["type"] == "token"
or prompt_inputs["type"] == "multimodal"):
if "prompt_token_ids" in prompt_inputs:
prompt_inputs = cast(Union[TokenInputs, MultiModalInputs],
prompt_inputs) # Needed for mypy
prompt_inputs["prompt_token_ids"] = self._apply_prompt_adapter(
prompt_inputs["prompt_token_ids"],
prompt_adapter_request=prompt_adapter_request,
)
else:
assert_never(prompt_inputs) # type: ignore[arg-type]
return prompt_inputs

View File

@ -0,0 +1,200 @@
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}

View File

@ -71,8 +71,8 @@ def single_marlin_moe(
E = w.shape[0]
N = w.shape[2] // (num_bits // 2)
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
# This might not be an optimal config for a single MMM
get_config_func = functools.partial(try_get_optimal_moe_config,

View File

@ -854,7 +854,7 @@ def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
@ -868,20 +868,19 @@ def fused_topk(
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indices = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
gating_output_float = gating_output.float() # TODO(woosuk): Optimize this.
topk_func = dispatch_topk_func()
topk_weights, topk_ids = topk_func(topk_weights, topk_ids,
token_expert_indicies,
token_expert_indices,
gating_output_float, renormalize)
del token_expert_indicies # Not used. Will be used in the future.
return topk_weights, topk_ids
return topk_weights, topk_ids, token_expert_indices
# This is used by the Deepseek-V2 and Deepseek-V3 model
@ -1510,8 +1509,8 @@ def fused_moe(
topk, renormalize,
num_expert_group, topk_group)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states, gating_output, topk, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize)

View File

@ -801,10 +801,11 @@ class FusedMoE(torch.nn.Module):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
topk_weights, topk_ids, token_expert_indices = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,

View File

@ -0,0 +1,116 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Tuple
import torch
def moe_permute(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
expert_map: Optional[torch.Tensor] = None,
align_block_size: Optional[int] = None,
fill_invalid_expert: int = -1
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
This function expands and permutes activation to gather uncontinuous tokens
for each expert.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token.
- token_expert_indices (torch.Tensor): indice for expanded hidden.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
- expert_map (Optional[torch.Tensor]): A tensor mapping expert indices
from the global expert space to the local expert space of the expert
parallel shard.
- align_block_size (Optional[int]): align group gemm block size for deepgemm
- fill_invalid_expert(int): fill expert id in m_indices for invalid expert
to workaround DeepGemm unsupported -1 in m_indices
Returns:
- permuted_hidden_states (torch.Tensor): permuted activation.
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for standard grouped gemm. if enable 'align_block_size'
expert_first_token_offset will align up to 'align_block_size'.
- src_row_id2dst_row_id_map (torch.Tensor): idx map for moe_unpermute.
- m_indices: m_indices for grouped gemm in deepgemm,`m_indices[i]` records
the group which the j-th row of the LHS belong to.`
"""
n_token, n_hidden = hidden_states.shape
assert (n_hidden * hidden_states.element_size()
) % 16 == 0, "permue kernel need hidden dim align to 16B"
permuted_row_size = n_token * topk
if align_block_size is not None:
permuted_row_size = (permuted_row_size + n_expert *
(align_block_size - 1) + align_block_size -
1) // align_block_size * align_block_size
permuted_hidden_states = torch.empty(
(permuted_row_size, n_hidden),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
m_indices = torch.full((permuted_row_size, ),
fill_invalid_expert,
dtype=torch.int32,
device=hidden_states.device)
expert_first_token_offset = torch.empty(n_local_expert + 1,
dtype=torch.int64,
device=hidden_states.device)
src_row_id2dst_row_id_map = torch.empty((n_token, topk),
dtype=torch.int32,
device=hidden_states.device)
torch.ops._moe_C.moe_permute(hidden_states, topk_weights, topk_ids,
token_expert_indices, expert_map, n_expert,
n_local_expert, topk, align_block_size,
permuted_hidden_states,
expert_first_token_offset,
src_row_id2dst_row_id_map, m_indices)
return (permuted_hidden_states, expert_first_token_offset,
src_row_id2dst_row_id_map, m_indices)
def moe_unpermute(
permuted_hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
src_row_id2dst_row_id_map: torch.Tensor,
expert_first_token_offset: torch.Tensor,
topk: int,
n_expert: int,
n_local_expert: int,
) -> torch.Tensor:
"""
This function expands and permutes activation to gathering uncontinuous
tokens for each expert.
Parameters:
- permuted_hidden_states (torch.Tensor): permuted activation.
- topk_weights (torch.Tensor): topk expert route weight for each token.
- topk_ids (torch.Tensor): topk expert route id for each token.
- expert_first_token_offset (torch.Tensor): offset of the first token
of each expert for grouped gemm.
- topk (int): The number of top-k experts to select.
- n_expert (int): The number of expert.
- n_local_expert (int): The number of expert in current EP rank.
Returns:
- hidden_states (torch.Tensor): The reduced and unpermuted activation
tensor.
"""
n_token, n_hidden = topk_weights.shape[0], permuted_hidden_states.shape[-1]
assert (n_hidden * permuted_hidden_states.element_size()
) % 16 == 0, "unpermue kernel need hidden dim align to 16B"
hidden_states = torch.empty((n_token, n_hidden),
dtype=permuted_hidden_states.dtype,
device=permuted_hidden_states.device)
torch.ops._moe_C.moe_unpermute(permuted_hidden_states, topk_weights,
topk_ids, src_row_id2dst_row_id_map,
expert_first_token_offset, n_expert,
n_local_expert, topk, hidden_states)
return hidden_states

View File

@ -140,7 +140,7 @@ class AWQMarlinConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
logger.warning_once(
f"Layer '{prefix}' is not supported by AWQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(

View File

@ -34,6 +34,7 @@ __all__ = [
"CompressedTensorsMoEMethod",
"CompressedTensorsW8A8Fp8MoEMethod",
"CompressedTensorsW8A8Fp8MoECutlassMethod",
"CompressedTensorsW8A8Int8MoEMethod",
"CompressedTensorsWNA16MarlinMoEMethod",
"CompressedTensorsWNA16MoEMethod",
]
@ -71,6 +72,8 @@ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
return CompressedTensorsW8A8Fp8MoECutlassMethod(quant_config)
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
elif quant_config._is_dynamic_token_w8a8(weight_quant, input_quant):
return CompressedTensorsW8A8Int8MoEMethod(quant_config)
else:
raise RuntimeError(
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}")
@ -545,6 +548,138 @@ class CompressedTensorsW8A8Fp8MoECutlassMethod(CompressedTensorsMoEMethod):
)
class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
def __init__(
self,
quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
):
self.quant_config = quant_config
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get(
"weights")
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
"input_activations")
per_channel = (
self.weight_quant.strategy == QuantizationStrategy.CHANNEL
and self.input_quant.strategy == QuantizationStrategy.TOKEN)
if not per_channel:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found "
f"{self.weight_quant}, {self.input_quant}")
self.static_input_scales = not self.input_quant.dynamic
if self.static_input_scales:
raise ValueError(
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype = torch.int8
# WEIGHTS
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# WEIGHT_SCALES
assert self.weight_quant.strategy == QuantizationStrategy.CHANNEL
w13_weight_scale = torch.nn.Parameter(torch.ones(
num_experts,
2 * intermediate_size_per_partition,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
w2_weight_scale = torch.nn.Parameter(torch.ones(num_experts,
hidden_size,
1,
dtype=torch.float32),
requires_grad=False)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
# Add PER-CHANNEL quantization for FusedMoE.weight_loader.
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value})
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# INPUT_SCALES
assert not self.static_input_scales
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
pass
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
def __init__(

View File

@ -134,7 +134,7 @@ class GPTQBitBLASConfig(QuantizationConfig):
@classmethod
def get_min_capability(cls) -> int:
return 70
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:

View File

@ -157,7 +157,7 @@ class GPTQMarlinConfig(QuantizationConfig):
from vllm.model_executor.layers.quantization.moe_wna16 import (
MoeWNA16Config)
if not check_moe_marlin_supports_layer(layer, self.group_size):
logger.warning_one(
logger.warning_once(
f"Layer '{prefix}' is not supported by GPTQMoeMarlin. "
"Falling back to Moe WNA16 kernels.")
return MoeWNA16Config.from_config(

Some files were not shown because too many files have changed in this diff Show More