mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
43 Commits
Author | SHA1 | Date | |
---|---|---|---|
31c1f3255e | |||
21d93c140d | |||
f1c8520146 | |||
096827c284 | |||
6565d9e33e | |||
f375ec8440 | |||
518369d78c | |||
30bad5c492 | |||
3fefe271ec | |||
6428f1d051 | |||
7e1b21daac | |||
cb3f30c600 | |||
f3e024bece | |||
31d2ab4aff | |||
eb17212858 | |||
4dd4b5c538 | |||
6120e5aaea | |||
2eaa81b236 | |||
81ce2a4b26 | |||
5dd80d3777 | |||
beeee69bc9 | |||
9bf28d0b69 | |||
c0ce15dfb2 | |||
b9bcdc7158 | |||
4ff0203987 | |||
b5f882cc98 | |||
2e8fc0d4c3 | |||
dacaf5a400 | |||
24cde76a15 | |||
1aa1361510 | |||
fe470ae5ad | |||
3a8c2381f7 | |||
c85b80c2b6 | |||
2b981012a6 | |||
6ccc0bfffb | |||
c8e7eb1eb3 | |||
24f60a54f4 | |||
42c02f5892 | |||
ebede26ebf | |||
d940ce497e | |||
05ff90b692 | |||
1d9b737e05 | |||
60dc62dc9e |
2
.github/workflows/publish.yml
vendored
2
.github/workflows/publish.yml
vendored
@ -49,7 +49,7 @@ jobs:
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.1.0']
|
||||
pytorch-version: ['2.1.1']
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
|
4
.gitignore
vendored
4
.gitignore
vendored
@ -177,3 +177,7 @@ _build/
|
||||
# vim swap files
|
||||
*.swo
|
||||
*.swp
|
||||
|
||||
# hip files generated by PyTorch
|
||||
*.hip
|
||||
*_hip*
|
||||
|
11
Dockerfile
11
Dockerfile
@ -30,8 +30,15 @@ COPY requirements.txt requirements.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm/__init__.py vllm/__init__.py
|
||||
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# max jobs used by Ninja to build extensions
|
||||
ENV MAX_JOBS=$max_jobs
|
||||
ARG max_jobs=2
|
||||
ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
RUN python3 setup.py build_ext --inplace
|
||||
|
||||
# image to run unit testing suite
|
||||
@ -68,7 +75,7 @@ ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
|
||||
FROM vllm-base AS vllm-openai
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate fschat
|
||||
pip install accelerate
|
||||
|
||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
|
||||
COPY vllm vllm
|
||||
|
62
Dockerfile.rocm
Normal file
62
Dockerfile.rocm
Normal file
@ -0,0 +1,62 @@
|
||||
FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install python3 python3-pip -y
|
||||
|
||||
# Install some basic utilities
|
||||
RUN apt-get update && apt-get install -y \
|
||||
curl \
|
||||
ca-certificates \
|
||||
sudo \
|
||||
git \
|
||||
bzip2 \
|
||||
libx11-6 \
|
||||
build-essential \
|
||||
wget \
|
||||
unzip \
|
||||
nvidia-cuda-toolkit \
|
||||
tmux \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
### Mount Point ###
|
||||
# When launching the container, mount the code directory to /app
|
||||
ARG APP_MOUNT=/app
|
||||
VOLUME [ ${APP_MOUNT} ]
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
|
||||
|
||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
|
||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
|
||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
|
||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
|
||||
|
||||
# Install ROCm flash-attention
|
||||
RUN mkdir libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout 3d2b6f5 \
|
||||
&& git submodule update --init \
|
||||
&& export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \
|
||||
&& patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
COPY ./ /app/vllm
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN pip install xformers==0.0.23 --no-deps
|
||||
|
||||
RUN cd /app \
|
||||
&& cd vllm \
|
||||
&& pip install -U -r requirements-rocm.txt \
|
||||
&& bash patch_xformers-0.0.23.rocm.sh \
|
||||
&& python3 setup.py install \
|
||||
&& cd ..
|
||||
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
RUN python3 -m pip install --no-cache-dir ray[all]
|
||||
|
||||
CMD ["/bin/bash"]
|
@ -17,6 +17,7 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2023/12] Added ROCm support to vLLM.
|
||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
|
||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
|
||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
|
||||
@ -43,6 +44,7 @@ vLLM is flexible and easy to use with:
|
||||
- Tensor parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA CUDA and AMD ROCm.
|
||||
|
||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
|
||||
|
||||
@ -58,6 +60,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
|
||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
|
||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
|
||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
|
||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
|
||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
|
||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
|
||||
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
|
||||
|
@ -1,6 +1,8 @@
|
||||
"""Benchmark the latency of processing a single batch of requests."""
|
||||
import argparse
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -34,12 +36,15 @@ def main(args: argparse.Namespace):
|
||||
print(sampling_params)
|
||||
dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
|
||||
|
||||
def run_to_completion(profile: bool = False):
|
||||
if profile:
|
||||
with torch.profiler.profile(activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
]) as p:
|
||||
def run_to_completion(profile_dir: Optional[str] = None):
|
||||
if profile_dir:
|
||||
with torch.profiler.profile(
|
||||
activities=[
|
||||
torch.profiler.ProfilerActivity.CPU,
|
||||
torch.profiler.ProfilerActivity.CUDA,
|
||||
],
|
||||
on_trace_ready=torch.profiler.tensorboard_trace_handler(
|
||||
str(profile_dir))) as p:
|
||||
llm.generate(prompt_token_ids=dummy_prompt_token_ids,
|
||||
sampling_params=sampling_params,
|
||||
use_tqdm=False)
|
||||
@ -54,17 +59,20 @@ def main(args: argparse.Namespace):
|
||||
return latency
|
||||
|
||||
print("Warming up...")
|
||||
run_to_completion(profile=False)
|
||||
run_to_completion(profile_dir=None)
|
||||
|
||||
if args.profile:
|
||||
print("Profiling...")
|
||||
run_to_completion(profile=True)
|
||||
profile_dir = args.profile_result_dir
|
||||
if not profile_dir:
|
||||
profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}"
|
||||
print(f"Profiling (results will be saved to '{profile_dir}')...")
|
||||
run_to_completion(profile_dir=args.profile_result_dir)
|
||||
return
|
||||
|
||||
# Benchmark.
|
||||
latencies = []
|
||||
for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
|
||||
latencies.append(run_to_completion(profile=False))
|
||||
latencies.append(run_to_completion(profile_dir=None))
|
||||
print(f'Avg latency: {np.mean(latencies)} seconds')
|
||||
|
||||
|
||||
@ -107,5 +115,13 @@ if __name__ == '__main__':
|
||||
'--profile',
|
||||
action='store_true',
|
||||
help='profile the generation process of a single batch')
|
||||
parser.add_argument(
|
||||
'--profile-result-dir',
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
'path to save the pytorch profiler output. Can be visualized '
|
||||
'with ui.perfetto.dev or Tensorboard.'
|
||||
))
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
@ -37,10 +37,6 @@ def main(
|
||||
query.uniform_(-scale, scale)
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
@ -103,7 +99,7 @@ def main(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
@ -120,7 +116,7 @@ def main(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
@ -18,8 +19,8 @@ __global__ void silu_and_mul_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
|
||||
const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
|
||||
out[token_idx * d + idx] = silu(x) * y;
|
||||
}
|
||||
}
|
||||
@ -57,7 +58,7 @@ __global__ void activation_kernel(
|
||||
const int d) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
|
||||
const scalar_t x = __ldg(&input[token_idx * d + idx]);
|
||||
const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
|
||||
out[token_idx * d + idx] = ACT_FN(x);
|
||||
}
|
||||
}
|
||||
|
@ -15,6 +15,10 @@
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
@ -23,7 +27,11 @@
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -40,7 +48,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Compute the sum per warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Warp leaders store the data to shared memory.
|
||||
@ -59,11 +67,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
|
||||
// Parallel reduction inside the warp.
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
|
||||
sum += VLLM_SHFL_XOR_SYNC(sum, mask);
|
||||
}
|
||||
|
||||
// Broadcast to other threads.
|
||||
return __shfl_sync(uint32_t(-1), sum, 0);
|
||||
return VLLM_SHFL_SYNC(sum, 0);
|
||||
}
|
||||
|
||||
// TODO(woosuk): Merge the last two dimensions of the grid.
|
||||
@ -81,7 +89,7 @@ __device__ void paged_attention_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -124,7 +132,8 @@ __device__ void paged_attention_kernel(
|
||||
|
||||
const int head_idx = blockIdx.x;
|
||||
const int num_heads = gridDim.x;
|
||||
const int kv_head_idx = head_mapping[head_idx];
|
||||
const int num_queries_per_kv = num_heads / num_kv_heads;
|
||||
const int kv_head_idx = head_idx / num_queries_per_kv;
|
||||
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
|
||||
|
||||
// A vector type to store a part of a key or a query.
|
||||
@ -223,7 +232,7 @@ __device__ void paged_attention_kernel(
|
||||
// The 0-th thread of each thread group already has its max qk value.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = qk_max;
|
||||
@ -235,10 +244,10 @@ __device__ void paged_attention_kernel(
|
||||
qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
|
||||
qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
|
||||
}
|
||||
// Broadcast the max qk value to all threads.
|
||||
qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
|
||||
qk_max = VLLM_SHFL_SYNC(qk_max, 0);
|
||||
|
||||
// Get the sum of the exp values.
|
||||
float exp_sum = 0.f;
|
||||
@ -326,7 +335,7 @@ __device__ void paged_attention_kernel(
|
||||
float acc = accs[i];
|
||||
#pragma unroll
|
||||
for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
|
||||
acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
|
||||
acc += VLLM_SHFL_XOR_SYNC(acc, mask);
|
||||
}
|
||||
accs[i] = acc;
|
||||
}
|
||||
@ -393,7 +402,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -404,7 +413,7 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr,
|
||||
out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
|
||||
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
|
||||
@ -422,7 +431,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
|
||||
const scalar_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, head_size/x, block_size, x]
|
||||
const scalar_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, head_size, block_size]
|
||||
const int* __restrict__ head_mapping, // [num_heads]
|
||||
const int num_kv_heads, // [num_heads]
|
||||
const float scale,
|
||||
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
const int* __restrict__ context_lens, // [num_seqs]
|
||||
@ -432,7 +441,7 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int kv_block_stride,
|
||||
const int kv_head_stride) {
|
||||
paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
|
||||
q_stride, kv_block_stride, kv_head_stride);
|
||||
}
|
||||
@ -492,7 +501,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
// Reduce within the warp.
|
||||
#pragma unroll
|
||||
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
if (lane == 0) {
|
||||
red_smem[warp_idx] = max_logit;
|
||||
@ -502,10 +511,10 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
|
||||
#pragma unroll
|
||||
for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
|
||||
max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
|
||||
max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
|
||||
}
|
||||
// Broadcast the max value to all threads.
|
||||
max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
|
||||
max_logit = VLLM_SHFL_SYNC(max_logit, 0);
|
||||
|
||||
// Load rescaled exp sums to shared memory.
|
||||
float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
|
||||
@ -539,16 +548,16 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
} // namespace vllm
|
||||
|
||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \
|
||||
cudaFuncSetAttribute( \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size); \
|
||||
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \
|
||||
((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>), \
|
||||
shared_mem_size); \
|
||||
vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS> \
|
||||
<<<grid, block, shared_mem_size, stream>>>( \
|
||||
out_ptr, \
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
@ -568,7 +577,7 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -594,7 +603,6 @@ void paged_attention_v1_launcher(
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -643,7 +651,7 @@ void paged_attention_v1_launcher(
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
@ -673,7 +681,7 @@ void paged_attention_v1(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
@ -700,7 +708,7 @@ void paged_attention_v1(
|
||||
query_ptr, \
|
||||
key_cache_ptr, \
|
||||
value_cache_ptr, \
|
||||
head_mapping_ptr, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables_ptr, \
|
||||
context_lens_ptr, \
|
||||
@ -731,7 +739,7 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -760,7 +768,6 @@ void paged_attention_v2_launcher(
|
||||
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
|
||||
T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
|
||||
T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
|
||||
int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
|
||||
int* block_tables_ptr = block_tables.data_ptr<int>();
|
||||
int* context_lens_ptr = context_lens.data_ptr<int>();
|
||||
|
||||
@ -815,7 +822,7 @@ void paged_attention_v2_launcher(
|
||||
query, \
|
||||
key_cache, \
|
||||
value_cache, \
|
||||
head_mapping, \
|
||||
num_kv_heads, \
|
||||
scale, \
|
||||
block_tables, \
|
||||
context_lens, \
|
||||
@ -848,7 +855,7 @@ void paged_attention_v2(
|
||||
torch::Tensor& query, // [num_seqs, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
|
||||
torch::Tensor& value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& head_mapping, // [num_heads]
|
||||
int num_kv_heads, // [num_heads]
|
||||
float scale,
|
||||
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
|
||||
torch::Tensor& context_lens, // [num_seqs]
|
||||
|
@ -17,6 +17,7 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "../cuda_compat.h"
|
||||
#include "attention_dtypes.h"
|
||||
|
||||
#include <float.h>
|
||||
@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
|
||||
float qk = sum(qk_vec);
|
||||
#pragma unroll
|
||||
for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
|
||||
qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
|
||||
qk += VLLM_SHFL_XOR_SYNC(qk, mask);
|
||||
}
|
||||
return qk;
|
||||
}
|
||||
|
@ -21,8 +21,17 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
typedef __hip_bfloat162 __nv_bfloat162;
|
||||
typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
assert(false);
|
||||
#else
|
||||
return a + b;
|
||||
#ifndef USE_ROCM
|
||||
return a + b;
|
||||
#else
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -21,6 +21,10 @@
|
||||
#include "attention_generic.cuh"
|
||||
#include "dtype_float32.cuh"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_fp16.h>
|
||||
#endif
|
||||
|
||||
#include <stdint.h>
|
||||
|
||||
namespace vllm {
|
||||
@ -63,21 +67,47 @@ struct FloatVec<uint4> {
|
||||
|
||||
// Utility functions for type conversions.
|
||||
inline __device__ uint32_t h0_h0(uint16_t a) {
|
||||
#ifndef USE_ROCM
|
||||
uint32_t b;
|
||||
asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
|
||||
return b;
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u16[0] = a;
|
||||
tmp.u16[1] = a;
|
||||
return tmp.u32;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ float half_to_float(uint16_t h) {
|
||||
float f;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
|
||||
#else
|
||||
asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
|
||||
#endif
|
||||
return f;
|
||||
}
|
||||
|
||||
inline __device__ float2 half2_to_float2(uint32_t v) {
|
||||
#ifndef USE_ROCM
|
||||
uint16_t lo, hi;
|
||||
asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
|
||||
return make_float2(half_to_float(lo), half_to_float(hi));
|
||||
#else
|
||||
union {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
tmp.u32 = v;
|
||||
float2 ret;
|
||||
ret.x = half_to_float(tmp.u16[0]);
|
||||
ret.y = half_to_float(tmp.u16[1]);
|
||||
return ret;
|
||||
#endif
|
||||
}
|
||||
|
||||
inline __device__ uint16_t float_to_half(float f) {
|
||||
@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
|
||||
#else
|
||||
asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
|
||||
#endif
|
||||
return tmp.u16[0];
|
||||
}
|
||||
|
||||
@ -94,12 +128,16 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
uint32_t u32;
|
||||
uint16_t u16[2];
|
||||
} tmp;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#ifndef USE_ROCM
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
#endif
|
||||
#else
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
|
||||
asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
|
||||
tmp.u16[0] = float_to_half(f.x);
|
||||
tmp.u16[1] = float_to_half(f.y);
|
||||
#endif
|
||||
return tmp.u32;
|
||||
}
|
||||
@ -107,13 +145,21 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
|
||||
// Vector addition.
|
||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
|
||||
template<>
|
||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
|
||||
uint16_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
|
||||
#else
|
||||
asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
template<>
|
||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
|
||||
uint32_t c;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
|
||||
#else
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
|
||||
#endif
|
||||
return c;
|
||||
}
|
||||
|
||||
@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
|
||||
// Vector fused multiply-add.
|
||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
|
||||
uint32_t d;
|
||||
#ifndef USE_ROCM
|
||||
asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
|
||||
#else
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
|
||||
#endif
|
||||
return d;
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
@ -28,8 +29,8 @@ void swap_blocks(
|
||||
TORCH_CHECK(false, "Invalid device combination");
|
||||
}
|
||||
|
||||
void *src_ptr = src.data_ptr();
|
||||
void *dst_ptr = dst.data_ptr();
|
||||
char *src_ptr = static_cast<char*>(src.data_ptr());
|
||||
char *dst_ptr = static_cast<char*>(dst.data_ptr());
|
||||
|
||||
const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
@ -267,8 +268,8 @@ __global__ void gather_cached_kv_kernel(
|
||||
+ head_offset * block_size
|
||||
+ block_offset;
|
||||
|
||||
key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
|
||||
key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
}
|
||||
|
||||
@ -333,8 +334,8 @@ __global__ void gather_cached_kv_kernel_optimized(
|
||||
src_key_indices[j] = src_key_idx;
|
||||
src_value_indices[j] = src_value_idx;
|
||||
|
||||
keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = __ldg(&value_cache[src_value_idx]);
|
||||
keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
|
||||
values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
|
||||
}
|
||||
|
||||
#pragma unroll
|
||||
|
28
csrc/cuda_compat.h
Normal file
28
csrc/cuda_compat.h
Normal file
@ -0,0 +1,28 @@
|
||||
#pragma once
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_LDG(arg) __ldg(arg)
|
||||
#else
|
||||
#define VLLM_LDG(arg) *(arg)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
|
||||
#else
|
||||
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
|
||||
#else
|
||||
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#else
|
||||
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
|
||||
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
|
||||
#endif
|
||||
|
@ -1,3 +1,6 @@
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
int get_device_attribute(
|
||||
int attribute,
|
||||
int device_id)
|
||||
|
@ -5,7 +5,7 @@ void paged_attention_v1(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -21,7 +21,7 @@ void paged_attention_v2(
|
||||
torch::Tensor& query,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& head_mapping,
|
||||
int num_kv_heads,
|
||||
float scale,
|
||||
torch::Tensor& block_tables,
|
||||
torch::Tensor& context_lens,
|
||||
@ -61,12 +61,14 @@ void gelu_fast(
|
||||
torch::Tensor& out,
|
||||
torch::Tensor& input);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor awq_gemm(
|
||||
torch::Tensor _in_feats,
|
||||
torch::Tensor _kernel,
|
||||
torch::Tensor _scaling_factors,
|
||||
torch::Tensor _zeros,
|
||||
int split_k_iters);
|
||||
#endif
|
||||
|
||||
void squeezellm_gemm(
|
||||
torch::Tensor vec,
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/extension.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
namespace vllm {
|
||||
@ -19,14 +20,14 @@ inline __device__ void apply_rotary_embedding(
|
||||
// GPT-NeoX style rotary embedding.
|
||||
x_index = rot_offset;
|
||||
y_index = embed_dim + rot_offset;
|
||||
cos = __ldg(cos_ptr + x_index);
|
||||
sin = __ldg(sin_ptr + x_index);
|
||||
cos = VLLM_LDG(cos_ptr + x_index);
|
||||
sin = VLLM_LDG(sin_ptr + x_index);
|
||||
} else {
|
||||
// GPT-J style rotary embedding.
|
||||
x_index = 2 * rot_offset;
|
||||
y_index = 2 * rot_offset + 1;
|
||||
cos = __ldg(cos_ptr + x_index / 2);
|
||||
sin = __ldg(sin_ptr + x_index / 2);
|
||||
cos = VLLM_LDG(cos_ptr + x_index / 2);
|
||||
sin = VLLM_LDG(sin_ptr + x_index / 2);
|
||||
}
|
||||
|
||||
const scalar_t x = arr[x_index];
|
||||
|
@ -48,8 +48,12 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
&rotary_embedding,
|
||||
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Quantization ops
|
||||
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
|
||||
#endif
|
||||
|
||||
|
||||
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
|
||||
|
||||
// Cache ops
|
||||
|
@ -20,9 +20,17 @@ __device__ inline unsigned int as_unsigned(int i) {
|
||||
|
||||
// 4-bit matvec kernel (LUT-based)
|
||||
__global__ void NUQ4MatMulKernel(
|
||||
#ifndef USE_ROCM
|
||||
const half2* __restrict__ vec,
|
||||
#else
|
||||
const __half2* __restrict__ vec,
|
||||
#endif
|
||||
const int* __restrict__ mat,
|
||||
#ifndef USE_ROCM
|
||||
half2* __restrict__ mul,
|
||||
#else
|
||||
float2* __restrict__ mul,
|
||||
#endif
|
||||
const __half* __restrict__ lookup_table,
|
||||
int height,
|
||||
int width,
|
||||
@ -35,7 +43,11 @@ __global__ void NUQ4MatMulKernel(
|
||||
int row = BLOCKHEIGHT4 * blockIdx.x;
|
||||
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
|
||||
|
||||
#ifndef USE_ROCM
|
||||
__shared__ half2 blockvec[blockwidth2];
|
||||
#else
|
||||
__shared__ __half2 blockvec[blockwidth2];
|
||||
#endif
|
||||
|
||||
__shared__ __half deq2[16][BLOCKWIDTH];
|
||||
int off = threadIdx.x;
|
||||
@ -46,8 +58,13 @@ __global__ void NUQ4MatMulKernel(
|
||||
}
|
||||
|
||||
__half res;
|
||||
#ifndef USE_ROCM
|
||||
half2 res2;
|
||||
half2 tmp2;
|
||||
#else
|
||||
__half2 res2;
|
||||
__half2 tmp2;
|
||||
#endif
|
||||
|
||||
int i;
|
||||
int k;
|
||||
@ -68,48 +85,96 @@ __global__ void NUQ4MatMulKernel(
|
||||
while (k < blockwidth2) {
|
||||
tmp1 = as_unsigned(mat[i]);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res2 = {};
|
||||
tmp2 = {};
|
||||
#else
|
||||
res2.x = __half_as_ushort(__float2half(0));
|
||||
res2.y = __half_as_ushort(__float2half(0));
|
||||
tmp2.x = __half_as_ushort(__float2half(0));
|
||||
tmp2.y = __half_as_ushort(__float2half(0));
|
||||
#endif
|
||||
|
||||
lut_index1 = tmp1 & 0xF;
|
||||
lut_index2 = (tmp1 >> 4) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 8) & 0xF;
|
||||
lut_index2 = (tmp1 >> 12) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 16) & 0xF;
|
||||
lut_index2 = (tmp1 >> 20) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
|
||||
|
||||
lut_index1 = (tmp1 >> 24) & 0xF;
|
||||
lut_index2 = (tmp1 >> 28) & 0xF;
|
||||
#ifndef USE_ROCM
|
||||
tmp2.x = deq2[lut_index1][off];
|
||||
tmp2.y = deq2[lut_index2][off];
|
||||
#else
|
||||
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
|
||||
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
|
||||
#endif
|
||||
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
res = __hadd(__hadd(res2.x, res2.y), res);
|
||||
#else
|
||||
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
|
||||
#endif
|
||||
|
||||
i += width;
|
||||
k += 4;
|
||||
}
|
||||
|
||||
// col%2 -> only set one of the two values
|
||||
#ifndef USE_ROCM
|
||||
half2 res3 = {};
|
||||
if (col % 2 == 0) {
|
||||
res3.x = res;
|
||||
} else {
|
||||
res3.y = res;
|
||||
}
|
||||
#else
|
||||
__half2 res3;
|
||||
res3.x = __half_as_ushort(__float2half(0));
|
||||
res3.y = __half_as_ushort(__float2half(0));
|
||||
if (col % 2 == 0) {
|
||||
res3.x = __half_as_ushort(res);
|
||||
} else {
|
||||
res3.y = __half_as_ushort(res);
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
atomicAdd(&mul[b * width / 2 + col / 2], res3);
|
||||
#else
|
||||
int tmp_addr = b * width / 2 + col / 2;
|
||||
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
|
||||
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,10 +201,19 @@ void squeezellm_gemm(
|
||||
dim3 threads(BLOCKWIDTH);
|
||||
|
||||
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
|
||||
#ifndef USE_ROCM
|
||||
(half2*) vec.data<at::Half>(),
|
||||
#else
|
||||
(__half2*) vec.data_ptr<at::Half>(),
|
||||
#endif
|
||||
mat.data_ptr<int>(),
|
||||
#ifndef USE_ROCM
|
||||
(half2*) mul.data<at::Half>(),
|
||||
(__half*) lookup_table.data<at::Half>(),
|
||||
#else
|
||||
(float2*) mul.data_ptr<float>(),
|
||||
(__half*) lookup_table.data_ptr<at::Half>(),
|
||||
#endif
|
||||
height, width, batch, vec_height
|
||||
);
|
||||
}
|
||||
|
@ -17,13 +17,15 @@
|
||||
*/
|
||||
#pragma once
|
||||
|
||||
#include "cuda_compat.h"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template<typename T>
|
||||
__inline__ __device__ T warpReduceSum(T val) {
|
||||
#pragma unroll
|
||||
for (int mask = 16; mask > 0; mask >>= 1)
|
||||
val += __shfl_xor_sync(0xffffffff, val, mask, 32);
|
||||
val += VLLM_SHFL_XOR_SYNC(val, mask);
|
||||
return val;
|
||||
}
|
||||
|
||||
|
143
docs/source/getting_started/amd-installation.rst
Normal file
143
docs/source/getting_started/amd-installation.rst
Normal file
@ -0,0 +1,143 @@
|
||||
.. _installation_rocm:
|
||||
|
||||
Installation with ROCm
|
||||
======================
|
||||
|
||||
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
|
||||
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
|
||||
Data types currently supported in ROCm are FP16 and BF16.
|
||||
|
||||
Requirements
|
||||
------------
|
||||
|
||||
* OS: Linux
|
||||
* Python: 3.8 -- 3.11 (Verified on 3.10)
|
||||
* GPU: MI200s
|
||||
* Pytorch 2.0.1/2.1.1/2.2
|
||||
* ROCm 5.7
|
||||
|
||||
Installation options:
|
||||
|
||||
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
|
||||
#. :ref:`Build from source <build_from_source_rocm>`
|
||||
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
|
||||
|
||||
.. _quick_start_docker_rocm:
|
||||
|
||||
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
|
||||
---------------------------------------------------------------------------
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
embeddedllminfo/vllm-rocm \
|
||||
bash
|
||||
|
||||
|
||||
.. _build_from_source_rocm:
|
||||
|
||||
Option 2: Build from source
|
||||
---------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
|
||||
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.23 --no-deps
|
||||
$ bash patch_xformers.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
|
||||
|
||||
|
||||
.. _build_from_source_docker_rocm:
|
||||
|
||||
Option 3: Build from source with docker
|
||||
-----------------------------------------------------
|
||||
|
||||
You can build and install vLLM from source:
|
||||
|
||||
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ docker build -f Dockerfile.rocm -t vllm-rocm .
|
||||
$ docker run -it \
|
||||
--network=host \
|
||||
--group-add=video \
|
||||
--ipc=host \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--security-opt seccomp=unconfined \
|
||||
--device /dev/kfd \
|
||||
--device /dev/dri \
|
||||
-v <path/to/model>:/app/model \
|
||||
vllm-rocm \
|
||||
bash
|
||||
|
||||
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
|
||||
|
||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
|
||||
|
||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
|
||||
- `Pytorch <https://pytorch.org/>`_
|
||||
|
||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
|
||||
|
||||
Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
|
||||
|
||||
.. note::
|
||||
- If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
|
||||
- If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
|
||||
- ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
|
||||
- You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
|
||||
|
||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ pip install xformers==0.0.23 --no-deps
|
||||
$ bash patch_xformers.rocm.sh
|
||||
|
||||
3. Build vLLM.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ cd vllm
|
||||
$ pip install -U -r requirements-rocm.txt
|
||||
$ python setup.py install # This may take 5-10 minutes.
|
@ -20,7 +20,7 @@ You can install vLLM using pip:
|
||||
.. code-block:: console
|
||||
|
||||
$ # (Optional) Create a new conda environment.
|
||||
$ conda create -n myenv python=3.8 -y
|
||||
$ conda create -n myenv python=3.9 -y
|
||||
$ conda activate myenv
|
||||
|
||||
$ # Install vLLM with CUDA 12.1.
|
||||
@ -34,8 +34,9 @@ You can install vLLM using pip:
|
||||
.. code-block:: console
|
||||
|
||||
$ # Install vLLM with CUDA 11.8.
|
||||
$ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`).
|
||||
$ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl
|
||||
$ export VLLM_VERSION=0.2.4
|
||||
$ export PYTHON_VERSION=39
|
||||
$ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
|
||||
|
||||
$ # Re-install PyTorch with CUDA 11.8.
|
||||
$ pip uninstall torch -y
|
||||
|
@ -129,7 +129,7 @@ By default, the server uses a predefined chat template stored in the tokenizer.
|
||||
|
||||
$ python -m vllm.entrypoints.openai.api_server \
|
||||
$ --model facebook/opt-125m \
|
||||
$ --chat-template ./examples/template_chatml.json
|
||||
$ --chat-template ./examples/template_chatml.jinja
|
||||
|
||||
This server can be queried in the same format as OpenAI API. For example, list the models:
|
||||
|
||||
|
@ -39,6 +39,7 @@ vLLM is flexible and easy to use with:
|
||||
* Tensor parallelism support for distributed inference
|
||||
* Streaming outputs
|
||||
* OpenAI-compatible API server
|
||||
* Support NVIDIA CUDA and AMD ROCm.
|
||||
|
||||
For more information, check out the following:
|
||||
|
||||
@ -56,6 +57,7 @@ Documentation
|
||||
:caption: Getting Started
|
||||
|
||||
getting_started/installation
|
||||
getting_started/amd-installation
|
||||
getting_started/quickstart
|
||||
|
||||
.. toctree::
|
||||
|
@ -26,7 +26,7 @@ This gives you the ability to modify the codebase and test your model.
|
||||
------------------------
|
||||
|
||||
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
|
||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adapted from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
|
||||
|
||||
.. warning::
|
||||
When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
|
||||
|
@ -50,6 +50,9 @@ Alongside each architecture, we include some popular models that use it.
|
||||
* - :code:`MistralForCausalLM`
|
||||
- Mistral, Mistral-Instruct
|
||||
- :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
|
||||
* - :code:`MixtralForCausalLM`
|
||||
- Mixtral-8x7B, Mixtral-8x7B-Instruct
|
||||
- :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
|
||||
* - :code:`MPTForCausalLM`
|
||||
- MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
|
||||
- :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
|
||||
@ -70,6 +73,9 @@ If your model uses one of the above model architectures, you can seamlessly run
|
||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
|
||||
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
|
||||
|
||||
.. note::
|
||||
Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
|
||||
|
||||
.. tip::
|
||||
The easiest way to check if your model is supported is to run the program below:
|
||||
|
||||
@ -81,12 +87,17 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
To use model from www.modelscope.cn
|
||||
If vLLM successfully generates text, it indicates that your model is supported.
|
||||
|
||||
.. tip::
|
||||
To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
|
||||
|
||||
.. code-block:: shell
|
||||
|
||||
$ export VLLM_USE_MODELSCOPE=True
|
||||
|
||||
And use with :code:`trust_remote_code=True`.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
from vllm import LLM
|
||||
@ -94,5 +105,3 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
|
||||
llm = LLM(model=..., revision=..., trust_remote_code=True) # Name or path of your model
|
||||
output = llm.generate("Hello, my name is")
|
||||
print(output)
|
||||
|
||||
If vLLM successfully generates text, it indicates that your model is supported.
|
||||
|
@ -29,7 +29,15 @@ You can build and run vLLM from source via the provided dockerfile. To build vLL
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
$ DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai --build-arg max_jobs=8
|
||||
$ DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai # optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2
|
||||
|
||||
|
||||
.. note::
|
||||
|
||||
By default vLLM will build for all GPU types for widest distribution. If you are just building for the
|
||||
current GPU type the machine is running on, you can add the argument ``--build-arg torch_cuda_arch_list=""``
|
||||
for vLLM to find the current GPU type and build for that.
|
||||
|
||||
|
||||
To run vLLM:
|
||||
|
||||
|
@ -55,7 +55,7 @@ Start the serving the LLaMA-13B model on an A100 GPU:
|
||||
|
||||
$ sky launch serving.yaml
|
||||
|
||||
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||
Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
|
||||
|
||||
.. code-block:: console
|
||||
|
||||
|
33
patch_xformers.rocm.sh
Normal file
33
patch_xformers.rocm.sh
Normal file
@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
set -e
|
||||
|
||||
XFORMERS_VERSION="0.0.23"
|
||||
|
||||
export XFORMERS_INSTALLED_VERSION=$(python -c 'import xformers; print(xformers.__version__)')
|
||||
|
||||
if [ "$XFORMERS_INSTALLED_VERSION" != "$XFORMERS_VERSION" ]; then
|
||||
echo "ERROR: xformers version must be ${XFORMERS_VERSION}. ${XFORMERS_INSTALLED_VERSION} is installed"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export XFORMERS_FMHA_FLASH_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.flash.__file__)')
|
||||
export XFORMERS_FMHA_COMMON_PATH=$(python -c 'from xformers import ops as xops; print(xops.fmha.common.__file__)')
|
||||
|
||||
echo "XFORMERS_FMHA_FLASH_PATH = ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
echo "XFORMERS_FMHA_COMMON_PATH = ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
|
||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
|
||||
echo "Applying patch to ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
patch -p0 $XFORMERS_FMHA_FLASH_PATH "./rocm_patch/flashpy_xformers-${XFORMERS_VERSION}.rocm.patch"
|
||||
echo "Successfully patch ${XFORMERS_FMHA_FLASH_PATH}"
|
||||
else
|
||||
echo "${XFORMERS_FMHA_FLASH_PATH} was patched before"
|
||||
fi
|
||||
|
||||
if ! patch -R -p0 -s -f --dry-run $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"; then
|
||||
echo "Applying patch to ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
patch -p0 $XFORMERS_FMHA_COMMON_PATH "./rocm_patch/commonpy_xformers-${XFORMERS_VERSION}.rocm.patch"
|
||||
echo "Successfully patch ${XFORMERS_FMHA_COMMON_PATH}"
|
||||
else
|
||||
echo "${XFORMERS_FMHA_COMMON_PATH} was patched before"
|
||||
fi
|
@ -4,7 +4,7 @@ requires = [
|
||||
"ninja",
|
||||
"packaging",
|
||||
"setuptools >= 49.4.0",
|
||||
"torch >= 2.1.0",
|
||||
"torch >= 2.1.1",
|
||||
"wheel",
|
||||
]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
15
requirements-rocm.txt
Normal file
15
requirements-rocm.txt
Normal file
@ -0,0 +1,15 @@
|
||||
ninja # For faster builds.
|
||||
typing-extensions>=4.8.0
|
||||
starlette
|
||||
psutil
|
||||
ray >= 2.5.1
|
||||
pandas # Required for Ray data.
|
||||
pyarrow # Required for Ray data.
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
tokenizers>=0.15.0
|
||||
transformers >= 4.36.0 # Required for Mixtral.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
aioprometheus[starlette]
|
@ -5,10 +5,9 @@ pandas # Required for Ray data.
|
||||
pyarrow # Required for Ray data.
|
||||
sentencepiece # Required for LLaMA tokenizer.
|
||||
numpy
|
||||
einops # Required for phi-1_5
|
||||
torch >= 2.1.0
|
||||
transformers >= 4.34.0 # Required for Mistral.
|
||||
xformers >= 0.0.22.post7 # Required for CUDA 12.1.
|
||||
torch >= 2.1.1
|
||||
transformers >= 4.36.0 # Required for Mixtral.
|
||||
xformers >= 0.0.23 # Required for CUDA 12.1.
|
||||
fastapi
|
||||
uvicorn[standard]
|
||||
pydantic == 1.10.13 # Required for OpenAI server.
|
||||
|
13
rocm_patch/commonpy_xformers-0.0.23.rocm.patch
Normal file
13
rocm_patch/commonpy_xformers-0.0.23.rocm.patch
Normal file
@ -0,0 +1,13 @@
|
||||
--- /opt/conda/envs/py_3.10/lib/python3.10/site-packages/xformers/ops/fmha/common.py 2023-11-29 03:17:03.930103539 +0000
|
||||
+++ common.py 2023-11-28 16:14:19.846233146 +0000
|
||||
@@ -298,8 +298,8 @@
|
||||
dtype = d.query.dtype
|
||||
if device_type not in cls.SUPPORTED_DEVICES:
|
||||
reasons.append(f"device={device_type} (supported: {cls.SUPPORTED_DEVICES})")
|
||||
- if device_type == "cuda" and not _built_with_cuda:
|
||||
- reasons.append("xFormers wasn't build with CUDA support")
|
||||
+ #if device_type == "cuda" and not _built_with_cuda:
|
||||
+ # reasons.append("xFormers wasn't build with CUDA support")
|
||||
if device_type == "cuda":
|
||||
device_capability = torch.cuda.get_device_capability(d.device)
|
||||
if device_capability < cls.CUDA_MINIMUM_COMPUTE_CAPABILITY:
|
152
rocm_patch/flashpy_xformers-0.0.23.rocm.patch
Normal file
152
rocm_patch/flashpy_xformers-0.0.23.rocm.patch
Normal file
@ -0,0 +1,152 @@
|
||||
--- flash_ori.py 2023-12-13 05:43:31.530752623 +0000
|
||||
+++ flash_patch.py 2023-12-13 06:00:45.962403104 +0000
|
||||
@@ -36,44 +36,44 @@
|
||||
|
||||
FLASH_VERSION = "0.0.0"
|
||||
try:
|
||||
- try:
|
||||
- from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
- from ..._cpp_lib import _build_metadata
|
||||
-
|
||||
- if _build_metadata is not None:
|
||||
- FLASH_VERSION = _build_metadata.flash_version
|
||||
- except ImportError:
|
||||
- import flash_attn
|
||||
- from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
-
|
||||
- FLASH_VERSION = flash_attn.__version__
|
||||
- flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
|
||||
- if (
|
||||
- flash_ver_parsed != (2, 3, 6)
|
||||
- and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
|
||||
- ):
|
||||
- raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
|
||||
+ #try:
|
||||
+ # from ... import _C_flashattention # type: ignore[attr-defined]
|
||||
+ # from ..._cpp_lib import _build_metadata
|
||||
+
|
||||
+ # if _build_metadata is not None:
|
||||
+ # FLASH_VERSION = _build_metadata.flash_version
|
||||
+ #except ImportError:
|
||||
+ import flash_attn
|
||||
+ from flash_attn.flash_attn_interface import flash_attn_cuda as _C_flashattention
|
||||
+
|
||||
+ FLASH_VERSION = flash_attn.__version__
|
||||
+ # flash_ver_parsed = tuple(int(s) for s in FLASH_VERSION.split(".")[:3])
|
||||
+ # if (
|
||||
+ # flash_ver_parsed != (2, 3, 6)
|
||||
+ # and os.environ.get("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "0") != "1"
|
||||
+ # ):
|
||||
+ # raise ImportError("Requires Flash attention 2.3.6 for varlen_fwd api")
|
||||
|
||||
# create library so that flash-attn goes through the PyTorch Dispatcher
|
||||
- _flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
-
|
||||
- _flash_lib.define(
|
||||
- "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
- "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
|
||||
- "int max_seqlen_q, int max_seqlen_k, "
|
||||
- "float p, float softmax_scale, "
|
||||
- "bool is_causal, int window_left, "
|
||||
- "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
||||
- )
|
||||
+ #_flash_lib = torch.library.Library("xformers_flash", "DEF")
|
||||
|
||||
- _flash_lib.define(
|
||||
- "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
- "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
- "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
- "int max_seqlen_q, int max_seqlen_k, "
|
||||
- "float p, float softmax_scale, bool is_causal, "
|
||||
- "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
||||
- )
|
||||
+ #_flash_lib.define(
|
||||
+ # "flash_fwd(Tensor query, Tensor key, Tensor value, "
|
||||
+ # "Tensor? cu_seqlens_q, Tensor? cu_seqlens_k, Tensor? seqused_k, "
|
||||
+ # "int max_seqlen_q, int max_seqlen_k, "
|
||||
+ # "float p, float softmax_scale, "
|
||||
+ # "bool is_causal, int window_left, "
|
||||
+ # "int window_right, bool return_softmax) -> (Tensor, Tensor, Tensor)"
|
||||
+ #)
|
||||
+
|
||||
+ #_flash_lib.define(
|
||||
+ # "flash_bwd(Tensor dout, Tensor query, Tensor key, Tensor value, "
|
||||
+ # "Tensor out, Tensor softmax_lse_, Tensor dq, Tensor dk, Tensor dv, "
|
||||
+ # "Tensor cu_seqlens_q, Tensor cu_seqlens_k, "
|
||||
+ # "int max_seqlen_q, int max_seqlen_k, "
|
||||
+ # "float p, float softmax_scale, bool is_causal, "
|
||||
+ # "int window_left, int window_right, Tensor rng_state) -> (Tensor, Tensor, Tensor)"
|
||||
+ #)
|
||||
|
||||
def _flash_fwd(
|
||||
query,
|
||||
@@ -111,8 +111,8 @@
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
- window_left, # window_size_left
|
||||
- window_right, # window_size_right
|
||||
+ # window_left, # window_size_left
|
||||
+ # window_right, # window_size_right
|
||||
return_softmax,
|
||||
None, # rng
|
||||
)
|
||||
@@ -134,15 +134,15 @@
|
||||
out,
|
||||
cu_seq_lens_q,
|
||||
cu_seq_lens_k,
|
||||
- seqused_k,
|
||||
+ # seqused_k,
|
||||
max_seq_len_q,
|
||||
max_seq_len_k,
|
||||
p,
|
||||
softmax_scale,
|
||||
False,
|
||||
is_causal,
|
||||
- window_left,
|
||||
- window_right,
|
||||
+ # window_left,
|
||||
+ # window_right,
|
||||
return_softmax,
|
||||
None,
|
||||
)
|
||||
@@ -184,8 +184,8 @@
|
||||
p,
|
||||
softmax_scale,
|
||||
is_causal,
|
||||
- window_left,
|
||||
- window_right,
|
||||
+ # window_left,
|
||||
+ # window_right,
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
@@ -208,15 +208,15 @@
|
||||
softmax_scale,
|
||||
False, # zero_tensors
|
||||
is_causal,
|
||||
- window_left,
|
||||
- window_right,
|
||||
+ # window_left,
|
||||
+ # window_right,
|
||||
None,
|
||||
rng_state,
|
||||
)
|
||||
return dq, dk, dv
|
||||
|
||||
- _flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
- _flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
+ #_flash_lib.impl("flash_fwd", _flash_fwd, "CUDA")
|
||||
+ #_flash_lib.impl("flash_bwd", _flash_bwd, "CUDA")
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -400,7 +400,7 @@
|
||||
implementation.
|
||||
"""
|
||||
|
||||
- OPERATOR = get_operator("xformers_flash", "flash_fwd")
|
||||
+ OPERATOR = _flash_fwd # get_operator("xformers_flash", "flash_fwd")
|
||||
SUPPORTED_DEVICES: Set[str] = {"cuda"}
|
||||
CUDA_MINIMUM_COMPUTE_CAPABILITY = (8, 0)
|
||||
SUPPORTED_DTYPES: Set[torch.dtype] = {torch.half, torch.bfloat16}
|
231
setup.py
231
setup.py
@ -8,27 +8,83 @@ import warnings
|
||||
from packaging.version import parse, Version
|
||||
import setuptools
|
||||
import torch
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME
|
||||
|
||||
ROOT_DIR = os.path.dirname(__file__)
|
||||
|
||||
MAIN_CUDA_VERSION = "12.1"
|
||||
|
||||
# Supported NVIDIA GPU architectures.
|
||||
SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
NVIDIA_SUPPORTED_ARCHS = {"7.0", "7.5", "8.0", "8.6", "8.9", "9.0"}
|
||||
ROCM_SUPPORTED_ARCHS = {"gfx90a", "gfx908", "gfx906", "gfx1030", "gfx1100"}
|
||||
# SUPPORTED_ARCHS = NVIDIA_SUPPORTED_ARCHS.union(ROCM_SUPPORTED_ARCHS)
|
||||
|
||||
|
||||
def _is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def _is_cuda() -> bool:
|
||||
return torch.version.cuda is not None
|
||||
|
||||
|
||||
# Compiler flags.
|
||||
CXX_FLAGS = ["-g", "-O2", "-std=c++17"]
|
||||
# TODO(woosuk): Should we use -O3?
|
||||
NVCC_FLAGS = ["-O2", "-std=c++17"]
|
||||
|
||||
if _is_hip():
|
||||
if ROCM_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find ROCM_HOME. ROCm must be available to build the package."
|
||||
)
|
||||
NVCC_FLAGS += ["-DUSE_ROCM"]
|
||||
|
||||
if _is_cuda() and CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
ABI = 1 if torch._C._GLIBCXX_USE_CXX11_ABI else 0
|
||||
CXX_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
NVCC_FLAGS += [f"-D_GLIBCXX_USE_CXX11_ABI={ABI}"]
|
||||
|
||||
if CUDA_HOME is None:
|
||||
raise RuntimeError(
|
||||
"Cannot find CUDA_HOME. CUDA must be available to build the package.")
|
||||
|
||||
def get_amdgpu_offload_arch():
|
||||
command = "/opt/rocm/llvm/bin/amdgpu-offload-arch"
|
||||
try:
|
||||
output = subprocess.check_output([command])
|
||||
return output.decode('utf-8').strip()
|
||||
except subprocess.CalledProcessError as e:
|
||||
error_message = f"Error: {e}"
|
||||
raise RuntimeError(error_message) from e
|
||||
except FileNotFoundError as e:
|
||||
# If the command is not found, print an error message
|
||||
error_message = f"The command {command} was not found."
|
||||
raise RuntimeError(error_message) from e
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_hipcc_rocm_version():
|
||||
# Run the hipcc --version command
|
||||
result = subprocess.run(['hipcc', '--version'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True)
|
||||
|
||||
# Check if the command was executed successfully
|
||||
if result.returncode != 0:
|
||||
print("Error running 'hipcc --version'")
|
||||
return None
|
||||
|
||||
# Extract the version using a regular expression
|
||||
match = re.search(r'HIP version: (\S+)', result.stdout)
|
||||
if match:
|
||||
# Return the version string
|
||||
return match.group(1)
|
||||
else:
|
||||
print("Could not find HIP version in the output")
|
||||
return None
|
||||
|
||||
|
||||
def get_nvcc_cuda_version(cuda_dir: str) -> Version:
|
||||
@ -61,20 +117,22 @@ def get_torch_arch_list() -> Set[str]:
|
||||
return set()
|
||||
|
||||
# Filter out the invalid architectures and print a warning.
|
||||
valid_archs = SUPPORTED_ARCHS.union({s + "+PTX" for s in SUPPORTED_ARCHS})
|
||||
valid_archs = NVIDIA_SUPPORTED_ARCHS.union(
|
||||
{s + "+PTX"
|
||||
for s in NVIDIA_SUPPORTED_ARCHS})
|
||||
arch_list = torch_arch_list.intersection(valid_archs)
|
||||
# If none of the specified architectures are valid, raise an error.
|
||||
if not arch_list:
|
||||
raise RuntimeError(
|
||||
"None of the CUDA architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
"None of the CUDA/ROCM architectures in `TORCH_CUDA_ARCH_LIST` env "
|
||||
f"variable ({env_arch_list}) is supported. "
|
||||
f"Supported CUDA architectures are: {valid_archs}.")
|
||||
f"Supported CUDA/ROCM architectures are: {valid_archs}.")
|
||||
invalid_arch_list = torch_arch_list - valid_archs
|
||||
if invalid_arch_list:
|
||||
warnings.warn(
|
||||
f"Unsupported CUDA architectures ({invalid_arch_list}) are "
|
||||
f"Unsupported CUDA/ROCM architectures ({invalid_arch_list}) are "
|
||||
"excluded from the `TORCH_CUDA_ARCH_LIST` env variable "
|
||||
f"({env_arch_list}). Supported CUDA architectures are: "
|
||||
f"({env_arch_list}). Supported CUDA/ROCM architectures are: "
|
||||
f"{valid_archs}.",
|
||||
stacklevel=2)
|
||||
return arch_list
|
||||
@ -82,7 +140,7 @@ def get_torch_arch_list() -> Set[str]:
|
||||
|
||||
# First, check the TORCH_CUDA_ARCH_LIST environment variable.
|
||||
compute_capabilities = get_torch_arch_list()
|
||||
if not compute_capabilities:
|
||||
if _is_cuda() and not compute_capabilities:
|
||||
# If TORCH_CUDA_ARCH_LIST is not defined or empty, target all available
|
||||
# GPUs on the current machine.
|
||||
device_count = torch.cuda.device_count()
|
||||
@ -93,68 +151,84 @@ if not compute_capabilities:
|
||||
"GPUs with compute capability below 7.0 are not supported.")
|
||||
compute_capabilities.add(f"{major}.{minor}")
|
||||
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError("CUDA 11.0 or higher is required to build the package.")
|
||||
if (nvcc_cuda_version < Version("11.1")
|
||||
and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.",
|
||||
stacklevel=2)
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
if _is_cuda():
|
||||
nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME)
|
||||
if not compute_capabilities:
|
||||
# If no GPU is specified nor available, add all supported architectures
|
||||
# based on the NVCC CUDA version.
|
||||
compute_capabilities = NVIDIA_SUPPORTED_ARCHS.copy()
|
||||
if nvcc_cuda_version < Version("11.1"):
|
||||
compute_capabilities.remove("8.6")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
compute_capabilities.remove("8.9")
|
||||
compute_capabilities.remove("9.0")
|
||||
# Validate the NVCC CUDA version.
|
||||
if nvcc_cuda_version < Version("11.0"):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
"CUDA 11.0 or higher is required to build the package.")
|
||||
if (nvcc_cuda_version < Version("11.1")
|
||||
and any(cc.startswith("8.6") for cc in compute_capabilities)):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.1 or higher is required for compute capability 8.6.")
|
||||
if nvcc_cuda_version < Version("11.8"):
|
||||
if any(cc.startswith("8.9") for cc in compute_capabilities):
|
||||
# CUDA 11.8 is required to generate the code targeting compute capability 8.9.
|
||||
# However, GPUs with compute capability 8.9 can also run the code generated by
|
||||
# the previous versions of CUDA 11 and targeting compute capability 8.0.
|
||||
# Therefore, if CUDA 11.8 is not available, we target compute capability 8.0
|
||||
# instead of 8.9.
|
||||
warnings.warn(
|
||||
"CUDA 11.8 or higher is required for compute capability 8.9. "
|
||||
"Targeting compute capability 8.0 instead.",
|
||||
stacklevel=2)
|
||||
compute_capabilities = set(cc for cc in compute_capabilities
|
||||
if not cc.startswith("8.9"))
|
||||
compute_capabilities.add("8.0+PTX")
|
||||
if any(cc.startswith("9.0") for cc in compute_capabilities):
|
||||
raise RuntimeError(
|
||||
"CUDA 11.8 or higher is required for compute capability 9.0.")
|
||||
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=compute_{num}"]
|
||||
# Add target compute capabilities to NVCC flags.
|
||||
for capability in compute_capabilities:
|
||||
num = capability[0] + capability[2]
|
||||
NVCC_FLAGS += ["-gencode", f"arch=compute_{num},code=sm_{num}"]
|
||||
if capability.endswith("+PTX"):
|
||||
NVCC_FLAGS += [
|
||||
"-gencode", f"arch=compute_{num},code=compute_{num}"
|
||||
]
|
||||
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
num_threads = min(os.cpu_count(), 8)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
# Use NVCC threads to parallelize the build.
|
||||
if nvcc_cuda_version >= Version("11.2"):
|
||||
nvcc_threads = int(os.getenv("NVCC_THREADS", 8))
|
||||
num_threads = min(os.cpu_count(), nvcc_threads)
|
||||
NVCC_FLAGS += ["--threads", str(num_threads)]
|
||||
|
||||
elif _is_hip():
|
||||
amd_arch = get_amdgpu_offload_arch()
|
||||
if amd_arch not in ROCM_SUPPORTED_ARCHS:
|
||||
raise RuntimeError(
|
||||
f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}"
|
||||
f"amdgpu_arch_found: {amd_arch}")
|
||||
|
||||
ext_modules = []
|
||||
|
||||
vllm_extension_sources = [
|
||||
"csrc/cache_kernels.cu",
|
||||
"csrc/attention/attention_kernels.cu",
|
||||
"csrc/pos_encoding_kernels.cu",
|
||||
"csrc/activation_kernels.cu",
|
||||
"csrc/layernorm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
"csrc/cuda_utils_kernels.cu",
|
||||
"csrc/pybind.cpp",
|
||||
]
|
||||
|
||||
if _is_cuda():
|
||||
vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu")
|
||||
|
||||
vllm_extension = CUDAExtension(
|
||||
name="vllm._C",
|
||||
sources=[
|
||||
"csrc/cache_kernels.cu",
|
||||
"csrc/attention/attention_kernels.cu",
|
||||
"csrc/pos_encoding_kernels.cu",
|
||||
"csrc/activation_kernels.cu",
|
||||
"csrc/layernorm_kernels.cu",
|
||||
"csrc/quantization/awq/gemm_kernels.cu",
|
||||
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
|
||||
"csrc/cuda_utils_kernels.cu",
|
||||
"csrc/pybind.cpp",
|
||||
],
|
||||
sources=vllm_extension_sources,
|
||||
extra_compile_args={
|
||||
"cxx": CXX_FLAGS,
|
||||
"nvcc": NVCC_FLAGS,
|
||||
@ -182,10 +256,19 @@ def find_version(filepath: str) -> str:
|
||||
|
||||
def get_vllm_version() -> str:
|
||||
version = find_version(get_path("vllm", "__init__.py"))
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
|
||||
if _is_hip():
|
||||
# Get the HIP version
|
||||
hipcc_version = get_hipcc_rocm_version()
|
||||
if hipcc_version != MAIN_CUDA_VERSION:
|
||||
rocm_version_str = hipcc_version.replace(".", "")[:3]
|
||||
version += f"+rocm{rocm_version_str}"
|
||||
else:
|
||||
cuda_version = str(nvcc_cuda_version)
|
||||
if cuda_version != MAIN_CUDA_VERSION:
|
||||
cuda_version_str = cuda_version.replace(".", "")[:3]
|
||||
version += f"+cu{cuda_version_str}"
|
||||
|
||||
return version
|
||||
|
||||
|
||||
@ -200,8 +283,12 @@ def read_readme() -> str:
|
||||
|
||||
def get_requirements() -> List[str]:
|
||||
"""Get Python package dependencies from requirements.txt."""
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
if _is_hip():
|
||||
with open(get_path("requirements-rocm.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
else:
|
||||
with open(get_path("requirements.txt")) as f:
|
||||
requirements = f.read().strip().split("\n")
|
||||
return requirements
|
||||
|
||||
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import pytest
|
||||
@ -7,21 +8,32 @@ from transformers import AutoModelForCausalLM
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
|
||||
_TEST_PROMPTS = [
|
||||
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.",
|
||||
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.",
|
||||
"Compare and contrast artificial intelligence with human intelligence in terms of processing information.",
|
||||
"Describe the basic components of a neural network and how it can be trained.",
|
||||
"Write a short story about a robot that dreams for the first time.",
|
||||
"Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.",
|
||||
"Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.",
|
||||
"Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'",
|
||||
]
|
||||
_TEST_PROMPTS = ["prompts/example.txt"]
|
||||
_LONG_PROMPTS = ["prompts/summary.txt"]
|
||||
|
||||
|
||||
def _read_prompts(filename: str) -> str:
|
||||
prompts = []
|
||||
with open(filename, "r") as f:
|
||||
prompt = f.readline()
|
||||
prompts.append(prompt)
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_prompts() -> List[str]:
|
||||
return _TEST_PROMPTS
|
||||
prompts = []
|
||||
for filename in _TEST_PROMPTS:
|
||||
prompts += _read_prompts(os.path.join("tests", filename))
|
||||
return prompts
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def example_long_prompts() -> List[str]:
|
||||
prompts = []
|
||||
for filename in _LONG_PROMPTS:
|
||||
prompts += _read_prompts(os.path.join("tests", filename))
|
||||
return prompts
|
||||
|
||||
|
||||
_STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
|
@ -131,9 +131,6 @@ def test_paged_attention(
|
||||
|
||||
assert num_query_heads % num_kv_heads == 0
|
||||
num_queries_per_kv = num_query_heads // num_kv_heads
|
||||
head_mapping = torch.repeat_interleave(
|
||||
torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
num_queries_per_kv)
|
||||
alibi_slopes = None
|
||||
if use_alibi:
|
||||
alibi_slopes = torch.randn(num_query_heads,
|
||||
@ -170,7 +167,7 @@ def test_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
@ -202,7 +199,7 @@ def test_paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
block_tables,
|
||||
context_lens,
|
||||
|
37
tests/models/test_mistral.py
Normal file
37
tests/models/test_mistral.py
Normal file
@ -0,0 +1,37 @@
|
||||
"""Compare the outputs of HF and vLLM for Mistral models using greedy sampling.
|
||||
|
||||
Run `pytest tests/models/test_mistral.py --forked`.
|
||||
"""
|
||||
import pytest
|
||||
|
||||
MODELS = [
|
||||
"mistralai/Mistral-7B-Instruct-v0.1",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["bfloat16"])
|
||||
@pytest.mark.parametrize("max_tokens", [128])
|
||||
def test_models(
|
||||
hf_runner,
|
||||
vllm_runner,
|
||||
example_long_prompts,
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
hf_model = hf_runner(model, dtype=dtype)
|
||||
hf_outputs = hf_model.generate_greedy(example_long_prompts, max_tokens)
|
||||
del hf_model
|
||||
|
||||
vllm_model = vllm_runner(model, dtype=dtype)
|
||||
vllm_outputs = vllm_model.generate_greedy(example_long_prompts, max_tokens)
|
||||
del vllm_model
|
||||
|
||||
for i in range(len(example_long_prompts)):
|
||||
hf_output_ids, hf_output_str = hf_outputs[i]
|
||||
vllm_output_ids, vllm_output_str = vllm_outputs[i]
|
||||
assert hf_output_str == vllm_output_str, (
|
||||
f"Test{i}:\nHF: {hf_output_str!r}\nvLLM: {vllm_output_str!r}")
|
||||
assert hf_output_ids == vllm_output_ids, (
|
||||
f"Test{i}:\nHF: {hf_output_ids}\nvLLM: {vllm_output_ids}")
|
8
tests/prompts/example.txt
Normal file
8
tests/prompts/example.txt
Normal file
@ -0,0 +1,8 @@
|
||||
vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs.
|
||||
Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020.
|
||||
Compare and contrast artificial intelligence with human intelligence in terms of processing information.
|
||||
Describe the basic components of a neural network and how it can be trained.
|
||||
Write a short story about a robot that dreams for the first time.
|
||||
Analyze the impact of the COVID-19 pandemic on global economic structures and future business models.
|
||||
Explain the cultural significance of the Mona Lisa painting, and how its perception might vary in Western versus Eastern societies.
|
||||
Translate the following English sentence into Japanese, French, and Swahili: 'The early bird catches the worm.'
|
1
tests/prompts/summary.txt
Normal file
1
tests/prompts/summary.txt
Normal file
File diff suppressed because one or more lines are too long
@ -8,7 +8,7 @@ from vllm.entrypoints.llm import LLM
|
||||
from vllm.outputs import CompletionOutput, RequestOutput
|
||||
from vllm.sampling_params import SamplingParams
|
||||
|
||||
__version__ = "0.2.3"
|
||||
__version__ = "0.2.5"
|
||||
|
||||
__all__ = [
|
||||
"LLM",
|
||||
|
@ -6,7 +6,7 @@ from transformers import PretrainedConfig
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.transformers_utils.config import get_config
|
||||
from vllm.utils import get_cpu_memory
|
||||
from vllm.utils import get_cpu_memory, is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -98,12 +98,39 @@ class ModelConfig:
|
||||
|
||||
def _verify_load_format(self) -> None:
|
||||
load_format = self.load_format.lower()
|
||||
if load_format not in [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]:
|
||||
supported_load_format = [
|
||||
"auto", "pt", "safetensors", "npcache", "dummy"
|
||||
]
|
||||
rocm_not_supported_load_format = ["safetensors"]
|
||||
if load_format not in supported_load_format:
|
||||
raise ValueError(
|
||||
f"Unknown load format: {self.load_format}. Must be one of "
|
||||
"'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.")
|
||||
if is_hip():
|
||||
if load_format in ["safetensors"]:
|
||||
rocm_supported_load_format = [
|
||||
f for f in supported_load_format
|
||||
if (f not in rocm_not_supported_load_format)
|
||||
]
|
||||
raise ValueError(
|
||||
f"load format \'{load_format}\' is not supported in ROCm. "
|
||||
f"Supported load format are "
|
||||
f"{rocm_supported_load_format}")
|
||||
# Force ROCm to load from pt weights if nothing specific is set
|
||||
if load_format == "auto":
|
||||
load_format = "pt"
|
||||
|
||||
# TODO: Remove this check once HF updates the pt weights of Mixtral.
|
||||
architectures = getattr(self.hf_config, "architectures", [])
|
||||
if "MixtralForCausalLM" in architectures:
|
||||
if load_format == "pt":
|
||||
raise ValueError(
|
||||
"Currently, the 'pt' format is not supported for Mixtral. "
|
||||
"Please use the 'safetensors' format instead. ")
|
||||
elif load_format == "auto":
|
||||
# Do not fall back to pt weights.
|
||||
load_format = "safetensors"
|
||||
|
||||
self.load_format = load_format
|
||||
|
||||
def _verify_tokenizer_mode(self) -> None:
|
||||
@ -116,6 +143,7 @@ class ModelConfig:
|
||||
|
||||
def _verify_quantization(self) -> None:
|
||||
supported_quantization = ["awq", "squeezellm"]
|
||||
rocm_not_supported_quantization = ["awq"]
|
||||
if self.quantization is not None:
|
||||
self.quantization = self.quantization.lower()
|
||||
|
||||
@ -137,6 +165,11 @@ class ModelConfig:
|
||||
raise ValueError(
|
||||
f"Unknown quantization method: {self.quantization}. Must "
|
||||
f"be one of {supported_quantization}.")
|
||||
if is_hip(
|
||||
) and self.quantization in rocm_not_supported_quantization:
|
||||
raise ValueError(
|
||||
f"{self.quantization} quantization is currently not supported "
|
||||
f"in ROCm.")
|
||||
logger.warning(f"{self.quantization} quantization is not fully "
|
||||
"optimized yet. The speed can be slower than "
|
||||
"non-quantized models.")
|
||||
@ -364,6 +397,8 @@ _STR_DTYPE_TO_TORCH_DTYPE = {
|
||||
"bfloat16": torch.bfloat16,
|
||||
}
|
||||
|
||||
_ROCM_NOT_SUPPORTED_DTYPE = ["float", "float32"]
|
||||
|
||||
|
||||
def _get_and_verify_dtype(
|
||||
config: PretrainedConfig,
|
||||
@ -393,6 +428,14 @@ def _get_and_verify_dtype(
|
||||
else:
|
||||
raise ValueError(f"Unknown dtype: {dtype}")
|
||||
|
||||
if is_hip() and torch_dtype == torch.float32:
|
||||
rocm_supported_dtypes = [
|
||||
k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items()
|
||||
if (k not in _ROCM_NOT_SUPPORTED_DTYPE)
|
||||
]
|
||||
raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. "
|
||||
f"Supported dtypes are {rocm_supported_dtypes}")
|
||||
|
||||
# Verify the dtype.
|
||||
if torch_dtype != config_dtype:
|
||||
if torch_dtype == torch.float32:
|
||||
|
@ -3,6 +3,7 @@ from typing import Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@ -73,7 +74,12 @@ def initialize_cluster(
|
||||
"Ray is not installed. Please install Ray to use distributed "
|
||||
"serving.")
|
||||
# Connect to a ray cluster.
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
if is_hip():
|
||||
ray.init(address=ray_address,
|
||||
ignore_reinit_error=True,
|
||||
num_gpus=parallel_config.world_size)
|
||||
else:
|
||||
ray.init(address=ray_address, ignore_reinit_error=True)
|
||||
|
||||
if not parallel_config.worker_use_ray:
|
||||
# Initialize cluster locally.
|
||||
|
@ -253,8 +253,10 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
n=request.n,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
max_tokens=request.max_tokens,
|
||||
@ -330,8 +332,7 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
# Send token-by-token response for each request.n
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
previous_texts[i] = output.text
|
||||
completion_tokens = len(output.token_ids)
|
||||
previous_num_tokens[i] = completion_tokens
|
||||
previous_num_tokens[i] = len(output.token_ids)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i,
|
||||
delta=DeltaMessage(content=delta_text),
|
||||
@ -349,8 +350,8 @@ async def create_chat_completion(request: ChatCompletionRequest,
|
||||
prompt_tokens = len(res.prompt_token_ids)
|
||||
final_usage = UsageInfo(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
completion_tokens=previous_num_tokens[i],
|
||||
total_tokens=prompt_tokens + previous_num_tokens[i],
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=i, delta=[], finish_reason=output.finish_reason)
|
||||
@ -497,9 +498,11 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
best_of=request.best_of,
|
||||
presence_penalty=request.presence_penalty,
|
||||
frequency_penalty=request.frequency_penalty,
|
||||
repetition_penalty=request.repetition_penalty,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
top_k=request.top_k,
|
||||
min_p=request.min_p,
|
||||
stop=request.stop,
|
||||
stop_token_ids=request.stop_token_ids,
|
||||
ignore_eos=request.ignore_eos,
|
||||
@ -564,17 +567,22 @@ async def create_completion(request: CompletionRequest, raw_request: Request):
|
||||
i = output.index
|
||||
delta_text = output.text[len(previous_texts[i]):]
|
||||
token_ids = output.token_ids[previous_num_tokens[i]:]
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
if request.logprobs is not None:
|
||||
top_logprobs = output.logprobs[previous_num_tokens[i]:]
|
||||
else:
|
||||
top_logprobs = None
|
||||
offsets = len(previous_texts[i])
|
||||
if request.echo and not has_echoed[i]:
|
||||
if not echo_without_generation:
|
||||
delta_text = res.prompt + delta_text
|
||||
token_ids = res.prompt_token_ids + token_ids
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else:
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs + top_logprobs
|
||||
else: # only just return the prompt
|
||||
delta_text = res.prompt
|
||||
token_ids = res.prompt_token_ids
|
||||
top_logprobs = res.prompt_logprobs
|
||||
if top_logprobs:
|
||||
top_logprobs = res.prompt_logprobs
|
||||
has_echoed[i] = True
|
||||
if request.logprobs is not None:
|
||||
logprobs = create_logprobs(
|
||||
|
@ -75,6 +75,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
add_generation_prompt: Optional[bool] = True
|
||||
echo: Optional[bool] = False
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
@ -102,6 +104,8 @@ class CompletionRequest(BaseModel):
|
||||
stop_token_ids: Optional[List[int]] = Field(default_factory=list)
|
||||
skip_special_tokens: Optional[bool] = True
|
||||
spaces_between_special_tokens: Optional[bool] = True
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
min_p: Optional[float] = 0.0
|
||||
|
||||
|
||||
class LogProbs(BaseModel):
|
||||
|
@ -10,6 +10,7 @@ from xformers.ops.fmha.attn_bias import (BlockDiagonalCausalMask,
|
||||
from vllm._C import ops
|
||||
from vllm._C import cache_ops
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.utils import is_hip
|
||||
|
||||
_SUPPORTED_HEAD_SIZES = [64, 80, 96, 112, 128, 256]
|
||||
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
|
||||
@ -53,9 +54,6 @@ class PagedAttention(nn.Module):
|
||||
|
||||
assert self.num_heads % self.num_kv_heads == 0
|
||||
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
|
||||
self.head_mapping = torch.repeat_interleave(
|
||||
torch.arange(self.num_kv_heads, dtype=torch.int32, device="cuda"),
|
||||
self.num_queries_per_kv)
|
||||
|
||||
if self.head_size not in _SUPPORTED_HEAD_SIZES:
|
||||
raise ValueError(f"head_size ({self.head_size}) is not supported. "
|
||||
@ -76,7 +74,7 @@ class PagedAttention(nn.Module):
|
||||
Args:
|
||||
query: shape = [batch_size, seq_len, num_heads * head_size]
|
||||
key: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, num_kv_heads * head_size]
|
||||
value: shape = [batch_size, seq_len, num_kv_heads * head_size]
|
||||
key_cache: shape = [num_blocks, num_kv_heads, head_size/x,
|
||||
block_size, x]
|
||||
value_cache: shape = [num_blocks, num_kv_heads, head_size,
|
||||
@ -140,7 +138,8 @@ class PagedAttention(nn.Module):
|
||||
input_metadata.attn_bias = attn_bias
|
||||
else:
|
||||
input_metadata.attn_bias = _make_alibi_bias(
|
||||
self.alibi_slopes, batch_size, seq_len, query.dtype)
|
||||
self.alibi_slopes, self.num_kv_heads, batch_size,
|
||||
seq_len, query.dtype)
|
||||
|
||||
# TODO(woosuk): Too many view operations. Let's try to reduce them
|
||||
# in the future for code readability.
|
||||
@ -160,6 +159,8 @@ class PagedAttention(nn.Module):
|
||||
attn_bias=input_metadata.attn_bias,
|
||||
p=0.0,
|
||||
scale=self.scale,
|
||||
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
|
||||
(is_hip()) else None,
|
||||
)
|
||||
output = out.view_as(query)
|
||||
else:
|
||||
@ -169,7 +170,7 @@ class PagedAttention(nn.Module):
|
||||
key_cache,
|
||||
value_cache,
|
||||
input_metadata,
|
||||
self.head_mapping,
|
||||
self.num_kv_heads,
|
||||
self.scale,
|
||||
self.alibi_slopes,
|
||||
)
|
||||
@ -180,31 +181,34 @@ class PagedAttention(nn.Module):
|
||||
|
||||
def _make_alibi_bias(
|
||||
alibi_slopes: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
batch_size: int,
|
||||
seq_len: int,
|
||||
dtype: torch.dtype,
|
||||
) -> LowerTriangularMaskWithTensorBias:
|
||||
bias = torch.arange(seq_len, dtype=dtype)
|
||||
bias = torch.arange(seq_len, dtype=dtype, device="cuda")
|
||||
# NOTE(zhuohan): HF uses
|
||||
# `bias = bias[None, :].repeat(prompt_len, 1)`
|
||||
# here. We find that both biases give the same results, but
|
||||
# the bias below more accurately follows the original ALiBi
|
||||
# paper.
|
||||
bias = bias[None, :] - bias[:, None]
|
||||
bias = bias.to(alibi_slopes.device)
|
||||
|
||||
# When using custom attention bias, xformers requires the bias to
|
||||
# be sliced from a tensor whose length is a multiple of 8.
|
||||
padded_len = (seq_len + 7) // 8 * 8
|
||||
num_heads = alibi_slopes.shape[0]
|
||||
bias = torch.empty(
|
||||
batch_size,
|
||||
alibi_slopes.shape[0],
|
||||
num_heads,
|
||||
seq_len,
|
||||
padded_len,
|
||||
device=alibi_slopes.device,
|
||||
dtype=dtype,
|
||||
)[:, :, :, :seq_len].copy_(bias)
|
||||
bias.mul_(alibi_slopes[:, None, None])
|
||||
if num_heads != num_kv_heads:
|
||||
bias = bias.unflatten(1, (num_kv_heads, num_heads // num_kv_heads))
|
||||
attn_bias = LowerTriangularMaskWithTensorBias(bias)
|
||||
return attn_bias
|
||||
|
||||
@ -214,7 +218,7 @@ def _paged_attention(
|
||||
key_cache: torch.Tensor,
|
||||
value_cache: torch.Tensor,
|
||||
input_metadata: InputMetadata,
|
||||
head_mapping: torch.Tensor,
|
||||
num_kv_heads: int,
|
||||
scale: float,
|
||||
alibi_slopes: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
@ -241,7 +245,7 @@ def _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
@ -271,7 +275,7 @@ def _paged_attention(
|
||||
query,
|
||||
key_cache,
|
||||
value_cache,
|
||||
head_mapping,
|
||||
num_kv_heads,
|
||||
scale,
|
||||
input_metadata.block_tables,
|
||||
input_metadata.context_lens,
|
||||
|
@ -7,6 +7,7 @@ from vllm._C import ops
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
set_weight_attrs)
|
||||
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
||||
from vllm.utils import is_hip
|
||||
|
||||
|
||||
class SqueezeLLMConfig(QuantizationConfig):
|
||||
@ -114,9 +115,14 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
|
||||
lookup_table = weights["lookup_table"]
|
||||
out_shape = x.shape[:-1] + (qweight.shape[-1], )
|
||||
reshaped_x = x.reshape(-1, x.shape[-1])
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
if is_hip():
|
||||
out_f = torch.zeros(out_shape, device="cuda", dtype=torch.float)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out_f, lookup_table)
|
||||
out = out_f.to(dtype=torch.float16)
|
||||
else:
|
||||
# NOTE: The output tensor should be zero-initialized.
|
||||
out = torch.zeros(out_shape, device="cuda", dtype=torch.float16)
|
||||
ops.squeezellm_gemm(reshaped_x, qweight, out, lookup_table)
|
||||
|
||||
if bias is not None:
|
||||
out = out + bias
|
||||
|
@ -7,37 +7,10 @@ import torch.nn as nn
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
from vllm.config import ModelConfig
|
||||
from vllm.model_executor.models import *
|
||||
from vllm.model_executor.models import ModelRegistry
|
||||
from vllm.model_executor.weight_utils import (get_quant_config,
|
||||
initialize_dummy_weights)
|
||||
|
||||
# TODO(woosuk): Lazy-load the model classes.
|
||||
_MODEL_REGISTRY = {
|
||||
"AquilaModel": AquilaForCausalLM,
|
||||
"AquilaForCausalLM": AquilaForCausalLM, # AquilaChat2
|
||||
"BaiChuanForCausalLM": BaiChuanForCausalLM, # baichuan-7b
|
||||
"BaichuanForCausalLM": BaichuanForCausalLM, # baichuan-13b
|
||||
"BloomForCausalLM": BloomForCausalLM,
|
||||
"ChatGLMModel": ChatGLMForCausalLM,
|
||||
"FalconForCausalLM": FalconForCausalLM,
|
||||
"GPT2LMHeadModel": GPT2LMHeadModel,
|
||||
"GPTBigCodeForCausalLM": GPTBigCodeForCausalLM,
|
||||
"GPTJForCausalLM": GPTJForCausalLM,
|
||||
"GPTNeoXForCausalLM": GPTNeoXForCausalLM,
|
||||
"InternLMForCausalLM": InternLMForCausalLM,
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"LLaMAForCausalLM": LlamaForCausalLM, # For decapoda-research/llama-*
|
||||
"MistralForCausalLM": MistralForCausalLM,
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": MPTForCausalLM,
|
||||
"MPTForCausalLM": MPTForCausalLM,
|
||||
"OPTForCausalLM": OPTForCausalLM,
|
||||
"PhiForCausalLM": PhiForCausalLM,
|
||||
"QWenLMHeadModel": QWenLMHeadModel,
|
||||
"RWForCausalLM": FalconForCausalLM,
|
||||
"YiForCausalLM": YiForCausalLM,
|
||||
}
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
@ -51,11 +24,12 @@ def _set_default_torch_dtype(dtype: torch.dtype):
|
||||
def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]:
|
||||
architectures = getattr(config, "architectures", [])
|
||||
for arch in architectures:
|
||||
if arch in _MODEL_REGISTRY:
|
||||
return _MODEL_REGISTRY[arch]
|
||||
model_cls = ModelRegistry.load_model_cls(arch)
|
||||
if model_cls is not None:
|
||||
return model_cls
|
||||
raise ValueError(
|
||||
f"Model architectures {architectures} are not supported for now. "
|
||||
f"Supported architectures: {list(_MODEL_REGISTRY.keys())}")
|
||||
f"Supported architectures: {ModelRegistry.get_supported_archs()}")
|
||||
|
||||
|
||||
def get_model(model_config: ModelConfig) -> nn.Module:
|
||||
|
@ -1,39 +1,82 @@
|
||||
from vllm.model_executor.models.aquila import AquilaForCausalLM
|
||||
from vllm.model_executor.models.baichuan import (BaiChuanForCausalLM,
|
||||
BaichuanForCausalLM)
|
||||
from vllm.model_executor.models.bloom import BloomForCausalLM
|
||||
from vllm.model_executor.models.falcon import FalconForCausalLM
|
||||
from vllm.model_executor.models.gpt2 import GPT2LMHeadModel
|
||||
from vllm.model_executor.models.gpt_bigcode import GPTBigCodeForCausalLM
|
||||
from vllm.model_executor.models.gpt_j import GPTJForCausalLM
|
||||
from vllm.model_executor.models.gpt_neox import GPTNeoXForCausalLM
|
||||
from vllm.model_executor.models.internlm import InternLMForCausalLM
|
||||
from vllm.model_executor.models.llama import LlamaForCausalLM
|
||||
from vllm.model_executor.models.mistral import MistralForCausalLM
|
||||
from vllm.model_executor.models.mpt import MPTForCausalLM
|
||||
from vllm.model_executor.models.opt import OPTForCausalLM
|
||||
from vllm.model_executor.models.phi_1_5 import PhiForCausalLM
|
||||
from vllm.model_executor.models.qwen import QWenLMHeadModel
|
||||
from vllm.model_executor.models.chatglm import ChatGLMForCausalLM
|
||||
from vllm.model_executor.models.yi import YiForCausalLM
|
||||
import importlib
|
||||
from typing import List, Optional, Type
|
||||
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.logger import init_logger
|
||||
from vllm.utils import is_hip
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
# Architecture -> (module, class).
|
||||
_MODELS = {
|
||||
"AquilaModel": ("aquila", "AquilaForCausalLM"),
|
||||
"AquilaForCausalLM": ("aquila", "AquilaForCausalLM"), # AquilaChat2
|
||||
"BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b
|
||||
"BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b
|
||||
"BloomForCausalLM": ("bloom", "BloomForCausalLM"),
|
||||
"ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"),
|
||||
"FalconForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"),
|
||||
"GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"),
|
||||
"GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"),
|
||||
"GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"),
|
||||
"InternLMForCausalLM": ("internlm", "InternLMForCausalLM"),
|
||||
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
# For decapoda-research/llama-*
|
||||
"LLaMAForCausalLM": ("llama", "LlamaForCausalLM"),
|
||||
"MistralForCausalLM": ("mistral", "MistralForCausalLM"),
|
||||
"MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"),
|
||||
# transformers's mpt class has lower case
|
||||
"MptForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"MPTForCausalLM": ("mpt", "MPTForCausalLM"),
|
||||
"OPTForCausalLM": ("opt", "OPTForCausalLM"),
|
||||
"PhiForCausalLM": ("phi_1_5", "PhiForCausalLM"),
|
||||
"QWenLMHeadModel": ("qwen", "QWenLMHeadModel"),
|
||||
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
|
||||
"YiForCausalLM": ("yi", "YiForCausalLM"),
|
||||
}
|
||||
|
||||
# Models not supported by ROCm.
|
||||
_ROCM_UNSUPPORTED_MODELS = []
|
||||
|
||||
# Models partially supported by ROCm.
|
||||
# Architecture -> Reason.
|
||||
_ROCM_PARTIALLY_SUPPORTED_MODELS = {
|
||||
"MistralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
"MixtralForCausalLM":
|
||||
"Sliding window attention is not yet supported in ROCm's flash attention",
|
||||
}
|
||||
|
||||
|
||||
class ModelRegistry:
|
||||
|
||||
@staticmethod
|
||||
def load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]:
|
||||
if model_arch not in _MODELS:
|
||||
return None
|
||||
if is_hip():
|
||||
if model_arch in _ROCM_UNSUPPORTED_MODELS:
|
||||
raise ValueError(
|
||||
f"Model architecture {model_arch} is not supported by "
|
||||
"ROCm for now.")
|
||||
if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS:
|
||||
logger.warning(
|
||||
f"Model architecture {model_arch} is partially supported "
|
||||
"by ROCm: " + _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch])
|
||||
|
||||
module_name, model_cls_name = _MODELS[model_arch]
|
||||
module = importlib.import_module(
|
||||
f"vllm.model_executor.models.{module_name}")
|
||||
return getattr(module, model_cls_name, None)
|
||||
|
||||
@staticmethod
|
||||
def get_supported_archs() -> List[str]:
|
||||
return list(_MODELS.keys())
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AquilaForCausalLM",
|
||||
"BaiChuanForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
"BloomForCausalLM",
|
||||
"ChatGLMForCausalLM",
|
||||
"FalconForCausalLM",
|
||||
"GPT2LMHeadModel",
|
||||
"GPTBigCodeForCausalLM",
|
||||
"GPTJForCausalLM",
|
||||
"GPTNeoXForCausalLM",
|
||||
"InternLMForCausalLM",
|
||||
"LlamaForCausalLM",
|
||||
"MPTForCausalLM",
|
||||
"OPTForCausalLM",
|
||||
"PhiForCausalLM",
|
||||
"QWenLMHeadModel",
|
||||
"MistralForCausalLM",
|
||||
"YiForCausalLM",
|
||||
"ModelRegistry",
|
||||
]
|
||||
|
@ -366,15 +366,20 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 13b
|
||||
class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
"""Baichuan 13B and Baichuan2 7B/13B."""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
linear_method: Optional[LinearMethodBase] = None):
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
if config.hidden_size == 4096: # baichuan2 7b
|
||||
super().__init__(config, "ROPE", linear_method)
|
||||
else: # baichuan 13b, baichuan2 13b
|
||||
super().__init__(config, "ALIBI", linear_method)
|
||||
|
||||
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): # baichuan 7b
|
||||
class BaiChuanForCausalLM(BaiChuanBaseForCausalLM):
|
||||
"""Baichuan 7B."""
|
||||
|
||||
def __init__(self,
|
||||
config,
|
||||
|
@ -1,5 +1,5 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
@ -67,6 +67,7 @@ class InternLMAttention(nn.Module):
|
||||
rope_theta: float = 10000,
|
||||
max_position_embeddings: int = 8192,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
rope_scaling: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
@ -99,6 +100,7 @@ class InternLMAttention(nn.Module):
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=self.max_position_embeddings,
|
||||
base=self.rope_theta,
|
||||
rope_scaling=rope_scaling,
|
||||
)
|
||||
self.attn = PagedAttention(self.num_heads, self.head_dim, self.scaling)
|
||||
|
||||
@ -139,6 +141,7 @@ class InternLMDecoderLayer(nn.Module):
|
||||
rope_theta=rope_theta,
|
||||
max_position_embeddings=max_position_embeddings,
|
||||
linear_method=linear_method,
|
||||
rope_scaling=getattr(config, "rope_scaling", None),
|
||||
)
|
||||
self.mlp = InternLMMLP(
|
||||
hidden_size=self.hidden_size,
|
||||
|
@ -322,6 +322,11 @@ class LlamaForCausalLM(nn.Module):
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
if ("rotary_emb.cos_cached" in name
|
||||
or "rotary_emb.sin_cached" in name):
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
|
429
vllm/model_executor/models/mixtral.py
Normal file
429
vllm/model_executor/models/mixtral.py
Normal file
@ -0,0 +1,429 @@
|
||||
# coding=utf-8
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Mixtral model."""
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import nn
|
||||
from transformers import MixtralConfig
|
||||
|
||||
from vllm.model_executor.input_metadata import InputMetadata
|
||||
from vllm.model_executor.layers.attention import PagedAttention
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (LinearMethodBase,
|
||||
ReplicatedLinear,
|
||||
QKVParallelLinear,
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.model_executor.layers.sampler import Sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
VocabParallelEmbedding, ParallelLMHead)
|
||||
from vllm.model_executor.parallel_utils.communication_op import (
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)
|
||||
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
||||
from vllm.model_executor.weight_utils import (default_weight_loader,
|
||||
hf_model_weights_iterator)
|
||||
from vllm.model_executor.utils import set_weight_attrs
|
||||
from vllm.sequence import SamplerOutput
|
||||
|
||||
KVCache = Tuple[torch.Tensor, torch.Tensor]
|
||||
|
||||
|
||||
class MixtralMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.num_experts = num_experts
|
||||
self.ffn_dim = intermediate_size
|
||||
self.hidden_dim = hidden_size
|
||||
|
||||
self.w1 = ReplicatedLinear(self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.w2 = ReplicatedLinear(self.ffn_dim,
|
||||
self.hidden_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
self.w3 = ReplicatedLinear(self.hidden_dim,
|
||||
self.ffn_dim,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
|
||||
# TODO: Use vllm's SiluAndMul
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
w1_out, _ = self.w1(hidden_states)
|
||||
w1_out = self.act_fn(w1_out)
|
||||
w3_out, _ = self.w3(hidden_states)
|
||||
current_hidden_states = w1_out * w3_out
|
||||
current_hidden_states, _ = self.w2(current_hidden_states)
|
||||
return current_hidden_states
|
||||
|
||||
|
||||
class DummyModule(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.w1 = nn.Linear(0, 0, bias=False)
|
||||
self.w2 = nn.Linear(0, 0, bias=False)
|
||||
self.w3 = nn.Linear(0, 0, bias=False)
|
||||
|
||||
set_weight_attrs(self.w1.weight,
|
||||
{"weight_loader": self.dummy_weight_loader})
|
||||
set_weight_attrs(self.w2.weight,
|
||||
{"weight_loader": self.dummy_weight_loader})
|
||||
set_weight_attrs(self.w3.weight,
|
||||
{"weight_loader": self.dummy_weight_loader})
|
||||
|
||||
def forward(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
def dummy_weight_loader(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
|
||||
# Noop
|
||||
return
|
||||
|
||||
|
||||
class MixtralMoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.num_total_experts = config.num_local_experts
|
||||
self.top_k = config.num_experts_per_tok
|
||||
if self.tp_size > self.num_total_experts:
|
||||
raise ValueError(
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.num_total_experts}.")
|
||||
# Split experts equally between ranks
|
||||
self.expert_indicies = np.array_split(range(
|
||||
self.num_total_experts), self.tp_size)[self.rank].tolist()
|
||||
if not self.expert_indicies:
|
||||
raise ValueError(
|
||||
f"Rank {self.rank} has no experts assigned to it.")
|
||||
|
||||
self.experts = nn.ModuleList([
|
||||
MixtralMLP(self.num_total_experts,
|
||||
config.hidden_size,
|
||||
config.intermediate_size,
|
||||
linear_method=linear_method)
|
||||
if idx in self.expert_indicies else DummyModule()
|
||||
for idx in range(self.num_total_experts)
|
||||
])
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
self.num_total_experts,
|
||||
bias=False,
|
||||
linear_method=linear_method)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
batch_size, sequence_length, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
# router_logits: (batch * sequence_length, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
|
||||
routing_weights, selected_experts = torch.topk(routing_weights,
|
||||
self.top_k,
|
||||
dim=-1)
|
||||
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
|
||||
|
||||
final_hidden_states = None
|
||||
for expert_idx in self.expert_indicies:
|
||||
expert_layer = self.experts[expert_idx]
|
||||
expert_mask = (selected_experts == expert_idx)
|
||||
expert_weights = (routing_weights * expert_mask).sum(dim=-1,
|
||||
keepdim=True)
|
||||
|
||||
current_hidden_states = expert_layer(hidden_states).mul_(
|
||||
expert_weights)
|
||||
if final_hidden_states is None:
|
||||
final_hidden_states = current_hidden_states
|
||||
else:
|
||||
final_hidden_states.add_(current_hidden_states)
|
||||
|
||||
return tensor_model_parallel_all_reduce(final_hidden_states).view(
|
||||
batch_size, sequence_length, hidden_dim)
|
||||
|
||||
|
||||
class MixtralAttention(nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
hidden_size: int,
|
||||
num_heads: int,
|
||||
num_kv_heads: int,
|
||||
max_position: int = 4096 * 32,
|
||||
rope_theta: float = 10000,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
sliding_window: Optional[int] = None) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
self.total_num_heads = num_heads
|
||||
assert self.total_num_heads % tp_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_size
|
||||
self.total_num_kv_heads = num_kv_heads
|
||||
if self.total_num_kv_heads >= tp_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
||||
self.head_dim = hidden_size // self.total_num_heads
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
self.scaling = self.head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
self.sliding_window = sliding_window
|
||||
|
||||
self.qkv_proj = QKVParallelLinear(
|
||||
hidden_size,
|
||||
self.head_dim,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.o_proj = RowParallelLinear(
|
||||
self.total_num_heads * self.head_dim,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
self.rotary_emb = get_rope(
|
||||
self.head_dim,
|
||||
rotary_dim=self.head_dim,
|
||||
max_position=max_position,
|
||||
base=int(self.rope_theta),
|
||||
is_neox_style=True,
|
||||
)
|
||||
self.attn = PagedAttention(
|
||||
self.num_heads,
|
||||
self.head_dim,
|
||||
self.scaling,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
sliding_window=self.sliding_window,
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
) -> torch.Tensor:
|
||||
qkv, _ = self.qkv_proj(hidden_states)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
q, k = self.rotary_emb(positions, q, k)
|
||||
k_cache, v_cache = kv_cache
|
||||
attn_output = self.attn(q, k, v, k_cache, v_cache, input_metadata,
|
||||
cache_event)
|
||||
output, _ = self.o_proj(attn_output)
|
||||
return output
|
||||
|
||||
|
||||
class MixtralDecoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
# Requires transformers > 4.32.0
|
||||
rope_theta = getattr(config, "rope_theta", 10000)
|
||||
self.self_attn = MixtralAttention(
|
||||
hidden_size=self.hidden_size,
|
||||
num_heads=config.num_attention_heads,
|
||||
max_position=config.max_position_embeddings,
|
||||
num_kv_heads=config.num_key_value_heads,
|
||||
rope_theta=rope_theta,
|
||||
sliding_window=config.sliding_window,
|
||||
linear_method=linear_method)
|
||||
self.block_sparse_moe = MixtralMoE(config=config,
|
||||
linear_method=linear_method)
|
||||
self.input_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
self.post_attention_layernorm = RMSNorm(config.hidden_size,
|
||||
eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
positions: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
kv_cache: KVCache,
|
||||
input_metadata: InputMetadata,
|
||||
cache_event: Optional[torch.cuda.Event],
|
||||
residual: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
# Self Attention
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
else:
|
||||
hidden_states, residual = self.input_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.self_attn(
|
||||
positions=positions,
|
||||
hidden_states=hidden_states,
|
||||
kv_cache=kv_cache,
|
||||
input_metadata=input_metadata,
|
||||
cache_event=cache_event,
|
||||
)
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
hidden_states = self.block_sparse_moe(hidden_states)
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
class MixtralModel(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.padding_idx = config.pad_token_id
|
||||
self.vocab_size = config.vocab_size
|
||||
|
||||
self.embed_tokens = VocabParallelEmbedding(
|
||||
config.vocab_size,
|
||||
config.hidden_size,
|
||||
)
|
||||
self.layers = nn.ModuleList([
|
||||
MixtralDecoderLayer(config, linear_method=linear_method)
|
||||
for _ in range(config.num_hidden_layers)
|
||||
])
|
||||
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> SamplerOutput:
|
||||
hidden_states = self.embed_tokens(input_ids)
|
||||
residual = None
|
||||
for i in range(len(self.layers)):
|
||||
cache_event = None if cache_events is None else cache_events[i]
|
||||
layer = self.layers[i]
|
||||
hidden_states, residual = layer(positions, hidden_states,
|
||||
kv_caches[i], input_metadata,
|
||||
cache_event, residual)
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MixtralForCausalLM(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: MixtralConfig,
|
||||
linear_method: Optional[LinearMethodBase] = None,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.linear_method = linear_method
|
||||
self.model = MixtralModel(config, linear_method)
|
||||
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
||||
self.sampler = Sampler(config.vocab_size)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
kv_caches: List[KVCache],
|
||||
input_metadata: InputMetadata,
|
||||
cache_events: Optional[List[torch.cuda.Event]],
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.model(input_ids, positions, kv_caches,
|
||||
input_metadata, cache_events)
|
||||
return hidden_states
|
||||
|
||||
def sample(
|
||||
self,
|
||||
hidden_states: Optional[torch.Tensor],
|
||||
sampling_metadata: SamplingMetadata,
|
||||
) -> SamplerOutput:
|
||||
next_tokens = self.sampler(self.lm_head.weight, hidden_states,
|
||||
sampling_metadata)
|
||||
return next_tokens
|
||||
|
||||
def load_weights(self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision):
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for (param_name, weight_name, shard_id) in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
param = params_dict[name.replace(weight_name, param_name)]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader",
|
||||
default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
@ -50,9 +50,14 @@ class MPTAttention(nn.Module):
|
||||
super().__init__()
|
||||
self.d_model = config.d_model
|
||||
self.total_num_heads = config.n_heads
|
||||
self.head_dim = self.d_model // self.total_num_heads
|
||||
self.clip_qkv = config.attn_config["clip_qkv"]
|
||||
self.qk_ln = config.attn_config["qk_ln"]
|
||||
self.alibi_bias_max = config.attn_config["alibi_bias_max"]
|
||||
if "kv_n_heads" in config.attn_config:
|
||||
self.total_num_kv_heads = config.attn_config['kv_n_heads']
|
||||
else:
|
||||
self.total_num_kv_heads = self.total_num_heads
|
||||
assert not config.attn_config["prefix_lm"]
|
||||
assert config.attn_config["alibi"]
|
||||
|
||||
@ -61,6 +66,7 @@ class MPTAttention(nn.Module):
|
||||
self.d_model,
|
||||
self.d_model // self.total_num_heads,
|
||||
self.total_num_heads,
|
||||
self.total_num_kv_heads,
|
||||
bias=not config.no_bias,
|
||||
linear_method=linear_method,
|
||||
)
|
||||
@ -78,6 +84,17 @@ class MPTAttention(nn.Module):
|
||||
assert self.total_num_heads % tp_world_size == 0
|
||||
self.num_heads = self.total_num_heads // tp_world_size
|
||||
|
||||
if self.total_num_kv_heads >= tp_world_size:
|
||||
# Number of KV heads is greater than TP size, so we partition
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert self.total_num_kv_heads % tp_world_size == 0
|
||||
else:
|
||||
# Number of KV heads is less than TP size, so we replicate
|
||||
# the KV heads across multiple tensor parallel GPUs.
|
||||
assert tp_world_size % self.total_num_kv_heads == 0
|
||||
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_world_size)
|
||||
self.q_size = self.num_heads * self.head_dim
|
||||
self.kv_size = self.num_kv_heads * self.head_dim
|
||||
# Create the alibi slopes and slice them.
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
head_start = tp_rank * self.num_heads
|
||||
@ -91,7 +108,8 @@ class MPTAttention(nn.Module):
|
||||
self.attn = PagedAttention(self.num_heads,
|
||||
self.head_dim,
|
||||
scaling,
|
||||
alibi_slopes=alibi_slopes)
|
||||
alibi_slopes=alibi_slopes,
|
||||
num_kv_heads=self.num_kv_heads)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@ -105,7 +123,7 @@ class MPTAttention(nn.Module):
|
||||
qkv, _ = self.Wqkv(hidden_states)
|
||||
if self.clip_qkv is not None:
|
||||
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
||||
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
||||
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
||||
if self.qk_ln:
|
||||
q = self.q_ln(q)
|
||||
k = self.k_ln(k)
|
||||
|
@ -149,6 +149,7 @@ class SamplingParams:
|
||||
# Zero temperature means greedy sampling.
|
||||
self.top_p = 1.0
|
||||
self.top_k = -1
|
||||
self.min_p = 0.0
|
||||
self._verify_greedy_sampling()
|
||||
|
||||
def _verify_args(self) -> None:
|
||||
|
@ -27,20 +27,19 @@ class Counter:
|
||||
self.counter = 0
|
||||
|
||||
|
||||
def is_hip() -> bool:
|
||||
return torch.version.hip is not None
|
||||
|
||||
|
||||
def get_max_shared_memory_bytes(gpu: int = 0) -> int:
|
||||
"""Returns the maximum shared memory per thread block in bytes."""
|
||||
# https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
|
||||
max_shared_mem = cuda_utils.get_device_attribute(
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, gpu)
|
||||
return int(max_shared_mem)
|
||||
|
||||
|
||||
def get_gpu_memory(gpu: int = 0) -> int:
|
||||
"""Returns the total memory of the GPU in bytes."""
|
||||
return torch.cuda.get_device_properties(gpu).total_memory
|
||||
|
||||
|
||||
def get_cpu_memory() -> int:
|
||||
"""Returns the total CPU memory of the node in bytes."""
|
||||
return psutil.virtual_memory().total
|
||||
|
@ -134,14 +134,14 @@ class ModelRunner:
|
||||
generation_token = seq_data.get_last_token_id()
|
||||
input_tokens.append([generation_token])
|
||||
|
||||
context_len = seq_data.get_len()
|
||||
if self.sliding_window is not None:
|
||||
context_len = min(context_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
|
||||
position = context_len - 1
|
||||
seq_len = seq_data.get_len()
|
||||
position = seq_len - 1
|
||||
input_positions.append([position])
|
||||
|
||||
context_len = seq_len if self.sliding_window is None else min(
|
||||
seq_len, self.sliding_window)
|
||||
context_lens.append(context_len)
|
||||
|
||||
block_table = seq_group_metadata.block_tables[seq_id]
|
||||
block_number = block_table[position // self.block_size]
|
||||
block_offset = position % self.block_size
|
||||
|
@ -13,7 +13,6 @@ from vllm.model_executor.parallel_utils.parallel_state import (
|
||||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
|
||||
from vllm.worker.cache_engine import CacheEngine
|
||||
from vllm.worker.model_runner import ModelRunner
|
||||
from vllm.utils import get_gpu_memory
|
||||
|
||||
|
||||
class Worker:
|
||||
@ -81,7 +80,6 @@ class Worker:
|
||||
# Profile the memory usage of the model and get the maximum number of
|
||||
# cache blocks that can be allocated with the remaining free memory.
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# Execute a forward pass with dummy inputs to profile the memory usage
|
||||
# of the model.
|
||||
@ -90,8 +88,9 @@ class Worker:
|
||||
# Calculate the number of blocks that can be allocated with the
|
||||
# profiled peak memory.
|
||||
torch.cuda.synchronize()
|
||||
peak_memory = torch.cuda.max_memory_allocated()
|
||||
total_gpu_memory = get_gpu_memory()
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
peak_memory = total_gpu_memory - free_gpu_memory
|
||||
|
||||
cache_block_size = CacheEngine.get_cache_block_size(
|
||||
block_size, self.model_config, self.parallel_config)
|
||||
num_gpu_blocks = int(
|
||||
|
Reference in New Issue
Block a user