mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
147 Commits
fix-doc-bu
...
codex/chan
Author | SHA1 | Date | |
---|---|---|---|
5bb81e284b | |||
01513a334a | |||
ac2bf41e53 | |||
a931b4cdcf | |||
a0f8a79646 | |||
18bdcf4113 | |||
1c3198b6c4 | |||
260127ea54 | |||
d0dc4cfca4 | |||
d31a647124 | |||
85431bd9ad | |||
c11013db8b | |||
1eb2b9c102 | |||
6ebf313790 | |||
cfbcb9ed87 | |||
76ddeff293 | |||
f46098335b | |||
e9534c7202 | |||
7976446015 | |||
fcb9f879c1 | |||
3ed94f9d0a | |||
fa839565f2 | |||
75a99b98bf | |||
b5c3b68359 | |||
6cbc4d4bea | |||
153c6f1e61 | |||
34cda778a0 | |||
30800b01c2 | |||
10be209493 | |||
19c863068b | |||
f29fd8a7f8 | |||
ed10f3cea1 | |||
b637e9dcb8 | |||
1e36c8687e | |||
5bac61362b | |||
313ae8c16a | |||
c847e34b39 | |||
e7e3e6d263 | |||
4ffd963fa0 | |||
56fe4bedd6 | |||
d91278181d | |||
20149d84d9 | |||
3534c39a20 | |||
c586b55667 | |||
33d560001e | |||
f148c44c6a | |||
235bfd5dfe | |||
68d28e37b0 | |||
37a7d5d74a | |||
d4d309409f | |||
85bd6599e4 | |||
91b3d190ae | |||
fc017915f5 | |||
9ad0a4588b | |||
016b8d1b7f | |||
80305c1b24 | |||
37e2ecace2 | |||
054c8657e3 | |||
d4170fad39 | |||
946aadb4a0 | |||
bcdfb2a330 | |||
ba8c300018 | |||
8cdc371217 | |||
61e20828da | |||
55e1c66da5 | |||
86f3ac21ce | |||
149f2435a5 | |||
c0569dbc82 | |||
8bb43b9c9e | |||
559756214b | |||
6d0cf239c6 | |||
3fc964433a | |||
0caf61c08a | |||
667624659b | |||
38efa28278 | |||
e8cc53af5e | |||
a4851cfe68 | |||
9887e8ec50 | |||
f326ab9c88 | |||
dcf2a5e208 | |||
1e9438e0b0 | |||
697ef765ee | |||
a99b9f7dee | |||
c488b928a7 | |||
2c7fa47161 | |||
88fc8a97e3 | |||
66f6fbd393 | |||
8632e831ba | |||
4bbfc36b16 | |||
80d38b8ac8 | |||
211b6a6113 | |||
247102f07f | |||
bd4c1e6fdb | |||
99b4f080d8 | |||
020f58abcd | |||
c1acd6d7d4 | |||
3b3b778d4a | |||
42d440c22b | |||
f45a332886 | |||
6e2c176e1f | |||
a86754a12b | |||
c2a2f19aba | |||
2c11a738b3 | |||
b639327ad9 | |||
4afe687a82 | |||
5de8d9f111 | |||
c1c8ca57ff | |||
a3a5a47e48 | |||
fb25e95688 | |||
0d4891cd03 | |||
f56d2996ca | |||
147afb448b | |||
3c7d942da8 | |||
890323dc1b | |||
01cae37713 | |||
11c0198615 | |||
b1235c3e10 | |||
44d02f54db | |||
a8593237c0 | |||
fc0f41d10a | |||
7b828e30d5 | |||
5f0af36af5 | |||
0d21b2664c | |||
9907fc4494 | |||
d47661f0cd | |||
53fa457391 | |||
6fb162447b | |||
66177189c5 | |||
b4f0b5f9aa | |||
cbd14ed561 | |||
7bd4c37ae7 | |||
8020e98c9f | |||
762be26a8e | |||
6a9e6b2abf | |||
5d09152ff1 | |||
31d5c1797f | |||
35514b682a | |||
e2de455c34 | |||
5b032352cc | |||
922f316441 | |||
5923ab9524 | |||
0cf893cae1 | |||
cf75cd2098 | |||
b854321ffe | |||
5b6fe23d05 | |||
f0c98cae27 | |||
574ad60db9 |
@ -6,19 +6,17 @@ set -exuo pipefail
|
||||
|
||||
# Try building the docker image
|
||||
cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - .
|
||||
FROM 1.22-413-pt2.7.1:latest
|
||||
FROM gaudi-base-image:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
|
||||
RUN VLLM_TARGET_DEVICE=empty pip install .
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
@ -117,7 +117,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 40min
|
||||
- label: Entrypoints Test (LLM) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
@ -125,8 +125,6 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
@ -135,9 +133,21 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Test (API Server) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -630,6 +640,18 @@ steps:
|
||||
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
|
||||
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
|
||||
|
||||
- label: Transformers Nightly Models Test
|
||||
working_dir: "/vllm-workspace/"
|
||||
optional: true
|
||||
commands:
|
||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||
- pytest -v -s tests/models/test_initialization.py
|
||||
- pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
- python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
|
6
.gemini/config.yaml
Normal file
6
.gemini/config.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
# https://developers.google.com/gemini-code-assist/docs/customize-gemini-behavior-github
|
||||
have_fun: false # Just review the code
|
||||
code_review:
|
||||
comment_severity_threshold: HIGH # Reduce quantity of comments
|
||||
pull_request_opened:
|
||||
summary: false # Don't summarize the PR in a separate comment
|
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@ -16,6 +16,7 @@
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
/vllm/compilation @zou3519 @youkaichao
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
|
@ -21,7 +21,7 @@ repos:
|
||||
- id: ruff-format
|
||||
files: ^(.buildkite|benchmarks|examples)/.*
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.32.0
|
||||
rev: v1.34.0
|
||||
hooks:
|
||||
- id: typos
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
@ -166,7 +166,7 @@ repos:
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py
|
||||
files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
@ -171,16 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Set nvcc fatbin compression.
|
||||
#
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8)
|
||||
list(APPEND VLLM_GPU_FLAGS "-Xfatbin" "-compress-all" "-compress-mode=size")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
|
||||
@ -563,7 +553,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
|
@ -63,13 +63,11 @@ vLLM is fast with:
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
|
||||
**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
||||
|
@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to prevent
|
||||
# sampling 0 tokens, which can cause request failures.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
|
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
@ -86,6 +86,9 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
|
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
page_size,
|
||||
max_kv_len,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
@ -24,6 +24,7 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
@ -0,0 +1,372 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
// printf("set_split_kv start");
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
// std::cout << H << " " << K << " " << D << " " << B << "\n";
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
// printf(" sm_count = %d\n", sm_count);
|
||||
int max_splits = ceil_div(K, 128);
|
||||
max_splits = min(16, max_splits);
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
// printf(" sms_per_batch = %d\n", sms_per_batch);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
// printf(" args.split_kv = %d\n", args.split_kv);
|
||||
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,203 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,165 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
283
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
283
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
@ -0,0 +1,283 @@
|
||||
/*
|
||||
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/kernel_hardware_info.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
// clang-format off
|
||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||
}
|
||||
#else
|
||||
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape,
|
||||
Element,
|
||||
ElementAcc,
|
||||
ElementOut,
|
||||
ElementAcc,
|
||||
TileScheduler,
|
||||
/*kIsCpAsync=*/!IsPaged128>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_nope = cute::make_tuple(
|
||||
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||
StrideQ stride_Q_pe = cute::make_tuple(
|
||||
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||
|
||||
StrideK stride_C = cute::make_tuple(
|
||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale,
|
||||
Q_nope_ptr,
|
||||
stride_Q_nope,
|
||||
Q_pe_ptr,
|
||||
stride_Q_pe,
|
||||
C_ptr,
|
||||
stride_C,
|
||||
C_ptr + D_latent,
|
||||
stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()),
|
||||
stride_PT,
|
||||
page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
// Maybe per batch split kv will fix this.
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
||||
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
||||
arguments.problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
|
||||
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
|
||||
}
|
||||
|
||||
// clang-format on
|
@ -18,12 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -187,7 +182,6 @@ void paged_attention_v1(
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -18,12 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -197,7 +192,6 @@ void paged_attention_v2(
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -58,7 +58,7 @@ namespace {
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
|
@ -126,7 +126,7 @@ void fused_experts_int4_w4a16_kernel_impl(
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
|
@ -41,7 +41,7 @@ struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
// oops! 4x4 spills but luckily we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
|
@ -37,7 +37,7 @@ inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vecto
|
||||
#define CVT_FP16_TO_FP32(a) \
|
||||
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't hanel NaN.
|
||||
// this doesn't handle NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
|
@ -4,10 +4,10 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#if defined(USE_ROCM) && defined(__GFX9__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
@ -7,7 +7,11 @@
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
||||
#else
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
@ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifdef USE_ROCM
|
||||
C10_HIP_CHECK(hipFuncSetAttribute(
|
||||
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
@ -653,4 +664,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -201,11 +201,10 @@ void run_blockwise_scaled_group_mm(
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_ptrs.get_device();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
int device_id = a_ptrs.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
|
@ -30,35 +30,40 @@
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits;
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
// Configuration for M in (256, inf)
|
||||
struct sm100_fp4_config_default {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in (16, 256]
|
||||
struct sm100_fp4_config_M256 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in [1, 16]
|
||||
struct sm100_fp4_config_M16 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::half_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::bfloat16_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename Config, typename OutType>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using ElementD = OutType;
|
||||
using ElementC = OutType;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
||||
// Use config's tile shapes
|
||||
using MmaTileShape = typename Config::TileShape;
|
||||
using ClusterShape = typename Config::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
template <typename Config>
|
||||
typename Config::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementA = typename Config::Gemm::ElementA;
|
||||
using ElementB = typename Config::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementD = typename Config::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using StrideA = typename Config::StrideA;
|
||||
using StrideB = typename Config::StrideB;
|
||||
using StrideD = typename Config::StrideD;
|
||||
using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
|
||||
CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
typename Config::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename Config>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
typename Config::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (16, 256]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
// m in (256, inf)
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||
k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||
m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
@ -514,6 +514,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
|
@ -63,7 +63,7 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
|
||||
ARG PIP_KEYRING_PROVIDER=disabled
|
||||
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
|
||||
|
||||
# Flag enables build-in KV-connector dependency libs into docker images
|
||||
# Flag enables built-in KV-connector dependency libs into docker images
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
@ -207,6 +207,19 @@ ARG SCCACHE_ENDPOINT
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Flag to control whether to use pre-built vLLM wheels
|
||||
ARG VLLM_USE_PRECOMPILED
|
||||
# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed
|
||||
ENV VLLM_USE_PRECOMPILED=""
|
||||
RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \
|
||||
export VLLM_USE_PRECOMPILED=1 && \
|
||||
echo "Using precompiled wheels"; \
|
||||
else \
|
||||
unset VLLM_USE_PRECOMPILED && \
|
||||
echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \
|
||||
fi
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
@ -408,7 +421,8 @@ RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
|
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="6487649"
|
||||
ARG AITER_BRANCH="916bf3c"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
@ -36,7 +36,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
|
@ -8,7 +8,6 @@ API documentation for vLLM's configuration classes.
|
||||
|
||||
- [vllm.config.ModelConfig][]
|
||||
- [vllm.config.CacheConfig][]
|
||||
- [vllm.config.TokenizerPoolConfig][]
|
||||
- [vllm.config.LoadConfig][]
|
||||
- [vllm.config.ParallelConfig][]
|
||||
- [vllm.config.SchedulerConfig][]
|
||||
|
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 84 KiB |
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
Binary file not shown.
Before Width: | Height: | Size: 68 KiB After Width: | Height: | Size: 57 KiB |
@ -1,3 +1,7 @@
|
||||
---
|
||||
toc_depth: 4
|
||||
---
|
||||
|
||||
# vLLM CLI Guide
|
||||
|
||||
The vllm command-line tool is used to run and manage vLLM models. You can start by viewing the help message with:
|
||||
@ -37,8 +41,15 @@ Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
|
||||
# To view full help with pager (less/more)
|
||||
vllm serve --help=page
|
||||
```
|
||||
|
||||
### Options
|
||||
|
||||
--8<-- "docs/argparse/serve.md"
|
||||
|
||||
## chat
|
||||
|
||||
Generate chat completions via the running API server.
|
||||
|
@ -5,7 +5,7 @@ The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
## CLI Arguments
|
||||
|
||||
The `vllm serve` command is used to launch the OpenAI-compatible server.
|
||||
To see the available CLI arguments, run `vllm serve --help`!
|
||||
To see the available options, take a look at the [CLI Reference](../cli/README.md#options)!
|
||||
|
||||
## Configuration file
|
||||
|
||||
|
@ -73,6 +73,8 @@ def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
...
|
||||
```
|
||||
|
@ -3,6 +3,15 @@
|
||||
[](){ #deployment-anyscale }
|
||||
|
||||
[Anyscale](https://www.anyscale.com) is a managed, multi-cloud platform developed by the creators of Ray.
|
||||
It hosts Ray clusters inside your own AWS, GCP, or Azure account, delivering the flexibility of open-source Ray
|
||||
without the operational overhead of maintaining Kubernetes control planes, configuring autoscalers, or managing observability stacks.
|
||||
|
||||
Anyscale automates the entire lifecycle of Ray clusters in your AWS, GCP, or Azure account, delivering the flexibility of open-source Ray
|
||||
without the operational overhead of maintaining Kubernetes control planes, configuring autoscalers, managing observability stacks, or manually managing head and worker nodes with helper scripts like <gh-file:examples/online_serving/run_cluster.sh>.
|
||||
|
||||
When serving large language models with vLLM, Anyscale can rapidly provision [production-ready HTTPS endpoints](https://docs.anyscale.com/examples/deploy-ray-serve-llms) or [fault-tolerant batch inference jobs](https://docs.anyscale.com/examples/ray-data-llm).
|
||||
|
||||
## Production-ready vLLM on Anyscale quickstarts
|
||||
|
||||
- [Offline batch inference](https://console.anyscale.com/template-preview/llm_batch_inference?utm_source=vllm_docs)
|
||||
- [Deploy vLLM services](https://console.anyscale.com/template-preview/llm_serving?utm_source=vllm_docs)
|
||||
- [Curate a dataset](https://console.anyscale.com/template-preview/audio-dataset-curation-llm-judge?utm_source=vllm_docs)
|
||||
- [Finetune an LLM](https://console.anyscale.com/template-preview/entity-recognition-with-llms?utm_source=vllm_docs)
|
||||
|
@ -1,26 +1,42 @@
|
||||
# Open WebUI
|
||||
|
||||
1. Install the [Docker](https://docs.docker.com/engine/install/)
|
||||
[Open WebUI](https://github.com/open-webui/open-webui) is an extensible, feature-rich,
|
||||
and user-friendly self-hosted AI platform designed to operate entirely offline.
|
||||
It supports various LLM runners like Ollama and OpenAI-compatible APIs,
|
||||
with built-in RAG capabilities, making it a powerful AI deployment solution.
|
||||
|
||||
2. Start the vLLM server with the supported chat completion model, e.g.
|
||||
To get started with Open WebUI using vLLM, follow these steps:
|
||||
|
||||
```bash
|
||||
vllm serve qwen/Qwen1.5-0.5B-Chat
|
||||
```
|
||||
1. Install the [Docker](https://docs.docker.com/engine/install/).
|
||||
|
||||
1. Start the [Open WebUI](https://github.com/open-webui/open-webui) docker container (replace the vllm serve host and vllm serve port):
|
||||
2. Start the vLLM server with a supported chat completion model:
|
||||
|
||||
```bash
|
||||
docker run -d -p 3000:8080 \
|
||||
--name open-webui \
|
||||
-v open-webui:/app/backend/data \
|
||||
-e OPENAI_API_BASE_URL=http://<vllm serve host>:<vllm serve port>/v1 \
|
||||
--restart always \
|
||||
ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
```console
|
||||
vllm serve Qwen/Qwen3-0.6B-Chat
|
||||
```
|
||||
|
||||
1. Open it in the browser: <http://open-webui-host:3000/>
|
||||
!!! note
|
||||
When starting the vLLM server, be sure to specify the host and port using the `--host` and `--port` flags.
|
||||
For example:
|
||||
|
||||
On the top of the web page, you can see the model `qwen/Qwen1.5-0.5B-Chat`.
|
||||
```console
|
||||
python -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||

|
||||
3. Start the Open WebUI Docker container:
|
||||
|
||||
```console
|
||||
docker run -d \
|
||||
--name open-webui \
|
||||
-p 3000:8080 \
|
||||
-v open-webui:/app/backend/data \
|
||||
-e OPENAI_API_BASE_URL=http://0.0.0.0:8000/v1 \
|
||||
--restart always \
|
||||
ghcr.io/open-webui/open-webui:main
|
||||
```
|
||||
|
||||
4. Open it in the browser: <http://open-webui-host:3000/>
|
||||
|
||||
At the top of the page, you should see the model `Qwen/Qwen3-0.6B-Chat`.
|
||||
|
||||

|
||||
|
20
docs/deployment/integrations/kuberay.md
Normal file
20
docs/deployment/integrations/kuberay.md
Normal file
@ -0,0 +1,20 @@
|
||||
# KubeRay
|
||||
|
||||
[KubeRay](https://github.com/ray-project/kuberay) provides a Kubernetes-native way to run vLLM workloads on Ray clusters.
|
||||
A Ray cluster can be declared in YAML, and the operator then handles pod scheduling, networking configuration, restarts, and blue-green deployments — all while preserving the familiar Kubernetes experience.
|
||||
|
||||
## Why KubeRay instead of manual scripts?
|
||||
|
||||
| Feature | Manual scripts | KubeRay |
|
||||
|---------|-----------------------------------------------------------|---------|
|
||||
| Cluster bootstrap | Manually SSH into every node and run a script | One command to create or update the whole cluster: `kubectl apply -f cluster.yaml` |
|
||||
| Autoscaling | Manual | Automatically patches CRDs for adjusting cluster size |
|
||||
| Upgrades | Tear down & re-create manually | Blue/green deployment updates supported |
|
||||
| Declarative config | Bash flags & environment variables | Git-ops-friendly YAML CRDs (RayCluster/RayService) |
|
||||
|
||||
Using KubeRay reduces the operational burden and simplifies integration of Ray + vLLM with existing Kubernetes workflows (CI/CD, secrets, storage classes, etc.).
|
||||
|
||||
## Learn more
|
||||
|
||||
* ["Serve a Large Language Model using Ray Serve LLM on Kubernetes"](https://docs.ray.io/en/master/cluster/kubernetes/examples/rayserve-llm-example.html) - An end-to-end example of how to serve a model using vLLM, KubeRay, and Ray Serve.
|
||||
* [KubeRay documentation](https://docs.ray.io/en/latest/cluster/kubernetes/index.html)
|
@ -13,6 +13,7 @@ Alternatively, you can deploy vLLM to Kubernetes using any of the following:
|
||||
- [Helm](frameworks/helm.md)
|
||||
- [InftyAI/llmaz](integrations/llmaz.md)
|
||||
- [KServe](integrations/kserve.md)
|
||||
- [KubeRay](integrations/kuberay.md)
|
||||
- [kubernetes-sigs/lws](frameworks/lws.md)
|
||||
- [meta-llama/llama-stack](integrations/llamastack.md)
|
||||
- [substratusai/kubeai](integrations/kubeai.md)
|
||||
|
@ -279,64 +279,64 @@ Some models, e.g., [Granite Speech](https://huggingface.co/ibm-granite/granite-s
|
||||
|
||||
To this end, we allow registration of default multimodal LoRAs to handle this automatically, where users can map each modality to a LoRA adapter to automatically apply it when the corresponding inputs are present. Note that currently, we only allow one LoRA per prompt; if several modalities are provided, each of which are registered to a given modality, none of them will be applied.
|
||||
|
||||
Example usage for offline inference:
|
||||
??? code "Example usage for offline inference"
|
||||
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
```python
|
||||
from transformers import AutoTokenizer
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.assets.audio import AudioAsset
|
||||
|
||||
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
model_id = "ibm-granite/granite-speech-3.3-2b"
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
||||
|
||||
def get_prompt(question: str, has_audio: bool):
|
||||
"""Build the input prompt to send to vLLM."""
|
||||
if has_audio:
|
||||
question = f"<|audio|>{question}"
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
def get_prompt(question: str, has_audio: bool):
|
||||
"""Build the input prompt to send to vLLM."""
|
||||
if has_audio:
|
||||
question = f"<|audio|>{question}"
|
||||
chat = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question
|
||||
}
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
|
||||
|
||||
model = LLM(
|
||||
model=model_id,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_model_len=2048,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
# Will always pass a `LoRARequest` with the `model_id`
|
||||
# whenever audio is contained in the request data.
|
||||
default_mm_loras = {"audio": model_id},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
question = "can you transcribe the speech into a written format?"
|
||||
prompt_with_audio = get_prompt(
|
||||
question=question,
|
||||
has_audio=True,
|
||||
)
|
||||
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt_with_audio,
|
||||
"multi_modal_data": {
|
||||
"audio": audio,
|
||||
}
|
||||
]
|
||||
return tokenizer.apply_chat_template(chat, tokenize=False)
|
||||
|
||||
|
||||
model = LLM(
|
||||
model=model_id,
|
||||
enable_lora=True,
|
||||
max_lora_rank=64,
|
||||
max_model_len=2048,
|
||||
limit_mm_per_prompt={"audio": 1},
|
||||
# Will always pass a `LoRARequest` with the `model_id`
|
||||
# whenever audio is contained in the request data.
|
||||
default_mm_loras = {"audio": model_id},
|
||||
enforce_eager=True,
|
||||
)
|
||||
|
||||
question = "can you transcribe the speech into a written format?"
|
||||
prompt_with_audio = get_prompt(
|
||||
question=question,
|
||||
has_audio=True,
|
||||
)
|
||||
audio = AudioAsset("mary_had_lamb").audio_and_sample_rate
|
||||
|
||||
inputs = {
|
||||
"prompt": prompt_with_audio,
|
||||
"multi_modal_data": {
|
||||
"audio": audio,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.2,
|
||||
max_tokens=64,
|
||||
),
|
||||
)
|
||||
```
|
||||
outputs = model.generate(
|
||||
inputs,
|
||||
sampling_params=SamplingParams(
|
||||
temperature=0.2,
|
||||
max_tokens=64,
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can also pass a json dictionary of `--default-mm-loras` mapping modalities to LoRA model IDs. For example, when starting the server:
|
||||
|
||||
|
@ -10,6 +10,7 @@ Contents:
|
||||
- [BitBLAS](bitblas.md)
|
||||
- [GGUF](gguf.md)
|
||||
- [GPTQModel](gptqmodel.md)
|
||||
- [INC](inc.md)
|
||||
- [INT4 W4A16](int4.md)
|
||||
- [INT8 W8A8](int8.md)
|
||||
- [FP8 W8A8](fp8.md)
|
||||
|
56
docs/features/quantization/inc.md
Normal file
56
docs/features/quantization/inc.md
Normal file
@ -0,0 +1,56 @@
|
||||
---
|
||||
title: FP8 INC
|
||||
---
|
||||
[](){ #inc }
|
||||
|
||||
vLLM supports FP8 (8-bit floating point) weight and activation quantization using Intel® Neural Compressor (INC) on Intel® Gaudi® 2 and Intel® Gaudi® 3 AI accelerators.
|
||||
Currently, quantization is validated only in Llama models.
|
||||
|
||||
Intel Gaudi supports quantization of various modules and functions, including, but not limited to `Linear`, `KVCache`, `Matmul` and `Softmax`. For more information, please refer to:
|
||||
[Supported Modules\\Supported Functions\\Custom Patched Modules](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-modules).
|
||||
|
||||
!!! note
|
||||
Measurement files are required to run quantized models with vLLM on Gaudi accelerators. The FP8 model calibration procedure is described in the [vllm-hpu-extention](https://github.com/HabanaAI/vllm-hpu-extension/tree/main/calibration/README.md) package.
|
||||
|
||||
!!! note
|
||||
`QUANT_CONFIG` is an environment variable that points to the measurement or quantization [JSON config file](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Quantization/Inference_Using_FP8.html#supported-json-config-file-options).
|
||||
The measurement configuration file is used during the calibration procedure to collect measurements for a given model. The quantization configuration is used during inference.
|
||||
|
||||
## Run Online Inference Using FP8
|
||||
|
||||
Once you've completed the model calibration process and collected the measurements, you can run FP8 inference with vLLM using the following command:
|
||||
|
||||
```bash
|
||||
export QUANT_CONFIG=/path/to/quant/config/inc/meta-llama-3.1-405b-instruct/maxabs_measure_g3.json
|
||||
vllm serve meta-llama/Llama-3.1-405B-Instruct --quantization inc --kv-cache-dtype fp8_inc --tensor_paralel_size 8
|
||||
```
|
||||
|
||||
!!! tip
|
||||
If you are just prototyping or testing your model with FP8, you can use the `VLLM_SKIP_WARMUP=true` environment variable to disable the warmup stage, which can take a long time. However, we do not recommend disabling this feature in production environments as it causes a significant performance drop.
|
||||
|
||||
!!! tip
|
||||
When using FP8 models, you may experience timeouts caused by the long compilation time of FP8 operations. To mitigate this problem, you can use the below environment variables:
|
||||
`VLLM_ENGINE_ITERATION_TIMEOUT_S` - to adjust the vLLM server timeout. You can set the value in seconds, e.g., 600 equals 10 minutes.
|
||||
`VLLM_RPC_TIMEOUT` - to adjust the RPC protocol timeout used by the OpenAI-compatible API. This value is in microseconds, e.g., 600000 equals 10 minutes.
|
||||
|
||||
## Run Offline Inference Using FP8
|
||||
|
||||
To run offline inference (after completing the model calibration process):
|
||||
|
||||
* Set the "QUANT_CONFIG" environment variable to point to a JSON configuration file with QUANTIZE mode.
|
||||
* Pass `quantization=inc` and `kv_cache_dtype=fp8_inc` as parameters to the `LLM` object.
|
||||
* Call shutdown method of the model_executor at the end of the run.
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
llm = LLM("llama3.1/Meta-Llama-3.1-8B-Instruct", quantization="inc", kv_cache_dtype="fp8_inc")
|
||||
...
|
||||
# Call llm.generate on the required prompts and sampling params.
|
||||
...
|
||||
llm.llm_engine.model_executor.shutdown()
|
||||
```
|
||||
|
||||
## Device for the Model's Weights Uploading
|
||||
|
||||
The unquantized weights are first loaded onto the CPU, then quantized and transferred to the target device (HPU) for model execution.
|
||||
This reduces the device memory footprint of model weights, as only quantized weights are stored in the device memory.
|
@ -2,18 +2,19 @@
|
||||
|
||||
The table below shows the compatibility of various quantization implementations with different hardware platforms in vLLM:
|
||||
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | x86 CPU | AWS Neuron | Google TPU |
|
||||
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-----------|------------------|--------------|
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ✅︎ | ❌ | ❌ |
|
||||
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ✅︎ | ❌ |
|
||||
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ |
|
||||
| Implementation | Volta | Turing | Ampere | Ada | Hopper | AMD GPU | Intel GPU | Intel Gaudi | x86 CPU | AWS Neuron | Google TPU |
|
||||
|-----------------------|---------|----------|----------|-------|----------|-----------|-------------|-------------|-----------|--------------|--------------|
|
||||
| AWQ | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
|
||||
| GPTQ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ✅︎ | ❌ | ✅︎ | ❌ | ❌ |
|
||||
| Marlin (GPTQ/AWQ/FP8) | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INT8 (W8A8) | ❌ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ |
|
||||
| FP8 (W8A8) | ❌ | ❌ | ❌ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ✅︎ | ❌ |
|
||||
| BitBLAS (GPTQ) | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| AQLM | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| bitsandbytes | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| DeepSpeedFP | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| GGUF | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ✅︎ | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| INC (W8A8) | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅︎ | ❌ | ❌ | ❌ |
|
||||
|
||||
- Volta refers to SM 7.0, Turing to SM 7.5, Ampere to SM 8.0/8.6, Ada to SM 8.9, and Hopper to SM 9.0.
|
||||
- ✅︎ indicates that the quantization method is supported on the specified hardware.
|
||||
|
@ -256,12 +256,12 @@ speculative decoding, breaking down the guarantees into three key areas:
|
||||
2. **Algorithmic Losslessness**
|
||||
\- vLLM’s implementation of speculative decoding is algorithmically validated to be lossless. Key validation tests include:
|
||||
|
||||
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
||||
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
||||
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||||
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||||
> provides a lossless guarantee. Almost all of the tests in <gh-dir:tests/spec_decode/e2e>.
|
||||
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
||||
> - **Rejection Sampler Convergence**: Ensures that samples from vLLM’s rejection sampler align with the target
|
||||
> distribution. [View Test Code](https://github.com/vllm-project/vllm/blob/47b65a550866c7ffbd076ecb74106714838ce7da/tests/samplers/test_rejection_sampler.py#L252)
|
||||
> - **Greedy Sampling Equality**: Confirms that greedy sampling with speculative decoding matches greedy sampling
|
||||
> without it. This verifies that vLLM's speculative decoding framework, when integrated with the vLLM forward pass and the vLLM rejection sampler,
|
||||
> provides a lossless guarantee. Almost all of the tests in <gh-dir:tests/spec_decode/e2e>.
|
||||
> verify this property using [this assertion implementation](https://github.com/vllm-project/vllm/blob/b67ae00cdbbe1a58ffc8ff170f0c8d79044a684a/tests/spec_decode/e2e/conftest.py#L291)
|
||||
|
||||
3. **vLLM Logprob Stability**
|
||||
\- vLLM does not currently guarantee stable token log probabilities (logprobs). This can result in different outputs for the
|
||||
|
@ -103,9 +103,7 @@ When tool_choice='required' is set, the model is guaranteed to generate one or m
|
||||
|
||||
vLLM supports the `tool_choice='none'` option in the chat completion API. When this option is set, the model will not generate any tool calls and will respond with regular text content only, even if tools are defined in the request.
|
||||
|
||||
By default, when `tool_choice='none'` is specified, vLLM excludes tool definitions from the prompt to optimize context usage. To include tool definitions even with `tool_choice='none'`, use the `--expand-tools-even-if-tool-choice-none` option.
|
||||
|
||||
Note: This behavior will change in v0.10.0, where tool definitions will be included by default even with `tool_choice='none'`.
|
||||
However, when `tool_choice='none'` is specified, vLLM includes tool definitions from the prompt.
|
||||
|
||||
## Automatic Function Calling
|
||||
|
||||
@ -282,6 +280,14 @@ Supported models:
|
||||
|
||||
Flags: `--tool-call-parser deepseek_v3 --chat-template {see_above}`
|
||||
|
||||
### Kimi-K2 Models (`kimi_k2`)
|
||||
|
||||
Supported models:
|
||||
|
||||
* `moonshotai/Kimi-K2-Instruct`
|
||||
|
||||
Flags: `--tool-call-parser kimi_k2`
|
||||
|
||||
### Models with Pythonic Tool Calls (`pythonic`)
|
||||
|
||||
A growing number of models output a python list to represent tool calls instead of using JSON. This has the advantage of inherently supporting parallel tool calls and removing ambiguity around the JSON schema required for tool calls. The `pythonic` tool parser can support such models.
|
||||
|
@ -28,7 +28,7 @@ To verify that the Intel Gaudi software was correctly installed, run:
|
||||
hl-smi # verify that hl-smi is in your PATH and each Gaudi accelerator is visible
|
||||
apt list --installed | grep habana # verify that habanalabs-firmware-tools, habanalabs-graph, habanalabs-rdma-core, habanalabs-thunk and habanalabs-container-runtime are installed
|
||||
pip list | grep habana # verify that habana-torch-plugin, habana-torch-dataloader, habana-pyhlml and habana-media-loader are installed
|
||||
pip list | grep neural # verify that neural_compressor is installed
|
||||
pip list | grep neural # verify that neural_compressor_pt is installed
|
||||
```
|
||||
|
||||
Refer to [Intel Gaudi Software Stack Verification](https://docs.habana.ai/en/latest/Installation_Guide/SW_Verification.html#platform-upgrade)
|
||||
@ -120,12 +120,13 @@ docker run \
|
||||
- Inference with [HPU Graphs](https://docs.habana.ai/en/latest/PyTorch/Inference_on_PyTorch/Inference_Using_HPU_Graphs.html)
|
||||
for accelerating low-batch latency and throughput
|
||||
- Attention with Linear Biases (ALiBi)
|
||||
- INC quantization
|
||||
|
||||
### Unsupported features
|
||||
|
||||
- Beam search
|
||||
- LoRA adapters
|
||||
- Quantization
|
||||
- AWQ quantization
|
||||
- Prefill chunking (mixed-batch inferencing)
|
||||
|
||||
### Supported configurations
|
||||
@ -133,36 +134,20 @@ docker run \
|
||||
The following configurations have been validated to function with
|
||||
Gaudi2 devices. Configurations that are not listed may or may not work.
|
||||
|
||||
- [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct)
|
||||
on single HPU, or with tensor parallelism on 2x and 8x HPU, BF16
|
||||
datatype with random or greedy sampling
|
||||
- [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
- [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct)
|
||||
with tensor parallelism on 8x HPU, BF16 datatype with random or greedy sampling
|
||||
| Model | TP Size| dtype | Sampling |
|
||||
|-------|--------|--------|----------|
|
||||
| [meta-llama/Llama-2-7b](https://huggingface.co/meta-llama/Llama-2-7b) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Llama-2-7b-chat-hf](https://huggingface.co/meta-llama/Llama-2-7b-chat-hf) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-8B](https://huggingface.co/meta-llama/Meta-Llama-3-8B) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-8B](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) | 1, 2, 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Llama-2-70b](https://huggingface.co/meta-llama/Llama-2-70b) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Llama-2-70b-chat-hf](https://huggingface.co/meta-llama/Llama-2-70b-chat-hf) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-70B](https://huggingface.co/meta-llama/Meta-Llama-3-70B) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-70B-Instruct) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-70B](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B) | 8 | BF16 | Random / Greedy |
|
||||
| [meta-llama/Meta-Llama-3.1-70B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-70B-Instruct) | 8 | BF16 | Random / Greedy |
|
||||
|
||||
## Performance tuning
|
||||
|
||||
|
@ -13,10 +13,10 @@ ARGPARSE_DOC_DIR = ROOT_DIR / "docs/argparse"
|
||||
sys.path.insert(0, str(ROOT_DIR))
|
||||
sys.modules["aiohttp"] = MagicMock()
|
||||
sys.modules["blake3"] = MagicMock()
|
||||
sys.modules["gguf"] = MagicMock()
|
||||
sys.modules["vllm._C"] = MagicMock()
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs, EngineArgs # noqa: E402
|
||||
from vllm.entrypoints.openai.cli_args import make_arg_parser # noqa: E402
|
||||
from vllm.utils import FlexibleArgumentParser # noqa: E402
|
||||
|
||||
logger = logging.getLogger("mkdocs")
|
||||
@ -25,15 +25,18 @@ logger = logging.getLogger("mkdocs")
|
||||
class MarkdownFormatter(HelpFormatter):
|
||||
"""Custom formatter that generates markdown for argument groups."""
|
||||
|
||||
def __init__(self, prog):
|
||||
def __init__(self, prog, starting_heading_level=3):
|
||||
super().__init__(prog,
|
||||
max_help_position=float('inf'),
|
||||
width=float('inf'))
|
||||
self._section_heading_prefix = "#" * starting_heading_level
|
||||
self._argument_heading_prefix = "#" * (starting_heading_level + 1)
|
||||
self._markdown_output = []
|
||||
|
||||
def start_section(self, heading):
|
||||
if heading not in {"positional arguments", "options"}:
|
||||
self._markdown_output.append(f"\n### {heading}\n\n")
|
||||
heading_md = f"\n{self._section_heading_prefix} {heading}\n\n"
|
||||
self._markdown_output.append(heading_md)
|
||||
|
||||
def end_section(self):
|
||||
pass
|
||||
@ -47,9 +50,13 @@ class MarkdownFormatter(HelpFormatter):
|
||||
|
||||
def add_arguments(self, actions):
|
||||
for action in actions:
|
||||
if (len(action.option_strings) == 0
|
||||
or "--help" in action.option_strings):
|
||||
continue
|
||||
|
||||
option_strings = f'`{"`, `".join(action.option_strings)}`'
|
||||
self._markdown_output.append(f"#### {option_strings}\n\n")
|
||||
heading_md = f"{self._argument_heading_prefix} {option_strings}\n\n"
|
||||
self._markdown_output.append(heading_md)
|
||||
|
||||
if choices := action.choices:
|
||||
choices = f'`{"`, `".join(str(c) for c in choices)}`'
|
||||
@ -82,6 +89,14 @@ def create_parser(cls, **kwargs) -> FlexibleArgumentParser:
|
||||
return cls.add_cli_args(parser, **kwargs)
|
||||
|
||||
|
||||
def create_serve_parser() -> FlexibleArgumentParser:
|
||||
"""Create a parser for the serve command with markdown formatting."""
|
||||
parser = FlexibleArgumentParser()
|
||||
parser.formatter_class = lambda prog: MarkdownFormatter(
|
||||
prog, starting_heading_level=4)
|
||||
return make_arg_parser(parser)
|
||||
|
||||
|
||||
def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
logger.info("Generating argparse documentation")
|
||||
logger.debug("Root directory: %s", ROOT_DIR.resolve())
|
||||
@ -96,6 +111,7 @@ def on_startup(command: Literal["build", "gh-deploy", "serve"], dirty: bool):
|
||||
"engine_args": create_parser(EngineArgs),
|
||||
"async_engine_args": create_parser(AsyncEngineArgs,
|
||||
async_args_only=True),
|
||||
"serve": create_serve_parser(),
|
||||
}
|
||||
|
||||
# Generate documentation for each parser
|
||||
|
@ -316,6 +316,7 @@ Specified using `--task generate`.
|
||||
| `AquilaForCausalLM` | Aquila, Aquila2 | `BAAI/Aquila-7B`, `BAAI/AquilaChat-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `ArcticForCausalLM` | Arctic | `Snowflake/snowflake-arctic-base`, `Snowflake/snowflake-arctic-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `BaiChuanForCausalLM` | Baichuan2, Baichuan | `baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BailingMoeForCausalLM` | Ling | `inclusionAI/Ling-lite-1.5`, `inclusionAI/Ling-plus`, etc. | | ✅︎ | ✅︎ |
|
||||
| `BambaForCausalLM` | Bamba | `ibm-ai-platform/Bamba-9B-fp8`, `ibm-ai-platform/Bamba-9B` | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `BloomForCausalLM` | BLOOM, BLOOMZ, BLOOMChat | `bigscience/bloom`, `bigscience/bloomz`, etc. | | ✅︎ | |
|
||||
| `BartForConditionalGeneration` | BART | `facebook/bart-base`, `facebook/bart-large-cnn`, etc. | | | |
|
||||
@ -374,6 +375,7 @@ Specified using `--task generate`.
|
||||
| `Phi3ForCausalLM` | Phi-4, Phi-3 | `microsoft/Phi-4-mini-instruct`, `microsoft/Phi-4`, `microsoft/Phi-3-mini-4k-instruct`, `microsoft/Phi-3-mini-128k-instruct`, `microsoft/Phi-3-medium-128k-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi3SmallForCausalLM` | Phi-3-Small | `microsoft/Phi-3-small-8k-instruct`, `microsoft/Phi-3-small-128k-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PhiMoEForCausalLM` | Phi-3.5-MoE | `microsoft/Phi-3.5-MoE-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Phi4FlashForCausalLM` | Phi-4-mini-flash-reasoning | `microsoft/microsoft/Phi-4-mini-instruct`, etc. | | | |
|
||||
| `PersimmonForCausalLM` | Persimmon | `adept/persimmon-8b-base`, `adept/persimmon-8b-chat`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Plamo2ForCausalLM` | PLaMo2 | `pfnet/plamo-2-1b`, `pfnet/plamo-2-8b`, etc. | | | |
|
||||
| `QWenLMHeadModel` | Qwen | `Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
@ -579,14 +581,14 @@ Specified using `--task generate`.
|
||||
| `KeyeForConditionalGeneration` | Keye-VL-8B-Preview | T + I<sup>E+</sup> + V<sup>E+</sup> | `Kwai-Keye/Keye-VL-8B-Preview` | | | ✅︎ |
|
||||
| `KimiVLForConditionalGeneration` | Kimi-VL-A3B-Instruct, Kimi-VL-A3B-Thinking | T + I<sup>+</sup> | `moonshotai/Kimi-VL-A3B-Instruct`, `moonshotai/Kimi-VL-A3B-Thinking` | | | ✅︎ |
|
||||
| `Llama4ForConditionalGeneration` | Llama 4 | T + I<sup>+</sup> | `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5 | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaForConditionalGeneration` | LLaVA-1.5, Pixtral (HF Transformers) | T + I<sup>E+</sup> | `llava-hf/llava-1.5-7b-hf`, `TIGER-Lab/Mantis-8B-siglip-llama3` (see note), `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaNextForConditionalGeneration` | LLaVA-NeXT | T + I<sup>E+</sup> | `llava-hf/llava-v1.6-mistral-7b-hf`, `llava-hf/llava-v1.6-vicuna-7b-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaNextVideoForConditionalGeneration` | LLaVA-NeXT-Video | T + V | `llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `LlavaOnevisionForConditionalGeneration` | LLaVA-Onevision | T + I<sup>+</sup> + V<sup>+</sup> | `llava-hf/llava-onevision-qwen2-7b-ov-hf`, `llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. | | ✅︎ | ✅︎ |
|
||||
| `MiniCPMO` | MiniCPM-O | T + I<sup>E+</sup> + V<sup>E+</sup> + A<sup>E+</sup> | `openbmb/MiniCPM-o-2_6`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MiniCPMV` | MiniCPM-V | T + I<sup>E+</sup> + V<sup>E+</sup> | `openbmb/MiniCPM-V-2` (see note), `openbmb/MiniCPM-Llama3-V-2_5`, `openbmb/MiniCPM-V-2_6`, etc. | ✅︎ | | ✅︎ |
|
||||
| `MiniMaxVL01ForConditionalGeneration` | MiniMax-VL | T + I<sup>E+</sup> | `MiniMaxAI/MiniMax-VL-01`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Mistral3ForConditionalGeneration` | Mistral3 | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Mistral3ForConditionalGeneration` | Mistral3 (HF Transformers) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `MllamaForConditionalGeneration` | Llama 3.2 | T + I<sup>+</sup> | `meta-llama/Llama-3.2-90B-Vision-Instruct`, `meta-llama/Llama-3.2-11B-Vision`, etc. | | | |
|
||||
| `MolmoForCausalLM` | Molmo | T + I<sup>+</sup> | `allenai/Molmo-7B-D-0924`, `allenai/Molmo-7B-O-0924`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `NVLM_D_Model` | NVLM-D 1.0 | T + I<sup>+</sup> | `nvidia/NVLM-D-72B`, etc. | | ✅︎ | ✅︎ |
|
||||
@ -594,7 +596,7 @@ Specified using `--task generate`.
|
||||
| `PaliGemmaForConditionalGeneration` | PaliGemma, PaliGemma 2 | T + I<sup>E</sup> | `google/paligemma-3b-pt-224`, `google/paligemma-3b-mix-224`, `google/paligemma2-3b-ft-docci-448`, etc. | | ✅︎ | ⚠️ |
|
||||
| `Phi3VForCausalLM` | Phi-3-Vision, Phi-3.5-Vision | T + I<sup>E+</sup> | `microsoft/Phi-3-vision-128k-instruct`, `microsoft/Phi-3.5-vision-instruct`, etc. | | ✅︎ | ✅︎ |
|
||||
| `Phi4MMForCausalLM` | Phi-4-multimodal | T + I<sup>+</sup> / T + A<sup>+</sup> / I<sup>+</sup> + A<sup>+</sup> | `microsoft/Phi-4-multimodal-instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Pixtral | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistral-community/pixtral-12b`, etc. | | ✅︎ | ✅︎ |
|
||||
| `PixtralForConditionalGeneration` | Mistral 3 (Mistral format), Pixtral (Mistral format) | T + I<sup>+</sup> | `mistralai/Mistral-Small-3.1-24B-Instruct-2503`, `mistralai/Pixtral-12B-2409`, etc. | | ✅︎ | ✅︎ |
|
||||
| `QwenVLForConditionalGeneration`<sup>^</sup> | Qwen-VL | T + I<sup>E+</sup> | `Qwen/Qwen-VL`, `Qwen/Qwen-VL-Chat`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
| `Qwen2AudioForConditionalGeneration` | Qwen2-Audio | T + A<sup>+</sup> | `Qwen/Qwen2-Audio-7B-Instruct` | | ✅︎ | ✅︎ |
|
||||
| `Qwen2VLForConditionalGeneration` | QVQ, Qwen2-VL | T + I<sup>E+</sup> + V<sup>E+</sup> | `Qwen/QVQ-72B-Preview`, `Qwen/Qwen2-VL-7B-Instruct`, `Qwen/Qwen2-VL-72B-Instruct`, etc. | ✅︎ | ✅︎ | ✅︎ |
|
||||
|
120
docs/serving/data_parallel_deployment.md
Normal file
120
docs/serving/data_parallel_deployment.md
Normal file
@ -0,0 +1,120 @@
|
||||
# Data Parallel Deployment
|
||||
|
||||
vLLM supports Data Parallel deployment, where model weights are replicated across separate instances/GPUs to process independent batches of requests.
|
||||
|
||||
This will work with both dense and MoE models.
|
||||
|
||||
For MoE models, particularly those like DeepSeek that employ MLA (Multi-head Latent Attention), it can be advantageous to use data parallel for the attention layers and expert or tensor parallel (EP or TP) for the expert layers.
|
||||
|
||||
In these cases, the data parallel ranks are not completely independent. Forward passes must be aligned, and expert layers across all ranks are required to synchronize during every forward pass, even when there are fewer requests to be processed than DP ranks.
|
||||
|
||||
The expert layers will by default form a (DP x TP) sized tensor parallel group. To enable expert parallelism, include the `--enable-expert-parallel` CLI arg (on all nodes in the multi-node case).
|
||||
|
||||
In vLLM, each DP rank is deployed as a separate "core engine" process that communicates with front-end process(es) via ZMQ sockets. Data Parallel attention can be combined with Tensor Parallel attention, in which case each DP engine owns a number of per-GPU worker processes equal to the configured TP size.
|
||||
|
||||
For MoE models, when any requests are in progress in any rank, we must ensure that empty "dummy" forward passes are performed in all ranks that don't currently have any requests scheduled. This is handled via a separate DP Coordinator process that communicates with all ranks, and a collective operation performed every N steps to determine when all ranks become idle and can be paused. When TP is used in conjunction with DP, expert layers form an EP or TP group of size (DP x TP).
|
||||
|
||||
In all cases, it is beneficial to load-balance requests between DP ranks. For online deployments, this balancing can be optimized by taking into account the state of each DP engine - in particular its currently scheduled and waiting (queued) requests, and KV cache state. Each DP engine has an independent KV cache, and the benefit of prefix caching can be maximized by directing prompts intelligently.
|
||||
|
||||
This document focuses on online deployments (with the API server). DP + EP is also supported for offline usage (via the LLM class), for an example see <gh-file:examples/offline_inference/data_parallel.py>.
|
||||
|
||||
There are two distinct modes supported for online deployments - self-contained with internal load balancing, or externally per-rank process deployment and load balancing.
|
||||
|
||||
## Internal Load Balancing
|
||||
|
||||
vLLM supports "self-contained" data parallel deployments that expose a single API endpoint.
|
||||
|
||||
It can be configured by simply including e.g. `--data-parallel-size=4` in the vllm serve command line arguments. This will require 4 GPUs. It can be combined with tensor parallel, for example `--data-parallel-size=4 --tensor-parallel-size=2`, which would require 8 GPUs.
|
||||
|
||||
Running a single data parallel deployment across multiple nodes requires a different `vllm serve` to be run on each node, specifying which DP ranks should run on that node. In this case, there will still be a single HTTP entrypoint - the API server(s) will run only on one node, but it doesn't necessarily need to be co-located with the DP ranks.
|
||||
|
||||
This will run DP=4, TP=2 on a single 8-GPU node:
|
||||
|
||||
```bash
|
||||
vllm serve $MODEL --data-parallel-size 4 --tensor-parallel-size 2
|
||||
```
|
||||
|
||||
This will run DP=4 with DP ranks 0 and 1 on the head node and ranks 2 and 3 on the second node:
|
||||
|
||||
```bash
|
||||
# Node 0 (with ip address 10.99.48.128)
|
||||
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 2 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
# Node 1
|
||||
vllm serve $MODEL --headless --data-parallel-size 4 --data-parallel-size-local 2 \
|
||||
--data-parallel-start-rank 2 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
```
|
||||
|
||||
This will run DP=4 with only the API server on the first node and all engines on the second node:
|
||||
|
||||
```bash
|
||||
# Node 0 (with ip address 10.99.48.128)
|
||||
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 0 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
# Node 1
|
||||
vllm serve $MODEL --headless --data-parallel-size 4 --data-parallel-size-local 4 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
```
|
||||
|
||||
This DP mode can also be used with Ray by specifying `--data-parallel-backend=ray`:
|
||||
|
||||
```bash
|
||||
vllm serve $MODEL --data-parallel-size 4 --data-parallel-size-local 2 \
|
||||
--data-parallel-backend=ray
|
||||
```
|
||||
|
||||
There are several notable differences when using Ray:
|
||||
|
||||
- A single launch command (on any node) is needed to start all local and remote DP ranks, therefore it is more convenient compared to launching on each node
|
||||
- There is no need to specify `--data-parallel-address`, and the node where the command is run is used as `--data-parallel-address`
|
||||
- There is no need to specify `--data-parallel-rpc-port`
|
||||
- Remote DP ranks will be allocated based on node resources of the Ray cluster
|
||||
|
||||
Currently, the internal DP load balancing is done within the API server process(es) and is based on the running and waiting queues in each of the engines. This could be made more sophisticated in future by incorporating KV cache aware logic.
|
||||
|
||||
When deploying large DP sizes using this method, the API server process can become a bottleneck. In this case, the orthogonal `--api-server-count` command line option can be used to scale this out (for example `--api-server-count=4`). This is transparent to users - a single HTTP endpoint / port is still exposed. Note that this API server scale-out is "internal" and still confined to the "head" node.
|
||||
|
||||
<figure markdown="1">
|
||||

|
||||
</figure>
|
||||
|
||||
## External Load Balancing
|
||||
|
||||
For larger scale deployments especially, it can make sense to handle the orchestration and load balancing of data parallel ranks externally.
|
||||
|
||||
In this case, it's more convenient to treat each DP rank like a separate vLLM deployment, with its own endpoint, and have an external router balance HTTP requests between them, making use of appropriate real-time telemetry from each server for routing decisions.
|
||||
|
||||
This can already be done trivially for non-MoE models, since each deployed server is fully independent. No data parallel CLI options need to be used for this.
|
||||
|
||||
We support an equivalent topology for MoE DP+EP which can be configured via the following CLI arguments.
|
||||
|
||||
If DP ranks are co-located (same node / ip address), a default RPC port is used, but a different HTTP server port must be specified for each rank:
|
||||
|
||||
```bash
|
||||
# Rank 0
|
||||
CUDA_VISIBLE_DEVICES=0 vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 0 \
|
||||
--port 8000
|
||||
# Rank 1
|
||||
CUDA_VISIBLE_DEVICES=1 vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 1 \
|
||||
--port 8001
|
||||
```
|
||||
|
||||
For multi-node cases, the address/port of rank 0 must also be specified:
|
||||
|
||||
```bash
|
||||
# Rank 0 (with ip address 10.99.48.128)
|
||||
vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 0 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
# Rank 1
|
||||
vllm serve $MODEL --data-parallel-size 2 --data-parallel-rank 1 \
|
||||
--data-parallel-address 10.99.48.128 --data-parallel-rpc-port 13345
|
||||
```
|
||||
|
||||
The coordinator process also runs in this scenario, co-located with the DP rank 0 engine.
|
||||
|
||||
<figure markdown="1">
|
||||

|
||||
</figure>
|
||||
|
||||
In the above diagram, each of the dotted boxes corresponds to a separate launch of `vllm serve` - these could be separate Kubernetes pods, for example.
|
@ -15,6 +15,10 @@ After adding enough GPUs and nodes to hold the model, you can run vLLM first, wh
|
||||
!!! note
|
||||
There is one edge case: if the model fits in a single node with multiple GPUs, but the number of GPUs cannot divide the model size evenly, you can use pipeline parallelism, which splits the model along layers and supports uneven splits. In this case, the tensor parallel size should be 1 and the pipeline parallel size should be the number of GPUs.
|
||||
|
||||
### Distributed serving of MoE (Mixture of Experts) models
|
||||
|
||||
It is often advantageous to exploit the inherent parallelism of experts by using a separate parallelism strategy for the expert layers. vLLM supports large-scale deployment combining Data Parallel attention with Expert or Tensor Parallel MoE layers. See the page on [Data Parallel Deployment](data_parallel_deployment.md) for more information.
|
||||
|
||||
## Running vLLM on a single node
|
||||
|
||||
vLLM supports distributed tensor-parallel and pipeline-parallel inference and serving. Currently, we support [Megatron-LM's tensor parallel algorithm](https://arxiv.org/pdf/1909.08053.pdf). We manage the distributed runtime with either [Ray](https://github.com/ray-project/ray) or python native multiprocessing. Multiprocessing can be used when deploying on a single node, multi-node inference currently requires Ray.
|
||||
|
@ -106,14 +106,13 @@ to enable simultaneous generation and embedding using the same engine instance i
|
||||
|
||||
Models using selective state-space mechanisms instead of standard transformer attention are partially supported.
|
||||
Models that use Mamba-2 layers (e.g., `Mamba2ForCausalLM`) are supported, but models that use older Mamba-1 layers
|
||||
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet suported. Please note that these models currently require
|
||||
(e.g., `MambaForCausalLM`, `JambaForCausalLM`) are not yet supported. Please note that these models currently require
|
||||
enforcing eager mode and disabling prefix caching in V1.
|
||||
|
||||
Models that combine Mamba-2 layers with standard attention layers are also supported (e.g., `BambaForCausalLM`,
|
||||
`Zamba2ForCausalLM`, `NemotronHForCausalLM`, `FalconH1ForCausalLM` and `GraniteMoeHybridForCausalLM`). Please note that
|
||||
these models currently require enforcing eager mode, disabling prefix caching, and using the FlashInfer attention
|
||||
backend in V1. It is also necessary to pass a non-standard block size for attention layers (this is not possible
|
||||
using the `vllm serve` CLI yet).
|
||||
backend in V1.
|
||||
|
||||
#### Encoder-Decoder Models
|
||||
|
||||
|
@ -10,7 +10,7 @@ on HuggingFace model repository.
|
||||
|
||||
import os
|
||||
from dataclasses import asdict
|
||||
from typing import NamedTuple, Optional
|
||||
from typing import Any, NamedTuple, Optional
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import AutoTokenizer
|
||||
@ -30,7 +30,9 @@ question_per_audio_count = {
|
||||
|
||||
class ModelRequestData(NamedTuple):
|
||||
engine_args: EngineArgs
|
||||
prompt: str
|
||||
prompt: Optional[str] = None
|
||||
prompt_token_ids: Optional[dict[str, list[int]]] = None
|
||||
multi_modal_data: Optional[dict[str, Any]] = None
|
||||
stop_token_ids: Optional[list[int]] = None
|
||||
lora_requests: Optional[list[LoRARequest]] = None
|
||||
|
||||
@ -40,6 +42,60 @@ class ModelRequestData(NamedTuple):
|
||||
# Unless specified, these settings have been tested to work on a single L4.
|
||||
|
||||
|
||||
# Voxtral
|
||||
def run_voxtral(question: str, audio_count: int) -> ModelRequestData:
|
||||
from mistral_common.audio import Audio
|
||||
from mistral_common.protocol.instruct.messages import (
|
||||
AudioChunk,
|
||||
RawAudio,
|
||||
TextChunk,
|
||||
UserMessage,
|
||||
)
|
||||
from mistral_common.protocol.instruct.request import ChatCompletionRequest
|
||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||
|
||||
model_name = "mistralai/Voxtral-Mini-3B-2507"
|
||||
tokenizer = MistralTokenizer.from_hf_hub(model_name)
|
||||
|
||||
engine_args = EngineArgs(
|
||||
model=model_name,
|
||||
max_model_len=8192,
|
||||
max_num_seqs=2,
|
||||
limit_mm_per_prompt={"audio": audio_count},
|
||||
config_format="mistral",
|
||||
load_format="mistral",
|
||||
tokenizer_mode="mistral",
|
||||
enforce_eager=True,
|
||||
enable_chunked_prefill=False,
|
||||
)
|
||||
|
||||
text_chunk = TextChunk(text=question)
|
||||
audios = [
|
||||
Audio.from_file(str(audio_assets[i].get_local_path()), strict=False)
|
||||
for i in range(audio_count)
|
||||
]
|
||||
audio_chunks = [
|
||||
AudioChunk(input_audio=RawAudio.from_audio(audio)) for audio in audios
|
||||
]
|
||||
|
||||
messages = [UserMessage(content=[*audio_chunks, text_chunk])]
|
||||
|
||||
req = ChatCompletionRequest(messages=messages, model=model_name)
|
||||
|
||||
tokens = tokenizer.encode_chat_completion(req)
|
||||
prompt_ids, audios = tokens.tokens, tokens.audios
|
||||
|
||||
audios_and_sr = [(au.audio_array, au.sampling_rate) for au in audios]
|
||||
|
||||
multi_modal_data = {"audio": audios_and_sr}
|
||||
|
||||
return ModelRequestData(
|
||||
engine_args=engine_args,
|
||||
prompt_token_ids=prompt_ids,
|
||||
multi_modal_data=multi_modal_data,
|
||||
)
|
||||
|
||||
|
||||
# Granite Speech
|
||||
def run_granite_speech(question: str, audio_count: int) -> ModelRequestData:
|
||||
# NOTE - the setting in this example are somehat different than what is
|
||||
@ -243,6 +299,7 @@ def run_whisper(question: str, audio_count: int) -> ModelRequestData:
|
||||
|
||||
|
||||
model_example_map = {
|
||||
"voxtral": run_voxtral,
|
||||
"granite_speech": run_granite_speech,
|
||||
"minicpmo": run_minicpmo,
|
||||
"phi4_mm": run_phi4mm,
|
||||
@ -311,16 +368,24 @@ def main(args):
|
||||
temperature=0.2, max_tokens=64, stop_token_ids=req_data.stop_token_ids
|
||||
)
|
||||
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
mm_data = req_data.multi_modal_data
|
||||
if not mm_data:
|
||||
mm_data = {}
|
||||
if audio_count > 0:
|
||||
mm_data = {
|
||||
"audio": [
|
||||
asset.audio_and_sample_rate for asset in audio_assets[:audio_count]
|
||||
]
|
||||
}
|
||||
|
||||
assert args.num_prompts > 0
|
||||
inputs = {"prompt": req_data.prompt, "multi_modal_data": mm_data}
|
||||
inputs = {"multi_modal_data": mm_data}
|
||||
|
||||
if req_data.prompt:
|
||||
inputs["prompt"] = req_data.prompt
|
||||
else:
|
||||
inputs["prompt_token_ids"] = req_data.prompt_token_ids
|
||||
|
||||
if args.num_prompts > 1:
|
||||
# Batch inference
|
||||
inputs = [inputs] * args.num_prompts
|
||||
|
@ -1,17 +1,31 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""
|
||||
a simple demonstration of RLHF with vLLM, inspired by
|
||||
the OpenRLHF framework https://github.com/OpenRLHF/OpenRLHF .
|
||||
It follows the design that, training processes and inference processes
|
||||
are different, and they live on different GPUs.
|
||||
Training processes send prompts to inference processes to generate data,
|
||||
and also synchronize the weights of the model by broadcasting the weights
|
||||
from the training process to the inference process.
|
||||
Note that this is a simple demonstration of one training instance and one
|
||||
inference instance. In practice, there could be multiple training instances
|
||||
and multiple inference instances. For the full implementation, please refer
|
||||
to the OpenRLHF framework.
|
||||
Demonstrates reinforcement learning from human feedback (RLHF) using vLLM and Ray.
|
||||
|
||||
The script separates training and inference workloads onto distinct GPUs
|
||||
so that Ray can manage process placement and inter-process communication.
|
||||
A Hugging Face Transformer model occupies GPU 0 for training, whereas a
|
||||
tensor-parallel vLLM inference engine occupies GPU 1–2.
|
||||
|
||||
The example performs the following steps:
|
||||
|
||||
* Load the training model on GPU 0.
|
||||
* Split the inference model across GPUs 1–2 using vLLM's tensor parallelism
|
||||
and Ray placement groups.
|
||||
* Generate text from a list of prompts using the inference engine.
|
||||
* Update the weights of the training model and broadcast the updated weights
|
||||
to the inference engine by using a Ray collective RPC group. Note that
|
||||
for demonstration purposes we simply zero out the weights.
|
||||
|
||||
For a production-ready implementation that supports multiple training and
|
||||
inference replicas, see the OpenRLHF framework:
|
||||
https://github.com/OpenRLHF/OpenRLHF
|
||||
|
||||
This example assumes a single-node cluster with three GPUs, but Ray
|
||||
supports multi-node clusters. vLLM expects the GPUs are only used for vLLM
|
||||
workloads. Residual GPU activity interferes with vLLM memory profiling and
|
||||
causes unexpected behavior.
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -28,29 +42,27 @@ from vllm.utils import get_ip, get_open_port
|
||||
|
||||
|
||||
class MyLLM(LLM):
|
||||
"""Configure the vLLM worker for Ray placement group execution."""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
# a hack to make the script work.
|
||||
# stop ray from manipulating CUDA_VISIBLE_DEVICES
|
||||
# at the top-level
|
||||
# Remove the top-level CUDA_VISIBLE_DEVICES variable set by Ray
|
||||
# so that vLLM can manage its own device placement within the worker.
|
||||
os.environ.pop("CUDA_VISIBLE_DEVICES", None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
|
||||
"""
|
||||
Start the training process, here we use huggingface transformers
|
||||
as an example to hold a model on GPU 0.
|
||||
"""
|
||||
|
||||
# Load the OPT-125M model onto GPU 0 for the training workload.
|
||||
train_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m")
|
||||
train_model.to("cuda:0")
|
||||
"""
|
||||
Start the inference process, here we use vLLM to hold a model on GPU 1 and
|
||||
GPU 2. For the details on how to use ray, please refer to the ray
|
||||
documentation https://docs.ray.io/en/latest/ .
|
||||
"""
|
||||
|
||||
# Initialize Ray and set the visible devices. The vLLM engine will
|
||||
# be placed on GPUs 1 and 2.
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "1,2"
|
||||
ray.init()
|
||||
|
||||
# Create a placement group that reserves GPU 1–2 for the vLLM inference engine.
|
||||
# Learn more about Ray placement groups:
|
||||
# https://docs.ray.io/en/latest/placement-groups.html
|
||||
pg_inference = placement_group([{"GPU": 1, "CPU": 0}] * 2)
|
||||
ray.get(pg_inference.ready())
|
||||
scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
@ -58,10 +70,9 @@ scheduling_inference = PlacementGroupSchedulingStrategy(
|
||||
placement_group_capture_child_tasks=True,
|
||||
placement_group_bundle_index=0,
|
||||
)
|
||||
"""
|
||||
launch the vLLM inference engine.
|
||||
here we use `enforce_eager` to reduce the start time.
|
||||
"""
|
||||
|
||||
# Launch the vLLM inference engine. The `enforce_eager` flag reduces
|
||||
# start-up latency.
|
||||
llm = ray.remote(
|
||||
num_cpus=0,
|
||||
num_gpus=0,
|
||||
@ -74,7 +85,7 @@ llm = ray.remote(
|
||||
distributed_executor_backend="ray",
|
||||
)
|
||||
|
||||
# Generate texts from the prompts.
|
||||
# Generate text from the prompts.
|
||||
prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
@ -93,8 +104,8 @@ for output in outputs:
|
||||
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
|
||||
print("-" * 50)
|
||||
|
||||
# set up the communication between the training process
|
||||
# and the inference engine.
|
||||
# Set up the communication channel between the training process and the
|
||||
# inference engine.
|
||||
master_address = get_ip()
|
||||
master_port = get_open_port()
|
||||
|
||||
@ -107,21 +118,23 @@ model_update_group = stateless_init_process_group(
|
||||
)
|
||||
ray.get(handle)
|
||||
|
||||
# simulate training, modify the weights of the model.
|
||||
# Simulate a training step by zeroing out all model weights.
|
||||
# In a real RLHF training loop the weights would be updated using the gradient
|
||||
# from an RL objective such as PPO on a reward model.
|
||||
for name, p in train_model.named_parameters():
|
||||
p.data.zero_()
|
||||
|
||||
# sync weight from the training process to the inference engine.
|
||||
# Synchronize the updated weights to the inference engine.
|
||||
for name, p in train_model.named_parameters():
|
||||
handle = llm.collective_rpc.remote("update_weight", args=(name, p.dtype, p.shape))
|
||||
model_update_group.broadcast(p, src=0, stream=torch.cuda.current_stream())
|
||||
ray.get(handle)
|
||||
|
||||
# check if the weights are updated.
|
||||
# Verify that the inference weights have been updated.
|
||||
assert all(ray.get(llm.collective_rpc.remote("check_weights_changed")))
|
||||
|
||||
# use the updated model to generate texts, they will be nonsense
|
||||
# because the weights are all zeros.
|
||||
# Generate text with the updated model. The output is expected to be nonsense
|
||||
# because the weights are zero.
|
||||
outputs_updated = ray.get(llm.generate.remote(prompts, sampling_params))
|
||||
print("-" * 50)
|
||||
for output in outputs_updated:
|
||||
|
@ -84,6 +84,7 @@ def main():
|
||||
gpu_memory_utilization=0.8,
|
||||
speculative_config=speculative_config,
|
||||
disable_log_stats=False,
|
||||
max_model_len=16384,
|
||||
)
|
||||
|
||||
sampling_params = SamplingParams(temperature=args.temp, max_tokens=args.output_len)
|
||||
|
@ -1,35 +1,81 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# Launch a Ray cluster inside Docker for vLLM inference.
|
||||
#
|
||||
# This script can start either a head node or a worker node, depending on the
|
||||
# --head or --worker flag provided as the third positional argument.
|
||||
#
|
||||
# Usage:
|
||||
# 1. Designate one machine as the head node and execute:
|
||||
# bash run_cluster.sh \
|
||||
# vllm/vllm-openai \
|
||||
# <head_node_ip> \
|
||||
# --head \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<head_node_ip>
|
||||
#
|
||||
# 2. On every worker machine, execute:
|
||||
# bash run_cluster.sh \
|
||||
# vllm/vllm-openai \
|
||||
# <head_node_ip> \
|
||||
# --worker \
|
||||
# /abs/path/to/huggingface/cache \
|
||||
# -e VLLM_HOST_IP=<worker_node_ip>
|
||||
#
|
||||
# Each worker requires a unique VLLM_HOST_IP value.
|
||||
# Keep each terminal session open. Closing a session stops the associated Ray
|
||||
# node and thereby shuts down the entire cluster.
|
||||
# Every machine must be reachable at the supplied IP address.
|
||||
#
|
||||
# The container is named "node-<random_suffix>". To open a shell inside
|
||||
# a container after launch, use:
|
||||
# docker exec -it node-<random_suffix> /bin/bash
|
||||
#
|
||||
# Then, you can execute vLLM commands on the Ray cluster as if it were a
|
||||
# single machine, e.g. vllm serve ...
|
||||
#
|
||||
# To stop the container, use:
|
||||
# docker stop node-<random_suffix>
|
||||
|
||||
# Check for minimum number of required arguments
|
||||
# Check for minimum number of required arguments.
|
||||
if [ $# -lt 4 ]; then
|
||||
echo "Usage: $0 docker_image head_node_address --head|--worker path_to_hf_home [additional_args...]"
|
||||
echo "Usage: $0 docker_image head_node_ip --head|--worker path_to_hf_home [additional_args...]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Assign the first three arguments and shift them away
|
||||
# Extract the mandatory positional arguments and remove them from $@.
|
||||
DOCKER_IMAGE="$1"
|
||||
HEAD_NODE_ADDRESS="$2"
|
||||
NODE_TYPE="$3" # Should be --head or --worker
|
||||
NODE_TYPE="$3" # Should be --head or --worker.
|
||||
PATH_TO_HF_HOME="$4"
|
||||
shift 4
|
||||
|
||||
# Additional arguments are passed directly to the Docker command
|
||||
# Preserve any extra arguments so they can be forwarded to Docker.
|
||||
ADDITIONAL_ARGS=("$@")
|
||||
|
||||
# Validate node type
|
||||
# Validate the NODE_TYPE argument.
|
||||
if [ "${NODE_TYPE}" != "--head" ] && [ "${NODE_TYPE}" != "--worker" ]; then
|
||||
echo "Error: Node type must be --head or --worker"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Define a function to cleanup on EXIT signal
|
||||
# Generate a unique container name with random suffix.
|
||||
# Docker container names must be unique on each host.
|
||||
# The random suffix allows multiple Ray containers to run simultaneously on the same machine,
|
||||
# for example, on a multi-GPU machine.
|
||||
CONTAINER_NAME="node-${RANDOM}"
|
||||
|
||||
# Define a cleanup routine that removes the container when the script exits.
|
||||
# This prevents orphaned containers from accumulating if the script is interrupted.
|
||||
cleanup() {
|
||||
docker stop node
|
||||
docker rm node
|
||||
docker stop "${CONTAINER_NAME}"
|
||||
docker rm "${CONTAINER_NAME}"
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
# Command setup for head or worker node
|
||||
# Build the Ray start command based on the node role.
|
||||
# The head node manages the cluster and accepts connections on port 6379,
|
||||
# while workers connect to the head's address.
|
||||
RAY_START_CMD="ray start --block"
|
||||
if [ "${NODE_TYPE}" == "--head" ]; then
|
||||
RAY_START_CMD+=" --head --port=6379"
|
||||
@ -37,11 +83,15 @@ else
|
||||
RAY_START_CMD+=" --address=${HEAD_NODE_ADDRESS}:6379"
|
||||
fi
|
||||
|
||||
# Run the docker command with the user specified parameters and additional arguments
|
||||
# Launch the container with the assembled parameters.
|
||||
# --network host: Allows Ray nodes to communicate directly via host networking
|
||||
# --shm-size 10.24g: Increases shared memory
|
||||
# --gpus all: Gives container access to all GPUs on the host
|
||||
# -v HF_HOME: Mounts HuggingFace cache to avoid re-downloading models
|
||||
docker run \
|
||||
--entrypoint /bin/bash \
|
||||
--network host \
|
||||
--name node \
|
||||
--name "${CONTAINER_NAME}" \
|
||||
--shm-size 10.24g \
|
||||
--gpus all \
|
||||
-v "${PATH_TO_HF_HOME}:/root/.cache/huggingface" \
|
||||
|
@ -61,6 +61,7 @@ plugins:
|
||||
- search
|
||||
- autorefs
|
||||
- awesome-nav
|
||||
- glightbox
|
||||
# For API reference generation
|
||||
- api-autonav:
|
||||
modules: ["vllm"]
|
||||
|
183
pyproject.toml
183
pyproject.toml
@ -174,3 +174,186 @@ respect-ignore-files = true
|
||||
|
||||
[tool.ty.environment]
|
||||
python = "./.venv"
|
||||
|
||||
[tool.typos.files]
|
||||
# these files may be written in non english words
|
||||
extend-exclude = ["tests/models/fixtures/*", "tests/prompts/*",
|
||||
"benchmarks/sonnet.txt", "tests/lora/data/*", "build/*",
|
||||
"vllm/third_party/*"]
|
||||
ignore-hidden = true
|
||||
ignore-files = true
|
||||
ignore-dot = true
|
||||
ignore-vcs = true
|
||||
ignore-global = true
|
||||
ignore-parent = true
|
||||
|
||||
[tool.typos.default]
|
||||
binary = false
|
||||
check-filename = false
|
||||
check-file = true
|
||||
unicode = true
|
||||
ignore-hex = true
|
||||
identifier-leading-digits = false
|
||||
locale = "en"
|
||||
extend-ignore-identifiers-re = ["NVML_*", ".*Unc.*", ".*_thw",
|
||||
".*UE8M0.*", ".*[UE4M3|ue4m3].*", ".*eles.*",
|
||||
".*[Tt]h[rR].*"]
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.default.extend-identifiers]
|
||||
bbc5b7ede = "bbc5b7ede"
|
||||
womens_doubles = "womens_doubles"
|
||||
v_2nd = "v_2nd"
|
||||
# splitted_input = "splitted_input"
|
||||
NOOPs = "NOOPs"
|
||||
typ = "typ"
|
||||
nin_shortcut = "nin_shortcut"
|
||||
UperNetDecoder = "UperNetDecoder"
|
||||
subtile = "subtile"
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = "cudaDevAttrMaxSharedMemoryPerBlockOptin"
|
||||
SFOuput = "SFOuput"
|
||||
# huggingface transformers repo uses these words
|
||||
depthwise_seperable_out_channel = "depthwise_seperable_out_channel"
|
||||
DepthWiseSeperableConv1d = "DepthWiseSeperableConv1d"
|
||||
depthwise_seperable_CNN = "depthwise_seperable_CNN"
|
||||
|
||||
[tool.typos.default.extend-words]
|
||||
iy = "iy"
|
||||
tendencias = "tendencias"
|
||||
# intel cpu features
|
||||
tme = "tme"
|
||||
dout = "dout"
|
||||
Pn = "Pn"
|
||||
arange = "arange"
|
||||
|
||||
[tool.typos.type.py]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.py.extend-identifiers]
|
||||
arange = "arange"
|
||||
NDArray = "NDArray"
|
||||
EOFError = "EOFError"
|
||||
fo = "fo"
|
||||
ba = "ba"
|
||||
|
||||
[tool.typos.type.py.extend-words]
|
||||
|
||||
[tool.typos.type.cpp]
|
||||
extend-glob = ["*.cu"]
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.cpp.extend-identifiers]
|
||||
countr_one = "countr_one"
|
||||
k_ot = "k_ot"
|
||||
ot = "ot"
|
||||
|
||||
[tool.typos.type.cpp.extend-words]
|
||||
|
||||
[tool.typos.type.rust]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.rust.extend-identifiers]
|
||||
flate2 = "flate2"
|
||||
|
||||
[tool.typos.type.rust.extend-words]
|
||||
ser = "ser"
|
||||
|
||||
[tool.typos.type.lock]
|
||||
extend-glob = []
|
||||
check-file = false
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.lock.extend-identifiers]
|
||||
|
||||
[tool.typos.type.lock.extend-words]
|
||||
|
||||
[tool.typos.type.jl]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.jl.extend-identifiers]
|
||||
|
||||
[tool.typos.type.jl.extend-words]
|
||||
modul = "modul"
|
||||
egals = "egals"
|
||||
usig = "usig"
|
||||
egal = "egal"
|
||||
|
||||
[tool.typos.type.go]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.go.extend-identifiers]
|
||||
flate = "flate"
|
||||
|
||||
[tool.typos.type.go.extend-words]
|
||||
|
||||
[tool.typos.type.css]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.css.extend-identifiers]
|
||||
nd = "nd"
|
||||
|
||||
[tool.typos.type.css.extend-words]
|
||||
|
||||
[tool.typos.type.man]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.man.extend-identifiers]
|
||||
Nd = "Nd"
|
||||
|
||||
[tool.typos.type.man.extend-words]
|
||||
|
||||
[tool.typos.type.cert]
|
||||
extend-glob = []
|
||||
check-file = false
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.cert.extend-identifiers]
|
||||
|
||||
[tool.typos.type.cert.extend-words]
|
||||
|
||||
[tool.typos.type.sh]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.sh.extend-identifiers]
|
||||
ot = "ot"
|
||||
|
||||
[tool.typos.type.sh.extend-words]
|
||||
|
||||
[tool.typos.type.vimscript]
|
||||
extend-glob = []
|
||||
extend-ignore-identifiers-re = []
|
||||
extend-ignore-words-re = []
|
||||
extend-ignore-re = []
|
||||
|
||||
[tool.typos.type.vimscript.extend-identifiers]
|
||||
windo = "windo"
|
||||
|
||||
[tool.typos.type.vimscript.extend-words]
|
||||
|
@ -7,7 +7,7 @@ requests >= 2.26.0
|
||||
tqdm
|
||||
blake3
|
||||
py-cpuinfo
|
||||
transformers >= 4.51.1
|
||||
transformers >= 4.53.2
|
||||
huggingface-hub[hf_xet] >= 0.33.0 # Required for Xet downloads.
|
||||
tokenizers >= 0.21.1 # Required for fast incremental detokenization.
|
||||
protobuf # Required by LlamaTokenizer.
|
||||
@ -25,7 +25,7 @@ outlines_core == 0.2.10
|
||||
# required for outlines backend disk cache
|
||||
diskcache == 5.6.3
|
||||
lark == 1.2.2
|
||||
xgrammar == 0.1.19; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
xgrammar == 0.1.21; platform_machine == "x86_64" or platform_machine == "aarch64" or platform_machine == "arm64"
|
||||
typing_extensions >= 4.10
|
||||
filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317
|
||||
partial-json-parser # used for parsing partial JSON outputs
|
||||
@ -33,17 +33,18 @@ pyzmq >= 25.0.0
|
||||
msgspec
|
||||
gguf >= 0.13.0
|
||||
importlib_metadata; python_version < '3.10'
|
||||
mistral_common[opencv] >= 1.6.2
|
||||
mistral_common[opencv] >= 1.8.0
|
||||
opencv-python-headless >= 4.11.0 # required for video IO
|
||||
pyyaml
|
||||
six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12
|
||||
setuptools>=77.0.3,<80; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12
|
||||
einops # Required for Qwen2-VL.
|
||||
compressed-tensors == 0.10.2 # required for compressed-tensors
|
||||
depyf==0.18.0 # required for profiling and debugging with compilation config
|
||||
depyf==0.19.0 # required for profiling and debugging with compilation config
|
||||
cloudpickle # allows pickling lambda functions in model_executor/models/registry.py
|
||||
watchfiles # required for http server to monitor the updates of TLS files
|
||||
python-json-logger # Used by logging as per examples/others/logging_configuration.md
|
||||
scipy # Required for phi-4-multimodal-instruct
|
||||
ninja # Required for xgrammar, rocm, tpu, xpu
|
||||
pybase64 # fast base64 implementation
|
||||
cbor2 # Required for cross-language serialization of hashable objects
|
||||
|
@ -4,6 +4,7 @@ mkdocs-material
|
||||
mkdocstrings-python
|
||||
mkdocs-gen-files
|
||||
mkdocs-awesome-nav
|
||||
mkdocs-glightbox
|
||||
python-markdown-math
|
||||
regex
|
||||
ruff
|
||||
@ -11,10 +12,12 @@ ruff
|
||||
# Required for argparse hook only
|
||||
-f https://download.pytorch.org/whl/cpu
|
||||
cachetools
|
||||
cbor2
|
||||
cloudpickle
|
||||
fastapi
|
||||
msgspec
|
||||
openai
|
||||
partial-json-parser
|
||||
pillow
|
||||
psutil
|
||||
pybase64
|
||||
|
@ -23,7 +23,7 @@ jiwer # required for audio tests
|
||||
timm # required for internvl test
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.6.2 # required for pixtral test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
|
@ -28,13 +28,13 @@ torchvision==0.22.0
|
||||
transformers_stream_generator # required for qwen-vl test
|
||||
mamba_ssm # required for plamo2 test
|
||||
matplotlib # required for qwen-vl test
|
||||
mistral_common[opencv] >= 1.6.2 # required for pixtral test
|
||||
mistral_common[opencv] >= 1.8.0 # required for voxtral test
|
||||
num2words # required for smolvlm test
|
||||
opencv-python-headless >= 4.11.0 # required for video test
|
||||
datamodel_code_generator # required for minicpm3 test
|
||||
lm-eval[api]==0.4.8 # required for model evaluation test
|
||||
mteb[bm25s]>=1.38.11, <2 # required for mteb test
|
||||
transformers==4.52.4
|
||||
transformers==4.53.2
|
||||
tokenizers==0.21.1
|
||||
huggingface-hub[hf_xet]>=0.33.0 # Required for Xet downloads.
|
||||
schemathesis>=3.39.15 # Required for openai schema test.
|
||||
|
@ -305,7 +305,7 @@ mbstrdecoder==1.1.3
|
||||
# typepy
|
||||
mdurl==0.1.2
|
||||
# via markdown-it-py
|
||||
mistral-common==1.6.2
|
||||
mistral-common==1.8.0
|
||||
# via -r requirements/test.in
|
||||
more-itertools==10.5.0
|
||||
# via lm-eval
|
||||
@ -518,6 +518,8 @@ pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pybind11==2.13.6
|
||||
# via lm-eval
|
||||
pycountry==24.6.1
|
||||
# via pydantic-extra-types
|
||||
pycparser==2.22
|
||||
# via cffi
|
||||
pycryptodomex==3.22.0
|
||||
@ -528,9 +530,12 @@ pydantic==2.11.5
|
||||
# datamodel-code-generator
|
||||
# mistral-common
|
||||
# mteb
|
||||
# pydantic-extra-types
|
||||
# ray
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-extra-types==2.10.5
|
||||
# via mistral-common
|
||||
pygments==2.18.0
|
||||
# via rich
|
||||
pyparsing==3.2.0
|
||||
@ -800,7 +805,7 @@ tqdm==4.66.6
|
||||
# transformers
|
||||
tqdm-multiprocess==0.0.11
|
||||
# via lm-eval
|
||||
transformers==4.52.4
|
||||
transformers==4.53.2
|
||||
# via
|
||||
# -r requirements/test.in
|
||||
# genai-perf
|
||||
@ -835,6 +840,7 @@ typing-extensions==4.12.2
|
||||
# pqdm
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydantic-extra-types
|
||||
# torch
|
||||
# typer
|
||||
# typing-inspection
|
||||
|
@ -18,9 +18,9 @@ setuptools==78.1.0
|
||||
--find-links https://storage.googleapis.com/libtpu-releases/index.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
|
||||
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
torch==2.9.0.dev20250703
|
||||
torchvision==0.24.0.dev20250703
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250703-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
torch==2.9.0.dev20250711
|
||||
torchvision==0.24.0.dev20250711
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
|
||||
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.9.0.dev20250711-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
|
||||
|
||||
|
3
setup.py
3
setup.py
@ -692,7 +692,8 @@ setup(
|
||||
"tensorizer": ["tensorizer==2.10.1"],
|
||||
"fastsafetensors": ["fastsafetensors >= 0.1.10"],
|
||||
"runai": ["runai-model-streamer", "runai-model-streamer-s3", "boto3"],
|
||||
"audio": ["librosa", "soundfile"], # Required for audio processing
|
||||
"audio": ["librosa", "soundfile",
|
||||
"mistral_common[audio]"], # Required for audio processing
|
||||
"video": [] # Kept for backwards compatibility
|
||||
},
|
||||
cmdclass=cmdclass,
|
||||
|
@ -29,7 +29,7 @@ def _query_server_long(prompt: str) -> dict:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def api_server(tokenizer_pool_size: int, distributed_executor_backend: str):
|
||||
def api_server(distributed_executor_backend: str):
|
||||
script_path = Path(__file__).parent.joinpath(
|
||||
"api_server_async_engine.py").absolute()
|
||||
commands = [
|
||||
@ -40,8 +40,6 @@ def api_server(tokenizer_pool_size: int, distributed_executor_backend: str):
|
||||
"facebook/opt-125m",
|
||||
"--host",
|
||||
"127.0.0.1",
|
||||
"--tokenizer-pool-size",
|
||||
str(tokenizer_pool_size),
|
||||
"--distributed-executor-backend",
|
||||
distributed_executor_backend,
|
||||
]
|
||||
@ -54,10 +52,8 @@ def api_server(tokenizer_pool_size: int, distributed_executor_backend: str):
|
||||
uvicorn_process.terminate()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenizer_pool_size", [0, 2])
|
||||
@pytest.mark.parametrize("distributed_executor_backend", ["mp", "ray"])
|
||||
def test_api_server(api_server, tokenizer_pool_size: int,
|
||||
distributed_executor_backend: str):
|
||||
def test_api_server(api_server, distributed_executor_backend: str):
|
||||
"""
|
||||
Run the API server and test it.
|
||||
|
||||
|
@ -26,6 +26,30 @@ def test_use_cudagraphs_dynamic(monkeypatch):
|
||||
assert not vllm_config.compilation_config.use_cudagraph
|
||||
|
||||
|
||||
# NB: We don't test VLLM_DISABLE_COMPILE_CACHE=0 because that depends
|
||||
# on the state of the cache directory on the current machine, which
|
||||
# may be influenced by other tests.
|
||||
@pytest.mark.parametrize("val", ["1"])
|
||||
def test_VLLM_DISABLE_COMPILE_CACHE(vllm_runner, monkeypatch, val):
|
||||
assert vllm.envs.VLLM_USE_V1
|
||||
|
||||
# spawn means that the counters are in the same process.
|
||||
monkeypatch.setenv('VLLM_WORKER_MULTIPROC_METHOD', "spawn")
|
||||
monkeypatch.setenv('VLLM_DISABLE_COMPILE_CACHE', val)
|
||||
|
||||
compilation_config = {
|
||||
"use_cudagraph": False, # speed things up a bit
|
||||
}
|
||||
with (
|
||||
compilation_counter.expect(num_cache_entries_updated=0,
|
||||
num_compiled_artifacts_saved=0),
|
||||
# loading the model causes compilation (if enabled) to happen
|
||||
vllm_runner('facebook/opt-125m',
|
||||
compilation_config=compilation_config,
|
||||
gpu_memory_utilization=0.4) as _):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enabled", [True, False])
|
||||
def test_use_cudagraphs(vllm_runner, monkeypatch, enabled):
|
||||
assert vllm.envs.VLLM_USE_V1
|
||||
|
@ -3,6 +3,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import pytest
|
||||
@ -111,6 +112,11 @@ def test_full_graph(
|
||||
pass_config=PassConfig(enable_fusion=True,
|
||||
enable_noop=True)), model)
|
||||
for model in models_list(keywords=["FP8-dynamic", "quantized.w8a8"])
|
||||
] + [
|
||||
# Test depyf integration works
|
||||
(CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
debug_dump_path=tempfile.gettempdir()),
|
||||
("facebook/opt-125m", {})),
|
||||
])
|
||||
# only test some of the models
|
||||
@create_new_process_for_each_test()
|
||||
|
@ -44,7 +44,9 @@ class TestModel(torch.nn.Module):
|
||||
]
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
use_per_token_if_dynamic=True)
|
||||
act_quant_static=static,
|
||||
act_quant_group_shape=group_shape,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
resid = torch.sqrt(x)
|
||||
@ -91,9 +93,10 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps, static,
|
||||
maybe_create_device_identity() # needed for certain non-cutlass fp8 paths
|
||||
|
||||
vllm_config = VllmConfig(compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE, custom_ops=["+rms_norm"]))
|
||||
vllm_config.compilation_config.pass_config = \
|
||||
PassConfig(enable_fusion=True, enable_noop=True)
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm", "+quant_fp8"],
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
|
||||
))
|
||||
with vllm.config.set_current_vllm_config(vllm_config):
|
||||
# Reshape pass is needed for the fusion pass to work
|
||||
noop_pass = NoOpEliminationPass(vllm_config)
|
||||
|
150
tests/compile/test_fusion_all_reduce.py
Normal file
150
tests/compile/test_fusion_all_reduce.py
Normal file
@ -0,0 +1,150 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm.compilation.collective_fusion import AllReduceFusionPass
|
||||
from vllm.config import (CompilationConfig, CompilationLevel, DeviceConfig,
|
||||
ModelConfig, PassConfig, VllmConfig)
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
from vllm.distributed.parallel_state import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.utils import update_environment_variables
|
||||
|
||||
from ..utils import multi_gpu_test
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestAllReduceRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm = self.norm(all_reduce)
|
||||
return norm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
class TestAllReduceFusedAddRMSNormModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, hidden_size=16, eps=1e-6):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.eps = eps
|
||||
self.norm = RMSNorm(hidden_size, eps)
|
||||
|
||||
def forward(self, hidden_states, residual):
|
||||
view = hidden_states.reshape(-1, self.hidden_size)
|
||||
all_reduce = tensor_model_parallel_all_reduce(view)
|
||||
norm, _ = self.norm(all_reduce, residual)
|
||||
return norm
|
||||
|
||||
def ops_in_model_before(self):
|
||||
return [torch.ops.vllm.all_reduce.default]
|
||||
|
||||
def ops_in_model_after(self):
|
||||
return [torch.ops.vllm.flashinfer_trtllm_fused_allreduce_norm.default]
|
||||
|
||||
|
||||
@multi_gpu_test(num_gpus=2)
|
||||
@pytest.mark.parametrize(
|
||||
"test_model",
|
||||
[TestAllReduceRMSNormModel, TestAllReduceFusedAddRMSNormModel])
|
||||
@pytest.mark.parametrize("batch_size", [8])
|
||||
@pytest.mark.parametrize("seq_len", [8])
|
||||
@pytest.mark.parametrize("hidden_size", [4096])
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda"],
|
||||
reason="Only test on CUDA")
|
||||
@pytest.mark.skipif(not find_spec("flashinfer"),
|
||||
reason="flashinfer is not installed")
|
||||
@pytest.mark.skipif(not current_platform.is_device_capability(100),
|
||||
reason="Only test on SM100")
|
||||
def test_all_reduce_fusion_pass_replace(test_model: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
num_processes = 2
|
||||
|
||||
def run_torch_spawn(fn, nprocs):
|
||||
torch.multiprocessing.spawn(fn,
|
||||
args=(num_processes, test_model,
|
||||
batch_size, seq_len, hidden_size,
|
||||
dtype),
|
||||
nprocs=nprocs)
|
||||
|
||||
run_torch_spawn(all_reduce_fusion_pass_on_test_model, num_processes)
|
||||
|
||||
|
||||
def all_reduce_fusion_pass_on_test_model(local_rank: int, world_size: int,
|
||||
test_model_cls: torch.nn.Module,
|
||||
batch_size: int, seq_len: int,
|
||||
hidden_size: int, dtype: torch.dtype):
|
||||
current_platform.seed_everything(0)
|
||||
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
torch.cuda.set_device(device)
|
||||
torch.set_default_device(device)
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
update_environment_variables({
|
||||
'RANK': str(local_rank),
|
||||
'LOCAL_RANK': str(local_rank),
|
||||
'WORLD_SIZE': str(world_size),
|
||||
'MASTER_ADDR': 'localhost',
|
||||
'MASTER_PORT': '12345',
|
||||
})
|
||||
|
||||
init_distributed_environment()
|
||||
initialize_model_parallel(tensor_model_parallel_size=world_size)
|
||||
|
||||
vllm_config = VllmConfig(
|
||||
compilation_config=CompilationConfig(level=CompilationLevel.PIECEWISE,
|
||||
custom_ops=["+rms_norm"],
|
||||
compile_sizes=[2, 4, 8]))
|
||||
vllm_config.compilation_config.pass_config = PassConfig(
|
||||
enable_fi_allreduce_fusion=True)
|
||||
vllm_config.device_config = DeviceConfig(device=torch.device("cuda"))
|
||||
|
||||
# this is a fake model name to construct the model config
|
||||
# in the vllm_config, it's not really used.
|
||||
model_name = "nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"
|
||||
vllm_config.model_config = ModelConfig(model=model_name,
|
||||
task="auto",
|
||||
tokenizer=model_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
seed=42)
|
||||
|
||||
all_reduce_fusion_pass = AllReduceFusionPass(vllm_config)
|
||||
backend = TestBackend(all_reduce_fusion_pass)
|
||||
|
||||
model = test_model_cls(hidden_size)
|
||||
|
||||
hidden_states = torch.randn((batch_size * seq_len, hidden_size),
|
||||
requires_grad=False)
|
||||
residual = torch.randn((batch_size * seq_len, hidden_size),
|
||||
requires_grad=False)
|
||||
|
||||
compiled_model = torch.compile(model, backend=backend)
|
||||
compiled_model(hidden_states, residual)
|
||||
|
||||
backend.check_before_ops(model.ops_in_model_before(), fully_replaced=False)
|
||||
backend.check_after_ops(model.ops_in_model_after())
|
||||
del all_reduce_fusion_pass
|
@ -50,6 +50,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend_unfused",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
|
||||
@ -73,6 +74,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
|
||||
# DYNAMO_ONCE does not properly propagate shapes.
|
||||
level=CompilationLevel.DYNAMO_AS_IS,
|
||||
backend="tests.compile.test_fusion_attn.backend",
|
||||
custom_ops=["+quant_fp8"],
|
||||
)
|
||||
vllm_config = VllmConfig(compilation_config=compile_config)
|
||||
|
||||
|
@ -4,33 +4,56 @@ import pytest
|
||||
import torch
|
||||
|
||||
import vllm.envs as envs
|
||||
from vllm._custom_ops import scaled_fp8_quant
|
||||
from vllm.compilation.activation_quant_fusion import ActivationQuantFusionPass
|
||||
from vllm.compilation.fx_utils import find_auto_fn, find_auto_fn_maybe
|
||||
from vllm.compilation.noop_elimination import NoOpEliminationPass
|
||||
from vllm.config import CompilationConfig, PassConfig, VllmConfig
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
GroupShape)
|
||||
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
||||
CUTLASS_FP8_SUPPORTED, Fp8LinearOp)
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .backend import TestBackend
|
||||
|
||||
|
||||
class TestModel(torch.nn.Module):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, hidden_size: int, cutlass_fp8_enabled: bool, *args,
|
||||
**kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.silu_and_mul = SiluAndMul()
|
||||
self.wscale = torch.rand(1, dtype=torch.float32)
|
||||
self.scale = torch.rand(1, dtype=torch.float32)
|
||||
|
||||
self.w = (torch.rand(
|
||||
hidden_size,
|
||||
hidden_size).to(dtype=current_platform.fp8_dtype()).t())
|
||||
|
||||
self.fp8_linear = Fp8LinearOp(
|
||||
cutlass_fp8_supported=cutlass_fp8_enabled,
|
||||
act_quant_static=True,
|
||||
act_quant_group_shape=GroupShape.PER_TENSOR,
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = self.silu_and_mul(x)
|
||||
x2 = scaled_fp8_quant(y, self.scale)
|
||||
x2 = self.fp8_linear.apply(y,
|
||||
self.w,
|
||||
self.wscale,
|
||||
input_scale=self.wscale)
|
||||
return x2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_tokens", [256])
|
||||
@pytest.mark.parametrize("hidden_size", [64])
|
||||
@pytest.mark.parametrize("cutlass_fp8_enabled",
|
||||
[True, False] if CUTLASS_FP8_SUPPORTED else [False])
|
||||
@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"],
|
||||
reason="Only test on CUDA and ROCm")
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
def test_fusion_silu_and_mul_quant(num_tokens, hidden_size,
|
||||
cutlass_fp8_enabled):
|
||||
torch.set_default_device("cuda")
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
||||
@ -40,11 +63,11 @@ def test_fusion_silu_and_mul_quant(num_tokens, hidden_size):
|
||||
pass_config=PassConfig(enable_fusion=True, enable_noop=True))
|
||||
fusion_pass = ActivationQuantFusionPass(config)
|
||||
|
||||
backend = TestBackend(fusion_pass)
|
||||
model = TestModel()
|
||||
backend = TestBackend(NoOpEliminationPass(config), fusion_pass)
|
||||
model = TestModel(hidden_size, cutlass_fp8_enabled)
|
||||
|
||||
# First dimension dynamic
|
||||
x = torch.rand(num_tokens, hidden_size)
|
||||
x = torch.rand(num_tokens, hidden_size * 2)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
result = model(x)
|
||||
|
@ -804,7 +804,7 @@ class VllmRunner:
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: Union[list[str], list[torch.Tensor]],
|
||||
prompts: Union[list[str], list[torch.Tensor], list[int]],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
@ -826,11 +826,16 @@ class VllmRunner:
|
||||
if audios is not None and (audio := audios[i]) is not None:
|
||||
multi_modal_data["audio"] = audio
|
||||
|
||||
text_prompt_kwargs = {
|
||||
("prompt" if isinstance(prompt, str) else "prompt_embeds"):
|
||||
prompt,
|
||||
text_prompt_kwargs: dict[str, Any] = {
|
||||
"multi_modal_data": multi_modal_data or None
|
||||
}
|
||||
if isinstance(prompt, str):
|
||||
text_prompt_kwargs["prompt"] = prompt
|
||||
elif isinstance(prompt, list):
|
||||
text_prompt_kwargs["prompt_token_ids"] = prompt
|
||||
else:
|
||||
text_prompt_kwargs["prompt_embeds"] = prompt
|
||||
|
||||
inputs.append(TextPrompt(**text_prompt_kwargs))
|
||||
|
||||
return inputs
|
||||
|
@ -14,8 +14,9 @@ from typing import Literal, NamedTuple, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from vllm.config import TaskOption
|
||||
from vllm.config import _FLOAT16_NOT_SUPPORTED_MODELS, TaskOption
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
|
||||
from ..models.registry import HF_EXAMPLE_MODELS
|
||||
from ..utils import compare_two_settings, create_new_process_for_each_test
|
||||
@ -158,7 +159,7 @@ TEXT_GENERATION_MODELS = {
|
||||
"databricks/dbrx-instruct": PPTestSettings.fast(load_format="dummy"),
|
||||
"Deci/DeciLM-7B-instruct": PPTestSettings.fast(),
|
||||
"deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(),
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(),
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(tp_base=2),
|
||||
"LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(),
|
||||
"tiiuae/falcon-7b": PPTestSettings.fast(),
|
||||
"google/gemma-1.1-2b-it": PPTestSettings.fast(),
|
||||
@ -210,9 +211,11 @@ TEXT_GENERATION_MODELS = {
|
||||
|
||||
EMBEDDING_MODELS = { # type: ignore[var-annotated]
|
||||
# [Text-only]
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(load_format="dummy"),
|
||||
"intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(task="embed"),
|
||||
"BAAI/bge-multilingual-gemma2": PPTestSettings.fast(task="embed"),
|
||||
"Qwen/Qwen2.5-Math-RM-72B": PPTestSettings.fast(
|
||||
load_format="dummy", task="embed"
|
||||
),
|
||||
}
|
||||
|
||||
MULTIMODAL_MODELS = {
|
||||
@ -248,6 +251,7 @@ TEST_MODELS = [
|
||||
"meta-llama/Llama-3.2-1B-Instruct",
|
||||
"ArthurZ/Ilama-3.2-1B",
|
||||
"ibm/PowerLM-3b",
|
||||
"deepseek-ai/DeepSeek-V2-Lite-Chat",
|
||||
# [LANGUAGE EMBEDDING]
|
||||
"intfloat/e5-mistral-7b-instruct",
|
||||
"BAAI/bge-multilingual-gemma2",
|
||||
@ -287,6 +291,11 @@ def _compare_tp(
|
||||
trust_remote_code = model_info.trust_remote_code
|
||||
tokenizer_mode = model_info.tokenizer_mode
|
||||
hf_overrides = model_info.hf_overrides
|
||||
hf_config = get_config(model_id, trust_remote_code)
|
||||
|
||||
dtype = "float16"
|
||||
if hf_config.model_type in _FLOAT16_NOT_SUPPORTED_MODELS:
|
||||
dtype = "bfloat16"
|
||||
|
||||
if load_format == "dummy":
|
||||
# Avoid OOM
|
||||
@ -316,7 +325,7 @@ def _compare_tp(
|
||||
common_args = [
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"float16",
|
||||
dtype,
|
||||
"--max-model-len",
|
||||
"2048",
|
||||
"--max-num-seqs",
|
||||
@ -338,6 +347,7 @@ def _compare_tp(
|
||||
common_args.extend(["--hf-overrides", json.dumps(hf_overrides)])
|
||||
|
||||
specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill
|
||||
testing_ray_compiled_graph = False
|
||||
if distributed_backend == "ray" and (vllm_major_version == "1"
|
||||
or specific_case):
|
||||
# For V1, test Ray Compiled Graph for all the tests
|
||||
@ -351,6 +361,7 @@ def _compare_tp(
|
||||
# Temporary. Currently when zeromq + SPMD is used, it does not properly
|
||||
# terminate because of a Ray Compiled Graph issue.
|
||||
common_args.append("--disable-frontend-multiprocessing")
|
||||
testing_ray_compiled_graph = True
|
||||
elif distributed_backend == "mp":
|
||||
# Both V0/V1 of multiprocessing executor support PP
|
||||
pp_env = {
|
||||
@ -394,7 +405,6 @@ def _compare_tp(
|
||||
tp_env,
|
||||
method=method)
|
||||
except Exception:
|
||||
testing_ray_compiled_graph = pp_env is not None
|
||||
if testing_ray_compiled_graph and vllm_major_version == "0":
|
||||
# Ray Compiled Graph tests are flaky for V0,
|
||||
# so we don't want to fail the test
|
||||
|
@ -4,6 +4,7 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed
|
||||
@ -177,6 +178,38 @@ def test_pynccl_all_gather():
|
||||
distributed_run(all_gather_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def all_gatherv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
assert world_size <= 8
|
||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||
num_elems = sizes[rank]
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * 100
|
||||
result = torch.zeros(sum(sizes), dtype=torch.float32, device=device)
|
||||
|
||||
expected = torch.cat([
|
||||
torch.arange(sizes[r], dtype=torch.float32) + r * 100
|
||||
for r in range(world_size)
|
||||
]).to(device)
|
||||
|
||||
pynccl_comm.all_gatherv(result, tensor, sizes=sizes)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_all_gatherv():
|
||||
distributed_run(all_gatherv_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatter_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
@ -214,6 +247,43 @@ def test_pynccl_reduce_scatter():
|
||||
distributed_run(reduce_scatter_worker_fn, 2)
|
||||
|
||||
|
||||
@worker_fn_wrapper
|
||||
def reduce_scatterv_worker_fn():
|
||||
pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group,
|
||||
device=get_world_group().device)
|
||||
|
||||
rank = pynccl_comm.rank
|
||||
world_size = pynccl_comm.world_size
|
||||
device = f'cuda:{pynccl_comm.rank}'
|
||||
|
||||
assert world_size <= 8
|
||||
sizes = [81, 20, 57, 52, 81, 5, 49, 49][:world_size]
|
||||
num_elems = sum(sizes)
|
||||
tensor = torch.arange(num_elems, dtype=torch.float32,
|
||||
device=device) + rank * 100
|
||||
result = torch.zeros(sizes[rank], dtype=torch.float32, device=device)
|
||||
|
||||
# Calculate expected result for this rank's chunk
|
||||
all_tensors = [
|
||||
torch.arange(num_elems, dtype=torch.float32) + r * 100
|
||||
for r in range(world_size)
|
||||
]
|
||||
sizes_cumsum = np.cumsum(sizes)
|
||||
start = 0 if rank == 0 else sizes_cumsum[rank - 1]
|
||||
end = sizes_cumsum[rank]
|
||||
expected = sum(tensor[start:end] for tensor in all_tensors).to(device)
|
||||
|
||||
pynccl_comm.reduce_scatterv(result, tensor, sizes=sizes)
|
||||
torch.cuda.synchronize()
|
||||
torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_reduce_scatterv():
|
||||
distributed_run(reduce_scatterv_worker_fn, 2)
|
||||
|
||||
|
||||
@pytest.mark.skipif(torch.cuda.device_count() < 2,
|
||||
reason="Need at least 2 GPUs to run the test.")
|
||||
def test_pynccl_with_cudagraph():
|
||||
|
@ -2,7 +2,7 @@
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import json
|
||||
from argparse import ArgumentError, ArgumentTypeError
|
||||
from argparse import ArgumentError
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Annotated, Literal, Optional
|
||||
@ -12,8 +12,8 @@ import pytest
|
||||
from vllm.config import CompilationConfig, config
|
||||
from vllm.engine.arg_utils import (EngineArgs, contains_type, get_kwargs,
|
||||
get_type, get_type_hints, is_not_builtin,
|
||||
is_type, literal_to_kwargs, nullable_kvs,
|
||||
optional_type, parse_type)
|
||||
is_type, literal_to_kwargs, optional_type,
|
||||
parse_type)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
|
||||
@ -25,18 +25,10 @@ from vllm.utils import FlexibleArgumentParser
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
(json.loads, "foo=1,bar=2", {
|
||||
"foo": 1,
|
||||
"bar": 2
|
||||
}),
|
||||
])
|
||||
def test_parse_type(type, value, expected):
|
||||
parse_type_func = parse_type(type)
|
||||
context = nullcontext()
|
||||
if value == "foo=1,bar=2":
|
||||
context = pytest.warns(DeprecationWarning)
|
||||
with context:
|
||||
assert parse_type_func(value) == expected
|
||||
assert parse_type_func(value) == expected
|
||||
|
||||
|
||||
def test_optional_type():
|
||||
@ -203,34 +195,6 @@ def test_get_kwargs():
|
||||
assert kwargs["from_cli_config2"]["type"]('{"field": 2}').field == 4
|
||||
|
||||
|
||||
@pytest.mark.parametrize(("arg", "expected"), [
|
||||
(None, dict()),
|
||||
("image=16", {
|
||||
"image": 16
|
||||
}),
|
||||
("image=16,video=2", {
|
||||
"image": 16,
|
||||
"video": 2
|
||||
}),
|
||||
("Image=16, Video=2", {
|
||||
"image": 16,
|
||||
"video": 2
|
||||
}),
|
||||
])
|
||||
def test_limit_mm_per_prompt_parser(arg, expected):
|
||||
"""This functionality is deprecated and will be removed in the future.
|
||||
This argument should be passed as JSON string instead.
|
||||
|
||||
TODO: Remove with nullable_kvs."""
|
||||
parser = EngineArgs.add_cli_args(FlexibleArgumentParser())
|
||||
if arg is None:
|
||||
args = parser.parse_args([])
|
||||
else:
|
||||
args = parser.parse_args(["--limit-mm-per-prompt", arg])
|
||||
|
||||
assert args.limit_mm_per_prompt == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("arg", "expected"),
|
||||
[
|
||||
@ -326,18 +290,6 @@ def test_prefix_cache_default():
|
||||
assert not engine_args.enable_prefix_caching
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("arg"),
|
||||
[
|
||||
"image", # Missing =
|
||||
"image=4,image=5", # Conflicting values
|
||||
"image=video=4" # Too many = in tokenized arg
|
||||
])
|
||||
def test_bad_nullable_kvs(arg):
|
||||
with pytest.raises(ArgumentTypeError):
|
||||
nullable_kvs(arg)
|
||||
|
||||
|
||||
# yapf: disable
|
||||
@pytest.mark.parametrize(("arg", "expected", "option"), [
|
||||
(None, None, "mm-processor-kwargs"),
|
||||
|
@ -69,6 +69,11 @@ def test_lm_eval_accuracy_v1_engine(model, monkeypatch: pytest.MonkeyPatch):
|
||||
more_args = None
|
||||
if current_platform.is_tpu():
|
||||
# Limit compilation time for TPU V1
|
||||
|
||||
if model == "google/gemma-3-1b-it":
|
||||
# TPU + google/gemma-3-1b-it + xet doesn't work well.
|
||||
m.setenv("HF_HUB_DISABLE_XET", "1")
|
||||
|
||||
more_args = "max_model_len=2048,max_num_seqs=64"
|
||||
|
||||
# Add TP test (if provided)
|
||||
|
@ -1113,10 +1113,7 @@ async def test_http_chat_no_model_name_with_curl(server: RemoteOpenAIServer):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("model_name", [MODEL_NAME, ""])
|
||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
||||
model_name: str):
|
||||
|
||||
async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer):
|
||||
openai_api_key = "EMPTY"
|
||||
openai_api_base = f"http://localhost:{server.port}/v1"
|
||||
|
||||
@ -1135,3 +1132,35 @@ async def test_http_chat_no_model_name_with_openai(server: RemoteOpenAIServer,
|
||||
messages=messages,
|
||||
)
|
||||
assert response.model == MODEL_NAME
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
messages = [{
|
||||
"role": "system",
|
||||
"content": "you are a helpful assistant"
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "what is 1+1?"
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"max_completion_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": False,
|
||||
}
|
||||
|
||||
chat_completion = await client.chat.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
assert chat_output["choices"] == invocation_output["choices"]
|
||||
|
@ -155,3 +155,29 @@ def test_batch_classification_empty_list(server: RemoteOpenAIServer,
|
||||
assert output.object == "list"
|
||||
assert isinstance(output.data, list)
|
||||
assert len(output.data) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": "This product was excellent and exceeded my expectations"
|
||||
}
|
||||
|
||||
classification_response = requests.post(server.url_for("classify"),
|
||||
json=request_args)
|
||||
classification_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
classification_output = classification_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert classification_output.keys() == invocation_output.keys()
|
||||
for classification_data, invocation_data in zip(
|
||||
classification_output["data"], invocation_output["data"]):
|
||||
assert classification_data.keys() == invocation_data.keys()
|
||||
assert classification_data["probs"] == pytest.approx(
|
||||
invocation_data["probs"], rel=0.01)
|
||||
|
@ -11,6 +11,7 @@ import openai # use the official client for correctness check
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
import regex as re
|
||||
import requests
|
||||
# downloading lora to test lora requests
|
||||
from huggingface_hub import snapshot_download
|
||||
from openai import BadRequestError
|
||||
@ -833,3 +834,27 @@ async def test_echo_stream_completion(client: openai.AsyncOpenAI,
|
||||
assert content is not None and saying in content
|
||||
else:
|
||||
assert content is not None and saying not in content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello, my name is",
|
||||
"max_tokens": 5,
|
||||
"temperature": 0.0,
|
||||
"logprobs": None,
|
||||
}
|
||||
|
||||
completion = await client.completions.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
assert completion_output["choices"] == invocation_output["choices"]
|
||||
|
@ -14,6 +14,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...models.language.pooling.embed_utils import (
|
||||
run_embedding_correctness_test)
|
||||
from ...models.utils import check_embeddings_close
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "intfloat/multilingual-e5-small"
|
||||
@ -296,3 +297,75 @@ async def test_single_embedding_truncation_invalid(client: openai.AsyncOpenAI,
|
||||
assert "error" in response.object
|
||||
assert "truncate_prompt_tokens value is greater than max_model_len. "\
|
||||
"Please, select a smaller truncation size." in response.message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer,
|
||||
client: openai.AsyncOpenAI):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = await client.embeddings.create(**request_args)
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.model_dump()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
for completion_data, invocation_data in zip(completion_output["data"],
|
||||
invocation_output["data"]):
|
||||
assert completion_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(embeddings_0_lst=[completion_data["embedding"]],
|
||||
embeddings_1_lst=[invocation_data["embedding"]],
|
||||
name_0="completion",
|
||||
name_1="invocation")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("v1/embeddings"),
|
||||
json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
for chat_data, invocation_data in zip(chat_output["data"],
|
||||
invocation_output["data"]):
|
||||
assert chat_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(embeddings_0_lst=[chat_data["embedding"]],
|
||||
embeddings_1_lst=[invocation_data["embedding"]],
|
||||
name_0="chat",
|
||||
name_1="invocation")
|
||||
|
@ -1,5 +1,6 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import json
|
||||
from typing import Final
|
||||
|
||||
import pytest
|
||||
@ -29,7 +30,7 @@ def server():
|
||||
"--enforce-eager",
|
||||
"--trust-remote-code",
|
||||
"--limit-mm-per-prompt",
|
||||
f"image={MAXIMUM_IMAGES}",
|
||||
json.dumps({"image": MAXIMUM_IMAGES}),
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
|
@ -13,7 +13,7 @@ from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MODEL_NAME = "jason9693/Qwen2.5-1.5B-apeach"
|
||||
MODEL_NAME = "internlm/internlm2-1_8b-reward"
|
||||
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
|
||||
|
||||
|
||||
@ -21,15 +21,16 @@ DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' +
|
||||
def server():
|
||||
args = [
|
||||
"--task",
|
||||
"classify",
|
||||
"reward",
|
||||
# use half precision for speed and memory savings in CI environment
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
"--enforce-eager",
|
||||
"--max-model-len",
|
||||
"8192",
|
||||
"512",
|
||||
"--chat-template",
|
||||
DUMMY_CHAT_TEMPLATE,
|
||||
"--trust-remote-code",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -57,10 +58,10 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 7
|
||||
assert poolings.usage.total_tokens == 7
|
||||
assert poolings.usage.prompt_tokens == 8
|
||||
assert poolings.usage.total_tokens == 8
|
||||
|
||||
# test using token IDs
|
||||
input_tokens = [1, 1, 1, 1, 1]
|
||||
@ -77,7 +78,7 @@ async def test_single_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 1
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 5
|
||||
assert poolings.usage.total_tokens == 5
|
||||
@ -104,10 +105,10 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 3
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 8
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 25
|
||||
assert poolings.usage.total_tokens == 25
|
||||
assert poolings.usage.prompt_tokens == 29
|
||||
assert poolings.usage.total_tokens == 29
|
||||
|
||||
# test list[list[int]]
|
||||
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
|
||||
@ -125,7 +126,7 @@ async def test_batch_pooling(server: RemoteOpenAIServer, model_name: str):
|
||||
|
||||
assert poolings.id is not None
|
||||
assert len(poolings.data) == 4
|
||||
assert len(poolings.data[0].data) == 2
|
||||
assert len(poolings.data[0].data) == 5
|
||||
assert poolings.usage.completion_tokens == 0
|
||||
assert poolings.usage.prompt_tokens == 17
|
||||
assert poolings.usage.total_tokens == 17
|
||||
@ -157,7 +158,11 @@ async def test_conversation_pooling(server: RemoteOpenAIServer,
|
||||
chat_response.raise_for_status()
|
||||
chat_poolings = PoolingResponse.model_validate(chat_response.json())
|
||||
|
||||
tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast")
|
||||
tokenizer = get_tokenizer(
|
||||
tokenizer_name=model_name,
|
||||
tokenizer_mode="fast",
|
||||
trust_remote_code=True,
|
||||
)
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages,
|
||||
chat_template=DUMMY_CHAT_TEMPLATE,
|
||||
@ -206,6 +211,9 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
)
|
||||
float_response.raise_for_status()
|
||||
responses_float = PoolingResponse.model_validate(float_response.json())
|
||||
float_data = [
|
||||
np.array(d.data).squeeze(-1).tolist() for d in responses_float.data
|
||||
]
|
||||
|
||||
base64_response = requests.post(
|
||||
server.url_for("pooling"),
|
||||
@ -224,11 +232,10 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
np.frombuffer(base64.b64decode(data.data),
|
||||
dtype="float32").tolist())
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_float.data],
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
check_embeddings_close(embeddings_0_lst=float_data,
|
||||
embeddings_1_lst=decoded_responses_base64_data,
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
|
||||
# Default response is float32 decoded from base64 by OpenAI Client
|
||||
default_response = requests.post(
|
||||
@ -240,9 +247,83 @@ async def test_batch_base64_pooling(server: RemoteOpenAIServer,
|
||||
)
|
||||
default_response.raise_for_status()
|
||||
responses_default = PoolingResponse.model_validate(default_response.json())
|
||||
default_data = [
|
||||
np.array(d.data).squeeze(-1).tolist() for d in responses_default.data
|
||||
]
|
||||
|
||||
check_embeddings_close(
|
||||
embeddings_0_lst=[d.data for d in responses_default.data],
|
||||
embeddings_1_lst=[d.data for d in responses_default.data],
|
||||
name_0="float32",
|
||||
name_1="base64")
|
||||
check_embeddings_close(embeddings_0_lst=float_data,
|
||||
embeddings_1_lst=default_data,
|
||||
name_0="float32",
|
||||
name_1="default")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations(server: RemoteOpenAIServer):
|
||||
input_texts = [
|
||||
"The chef prepared a delicious meal.",
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"input": input_texts,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
completion_response = requests.post(server.url_for("pooling"),
|
||||
json=request_args)
|
||||
completion_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
completion_output = completion_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert completion_output.keys() == invocation_output.keys()
|
||||
for completion_data, invocation_data in zip(completion_output["data"],
|
||||
invocation_output["data"]):
|
||||
assert completion_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(embeddings_0_lst=completion_data["data"],
|
||||
embeddings_1_lst=invocation_data["data"],
|
||||
name_0="completion",
|
||||
name_1="invocation")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invocations_conversation(server: RemoteOpenAIServer):
|
||||
messages = [{
|
||||
"role": "user",
|
||||
"content": "The cat sat on the mat.",
|
||||
}, {
|
||||
"role": "assistant",
|
||||
"content": "A feline was resting on a rug.",
|
||||
}, {
|
||||
"role": "user",
|
||||
"content": "Stars twinkle brightly in the night sky.",
|
||||
}]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"messages": messages,
|
||||
"encoding_format": "float",
|
||||
}
|
||||
|
||||
chat_response = requests.post(server.url_for("pooling"), json=request_args)
|
||||
chat_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
chat_output = chat_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert chat_output.keys() == invocation_output.keys()
|
||||
for chat_data, invocation_data in zip(chat_output["data"],
|
||||
invocation_output["data"]):
|
||||
assert chat_data.keys() == invocation_data.keys()
|
||||
check_embeddings_close(embeddings_0_lst=chat_data["data"],
|
||||
embeddings_1_lst=invocation_data["data"],
|
||||
name_0="chat",
|
||||
name_1="invocation")
|
||||
|
@ -94,3 +94,34 @@ def test_rerank_max_model_len(server: RemoteOpenAIServer, model_name: str):
|
||||
# Assert just a small fragments of the response
|
||||
assert "Please reduce the length of the input." in \
|
||||
rerank_response.text
|
||||
|
||||
|
||||
def test_invocations(server: RemoteOpenAIServer):
|
||||
query = "What is the capital of France?"
|
||||
documents = [
|
||||
"The capital of Brazil is Brasilia.", "The capital of France is Paris."
|
||||
]
|
||||
|
||||
request_args = {
|
||||
"model": MODEL_NAME,
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
}
|
||||
|
||||
rerank_response = requests.post(server.url_for("rerank"),
|
||||
json=request_args)
|
||||
rerank_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
rerank_output = rerank_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert rerank_output.keys() == invocation_output.keys()
|
||||
for rerank_result, invocations_result in zip(rerank_output["results"],
|
||||
invocation_output["results"]):
|
||||
assert rerank_result.keys() == invocations_result.keys()
|
||||
assert rerank_result["relevance_score"] == pytest.approx(
|
||||
invocations_result["relevance_score"], rel=0.01)
|
||||
|
@ -191,3 +191,32 @@ class TestModel:
|
||||
assert score_response.status_code == 400
|
||||
assert "Please, select a smaller truncation size." in \
|
||||
score_response.text
|
||||
|
||||
def test_invocations(self, server: RemoteOpenAIServer, model: dict[str,
|
||||
Any]):
|
||||
text_1 = "What is the capital of France?"
|
||||
text_2 = "The capital of France is Paris."
|
||||
|
||||
request_args = {
|
||||
"model": model["name"],
|
||||
"text_1": text_1,
|
||||
"text_2": text_2,
|
||||
}
|
||||
|
||||
score_response = requests.post(server.url_for("score"),
|
||||
json=request_args)
|
||||
score_response.raise_for_status()
|
||||
|
||||
invocation_response = requests.post(server.url_for("invocations"),
|
||||
json=request_args)
|
||||
invocation_response.raise_for_status()
|
||||
|
||||
score_output = score_response.json()
|
||||
invocation_output = invocation_response.json()
|
||||
|
||||
assert score_output.keys() == invocation_output.keys()
|
||||
for score_data, invocation_data in zip(score_output["data"],
|
||||
invocation_output["data"]):
|
||||
assert score_data.keys() == invocation_data.keys()
|
||||
assert score_data["score"] == pytest.approx(
|
||||
invocation_data["score"], rel=0.01)
|
||||
|
@ -32,6 +32,7 @@ def server(zephyr_lora_added_tokens_files: str): # noqa: F811
|
||||
f"zephyr-lora2={zephyr_lora_added_tokens_files}",
|
||||
"--max-lora-rank",
|
||||
"64",
|
||||
"--enable-tokenizer-info-endpoint",
|
||||
]
|
||||
|
||||
with RemoteOpenAIServer(MODEL_NAME, args) as remote_server:
|
||||
@ -283,3 +284,106 @@ async def test_detokenize(
|
||||
response.raise_for_status()
|
||||
|
||||
assert response.json() == {"prompt": prompt}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"model_name,tokenizer_name",
|
||||
[(MODEL_NAME, MODEL_NAME), ("zephyr-lora2", "zephyr-lora2")],
|
||||
indirect=["tokenizer_name"],
|
||||
)
|
||||
async def test_tokenizer_info_basic(
|
||||
server: RemoteOpenAIServer,
|
||||
model_name: str,
|
||||
tokenizer_name: str,
|
||||
):
|
||||
"""Test basic tokenizer info endpoint functionality."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
assert "tokenizer_class" in result
|
||||
assert isinstance(result["tokenizer_class"], str)
|
||||
assert result["tokenizer_class"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_schema(server: RemoteOpenAIServer):
|
||||
"""Test that the response matches expected schema types."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
field_types = {
|
||||
"add_bos_token": bool,
|
||||
"add_prefix_space": bool,
|
||||
"clean_up_tokenization_spaces": bool,
|
||||
"split_special_tokens": bool,
|
||||
"bos_token": str,
|
||||
"eos_token": str,
|
||||
"pad_token": str,
|
||||
"unk_token": str,
|
||||
"chat_template": str,
|
||||
"errors": str,
|
||||
"model_max_length": int,
|
||||
"additional_special_tokens": list,
|
||||
"added_tokens_decoder": dict,
|
||||
}
|
||||
for field, expected_type in field_types.items():
|
||||
if field in result and result[field] is not None:
|
||||
assert isinstance(
|
||||
result[field],
|
||||
expected_type), (f"{field} should be {expected_type.__name__}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_added_tokens_structure(
|
||||
server: RemoteOpenAIServer, ):
|
||||
"""Test added_tokens_decoder structure if present."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
added_tokens = result.get("added_tokens_decoder")
|
||||
if added_tokens:
|
||||
for token_id, token_info in added_tokens.items():
|
||||
assert isinstance(token_id, str), "Token IDs should be strings"
|
||||
assert isinstance(token_info, dict), "Token info should be a dict"
|
||||
assert "content" in token_info, "Token info should have content"
|
||||
assert "special" in token_info, (
|
||||
"Token info should have special flag")
|
||||
assert isinstance(token_info["special"],
|
||||
bool), ("Special flag should be boolean")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_consistency_with_tokenize(
|
||||
server: RemoteOpenAIServer, ):
|
||||
"""Test that tokenizer info is consistent with tokenization endpoint."""
|
||||
info_response = requests.get(server.url_for("tokenizer_info"))
|
||||
info_response.raise_for_status()
|
||||
info = info_response.json()
|
||||
tokenize_response = requests.post(
|
||||
server.url_for("tokenize"),
|
||||
json={
|
||||
"model": MODEL_NAME,
|
||||
"prompt": "Hello world!"
|
||||
},
|
||||
)
|
||||
tokenize_response.raise_for_status()
|
||||
tokenize_result = tokenize_response.json()
|
||||
info_max_len = info.get("model_max_length")
|
||||
tokenize_max_len = tokenize_result.get("max_model_len")
|
||||
if info_max_len and tokenize_max_len:
|
||||
assert info_max_len >= tokenize_max_len, (
|
||||
"Info max length should be >= tokenize max length")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tokenizer_info_chat_template(server: RemoteOpenAIServer):
|
||||
"""Test chat template is properly included."""
|
||||
response = requests.get(server.url_for("tokenizer_info"))
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
chat_template = result.get("chat_template")
|
||||
if chat_template:
|
||||
assert isinstance(chat_template,
|
||||
str), ("Chat template should be a string")
|
||||
assert chat_template.strip(), "Chat template should not be empty"
|
@ -17,6 +17,11 @@ from vllm.assets.audio import AudioAsset
|
||||
|
||||
from ...utils import RemoteOpenAIServer
|
||||
|
||||
MISTRAL_FORMAT_ARGS = [
|
||||
"--tokenizer_mode", "mistral", "--config_format", "mistral",
|
||||
"--load_format", "mistral"
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mary_had_lamb():
|
||||
@ -33,9 +38,15 @@ def winning_call():
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_audio(mary_had_lamb):
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
@pytest.mark.parametrize(
|
||||
"model_name",
|
||||
["openai/whisper-large-v3-turbo", "mistralai/Voxtral-Mini-3B-2507"])
|
||||
async def test_basic_audio(mary_had_lamb, model_name):
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
if model_name.startswith("mistralai"):
|
||||
server_args += MISTRAL_FORMAT_ARGS
|
||||
|
||||
# Based on https://github.com/openai/openai-cookbook/blob/main/examples/Whisper_prompting_guide.ipynb.
|
||||
with RemoteOpenAIServer(model_name, server_args) as remote_server:
|
||||
client = remote_server.get_async_client()
|
||||
@ -65,10 +76,13 @@ async def test_bad_requests(mary_had_lamb):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_long_audio_request(mary_had_lamb):
|
||||
model_name = "openai/whisper-large-v3-turbo"
|
||||
@pytest.mark.parametrize("model_name", ["openai/whisper-large-v3-turbo"])
|
||||
async def test_long_audio_request(mary_had_lamb, model_name):
|
||||
server_args = ["--enforce-eager"]
|
||||
|
||||
if model_name.startswith("openai"):
|
||||
return
|
||||
|
||||
mary_had_lamb.seek(0)
|
||||
audio, sr = librosa.load(mary_had_lamb)
|
||||
# Add small silence after each audio for repeatability in the split process
|
||||
@ -87,7 +101,8 @@ async def test_long_audio_request(mary_had_lamb):
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(transcription)['text']
|
||||
assert out.count("Mary had a little lamb") == 10
|
||||
counts = out.count("Mary had a little lamb")
|
||||
assert counts == 10, counts
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -154,7 +169,8 @@ async def test_streaming_response(winning_call):
|
||||
file=winning_call,
|
||||
language="en",
|
||||
temperature=0.0,
|
||||
extra_body=dict(stream=True))
|
||||
extra_body=dict(stream=True),
|
||||
timeout=30)
|
||||
# Reconstruct from chunks and validate
|
||||
async for chunk in res:
|
||||
# just a chunk
|
||||
@ -184,7 +200,8 @@ async def test_stream_options(winning_call):
|
||||
temperature=0.0,
|
||||
extra_body=dict(stream=True,
|
||||
stream_include_usage=True,
|
||||
stream_continuous_usage_stats=True))
|
||||
stream_continuous_usage_stats=True),
|
||||
timeout=30)
|
||||
final = False
|
||||
continuous = True
|
||||
async for chunk in res:
|
||||
|
@ -39,8 +39,8 @@ async def test_basic_audio(foscolo):
|
||||
# TODO remove once language detection is implemented
|
||||
extra_body=dict(language="it"),
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip()
|
||||
assert "Nor will I ever touch the sacred" in out
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
assert "greek sea" in out
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -168,5 +168,4 @@ async def test_long_audio_request(foscolo):
|
||||
response_format="text",
|
||||
temperature=0.0)
|
||||
out = json.loads(translation)['text'].strip().lower()
|
||||
# TODO investigate higher model uncertainty in for longer translations.
|
||||
assert out.count("nor will i ever") == 2
|
||||
assert out.count("greek sea") == 2
|
||||
|
@ -0,0 +1,140 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from typing import Optional
|
||||
|
||||
import flashinfer
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
if not current_platform.is_device_capability(100):
|
||||
pytest.skip("This TRTLLM kernel requires NVIDIA Blackwell.",
|
||||
allow_module_level=True)
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
NUM_HEADS = [(64, 8), (16, 16), (40, 8), (32, 8)]
|
||||
HEAD_SIZES = [128]
|
||||
BLOCK_SIZES = [16, 32]
|
||||
DTYPES = [torch.float16, torch.bfloat16]
|
||||
NUM_BLOCKS = 32768 # Large enough to test overflow in index calculation.
|
||||
SOFT_CAPS = [None, 30.0, 50.0]
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
|
||||
@pytest.mark.parametrize("num_heads", NUM_HEADS)
|
||||
@pytest.mark.parametrize("head_size", HEAD_SIZES)
|
||||
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
|
||||
@pytest.mark.parametrize("kv_layout", ["HND"])
|
||||
@pytest.mark.parametrize("dtype", DTYPES)
|
||||
@pytest.mark.parametrize("soft_cap", SOFT_CAPS)
|
||||
@torch.inference_mode
|
||||
def test_flashinfer_trtllm_decode_with_baseline(
|
||||
kv_lens: list[int],
|
||||
num_heads: tuple[int, int],
|
||||
head_size: int,
|
||||
dtype: torch.dtype,
|
||||
block_size: int,
|
||||
soft_cap: Optional[float],
|
||||
kv_layout: str,
|
||||
) -> None:
|
||||
torch.set_default_device("cuda")
|
||||
current_platform.seed_everything(0)
|
||||
num_seqs = len(kv_lens)
|
||||
num_query_heads = num_heads[0]
|
||||
num_kv_heads = num_heads[1]
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
max_kv_len = max(kv_lens)
|
||||
scale = head_size**-0.5
|
||||
|
||||
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
|
||||
kv_cache_shape = None
|
||||
if kv_layout == "NHD":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, block_size, num_kv_heads, head_size)
|
||||
elif kv_layout == "HND":
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, block_size, head_size)
|
||||
else:
|
||||
raise ValueError(f"Invalid kv_layout: {kv_layout}")
|
||||
key_value_cache = torch.randn(kv_cache_shape, dtype=dtype)
|
||||
|
||||
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
|
||||
block_tables = torch.randint(0,
|
||||
NUM_BLOCKS,
|
||||
(num_seqs, max_num_blocks_per_seq),
|
||||
dtype=torch.int32)
|
||||
k_scale = v_scale = 1.0
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + block_size - 1) // block_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % block_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = block_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
|
||||
wrapper = flashinfer.\
|
||||
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, kv_layout,
|
||||
use_tensor_cores=(
|
||||
(num_query_heads//num_kv_heads) > 4)
|
||||
)
|
||||
wrapper.plan(kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
head_size,
|
||||
block_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=dtype,
|
||||
logits_soft_cap=soft_cap)
|
||||
|
||||
output = wrapper.run(query, key_value_cache, scale)
|
||||
|
||||
# TRTLLM Decode
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens,
|
||||
dtype=torch.int,
|
||||
device=query.device)
|
||||
output_trtllm = flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
query.contiguous(),
|
||||
key_value_cache,
|
||||
workspace_buffer,
|
||||
num_query_heads,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
block_size,
|
||||
max_kv_len,
|
||||
"auto",
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
torch.testing.assert_close(output, output_trtllm, atol=1e-2, rtol=1e-2), \
|
||||
f"{torch.max(torch.abs(output - output_trtllm))}"
|
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
0
tests/kernels/moe/modular_kernel_tools/__init__.py
Normal file
160
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
160
tests/kernels/moe/modular_kernel_tools/cli_args.py
Normal file
@ -0,0 +1,160 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import argparse
|
||||
|
||||
import torch
|
||||
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
|
||||
from .common import Config
|
||||
from .mk_objects import (MK_ALL_PREPARE_FINALIZE_TYPES, MK_FUSED_EXPERT_TYPES,
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
|
||||
|
||||
def make_config_arg_parser(description: str):
|
||||
|
||||
def to_pf_class_type(s: str) -> mk.FusedMoEPrepareAndFinalize:
|
||||
for pf in MK_ALL_PREPARE_FINALIZE_TYPES:
|
||||
if pf.__name__ == s:
|
||||
return pf
|
||||
raise ValueError(
|
||||
f"Cannot find a PrepareFinalize type that matches {s}")
|
||||
|
||||
def to_experts_class_type(s: str) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
for fe in MK_FUSED_EXPERT_TYPES:
|
||||
if fe.__name__ == s:
|
||||
return fe
|
||||
raise ValueError(f"Cannot find a FusedExperts type that matches {s}")
|
||||
|
||||
def to_quant_torch_dtype(s: str) -> torch.dtype:
|
||||
if s == "torch.float8_e4m3fn":
|
||||
return torch.float8_e4m3fn
|
||||
raise ValueError(f"Unsupported quant type {s}")
|
||||
|
||||
parser = argparse.ArgumentParser(description=description)
|
||||
|
||||
parser.add_argument(
|
||||
"--world-size",
|
||||
type=int,
|
||||
default=2,
|
||||
help="Number of ranks that participate in all2all",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pf-type",
|
||||
type=to_pf_class_type,
|
||||
required=True,
|
||||
help=("Choose a PrepareFinalize Type : "
|
||||
f"{[x.__name__ for x in MK_ALL_PREPARE_FINALIZE_TYPES]}"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--experts-type",
|
||||
type=to_experts_class_type,
|
||||
required=True,
|
||||
help=(f"Choose a FusedExpert type : "
|
||||
f"{[x.__name__ for x in MK_FUSED_EXPERT_TYPES]}"),
|
||||
)
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[64],
|
||||
help="num tokens per rank",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-k",
|
||||
type=int,
|
||||
default=7168,
|
||||
help="hidden-size",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
type=int,
|
||||
default=1024,
|
||||
help="N dimension of the first fused-moe matmul",
|
||||
)
|
||||
parser.add_argument("--num-experts",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Global num experts")
|
||||
parser.add_argument("--topk",
|
||||
nargs="+",
|
||||
type=int,
|
||||
default=[4, 1],
|
||||
help="num topk")
|
||||
parser.add_argument(
|
||||
"--fused-moe-chunk-size",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Fused moe chunk size used for the non-batched fused experts impl."
|
||||
)
|
||||
|
||||
# Quant args
|
||||
parser.add_argument("--quant-dtype",
|
||||
type=to_quant_torch_dtype,
|
||||
help="Quant datatype")
|
||||
parser.add_argument("--per-token-quantized-activations",
|
||||
action='store_true',
|
||||
help=("The input activations must be per-token "
|
||||
"quantized"))
|
||||
parser.add_argument("--per-channel-quantized-weights",
|
||||
action="store_true",
|
||||
help="The weights must be per-channel quantized.")
|
||||
parser.add_argument("--block-shape",
|
||||
nargs="+",
|
||||
type=int,
|
||||
help="Quantization block shape")
|
||||
|
||||
# Torch trace profile generation args
|
||||
parser.add_argument("--torch-trace-dir-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Get torch trace for single execution")
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _validate_args(args: argparse.Namespace):
|
||||
|
||||
if args.quant_dtype is not None:
|
||||
assert args.quant_dtype == torch.float8_e4m3fn
|
||||
if args.block_shape is not None:
|
||||
assert len(args.block_shape) == 2, (
|
||||
f"block shape must have 2 elements. got {args.block_shape}")
|
||||
|
||||
if args.experts_type in MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES:
|
||||
assert args.world_size == 1, (
|
||||
"Single GPU objects need world size set to 1")
|
||||
|
||||
if args.torch_trace_dir_path is not None:
|
||||
from pathlib import Path
|
||||
assert Path(args.torch_trace_dir_path).is_dir(), (
|
||||
f"Please create {args.torch_trace_dir_path}")
|
||||
|
||||
|
||||
def make_config(args: argparse.Namespace) -> Config:
|
||||
|
||||
_validate_args(args)
|
||||
|
||||
quant_config = None
|
||||
if args.quant_dtype is not None:
|
||||
quant_config = FusedMoEQuantConfig(
|
||||
quant_dtype=args.quant_dtype,
|
||||
per_act_token_quant=args.per_token_quantized_activations,
|
||||
per_out_ch_quant=args.per_channel_quantized_weights,
|
||||
block_shape=args.block_shape)
|
||||
|
||||
return Config(
|
||||
Ms=args.m,
|
||||
K=args.k,
|
||||
N=args.n,
|
||||
E=args.num_experts,
|
||||
topks=args.topk,
|
||||
dtype=torch.bfloat16, # hard-code
|
||||
quant_config=quant_config,
|
||||
prepare_finalize_type=args.pf_type,
|
||||
fused_experts_type=args.experts_type,
|
||||
fused_moe_chunk_size=args.fused_moe_chunk_size,
|
||||
world_size=args.world_size,
|
||||
torch_trace_dir_path=args.torch_trace_dir_path)
|
641
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
641
tests/kernels/moe/modular_kernel_tools/common.py
Normal file
@ -0,0 +1,641 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
|
||||
from tests.kernels.utils import torch_experts
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_dp_group, get_tensor_model_parallel_world_size
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import (
|
||||
FusedMoEConfig, FusedMoEParallelConfig, FusedMoEQuantConfig)
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
|
||||
from vllm.model_executor.layers.fused_moe.layer import (FusedMoEMethodBase,
|
||||
TritonExperts)
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_deep_gemm, has_pplx
|
||||
|
||||
from .parallel_utils import ProcessGroupInfo
|
||||
from .utils import (make_block_quant_fp8_weights, make_non_quant_weights,
|
||||
make_quant_fp8_weights, per_token_cast_to_fp8)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
|
||||
def _describe_tensor(t: Optional[torch.Tensor], name: str) -> str:
|
||||
if t is None:
|
||||
return f"{name} : None"
|
||||
else:
|
||||
return f"{name} : {t.shape} {t.dtype} {t.device}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Config:
|
||||
Ms: Union[list[int], int]
|
||||
K: int
|
||||
N: int
|
||||
E: int
|
||||
topks: Union[list[int], int]
|
||||
dtype: torch.dtype
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
prepare_finalize_type: mk.FusedMoEPrepareAndFinalize
|
||||
fused_experts_type: mk.FusedMoEPermuteExpertsUnpermute
|
||||
|
||||
fused_moe_chunk_size: Optional[int]
|
||||
world_size: int
|
||||
|
||||
torch_trace_dir_path: Optional[str] = None
|
||||
|
||||
def describe(self) -> str:
|
||||
s = ""
|
||||
s += "== Config: \n"
|
||||
s += f" world_size={self.world_size} \n"
|
||||
s += f" PF={self.prepare_finalize_type.__name__} \n"
|
||||
s += f" FE={self.fused_experts_type.__name__} \n"
|
||||
s += f" topk={self.topks} \n"
|
||||
s += f" dtype={self.dtype} \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n"
|
||||
s += " Quant: \n"
|
||||
s += f" fused_moe_chunk_size={self.fused_moe_chunk_size} \n "
|
||||
if self.quant_config is not None:
|
||||
s += f" q_dtype={self.quant_dtype} \n"
|
||||
s += f" q_block_shape={self.quant_block_shape} \n"
|
||||
s += f" q_per_out_ch_quant={self.is_per_out_ch_quant} \n"
|
||||
s += f" q_per_act_token={self.is_per_act_token_quant} \n"
|
||||
else:
|
||||
s += " quant=None \n"
|
||||
return s
|
||||
|
||||
@property
|
||||
def M(self) -> int:
|
||||
assert isinstance(self.Ms, int)
|
||||
return self.Ms
|
||||
|
||||
@property
|
||||
def quant_dtype(self) -> Optional[torch.dtype]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
return self.quant_config.quant_dtype
|
||||
|
||||
@property
|
||||
def is_per_act_token_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return self.quant_config.per_act_token_quant
|
||||
|
||||
@property
|
||||
def is_per_tensor_act_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return (not self.is_per_act_token_quant
|
||||
and self.quant_block_shape is None)
|
||||
|
||||
@property
|
||||
def is_per_out_ch_quant(self) -> bool:
|
||||
if self.quant_config is None:
|
||||
return False
|
||||
return self.quant_config.per_out_ch_quant
|
||||
|
||||
@property
|
||||
def quant_block_shape(self) -> Optional[list[int]]:
|
||||
if self.quant_config is None:
|
||||
return None
|
||||
return self.quant_config.block_shape
|
||||
|
||||
@property
|
||||
def topk(self) -> int:
|
||||
assert isinstance(self.topks, int)
|
||||
return self.topks
|
||||
|
||||
@property
|
||||
def topk_ids_dtype(self) -> Optional[torch.dtype]:
|
||||
topk_ids_dtype = None
|
||||
if self.prepare_finalize_type == PplxPrepareAndFinalize:
|
||||
topk_ids_dtype = torch.uint32
|
||||
elif self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]:
|
||||
topk_ids_dtype = torch.int64
|
||||
return topk_ids_dtype
|
||||
|
||||
@property
|
||||
def num_local_experts(self) -> int:
|
||||
return self.E // self.world_size
|
||||
|
||||
def make_env_data(self) -> tuple[VllmConfig, dict[Any, Any]]:
|
||||
"""
|
||||
make env data for vllm launch.
|
||||
"""
|
||||
vllm_config = VllmConfig()
|
||||
vllm_config.parallel_config.data_parallel_size = self.world_size
|
||||
vllm_config.parallel_config.enable_expert_parallel = True
|
||||
|
||||
env_dict = {
|
||||
"VLLM_ALL2ALL_BACKEND": self.all2all_backend(),
|
||||
"VLLM_USE_DEEP_GEMM": str(int(self.needs_deep_gemm())),
|
||||
}
|
||||
if self.fused_moe_chunk_size is not None:
|
||||
env_dict.update(
|
||||
{"VLLM_FUSED_MOE_CHUNK_SIZE": str(self.fused_moe_chunk_size)})
|
||||
return vllm_config, env_dict
|
||||
|
||||
def is_fp8_block_quantized(self):
|
||||
return (self.quant_dtype == torch.float8_e4m3fn
|
||||
and self.quant_block_shape is not None)
|
||||
|
||||
def is_batched_prepare_finalize(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def is_batched_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, BatchedDeepGemmExperts, BatchedTritonExperts,
|
||||
NaiveBatchedExperts, BatchedTritonOrDeepGemmExperts
|
||||
]
|
||||
|
||||
def is_standard_fused_experts(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
|
||||
def is_fe_16bit_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedTritonExperts, BatchedTritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts, TritonExperts
|
||||
]
|
||||
|
||||
def is_fe_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
|
||||
def is_fe_block_fp8_supported(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
TritonExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
]
|
||||
|
||||
def is_fe_supports_chunking(self):
|
||||
return self.fused_experts_type in [
|
||||
CutlassExpertsFp8, DeepGemmExperts, TritonOrDeepGemmExperts,
|
||||
TritonExperts
|
||||
]
|
||||
|
||||
def needs_deep_gemm(self):
|
||||
return self.fused_experts_type in [
|
||||
BatchedDeepGemmExperts,
|
||||
DeepGemmExperts,
|
||||
]
|
||||
|
||||
def needs_pplx(self):
|
||||
return self.prepare_finalize_type in [PplxPrepareAndFinalize]
|
||||
|
||||
def needs_deep_ep(self):
|
||||
return self.prepare_finalize_type in [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def all2all_backend(self):
|
||||
if self.needs_pplx():
|
||||
return "pplx"
|
||||
if self.prepare_finalize_type == DeepEPHTPrepareAndFinalize:
|
||||
return "deepep_high_throughput"
|
||||
if self.prepare_finalize_type == DeepEPLLPrepareAndFinalize:
|
||||
return "deepep_low_latency"
|
||||
return "naive"
|
||||
|
||||
def needs_all2all(self):
|
||||
return self.prepare_finalize_type in [
|
||||
PplxPrepareAndFinalize, DeepEPHTPrepareAndFinalize,
|
||||
DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
def is_valid(self):
|
||||
# Check prepare-finalize and fused-experts compatibility
|
||||
if self.is_batched_prepare_finalize():
|
||||
if not self.is_batched_fused_experts():
|
||||
return False
|
||||
else:
|
||||
if not self.is_standard_fused_experts():
|
||||
return False
|
||||
|
||||
use_chunking = self.fused_moe_chunk_size is not None
|
||||
if use_chunking and not self.is_fe_supports_chunking():
|
||||
return False
|
||||
|
||||
# Check quantization sanity
|
||||
if (int(self.is_per_act_token_quant) +
|
||||
int(self.is_per_tensor_act_quant) +
|
||||
int(self.quant_block_shape is not None)) > 1:
|
||||
# invalid quant config
|
||||
return False
|
||||
|
||||
# check bf16 / fp16 support
|
||||
is_16bit = (self.dtype.itemsize == 2 and self.quant_dtype is None)
|
||||
if is_16bit and not self.is_fe_16bit_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 support
|
||||
is_fp8 = self.quant_dtype == torch.float8_e4m3fn
|
||||
if is_fp8 and not self.is_fe_fp8_supported():
|
||||
return False
|
||||
|
||||
# Check fp8 block quanization support
|
||||
is_block_quatized = self.quant_block_shape is not None
|
||||
if is_block_quatized and not is_fp8:
|
||||
return False
|
||||
if is_block_quatized and not self.is_fe_block_fp8_supported():
|
||||
return False
|
||||
|
||||
# deep_gemm only works with block-quantized
|
||||
if self.needs_deep_gemm() and not is_block_quatized:
|
||||
return False
|
||||
|
||||
# Check dependencies
|
||||
if self.needs_deep_ep() and not has_deep_ep():
|
||||
return False
|
||||
if self.needs_deep_gemm() and not has_deep_gemm():
|
||||
return False
|
||||
if self.needs_pplx() and not has_pplx(): # noqa: SIM103
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@dataclass
|
||||
class WeightTensors:
|
||||
w1: torch.Tensor
|
||||
w2: torch.Tensor
|
||||
w1_scale: Optional[torch.Tensor]
|
||||
w2_scale: Optional[torch.Tensor]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Weight Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.w1, "w1")} \n'
|
||||
s += f' - {_describe_tensor(self.w2, "w2")} \n'
|
||||
s += f' - {_describe_tensor(self.w1_scale, "w1_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.w2_scale, "w2_scale")} \n'
|
||||
return s
|
||||
|
||||
def to_current_device(self):
|
||||
self.w1 = self.w1.to(device=torch.cuda.current_device())
|
||||
self.w2 = self.w2.to(device=torch.cuda.current_device())
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
if is_quantized:
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
self.w1_scale = self.w1_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
self.w2_scale = self.w2_scale.to(
|
||||
device=torch.cuda.current_device())
|
||||
|
||||
def slice_weights(self, rank: int,
|
||||
num_local_experts: int) -> "WeightTensors":
|
||||
s = rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
w1 = self.w1[s:e, :, :]
|
||||
w2 = self.w2[s:e, :, :]
|
||||
is_quantized = self.w1.dtype == torch.float8_e4m3fn
|
||||
w1_scale, w2_scale = (None, None)
|
||||
if is_quantized:
|
||||
assert self.w1_scale is not None
|
||||
assert self.w2_scale is not None
|
||||
w1_scale = self.w1_scale[s:e, :, :]
|
||||
w2_scale = self.w2_scale[s:e, :, :]
|
||||
return WeightTensors(w1, w2, w1_scale, w2_scale)
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config) -> "WeightTensors":
|
||||
|
||||
if config.quant_dtype is None:
|
||||
# just make normal dtype weights
|
||||
w1, w2 = make_non_quant_weights(e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
dtype=config.dtype)
|
||||
return WeightTensors(w1=w1, w2=w2, w1_scale=None, w2_scale=None)
|
||||
|
||||
assert config.quant_dtype == torch.float8_e4m3fn
|
||||
if not config.is_fp8_block_quantized():
|
||||
w1, w2, w1_scale, w2_scale = make_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
per_out_channel_quant=config.is_per_out_ch_quant,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
w1, w2, w1_scale, w2_scale = make_block_quant_fp8_weights(
|
||||
e=config.E,
|
||||
n=config.N,
|
||||
k=config.K,
|
||||
block_size=config.quant_block_shape,
|
||||
)
|
||||
return WeightTensors(w1=w1,
|
||||
w2=w2,
|
||||
w1_scale=w1_scale,
|
||||
w2_scale=w2_scale)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RankTensors:
|
||||
hidden_states: torch.Tensor
|
||||
hidden_states_scale: Optional[torch.Tensor]
|
||||
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
expert_map: Optional[torch.Tensor]
|
||||
|
||||
quant_config: Optional[FusedMoEQuantConfig]
|
||||
|
||||
def describe(self):
|
||||
s = ""
|
||||
s += "== Rank Tensors: \n"
|
||||
s += f' - {_describe_tensor(self.hidden_states, "HS")} \n'
|
||||
s += f' - {_describe_tensor(self.hidden_states_scale, "HS_scale")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_weights, "topk_weights")} \n'
|
||||
s += f' - {_describe_tensor(self.topk_ids, "topk_ids")} \n'
|
||||
s += f' - {_describe_tensor(self.expert_map, "expert_map")} \n'
|
||||
return s
|
||||
|
||||
@staticmethod
|
||||
def make_hidden_states(
|
||||
config: Config) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Return hidden_states
|
||||
"""
|
||||
m, k, dtype = (config.M, config.K, config.dtype)
|
||||
a = (torch.randn(
|
||||
(m, k), device=torch.cuda.current_device(), dtype=dtype) / 15.0)
|
||||
|
||||
if config.quant_dtype is None:
|
||||
return a, None
|
||||
|
||||
# We dequant and use that as hidden_states so the tests are stable.
|
||||
# quantizing and dequantizing yield slightly different results
|
||||
# depending on the hardware. Here we, quantize and dequantize
|
||||
# first - so further quantize and dequantize will yield the same
|
||||
# values.
|
||||
if config.is_per_tensor_act_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(
|
||||
a, use_per_token_if_dynamic=False)
|
||||
return a_q.float().mul(a_scales).to(dtype), a_scales
|
||||
|
||||
if config.is_per_act_token_quant:
|
||||
a_q, a_scales = ops.scaled_fp8_quant(a,
|
||||
use_per_token_if_dynamic=True)
|
||||
return a_q.float().mul(a_scales).to(dtype), None
|
||||
|
||||
assert config.quant_block_shape is not None
|
||||
block_k = config.quant_block_shape[1]
|
||||
a_q, a_scales = per_token_cast_to_fp8(a, block_size=block_k)
|
||||
return a_q.float().view(
|
||||
(-1, block_k)).mul(a_scales.view(-1, 1)).view(m, k).to(dtype), None
|
||||
|
||||
@staticmethod
|
||||
def make(config: Config, pgi: ProcessGroupInfo):
|
||||
|
||||
dtype = config.dtype
|
||||
topk, m, _ = (config.topk, config.M, config.K)
|
||||
hidden_states, hidden_states_scale = RankTensors.make_hidden_states(
|
||||
config)
|
||||
|
||||
num_local_experts, global_num_experts = (config.num_local_experts,
|
||||
config.E)
|
||||
score = torch.randn((m, global_num_experts),
|
||||
device="cuda",
|
||||
dtype=dtype)
|
||||
topk_weights, topk_ids, _ = fused_topk(hidden_states, score, topk,
|
||||
False)
|
||||
topk_ids = topk_ids.to(config.topk_ids_dtype)
|
||||
|
||||
# distribute topk_ids evenly
|
||||
for mi in range(m):
|
||||
topk_ids[mi] = torch.randperm(config.E)[:topk]
|
||||
topk_ids = topk_ids.to(device=torch.cuda.current_device())
|
||||
|
||||
expert_map = None
|
||||
if config.world_size > 1:
|
||||
expert_map = torch.full((global_num_experts, ),
|
||||
fill_value=-1,
|
||||
dtype=torch.int32)
|
||||
s = pgi.rank * num_local_experts
|
||||
e = s + num_local_experts
|
||||
expert_map[s:e] = torch.tensor(list(range(num_local_experts)))
|
||||
expert_map = expert_map.to(device=torch.cuda.current_device(),
|
||||
dtype=torch.int32)
|
||||
|
||||
return RankTensors(
|
||||
hidden_states=hidden_states,
|
||||
hidden_states_scale=hidden_states_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
quant_config=config.quant_config,
|
||||
)
|
||||
|
||||
|
||||
def reference_moe_impl(config: Config, weights: WeightTensors,
|
||||
rank_tensors: RankTensors) -> torch.Tensor:
|
||||
|
||||
return torch_experts(a=rank_tensors.hidden_states,
|
||||
w1=weights.w1,
|
||||
w2=weights.w2,
|
||||
topk_weight=rank_tensors.topk_weights,
|
||||
topk_ids=rank_tensors.topk_ids,
|
||||
global_num_experts=config.E,
|
||||
expert_map=None,
|
||||
w1_scale=weights.w1_scale,
|
||||
w2_scale=weights.w2_scale,
|
||||
a1_scale=rank_tensors.hidden_states_scale,
|
||||
quant_dtype=config.quant_dtype,
|
||||
per_act_token_quant=config.is_per_act_token_quant,
|
||||
block_shape=config.quant_block_shape,
|
||||
apply_router_weights_on_input=config.topk == 1)
|
||||
|
||||
|
||||
def make_fused_experts(
|
||||
config: Config, moe: FusedMoEConfig,
|
||||
num_dispatchers: int) -> mk.FusedMoEPermuteExpertsUnpermute:
|
||||
|
||||
use_fp8 = config.quant_dtype == torch.float8_e4m3fn
|
||||
batch_kwargs = {
|
||||
"max_num_tokens": moe.max_num_tokens,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
}
|
||||
quant_kwargs = {
|
||||
"use_fp8_w8a8": use_fp8,
|
||||
"use_int8_w8a8": False,
|
||||
"use_int8_w8a16": False,
|
||||
"use_int4_w4a16": False,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
deepgemm_kwargs = {"allow_deep_gemm": has_deep_gemm()}
|
||||
|
||||
if config.fused_experts_type == BatchedDeepGemmExperts:
|
||||
kwargs = batch_kwargs | {
|
||||
"block_shape": config.quant_block_shape,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
}
|
||||
print(f"Making BatchedDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making BatchedTritonExperts {kwargs} ...")
|
||||
experts = BatchedTritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == BatchedTritonOrDeepGemmExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making BatchedTritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = BatchedTritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == DeepGemmExperts:
|
||||
print("Making DeepGemmExperts () ...")
|
||||
experts = DeepGemmExperts()
|
||||
elif config.fused_experts_type == TritonExperts:
|
||||
kwargs = quant_kwargs
|
||||
print(f"Making TritonExperts {kwargs} ...")
|
||||
experts = TritonExperts(**kwargs)
|
||||
elif config.fused_experts_type == TritonOrDeepGemmExperts:
|
||||
kwargs = quant_kwargs | deepgemm_kwargs
|
||||
print(f"Making TritonOrDeepGemmExperts {kwargs} ...")
|
||||
experts = TritonOrDeepGemmExperts(**kwargs)
|
||||
elif config.fused_experts_type == NaiveBatchedExperts:
|
||||
kwargs = batch_kwargs | quant_kwargs
|
||||
print(f"Making NaiveBatchedExperts {kwargs} ...")
|
||||
experts = NaiveBatchedExperts(**kwargs)
|
||||
elif config.fused_experts_type == CutlassExpertsFp8:
|
||||
use_batched_format = config.is_batched_prepare_finalize()
|
||||
num_experts = (moe.num_local_experts
|
||||
if use_batched_format else moe.num_experts)
|
||||
kwargs = {
|
||||
"max_experts_per_worker": num_experts,
|
||||
"out_dtype": moe.in_dtype,
|
||||
"per_act_token_quant": config.is_per_act_token_quant,
|
||||
"per_out_ch_quant": config.is_per_out_ch_quant,
|
||||
"block_shape": config.quant_block_shape,
|
||||
"num_dispatchers": num_dispatchers,
|
||||
"use_batched_format": use_batched_format
|
||||
}
|
||||
print(f"Making CutlassExpertsFp8 {kwargs} ...")
|
||||
experts = CutlassExpertsFp8(**kwargs)
|
||||
|
||||
return experts
|
||||
|
||||
|
||||
def make_modular_kernel(config: Config,
|
||||
vllm_config: VllmConfig) -> mk.FusedMoEModularKernel:
|
||||
|
||||
def next_power_of_2(x):
|
||||
import math
|
||||
if x == 0:
|
||||
return 1
|
||||
return 2**math.ceil(math.log2(x))
|
||||
|
||||
# make moe config
|
||||
moe_parallel_config: FusedMoEParallelConfig = FusedMoEParallelConfig.make(
|
||||
tp_size_=get_tensor_model_parallel_world_size(),
|
||||
dp_size_=get_dp_group().world_size,
|
||||
vllm_parallel_config=vllm_config.parallel_config,
|
||||
)
|
||||
moe = FusedMoEConfig(
|
||||
num_experts=config.E,
|
||||
experts_per_token=config.topk,
|
||||
hidden_dim=config.K,
|
||||
num_local_experts=config.num_local_experts,
|
||||
moe_parallel_config=moe_parallel_config,
|
||||
in_dtype=config.dtype,
|
||||
quant_config=config.quant_config,
|
||||
max_num_tokens=next_power_of_2(config.M),
|
||||
)
|
||||
|
||||
# make modular kernel
|
||||
prepare_finalize = None
|
||||
if config.needs_all2all():
|
||||
prepare_finalize = FusedMoEMethodBase.maybe_make_prepare_finalize(moe)
|
||||
assert prepare_finalize is not None
|
||||
else:
|
||||
prepare_finalize = MoEPrepareAndFinalizeNoEP()
|
||||
|
||||
fused_experts = make_fused_experts(config, moe,
|
||||
prepare_finalize.num_dispatchers())
|
||||
|
||||
modular_kernel = mk.FusedMoEModularKernel(
|
||||
prepare_finalize=prepare_finalize, fused_experts=fused_experts)
|
||||
|
||||
return modular_kernel
|
||||
|
||||
|
||||
def run_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> torch.Tensor:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states.clone(
|
||||
), # impls might update the tensor in place
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
out = mk.forward(**mk_kwargs)
|
||||
|
||||
return out
|
173
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
173
tests/kernels/moe/modular_kernel_tools/make_feature_matrix.py
Normal file
@ -0,0 +1,173 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from enum import Enum
|
||||
from itertools import product
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import (Config, RankTensors, WeightTensors, reference_moe_impl,
|
||||
run_modular_kernel)
|
||||
from .mk_objects import (MK_FUSED_EXPERT_TYPES,
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES, MK_QUANT_CONFIGS)
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
class Result(Enum):
|
||||
PASS = 1
|
||||
FAIL = 2
|
||||
SKIP = 3
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
|
||||
# modular kernel out
|
||||
mk_out = run_modular_kernel(pgi, vllm_config, cfgx, weights,
|
||||
rank_tensors)
|
||||
|
||||
with set_current_vllm_config(vllm_config):
|
||||
ref_out = reference_moe_impl(cfgx, weights, rank_tensors)
|
||||
|
||||
torch.testing.assert_close(ref_out, mk_out, atol=3e-2, rtol=3e-2)
|
||||
|
||||
|
||||
def make_feature_matrix(csv_file_path: str):
|
||||
|
||||
from dataclasses import asdict
|
||||
|
||||
import pandas as pd
|
||||
|
||||
def add_to_results(config: Config,
|
||||
success: Result,
|
||||
results_df: Optional[pd.DataFrame] = None):
|
||||
config_dict = asdict(config)
|
||||
config_dict['prepare_finalize_type'] = config_dict[
|
||||
'prepare_finalize_type'].__name__
|
||||
config_dict['fused_experts_type'] = config_dict[
|
||||
'fused_experts_type'].__name__
|
||||
config_dict['per_tensor_act_quant'] = config.is_per_tensor_act_quant
|
||||
quant_config_dict = config_dict['quant_config']
|
||||
del config_dict['quant_config']
|
||||
if quant_config_dict is None:
|
||||
quant_config = FusedMoEQuantConfig(None)
|
||||
quant_config_dict = asdict(quant_config)
|
||||
|
||||
config_dict |= quant_config_dict
|
||||
result_dict = config_dict | {'success': success.name}
|
||||
|
||||
result_df = pd.DataFrame([result_dict])
|
||||
if results_df is None:
|
||||
results_df = result_df
|
||||
else:
|
||||
results_df = pd.concat([results_df, result_df], ignore_index=True)
|
||||
|
||||
return results_df
|
||||
|
||||
Ms = [64]
|
||||
Ks = [7168] # hidden sizes
|
||||
Ns = [2048]
|
||||
TOPKs = [[4, 1]]
|
||||
Es = [32]
|
||||
DTYPEs = [torch.bfloat16]
|
||||
PF_TYPES = MK_MULTI_GPU_PREPARE_FINALIZE_TYPES
|
||||
FE_TYPES = MK_FUSED_EXPERT_TYPES
|
||||
Q_TYPES = MK_QUANT_CONFIGS
|
||||
|
||||
combinations = list(
|
||||
product(Ms, Ks, Ns, Es, TOPKs, DTYPEs, PF_TYPES, FE_TYPES, Q_TYPES))
|
||||
|
||||
results_df: Optional[pd.DataFrame] = None
|
||||
for m, k, n, e, topks, dtype, pf_type, experts_type, quant_config in tqdm(
|
||||
combinations): #noqa: E501
|
||||
config = Config(Ms=[m],
|
||||
K=k,
|
||||
N=n,
|
||||
E=e,
|
||||
topks=topks,
|
||||
dtype=dtype,
|
||||
prepare_finalize_type=pf_type,
|
||||
fused_experts_type=experts_type,
|
||||
quant_config=quant_config,
|
||||
world_size=2,
|
||||
fused_moe_chunk_size=None)
|
||||
|
||||
success = None
|
||||
if config.is_valid():
|
||||
print(f"Running config : {config.describe()} ...")
|
||||
try:
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker,
|
||||
vllm_config, env_dict, config,
|
||||
weights)
|
||||
success = Result.PASS
|
||||
except Exception as _:
|
||||
success = Result.FAIL
|
||||
else:
|
||||
success = Result.SKIP
|
||||
|
||||
results_df = add_to_results(config, success, results_df)
|
||||
|
||||
if results_df is not None:
|
||||
results_df.to_csv(f"{csv_file_path}")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
parser = argparse.ArgumentParser(description=(
|
||||
"Make ModularKernel feature matrix \n"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.make_feature_matrix " #noqa: E501
|
||||
"-f ./feature_matrices/feature_matrix.csv"))
|
||||
|
||||
parser.add_argument("-f",
|
||||
"--feature-matrix-csv-file-path",
|
||||
type=str,
|
||||
required=True,
|
||||
help="File name to Generate a .csv file")
|
||||
args = parser.parse_args()
|
||||
|
||||
csv_path = args.feature_matrix_csv_file_path
|
||||
assert csv_path.endswith(
|
||||
'csv'), f"Need a file path ending with .csv, got {csv_path}"
|
||||
assert Path(csv_path).parent.is_dir(
|
||||
), f"Cannot find parent directory for {Path(csv_path).parent}"
|
||||
|
||||
make_feature_matrix(args.feature_matrix_csv_file_path)
|
87
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
87
tests/kernels/moe/modular_kernel_tools/mk_objects.py
Normal file
@ -0,0 +1,87 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
# Fused experts and PrepareFinalize imports
|
||||
from vllm.model_executor.layers.fused_moe.batched_deep_gemm_moe import (
|
||||
BatchedDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.batched_triton_or_deep_gemm_moe import ( # noqa: E501
|
||||
BatchedTritonOrDeepGemmExperts)
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
|
||||
from vllm.model_executor.layers.fused_moe.cutlass_moe import CutlassExpertsFp8
|
||||
from vllm.model_executor.layers.fused_moe.deep_gemm_moe import DeepGemmExperts
|
||||
from vllm.model_executor.layers.fused_moe.fused_batched_moe import (
|
||||
BatchedTritonExperts, NaiveBatchedExperts)
|
||||
from vllm.model_executor.layers.fused_moe.layer import TritonExperts
|
||||
from vllm.model_executor.layers.fused_moe.prepare_finalize import (
|
||||
MoEPrepareAndFinalizeNoEP)
|
||||
from vllm.model_executor.layers.fused_moe.triton_deep_gemm_moe import (
|
||||
TritonOrDeepGemmExperts)
|
||||
from vllm.utils import has_deep_ep, has_pplx
|
||||
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
DeepEPLLPrepareAndFinalize)
|
||||
|
||||
if has_pplx():
|
||||
from vllm.model_executor.layers.fused_moe.pplx_prepare_finalize import (
|
||||
PplxPrepareAndFinalize)
|
||||
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES = []
|
||||
if has_pplx():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [PplxPrepareAndFinalize]
|
||||
if has_deep_ep():
|
||||
MK_MULTI_GPU_PREPARE_FINALIZE_TYPES += [
|
||||
DeepEPHTPrepareAndFinalize, DeepEPLLPrepareAndFinalize
|
||||
]
|
||||
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES = [MoEPrepareAndFinalizeNoEP]
|
||||
|
||||
MK_ALL_PREPARE_FINALIZE_TYPES = (MK_MULTI_GPU_PREPARE_FINALIZE_TYPES +
|
||||
MK_SINGLE_GPU_PREPARE_FINALIZE_TYPES)
|
||||
|
||||
MK_FUSED_EXPERT_TYPES = [
|
||||
BatchedDeepGemmExperts,
|
||||
BatchedTritonExperts,
|
||||
NaiveBatchedExperts,
|
||||
BatchedTritonOrDeepGemmExperts,
|
||||
CutlassExpertsFp8,
|
||||
DeepGemmExperts,
|
||||
TritonOrDeepGemmExperts,
|
||||
TritonExperts,
|
||||
]
|
||||
|
||||
MK_QUANT_CONFIGS = [
|
||||
None,
|
||||
# per-channel / per-column weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-channel / per-column weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=True,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-tensor activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=None),
|
||||
# per-tensor weights and per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=True,
|
||||
block_shape=None),
|
||||
# block-quantized weights and 128 block per-token activations
|
||||
FusedMoEQuantConfig(quant_dtype=torch.float8_e4m3fn,
|
||||
per_out_ch_quant=False,
|
||||
per_act_token_quant=False,
|
||||
block_shape=[128, 128]),
|
||||
# TODO (varun) : Should we test the following combinations ?
|
||||
# block-quantized weights and per-token activations
|
||||
# block-quantized weights and per-tensor activations
|
||||
]
|
138
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
138
tests/kernels/moe/modular_kernel_tools/parallel_utils.py
Normal file
@ -0,0 +1,138 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import dataclasses
|
||||
import os
|
||||
import traceback
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.config import VllmConfig, set_current_vllm_config
|
||||
from vllm.distributed import (init_distributed_environment,
|
||||
initialize_model_parallel)
|
||||
from vllm.utils import get_open_port
|
||||
|
||||
## Parallel Processes Utils
|
||||
|
||||
P = ParamSpec("P")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProcessGroupInfo:
|
||||
world_size: int
|
||||
world_local_size: int
|
||||
rank: int
|
||||
node_rank: int
|
||||
local_rank: int
|
||||
device: torch.device
|
||||
|
||||
|
||||
def _set_vllm_config(vllm_config: VllmConfig, world_size: int, rank: int,
|
||||
local_rank: int):
|
||||
|
||||
import tempfile
|
||||
temp_file = tempfile.mkstemp()[1]
|
||||
|
||||
set_current_vllm_config(vllm_config)
|
||||
with set_current_vllm_config(vllm_config):
|
||||
init_distributed_environment(
|
||||
world_size=world_size,
|
||||
rank=rank,
|
||||
distributed_init_method=f"file://{temp_file}",
|
||||
local_rank=local_rank,
|
||||
backend="nccl",
|
||||
)
|
||||
|
||||
initialize_model_parallel(
|
||||
tensor_model_parallel_size=vllm_config.parallel_config.
|
||||
tensor_parallel_size,
|
||||
pipeline_model_parallel_size=vllm_config.parallel_config.
|
||||
pipeline_parallel_size,
|
||||
)
|
||||
cpu_group = torch.distributed.new_group(list(range(world_size)),
|
||||
backend="gloo")
|
||||
return cpu_group
|
||||
|
||||
|
||||
def _worker_parallel_launch(
|
||||
local_rank: int,
|
||||
world_size: int,
|
||||
world_local_size: int,
|
||||
node_rank: int,
|
||||
init_method: str,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, Optional[VllmConfig], Any,
|
||||
P], None],
|
||||
vllm_config: Optional[VllmConfig],
|
||||
env_dict: Optional[dict],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
rank = node_rank * world_local_size + local_rank
|
||||
torch.cuda.set_device(local_rank)
|
||||
device = torch.device("cuda", local_rank)
|
||||
torch.distributed.init_process_group(
|
||||
backend="cpu:gloo,cuda:nccl",
|
||||
init_method=init_method,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
device_id=device,
|
||||
)
|
||||
barrier = torch.tensor([rank], device=device)
|
||||
torch.distributed.all_reduce(barrier)
|
||||
|
||||
if env_dict is not None:
|
||||
os.environ.update(env_dict)
|
||||
|
||||
cpu_group = None
|
||||
if vllm_config is not None:
|
||||
cpu_group = _set_vllm_config(vllm_config, world_size, rank, local_rank)
|
||||
|
||||
try:
|
||||
worker(
|
||||
ProcessGroupInfo(
|
||||
world_size=world_size,
|
||||
world_local_size=world_local_size,
|
||||
rank=rank,
|
||||
node_rank=node_rank,
|
||||
local_rank=local_rank,
|
||||
device=device,
|
||||
),
|
||||
vllm_config,
|
||||
cpu_group,
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as ex:
|
||||
print(ex)
|
||||
traceback.print_exc()
|
||||
raise
|
||||
finally:
|
||||
torch.distributed.destroy_process_group()
|
||||
|
||||
|
||||
def parallel_launch_with_config(
|
||||
world_size: int,
|
||||
worker: Callable[Concatenate[ProcessGroupInfo, VllmConfig, Any, P], None],
|
||||
vllm_config: VllmConfig,
|
||||
env_dict: dict[Any, Any],
|
||||
*args: P.args,
|
||||
**kwargs: P.kwargs,
|
||||
) -> None:
|
||||
assert not kwargs
|
||||
spawn(
|
||||
_worker_parallel_launch,
|
||||
args=(
|
||||
world_size,
|
||||
world_size,
|
||||
0,
|
||||
f"tcp://{os.getenv('LOCALHOST', 'localhost')}:{get_open_port()}",
|
||||
worker,
|
||||
vllm_config,
|
||||
env_dict,
|
||||
) + args,
|
||||
nprocs=world_size,
|
||||
join=True,
|
||||
)
|
127
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
127
tests/kernels/moe/modular_kernel_tools/profile_modular_kernel.py
Normal file
@ -0,0 +1,127 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import copy
|
||||
from itertools import product
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
from .common import Config, RankTensors, WeightTensors, make_modular_kernel
|
||||
from .parallel_utils import ProcessGroupInfo, parallel_launch_with_config
|
||||
|
||||
|
||||
def do_profile(fn: Callable,
|
||||
fn_kwargs: dict[Any, Any],
|
||||
pgi: ProcessGroupInfo,
|
||||
config: Config,
|
||||
num_warmups: int = 5):
|
||||
for _ in range(num_warmups):
|
||||
fn(**fn_kwargs)
|
||||
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
with_stack=True,
|
||||
record_shapes=True,
|
||||
) as tprof:
|
||||
fn(**fn_kwargs)
|
||||
torch.cuda.synchronize(torch.cuda.current_device())
|
||||
|
||||
# TODO (varun): Add a descriptive trace file name
|
||||
tprof.export_chrome_trace(
|
||||
f"{config.torch_trace_dir_path}/m{config.M}_{pgi.rank}_trace.json")
|
||||
|
||||
|
||||
def profile_modular_kernel(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
rank_tensors: RankTensors,
|
||||
) -> None:
|
||||
assert isinstance(config.Ms, int)
|
||||
assert isinstance(config.topks, int)
|
||||
|
||||
# weights for rank
|
||||
rank_weights = weights.slice_weights(pgi.rank, config.num_local_experts)
|
||||
|
||||
# make modular kernel
|
||||
mk = make_modular_kernel(config, vllm_config)
|
||||
|
||||
mk_kwargs = {
|
||||
"hidden_states": rank_tensors.hidden_states,
|
||||
"w1": rank_weights.w1,
|
||||
"w2": rank_weights.w2,
|
||||
"topk_weights": rank_tensors.topk_weights,
|
||||
"topk_ids": rank_tensors.topk_ids,
|
||||
"expert_map": rank_tensors.expert_map,
|
||||
"w1_scale": rank_weights.w1_scale,
|
||||
"w2_scale": rank_weights.w2_scale,
|
||||
"a1_scale": rank_tensors.hidden_states_scale,
|
||||
"global_num_experts": config.E,
|
||||
"apply_router_weight_on_input": config.topk == 1,
|
||||
}
|
||||
|
||||
do_profile(mk.forward, mk_kwargs, pgi, config)
|
||||
|
||||
|
||||
def rank_worker(
|
||||
pgi: ProcessGroupInfo,
|
||||
vllm_config: VllmConfig,
|
||||
cpu_group,
|
||||
config: Config,
|
||||
weights: WeightTensors,
|
||||
):
|
||||
current_platform.seed_everything(pgi.rank)
|
||||
|
||||
# sanity check
|
||||
from vllm import envs
|
||||
if config.fused_moe_chunk_size is not None:
|
||||
assert (config.fused_moe_chunk_size == envs.VLLM_FUSED_MOE_CHUNK_SIZE)
|
||||
|
||||
# get weights to this device
|
||||
weights.to_current_device()
|
||||
|
||||
Ms = config.Ms
|
||||
assert isinstance(Ms, list)
|
||||
TOPKs = config.topks
|
||||
assert isinstance(TOPKs, list)
|
||||
|
||||
for m, topk in product(Ms, TOPKs):
|
||||
print(f"Running m={m}, topk={topk} ...")
|
||||
# override m and topk
|
||||
cfgx = copy.deepcopy(config)
|
||||
cfgx.Ms = m
|
||||
cfgx.topks = topk
|
||||
|
||||
# inputs for rank
|
||||
rank_tensors = RankTensors.make(cfgx, pgi)
|
||||
profile_modular_kernel(pgi, vllm_config, cfgx, weights, rank_tensors)
|
||||
|
||||
|
||||
def run(config: Config):
|
||||
weights: WeightTensors = WeightTensors.make(config)
|
||||
vllm_config, env_dict = config.make_env_data()
|
||||
parallel_launch_with_config(config.world_size, rank_worker, vllm_config,
|
||||
env_dict, config, weights)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
from .cli_args import make_config, make_config_arg_parser
|
||||
parser = make_config_arg_parser(description=(
|
||||
"Run single prepare-finalize & fused-experts combination test"
|
||||
"Example : python3 -m tests.kernels.moe.modular_kernel_tools.profile_modular_kernel " #noqa: E501
|
||||
"--pf-type PplxPrepareAndFinalize --experts-type BatchedTritonExperts"
|
||||
))
|
||||
args = parser.parse_args()
|
||||
assert args.torch_trace_dir_path is not None, (
|
||||
"Please pass in a directory to store torch traces")
|
||||
config = make_config(args)
|
||||
|
||||
run(config)
|
142
tests/kernels/moe/modular_kernel_tools/utils.py
Normal file
142
tests/kernels/moe/modular_kernel_tools/utils.py
Normal file
@ -0,0 +1,142 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
import vllm._custom_ops as ops
|
||||
|
||||
|
||||
def per_token_cast_to_fp8(
|
||||
x: torch.Tensor, block_size: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
pad_size = (block_size - (n % block_size)) % block_size
|
||||
x = torch.nn.functional.pad(x,
|
||||
(0, pad_size), value=0) if pad_size > 0 else x
|
||||
x_view = x.view(m, -1, block_size)
|
||||
x_amax = x_view.abs().float().amax(dim=2).view(m, -1).clamp(1e-4)
|
||||
fp8_data = (x_view * (448.0 / x_amax.unsqueeze(2))).to(torch.float8_e4m3fn)
|
||||
return fp8_data.view(m, n + pad_size)[:, :n], (x_amax / 448.0).view(m, -1)
|
||||
|
||||
|
||||
def per_block_cast_to_fp8(
|
||||
x: torch.Tensor, block_size_k: int,
|
||||
block_size_n: int) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert x.dim() == 2
|
||||
m, n = x.shape
|
||||
x_padded = torch.zeros(
|
||||
(
|
||||
int(math.ceil(m / block_size_k)) * block_size_k,
|
||||
int(math.ceil(n / block_size_n)) * block_size_n,
|
||||
),
|
||||
dtype=x.dtype,
|
||||
device=x.device,
|
||||
)
|
||||
x_padded[:m, :n] = x
|
||||
x_view = x_padded.view(-1, block_size_k,
|
||||
x_padded.size(1) // block_size_k, block_size_n)
|
||||
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
|
||||
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
|
||||
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
|
||||
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
|
||||
return x_scaled_sub, scales
|
||||
|
||||
|
||||
def make_non_quant_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
dtype: torch.dtype,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2
|
||||
"""
|
||||
device = torch.cuda.current_device()
|
||||
w1 = torch.randn((e, 2 * n, k), device=device, dtype=dtype) / 15
|
||||
w2 = torch.randn((e, k, n), device=device, dtype=dtype) / 15
|
||||
return w1, w2
|
||||
|
||||
|
||||
def make_block_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
block_size: list[int],
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return weights w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
dtype = torch.bfloat16
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
fp8_info = torch.finfo(torch.float8_e4m3fn)
|
||||
fp8_max, fp8_min = fp8_info.max, fp8_info.min
|
||||
|
||||
w1_bf16, w2_bf16 = make_non_quant_weights(e, n, k, dtype)
|
||||
w1_bf16 = w1_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
w2_bf16 = w2_bf16.clamp(min=fp8_min, max=fp8_max).to(dtype=dtype)
|
||||
|
||||
block_n, block_k = block_size[0], block_size[1]
|
||||
n_tiles_w1 = ((2 * n) + block_n - 1) // block_n
|
||||
k_tiles_w1 = (k + block_k - 1) // block_k
|
||||
n_tiles_w2 = (k + block_n - 1) // block_n
|
||||
k_tiles_w2 = (n + block_k - 1) // block_k
|
||||
|
||||
w1 = torch.empty_like(w1_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
w2 = torch.empty_like(w2_bf16, dtype=torch.float8_e4m3fn, device=device)
|
||||
|
||||
w1_s = torch.empty((e, n_tiles_w1, k_tiles_w1),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
w2_s = torch.empty((e, n_tiles_w2, k_tiles_w2),
|
||||
device=device,
|
||||
dtype=torch.float32)
|
||||
|
||||
assert w1_s.shape == (e, (2 * n + (block_n - 1)) // block_n,
|
||||
(k + (block_k - 1)) // block_k)
|
||||
assert (w2.shape[-2] + block_n - 1) // block_n == w2_s.shape[-2]
|
||||
|
||||
for i in range(e):
|
||||
w1[i], w1_s[i] = per_block_cast_to_fp8(w1_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
w2[i], w2_s[i] = per_block_cast_to_fp8(w2_bf16[i],
|
||||
block_size_k=block_k,
|
||||
block_size_n=block_n)
|
||||
|
||||
return w1, w2, w1_s, w2_s
|
||||
|
||||
|
||||
def make_quant_fp8_weights(
|
||||
e: int,
|
||||
n: int,
|
||||
k: int,
|
||||
per_out_channel_quant: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Return w1, w2, w1_scale, w2_scale
|
||||
"""
|
||||
q_dtype = torch.float8_e4m3fn
|
||||
|
||||
w1, w2 = make_non_quant_weights(e, n, k, dtype=torch.bfloat16)
|
||||
|
||||
# w1 -> w1_q, w2 -> w2_q
|
||||
w1_q = torch.empty((e, 2 * n, k), device="cuda", dtype=q_dtype)
|
||||
w2_q = torch.empty((e, k, n), device="cuda", dtype=q_dtype)
|
||||
|
||||
n_b_scales = 2 * n if per_out_channel_quant else 1
|
||||
k_b_scales = k if per_out_channel_quant else 1
|
||||
w1_scale = torch.empty((e, n_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
w2_scale = torch.empty((e, k_b_scales, 1),
|
||||
device="cuda",
|
||||
dtype=torch.float32)
|
||||
|
||||
for expert in range(e):
|
||||
w1_q[expert], w1_scale[expert] = ops.scaled_fp8_quant(
|
||||
w1[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
w2_q[expert], w2_scale[expert] = ops.scaled_fp8_quant(
|
||||
w2[expert], use_per_token_if_dynamic=per_out_channel_quant)
|
||||
return w1_q, w2_q, w1_scale, w2_scale
|
@ -4,7 +4,6 @@
|
||||
DeepEP test utilities
|
||||
"""
|
||||
import dataclasses
|
||||
import importlib
|
||||
import os
|
||||
import traceback
|
||||
from typing import Callable, Optional
|
||||
@ -15,10 +14,9 @@ from torch.multiprocessing import (
|
||||
spawn) # pyright: ignore[reportPrivateImportUsage]
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from vllm.utils import get_open_port
|
||||
from vllm.utils import get_open_port, has_deep_ep
|
||||
|
||||
has_deep_ep = importlib.util.find_spec("deep_ep") is not None
|
||||
if has_deep_ep:
|
||||
if has_deep_ep():
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ht_prepare_finalize import ( # noqa: E501
|
||||
DeepEPHTPrepareAndFinalize)
|
||||
from vllm.model_executor.layers.fused_moe.deepep_ll_prepare_finalize import ( # noqa: E501
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user