mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-10-31 06:14:38 +08:00 
			
		
		
		
	Compare commits
	
		
			110 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 1af090b57d | |||
| 3dad944485 | |||
| 105a40f53a | |||
| bbe9bd9684 | |||
| 4f65af0e25 | |||
| d79ced3292 | |||
| ab40644669 | |||
| 5d60def02c | |||
| ea8489fce2 | |||
| 1b20639a43 | |||
| b72af8f1ed | |||
| 9090bf02e7 | |||
| 7d648418b8 | |||
| 89be30fa7d | |||
| f8ecb84c02 | |||
| 5f036d2bcc | |||
| 380170038e | |||
| 220a47627b | |||
| beb89f68b4 | |||
| 390b495ff3 | |||
| 3a0e1fc070 | |||
| 6b7de1a030 | |||
| 5265631d15 | |||
| 2832e7b9f9 | |||
| 3a7dd7e367 | |||
| 223c19224b | |||
| f1f6cc10c7 | |||
| 3209b49033 | |||
| 1e4277d2d1 | |||
| 9b945daaf1 | |||
| 9c1352eb57 | |||
| 7a0b011dd5 | |||
| 63e835cbcc | |||
| 94b5edeb53 | |||
| ab7e6006d6 | |||
| 18bfcdd05c | |||
| 71d63ed72e | |||
| d75c40734a | |||
| 5b23c3f26f | |||
| 00efdc84ba | |||
| 91a61da9b1 | |||
| ef9b636e2d | |||
| 2709c0009a | |||
| dd7e8f5f64 | |||
| d2a68364c4 | |||
| 7e1081139d | |||
| 18473cf498 | |||
| 4df417d059 | |||
| 5d80a9178b | |||
| 8a25d3a71a | |||
| d10f8e1d43 | |||
| 14cc317ba4 | |||
| e1957c6ebd | |||
| 8cd5a992bf | |||
| 947f0b23cc | |||
| f780504d12 | |||
| bfc072addf | |||
| 2a18da257c | |||
| 6e01e8c1c8 | |||
| 9f659bf07f | |||
| 35c4bc20d9 | |||
| 218dc2ccda | |||
| 827cbcd37c | |||
| cb7a1c1cbf | |||
| 7878958c0d | |||
| ce036244c9 | |||
| 48cf1e413c | |||
| 97460585d9 | |||
| f745847ef7 | |||
| 6549aef245 | |||
| 50376faa7b | |||
| 4b61c6b669 | |||
| 79d64c4954 | |||
| 74cd5abdd1 | |||
| 28c3f12104 | |||
| c884819135 | |||
| 05921a9a7a | |||
| d0215a58e7 | |||
| 937e7b7d7c | |||
| aee8ef661a | |||
| 2e0b6e7757 | |||
| 941767127c | |||
| 74d8d77626 | |||
| fd4ea8ef5c | |||
| 1066cbd152 | |||
| 6ef00b03a2 | |||
| 9140561059 | |||
| 77af974b40 | |||
| 4934d49274 | |||
| 358c328d69 | |||
| 4aaafdd289 | |||
| 66b108d142 | |||
| e0ff920001 | |||
| face83c7ec | |||
| 1db83e31a2 | |||
| a1b9cb2a34 | |||
| 3a4fd5ca59 | |||
| c17daa9f89 | |||
| bd29cf3d3a | |||
| 31bff69151 | |||
| ba4f826738 | |||
| de60a3fb93 | |||
| 21d5daa4ac | |||
| 290e015c6c | |||
| 1b7c791d60 | |||
| bbe4466fd9 | |||
| 08133c4d1a | |||
| 76a7983b23 | |||
| 8041b7305e | |||
| 3ec8c25cd0 | 
							
								
								
									
										63
									
								
								.buildkite/run-benchmarks.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										63
									
								
								.buildkite/run-benchmarks.sh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,63 @@ | |||||||
|  | # This script is run by buildkite to run the benchmarks and upload the results to buildkite | ||||||
|  |  | ||||||
|  | set -ex | ||||||
|  | set -o pipefail | ||||||
|  |  | ||||||
|  | # cd into parent directory of this file | ||||||
|  | cd "$(dirname "${BASH_SOURCE[0]}")/.." | ||||||
|  |  | ||||||
|  | (wget && curl) || (apt-get update && apt-get install -y wget curl) | ||||||
|  |  | ||||||
|  | # run benchmarks and upload the result to buildkite | ||||||
|  | python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt | ||||||
|  | bench_latency_exit_code=$? | ||||||
|  |  | ||||||
|  | python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt | ||||||
|  | bench_throughput_exit_code=$? | ||||||
|  |  | ||||||
|  | python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf & | ||||||
|  | server_pid=$! | ||||||
|  | wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json | ||||||
|  |  | ||||||
|  | # wait for server to start, timeout after 600 seconds | ||||||
|  | timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 | ||||||
|  | python3 benchmarks/benchmark_serving.py \ | ||||||
|  |     --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \ | ||||||
|  |     --model meta-llama/Llama-2-7b-chat-hf \ | ||||||
|  |     --num-prompts 20 \ | ||||||
|  |     --endpoint /v1/completions \ | ||||||
|  |     --tokenizer meta-llama/Llama-2-7b-chat-hf 2>&1 | tee benchmark_serving.txt | ||||||
|  | bench_serving_exit_code=$? | ||||||
|  | kill $server_pid | ||||||
|  |  | ||||||
|  | # write the results into a markdown file | ||||||
|  | echo "### Latency Benchmarks" >> benchmark_results.md | ||||||
|  | sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line | ||||||
|  | echo "" >> benchmark_results.md | ||||||
|  | sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line | ||||||
|  |  | ||||||
|  | echo "### Throughput Benchmarks" >> benchmark_results.md | ||||||
|  | sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line | ||||||
|  | echo "" >> benchmark_results.md | ||||||
|  | sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line | ||||||
|  |  | ||||||
|  | echo "### Serving Benchmarks" >> benchmark_results.md | ||||||
|  | sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line | ||||||
|  | echo "" >> benchmark_results.md | ||||||
|  | tail -n 5 benchmark_serving.txt >> benchmark_results.md # last 5 lines | ||||||
|  |  | ||||||
|  | # upload the results to buildkite | ||||||
|  | /workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md | ||||||
|  |  | ||||||
|  | # exit with the exit code of the benchmarks | ||||||
|  | if [ $bench_latency_exit_code -ne 0 ]; then | ||||||
|  |     exit $bench_latency_exit_code | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | if [ $bench_throughput_exit_code -ne 0 ]; then | ||||||
|  |     exit $bench_throughput_exit_code | ||||||
|  | fi | ||||||
|  |  | ||||||
|  | if [ $bench_serving_exit_code -ne 0 ]; then | ||||||
|  |     exit $bench_serving_exit_code | ||||||
|  | fi | ||||||
							
								
								
									
										51
									
								
								.buildkite/test-pipeline.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										51
									
								
								.buildkite/test-pipeline.yaml
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,51 @@ | |||||||
|  | # In this file, you can add more tests to run either by adding a new step or | ||||||
|  | # adding a new command to an existing step. See different options here for examples. | ||||||
|  | # This script will be feed into Jinja template in `test-template.j2` to generate | ||||||
|  | # the final pipeline yaml file. | ||||||
|  |  | ||||||
|  | steps: | ||||||
|  | - label: Regression Test | ||||||
|  |   command: pytest -v -s test_regression.py | ||||||
|  |   working_dir: "/vllm-workspace/tests" # optional | ||||||
|  |  | ||||||
|  | - label: AsyncEngine Test | ||||||
|  |   command: pytest -v -s async_engine | ||||||
|  |  | ||||||
|  | - label: Distributed Test | ||||||
|  |   command: pytest -v -s test_comm_ops.py | ||||||
|  |   working_dir: "/vllm-workspace/tests/distributed" | ||||||
|  |   num_gpus: 2 # only support 1 or 2 for now. | ||||||
|  |  | ||||||
|  | - label: Engine Test | ||||||
|  |   command: pytest -v -s engine | ||||||
|  |  | ||||||
|  | - label: Entrypoints Test | ||||||
|  |   command: pytest -v -s entrypoints | ||||||
|  |  | ||||||
|  | - label: Kernels Test | ||||||
|  |   command: pytest -v -s kernels | ||||||
|  |   soft_fail: true | ||||||
|  |  | ||||||
|  | - label: Models Test | ||||||
|  |   commands: | ||||||
|  |     - pytest -v -s models --forked | ||||||
|  |   soft_fail: true | ||||||
|  |  | ||||||
|  | - label: Prefix Caching Test | ||||||
|  |   commands: | ||||||
|  |     - pytest -v -s prefix_caching | ||||||
|  |  | ||||||
|  | - label: Samplers Test | ||||||
|  |   command: pytest -v -s samplers --forked | ||||||
|  |  | ||||||
|  | - label: Worker Test | ||||||
|  |   command: pytest -v -s worker | ||||||
|  |  | ||||||
|  | - label: LoRA Test | ||||||
|  |   command: pytest -v -s lora | ||||||
|  |  | ||||||
|  | - label: Benchmarks | ||||||
|  |   working_dir: "/vllm-workspace/.buildkite" | ||||||
|  |   commands: | ||||||
|  |   - pip install aiohttp | ||||||
|  |   - bash run-benchmarks.sh | ||||||
							
								
								
									
										54
									
								
								.buildkite/test-template.j2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								.buildkite/test-template.j2
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,54 @@ | |||||||
|  | {% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %} | ||||||
|  | {% set default_num_gpu = 1 %} | ||||||
|  | {% set default_working_dir = "/vllm-workspace/tests" %} | ||||||
|  |  | ||||||
|  | steps: | ||||||
|  |   - label: ":docker: build image" | ||||||
|  |     commands: | ||||||
|  |       - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ." | ||||||
|  |       - "docker push {{ docker_image }}" | ||||||
|  |     env: | ||||||
|  |       DOCKER_BUILDKIT: "1" | ||||||
|  |     retry: | ||||||
|  |       automatic: | ||||||
|  |         - exit_status: -1  # Agent was lost | ||||||
|  |           limit: 5 | ||||||
|  |   - wait | ||||||
|  |  | ||||||
|  |   {% for step in steps %} | ||||||
|  |   - label: "{{ step.label }}" | ||||||
|  |     agents: | ||||||
|  |       queue: kubernetes | ||||||
|  |     soft_fail: {{ step.soft_fail or false }} | ||||||
|  |     retry: | ||||||
|  |       automatic: | ||||||
|  |         - exit_status: -1  # Agent was lost | ||||||
|  |           limit: 5 | ||||||
|  |     plugins: | ||||||
|  |       - kubernetes: | ||||||
|  |           podSpec: | ||||||
|  |             volumes: | ||||||
|  |               - name: dshm | ||||||
|  |                 emptyDir: | ||||||
|  |                   medium: Memory | ||||||
|  |             containers: | ||||||
|  |               - image: "{{ docker_image }}" | ||||||
|  |                 command: ["bash"] | ||||||
|  |                 args: | ||||||
|  |                 - "-c" | ||||||
|  |                 - "'cd {{ (step.working_dir or default_working_dir) | safe  }} && {{ step.command  or (step.commands | join(' && ')) | safe }}'" | ||||||
|  |                 resources: | ||||||
|  |                   requests: | ||||||
|  |                     nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" | ||||||
|  |                   limits: | ||||||
|  |                     nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}" | ||||||
|  |                 env: | ||||||
|  |                   - name: HF_TOKEN | ||||||
|  |                     valueFrom: | ||||||
|  |                       secretKeyRef: | ||||||
|  |                         name: hf-token-secret | ||||||
|  |                         key: token | ||||||
|  |                 volumeMounts: | ||||||
|  |                   - mountPath: /dev/shm | ||||||
|  |                     name: dshm | ||||||
|  |   {% endfor %} | ||||||
							
								
								
									
										1
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1 @@ | |||||||
|  | vllm/*.so | ||||||
							
								
								
									
										2
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
								
							| @ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt | |||||||
|  |  | ||||||
| # Limit the number of parallel jobs to avoid OOM | # Limit the number of parallel jobs to avoid OOM | ||||||
| export MAX_JOBS=1 | export MAX_JOBS=1 | ||||||
|  | # Make sure punica is built for the release (for LoRA) | ||||||
|  | export VLLM_INSTALL_PUNICA_KERNELS=1 | ||||||
|  |  | ||||||
| # Build | # Build | ||||||
| $python_executable setup.py bdist_wheel --dist-dir=dist | $python_executable setup.py bdist_wheel --dist-dir=dist | ||||||
|  | |||||||
							
								
								
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							| @ -28,4 +28,4 @@ jobs: | |||||||
|         pip install toml==0.10.2 |         pip install toml==0.10.2 | ||||||
|     - name: Running yapf |     - name: Running yapf | ||||||
|       run: | |       run: | | ||||||
|         yapf --diff --recursive vllm tests |         yapf --diff --recursive . | ||||||
|  | |||||||
							
								
								
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							| @ -181,3 +181,6 @@ _build/ | |||||||
| # hip files generated by PyTorch | # hip files generated by PyTorch | ||||||
| *.hip | *.hip | ||||||
| *_hip* | *_hip* | ||||||
|  |  | ||||||
|  | # Benchmark dataset | ||||||
|  | *.json | ||||||
|  | |||||||
							
								
								
									
										39
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								Dockerfile
									
									
									
									
									
								
							| @ -1,7 +1,11 @@ | |||||||
|  | # The vLLM Dockerfile is used to construct vLLM image that can be directly used | ||||||
|  | # to run the OpenAI compatible server. | ||||||
|  |  | ||||||
|  | #################### BASE BUILD IMAGE #################### | ||||||
| FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev | FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev | ||||||
|  |  | ||||||
| RUN apt-get update -y \ | RUN apt-get update -y \ | ||||||
|     && apt-get install -y python3-pip |     && apt-get install -y python3-pip git | ||||||
|  |  | ||||||
| WORKDIR /workspace | WORKDIR /workspace | ||||||
|  |  | ||||||
| @ -14,8 +18,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \ | |||||||
| COPY requirements-dev.txt requirements-dev.txt | COPY requirements-dev.txt requirements-dev.txt | ||||||
| RUN --mount=type=cache,target=/root/.cache/pip \ | RUN --mount=type=cache,target=/root/.cache/pip \ | ||||||
|     pip install -r requirements-dev.txt |     pip install -r requirements-dev.txt | ||||||
|  | #################### BASE BUILD IMAGE #################### | ||||||
|  |  | ||||||
| # image to build pytorch extensions |  | ||||||
|  | #################### EXTENSION BUILD IMAGE #################### | ||||||
| FROM dev AS build | FROM dev AS build | ||||||
|  |  | ||||||
| # install build dependencies | # install build dependencies | ||||||
| @ -30,6 +36,7 @@ COPY requirements.txt requirements.txt | |||||||
| COPY pyproject.toml pyproject.toml | COPY pyproject.toml pyproject.toml | ||||||
| COPY vllm/__init__.py vllm/__init__.py | COPY vllm/__init__.py vllm/__init__.py | ||||||
|  |  | ||||||
|  | # cuda arch list used by torch | ||||||
| ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' | ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX' | ||||||
| ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list} | ||||||
| # max jobs used by Ninja to build extensions | # max jobs used by Ninja to build extensions | ||||||
| @ -38,20 +45,30 @@ ENV MAX_JOBS=${max_jobs} | |||||||
| # number of threads used by nvcc | # number of threads used by nvcc | ||||||
| ARG nvcc_threads=8 | ARG nvcc_threads=8 | ||||||
| ENV NVCC_THREADS=$nvcc_threads | ENV NVCC_THREADS=$nvcc_threads | ||||||
|  | # make sure punica kernels are built (for LoRA) | ||||||
|  | ENV VLLM_INSTALL_PUNICA_KERNELS=1 | ||||||
|  |  | ||||||
| RUN python3 setup.py build_ext --inplace | RUN python3 setup.py build_ext --inplace | ||||||
|  | #################### EXTENSION Build IMAGE #################### | ||||||
|  |  | ||||||
|  |  | ||||||
|  | #################### TEST IMAGE #################### | ||||||
| # image to run unit testing suite | # image to run unit testing suite | ||||||
| FROM dev AS test | FROM dev AS test | ||||||
|  |  | ||||||
| # copy pytorch extensions separately to avoid having to rebuild | # copy pytorch extensions separately to avoid having to rebuild | ||||||
| # when python code changes | # when python code changes | ||||||
| COPY --from=build /workspace/vllm/*.so /workspace/vllm/ | WORKDIR /vllm-workspace | ||||||
| COPY tests tests | # ADD is used to preserve directory structure | ||||||
| COPY vllm vllm | ADD . /vllm-workspace/ | ||||||
|  | COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/ | ||||||
|  | # ignore build dependencies installation because we are using pre-complied extensions | ||||||
|  | RUN rm pyproject.toml | ||||||
|  | RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose | ||||||
|  | #################### TEST IMAGE #################### | ||||||
|  |  | ||||||
| ENTRYPOINT ["python3", "-m", "pytest", "tests"] |  | ||||||
|  |  | ||||||
|  | #################### RUNTIME BASE IMAGE #################### | ||||||
| # use CUDA base as CUDA runtime dependencies are already installed via pip | # use CUDA base as CUDA runtime dependencies are already installed via pip | ||||||
| FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base | FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base | ||||||
|  |  | ||||||
| @ -63,14 +80,10 @@ WORKDIR /workspace | |||||||
| COPY requirements.txt requirements.txt | COPY requirements.txt requirements.txt | ||||||
| RUN --mount=type=cache,target=/root/.cache/pip \ | RUN --mount=type=cache,target=/root/.cache/pip \ | ||||||
|     pip install -r requirements.txt |     pip install -r requirements.txt | ||||||
|  | #################### RUNTIME BASE IMAGE #################### | ||||||
|  |  | ||||||
| FROM vllm-base AS vllm |  | ||||||
| COPY --from=build /workspace/vllm/*.so /workspace/vllm/ |  | ||||||
| COPY vllm vllm |  | ||||||
|  |  | ||||||
| EXPOSE 8000 |  | ||||||
| ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"] |  | ||||||
|  |  | ||||||
|  | #################### OPENAI API SERVER #################### | ||||||
| # openai api server alternative | # openai api server alternative | ||||||
| FROM vllm-base AS vllm-openai | FROM vllm-base AS vllm-openai | ||||||
| # install additional dependencies for openai api server | # install additional dependencies for openai api server | ||||||
| @ -81,4 +94,4 @@ COPY --from=build /workspace/vllm/*.so /workspace/vllm/ | |||||||
| COPY vllm vllm | COPY vllm vllm | ||||||
|  |  | ||||||
| ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] | ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] | ||||||
|  | #################### OPENAI API SERVER #################### | ||||||
|  | |||||||
| @ -1,4 +1,24 @@ | |||||||
| FROM rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1 | # default base image | ||||||
|  | ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" | ||||||
|  |  | ||||||
|  | FROM $BASE_IMAGE | ||||||
|  |  | ||||||
|  | ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" | ||||||
|  |  | ||||||
|  | RUN echo "Base image is $BASE_IMAGE" | ||||||
|  |  | ||||||
|  | # BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" | ||||||
|  | # BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" | ||||||
|  |  | ||||||
|  | # this does not always work for all rocm versions | ||||||
|  | RUN LLVM_GFX_ARCH=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) && \ | ||||||
|  |     echo "LLVM_GFX_ARCH is $LLVM_GFX_ARCH" | ||||||
|  |  | ||||||
|  | ARG FA_GFX_ARCHS="gfx90a;gfx942" | ||||||
|  | RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS" | ||||||
|  |  | ||||||
|  | ARG FA_BRANCH="3d2b6f5" | ||||||
|  | RUN echo "FA_BRANCH is $FA_BRANCH" | ||||||
|  |  | ||||||
| # Install some basic utilities | # Install some basic utilities | ||||||
| RUN apt-get update && apt-get install python3 python3-pip -y | RUN apt-get update && apt-get install python3 python3-pip -y | ||||||
| @ -37,17 +57,23 @@ RUN mkdir libs \ | |||||||
|     && cd libs \ |     && cd libs \ | ||||||
|     && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ |     && git clone https://github.com/ROCmSoftwarePlatform/flash-attention.git \ | ||||||
|     && cd flash-attention \ |     && cd flash-attention \ | ||||||
|     && git checkout 3d2b6f5 \ |     && git checkout ${FA_BRANCH} \ | ||||||
|     && git submodule update --init \ |     && git submodule update --init \ | ||||||
|     && export GPU_ARCHS=$(/opt/rocm/llvm/bin/amdgpu-offload-arch) \ |     && export GPU_ARCHS=${FA_GFX_ARCHS} \ | ||||||
|     && patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch \ |     && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \ | ||||||
|  |         patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \ | ||||||
|     && python3 setup.py install \ |     && python3 setup.py install \ | ||||||
|     && cd .. |     && cd .. | ||||||
|  |  | ||||||
| COPY ./ /app/vllm | COPY ./ /app/vllm | ||||||
|  |  | ||||||
| RUN python3 -m pip install --upgrade pip | RUN python3 -m pip install --upgrade pip | ||||||
| RUN pip install xformers==0.0.23 --no-deps | RUN python3 -m pip install xformers==0.0.23 --no-deps | ||||||
|  |  | ||||||
|  | # Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt. | ||||||
|  | # Manually removed it so that later steps of numpy upgrade can continue | ||||||
|  | RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \ | ||||||
|  |     rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi | ||||||
|  |  | ||||||
| RUN cd /app \ | RUN cd /app \ | ||||||
|     && cd vllm \ |     && cd vllm \ | ||||||
|  | |||||||
							
								
								
									
										21
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								README.md
									
									
									
									
									
								
							| @ -16,8 +16,18 @@ Easy, fast, and cheap LLM serving for everyone | |||||||
|  |  | ||||||
| --- | --- | ||||||
|  |  | ||||||
|  | **The Second vLLM Bay Area Meetup (Jan 31st 5pm-7:30pm PT)** | ||||||
|  |  | ||||||
|  | We are thrilled to announce our second vLLM Meetup! | ||||||
|  | The vLLM team will share recent updates and roadmap. | ||||||
|  | We will also have vLLM collaborators from IBM coming up to the stage to discuss their insights on LLM optimizations. | ||||||
|  | Please register [here](https://lu.ma/ygxbpzhl) and join us! | ||||||
|  |  | ||||||
|  | --- | ||||||
|  |  | ||||||
| *Latest News* 🔥 | *Latest News* 🔥 | ||||||
| - [2023/12] Added ROCm support to vLLM. | - [2024/01] Added ROCm 6.0 support to vLLM. | ||||||
|  | - [2023/12] Added ROCm 5.7 support to vLLM. | ||||||
| - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). | - [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing). | ||||||
| - [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. | - [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there. | ||||||
| - [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! | - [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv! | ||||||
| @ -27,7 +37,7 @@ Easy, fast, and cheap LLM serving for everyone | |||||||
| - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). | - [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai). | ||||||
|  |  | ||||||
| --- | --- | ||||||
|  | ## About | ||||||
| vLLM is a fast and easy-to-use library for LLM inference and serving. | vLLM is a fast and easy-to-use library for LLM inference and serving. | ||||||
|  |  | ||||||
| vLLM is fast with: | vLLM is fast with: | ||||||
| @ -36,7 +46,7 @@ vLLM is fast with: | |||||||
| - Efficient management of attention key and value memory with **PagedAttention** | - Efficient management of attention key and value memory with **PagedAttention** | ||||||
| - Continuous batching of incoming requests | - Continuous batching of incoming requests | ||||||
| - Fast model execution with CUDA/HIP graph | - Fast model execution with CUDA/HIP graph | ||||||
| - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629) | - Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache | ||||||
| - Optimized CUDA kernels | - Optimized CUDA kernels | ||||||
|  |  | ||||||
| vLLM is flexible and easy to use with: | vLLM is flexible and easy to use with: | ||||||
| @ -47,6 +57,8 @@ vLLM is flexible and easy to use with: | |||||||
| - Streaming outputs | - Streaming outputs | ||||||
| - OpenAI-compatible API server | - OpenAI-compatible API server | ||||||
| - Support NVIDIA GPUs and AMD GPUs | - Support NVIDIA GPUs and AMD GPUs | ||||||
|  | - (Experimental) Prefix caching support | ||||||
|  | - (Experimental) Multi-lora support | ||||||
|  |  | ||||||
| vLLM seamlessly supports many Hugging Face models, including the following architectures: | vLLM seamlessly supports many Hugging Face models, including the following architectures: | ||||||
|  |  | ||||||
| @ -54,6 +66,7 @@ vLLM seamlessly supports many Hugging Face models, including the following archi | |||||||
| - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) | - Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.) | ||||||
| - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) | - BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.) | ||||||
| - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) | - ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.) | ||||||
|  | - DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.) | ||||||
| - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) | - Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.) | ||||||
| - GPT-2 (`gpt2`, `gpt2-xl`, etc.) | - GPT-2 (`gpt2`, `gpt2-xl`, etc.) | ||||||
| - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) | - GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.) | ||||||
| @ -67,6 +80,8 @@ vLLM seamlessly supports many Hugging Face models, including the following archi | |||||||
| - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) | - OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.) | ||||||
| - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) | - Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.) | ||||||
| - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) | - Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.) | ||||||
|  | - Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.) | ||||||
|  | - StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.) | ||||||
| - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) | - Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.) | ||||||
|  |  | ||||||
| Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): | Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source): | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ def main(args: argparse.Namespace): | |||||||
|         trust_remote_code=args.trust_remote_code, |         trust_remote_code=args.trust_remote_code, | ||||||
|         dtype=args.dtype, |         dtype=args.dtype, | ||||||
|         enforce_eager=args.enforce_eager, |         enforce_eager=args.enforce_eager, | ||||||
|  |         kv_cache_dtype=args.kv_cache_dtype, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     sampling_params = SamplingParams( |     sampling_params = SamplingParams( | ||||||
| @ -65,7 +66,9 @@ def main(args: argparse.Namespace): | |||||||
|     if args.profile: |     if args.profile: | ||||||
|         profile_dir = args.profile_result_dir |         profile_dir = args.profile_result_dir | ||||||
|         if not profile_dir: |         if not profile_dir: | ||||||
|             profile_dir = Path(".") / "vllm_benchmark_result" / f"latency_result_{time.time()}" |             profile_dir = Path( | ||||||
|  |                 "." | ||||||
|  |             ) / "vllm_benchmark_result" / f"latency_result_{time.time()}" | ||||||
|         print(f"Profiling (results will be saved to '{profile_dir}')...") |         print(f"Profiling (results will be saved to '{profile_dir}')...") | ||||||
|         run_to_completion(profile_dir=args.profile_result_dir) |         run_to_completion(profile_dir=args.profile_result_dir) | ||||||
|         return |         return | ||||||
| @ -115,6 +118,13 @@ if __name__ == '__main__': | |||||||
|     parser.add_argument('--enforce-eager', |     parser.add_argument('--enforce-eager', | ||||||
|                         action='store_true', |                         action='store_true', | ||||||
|                         help='enforce eager mode and disable CUDA graph') |                         help='enforce eager mode and disable CUDA graph') | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--kv-cache-dtype", | ||||||
|  |         type=str, | ||||||
|  |         choices=['auto', 'fp8_e5m2'], | ||||||
|  |         default='auto', | ||||||
|  |         help= | ||||||
|  |         'Data type for kv cache storage. If "auto", will use model data type.') | ||||||
|     parser.add_argument( |     parser.add_argument( | ||||||
|         '--profile', |         '--profile', | ||||||
|         action='store_true', |         action='store_true', | ||||||
| @ -123,9 +133,7 @@ if __name__ == '__main__': | |||||||
|         '--profile-result-dir', |         '--profile-result-dir', | ||||||
|         type=str, |         type=str, | ||||||
|         default=None, |         default=None, | ||||||
|         help=( |         help=('path to save the pytorch profiler output. Can be visualized ' | ||||||
|             'path to save the pytorch profiler output. Can be visualized ' |               'with ui.perfetto.dev or Tensorboard.')) | ||||||
|             'with ui.perfetto.dev or Tensorboard.' |  | ||||||
|         )) |  | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     main(args) |     main(args) | ||||||
|  | |||||||
| @ -24,6 +24,7 @@ from typing import AsyncGenerator, List, Tuple | |||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
| import numpy as np | import numpy as np | ||||||
|  | from tqdm.asyncio import tqdm | ||||||
| from transformers import PreTrainedTokenizerBase | from transformers import PreTrainedTokenizerBase | ||||||
| from vllm.transformers_utils.tokenizer import get_tokenizer | from vllm.transformers_utils.tokenizer import get_tokenizer | ||||||
|  |  | ||||||
| @ -40,15 +41,10 @@ def sample_requests( | |||||||
|     with open(dataset_path) as f: |     with open(dataset_path) as f: | ||||||
|         dataset = json.load(f) |         dataset = json.load(f) | ||||||
|     # Filter out the conversations with less than 2 turns. |     # Filter out the conversations with less than 2 turns. | ||||||
|     dataset = [ |     dataset = [data for data in dataset if len(data["conversations"]) >= 2] | ||||||
|         data for data in dataset |  | ||||||
|         if len(data["conversations"]) >= 2 |  | ||||||
|     ] |  | ||||||
|     # Only keep the first two turns of each conversation. |     # Only keep the first two turns of each conversation. | ||||||
|     dataset = [ |     dataset = [(data["conversations"][0]["value"], | ||||||
|         (data["conversations"][0]["value"], data["conversations"][1]["value"]) |                 data["conversations"][1]["value"]) for data in dataset] | ||||||
|         for data in dataset |  | ||||||
|     ] |  | ||||||
|  |  | ||||||
|     # Tokenize the prompts and completions. |     # Tokenize the prompts and completions. | ||||||
|     prompts = [prompt for prompt, _ in dataset] |     prompts = [prompt for prompt, _ in dataset] | ||||||
| @ -96,15 +92,9 @@ async def get_request( | |||||||
|         await asyncio.sleep(interval) |         await asyncio.sleep(interval) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def send_request( | async def send_request(backend: str, model: str, api_url: str, prompt: str, | ||||||
|     backend: str, |                        prompt_len: int, output_len: int, best_of: int, | ||||||
|     api_url: str, |                        use_beam_search: bool, pbar: tqdm) -> None: | ||||||
|     prompt: str, |  | ||||||
|     prompt_len: int, |  | ||||||
|     output_len: int, |  | ||||||
|     best_of: int, |  | ||||||
|     use_beam_search: bool, |  | ||||||
| ) -> None: |  | ||||||
|     request_start_time = time.perf_counter() |     request_start_time = time.perf_counter() | ||||||
|  |  | ||||||
|     headers = {"User-Agent": "Benchmark Client"} |     headers = {"User-Agent": "Benchmark Client"} | ||||||
| @ -120,6 +110,8 @@ async def send_request( | |||||||
|             "ignore_eos": True, |             "ignore_eos": True, | ||||||
|             "stream": False, |             "stream": False, | ||||||
|         } |         } | ||||||
|  |         if model is not None: | ||||||
|  |             pload["model"] = model | ||||||
|     elif backend == "tgi": |     elif backend == "tgi": | ||||||
|         assert not use_beam_search |         assert not use_beam_search | ||||||
|         params = { |         params = { | ||||||
| @ -137,7 +129,8 @@ async def send_request( | |||||||
|     timeout = aiohttp.ClientTimeout(total=3 * 3600) |     timeout = aiohttp.ClientTimeout(total=3 * 3600) | ||||||
|     async with aiohttp.ClientSession(timeout=timeout) as session: |     async with aiohttp.ClientSession(timeout=timeout) as session: | ||||||
|         while True: |         while True: | ||||||
|             async with session.post(api_url, headers=headers, json=pload) as response: |             async with session.post(api_url, headers=headers, | ||||||
|  |                                     json=pload) as response: | ||||||
|                 chunks = [] |                 chunks = [] | ||||||
|                 async for chunk, _ in response.content.iter_chunks(): |                 async for chunk, _ in response.content.iter_chunks(): | ||||||
|                     chunks.append(chunk) |                     chunks.append(chunk) | ||||||
| @ -151,10 +144,12 @@ async def send_request( | |||||||
|     request_end_time = time.perf_counter() |     request_end_time = time.perf_counter() | ||||||
|     request_latency = request_end_time - request_start_time |     request_latency = request_end_time - request_start_time | ||||||
|     REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) |     REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) | ||||||
|  |     pbar.update(1) | ||||||
|  |  | ||||||
|  |  | ||||||
| async def benchmark( | async def benchmark( | ||||||
|     backend: str, |     backend: str, | ||||||
|  |     model: str, | ||||||
|     api_url: str, |     api_url: str, | ||||||
|     input_requests: List[Tuple[str, int, int]], |     input_requests: List[Tuple[str, int, int]], | ||||||
|     best_of: int, |     best_of: int, | ||||||
| @ -162,13 +157,15 @@ async def benchmark( | |||||||
|     request_rate: float, |     request_rate: float, | ||||||
| ) -> None: | ) -> None: | ||||||
|     tasks: List[asyncio.Task] = [] |     tasks: List[asyncio.Task] = [] | ||||||
|  |     pbar = tqdm(total=len(input_requests)) | ||||||
|     async for request in get_request(input_requests, request_rate): |     async for request in get_request(input_requests, request_rate): | ||||||
|         prompt, prompt_len, output_len = request |         prompt, prompt_len, output_len = request | ||||||
|         task = asyncio.create_task(send_request(backend, api_url, prompt, |         task = asyncio.create_task( | ||||||
|                                                 prompt_len, output_len, |             send_request(backend, model, api_url, prompt, prompt_len, | ||||||
|                                                 best_of, use_beam_search)) |                          output_len, best_of, use_beam_search, pbar)) | ||||||
|         tasks.append(task) |         tasks.append(task) | ||||||
|     await asyncio.gather(*tasks) |     await asyncio.gather(*tasks) | ||||||
|  |     pbar.close() | ||||||
|  |  | ||||||
|  |  | ||||||
| def main(args: argparse.Namespace): | def main(args: argparse.Namespace): | ||||||
| @ -176,13 +173,15 @@ def main(args: argparse.Namespace): | |||||||
|     random.seed(args.seed) |     random.seed(args.seed) | ||||||
|     np.random.seed(args.seed) |     np.random.seed(args.seed) | ||||||
|  |  | ||||||
|     api_url = f"http://{args.host}:{args.port}/generate" |     api_url = f"{args.protocol}://{args.host}:{args.port}{args.endpoint}" | ||||||
|     tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code) |     tokenizer = get_tokenizer(args.tokenizer, | ||||||
|  |                               trust_remote_code=args.trust_remote_code) | ||||||
|     input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) |     input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) | ||||||
|  |  | ||||||
|     benchmark_start_time = time.perf_counter() |     benchmark_start_time = time.perf_counter() | ||||||
|     asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of, |     asyncio.run( | ||||||
|                           args.use_beam_search, args.request_rate)) |         benchmark(args.backend, args.model, api_url, input_requests, | ||||||
|  |                   args.best_of, args.use_beam_search, args.request_rate)) | ||||||
|     benchmark_end_time = time.perf_counter() |     benchmark_end_time = time.perf_counter() | ||||||
|     benchmark_time = benchmark_end_time - benchmark_start_time |     benchmark_time = benchmark_end_time - benchmark_start_time | ||||||
|     print(f"Total time: {benchmark_time:.2f} s") |     print(f"Total time: {benchmark_time:.2f} s") | ||||||
| @ -196,10 +195,8 @@ def main(args: argparse.Namespace): | |||||||
|         for prompt_len, output_len, latency in REQUEST_LATENCY |         for prompt_len, output_len, latency in REQUEST_LATENCY | ||||||
|     ]) |     ]) | ||||||
|     print(f"Average latency per token: {avg_per_token_latency:.2f} s") |     print(f"Average latency per token: {avg_per_token_latency:.2f} s") | ||||||
|     avg_per_output_token_latency = np.mean([ |     avg_per_output_token_latency = np.mean( | ||||||
|         latency / output_len |         [latency / output_len for _, output_len, latency in REQUEST_LATENCY]) | ||||||
|         for _, output_len, latency in REQUEST_LATENCY |  | ||||||
|     ]) |  | ||||||
|     print("Average latency per output token: " |     print("Average latency per output token: " | ||||||
|           f"{avg_per_output_token_latency:.2f} s") |           f"{avg_per_output_token_latency:.2f} s") | ||||||
|  |  | ||||||
| @ -207,27 +204,46 @@ def main(args: argparse.Namespace): | |||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     parser = argparse.ArgumentParser( |     parser = argparse.ArgumentParser( | ||||||
|         description="Benchmark the online serving throughput.") |         description="Benchmark the online serving throughput.") | ||||||
|     parser.add_argument("--backend", type=str, default="vllm", |     parser.add_argument("--backend", | ||||||
|  |                         type=str, | ||||||
|  |                         default="vllm", | ||||||
|                         choices=["vllm", "tgi"]) |                         choices=["vllm", "tgi"]) | ||||||
|  |     parser.add_argument("--protocol", | ||||||
|  |                         type=str, | ||||||
|  |                         default="http", | ||||||
|  |                         choices=["http", "https"]) | ||||||
|     parser.add_argument("--host", type=str, default="localhost") |     parser.add_argument("--host", type=str, default="localhost") | ||||||
|     parser.add_argument("--port", type=int, default=8000) |     parser.add_argument("--port", type=int, default=8000) | ||||||
|     parser.add_argument("--dataset", type=str, required=True, |     parser.add_argument("--endpoint", type=str, default="/generate") | ||||||
|  |     parser.add_argument("--model", type=str, default=None) | ||||||
|  |     parser.add_argument("--dataset", | ||||||
|  |                         type=str, | ||||||
|  |                         required=True, | ||||||
|                         help="Path to the dataset.") |                         help="Path to the dataset.") | ||||||
|     parser.add_argument("--tokenizer", type=str, required=True, |     parser.add_argument("--tokenizer", | ||||||
|  |                         type=str, | ||||||
|  |                         required=True, | ||||||
|                         help="Name or path of the tokenizer.") |                         help="Name or path of the tokenizer.") | ||||||
|     parser.add_argument("--best-of", type=int, default=1, |     parser.add_argument("--best-of", | ||||||
|  |                         type=int, | ||||||
|  |                         default=1, | ||||||
|                         help="Generates `best_of` sequences per prompt and " |                         help="Generates `best_of` sequences per prompt and " | ||||||
|                              "returns the best one.") |                         "returns the best one.") | ||||||
|     parser.add_argument("--use-beam-search", action="store_true") |     parser.add_argument("--use-beam-search", action="store_true") | ||||||
|     parser.add_argument("--num-prompts", type=int, default=1000, |     parser.add_argument("--num-prompts", | ||||||
|  |                         type=int, | ||||||
|  |                         default=1000, | ||||||
|                         help="Number of prompts to process.") |                         help="Number of prompts to process.") | ||||||
|     parser.add_argument("--request-rate", type=float, default=float("inf"), |     parser.add_argument("--request-rate", | ||||||
|  |                         type=float, | ||||||
|  |                         default=float("inf"), | ||||||
|                         help="Number of requests per second. If this is inf, " |                         help="Number of requests per second. If this is inf, " | ||||||
|                              "then all the requests are sent at time 0. " |                         "then all the requests are sent at time 0. " | ||||||
|                              "Otherwise, we use Poisson process to synthesize " |                         "Otherwise, we use Poisson process to synthesize " | ||||||
|                              "the request arrival times.") |                         "the request arrival times.") | ||||||
|     parser.add_argument("--seed", type=int, default=0) |     parser.add_argument("--seed", type=int, default=0) | ||||||
|     parser.add_argument('--trust-remote-code', action='store_true', |     parser.add_argument('--trust-remote-code', | ||||||
|  |                         action='store_true', | ||||||
|                         help='trust remote code from huggingface') |                         help='trust remote code from huggingface') | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     main(args) |     main(args) | ||||||
|  | |||||||
| @ -71,6 +71,7 @@ def run_vllm( | |||||||
|     dtype: str, |     dtype: str, | ||||||
|     max_model_len: Optional[int], |     max_model_len: Optional[int], | ||||||
|     enforce_eager: bool, |     enforce_eager: bool, | ||||||
|  |     kv_cache_dtype: str, | ||||||
| ) -> float: | ) -> float: | ||||||
|     from vllm import LLM, SamplingParams |     from vllm import LLM, SamplingParams | ||||||
|     llm = LLM( |     llm = LLM( | ||||||
| @ -83,6 +84,7 @@ def run_vllm( | |||||||
|         dtype=dtype, |         dtype=dtype, | ||||||
|         max_model_len=max_model_len, |         max_model_len=max_model_len, | ||||||
|         enforce_eager=enforce_eager, |         enforce_eager=enforce_eager, | ||||||
|  |         kv_cache_dtype=kv_cache_dtype, | ||||||
|     ) |     ) | ||||||
|  |  | ||||||
|     # Add the requests to the engine. |     # Add the requests to the engine. | ||||||
| @ -206,7 +208,8 @@ def main(args: argparse.Namespace): | |||||||
|                                 args.quantization, args.tensor_parallel_size, |                                 args.quantization, args.tensor_parallel_size, | ||||||
|                                 args.seed, args.n, args.use_beam_search, |                                 args.seed, args.n, args.use_beam_search, | ||||||
|                                 args.trust_remote_code, args.dtype, |                                 args.trust_remote_code, args.dtype, | ||||||
|                                 args.max_model_len, args.enforce_eager) |                                 args.max_model_len, args.enforce_eager, | ||||||
|  |                                 args.kv_cache_dtype) | ||||||
|     elif args.backend == "hf": |     elif args.backend == "hf": | ||||||
|         assert args.tensor_parallel_size == 1 |         assert args.tensor_parallel_size == 1 | ||||||
|         elapsed_time = run_hf(requests, args.model, tokenizer, args.n, |         elapsed_time = run_hf(requests, args.model, tokenizer, args.n, | ||||||
| @ -284,6 +287,13 @@ if __name__ == "__main__": | |||||||
|     parser.add_argument("--enforce-eager", |     parser.add_argument("--enforce-eager", | ||||||
|                         action="store_true", |                         action="store_true", | ||||||
|                         help="enforce eager execution") |                         help="enforce eager execution") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--kv-cache-dtype", | ||||||
|  |         type=str, | ||||||
|  |         choices=["auto", "fp8_e5m2"], | ||||||
|  |         default="auto", | ||||||
|  |         help= | ||||||
|  |         'Data type for kv cache storage. If "auto", will use model data type.') | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     if args.tokenizer is None: |     if args.tokenizer is None: | ||||||
|         args.tokenizer = args.model |         args.tokenizer = args.model | ||||||
|  | |||||||
| @ -1,9 +1,11 @@ | |||||||
|  | from typing import Optional | ||||||
| import argparse | import argparse | ||||||
| import random | import random | ||||||
| import time | import time | ||||||
|  |  | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
|  | from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random | ||||||
| from vllm._C import ops | from vllm._C import ops | ||||||
|  |  | ||||||
| NUM_BLOCKS = 1024 | NUM_BLOCKS = 1024 | ||||||
| @ -23,6 +25,7 @@ def main( | |||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|     do_profile: bool, |     do_profile: bool, | ||||||
|  |     kv_cache_dtype: Optional[str] = None, | ||||||
| ) -> None: | ) -> None: | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
| @ -59,15 +62,10 @@ def main( | |||||||
|     block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") |     block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") | ||||||
|  |  | ||||||
|     # Create the KV cache. |     # Create the KV cache. | ||||||
|     x = 16 // torch.tensor([], dtype=dtype).element_size() |     key_caches, value_caches = create_kv_caches_with_random( | ||||||
|     key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x) |         NUM_BLOCKS, block_size, 1, num_kv_heads, head_size, kv_cache_dtype, | ||||||
|     key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda") |         dtype) | ||||||
|     key_cache.uniform_(-scale, scale) |     key_cache, value_cache = key_caches[0], value_caches[0] | ||||||
|     value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size) |  | ||||||
|     value_cache = torch.empty(size=value_cache_shape, |  | ||||||
|                               dtype=dtype, |  | ||||||
|                               device="cuda") |  | ||||||
|     value_cache.uniform_(-scale, scale) |  | ||||||
|  |  | ||||||
|     # Prepare for the paged attention kernel. |     # Prepare for the paged attention kernel. | ||||||
|     output = torch.empty_like(query) |     output = torch.empty_like(query) | ||||||
| @ -106,6 +104,7 @@ def main( | |||||||
|                     block_size, |                     block_size, | ||||||
|                     max_context_len, |                     max_context_len, | ||||||
|                     alibi_slopes, |                     alibi_slopes, | ||||||
|  |                     kv_cache_dtype, | ||||||
|                 ) |                 ) | ||||||
|             elif version == "v2": |             elif version == "v2": | ||||||
|                 ops.paged_attention_v2( |                 ops.paged_attention_v2( | ||||||
| @ -123,6 +122,7 @@ def main( | |||||||
|                     block_size, |                     block_size, | ||||||
|                     max_context_len, |                     max_context_len, | ||||||
|                     alibi_slopes, |                     alibi_slopes, | ||||||
|  |                     kv_cache_dtype, | ||||||
|                 ) |                 ) | ||||||
|             else: |             else: | ||||||
|                 raise ValueError(f"Invalid version: {version}") |                 raise ValueError(f"Invalid version: {version}") | ||||||
| @ -168,16 +168,18 @@ if __name__ == '__main__': | |||||||
|                         default="half") |                         default="half") | ||||||
|     parser.add_argument("--seed", type=int, default=0) |     parser.add_argument("--seed", type=int, default=0) | ||||||
|     parser.add_argument("--profile", action="store_true") |     parser.add_argument("--profile", action="store_true") | ||||||
|  |     parser.add_argument( | ||||||
|  |         "--kv-cache-dtype", | ||||||
|  |         type=str, | ||||||
|  |         choices=["auto", "fp8_e5m2"], | ||||||
|  |         default="auto", | ||||||
|  |         help= | ||||||
|  |         'Data type for kv cache storage. If "auto", will use model data type.') | ||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|     print(args) |     print(args) | ||||||
|  |  | ||||||
|     if args.num_query_heads % args.num_kv_heads != 0: |     if args.num_query_heads % args.num_kv_heads != 0: | ||||||
|         raise ValueError("num_query_heads must be divisible by num_kv_heads") |         raise ValueError("num_query_heads must be divisible by num_kv_heads") | ||||||
|     dtype_to_torch_dtype = { |  | ||||||
|         "half": torch.half, |  | ||||||
|         "bfloat16": torch.bfloat16, |  | ||||||
|         "float": torch.float, |  | ||||||
|     } |  | ||||||
|     main( |     main( | ||||||
|         version=args.version, |         version=args.version, | ||||||
|         num_seqs=args.batch_size, |         num_seqs=args.batch_size, | ||||||
| @ -187,7 +189,8 @@ if __name__ == '__main__': | |||||||
|         head_size=args.head_size, |         head_size=args.head_size, | ||||||
|         block_size=args.block_size, |         block_size=args.block_size, | ||||||
|         use_alibi=args.use_alibi, |         use_alibi=args.use_alibi, | ||||||
|         dtype=dtype_to_torch_dtype[args.dtype], |         dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype], | ||||||
|         seed=args.seed, |         seed=args.seed, | ||||||
|         do_profile=args.profile, |         do_profile=args.profile, | ||||||
|  |         kv_cache_dtype=args.kv_cache_dtype, | ||||||
|     ) |     ) | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| #include <torch/extension.h> |  | ||||||
| #include <ATen/cuda/CUDAContext.h> | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  |  | ||||||
| #include "cuda_compat.h" | #include "cuda_compat.h" | ||||||
| #include "dispatch_utils.h" | #include "dispatch_utils.h" | ||||||
| @ -36,6 +37,7 @@ void silu_and_mul( | |||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|   dim3 block(std::min(d, 1024)); |   dim3 block(std::min(d, 1024)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   VLLM_DISPATCH_FLOATING_TYPES( | ||||||
|     input.scalar_type(), |     input.scalar_type(), | ||||||
| @ -71,6 +73,7 @@ __global__ void activation_kernel( | |||||||
|   int64_t num_tokens = input.numel() / d;                                                 \ |   int64_t num_tokens = input.numel() / d;                                                 \ | ||||||
|   dim3 grid(num_tokens);                                                                  \ |   dim3 grid(num_tokens);                                                                  \ | ||||||
|   dim3 block(std::min(d, 1024));                                                          \ |   dim3 block(std::min(d, 1024));                                                          \ | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \ | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \ |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \ | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES(                                                           \ |   VLLM_DISPATCH_FLOATING_TYPES(                                                           \ | ||||||
|     input.scalar_type(),                                                                  \ |     input.scalar_type(),                                                                  \ | ||||||
|  | |||||||
| @ -4,3 +4,4 @@ | |||||||
| #include "dtype_float16.cuh" | #include "dtype_float16.cuh" | ||||||
| #include "dtype_float32.cuh" | #include "dtype_float32.cuh" | ||||||
| #include "dtype_bfloat16.cuh" | #include "dtype_bfloat16.cuh" | ||||||
|  | #include "dtype_fp8_e5m2.cuh" | ||||||
|  | |||||||
| @ -21,9 +21,11 @@ | |||||||
|  |  | ||||||
| #include <torch/extension.h> | #include <torch/extension.h> | ||||||
| #include <ATen/cuda/CUDAContext.h> | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  |  | ||||||
| #include "attention_dtypes.h" | #include "attention_dtypes.h" | ||||||
| #include "attention_utils.cuh" | #include "attention_utils.cuh" | ||||||
|  | #include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh" | ||||||
|  |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
|  |  | ||||||
| @ -78,17 +80,19 @@ inline __device__ float block_sum(float* red_smem, float sum) { | |||||||
| // Grid: (num_heads, num_seqs, max_num_partitions). | // Grid: (num_heads, num_seqs, max_num_partitions). | ||||||
| template< | template< | ||||||
|   typename scalar_t, |   typename scalar_t, | ||||||
|  |   typename cache_t, | ||||||
|   int HEAD_SIZE, |   int HEAD_SIZE, | ||||||
|   int BLOCK_SIZE, |   int BLOCK_SIZE, | ||||||
|   int NUM_THREADS, |   int NUM_THREADS, | ||||||
|  |   bool IS_FP8_E5M2_KV_CACHE, | ||||||
|   int PARTITION_SIZE = 0> // Zero means no partitioning. |   int PARTITION_SIZE = 0> // Zero means no partitioning. | ||||||
| __device__ void paged_attention_kernel( | __device__ void paged_attention_kernel( | ||||||
|   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions] |   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions] | ||||||
|   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions] |   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions] | ||||||
|   scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size] |   scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size] | ||||||
|   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size] |   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size] | ||||||
|   const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x] |   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x] | ||||||
|   const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size] |   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size] | ||||||
|   const int num_kv_heads,                 // [num_heads] |   const int num_kv_heads,                 // [num_heads] | ||||||
|   const float scale, |   const float scale, | ||||||
|   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq] |   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq] | ||||||
| @ -144,6 +148,9 @@ __device__ void paged_attention_kernel( | |||||||
|   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); |   constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); | ||||||
|   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; |   using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type; | ||||||
|   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; |   using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type; | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  |   using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|   constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; |   constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; | ||||||
|   constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; |   constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; | ||||||
| @ -175,7 +182,7 @@ __device__ void paged_attention_kernel( | |||||||
|  |  | ||||||
|   // x == THREAD_GROUP_SIZE * VEC_SIZE |   // x == THREAD_GROUP_SIZE * VEC_SIZE | ||||||
|   // Each thread group fetches x elements from the key at a time. |   // Each thread group fetches x elements from the key at a time. | ||||||
|   constexpr int x = 16 / sizeof(scalar_t); |   constexpr int x = 16 / sizeof(cache_t); | ||||||
|   float qk_max = -FLT_MAX; |   float qk_max = -FLT_MAX; | ||||||
|  |  | ||||||
|   // Iterate over the key blocks. |   // Iterate over the key blocks. | ||||||
| @ -201,13 +208,23 @@ __device__ void paged_attention_kernel( | |||||||
|  |  | ||||||
| #pragma unroll | #pragma unroll | ||||||
|       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { |       for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { | ||||||
|         const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride |         const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride | ||||||
|                                         + kv_head_idx * kv_head_stride |                                        + kv_head_idx * kv_head_stride | ||||||
|                                         + physical_block_offset * x; |                                        + physical_block_offset * x; | ||||||
|         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; |         const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; | ||||||
|         const int offset1 = (vec_idx * VEC_SIZE) / x; |         const int offset1 = (vec_idx * VEC_SIZE) / x; | ||||||
|         const int offset2 = (vec_idx * VEC_SIZE) % x; |         const int offset2 = (vec_idx * VEC_SIZE) % x; | ||||||
|         k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); |         if constexpr (IS_FP8_E5M2_KV_CACHE) { | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  |           Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); | ||||||
|  |           // Vector conversion from Quant_vec to K_vec. | ||||||
|  |           k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant); | ||||||
|  | #else | ||||||
|  |           assert(false); | ||||||
|  | #endif | ||||||
|  |         } else { | ||||||
|  |           k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2); | ||||||
|  |         } | ||||||
|       } |       } | ||||||
|  |  | ||||||
|       // Compute dot product. |       // Compute dot product. | ||||||
| @ -281,6 +298,9 @@ __device__ void paged_attention_kernel( | |||||||
|   constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); |   constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); | ||||||
|   using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; |   using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; | ||||||
|   using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; |   using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type; | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  |   using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type; | ||||||
|  | #endif | ||||||
|   using Float_L_vec = typename FloatVec<L_vec>::Type; |   using Float_L_vec = typename FloatVec<L_vec>::Type; | ||||||
|  |  | ||||||
|   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; |   constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; | ||||||
| @ -306,14 +326,25 @@ __device__ void paged_attention_kernel( | |||||||
|     L_vec logits_vec; |     L_vec logits_vec; | ||||||
|     from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); |     from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx)); | ||||||
|  |  | ||||||
|     const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride |     const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride | ||||||
|                                     + kv_head_idx * kv_head_stride; |                                    + kv_head_idx * kv_head_stride; | ||||||
| #pragma unroll | #pragma unroll | ||||||
|     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { |     for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { | ||||||
|       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; |       const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; | ||||||
|       if (row_idx < HEAD_SIZE) { |       if (row_idx < HEAD_SIZE) { | ||||||
|         const int offset = row_idx * BLOCK_SIZE + physical_block_offset; |         const int offset = row_idx * BLOCK_SIZE + physical_block_offset; | ||||||
|         V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); |         V_vec v_vec; | ||||||
|  |         if constexpr (IS_FP8_E5M2_KV_CACHE) { | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  |           V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset); | ||||||
|  |           // Vector conversion from V_quant_vec to V_vec. | ||||||
|  |           v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec); | ||||||
|  | #else | ||||||
|  |           assert(false); | ||||||
|  | #endif | ||||||
|  |         } else { | ||||||
|  |           v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset); | ||||||
|  |         } | ||||||
|         if (block_idx == num_context_blocks - 1) { |         if (block_idx == num_context_blocks - 1) { | ||||||
|           // NOTE(woosuk): When v_vec contains the tokens that are out of the context, |           // NOTE(woosuk): When v_vec contains the tokens that are out of the context, | ||||||
|           // we should explicitly zero out the values since they may contain NaNs. |           // we should explicitly zero out the values since they may contain NaNs. | ||||||
| @ -394,14 +425,16 @@ __device__ void paged_attention_kernel( | |||||||
| // Grid: (num_heads, num_seqs, 1). | // Grid: (num_heads, num_seqs, 1). | ||||||
| template< | template< | ||||||
|   typename scalar_t, |   typename scalar_t, | ||||||
|  |   typename cache_t, | ||||||
|   int HEAD_SIZE, |   int HEAD_SIZE, | ||||||
|   int BLOCK_SIZE, |   int BLOCK_SIZE, | ||||||
|   int NUM_THREADS> |   int NUM_THREADS, | ||||||
|  |   bool IS_FP8_E5M2_KV_CACHE> | ||||||
| __global__ void paged_attention_v1_kernel( | __global__ void paged_attention_v1_kernel( | ||||||
|   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size] |   scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size] | ||||||
|   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size] |   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size] | ||||||
|   const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x] |   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x] | ||||||
|   const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size] |   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size] | ||||||
|   const int num_kv_heads,                 // [num_heads] |   const int num_kv_heads,                 // [num_heads] | ||||||
|   const float scale, |   const float scale, | ||||||
|   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq] |   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq] | ||||||
| @ -411,7 +444,7 @@ __global__ void paged_attention_v1_kernel( | |||||||
|   const int q_stride, |   const int q_stride, | ||||||
|   const int kv_block_stride, |   const int kv_block_stride, | ||||||
|   const int kv_head_stride) { |   const int kv_head_stride) { | ||||||
|   paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>( |   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>( | ||||||
|     /* exp_sums */ nullptr, /* max_logits */ nullptr, |     /* exp_sums */ nullptr, /* max_logits */ nullptr, | ||||||
|     out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, |     out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens, | ||||||
|     max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); |     max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride); | ||||||
| @ -420,17 +453,19 @@ __global__ void paged_attention_v1_kernel( | |||||||
| // Grid: (num_heads, num_seqs, max_num_partitions). | // Grid: (num_heads, num_seqs, max_num_partitions). | ||||||
| template< | template< | ||||||
|   typename scalar_t, |   typename scalar_t, | ||||||
|  |   typename cache_t, | ||||||
|   int HEAD_SIZE, |   int HEAD_SIZE, | ||||||
|   int BLOCK_SIZE, |   int BLOCK_SIZE, | ||||||
|   int NUM_THREADS, |   int NUM_THREADS, | ||||||
|  |   bool IS_FP8_E5M2_KV_CACHE, | ||||||
|   int PARTITION_SIZE> |   int PARTITION_SIZE> | ||||||
| __global__ void paged_attention_v2_kernel( | __global__ void paged_attention_v2_kernel( | ||||||
|   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions] |   float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions] | ||||||
|   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions] |   float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions] | ||||||
|   scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size] |   scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size] | ||||||
|   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size] |   const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size] | ||||||
|   const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x] |   const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x] | ||||||
|   const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size] |   const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size] | ||||||
|   const int num_kv_heads,                 // [num_heads] |   const int num_kv_heads,                 // [num_heads] | ||||||
|   const float scale, |   const float scale, | ||||||
|   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq] |   const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq] | ||||||
| @ -440,7 +475,7 @@ __global__ void paged_attention_v2_kernel( | |||||||
|   const int q_stride, |   const int q_stride, | ||||||
|   const int kv_block_stride, |   const int kv_block_stride, | ||||||
|   const int kv_head_stride) { |   const int kv_head_stride) { | ||||||
|   paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>( |   paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>( | ||||||
|     exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, |     exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, | ||||||
|     block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, |     block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes, | ||||||
|     q_stride, kv_block_stride, kv_head_stride); |     q_stride, kv_block_stride, kv_head_stride); | ||||||
| @ -549,10 +584,10 @@ __global__ void paged_attention_v2_reduce_kernel( | |||||||
|  |  | ||||||
| #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \ | #define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \ | ||||||
|   VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \ |   VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \ | ||||||
|     ((void*)vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>),          \ |     ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \ | ||||||
|     shared_mem_size);                                                                         \ |       IS_FP8_E5M2_KV_CACHE>), shared_mem_size);                                               \ | ||||||
|   vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>                      \ |   vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \ | ||||||
|   <<<grid, block, shared_mem_size, stream>>>(                                                 \ |   IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                            \ | ||||||
|     out_ptr,                                                                                  \ |     out_ptr,                                                                                  \ | ||||||
|     query_ptr,                                                                                \ |     query_ptr,                                                                                \ | ||||||
|     key_cache_ptr,                                                                            \ |     key_cache_ptr,                                                                            \ | ||||||
| @ -570,7 +605,9 @@ __global__ void paged_attention_v2_reduce_kernel( | |||||||
| // TODO(woosuk): Tune NUM_THREADS. | // TODO(woosuk): Tune NUM_THREADS. | ||||||
| template< | template< | ||||||
|   typename T, |   typename T, | ||||||
|  |   typename CACHE_T, | ||||||
|   int BLOCK_SIZE, |   int BLOCK_SIZE, | ||||||
|  |   bool IS_FP8_E5M2_KV_CACHE, | ||||||
|   int NUM_THREADS = 128> |   int NUM_THREADS = 128> | ||||||
| void paged_attention_v1_launcher( | void paged_attention_v1_launcher( | ||||||
|   torch::Tensor& out, |   torch::Tensor& out, | ||||||
| @ -601,8 +638,8 @@ void paged_attention_v1_launcher( | |||||||
|  |  | ||||||
|   T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); |   T* out_ptr = reinterpret_cast<T*>(out.data_ptr()); | ||||||
|   T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); |   T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); | ||||||
|   T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); |   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); | ||||||
|   T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); |   CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); | ||||||
|   int* block_tables_ptr = block_tables.data_ptr<int>(); |   int* block_tables_ptr = block_tables.data_ptr<int>(); | ||||||
|   int* context_lens_ptr = context_lens.data_ptr<int>(); |   int* context_lens_ptr = context_lens.data_ptr<int>(); | ||||||
|  |  | ||||||
| @ -616,6 +653,7 @@ void paged_attention_v1_launcher( | |||||||
|  |  | ||||||
|   dim3 grid(num_heads, num_seqs, 1); |   dim3 grid(num_heads, num_seqs, 1); | ||||||
|   dim3 block(NUM_THREADS); |   dim3 block(NUM_THREADS); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   switch (head_size) { |   switch (head_size) { | ||||||
|     // NOTE(woosuk): To reduce the compilation time, we only compile for the |     // NOTE(woosuk): To reduce the compilation time, we only compile for the | ||||||
| @ -645,35 +683,35 @@ void paged_attention_v1_launcher( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #define CALL_V1_LAUNCHER(T, BLOCK_SIZE)                             \ | #define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)       \ | ||||||
|   paged_attention_v1_launcher<T, BLOCK_SIZE>(                       \ |   paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \ | ||||||
|     out,                                                            \ |     out,                                                                     \ | ||||||
|     query,                                                          \ |     query,                                                                   \ | ||||||
|     key_cache,                                                      \ |     key_cache,                                                               \ | ||||||
|     value_cache,                                                    \ |     value_cache,                                                             \ | ||||||
|     num_kv_heads,                                                   \ |     num_kv_heads,                                                            \ | ||||||
|     scale,                                                          \ |     scale,                                                                   \ | ||||||
|     block_tables,                                                   \ |     block_tables,                                                            \ | ||||||
|     context_lens,                                                   \ |     context_lens,                                                            \ | ||||||
|     max_context_len,                                                \ |     max_context_len,                                                         \ | ||||||
|     alibi_slopes); |     alibi_slopes); | ||||||
|  |  | ||||||
| // NOTE(woosuk): To reduce the compilation time, we omitted block sizes | // NOTE(woosuk): To reduce the compilation time, we omitted block sizes | ||||||
| // 1, 2, 4, 64, 128, 256. | // 1, 2, 4, 64, 128, 256. | ||||||
| #define CALL_V1_LAUNCHER_BLOCK_SIZE(T)                              \ | #define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \ | ||||||
|   switch (block_size) {                                             \ |   switch (block_size) {                                               \ | ||||||
|     case 8:                                                         \ |     case 8:                                                           \ | ||||||
|       CALL_V1_LAUNCHER(T, 8);                                       \ |       CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);          \ | ||||||
|       break;                                                        \ |       break;                                                          \ | ||||||
|     case 16:                                                        \ |     case 16:                                                          \ | ||||||
|       CALL_V1_LAUNCHER(T, 16);                                      \ |       CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);         \ | ||||||
|       break;                                                        \ |       break;                                                          \ | ||||||
|     case 32:                                                        \ |     case 32:                                                          \ | ||||||
|       CALL_V1_LAUNCHER(T, 32);                                      \ |       CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);         \ | ||||||
|       break;                                                        \ |       break;                                                          \ | ||||||
|     default:                                                        \ |     default:                                                          \ | ||||||
|       TORCH_CHECK(false, "Unsupported block size: ", block_size);   \ |       TORCH_CHECK(false, "Unsupported block size: ", block_size);     \ | ||||||
|       break;                                                        \ |       break;                                                          \ | ||||||
|   } |   } | ||||||
|  |  | ||||||
| void paged_attention_v1( | void paged_attention_v1( | ||||||
| @ -687,20 +725,36 @@ void paged_attention_v1( | |||||||
|   torch::Tensor& context_lens,    // [num_seqs] |   torch::Tensor& context_lens,    // [num_seqs] | ||||||
|   int block_size, |   int block_size, | ||||||
|   int max_context_len, |   int max_context_len, | ||||||
|   const c10::optional<torch::Tensor>& alibi_slopes) { |   const c10::optional<torch::Tensor>& alibi_slopes, | ||||||
|   if (query.dtype() == at::ScalarType::Float) { |   const std::string& kv_cache_dtype) { | ||||||
|     CALL_V1_LAUNCHER_BLOCK_SIZE(float); |   if (kv_cache_dtype == "auto") { | ||||||
|   } else if (query.dtype() == at::ScalarType::Half) { |     if (query.dtype() == at::ScalarType::Float) { | ||||||
|     CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t); |       CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false); | ||||||
|   } else if (query.dtype() == at::ScalarType::BFloat16) { |     } else if (query.dtype() == at::ScalarType::Half) { | ||||||
|     CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); |       CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); | ||||||
|  |     } else if (query.dtype() == at::ScalarType::BFloat16) { | ||||||
|  |       CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); | ||||||
|  |     } else { | ||||||
|  |       TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||||||
|  |     } | ||||||
|  |   } else if (kv_cache_dtype == "fp8_e5m2") { | ||||||
|  |     if (query.dtype() == at::ScalarType::Float) { | ||||||
|  |       CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); | ||||||
|  |     } else if (query.dtype() == at::ScalarType::Half) { | ||||||
|  |       CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); | ||||||
|  |     } else if (query.dtype() == at::ScalarType::BFloat16) { | ||||||
|  |       CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); | ||||||
|  |     } else { | ||||||
|  |       TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||||||
|  |     } | ||||||
|   } else { |   } else { | ||||||
|     TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); |     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \ | #define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \ | ||||||
|   vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>      \ |   vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \ | ||||||
|  |   IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>                                                       \ | ||||||
|   <<<grid, block, shared_mem_size, stream>>>(                                                 \ |   <<<grid, block, shared_mem_size, stream>>>(                                                 \ | ||||||
|     exp_sums_ptr,                                                                             \ |     exp_sums_ptr,                                                                             \ | ||||||
|     max_logits_ptr,                                                                           \ |     max_logits_ptr,                                                                           \ | ||||||
| @ -728,7 +782,9 @@ void paged_attention_v1( | |||||||
|  |  | ||||||
| template< | template< | ||||||
|   typename T, |   typename T, | ||||||
|  |   typename CACHE_T, | ||||||
|   int BLOCK_SIZE, |   int BLOCK_SIZE, | ||||||
|  |   bool IS_FP8_E5M2_KV_CACHE, | ||||||
|   int NUM_THREADS = 128, |   int NUM_THREADS = 128, | ||||||
|   int PARTITION_SIZE = 512> |   int PARTITION_SIZE = 512> | ||||||
| void paged_attention_v2_launcher( | void paged_attention_v2_launcher( | ||||||
| @ -766,8 +822,8 @@ void paged_attention_v2_launcher( | |||||||
|   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); |   float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr()); | ||||||
|   T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); |   T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr()); | ||||||
|   T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); |   T* query_ptr = reinterpret_cast<T*>(query.data_ptr()); | ||||||
|   T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr()); |   CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr()); | ||||||
|   T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr()); |   CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr()); | ||||||
|   int* block_tables_ptr = block_tables.data_ptr<int>(); |   int* block_tables_ptr = block_tables.data_ptr<int>(); | ||||||
|   int* context_lens_ptr = context_lens.data_ptr<int>(); |   int* context_lens_ptr = context_lens.data_ptr<int>(); | ||||||
|  |  | ||||||
| @ -784,6 +840,7 @@ void paged_attention_v2_launcher( | |||||||
|   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); |   int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); | ||||||
|  |  | ||||||
|   dim3 block(NUM_THREADS); |   dim3 block(NUM_THREADS); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   switch (head_size) { |   switch (head_size) { | ||||||
|     // NOTE(woosuk): To reduce the compilation time, we only compile for the |     // NOTE(woosuk): To reduce the compilation time, we only compile for the | ||||||
| @ -813,38 +870,38 @@ void paged_attention_v2_launcher( | |||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| #define CALL_V2_LAUNCHER(T, BLOCK_SIZE)                             \ | #define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)           \ | ||||||
|   paged_attention_v2_launcher<T, BLOCK_SIZE>(                       \ |   paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>(     \ | ||||||
|     out,                                                            \ |     out,                                                                         \ | ||||||
|     exp_sums,                                                       \ |     exp_sums,                                                                    \ | ||||||
|     max_logits,                                                     \ |     max_logits,                                                                  \ | ||||||
|     tmp_out,                                                        \ |     tmp_out,                                                                     \ | ||||||
|     query,                                                          \ |     query,                                                                       \ | ||||||
|     key_cache,                                                      \ |     key_cache,                                                                   \ | ||||||
|     value_cache,                                                    \ |     value_cache,                                                                 \ | ||||||
|     num_kv_heads,                                                   \ |     num_kv_heads,                                                                \ | ||||||
|     scale,                                                          \ |     scale,                                                                       \ | ||||||
|     block_tables,                                                   \ |     block_tables,                                                                \ | ||||||
|     context_lens,                                                   \ |     context_lens,                                                                \ | ||||||
|     max_context_len,                                                \ |     max_context_len,                                                             \ | ||||||
|     alibi_slopes); |     alibi_slopes); | ||||||
|  |  | ||||||
| // NOTE(woosuk): To reduce the compilation time, we omitted block sizes | // NOTE(woosuk): To reduce the compilation time, we omitted block sizes | ||||||
| // 1, 2, 4, 64, 128, 256. | // 1, 2, 4, 64, 128, 256. | ||||||
| #define CALL_V2_LAUNCHER_BLOCK_SIZE(T)                              \ | #define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE)       \ | ||||||
|   switch (block_size) {                                             \ |   switch (block_size) {                                                     \ | ||||||
|     case 8:                                                         \ |     case 8:                                                                 \ | ||||||
|       CALL_V2_LAUNCHER(T, 8);                                       \ |       CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);                \ | ||||||
|       break;                                                        \ |       break;                                                                \ | ||||||
|     case 16:                                                        \ |     case 16:                                                                \ | ||||||
|       CALL_V2_LAUNCHER(T, 16);                                      \ |       CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);               \ | ||||||
|       break;                                                        \ |       break;                                                                \ | ||||||
|     case 32:                                                        \ |     case 32:                                                                \ | ||||||
|       CALL_V2_LAUNCHER(T, 32);                                      \ |       CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);               \ | ||||||
|       break;                                                        \ |       break;                                                                \ | ||||||
|     default:                                                        \ |     default:                                                                \ | ||||||
|       TORCH_CHECK(false, "Unsupported block size: ", block_size);   \ |       TORCH_CHECK(false, "Unsupported block size: ", block_size);           \ | ||||||
|       break;                                                        \ |       break;                                                                \ | ||||||
|   } |   } | ||||||
|  |  | ||||||
| void paged_attention_v2( | void paged_attention_v2( | ||||||
| @ -861,15 +918,30 @@ void paged_attention_v2( | |||||||
|   torch::Tensor& context_lens,    // [num_seqs] |   torch::Tensor& context_lens,    // [num_seqs] | ||||||
|   int block_size, |   int block_size, | ||||||
|   int max_context_len, |   int max_context_len, | ||||||
|   const c10::optional<torch::Tensor>& alibi_slopes) { |   const c10::optional<torch::Tensor>& alibi_slopes, | ||||||
|   if (query.dtype() == at::ScalarType::Float) { |   const std::string& kv_cache_dtype) { | ||||||
|     CALL_V2_LAUNCHER_BLOCK_SIZE(float); |   if (kv_cache_dtype == "auto") { | ||||||
|   } else if (query.dtype() == at::ScalarType::Half) { |     if (query.dtype() == at::ScalarType::Float) { | ||||||
|     CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t); |       CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false); | ||||||
|   } else if (query.dtype() == at::ScalarType::BFloat16) { |     } else if (query.dtype() == at::ScalarType::Half) { | ||||||
|     CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16); |       CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false); | ||||||
|  |     } else if (query.dtype() == at::ScalarType::BFloat16) { | ||||||
|  |       CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false); | ||||||
|  |     } else { | ||||||
|  |       TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||||||
|  |     } | ||||||
|  |   } else if (kv_cache_dtype == "fp8_e5m2") { | ||||||
|  |     if (query.dtype() == at::ScalarType::Float) { | ||||||
|  |       CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true); | ||||||
|  |     } else if (query.dtype() == at::ScalarType::Half) { | ||||||
|  |       CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true); | ||||||
|  |     } else if (query.dtype() == at::ScalarType::BFloat16) { | ||||||
|  |       CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true); | ||||||
|  |     } else { | ||||||
|  |       TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); | ||||||
|  |     } | ||||||
|   } else { |   } else { | ||||||
|     TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); |     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										35
									
								
								csrc/attention/dtype_fp8_e5m2.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								csrc/attention/dtype_fp8_e5m2.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,35 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include "attention_generic.cuh" | ||||||
|  |  | ||||||
|  | #include <stdint.h> | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  | #include <cuda_fp8.h> | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | namespace vllm { | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  | // fp8 vector types for quantization of kv cache | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | struct Vec<uint8_t, 1> { | ||||||
|  |     using Type = uint8_t; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | struct Vec<uint8_t, 2> { | ||||||
|  |     using Type = uint16_t; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | struct Vec<uint8_t, 4> { | ||||||
|  |     using Type = uint32_t; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | struct Vec<uint8_t, 8> { | ||||||
|  |     using Type = uint2; | ||||||
|  | }; | ||||||
|  | #endif // ENABLE_FP8_E5M2 | ||||||
|  |  | ||||||
|  | } // namespace vllm | ||||||
| @ -20,7 +20,8 @@ void reshape_and_cache( | |||||||
|   torch::Tensor& value, |   torch::Tensor& value, | ||||||
|   torch::Tensor& key_cache, |   torch::Tensor& key_cache, | ||||||
|   torch::Tensor& value_cache, |   torch::Tensor& value_cache, | ||||||
|   torch::Tensor& slot_mapping); |   torch::Tensor& slot_mapping, | ||||||
|  |   const std::string& kv_cache_dtype); | ||||||
|  |  | ||||||
| void gather_cached_kv( | void gather_cached_kv( | ||||||
|   torch::Tensor& key, |   torch::Tensor& key, | ||||||
| @ -28,3 +29,8 @@ void gather_cached_kv( | |||||||
|   torch::Tensor& key_cache, |   torch::Tensor& key_cache, | ||||||
|   torch::Tensor& value_cache, |   torch::Tensor& value_cache, | ||||||
|   torch::Tensor& slot_mapping); |   torch::Tensor& slot_mapping); | ||||||
|  |  | ||||||
|  | // Just for unittest | ||||||
|  | void convert_fp8_e5m2( | ||||||
|  |   torch::Tensor& src_cache, | ||||||
|  |   torch::Tensor& dst_cache); | ||||||
|  | |||||||
| @ -1,8 +1,10 @@ | |||||||
| #include <torch/extension.h> | #include <torch/extension.h> | ||||||
| #include <ATen/cuda/CUDAContext.h> | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  |  | ||||||
| #include "cuda_compat.h" | #include "cuda_compat.h" | ||||||
| #include "dispatch_utils.h" | #include "dispatch_utils.h" | ||||||
|  | #include "quantization/fp8_e5m2_kvcache/quant_utils.cuh" | ||||||
|  |  | ||||||
| #include <algorithm> | #include <algorithm> | ||||||
| #include <cassert> | #include <cassert> | ||||||
| @ -33,6 +35,7 @@ void swap_blocks( | |||||||
|   char *dst_ptr = static_cast<char*>(dst.data_ptr()); |   char *dst_ptr = static_cast<char*>(dst.data_ptr()); | ||||||
|  |  | ||||||
|   const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); |   const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   // NOTE(woosuk): This can be slow if the number of blocks is large. |   // NOTE(woosuk): This can be slow if the number of blocks is large. | ||||||
|   for (const auto& pair : block_mapping) { |   for (const auto& pair : block_mapping) { | ||||||
| @ -127,8 +130,9 @@ void copy_blocks( | |||||||
|   const int numel_per_block = key_caches[0][0].numel(); |   const int numel_per_block = key_caches[0][0].numel(); | ||||||
|   dim3 grid(num_layers, num_pairs); |   dim3 grid(num_layers, num_pairs); | ||||||
|   dim3 block(std::min(1024, numel_per_block)); |   dim3 block(std::min(1024, numel_per_block)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(cache_device); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( | ||||||
|     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { |     key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { | ||||||
|       vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( |       vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>( | ||||||
|         key_cache_ptrs_tensor.data_ptr<int64_t>(), |         key_cache_ptrs_tensor.data_ptr<int64_t>(), | ||||||
| @ -140,12 +144,12 @@ void copy_blocks( | |||||||
|  |  | ||||||
| namespace vllm { | namespace vllm { | ||||||
|  |  | ||||||
| template<typename scalar_t> | template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache> | ||||||
| __global__ void reshape_and_cache_kernel( | __global__ void reshape_and_cache_kernel( | ||||||
|   const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size] |   const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size] | ||||||
|   const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size] |   const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size] | ||||||
|   scalar_t* __restrict__ key_cache,           // [num_blocks, num_heads, head_size/x, block_size, x] |   cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x] | ||||||
|   scalar_t* __restrict__ value_cache,         // [num_blocks, num_heads, head_size, block_size] |   cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size] | ||||||
|   const int64_t* __restrict__ slot_mapping,   // [num_tokens] |   const int64_t* __restrict__ slot_mapping,   // [num_tokens] | ||||||
|   const int key_stride, |   const int key_stride, | ||||||
|   const int value_stride, |   const int value_stride, | ||||||
| @ -182,19 +186,45 @@ __global__ void reshape_and_cache_kernel( | |||||||
|                                   + head_idx * head_size * block_size |                                   + head_idx * head_size * block_size | ||||||
|                                   + head_offset * block_size |                                   + head_offset * block_size | ||||||
|                                   + block_offset; |                                   + block_offset; | ||||||
|     key_cache[tgt_key_idx] = key[src_key_idx]; |     scalar_t tgt_key = key[src_key_idx]; | ||||||
|     value_cache[tgt_value_idx] = value[src_value_idx]; |     scalar_t tgt_value = value[src_value_idx]; | ||||||
|  |     if constexpr (is_fp8_e5m2_kv_cache) { | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  |       key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key); | ||||||
|  |       value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value); | ||||||
|  | #else | ||||||
|  |       assert(false); | ||||||
|  | #endif | ||||||
|  |     } else { | ||||||
|  |       key_cache[tgt_key_idx] = tgt_key; | ||||||
|  |       value_cache[tgt_value_idx] = tgt_value; | ||||||
|  |     } | ||||||
|   } |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| } // namespace vllm | } // namespace vllm | ||||||
|  |  | ||||||
|  | #define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE)                                \ | ||||||
|  |   vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \ | ||||||
|  |     reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \ | ||||||
|  |     reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \ | ||||||
|  |     reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \ | ||||||
|  |     reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \ | ||||||
|  |     slot_mapping.data_ptr<int64_t>(),                                                              \ | ||||||
|  |     key_stride,                                                                                    \ | ||||||
|  |     value_stride,                                                                                  \ | ||||||
|  |     num_heads,                                                                                     \ | ||||||
|  |     head_size,                                                                                     \ | ||||||
|  |     block_size,                                                                                    \ | ||||||
|  |     x); | ||||||
|  |  | ||||||
| void reshape_and_cache( | void reshape_and_cache( | ||||||
|   torch::Tensor& key,           // [num_tokens, num_heads, head_size] |   torch::Tensor& key,           // [num_tokens, num_heads, head_size] | ||||||
|   torch::Tensor& value,         // [num_tokens, num_heads, head_size] |   torch::Tensor& value,         // [num_tokens, num_heads, head_size] | ||||||
|   torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x] |   torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x] | ||||||
|   torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size] |   torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size] | ||||||
|   torch::Tensor& slot_mapping)  // [num_tokens] |   torch::Tensor& slot_mapping,  // [num_tokens] | ||||||
|  |   const std::string& kv_cache_dtype) | ||||||
| { | { | ||||||
|   int num_tokens = key.size(0); |   int num_tokens = key.size(0); | ||||||
|   int num_heads = key.size(1); |   int num_heads = key.size(1); | ||||||
| @ -207,24 +237,27 @@ void reshape_and_cache( | |||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|   dim3 block(std::min(num_heads * head_size, 512)); |   dim3 block(std::min(num_heads * head_size, 512)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   if (kv_cache_dtype == "auto") { | ||||||
|     key.scalar_type(), |     if (key.dtype() == at::ScalarType::Float) { | ||||||
|     "reshape_and_cache_kernel", |       CALL_RESHAPE_AND_CACHE(float, float, false); | ||||||
|     [&] { |     } else if (key.dtype() == at::ScalarType::Half) { | ||||||
|       vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>( |       CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false); | ||||||
|         key.data_ptr<scalar_t>(), |     } else if (key.dtype() == at::ScalarType::BFloat16) { | ||||||
|         value.data_ptr<scalar_t>(), |       CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false); | ||||||
|         key_cache.data_ptr<scalar_t>(), |     } | ||||||
|         value_cache.data_ptr<scalar_t>(), |   } else if (kv_cache_dtype == "fp8_e5m2") { | ||||||
|         slot_mapping.data_ptr<int64_t>(), |     if (key.dtype() == at::ScalarType::Float) { | ||||||
|         key_stride, |       CALL_RESHAPE_AND_CACHE(float, uint8_t, true); | ||||||
|         value_stride, |     } else if (key.dtype() == at::ScalarType::Half) { | ||||||
|         num_heads, |       CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true); | ||||||
|         head_size, |     } else if (key.dtype() == at::ScalarType::BFloat16) { | ||||||
|         block_size, |       CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true); | ||||||
|         x); |     } | ||||||
|     }); |   } else { | ||||||
|  |     TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype); | ||||||
|  |   } | ||||||
| } | } | ||||||
|  |  | ||||||
| namespace vllm { | namespace vllm { | ||||||
| @ -252,12 +285,12 @@ __global__ void gather_cached_kv_kernel( | |||||||
|     for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { |     for (int i = threadIdx.x; i < num_tokens; i += blockDim.x) { | ||||||
|       const int tgt_key_idx = token_idx * key_stride + i; |       const int tgt_key_idx = token_idx * key_stride + i; | ||||||
|       const int tgt_value_idx = token_idx * value_stride + i; |       const int tgt_value_idx = token_idx * value_stride + i; | ||||||
|    |  | ||||||
|       const int head_idx = i / head_size; |       const int head_idx = i / head_size; | ||||||
|       const int head_offset = i % head_size; |       const int head_offset = i % head_size; | ||||||
|       const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension |       const int x_idx = head_offset / x;  // the offset of the [head_size/x] dimension | ||||||
|       const int x_offset = head_offset % x; |       const int x_offset = head_offset % x; | ||||||
|    |  | ||||||
|       const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x |       const int src_key_idx = block_idx * num_heads * (head_size / x) * block_size * x | ||||||
|                               + head_idx * (head_size / x) * block_size * x |                               + head_idx * (head_size / x) * block_size * x | ||||||
|                               + x_idx * block_size * x |                               + x_idx * block_size * x | ||||||
| @ -367,8 +400,9 @@ void gather_cached_kv( | |||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|   dim3 block(std::min(num_heads * head_size, 512)); |   dim3 block(std::min(num_heads * head_size, 512)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( | ||||||
|     key.scalar_type(), |     key.scalar_type(), | ||||||
|     "gather_cached_kv_kernel_optimized", |     "gather_cached_kv_kernel_optimized", | ||||||
|     [&] { |     [&] { | ||||||
| @ -386,3 +420,55 @@ void gather_cached_kv( | |||||||
|         x); |         x); | ||||||
|     }); |     }); | ||||||
| } | } | ||||||
|  |  | ||||||
|  | namespace vllm { | ||||||
|  |  | ||||||
|  | template<typename Tout, typename Tin> | ||||||
|  | __global__ void convert_fp8_e5m2_kernel( | ||||||
|  |   const Tin* __restrict__ src_cache, | ||||||
|  |   Tout* __restrict__ dst_cache, | ||||||
|  |   const int64_t block_stride) { | ||||||
|  |   const int64_t block_idx = blockIdx.x; | ||||||
|  |   for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { | ||||||
|  |     int64_t idx = block_idx * block_stride + i; | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  |     dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]); | ||||||
|  | #else | ||||||
|  |     assert(false); | ||||||
|  | #endif | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace vllm | ||||||
|  |  | ||||||
|  | #define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \ | ||||||
|  |   vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \ | ||||||
|  |     reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \ | ||||||
|  |     reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \ | ||||||
|  |     block_stride); | ||||||
|  |  | ||||||
|  | void convert_fp8_e5m2( | ||||||
|  |   torch::Tensor& src_cache, | ||||||
|  |   torch::Tensor& dst_cache) | ||||||
|  | { | ||||||
|  |   int64_t num_blocks = src_cache.size(0); | ||||||
|  |   int64_t block_stride = src_cache.stride(0); | ||||||
|  |  | ||||||
|  |   dim3 grid(num_blocks); | ||||||
|  |   dim3 block(std::min(block_stride, int64_t(512))); | ||||||
|  |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |  | ||||||
|  |   if (src_cache.dtype() == at::ScalarType::Float) { | ||||||
|  |     CALL_CONVERT_FP8_E5M2(uint8_t, float); | ||||||
|  |   } else if (src_cache.dtype() == at::ScalarType::Half) { | ||||||
|  |     CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t); | ||||||
|  |   } else if (src_cache.dtype() == at::ScalarType::BFloat16) { | ||||||
|  |     CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16); | ||||||
|  |   } else if (dst_cache.dtype() == at::ScalarType::Float) { | ||||||
|  |     CALL_CONVERT_FP8_E5M2(float, uint8_t); | ||||||
|  |   } else if (dst_cache.dtype() == at::ScalarType::Half) { | ||||||
|  |     CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t); | ||||||
|  |   } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { | ||||||
|  |     CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  | |||||||
| @ -5,3 +5,6 @@ | |||||||
| int get_device_attribute( | int get_device_attribute( | ||||||
|     int attribute, |     int attribute, | ||||||
|     int device_id); |     int device_id); | ||||||
|  |  | ||||||
|  | int get_max_shared_memory_per_block_device_attribute( | ||||||
|  |     int device_id); | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| #ifdef USE_ROCM | #ifdef USE_ROCM | ||||||
|   #include <hip/hip_runtime.h> |   #include <hip/hip_runtime.h> | ||||||
|  |   #include <hip/hip_runtime_api.h> | ||||||
| #endif | #endif | ||||||
| int get_device_attribute( | int get_device_attribute( | ||||||
|     int attribute, |     int attribute, | ||||||
| @ -15,3 +16,20 @@ int get_device_attribute( | |||||||
|     cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); |     cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device); | ||||||
|     return value; |     return value; | ||||||
| } | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | int get_max_shared_memory_per_block_device_attribute( | ||||||
|  |     int device_id) | ||||||
|  | { | ||||||
|  | int attribute;     | ||||||
|  | // https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html | ||||||
|  | // cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74 | ||||||
|  |  | ||||||
|  | #ifdef USE_ROCM | ||||||
|  |     attribute = hipDeviceAttributeMaxSharedMemoryPerBlock; | ||||||
|  | #else | ||||||
|  |     attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin; | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  |     return get_device_attribute(attribute, device_id); | ||||||
|  | } | ||||||
|  | |||||||
							
								
								
									
										148
									
								
								csrc/custom_all_reduce.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								csrc/custom_all_reduce.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,148 @@ | |||||||
|  | #include <ATen/cuda/Exceptions.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  | #include <c10/cuda/CUDAStream.h> | ||||||
|  | #include <torch/extension.h> | ||||||
|  |  | ||||||
|  | #include "custom_all_reduce.cuh" | ||||||
|  |  | ||||||
|  | // fake pointer type | ||||||
|  | using fptr_t = uint64_t; | ||||||
|  | static_assert(sizeof(void *) == sizeof(fptr_t)); | ||||||
|  |  | ||||||
|  | fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, | ||||||
|  |                       const std::vector<std::string> &handles, | ||||||
|  |                       const std::vector<int64_t> &offsets, int rank, | ||||||
|  |                       bool full_nvlink) { | ||||||
|  |   int world_size = offsets.size(); | ||||||
|  |   if (world_size > 8) | ||||||
|  |     throw std::invalid_argument("world size > 8 is not supported"); | ||||||
|  |   if (world_size % 2 != 0) | ||||||
|  |     throw std::invalid_argument("Odd num gpus is not supported for now"); | ||||||
|  |   if (world_size != handles.size()) | ||||||
|  |     throw std::invalid_argument( | ||||||
|  |         "handles length should equal to offsets length"); | ||||||
|  |   if (rank < 0 || rank >= world_size) | ||||||
|  |     throw std::invalid_argument("invalid rank passed in"); | ||||||
|  |  | ||||||
|  |   cudaIpcMemHandle_t ipc_handles[8]; | ||||||
|  |   for (int i = 0; i < world_size; i++) { | ||||||
|  |     std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); | ||||||
|  |   } | ||||||
|  |   return (fptr_t) new vllm::CustomAllreduce( | ||||||
|  |       reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(), | ||||||
|  |       rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | /** | ||||||
|  |  * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + | ||||||
|  |  * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() | ||||||
|  |  * because it allows transpose of contiguous slice (i.e. slicing the first | ||||||
|  |  * dimension). Currently, we require this because stride information is not | ||||||
|  |  * passed into the kernels and we treat input tensors as flat. | ||||||
|  |  * | ||||||
|  |  * Examples | ||||||
|  |  * A = torch.zeros(3, 3, 3) | ||||||
|  |  * 1. A: OK | ||||||
|  |  * 2. A[1:]: OK | ||||||
|  |  * 3. A.permute(2, 0, 1): OK | ||||||
|  |  * 4. A[1:].permute(2, 0, 1): OK | ||||||
|  |  * 5. A[None].expand(2, -1, -1, -1): Not OK | ||||||
|  |  * 6. A[:, 1:, 1:]: Not OK | ||||||
|  |  */ | ||||||
|  | bool _is_weak_contiguous(torch::Tensor &t) { | ||||||
|  |   return t.is_contiguous() || | ||||||
|  |          (t.storage().nbytes() - t.storage_offset() * t.element_size() == | ||||||
|  |           t.numel() * t.element_size()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, | ||||||
|  |                       bool full_nvlink) { | ||||||
|  |   auto inp_size = inp.numel() * inp.element_size(); | ||||||
|  |   // custom allreduce requires input byte size to be multiples of 16 | ||||||
|  |   if (inp_size % 16 != 0) return false; | ||||||
|  |   if (!_is_weak_contiguous(inp)) return false; | ||||||
|  |   if (world_size == 2 || full_nvlink) return inp_size <= max_size; | ||||||
|  |   // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size | ||||||
|  |   // <= 512k | ||||||
|  |   return world_size <= 4 && inp_size <= 512 * 1024; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out, | ||||||
|  |                  cudaStream_t stream) { | ||||||
|  |   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||||
|  |   TORCH_CHECK(_is_weak_contiguous(out)); | ||||||
|  |   switch (out.scalar_type()) { | ||||||
|  |     case at::ScalarType::Float: { | ||||||
|  |       fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()), | ||||||
|  |                            reinterpret_cast<float *>(out.data_ptr()), | ||||||
|  |                            out.numel()); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |     case at::ScalarType::Half: { | ||||||
|  |       fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()), | ||||||
|  |                           reinterpret_cast<half *>(out.data_ptr()), | ||||||
|  |                           out.numel()); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  | #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) | ||||||
|  |     case at::ScalarType::BFloat16: { | ||||||
|  |       fa->allreduce<nv_bfloat16>( | ||||||
|  |           stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()), | ||||||
|  |           reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel()); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  | #endif | ||||||
|  |     default: | ||||||
|  |       throw std::runtime_error( | ||||||
|  |           "custom allreduce only supports float32, float16 and bfloat16"); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) { | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); | ||||||
|  |   auto stream = c10::cuda::getCurrentCUDAStream().stream(); | ||||||
|  |   TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); | ||||||
|  |   TORCH_CHECK_EQ(inp.numel(), out.numel()); | ||||||
|  |   _all_reduce(_fa, inp, out, stream); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, | ||||||
|  |                       torch::Tensor &out) { | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); | ||||||
|  |   auto stream = c10::cuda::getCurrentCUDAStream().stream(); | ||||||
|  |  | ||||||
|  |   auto input_size = inp.numel() * inp.element_size(); | ||||||
|  |   TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); | ||||||
|  |   TORCH_CHECK_EQ(inp.numel(), out.numel()); | ||||||
|  |   TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), | ||||||
|  |               "registered buffer is too small to contain the input"); | ||||||
|  |   AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), | ||||||
|  |                                 input_size, cudaMemcpyDeviceToDevice, stream)); | ||||||
|  |   _all_reduce(_fa, reg_buffer, out, stream); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void dispose(fptr_t _fa) { | ||||||
|  |   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||||
|  |   delete fa; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int meta_size() { return sizeof(vllm::Metadata); } | ||||||
|  |  | ||||||
|  | void register_buffer(fptr_t _fa, torch::Tensor &t, | ||||||
|  |                      const std::vector<std::string> &handles, | ||||||
|  |                      const std::vector<int64_t> &offsets) { | ||||||
|  |   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||||
|  |   fa->register_buffer(handles, offsets, t.data_ptr()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta( | ||||||
|  |     fptr_t _fa) { | ||||||
|  |   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||||
|  |   return fa->get_graph_buffer_ipc_meta(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, | ||||||
|  |                             const std::vector<std::vector<int64_t>> &offsets) { | ||||||
|  |   auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa); | ||||||
|  |   fa->register_graph_buffers(handles, offsets); | ||||||
|  | } | ||||||
							
								
								
									
										562
									
								
								csrc/custom_all_reduce.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										562
									
								
								csrc/custom_all_reduce.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,562 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <cuda.h> | ||||||
|  | #include <cuda_bf16.h> | ||||||
|  | #include <cuda_fp16.h> | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  |  | ||||||
|  | #include <iostream> | ||||||
|  | #include <limits> | ||||||
|  | #include <map> | ||||||
|  | #include <unordered_map> | ||||||
|  | #include <vector> | ||||||
|  |  | ||||||
|  | #define CUDACHECK(cmd)                                              \ | ||||||
|  |   do {                                                              \ | ||||||
|  |     cudaError_t e = cmd;                                            \ | ||||||
|  |     if (e != cudaSuccess) {                                         \ | ||||||
|  |       printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ | ||||||
|  |              cudaGetErrorString(e));                                \ | ||||||
|  |       exit(EXIT_FAILURE);                                           \ | ||||||
|  |     }                                                               \ | ||||||
|  |   } while (0) | ||||||
|  |  | ||||||
|  | namespace vllm { | ||||||
|  |  | ||||||
|  | struct Signal { | ||||||
|  |   alignas(64) union { | ||||||
|  |     uint64_t flag; | ||||||
|  |     unsigned char data[8]; | ||||||
|  |   } start; | ||||||
|  |   alignas(64) union { | ||||||
|  |     uint64_t flag; | ||||||
|  |     unsigned char data[8]; | ||||||
|  |   } end; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | struct Metadata { | ||||||
|  |   alignas(128) Signal sg; | ||||||
|  |   alignas(128) int counter; | ||||||
|  | }; | ||||||
|  | static_assert(offsetof(Metadata, counter) == 128); | ||||||
|  | static_assert(sizeof(Metadata) == 256); | ||||||
|  |  | ||||||
|  | struct __align__(16) RankData { const void *__restrict__ ptrs[8]; }; | ||||||
|  |  | ||||||
|  | struct RankSignals { | ||||||
|  |   volatile Signal *signals[8]; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | // like std::array, but aligned | ||||||
|  | template <typename T, int sz> | ||||||
|  | struct __align__(alignof(T) * sz) array_t { | ||||||
|  |   T data[sz]; | ||||||
|  |   using type = T; | ||||||
|  |   static constexpr int size = sz; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | // use packed type to maximize memory efficiency | ||||||
|  | // goal: generate ld.128 and st.128 instructions | ||||||
|  | template <typename T> | ||||||
|  | struct packed_t { | ||||||
|  |   // the (P)acked type for load/store | ||||||
|  |   using P = array_t<T, 16 / sizeof(T)>; | ||||||
|  |   // the (A)ccumulator type for reduction | ||||||
|  |   using A = array_t<float, 16 / sizeof(T)>; | ||||||
|  | }; | ||||||
|  |  | ||||||
|  | #define DINLINE __device__ __forceinline__ | ||||||
|  |  | ||||||
|  | // scalar cast functions | ||||||
|  | DINLINE float upcast_s(half val) { return __half2float(val); } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | DINLINE T downcast_s(float val); | ||||||
|  | template <> | ||||||
|  | DINLINE half downcast_s(float val) { | ||||||
|  |   return __float2half(val); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // scalar add functions | ||||||
|  | // for some reason when compiling with Pytorch, the + operator for half and | ||||||
|  | // bfloat is disabled so we call the intrinsics directly | ||||||
|  | DINLINE half &assign_add(half &a, half b) { | ||||||
|  |   a = __hadd(a, b); | ||||||
|  |   return a; | ||||||
|  | } | ||||||
|  | DINLINE float &assign_add(float &a, float b) { return a += b; } | ||||||
|  |  | ||||||
|  | #if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) | ||||||
|  | DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } | ||||||
|  | template <> | ||||||
|  | DINLINE nv_bfloat16 downcast_s(float val) { | ||||||
|  |   return __float2bfloat16(val); | ||||||
|  | } | ||||||
|  | DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) { | ||||||
|  |   a = __hadd(a, b); | ||||||
|  |   return a; | ||||||
|  | } | ||||||
|  | #endif | ||||||
|  |  | ||||||
|  | template <typename T, int N> | ||||||
|  | DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) { | ||||||
|  | #pragma unroll | ||||||
|  |   for (int i = 0; i < N; i++) { | ||||||
|  |     assign_add(a.data[i], b.data[i]); | ||||||
|  |   } | ||||||
|  |   return a; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, int N> | ||||||
|  | DINLINE array_t<float, N> upcast(array_t<T, N> val) { | ||||||
|  |   if constexpr (std::is_same<T, float>::value) { | ||||||
|  |     return val; | ||||||
|  |   } else { | ||||||
|  |     array_t<float, N> out; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int i = 0; i < N; i++) { | ||||||
|  |       out.data[i] = upcast_s(val.data[i]); | ||||||
|  |     } | ||||||
|  |     return out; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename O> | ||||||
|  | DINLINE O downcast(array_t<float, O::size> val) { | ||||||
|  |   if constexpr (std::is_same<typename O::type, float>::value) { | ||||||
|  |     return val; | ||||||
|  |   } else { | ||||||
|  |     O out; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int i = 0; i < O::size; i++) { | ||||||
|  |       out.data[i] = downcast_s<typename O::type>(val.data[i]); | ||||||
|  |     } | ||||||
|  |     return out; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // compute flag at compile time | ||||||
|  | __host__ __device__ constexpr uint64_t compute_flag(int ngpus) { | ||||||
|  |   auto m = std::numeric_limits<uint64_t>::max(); | ||||||
|  |   return m >> ((8 - ngpus) * 8); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int ngpus> | ||||||
|  | DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta, | ||||||
|  |                         int rank) { | ||||||
|  |   constexpr auto FLAG = compute_flag(ngpus); | ||||||
|  |   if (blockIdx.x == 0) { | ||||||
|  |     if (threadIdx.x < ngpus) | ||||||
|  |       // simultaneously write to the corresponding byte to all other ranks. | ||||||
|  |       // Latency = 1 p2p write | ||||||
|  |       sg.signals[threadIdx.x]->start.data[rank] = 255; | ||||||
|  |     else if (threadIdx.x == 32) | ||||||
|  |       // reset | ||||||
|  |       meta->sg.end.flag = 0; | ||||||
|  |   } | ||||||
|  |   if (threadIdx.x == 0) { | ||||||
|  |     while (meta->sg.start.flag != FLAG) | ||||||
|  |       ; | ||||||
|  |   } | ||||||
|  |   __syncthreads(); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int ngpus, bool final_sync = false> | ||||||
|  | DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta, | ||||||
|  |                       int rank) { | ||||||
|  |   constexpr auto FLAG = compute_flag(ngpus); | ||||||
|  |   __syncthreads(); | ||||||
|  |   __shared__ int num; | ||||||
|  |   if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1); | ||||||
|  |   __syncthreads(); | ||||||
|  |  | ||||||
|  |   // Only the last completing block can perform the end synchronization | ||||||
|  |   // This can ensures when the final busy wait ends, all ranks must have | ||||||
|  |   // finished reading each other's buffer. | ||||||
|  |   if (num == gridDim.x - 1) { | ||||||
|  |     if (threadIdx.x == 32) { | ||||||
|  |       // reset in a different warp | ||||||
|  |       meta->counter = 0; | ||||||
|  |       meta->sg.start.flag = 0; | ||||||
|  |     } else if (threadIdx.x < ngpus) { | ||||||
|  |       // simultaneously write to the corresponding byte to all other ranks. | ||||||
|  |       // Latency = 1 p2p write | ||||||
|  |       sg.signals[threadIdx.x]->end.data[rank] = 255; | ||||||
|  |     } | ||||||
|  |     // if this is the final sync, only one block needs it | ||||||
|  |     // because kernel exit can serve as sync | ||||||
|  |     if constexpr (final_sync) { | ||||||
|  |       if (threadIdx.x == 0) { | ||||||
|  |         while (meta->sg.end.flag != FLAG) | ||||||
|  |           ; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   if constexpr (!final_sync) { | ||||||
|  |     if (threadIdx.x == 0) { | ||||||
|  |       while (meta->sg.end.flag != FLAG) | ||||||
|  |         ; | ||||||
|  |     } | ||||||
|  |     __syncthreads(); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename P, int ngpus, typename A> | ||||||
|  | DINLINE P packed_reduce(const P *ptrs[], int idx) { | ||||||
|  |   A tmp = upcast(ptrs[0][idx]); | ||||||
|  | #pragma unroll | ||||||
|  |   for (int i = 1; i < ngpus; i++) { | ||||||
|  |     packed_assign_add(tmp, upcast(ptrs[i][idx])); | ||||||
|  |   } | ||||||
|  |   return downcast<P>(tmp); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, int ngpus> | ||||||
|  | __global__ void __launch_bounds__(512, 1) | ||||||
|  |     cross_device_reduce_1stage(RankData *_dp, RankSignals sg, | ||||||
|  |                                volatile Metadata *meta, T *__restrict__ result, | ||||||
|  |                                int rank, int size) { | ||||||
|  |   using P = typename packed_t<T>::P; | ||||||
|  |   using A = typename packed_t<T>::A; | ||||||
|  |   // note: we don't reorder the address so the accumulation order is the same | ||||||
|  |   // for all ranks, ensuring bitwise identical results | ||||||
|  |   auto dp = *_dp; | ||||||
|  |   start_sync<ngpus>(sg, meta, rank); | ||||||
|  |   // do the actual reduction | ||||||
|  |   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||||
|  |        idx += gridDim.x * blockDim.x) { | ||||||
|  |     ((P *)result)[idx] = | ||||||
|  |         packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx); | ||||||
|  |   } | ||||||
|  |   end_sync<ngpus, true>(sg, meta, rank); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename P> | ||||||
|  | DINLINE P *get_tmp_buf(volatile Signal *sg) { | ||||||
|  |   return (P *)(((Metadata *)sg) + 1); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, int ngpus> | ||||||
|  | __global__ void __launch_bounds__(512, 1) | ||||||
|  |     cross_device_reduce_2stage(RankData *_dp, RankSignals sg, | ||||||
|  |                                volatile Metadata *meta, T *__restrict__ result, | ||||||
|  |                                int rank, int size) { | ||||||
|  |   int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |   int stride = gridDim.x * blockDim.x; | ||||||
|  |   using P = typename packed_t<T>::P; | ||||||
|  |   using A = typename packed_t<T>::A; | ||||||
|  |   int part = size / ngpus; | ||||||
|  |   int start = rank * part; | ||||||
|  |   int end = rank == ngpus - 1 ? size : start + part; | ||||||
|  |   const P *ptrs[ngpus]; | ||||||
|  |   P *tmps[ngpus]; | ||||||
|  | #pragma unroll | ||||||
|  |   for (int i = 0; i < ngpus; i++) { | ||||||
|  |     int target = (rank + i) % ngpus; | ||||||
|  |     ptrs[i] = (const P *)_dp->ptrs[target]; | ||||||
|  |     tmps[i] = get_tmp_buf<P>(sg.signals[target]); | ||||||
|  |   } | ||||||
|  |   auto tmp_out = tmps[0]; | ||||||
|  |   start_sync<ngpus>(sg, meta, rank); | ||||||
|  |   // stage 1: reduce scatter | ||||||
|  |   for (int idx = start + tid; idx < end; idx += stride) { | ||||||
|  |     tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx); | ||||||
|  |   } | ||||||
|  |   // Maybe TODO: replace this with per-block release-acquire | ||||||
|  |   // can save about 1-2us (not a lot though) | ||||||
|  |   end_sync<ngpus>(sg, meta, rank); | ||||||
|  |  | ||||||
|  |   // stage 2: allgather | ||||||
|  |   for (int idx = tid; idx < part; idx += stride) { | ||||||
|  | #pragma unroll | ||||||
|  |     for (int i = 0; i < ngpus; i++) { | ||||||
|  |       int dst_idx = ((rank + i) % ngpus) * part + idx; | ||||||
|  |       ((P *)result)[dst_idx] = tmps[i][idx]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   // process the last larger partition | ||||||
|  |   int remaining = size - part * ngpus; | ||||||
|  |   if (tid < remaining) { | ||||||
|  |     int dst_idx = tid + part * ngpus; | ||||||
|  |     ((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // faster than this | ||||||
|  |   // for (int idx = tid; idx < size; idx += stride) { | ||||||
|  |   //   int target_rank = idx / part; | ||||||
|  |   //   if (target_rank == ngpus) target_rank -= 1; | ||||||
|  |   //   ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part]; | ||||||
|  |   // } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T, int ngpus> | ||||||
|  | __global__ void __launch_bounds__(512, 1) | ||||||
|  |     cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg, | ||||||
|  |                                        volatile Metadata *meta, | ||||||
|  |                                        T *__restrict__ result, int rank, | ||||||
|  |                                        int size) { | ||||||
|  |   int tid = blockIdx.x * blockDim.x + threadIdx.x; | ||||||
|  |   int stride = gridDim.x * blockDim.x; | ||||||
|  |   using P = typename packed_t<T>::P; | ||||||
|  |   using A = typename packed_t<T>::A; | ||||||
|  |   auto tmp_out = get_tmp_buf<P>(sg.signals[rank]); | ||||||
|  |   constexpr int hg = ngpus / 2; | ||||||
|  |   // Actually not quite half butterfly. | ||||||
|  |   // This is an all-to-all within each group containing half of the ranks | ||||||
|  |   // followed by cross-group add. Equivalent to half butterfly when there | ||||||
|  |   // are 4 GPUs, a common case for PCIe cards like T4 and A10. | ||||||
|  |   const P *ptrs[hg]; | ||||||
|  |   { | ||||||
|  |     int start = rank - rank % hg; | ||||||
|  | #pragma unroll | ||||||
|  |     for (int i = 0; i < hg; i++) { | ||||||
|  |       ptrs[i] = (const P *)_dp->ptrs[i + start]; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   start_sync<ngpus>(sg, meta, rank); | ||||||
|  |   for (int idx = tid; idx < size; idx += stride) { | ||||||
|  |     tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx); | ||||||
|  |   } | ||||||
|  |   end_sync<ngpus>(sg, meta, rank); | ||||||
|  |  | ||||||
|  |   auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]); | ||||||
|  |   // do the cross group reduction | ||||||
|  |   for (int idx = tid; idx < size; idx += stride) { | ||||||
|  |     auto tmp = tmp_out[idx]; | ||||||
|  |     packed_assign_add(tmp, src[idx]); | ||||||
|  |     ((P *)result)[idx] = tmp; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>; | ||||||
|  | static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); | ||||||
|  | static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); | ||||||
|  |  | ||||||
|  | class CustomAllreduce { | ||||||
|  |  public: | ||||||
|  |   int rank_; | ||||||
|  |   int world_size_; | ||||||
|  |   bool full_nvlink_; | ||||||
|  |  | ||||||
|  |   // below are device pointers | ||||||
|  |   RankSignals sg_; | ||||||
|  |   std::unordered_map<void *, RankData *> buffers_; | ||||||
|  |   Metadata *meta_; | ||||||
|  |  | ||||||
|  |   // stores the registered device pointers from all ranks | ||||||
|  |   RankData *d_rank_data_base_, *d_rank_data_end_; | ||||||
|  |   std::vector<void *> graph_unreg_buffers_; | ||||||
|  |   // a map from IPC handles to opened IPC pointers | ||||||
|  |   std::map<IPC_KEY, char *> ipc_handles_; | ||||||
|  |  | ||||||
|  |   /** | ||||||
|  |    * meta is a pointer to device metadata and temporary buffer for allreduce. | ||||||
|  |    * | ||||||
|  |    * There's a total of sizeof(Metadata) of prefix before the actual data, | ||||||
|  |    * so meta + 1 points to actual temporary buffer. | ||||||
|  |    * | ||||||
|  |    * note: this class does not own any device memory. Any required buffers | ||||||
|  |    * are passed in from the constructor | ||||||
|  |    */ | ||||||
|  |   CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz, | ||||||
|  |                   const cudaIpcMemHandle_t *handles, | ||||||
|  |                   const std::vector<int64_t> &offsets, int rank, | ||||||
|  |                   bool full_nvlink = true) | ||||||
|  |       : rank_(rank), | ||||||
|  |         world_size_(offsets.size()), | ||||||
|  |         full_nvlink_(full_nvlink), | ||||||
|  |         meta_(meta), | ||||||
|  |         d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)), | ||||||
|  |         d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { | ||||||
|  |     for (int i = 0; i < world_size_; i++) { | ||||||
|  |       Metadata *rank_meta; | ||||||
|  |       if (i != rank_) { | ||||||
|  |         char *handle = open_ipc_handle(&handles[i]); | ||||||
|  |         handle += offsets[i]; | ||||||
|  |         rank_meta = (Metadata *)handle; | ||||||
|  |       } else { | ||||||
|  |         rank_meta = meta_; | ||||||
|  |       } | ||||||
|  |       sg_.signals[i] = &rank_meta->sg; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   char *open_ipc_handle(const void *ipc_handle) { | ||||||
|  |     auto [it, new_handle] = | ||||||
|  |         ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr}); | ||||||
|  |     if (new_handle) { | ||||||
|  |       char *ipc_ptr; | ||||||
|  |       CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr, | ||||||
|  |                                      *((const cudaIpcMemHandle_t *)ipc_handle), | ||||||
|  |                                      cudaIpcMemLazyEnablePeerAccess)); | ||||||
|  |       it->second = ipc_ptr; | ||||||
|  |     } | ||||||
|  |     return it->second; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   std::pair<std::vector<uint8_t>, std::vector<int64_t>> | ||||||
|  |   get_graph_buffer_ipc_meta() { | ||||||
|  |     auto num_buffers = graph_unreg_buffers_.size(); | ||||||
|  |     auto handle_sz = sizeof(cudaIpcMemHandle_t); | ||||||
|  |     std::vector<uint8_t> handles(handle_sz * num_buffers, 0); | ||||||
|  |     std::vector<int64_t> offsets(num_buffers); | ||||||
|  |     for (int i = 0; i < num_buffers; i++) { | ||||||
|  |       auto ptr = graph_unreg_buffers_[i]; | ||||||
|  |       void *base_ptr; | ||||||
|  |       // note: must share the base address of each allocation, or we get wrong | ||||||
|  |       // address | ||||||
|  |       if (cuPointerGetAttribute(&base_ptr, | ||||||
|  |                                 CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, | ||||||
|  |                                 (CUdeviceptr)ptr) != CUDA_SUCCESS) | ||||||
|  |         throw std::runtime_error("failed to get pointer attr"); | ||||||
|  |       CUDACHECK(cudaIpcGetMemHandle( | ||||||
|  |           (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr)); | ||||||
|  |       offsets[i] = ((char *)ptr) - ((char *)base_ptr); | ||||||
|  |     } | ||||||
|  |     return std::make_pair(handles, offsets); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void check_rank_data_capacity(size_t num = 1) { | ||||||
|  |     if (d_rank_data_base_ + num > d_rank_data_end_) | ||||||
|  |       throw std::runtime_error( | ||||||
|  |           "Rank data buffer is overflowed by " + | ||||||
|  |           std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   void register_buffer(const std::vector<std::string> &handles, | ||||||
|  |                        const std::vector<int64_t> &offsets, void *self) { | ||||||
|  |     check_rank_data_capacity(); | ||||||
|  |     RankData data; | ||||||
|  |     for (int i = 0; i < world_size_; i++) { | ||||||
|  |       if (i != rank_) { | ||||||
|  |         char *handle = open_ipc_handle(handles[i].data()); | ||||||
|  |         handle += offsets[i]; | ||||||
|  |         data.ptrs[i] = handle; | ||||||
|  |       } else { | ||||||
|  |         data.ptrs[i] = self; | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     auto d_data = d_rank_data_base_++; | ||||||
|  |     CUDACHECK( | ||||||
|  |         cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); | ||||||
|  |     buffers_[self] = d_data; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   // note: when registering graph buffers, we intentionally choose to not | ||||||
|  |   // deduplicate the addresses. That means if the allocator reuses some | ||||||
|  |   // addresses, they will be registered again. This is to account for the remote | ||||||
|  |   // possibility of different allocation patterns between ranks. For example, | ||||||
|  |   // rank 1 may get the same input address for the second allreduce, but rank 2 | ||||||
|  |   // got a different address. IPC handles have internal reference counting | ||||||
|  |   // mechanism so overhead should be small. | ||||||
|  |   void register_graph_buffers( | ||||||
|  |       const std::vector<std::string> &handles, | ||||||
|  |       const std::vector<std::vector<int64_t>> &offsets) { | ||||||
|  |     auto num_buffers = graph_unreg_buffers_.size(); | ||||||
|  |     check_rank_data_capacity(num_buffers); | ||||||
|  |     std::vector<RankData> rank_data(num_buffers); | ||||||
|  |     for (int i = 0; i < num_buffers; i++) { | ||||||
|  |       auto self_ptr = graph_unreg_buffers_[i]; | ||||||
|  |       auto &rd = rank_data[i]; | ||||||
|  |       for (int j = 0; j < world_size_; j++) { | ||||||
|  |         if (j != rank_) { | ||||||
|  |           char *handle = | ||||||
|  |               open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); | ||||||
|  |           handle += offsets[j][i]; | ||||||
|  |           rd.ptrs[j] = handle; | ||||||
|  |         } else { | ||||||
|  |           rd.ptrs[j] = self_ptr; | ||||||
|  |         } | ||||||
|  |       } | ||||||
|  |     } | ||||||
|  |     CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), | ||||||
|  |                          sizeof(RankData) * num_buffers, | ||||||
|  |                          cudaMemcpyHostToDevice)); | ||||||
|  |     d_rank_data_base_ += num_buffers; | ||||||
|  |     graph_unreg_buffers_.clear(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   /** | ||||||
|  |    * This is the result after careful grid search. Using 36 blocks give the best | ||||||
|  |    * or close to the best runtime on the devices I tried: A100, A10, A30, T4, | ||||||
|  |    * V100. You'll notice that NCCL kernels also only take a small amount of SMs. | ||||||
|  |    * Not quite sure the underlying reason, but my guess is that too many SMs | ||||||
|  |    * will cause contention on NVLink bus. | ||||||
|  |    */ | ||||||
|  |   template <typename T> | ||||||
|  |   void allreduce(cudaStream_t stream, T *input, T *output, int size, | ||||||
|  |                  int threads = 512, int block_limit = 36) { | ||||||
|  |     auto d = packed_t<T>::P::size; | ||||||
|  |     if (size % d != 0) | ||||||
|  |       throw std::runtime_error( | ||||||
|  |           "custom allreduce currently requires input length to be multiple " | ||||||
|  |           "of " + | ||||||
|  |           std::to_string(d)); | ||||||
|  |  | ||||||
|  |     RankData *ptrs; | ||||||
|  |     cudaStreamCaptureStatus status; | ||||||
|  |     CUDACHECK(cudaStreamIsCapturing(stream, &status)); | ||||||
|  |     if (status == cudaStreamCaptureStatusActive) { | ||||||
|  |       ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); | ||||||
|  |       graph_unreg_buffers_.push_back(input); | ||||||
|  |     } else { | ||||||
|  |       auto it = buffers_.find(input); | ||||||
|  |       if (it == buffers_.end()) | ||||||
|  |         throw std::runtime_error( | ||||||
|  |             "buffer address " + | ||||||
|  |             std::to_string(reinterpret_cast<uint64_t>(input)) + | ||||||
|  |             " is not registered!"); | ||||||
|  |       ptrs = it->second; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     size /= d; | ||||||
|  |     auto bytes = size * sizeof(typename packed_t<T>::P); | ||||||
|  |     int blocks = std::min(block_limit, (size + threads - 1) / threads); | ||||||
|  | #define KL(ngpus, name) \ | ||||||
|  |   name<T, ngpus>        \ | ||||||
|  |       <<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size); | ||||||
|  | #define REDUCE_CASE(ngpus)                            \ | ||||||
|  |   case ngpus: {                                       \ | ||||||
|  |     if (world_size_ == 2) {                           \ | ||||||
|  |       KL(ngpus, cross_device_reduce_1stage);          \ | ||||||
|  |     } else if (full_nvlink_) {                        \ | ||||||
|  |       if ((world_size_ <= 4 && bytes < 512 * 1024) || \ | ||||||
|  |           (world_size_ <= 8 && bytes < 256 * 1024)) { \ | ||||||
|  |         KL(ngpus, cross_device_reduce_1stage);        \ | ||||||
|  |       } else {                                        \ | ||||||
|  |         KL(ngpus, cross_device_reduce_2stage);        \ | ||||||
|  |       }                                               \ | ||||||
|  |     } else {                                          \ | ||||||
|  |       KL(ngpus, cross_device_reduce_half_butterfly);  \ | ||||||
|  |     }                                                 \ | ||||||
|  |     break;                                            \ | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |     switch (world_size_) { | ||||||
|  |       REDUCE_CASE(2) | ||||||
|  |       REDUCE_CASE(4) | ||||||
|  |       REDUCE_CASE(6) | ||||||
|  |       REDUCE_CASE(8) | ||||||
|  |       default: | ||||||
|  |         throw std::runtime_error( | ||||||
|  |             "custom allreduce only supports num gpus in (2,4,6,8). Actual num " | ||||||
|  |             "gpus = " + | ||||||
|  |             std::to_string(world_size_)); | ||||||
|  |     } | ||||||
|  | #undef REDUCE_CASE | ||||||
|  | #undef KL | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   ~CustomAllreduce() { | ||||||
|  |     for (auto [_, ptr] : ipc_handles_) { | ||||||
|  |       CUDACHECK(cudaIpcCloseMemHandle(ptr)); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | }; | ||||||
|  | /** | ||||||
|  |  * To inspect PTX/SASS, copy paste this header file to compiler explorer and add | ||||||
|  |  a template instantiation: | ||||||
|  |  * template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *, | ||||||
|  |  int, int, int); | ||||||
|  | */ | ||||||
|  | }  // namespace vllm | ||||||
							
								
								
									
										284
									
								
								csrc/custom_all_reduce_test.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								csrc/custom_all_reduce_test.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,284 @@ | |||||||
|  | /** | ||||||
|  |  * This is a standalone test for custom allreduce. | ||||||
|  |  * To compile, make sure you have MPI and NCCL installed in your system. | ||||||
|  |  * export MPI_HOME=XXX | ||||||
|  |  * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o | ||||||
|  |  * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi | ||||||
|  |  * | ||||||
|  |  * Warning: this C++ test is not designed to be very readable and was used | ||||||
|  |  * during the rapid prototyping process. | ||||||
|  |  * | ||||||
|  |  * To run: | ||||||
|  |  * mpirun -np 8 ./custom_all_reduce_test | ||||||
|  |  */ | ||||||
|  | #include <cuda.h> | ||||||
|  | #include <curand_kernel.h> | ||||||
|  | #include <stdio.h> | ||||||
|  | #include <stdlib.h> | ||||||
|  |  | ||||||
|  | #include <limits> | ||||||
|  | #include <vector> | ||||||
|  |  | ||||||
|  | #include "cuda_profiler_api.h" | ||||||
|  | #include "custom_all_reduce.cuh" | ||||||
|  | #include "mpi.h" | ||||||
|  | #include "nccl.h" | ||||||
|  |  | ||||||
|  | #define MPICHECK(cmd)                                                  \ | ||||||
|  |   do {                                                                 \ | ||||||
|  |     int e = cmd;                                                       \ | ||||||
|  |     if (e != MPI_SUCCESS) {                                            \ | ||||||
|  |       printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \ | ||||||
|  |       exit(EXIT_FAILURE);                                              \ | ||||||
|  |     }                                                                  \ | ||||||
|  |   } while (0) | ||||||
|  |  | ||||||
|  | #define NCCLCHECK(cmd)                                              \ | ||||||
|  |   do {                                                              \ | ||||||
|  |     ncclResult_t r = cmd;                                           \ | ||||||
|  |     if (r != ncclSuccess) {                                         \ | ||||||
|  |       printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \ | ||||||
|  |              ncclGetErrorString(r));                                \ | ||||||
|  |       exit(EXIT_FAILURE);                                           \ | ||||||
|  |     }                                                               \ | ||||||
|  |   } while (0) | ||||||
|  |  | ||||||
|  | __global__ void dummy_kernel() { | ||||||
|  |   for (int i = 0; i < 100; i++) __nanosleep(1000000);  // 100ms | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | __global__ void set_data(T *data, int size, int myRank) { | ||||||
|  |   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||||
|  |        idx += gridDim.x * blockDim.x) { | ||||||
|  |     data[idx] = myRank * 0.11f; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | __global__ void convert_data(const T *data1, const T *data2, double *fdata1, | ||||||
|  |                              double *fdata2, int size) { | ||||||
|  |   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||||
|  |        idx += gridDim.x * blockDim.x) { | ||||||
|  |     fdata1[idx] = data1[idx]; | ||||||
|  |     fdata2[idx] = data2[idx]; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | __global__ void init_rand(curandState_t *state, int size, int nRanks) { | ||||||
|  |   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||||
|  |        idx += gridDim.x * blockDim.x) { | ||||||
|  |     for (int i = 0; i < nRanks; i++) { | ||||||
|  |       curand_init(i + 1, idx, 0, &state[idx * nRanks + i]); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | __global__ void gen_data(curandState_t *state, T *data, double *ground_truth, | ||||||
|  |                          int myRank, int nRanks, int size) { | ||||||
|  |   for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; | ||||||
|  |        idx += gridDim.x * blockDim.x) { | ||||||
|  |     double sum = 0.0; | ||||||
|  |     for (int i = 0; i < nRanks; i++) { | ||||||
|  |       double val = curand_uniform_double(&state[idx * nRanks + i]) * 4; | ||||||
|  |       T hval = val;  // downcast first | ||||||
|  |       sum += static_cast<double>(hval); | ||||||
|  |       if (i == myRank) data[idx] = hval; | ||||||
|  |     } | ||||||
|  |     ground_truth[idx] = sum; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <typename T> | ||||||
|  | void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit, | ||||||
|  |          int data_size) { | ||||||
|  |   T *result; | ||||||
|  |   cudaStream_t stream; | ||||||
|  |   CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking)); | ||||||
|  |   CUDACHECK(cudaMalloc(&result, data_size * sizeof(T))); | ||||||
|  |   CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T))); | ||||||
|  |  | ||||||
|  |   cudaIpcMemHandle_t self_data_handle; | ||||||
|  |   cudaIpcMemHandle_t data_handles[8]; | ||||||
|  |   vllm::Metadata *buffer; | ||||||
|  |   T *self_data_copy; | ||||||
|  |   /** | ||||||
|  |    * Allocate IPC buffer | ||||||
|  |    * | ||||||
|  |    * The first section is a temporary buffer for storing intermediate allreduce | ||||||
|  |    * results, if a particular algorithm requires it. The second section is for | ||||||
|  |    * the input to the allreduce. The actual API takes the input pointer as an | ||||||
|  |    * argument (that is, they can and usually should be allocated separately). | ||||||
|  |    * But since the input pointers and the temporary buffer all require IPC | ||||||
|  |    * registration, they are allocated and registered together in the test for | ||||||
|  |    * convenience. | ||||||
|  |    */ | ||||||
|  |   CUDACHECK( | ||||||
|  |       cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); | ||||||
|  |   CUDACHECK(cudaMemset(buffer, 0, | ||||||
|  |                        2 * data_size * sizeof(T) + sizeof(vllm::Metadata))); | ||||||
|  |   CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); | ||||||
|  |   CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer)); | ||||||
|  |  | ||||||
|  |   MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t), | ||||||
|  |                          MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t), | ||||||
|  |                          MPI_BYTE, MPI_COMM_WORLD)); | ||||||
|  |  | ||||||
|  |   void *rank_data; | ||||||
|  |   size_t rank_data_sz = 16 * 1024 * 1024; | ||||||
|  |   CUDACHECK(cudaMalloc(&rank_data, rank_data_sz)); | ||||||
|  |   std::vector<int64_t> offsets(nRanks, 0); | ||||||
|  |   vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles, | ||||||
|  |                            offsets, myRank); | ||||||
|  |   auto *self_data = | ||||||
|  |       reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) + | ||||||
|  |                             sizeof(vllm::Metadata) + data_size * sizeof(T)); | ||||||
|  |   // hack buffer registration | ||||||
|  |   { | ||||||
|  |     std::vector<std::string> handles; | ||||||
|  |     handles.reserve(nRanks); | ||||||
|  |     for (int i = 0; i < nRanks; i++) { | ||||||
|  |       char *begin = (char *)&data_handles[i]; | ||||||
|  |       char *end = (char *)&data_handles[i + 1]; | ||||||
|  |       handles.emplace_back(begin, end); | ||||||
|  |     } | ||||||
|  |     std::vector<int64_t> offsets( | ||||||
|  |         nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T)); | ||||||
|  |     fa.register_buffer(handles, offsets, self_data); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   double *ground_truth; | ||||||
|  |   CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double))); | ||||||
|  |   curandState_t *states; | ||||||
|  |   CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size)); | ||||||
|  |   init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks); | ||||||
|  |   gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank, | ||||||
|  |                                         nRanks, data_size); | ||||||
|  |   CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T), | ||||||
|  |                             cudaMemcpyDeviceToDevice, stream)); | ||||||
|  |   cudaEvent_t start, stop; | ||||||
|  |   CUDACHECK(cudaEventCreate(&start)); | ||||||
|  |   CUDACHECK(cudaEventCreate(&stop)); | ||||||
|  |  | ||||||
|  |   ncclDataType_t ncclDtype; | ||||||
|  |   if (std::is_same<T, half>::value) { | ||||||
|  |     ncclDtype = ncclFloat16; | ||||||
|  |   } else if (std::is_same<T, nv_bfloat16>::value) { | ||||||
|  |     ncclDtype = ncclBfloat16; | ||||||
|  |   } else { | ||||||
|  |     ncclDtype = ncclFloat; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   dummy_kernel<<<1, 1, 0, stream>>>(); | ||||||
|  |   constexpr int warmup_iters = 5; | ||||||
|  |   constexpr int num_iters = 25; | ||||||
|  |   // warmup | ||||||
|  |   for (int i = 0; i < warmup_iters; i++) { | ||||||
|  |     NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, | ||||||
|  |                             stream)); | ||||||
|  |   } | ||||||
|  |   CUDACHECK(cudaEventRecord(start, stream)); | ||||||
|  |   for (int i = 0; i < num_iters; i++) { | ||||||
|  |     NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm, | ||||||
|  |                             stream)); | ||||||
|  |   } | ||||||
|  |   CUDACHECK(cudaEventRecord(stop, stream)); | ||||||
|  |   CUDACHECK(cudaStreamSynchronize(stream)); | ||||||
|  |   float allreduce_ms = 0; | ||||||
|  |   cudaEventElapsedTime(&allreduce_ms, start, stop); | ||||||
|  |  | ||||||
|  |   // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>(); | ||||||
|  |   // set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank); | ||||||
|  |  | ||||||
|  |   dummy_kernel<<<1, 1, 0, stream>>>(); | ||||||
|  |   // warm up | ||||||
|  |   for (int i = 0; i < warmup_iters; i++) { | ||||||
|  |     fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit); | ||||||
|  |   } | ||||||
|  |   CUDACHECK(cudaEventRecord(start, stream)); | ||||||
|  |   for (int i = 0; i < num_iters; i++) { | ||||||
|  |     fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit); | ||||||
|  |   } | ||||||
|  |   CUDACHECK(cudaEventRecord(stop, stream)); | ||||||
|  |   CUDACHECK(cudaStreamSynchronize(stream)); | ||||||
|  |  | ||||||
|  |   float duration_ms = 0; | ||||||
|  |   cudaEventElapsedTime(&duration_ms, start, stop); | ||||||
|  |   if (myRank == 0) | ||||||
|  |     printf( | ||||||
|  |         "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl " | ||||||
|  |         "time:%.2fus\n", | ||||||
|  |         myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit, | ||||||
|  |         duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters); | ||||||
|  |  | ||||||
|  |   // And wait for all the queued up work to complete | ||||||
|  |   CUDACHECK(cudaStreamSynchronize(stream)); | ||||||
|  |  | ||||||
|  |   NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype, | ||||||
|  |                           ncclSum, comm, stream)); | ||||||
|  |  | ||||||
|  |   double *nccl_result, *my_result; | ||||||
|  |   CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double))); | ||||||
|  |   CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double))); | ||||||
|  |  | ||||||
|  |   convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result, | ||||||
|  |                                             my_result, data_size); | ||||||
|  |   CUDACHECK(cudaStreamSynchronize(stream)); | ||||||
|  |  | ||||||
|  |   for (unsigned long j = 0; j < data_size; j++) { | ||||||
|  |     auto diff = abs(nccl_result[j] - my_result[j]); | ||||||
|  |     if (diff >= 1e-2) { | ||||||
|  |       printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n", | ||||||
|  |              myRank, j, nccl_result[j], my_result[j], ground_truth[j]); | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   long double nccl_diffs = 0.0; | ||||||
|  |   long double my_diffs = 0.0; | ||||||
|  |   for (int j = 0; j < data_size; j++) { | ||||||
|  |     nccl_diffs += abs(nccl_result[j] - ground_truth[j]); | ||||||
|  |     my_diffs += abs(my_result[j] - ground_truth[j]); | ||||||
|  |   } | ||||||
|  |   if (myRank == 0) | ||||||
|  |     std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size | ||||||
|  |               << " me: " << my_diffs / data_size << std::endl; | ||||||
|  |  | ||||||
|  |   CUDACHECK(cudaFree(result)); | ||||||
|  |   CUDACHECK(cudaFree(self_data_copy)); | ||||||
|  |   CUDACHECK(cudaFree(rank_data)); | ||||||
|  |   CUDACHECK(cudaFree(buffer)); | ||||||
|  |   CUDACHECK(cudaFree(states)); | ||||||
|  |   CUDACHECK(cudaFreeHost(ground_truth)); | ||||||
|  |   CUDACHECK(cudaFreeHost(nccl_result)); | ||||||
|  |   CUDACHECK(cudaFreeHost(my_result)); | ||||||
|  |   CUDACHECK(cudaStreamDestroy(stream)); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | int main(int argc, char **argv) { | ||||||
|  |   int nRanks, myRank; | ||||||
|  |   MPICHECK(MPI_Init(&argc, &argv)); | ||||||
|  |   MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank)); | ||||||
|  |   MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks)); | ||||||
|  |   CUDACHECK(cudaSetDevice(myRank)); | ||||||
|  |   ncclUniqueId id; | ||||||
|  |   ncclComm_t comm; | ||||||
|  |   if (myRank == 0) ncclGetUniqueId(&id); | ||||||
|  |   MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0, | ||||||
|  |                      MPI_COMM_WORLD)); | ||||||
|  |   NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank)); | ||||||
|  |  | ||||||
|  |   cudaProfilerStart(); | ||||||
|  |   // for (int threads : {256, 512}) { | ||||||
|  |   //   for (int block_limit = 16; block_limit < 112; block_limit += 4) { | ||||||
|  |   //     run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024); | ||||||
|  |   //   } | ||||||
|  |   // } | ||||||
|  |   for (int sz = 512; sz <= (32 << 20); sz *= 2) { | ||||||
|  |     run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   cudaProfilerStop(); | ||||||
|  |   return EXIT_SUCCESS; | ||||||
|  | } | ||||||
| @ -14,3 +14,24 @@ | |||||||
| #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \ | #define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \ | ||||||
|   AT_DISPATCH_SWITCH(                                             \ |   AT_DISPATCH_SWITCH(                                             \ | ||||||
|     TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) |     TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) | ||||||
|  |  | ||||||
|  | #define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)     \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) | ||||||
|  |  | ||||||
|  | #define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \ | ||||||
|  |   AT_DISPATCH_SWITCH(                                                    \ | ||||||
|  |     TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) | ||||||
|  |      | ||||||
|  | #define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)             \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)      \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)      \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)     \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)       \ | ||||||
|  |   AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) | ||||||
|  |  | ||||||
|  | #define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)             \ | ||||||
|  |   AT_DISPATCH_SWITCH(                                             \ | ||||||
|  |     TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| #include <torch/extension.h> | #include <torch/extension.h> | ||||||
| #include <ATen/cuda/CUDAContext.h> | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  |  | ||||||
| #include "dispatch_utils.h" | #include "dispatch_utils.h" | ||||||
| #include "reduction_utils.cuh" | #include "reduction_utils.cuh" | ||||||
| @ -76,6 +77,7 @@ void rms_norm( | |||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|   dim3 block(std::min(hidden_size, 1024)); |   dim3 block(std::min(hidden_size, 1024)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   VLLM_DISPATCH_FLOATING_TYPES( | ||||||
|     input.scalar_type(), |     input.scalar_type(), | ||||||
| @ -101,6 +103,7 @@ void fused_add_rms_norm( | |||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|   dim3 block(std::min(hidden_size, 1024)); |   dim3 block(std::min(hidden_size, 1024)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   VLLM_DISPATCH_FLOATING_TYPES( | ||||||
|     input.scalar_type(), |     input.scalar_type(), | ||||||
|  | |||||||
							
								
								
									
										108
									
								
								csrc/moe_align_block_size_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								csrc/moe_align_block_size_kernels.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,108 @@ | |||||||
|  | #include <torch/extension.h> | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  |  | ||||||
|  | #include <ATen/ATen.h> | ||||||
|  | #include <THC/THCAtomics.cuh> | ||||||
|  |  | ||||||
|  | #include "cuda_compat.h" | ||||||
|  | #include "dispatch_utils.h" | ||||||
|  |  | ||||||
|  | const static size_t NUM_MAX_EXPERTS = 64; | ||||||
|  | #define CEILDIV(x,y) (((x) + (y) - 1) / (y)) | ||||||
|  |  | ||||||
|  | namespace vllm { | ||||||
|  | template <typename scalar_t> | ||||||
|  | __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,  | ||||||
|  |                                 int32_t *sorted_token_ids,  | ||||||
|  |                                 int32_t *expert_ids,  | ||||||
|  |                                 int32_t *total_tokens_post_pad, | ||||||
|  |                                 int32_t num_experts,  | ||||||
|  |                                 int32_t block_size,  | ||||||
|  |                                 size_t numel) { | ||||||
|  |     const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); | ||||||
|  |     const size_t start_idx = threadIdx.x * tokens_per_thread; | ||||||
|  |     __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS]; | ||||||
|  |     __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1]; | ||||||
|  |     for (int i = 0; i < num_experts; ++i) { | ||||||
|  |         tokens_cnts[threadIdx.x + 1][i] = 0; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |     * In the first step we compute token_cnts[thread_index + 1][expert_index], | ||||||
|  |     * which counts how many tokens in the token shard of thread_index are assigned | ||||||
|  |     * to expert expert_index. | ||||||
|  |     */ | ||||||
|  |     for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||||
|  |         ++tokens_cnts[threadIdx.x + 1][topk_ids[i]];  | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     // For each expert we accumulate the token counts from the different threads. | ||||||
|  |     tokens_cnts[0][threadIdx.x] = 0; | ||||||
|  |     for (int i = 1; i <= blockDim.x; ++i) { | ||||||
|  |         tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |      | ||||||
|  |     // We accumulate the token counts of all experts in thread 0. | ||||||
|  |     if (threadIdx.x == 0) { | ||||||
|  |         cumsum[0] = 0; | ||||||
|  |         for (int i = 1; i <= num_experts; ++i) { | ||||||
|  |             cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size; | ||||||
|  |         } | ||||||
|  |         *total_tokens_post_pad = cumsum[num_experts]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     __syncthreads(); | ||||||
|  |  | ||||||
|  |     /** | ||||||
|  |     * For each expert, each thread processes the tokens of the corresponding blocks | ||||||
|  |     * and stores the corresponding expert_id for each block. | ||||||
|  |     */ | ||||||
|  |     for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) { | ||||||
|  |         expert_ids[i / block_size] = threadIdx.x; | ||||||
|  |     } | ||||||
|  |      | ||||||
|  |     /** | ||||||
|  |     * Each thread processes a token shard, calculating the index of each token after | ||||||
|  |     * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and | ||||||
|  |     * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], | ||||||
|  |     * where * represents a padding value(preset in python). | ||||||
|  |     */ | ||||||
|  |     for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { | ||||||
|  |         int32_t expert_id = topk_ids[i]; | ||||||
|  |         /** The cumsum[expert_id] stores the starting index of the tokens that the | ||||||
|  |         * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id] | ||||||
|  |         * stores the indices of the tokens processed by the expert with expert_id within | ||||||
|  |         * the current thread's token shard. | ||||||
|  |         */ | ||||||
|  |         int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id]; | ||||||
|  |         sorted_token_ids[rank_post_pad] = i; | ||||||
|  |         ++tokens_cnts[threadIdx.x][expert_id]; | ||||||
|  |     } | ||||||
|  | } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void moe_align_block_size( | ||||||
|  |     torch::Tensor topk_ids, | ||||||
|  |     int num_experts, | ||||||
|  |     int block_size, | ||||||
|  |     torch::Tensor sorted_token_ids, | ||||||
|  |     torch::Tensor experts_ids, | ||||||
|  |     torch::Tensor num_tokens_post_pad) { | ||||||
|  |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     assert(num_experts <= NUM_MAX_EXPERTS); | ||||||
|  |     VLLM_DISPATCH_INTEGRAL_TYPES( | ||||||
|  |         topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { | ||||||
|  |         vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>( | ||||||
|  |             topk_ids.data_ptr<scalar_t>(),  | ||||||
|  |             sorted_token_ids.data_ptr<int32_t>(),  | ||||||
|  |             experts_ids.data_ptr<int32_t>(),  | ||||||
|  |             num_tokens_post_pad.data_ptr<int32_t>(),  | ||||||
|  |             num_experts, | ||||||
|  |             block_size, | ||||||
|  |             topk_ids.numel()); | ||||||
|  |     }); | ||||||
|  | } | ||||||
							
								
								
									
										43
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										43
									
								
								csrc/ops.h
									
									
									
									
									
								
							| @ -13,7 +13,8 @@ void paged_attention_v1( | |||||||
|   torch::Tensor& context_lens, |   torch::Tensor& context_lens, | ||||||
|   int block_size, |   int block_size, | ||||||
|   int max_context_len, |   int max_context_len, | ||||||
|   const c10::optional<torch::Tensor>& alibi_slopes); |   const c10::optional<torch::Tensor>& alibi_slopes, | ||||||
|  |   const std::string& kv_cache_dtype); | ||||||
|  |  | ||||||
| void paged_attention_v2( | void paged_attention_v2( | ||||||
|   torch::Tensor& out, |   torch::Tensor& out, | ||||||
| @ -29,7 +30,8 @@ void paged_attention_v2( | |||||||
|   torch::Tensor& context_lens, |   torch::Tensor& context_lens, | ||||||
|   int block_size, |   int block_size, | ||||||
|   int max_context_len, |   int max_context_len, | ||||||
|   const c10::optional<torch::Tensor>& alibi_slopes); |   const c10::optional<torch::Tensor>& alibi_slopes, | ||||||
|  |   const std::string& kv_cache_dtype); | ||||||
|  |  | ||||||
| void rms_norm( | void rms_norm( | ||||||
|   torch::Tensor& out, |   torch::Tensor& out, | ||||||
| @ -70,6 +72,14 @@ torch::Tensor awq_gemm( | |||||||
|   torch::Tensor _scaling_factors, |   torch::Tensor _scaling_factors, | ||||||
|   torch::Tensor _zeros, |   torch::Tensor _zeros, | ||||||
|   int split_k_iters); |   int split_k_iters); | ||||||
|  |  | ||||||
|  | torch::Tensor awq_dequantize( | ||||||
|  |     torch::Tensor _kernel, | ||||||
|  |     torch::Tensor _scaling_factors, | ||||||
|  |     torch::Tensor _zeros, | ||||||
|  |     int split_k_iters, | ||||||
|  |     int thx, | ||||||
|  |     int thy); | ||||||
| #endif | #endif | ||||||
|  |  | ||||||
| void squeezellm_gemm( | void squeezellm_gemm( | ||||||
| @ -89,3 +99,32 @@ torch::Tensor gptq_gemm( | |||||||
| void gptq_shuffle( | void gptq_shuffle( | ||||||
|   torch::Tensor q_weight, |   torch::Tensor q_weight, | ||||||
|   torch::Tensor q_perm); |   torch::Tensor q_perm); | ||||||
|  |  | ||||||
|  | void moe_align_block_size( | ||||||
|  |   torch::Tensor topk_ids, | ||||||
|  |   int num_experts, | ||||||
|  |   int block_size, | ||||||
|  |   torch::Tensor sorted_token_ids, | ||||||
|  |   torch::Tensor experts_ids, | ||||||
|  |   torch::Tensor num_tokens_post_pad); | ||||||
|  |  | ||||||
|  | #ifndef USE_ROCM | ||||||
|  | using fptr_t = uint64_t; | ||||||
|  | fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, | ||||||
|  |                     const std::vector<std::string> &handles, | ||||||
|  |                     const std::vector<int64_t> &offsets, int rank, | ||||||
|  |                     bool full_nvlink); | ||||||
|  | bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size, | ||||||
|  |                       bool full_nvlink); | ||||||
|  | void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); | ||||||
|  | void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, | ||||||
|  |                       torch::Tensor &out); | ||||||
|  | void dispose(fptr_t _fa); | ||||||
|  | int meta_size(); | ||||||
|  | void register_buffer(fptr_t _fa, torch::Tensor &t, | ||||||
|  |                      const std::vector<std::string> &handles, | ||||||
|  |                      const std::vector<int64_t> &offsets); | ||||||
|  | std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa); | ||||||
|  | void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles, | ||||||
|  |                             const std::vector<std::vector<int64_t>> &offsets); | ||||||
|  | #endif | ||||||
|  | |||||||
| @ -1,5 +1,6 @@ | |||||||
| #include <torch/extension.h> | #include <torch/extension.h> | ||||||
| #include <ATen/cuda/CUDAContext.h> | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  |  | ||||||
| #include "cuda_compat.h" | #include "cuda_compat.h" | ||||||
| #include "dispatch_utils.h" | #include "dispatch_utils.h" | ||||||
| @ -43,8 +44,8 @@ __global__ void rotary_embedding_kernel( | |||||||
|   scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] |   scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size] | ||||||
|   const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2] |   const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2] | ||||||
|   const int rot_dim, |   const int rot_dim, | ||||||
|   const int query_stride, |   const int64_t query_stride, | ||||||
|   const int key_stride, |   const int64_t key_stride, | ||||||
|   const int num_heads, |   const int num_heads, | ||||||
|   const int num_kv_heads, |   const int num_kv_heads, | ||||||
|   const int head_size) { |   const int head_size) { | ||||||
| @ -60,7 +61,7 @@ __global__ void rotary_embedding_kernel( | |||||||
|   const int nq = num_heads * embed_dim; |   const int nq = num_heads * embed_dim; | ||||||
|   for (int i = threadIdx.x; i < nq; i += blockDim.x) { |   for (int i = threadIdx.x; i < nq; i += blockDim.x) { | ||||||
|     const int head_idx = i / embed_dim; |     const int head_idx = i / embed_dim; | ||||||
|     const int token_head = token_idx * query_stride + head_idx * head_size; |     const int64_t token_head = token_idx * query_stride + head_idx * head_size; | ||||||
|     const int rot_offset = i % embed_dim; |     const int rot_offset = i % embed_dim; | ||||||
|     apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, |     apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr, | ||||||
|                                               sin_ptr, rot_offset, embed_dim); |                                               sin_ptr, rot_offset, embed_dim); | ||||||
| @ -69,7 +70,7 @@ __global__ void rotary_embedding_kernel( | |||||||
|   const int nk = num_kv_heads * embed_dim; |   const int nk = num_kv_heads * embed_dim; | ||||||
|   for (int i = threadIdx.x; i < nk; i += blockDim.x) { |   for (int i = threadIdx.x; i < nk; i += blockDim.x) { | ||||||
|     const int head_idx = i / embed_dim; |     const int head_idx = i / embed_dim; | ||||||
|     const int token_head = token_idx * key_stride + head_idx * head_size; |     const int64_t token_head = token_idx * key_stride + head_idx * head_size; | ||||||
|     const int rot_offset = i % embed_dim; |     const int rot_offset = i % embed_dim; | ||||||
|     apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, |     apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr, | ||||||
|                                               sin_ptr, rot_offset, embed_dim); |                                               sin_ptr, rot_offset, embed_dim); | ||||||
| @ -89,11 +90,12 @@ void rotary_embedding( | |||||||
|   int rot_dim = cos_sin_cache.size(1); |   int rot_dim = cos_sin_cache.size(1); | ||||||
|   int num_heads = query.size(-1) / head_size; |   int num_heads = query.size(-1) / head_size; | ||||||
|   int num_kv_heads = key.size(-1) / head_size; |   int num_kv_heads = key.size(-1) / head_size; | ||||||
|   int query_stride = query.stride(-2); |   int64_t query_stride = query.stride(-2); | ||||||
|   int key_stride = key.stride(-2); |   int64_t key_stride = key.stride(-2); | ||||||
|  |  | ||||||
|   dim3 grid(num_tokens); |   dim3 grid(num_tokens); | ||||||
|   dim3 block(std::min(num_heads * rot_dim / 2, 512)); |   dim3 block(std::min(num_heads * rot_dim / 2, 512)); | ||||||
|  |   const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); | ||||||
|   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|   VLLM_DISPATCH_FLOATING_TYPES( |   VLLM_DISPATCH_FLOATING_TYPES( | ||||||
|     query.scalar_type(), |     query.scalar_type(), | ||||||
|  | |||||||
							
								
								
									
										217
									
								
								csrc/punica/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								csrc/punica/LICENSE
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,217 @@ | |||||||
|  | Contains code from https://github.com/punica-ai/punica | ||||||
|  |  | ||||||
|  |                                  Apache License | ||||||
|  |                            Version 2.0, January 2004 | ||||||
|  |                         http://www.apache.org/licenses/ | ||||||
|  |  | ||||||
|  |    TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION | ||||||
|  |  | ||||||
|  |    1. Definitions. | ||||||
|  |  | ||||||
|  |       "License" shall mean the terms and conditions for use, reproduction, | ||||||
|  |       and distribution as defined by Sections 1 through 9 of this document. | ||||||
|  |  | ||||||
|  |       "Licensor" shall mean the copyright owner or entity authorized by | ||||||
|  |       the copyright owner that is granting the License. | ||||||
|  |  | ||||||
|  |       "Legal Entity" shall mean the union of the acting entity and all | ||||||
|  |       other entities that control, are controlled by, or are under common | ||||||
|  |       control with that entity. For the purposes of this definition, | ||||||
|  |       "control" means (i) the power, direct or indirect, to cause the | ||||||
|  |       direction or management of such entity, whether by contract or | ||||||
|  |       otherwise, or (ii) ownership of fifty percent (50%) or more of the | ||||||
|  |       outstanding shares, or (iii) beneficial ownership of such entity. | ||||||
|  |  | ||||||
|  |       "You" (or "Your") shall mean an individual or Legal Entity | ||||||
|  |       exercising permissions granted by this License. | ||||||
|  |  | ||||||
|  |       "Source" form shall mean the preferred form for making modifications, | ||||||
|  |       including but not limited to software source code, documentation | ||||||
|  |       source, and configuration files. | ||||||
|  |  | ||||||
|  |       "Object" form shall mean any form resulting from mechanical | ||||||
|  |       transformation or translation of a Source form, including but | ||||||
|  |       not limited to compiled object code, generated documentation, | ||||||
|  |       and conversions to other media types. | ||||||
|  |  | ||||||
|  |       "Work" shall mean the work of authorship, whether in Source or | ||||||
|  |       Object form, made available under the License, as indicated by a | ||||||
|  |       copyright notice that is included in or attached to the work | ||||||
|  |       (an example is provided in the Appendix below). | ||||||
|  |  | ||||||
|  |       "Derivative Works" shall mean any work, whether in Source or Object | ||||||
|  |       form, that is based on (or derived from) the Work and for which the | ||||||
|  |       editorial revisions, annotations, elaborations, or other modifications | ||||||
|  |       represent, as a whole, an original work of authorship. For the purposes | ||||||
|  |       of this License, Derivative Works shall not include works that remain | ||||||
|  |       separable from, or merely link (or bind by name) to the interfaces of, | ||||||
|  |       the Work and Derivative Works thereof. | ||||||
|  |  | ||||||
|  |       "Contribution" shall mean any work of authorship, including | ||||||
|  |       the original version of the Work and any modifications or additions | ||||||
|  |       to that Work or Derivative Works thereof, that is intentionally | ||||||
|  |       submitted to Licensor for inclusion in the Work by the copyright owner | ||||||
|  |       or by an individual or Legal Entity authorized to submit on behalf of | ||||||
|  |       the copyright owner. For the purposes of this definition, "submitted" | ||||||
|  |       means any form of electronic, verbal, or written communication sent | ||||||
|  |       to the Licensor or its representatives, including but not limited to | ||||||
|  |       communication on electronic mailing lists, source code control systems, | ||||||
|  |       and issue tracking systems that are managed by, or on behalf of, the | ||||||
|  |       Licensor for the purpose of discussing and improving the Work, but | ||||||
|  |       excluding communication that is conspicuously marked or otherwise | ||||||
|  |       designated in writing by the copyright owner as "Not a Contribution." | ||||||
|  |  | ||||||
|  |       "Contributor" shall mean Licensor and any individual or Legal Entity | ||||||
|  |       on behalf of whom a Contribution has been received by Licensor and | ||||||
|  |       subsequently incorporated within the Work. | ||||||
|  |  | ||||||
|  |    2. Grant of Copyright License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       copyright license to reproduce, prepare Derivative Works of, | ||||||
|  |       publicly display, publicly perform, sublicense, and distribute the | ||||||
|  |       Work and such Derivative Works in Source or Object form. | ||||||
|  |  | ||||||
|  |    3. Grant of Patent License. Subject to the terms and conditions of | ||||||
|  |       this License, each Contributor hereby grants to You a perpetual, | ||||||
|  |       worldwide, non-exclusive, no-charge, royalty-free, irrevocable | ||||||
|  |       (except as stated in this section) patent license to make, have made, | ||||||
|  |       use, offer to sell, sell, import, and otherwise transfer the Work, | ||||||
|  |       where such license applies only to those patent claims licensable | ||||||
|  |       by such Contributor that are necessarily infringed by their | ||||||
|  |       Contribution(s) alone or by combination of their Contribution(s) | ||||||
|  |       with the Work to which such Contribution(s) was submitted. If You | ||||||
|  |       institute patent litigation against any entity (including a | ||||||
|  |       cross-claim or counterclaim in a lawsuit) alleging that the Work | ||||||
|  |       or a Contribution incorporated within the Work constitutes direct | ||||||
|  |       or contributory patent infringement, then any patent licenses | ||||||
|  |       granted to You under this License for that Work shall terminate | ||||||
|  |       as of the date such litigation is filed. | ||||||
|  |  | ||||||
|  |    4. Redistribution. You may reproduce and distribute copies of the | ||||||
|  |       Work or Derivative Works thereof in any medium, with or without | ||||||
|  |       modifications, and in Source or Object form, provided that You | ||||||
|  |       meet the following conditions: | ||||||
|  |  | ||||||
|  |       (a) You must give any other recipients of the Work or | ||||||
|  |           Derivative Works a copy of this License; and | ||||||
|  |  | ||||||
|  |       (b) You must cause any modified files to carry prominent notices | ||||||
|  |           stating that You changed the files; and | ||||||
|  |  | ||||||
|  |       (c) You must retain, in the Source form of any Derivative Works | ||||||
|  |           that You distribute, all copyright, patent, trademark, and | ||||||
|  |           attribution notices from the Source form of the Work, | ||||||
|  |           excluding those notices that do not pertain to any part of | ||||||
|  |           the Derivative Works; and | ||||||
|  |  | ||||||
|  |       (d) If the Work includes a "NOTICE" text file as part of its | ||||||
|  |           distribution, then any Derivative Works that You distribute must | ||||||
|  |           include a readable copy of the attribution notices contained | ||||||
|  |           within such NOTICE file, excluding those notices that do not | ||||||
|  |           pertain to any part of the Derivative Works, in at least one | ||||||
|  |           of the following places: within a NOTICE text file distributed | ||||||
|  |           as part of the Derivative Works; within the Source form or | ||||||
|  |           documentation, if provided along with the Derivative Works; or, | ||||||
|  |           within a display generated by the Derivative Works, if and | ||||||
|  |           wherever such third-party notices normally appear. The contents | ||||||
|  |           of the NOTICE file are for informational purposes only and | ||||||
|  |           do not modify the License. You may add Your own attribution | ||||||
|  |           notices within Derivative Works that You distribute, alongside | ||||||
|  |           or as an addendum to the NOTICE text from the Work, provided | ||||||
|  |           that such additional attribution notices cannot be construed | ||||||
|  |           as modifying the License. | ||||||
|  |  | ||||||
|  |       You may add Your own copyright statement to Your modifications and | ||||||
|  |       may provide additional or different license terms and conditions | ||||||
|  |       for use, reproduction, or distribution of Your modifications, or | ||||||
|  |       for any such Derivative Works as a whole, provided Your use, | ||||||
|  |       reproduction, and distribution of the Work otherwise complies with | ||||||
|  |       the conditions stated in this License. | ||||||
|  |  | ||||||
|  |    5. Submission of Contributions. Unless You explicitly state otherwise, | ||||||
|  |       any Contribution intentionally submitted for inclusion in the Work | ||||||
|  |       by You to the Licensor shall be under the terms and conditions of | ||||||
|  |       this License, without any additional terms or conditions. | ||||||
|  |       Notwithstanding the above, nothing herein shall supersede or modify | ||||||
|  |       the terms of any separate license agreement you may have executed | ||||||
|  |       with Licensor regarding such Contributions. | ||||||
|  |  | ||||||
|  |    6. Trademarks. This License does not grant permission to use the trade | ||||||
|  |       names, trademarks, service marks, or product names of the Licensor, | ||||||
|  |       except as required for reasonable and customary use in describing the | ||||||
|  |       origin of the Work and reproducing the content of the NOTICE file. | ||||||
|  |  | ||||||
|  |    7. Disclaimer of Warranty. Unless required by applicable law or | ||||||
|  |       agreed to in writing, Licensor provides the Work (and each | ||||||
|  |       Contributor provides its Contributions) on an "AS IS" BASIS, | ||||||
|  |       WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | ||||||
|  |       implied, including, without limitation, any warranties or conditions | ||||||
|  |       of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A | ||||||
|  |       PARTICULAR PURPOSE. You are solely responsible for determining the | ||||||
|  |       appropriateness of using or redistributing the Work and assume any | ||||||
|  |       risks associated with Your exercise of permissions under this License. | ||||||
|  |  | ||||||
|  |    8. Limitation of Liability. In no event and under no legal theory, | ||||||
|  |       whether in tort (including negligence), contract, or otherwise, | ||||||
|  |       unless required by applicable law (such as deliberate and grossly | ||||||
|  |       negligent acts) or agreed to in writing, shall any Contributor be | ||||||
|  |       liable to You for damages, including any direct, indirect, special, | ||||||
|  |       incidental, or consequential damages of any character arising as a | ||||||
|  |       result of this License or out of the use or inability to use the | ||||||
|  |       Work (including but not limited to damages for loss of goodwill, | ||||||
|  |       work stoppage, computer failure or malfunction, or any and all | ||||||
|  |       other commercial damages or losses), even if such Contributor | ||||||
|  |       has been advised of the possibility of such damages. | ||||||
|  |  | ||||||
|  |    9. Accepting Warranty or Additional Liability. While redistributing | ||||||
|  |       the Work or Derivative Works thereof, You may choose to offer, | ||||||
|  |       and charge a fee for, acceptance of support, warranty, indemnity, | ||||||
|  |       or other liability obligations and/or rights consistent with this | ||||||
|  |       License. However, in accepting such obligations, You may act only | ||||||
|  |       on Your own behalf and on Your sole responsibility, not on behalf | ||||||
|  |       of any other Contributor, and only if You agree to indemnify, | ||||||
|  |       defend, and hold each Contributor harmless for any liability | ||||||
|  |       incurred by, or claims asserted against, such Contributor by reason | ||||||
|  |       of your accepting any such warranty or additional liability. | ||||||
|  |  | ||||||
|  |    END OF TERMS AND CONDITIONS | ||||||
|  |  | ||||||
|  |    APPENDIX: How to apply the Apache License to your work. | ||||||
|  |  | ||||||
|  |       To apply the Apache License to your work, attach the following | ||||||
|  |       boilerplate notice, with the fields enclosed by brackets "{}" | ||||||
|  |       replaced with your own identifying information. (Don't include | ||||||
|  |       the brackets!)  The text should be enclosed in the appropriate | ||||||
|  |       comment syntax for the file format. We also recommend that a | ||||||
|  |       file or class name and description of purpose be included on the | ||||||
|  |       same "printed page" as the copyright notice for easier | ||||||
|  |       identification within third-party archives. | ||||||
|  |  | ||||||
|  |    Copyright {yyyy} {name of copyright owner} | ||||||
|  |  | ||||||
|  |    Licensed under the Apache License, Version 2.0 (the "License"); | ||||||
|  |    you may not use this file except in compliance with the License. | ||||||
|  |    You may obtain a copy of the License at | ||||||
|  |  | ||||||
|  |        http://www.apache.org/licenses/LICENSE-2.0 | ||||||
|  |  | ||||||
|  |    Unless required by applicable law or agreed to in writing, software | ||||||
|  |    distributed under the License is distributed on an "AS IS" BASIS, | ||||||
|  |    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
|  |    See the License for the specific language governing permissions and | ||||||
|  |    limitations under the License. | ||||||
|  |  | ||||||
|  | ------------------------------------------------------------------------------------ | ||||||
|  |  | ||||||
|  | This product bundles various third-party components under other open source licenses. | ||||||
|  | This section summarizes those components and their licenses. See licenses/ | ||||||
|  | for text of these licenses. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Apache-2.0 | ||||||
|  | * third_party/nvbench (with LLVM exception) | ||||||
|  | * third_party/flashinfer | ||||||
|  |  | ||||||
|  | BSD-3-Clause: | ||||||
|  | * third_party/cutlass | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half) | ||||||
							
								
								
									
										59
									
								
								csrc/punica/bgmv/bgmv_config.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								csrc/punica/bgmv/bgmv_config.h
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | template <int feat_in, int feat_out, typename in_T, typename out_T, | ||||||
|  |           typename W_T> | ||||||
|  | void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||||
|  |                  const W_T *__restrict__ W, | ||||||
|  |                  const int64_t *__restrict__ indicies, int64_t y_offset, | ||||||
|  |                  int64_t full_y_size, int64_t batch_size, int64_t num_layers, | ||||||
|  |                  int64_t layer_idx, float scale); | ||||||
|  |  | ||||||
|  | // clang-format off | ||||||
|  |  | ||||||
|  | #define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 128) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 256) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 512) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 1024) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 1280) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 1728) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 1792) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 2048) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 2560) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 2752) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 3072) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 3456) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 3584) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 4096) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 5120) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 5504) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 5632) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 6912) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 7168) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 8192) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 9216) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 10240) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 11008) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 12288) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 13824) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 14336) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 16384) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 20480) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 28672) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 32000) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 32256) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 32512) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 32768) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 33024) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 36864) \ | ||||||
|  |     f(in_T, out_T, W_T, narrow, 49152) \ | ||||||
|  | // Keep above in sync with vllm/lora/layers::SamplerWithLoRA | ||||||
|  |  | ||||||
|  | // Keep this in sync with vllm/config::LoRAConfig | ||||||
|  | #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ | ||||||
|  |     FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8)  \ | ||||||
|  |     FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \ | ||||||
|  |     FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ | ||||||
|  |     FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) | ||||||
|  |  | ||||||
|  | // clang-format on | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16) | ||||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,4 @@ | |||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half) | ||||||
							
								
								
									
										294
									
								
								csrc/punica/bgmv/bgmv_impl.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								csrc/punica/bgmv/bgmv_impl.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,294 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <ATen/cuda/CUDAContext.h> | ||||||
|  | #include <cooperative_groups.h> | ||||||
|  | #include <cuda/pipeline> | ||||||
|  | #include <cuda_runtime.h> | ||||||
|  | #include <iostream> | ||||||
|  | #include <stdio.h> | ||||||
|  |  | ||||||
|  | #include "vec_dtypes.cuh" | ||||||
|  |  | ||||||
|  | namespace cg = cooperative_groups; | ||||||
|  |  | ||||||
|  | // nthrs = (32, 4) | ||||||
|  | template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size, | ||||||
|  |           size_t W_copy_size, int tx, int ty, int tz, typename in_T, | ||||||
|  |           typename out_T, typename W_T> | ||||||
|  | __global__ void | ||||||
|  | bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||||
|  |                    const W_T *__restrict__ W, | ||||||
|  |                    const int64_t *__restrict__ indicies, int64_t y_offset, | ||||||
|  |                    int64_t full_y_size, int64_t num_layers, int64_t layer_idx, | ||||||
|  |                    float scale) { | ||||||
|  |   size_t batch_idx = blockIdx.y; | ||||||
|  |   int64_t idx = indicies[batch_idx] * num_layers + layer_idx; | ||||||
|  |   if (idx < 0) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto block = cg::this_thread_block(); | ||||||
|  |   size_t j = blockIdx.x; | ||||||
|  |   constexpr size_t num_pipeline_stages = 2; | ||||||
|  |   constexpr size_t tile_size = tx * ty * vec_size; | ||||||
|  |   __shared__ W_T W_shared[num_pipeline_stages * tile_size]; | ||||||
|  |   __shared__ in_T X_shared[num_pipeline_stages * tile_size]; | ||||||
|  |   __shared__ float y_warpwise[ty]; | ||||||
|  |  | ||||||
|  |   size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; | ||||||
|  |   size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size}; | ||||||
|  |   auto pipe = cuda::make_pipeline(); | ||||||
|  |  | ||||||
|  |   // pipeline load W/X and compute WX; | ||||||
|  |   pipe.producer_acquire(); | ||||||
|  |   cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                      W + (idx * feat_out + j) * feat_in + | ||||||
|  |                          (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                      cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe); | ||||||
|  |   cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                      X + (batch_idx * feat_in) + | ||||||
|  |                          (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                      cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe); | ||||||
|  |   pipe.producer_commit(); | ||||||
|  |   size_t copy_idx, compute_idx; | ||||||
|  |   float y = 0.f; | ||||||
|  |   vec_t<in_T, vec_size> x_vec; | ||||||
|  |   vec_t<W_T, vec_size> w_vec; | ||||||
|  |   size_t tile_idx; | ||||||
|  |  | ||||||
|  | #pragma unroll | ||||||
|  |   for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size; | ||||||
|  |        ++tile_idx) { | ||||||
|  |     copy_idx = tile_idx % num_pipeline_stages; | ||||||
|  |     // pipeline stage: async copy W fragment | ||||||
|  |     pipe.producer_acquire(); | ||||||
|  |     if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) { | ||||||
|  |       cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] + | ||||||
|  |                              (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                          W + (idx * feat_out + j) * feat_in + | ||||||
|  |                              tile_idx * tile_size + | ||||||
|  |                              (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                          cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe); | ||||||
|  |       cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] + | ||||||
|  |                              (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                          X + (batch_idx * feat_in) + tile_idx * tile_size + | ||||||
|  |                              (threadIdx.y * tx + threadIdx.x) * vec_size, | ||||||
|  |                          cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe); | ||||||
|  |     } | ||||||
|  |     pipe.producer_commit(); | ||||||
|  |  | ||||||
|  |     compute_idx = (tile_idx - 1) % num_pipeline_stages; | ||||||
|  |     // pipeline stage: compute WX | ||||||
|  |     pipe.consumer_wait(); | ||||||
|  |     block.sync(); | ||||||
|  |     x_vec.load(X_shared + X_shared_offset[compute_idx] + | ||||||
|  |                (threadIdx.y * tx + threadIdx.x) * vec_size); | ||||||
|  |     w_vec.load(W_shared + W_shared_offset[compute_idx] + | ||||||
|  |                (threadIdx.y * tx + threadIdx.x) * vec_size); | ||||||
|  |     float sum = 0.f; | ||||||
|  | #pragma unroll | ||||||
|  |     for (size_t i = 0; i < vec_size; ++i) { | ||||||
|  |       sum += float(w_vec[i]) * float(x_vec[i]) * scale; | ||||||
|  |     } | ||||||
|  | #pragma unroll | ||||||
|  |     for (size_t offset = tx / 2; offset > 0; offset /= 2) { | ||||||
|  |       sum += __shfl_down_sync(0xffffffff, sum, offset); | ||||||
|  |     } | ||||||
|  |     y_warpwise[threadIdx.y] = sum; | ||||||
|  |     block.sync(); | ||||||
|  | #pragma unroll | ||||||
|  |     for (size_t i = 0; i < ty; ++i) { | ||||||
|  |       y += y_warpwise[i]; | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     block.sync(); | ||||||
|  |     pipe.consumer_release(); | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   compute_idx = (tile_idx - 1) % num_pipeline_stages; | ||||||
|  |   // final pipeline stage | ||||||
|  |   pipe.consumer_wait(); | ||||||
|  |   block.sync(); | ||||||
|  |   x_vec.load(X_shared + X_shared_offset[compute_idx] + | ||||||
|  |              (threadIdx.y * tx + threadIdx.x) * vec_size); | ||||||
|  |   w_vec.load(W_shared + W_shared_offset[compute_idx] + | ||||||
|  |              (threadIdx.y * tx + threadIdx.x) * vec_size); | ||||||
|  |   float sum = 0.f; | ||||||
|  | #pragma unroll | ||||||
|  |   for (size_t i = 0; i < vec_size; ++i) { | ||||||
|  |     sum += float(w_vec[i]) * float(x_vec[i]) * scale; | ||||||
|  |   } | ||||||
|  | #pragma unroll | ||||||
|  |   for (size_t offset = tx / 2; offset > 0; offset /= 2) { | ||||||
|  |     sum += __shfl_down_sync(0xffffffff, sum, offset); | ||||||
|  |   } | ||||||
|  |   y_warpwise[threadIdx.y] = | ||||||
|  |       ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in) | ||||||
|  |           ? sum | ||||||
|  |           : 0.f; | ||||||
|  |   block.sync(); | ||||||
|  | #pragma unroll | ||||||
|  |   for (size_t i = 0; i < ty; ++i) { | ||||||
|  |     y += y_warpwise[i]; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   block.sync(); | ||||||
|  |   pipe.consumer_release(); | ||||||
|  |  | ||||||
|  |   // write Y; | ||||||
|  |   if (block.thread_rank() == 0) { | ||||||
|  |     Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // nthrs = (2, 16, 4) | ||||||
|  | template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz, | ||||||
|  |           typename in_T, typename out_T, typename W_T> | ||||||
|  | __global__ void | ||||||
|  | bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||||
|  |                    const W_T *__restrict__ W, | ||||||
|  |                    const int64_t *__restrict__ indicies, int64_t y_offset, | ||||||
|  |                    int64_t full_y_size, int64_t num_layers, int64_t layer_idx, | ||||||
|  |                    float scale) { | ||||||
|  |   size_t batch_idx = blockIdx.y; | ||||||
|  |   int64_t idx = indicies[batch_idx] * num_layers + layer_idx; | ||||||
|  |  | ||||||
|  |   if (idx < 0) { | ||||||
|  |     return; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   auto block = cg::this_thread_block(); | ||||||
|  |   size_t tile_idx = blockIdx.x; | ||||||
|  |  | ||||||
|  |   // load X; | ||||||
|  |   vec_t<in_T, vec_size> x_vec; | ||||||
|  |   x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size); | ||||||
|  |  | ||||||
|  |   // load W; | ||||||
|  |   vec_t<W_T, vec_size> w_vec; | ||||||
|  |   w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in + | ||||||
|  |              block.thread_rank() * vec_size); | ||||||
|  |  | ||||||
|  |   float sum = 0.f; | ||||||
|  | #pragma unroll | ||||||
|  |   for (size_t i = 0; i < vec_size; ++i) { | ||||||
|  |     sum += float(w_vec[i]) * float(x_vec[i]) * scale; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   cg::thread_block_tile g = cg::tiled_partition<tx>(block); | ||||||
|  | #pragma unroll | ||||||
|  |   for (size_t offset = tx / 2; offset > 0; offset /= 2) { | ||||||
|  |     sum += g.shfl_down(sum, offset); | ||||||
|  |   } | ||||||
|  |   sum = g.shfl(sum, 0); | ||||||
|  |  | ||||||
|  |   if (threadIdx.x == 0) { | ||||||
|  |     Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) + | ||||||
|  |       threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template <int feat_in, int feat_out, typename in_T, typename out_T, | ||||||
|  |           typename W_T> | ||||||
|  | void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, | ||||||
|  |                  const W_T *__restrict__ W, | ||||||
|  |                  const int64_t *__restrict__ indicies, int64_t y_offset, | ||||||
|  |                  int64_t full_y_size, int64_t batch_size, int64_t num_layers, | ||||||
|  |                  int64_t layer_idx, float scale) { | ||||||
|  |   constexpr size_t vec_size = 8; | ||||||
|  |   constexpr int tz = 4; | ||||||
|  |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |  | ||||||
|  |   if constexpr (feat_in < feat_out) { | ||||||
|  |     static_assert(feat_in % vec_size == 0); | ||||||
|  |     constexpr int tx = feat_in / vec_size; | ||||||
|  |  | ||||||
|  |     static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) || | ||||||
|  |                   (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) || | ||||||
|  |                   (8 % tx == 0 && feat_out % (8 / tx * tz) == 0)); | ||||||
|  |  | ||||||
|  |     if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) { | ||||||
|  |       constexpr int ty = 32 / tx; | ||||||
|  |       dim3 nblks(feat_out / (ty * tz), batch_size); | ||||||
|  |       dim3 nthrs(tx, ty, tz); | ||||||
|  |  | ||||||
|  |       bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz> | ||||||
|  |           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, | ||||||
|  |                                         full_y_size, num_layers, layer_idx, | ||||||
|  |                                         scale); | ||||||
|  |     } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) { | ||||||
|  |       constexpr int ty = 16 / tx; | ||||||
|  |       dim3 nblks(feat_out / (ty * tz), batch_size); | ||||||
|  |       dim3 nthrs(tx, ty, tz); | ||||||
|  |  | ||||||
|  |       bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz> | ||||||
|  |           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, | ||||||
|  |                                         full_y_size, num_layers, layer_idx, | ||||||
|  |                                         scale); | ||||||
|  |     } else { | ||||||
|  |       constexpr int ty = 8 / tx; | ||||||
|  |       dim3 nblks(feat_out / (ty * tz), batch_size); | ||||||
|  |       dim3 nthrs(tx, ty, tz); | ||||||
|  |  | ||||||
|  |       bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz> | ||||||
|  |           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, | ||||||
|  |                                         full_y_size, num_layers, layer_idx, | ||||||
|  |                                         scale); | ||||||
|  |     } | ||||||
|  |   } else { | ||||||
|  |     static_assert(feat_in % (vec_size * 32) == 0 || | ||||||
|  |                   feat_in % (vec_size * 16) == 0 || | ||||||
|  |                   feat_in % (vec_size * 8) == 0); | ||||||
|  |  | ||||||
|  |     if constexpr (feat_in % (vec_size * 32) == 0) { | ||||||
|  |       constexpr int tx = 32; | ||||||
|  |       constexpr int ty = 4; | ||||||
|  |  | ||||||
|  |       dim3 nblks(feat_out, batch_size); | ||||||
|  |       dim3 nthrs(tx, ty); | ||||||
|  |  | ||||||
|  |       bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T), | ||||||
|  |                          vec_size * sizeof(W_T), tx, ty, tz> | ||||||
|  |           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, | ||||||
|  |                                         full_y_size, num_layers, layer_idx, | ||||||
|  |                                         scale); | ||||||
|  |     } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) { | ||||||
|  |       constexpr int tx = 32; | ||||||
|  |       constexpr int ty = 4; | ||||||
|  |  | ||||||
|  |       dim3 nblks(feat_out, batch_size); | ||||||
|  |       dim3 nthrs(tx, ty); | ||||||
|  |  | ||||||
|  |       bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2, | ||||||
|  |                          vec_size * sizeof(in_T) / 2, | ||||||
|  |                          vec_size * sizeof(W_T) / 2, tx, ty, tz> | ||||||
|  |           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, | ||||||
|  |                                         full_y_size, num_layers, layer_idx, | ||||||
|  |                                         scale); | ||||||
|  |     } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) { | ||||||
|  |       constexpr int tx = 16; | ||||||
|  |       constexpr int ty = 4; | ||||||
|  |  | ||||||
|  |       dim3 nblks(feat_out, batch_size); | ||||||
|  |       dim3 nthrs(tx, ty); | ||||||
|  |  | ||||||
|  |       bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2, | ||||||
|  |                          vec_size * sizeof(in_T) / 2, | ||||||
|  |                          vec_size * sizeof(W_T) / 2, tx, ty, tz> | ||||||
|  |           <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, | ||||||
|  |                                         full_y_size, num_layers, layer_idx, | ||||||
|  |                                         scale); | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)                         \ | ||||||
|  |   template void bgmv_kernel<feat_in, feat_out>(                                \ | ||||||
|  |       out_T * __restrict__ Y, const in_T *__restrict__ X,                      \ | ||||||
|  |       const W_T *__restrict__ W, const int64_t *__restrict__ indicies,         \ | ||||||
|  |       int64_t y_offset, int64_t full_y_size, int64_t batch_size,               \ | ||||||
|  |       int64_t num_layers, int64_t layer_idx, float scale); | ||||||
|  |  | ||||||
|  | #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide)                      \ | ||||||
|  |   INST_BGMV(narrow, wide, in_T, out_T, W_T)                                    \ | ||||||
|  |   INST_BGMV(wide, narrow, in_T, out_T, W_T) | ||||||
							
								
								
									
										27
									
								
								csrc/punica/bgmv/generator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								csrc/punica/bgmv/generator.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | |||||||
|  | DTYPES = ["fp16", "bf16", "fp32"] | ||||||
|  | DTYPE_MAP = { | ||||||
|  |     "fp16": "nv_half", | ||||||
|  |     "bf16": "nv_bfloat16", | ||||||
|  |     "fp32": "float", | ||||||
|  | } | ||||||
|  |  | ||||||
|  | TEMPLATE = """ | ||||||
|  | #include "bgmv_config.h" | ||||||
|  | #include "bgmv_impl.cuh" | ||||||
|  |  | ||||||
|  | FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) | ||||||
|  | """.lstrip() | ||||||
|  |  | ||||||
|  | for input_dtype in DTYPES: | ||||||
|  |     for output_dtype in DTYPES: | ||||||
|  |         for weight_dtype in DTYPES: | ||||||
|  |             if weight_dtype == "fp32": | ||||||
|  |                 # FP32 weights are not supported. | ||||||
|  |                 continue | ||||||
|  |             kernel_definition = TEMPLATE.format( | ||||||
|  |                 input_dtype=DTYPE_MAP[input_dtype], | ||||||
|  |                 output_dtype=DTYPE_MAP[output_dtype], | ||||||
|  |                 weight_dtype=DTYPE_MAP[weight_dtype]) | ||||||
|  |             filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu" | ||||||
|  |             with open(filename, "w") as f: | ||||||
|  |                 f.write(kernel_definition) | ||||||
							
								
								
									
										1324
									
								
								csrc/punica/bgmv/vec_dtypes.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1324
									
								
								csrc/punica/bgmv/vec_dtypes.cuh
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										563
									
								
								csrc/punica/punica_ops.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										563
									
								
								csrc/punica/punica_ops.cc
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,563 @@ | |||||||
|  | #include <cuda_bf16.h> | ||||||
|  | #include <cuda_fp16.h> | ||||||
|  | #include <torch/extension.h> | ||||||
|  |  | ||||||
|  | #include <cstdint> | ||||||
|  |  | ||||||
|  | #include "bgmv/bgmv_config.h" | ||||||
|  |  | ||||||
|  | namespace { | ||||||
|  |  | ||||||
|  | //====== utils ====== | ||||||
|  |  | ||||||
|  | inline void check_shape(const torch::Tensor &a, const torch::Tensor &b, | ||||||
|  |                         const char *a_name, const char *b_name) { | ||||||
|  |   TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", | ||||||
|  |               a.dim(), " vs ", b.dim()); | ||||||
|  |   for (int i = 0; i < a.dim(); ++i) { | ||||||
|  |     TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, | ||||||
|  |                 ".size(", i, ")"); | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
|  | inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) { | ||||||
|  |   return (uint32_t(a) << 16) | uint32_t(b); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | #define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") | ||||||
|  |  | ||||||
|  | #define CHECK_CONTIGUOUS(x)                                                    \ | ||||||
|  |   TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") | ||||||
|  |  | ||||||
|  | #define CHECK_INPUT(x)                                                         \ | ||||||
|  |   CHECK_CUDA(x);                                                               \ | ||||||
|  |   CHECK_CONTIGUOUS(x) | ||||||
|  |  | ||||||
|  | #define CHECK_DIM(d, x)                                                        \ | ||||||
|  |   TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") | ||||||
|  |  | ||||||
|  | #define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b) | ||||||
|  |  | ||||||
|  | #define CHECK_EQ(a, b)                                                         \ | ||||||
|  |   TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) | ||||||
|  |  | ||||||
|  | //====== bgmv ====== | ||||||
|  |  | ||||||
|  | template <typename in_T, typename out_T, typename W_T> | ||||||
|  | inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, | ||||||
|  |                                const int64_t *lora_indices, | ||||||
|  |                                uint16_t in_features, uint16_t out_features, | ||||||
|  |                                int64_t y_offset, int64_t full_y_size, | ||||||
|  |                                int64_t batch_size, int64_t num_layers, | ||||||
|  |                                int64_t layer_idx, float scale) { | ||||||
|  |   switch (pack_u16(in_features, out_features)) { | ||||||
|  | #define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out)                   \ | ||||||
|  |   case pack_u16(feat_in, feat_out):                                            \ | ||||||
|  |     bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset,            \ | ||||||
|  |                                    full_y_size, batch_size, num_layers,        \ | ||||||
|  |                                    layer_idx, scale);                          \ | ||||||
|  |     break; | ||||||
|  | #define CASE(_in_T, _out_T, _W_T, narrow, wide)                                \ | ||||||
|  |   CASE_ONESIDE(in_T, out_T, W_T, narrow, wide)                                 \ | ||||||
|  |   CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) | ||||||
|  |  | ||||||
|  |     FOR_BGMV_WIDE_NARROW(CASE, _, _, _) | ||||||
|  | #undef CASE | ||||||
|  | #undef CASE_ONESIDE | ||||||
|  |   default: | ||||||
|  |     return false; | ||||||
|  |   } | ||||||
|  |  | ||||||
|  |   return true; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w, | ||||||
|  |                    torch::Tensor indicies, int64_t layer_idx, float scale) { | ||||||
|  |   CHECK_INPUT(y); | ||||||
|  |   CHECK_INPUT(x); | ||||||
|  |   CHECK_INPUT(w); | ||||||
|  |   CHECK_INPUT(indicies); | ||||||
|  |  | ||||||
|  |   CHECK_DIM(2, y); | ||||||
|  |   CHECK_DIM(2, x); | ||||||
|  |   CHECK_DIM(4, w); | ||||||
|  |   CHECK_DIM(1, indicies); | ||||||
|  |  | ||||||
|  |   int64_t B = x.size(0); | ||||||
|  |   int64_t h_in = x.size(1); | ||||||
|  |   int64_t h_out = y.size(1); | ||||||
|  |   int64_t num_layers = w.size(1); | ||||||
|  |   CHECK_EQ(w.size(3), h_in); | ||||||
|  |   CHECK_EQ(w.size(2), h_out); | ||||||
|  |   CHECK_EQ(indicies.size(0), x.size(0)); | ||||||
|  |   CHECK_EQ(y.size(0), x.size(0)); | ||||||
|  |   bool ok = false; | ||||||
|  |   if (h_in < 65536 && h_out < 65536) { | ||||||
|  |     // TODO: See if we can get rid of this massive nested switch | ||||||
|  |     switch (x.scalar_type()) { | ||||||
|  |     case at::ScalarType::Half: | ||||||
|  |       switch (y.scalar_type()) { | ||||||
|  |       case at::ScalarType::Half: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::BFloat16: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::Float: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       break; | ||||||
|  |     case at::ScalarType::BFloat16: | ||||||
|  |       switch (y.scalar_type()) { | ||||||
|  |       case at::ScalarType::Half: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::BFloat16: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::Float: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       break; | ||||||
|  |     case at::ScalarType::Float: | ||||||
|  |       switch (y.scalar_type()) { | ||||||
|  |       case at::ScalarType::Half: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::BFloat16: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::Float: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, 0, | ||||||
|  |                                   h_out, B, num_layers, layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       break; | ||||||
|  |     default: | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, | ||||||
|  |               " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w, | ||||||
|  |                              torch::Tensor indicies, int64_t layer_idx, | ||||||
|  |                              float scale, int64_t h_in, int64_t h_out, | ||||||
|  |                              int64_t y_offset) { | ||||||
|  |   CHECK_INPUT(y); | ||||||
|  |   CHECK_INPUT(x); | ||||||
|  |   CHECK_INPUT(w); | ||||||
|  |   CHECK_INPUT(indicies); | ||||||
|  |  | ||||||
|  |   CHECK_DIM(2, y); | ||||||
|  |   CHECK_DIM(2, x); | ||||||
|  |   CHECK_DIM(4, w); | ||||||
|  |   CHECK_DIM(1, indicies); | ||||||
|  |  | ||||||
|  |   int64_t B = x.size(0); | ||||||
|  |   int64_t num_layers = w.size(1); | ||||||
|  |   int64_t full_y_size = y.size(1); | ||||||
|  |   CHECK_EQ(w.size(3), h_in); | ||||||
|  |   CHECK_EQ(w.size(2), h_out); | ||||||
|  |   CHECK_EQ(indicies.size(0), x.size(0)); | ||||||
|  |   CHECK_EQ(y.size(0), x.size(0)); | ||||||
|  |   bool ok = false; | ||||||
|  |   if (h_in < 65536 && h_out < 65536) { | ||||||
|  |     // TODO: See if we can get rid of this massive nested switch | ||||||
|  |     switch (x.scalar_type()) { | ||||||
|  |     case at::ScalarType::Half: | ||||||
|  |       switch (y.scalar_type()) { | ||||||
|  |       case at::ScalarType::Half: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::BFloat16: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::Float: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       break; | ||||||
|  |     case at::ScalarType::BFloat16: | ||||||
|  |       switch (y.scalar_type()) { | ||||||
|  |       case at::ScalarType::Half: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::BFloat16: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::Float: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       break; | ||||||
|  |     case at::ScalarType::Float: | ||||||
|  |       switch (y.scalar_type()) { | ||||||
|  |       case at::ScalarType::Half: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::BFloat16: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       case at::ScalarType::Float: | ||||||
|  |         switch (w.scalar_type()) { | ||||||
|  |         case at::ScalarType::Half: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_half *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         case at::ScalarType::BFloat16: | ||||||
|  |           ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()), | ||||||
|  |                                   static_cast<float *>(x.data_ptr()), | ||||||
|  |                                   static_cast<nv_bfloat16 *>(w.data_ptr()), | ||||||
|  |                                   indicies.data_ptr<int64_t>(), h_in, h_out, | ||||||
|  |                                   y_offset, full_y_size, B, num_layers, | ||||||
|  |                                   layer_idx, scale); | ||||||
|  |           break; | ||||||
|  |         default: | ||||||
|  |           break; | ||||||
|  |         } | ||||||
|  |         break; | ||||||
|  |       default: | ||||||
|  |         break; | ||||||
|  |       } | ||||||
|  |       break; | ||||||
|  |     default: | ||||||
|  |       break; | ||||||
|  |     } | ||||||
|  |   } | ||||||
|  |   TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out, | ||||||
|  |               " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type()); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace | ||||||
|  |  | ||||||
|  | //====== pybind ====== | ||||||
|  |  | ||||||
|  | #define DEFINE_pybind(name) m.def(#name, &name, #name); | ||||||
|  |  | ||||||
|  | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||||||
|  |   m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv"); | ||||||
|  |   m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level, | ||||||
|  |         "dispatch_bgmv_low_level"); | ||||||
|  | } | ||||||
| @ -51,10 +51,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||||||
| #ifndef USE_ROCM | #ifndef USE_ROCM | ||||||
|   // Quantization ops |   // Quantization ops | ||||||
|   ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); |   ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); | ||||||
|  |   ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); | ||||||
| #endif | #endif | ||||||
|   ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); |   ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); | ||||||
|   ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); |   ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); | ||||||
|   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); |   ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); | ||||||
|  |   ops.def( | ||||||
|  |     "moe_align_block_size", | ||||||
|  |     &moe_align_block_size, | ||||||
|  |     "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."); | ||||||
|  |  | ||||||
|   // Cache ops |   // Cache ops | ||||||
|   pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); |   pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops"); | ||||||
| @ -74,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||||||
|     "gather_cached_kv", |     "gather_cached_kv", | ||||||
|     &gather_cached_kv, |     &gather_cached_kv, | ||||||
|     "Gather key and value from the cache into contiguous QKV tensors"); |     "Gather key and value from the cache into contiguous QKV tensors"); | ||||||
|  |   cache_ops.def( | ||||||
|  |     "convert_fp8_e5m2", | ||||||
|  |     &convert_fp8_e5m2, | ||||||
|  |     "Convert the key and value cache to fp8_e5m2 data type"); | ||||||
|  |  | ||||||
|   // Cuda utils |   // Cuda utils | ||||||
|   pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); |   pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils"); | ||||||
| @ -81,4 +90,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | |||||||
|     "get_device_attribute", |     "get_device_attribute", | ||||||
|     &get_device_attribute, |     &get_device_attribute, | ||||||
|     "Gets the specified device attribute."); |     "Gets the specified device attribute."); | ||||||
|  |  | ||||||
|  |   cuda_utils.def( | ||||||
|  |     "get_max_shared_memory_per_block_device_attribute", | ||||||
|  |     &get_max_shared_memory_per_block_device_attribute, | ||||||
|  |     "Gets the maximum shared memory per block device attribute."); | ||||||
|  |  | ||||||
|  | #ifndef USE_ROCM | ||||||
|  |   // Custom all-reduce kernels | ||||||
|  |   pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); | ||||||
|  |   custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); | ||||||
|  |   custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar"); | ||||||
|  |   custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg"); | ||||||
|  |   custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg"); | ||||||
|  |   custom_ar.def("dispose", &dispose, "dispose"); | ||||||
|  |   custom_ar.def("meta_size", &meta_size, "meta_size"); | ||||||
|  |   custom_ar.def("register_buffer", ®ister_buffer, "register_buffer"); | ||||||
|  |   custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta, | ||||||
|  |                 "get_graph_buffer_ipc_meta"); | ||||||
|  |   custom_ar.def("register_graph_buffers", ®ister_graph_buffers, | ||||||
|  |                 "register_graph_buffers"); | ||||||
|  | #endif | ||||||
|  |  | ||||||
| } | } | ||||||
|  | |||||||
| @ -493,9 +493,117 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in | |||||||
| #endif | #endif | ||||||
| } | } | ||||||
|  |  | ||||||
|  | __global__ void __launch_bounds__(64) dequantize_weights( | ||||||
|  |     int* __restrict__ B, | ||||||
|  |     half* __restrict__ scaling_factors, | ||||||
|  |     int* __restrict__ zeros, | ||||||
|  |     half* __restrict__ C, | ||||||
|  |     int G | ||||||
|  | ) | ||||||
|  | { | ||||||
|  |   int j_factors1 = 4; | ||||||
|  |   int row_stride2 = 4; | ||||||
|  |   int split_k_iters = 1; | ||||||
|  |   static constexpr uint32_t ZERO = 0x0; | ||||||
|  |   half B_shared[32 * (128 + 8)]; | ||||||
|  |  | ||||||
|  |   half* B_shared_ptr2 = B_shared; | ||||||
|  |  | ||||||
|  |   half B_shared_warp[32]; | ||||||
|  |   int OC = 512; | ||||||
|  |  | ||||||
|  |   int N = blockDim.x * gridDim.x;  // 2 | ||||||
|  |   int col = (blockIdx.x * blockDim.x + threadIdx.x); | ||||||
|  |   int row = blockIdx.y * blockDim.y + threadIdx.y; | ||||||
|  |   int index1 = 8 * col + 8 * row * N; | ||||||
|  |   half* C_ptr2 = C + index1; | ||||||
|  |  | ||||||
|  |   int index2 = col + row * N; | ||||||
|  |   int* B_ptr2 = B + index2; | ||||||
|  |  | ||||||
|  |   int index3 = col + (int)(row / G) * N; | ||||||
|  |   int* zeros_ptr2 = zeros + index3; | ||||||
|  |   int index4 = 8 * col + (int)(row / G) * N * 8; | ||||||
|  |   half* scaling_factors_ptr2 = scaling_factors + index4; | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2); | ||||||
|  |     uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded); | ||||||
|  |     uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2); | ||||||
|  | int j=0; | ||||||
|  |  | ||||||
|  |       uint32_t B_loaded = *(uint32_t*)(B_ptr2 + j); | ||||||
|  |       uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded); | ||||||
|  |       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x)); | ||||||
|  |       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO)); | ||||||
|  |       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y)); | ||||||
|  |       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_scale.y), "r"(ZERO)); | ||||||
|  |       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_zero.z)); | ||||||
|  |       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO)); | ||||||
|  |       asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w)); | ||||||
|  |       asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO)); | ||||||
|  |  | ||||||
|  |       *(uint4*)(B_shared_ptr2 + j) = B_loaded_fp16; | ||||||
|  |  | ||||||
|  |   for (int i=0; i<8; ++i) { | ||||||
|  |     *(C_ptr2 + i) = B_shared[i]; | ||||||
|  |   } | ||||||
|  | } | ||||||
|  |  | ||||||
| } // namespace awq | } // namespace awq | ||||||
| } // namespace vllm | } // namespace vllm | ||||||
|  |  | ||||||
|  | torch::Tensor awq_dequantize( | ||||||
|  |     torch::Tensor _kernel, | ||||||
|  |     torch::Tensor _scaling_factors, | ||||||
|  |     torch::Tensor _zeros, | ||||||
|  |     int split_k_iters, | ||||||
|  |     int thx, | ||||||
|  |     int thy) | ||||||
|  | { | ||||||
|  |     int in_c = _kernel.size(0); | ||||||
|  |     int qout_c = _kernel.size(1); | ||||||
|  |     int out_c = qout_c * 8; | ||||||
|  |     int G = in_c / _scaling_factors.size(0); | ||||||
|  |  | ||||||
|  |     int x_thread = thx; | ||||||
|  |     int y_thread = thy; | ||||||
|  |  | ||||||
|  |     int x_blocks = 1; | ||||||
|  |     int y_blocks = 1; | ||||||
|  |     if (thx==0) { | ||||||
|  |       x_thread = qout_c; | ||||||
|  |     } | ||||||
|  |     if (thy==0) { | ||||||
|  |       y_thread = in_c; | ||||||
|  |     } | ||||||
|  |     if (thx==0 && thy==0) { | ||||||
|  |       x_thread = 8; | ||||||
|  |       y_thread = 8; | ||||||
|  |       x_blocks = (int)(qout_c / 8); | ||||||
|  |       y_blocks = (int)(in_c / 8); | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors)); | ||||||
|  |  | ||||||
|  |     auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device()); | ||||||
|  |     at::Tensor _de_kernel = torch::empty({in_c, out_c}, options); | ||||||
|  |  | ||||||
|  |     auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>()); | ||||||
|  |     auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>()); | ||||||
|  |     auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>()); | ||||||
|  |     auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>()); | ||||||
|  |  | ||||||
|  |     dim3 num_blocks(x_blocks, y_blocks); | ||||||
|  |     dim3 threads_per_block(x_thread, y_thread); | ||||||
|  |  | ||||||
|  |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>( | ||||||
|  |         kernel, scaling_factors, zeros, de_kernel, G); | ||||||
|  |  | ||||||
|  |     return _de_kernel; | ||||||
|  | } | ||||||
|  |  | ||||||
| // in_feats: M, IC [float16] | // in_feats: M, IC [float16] | ||||||
| // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] | // kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b] | ||||||
| // scaling_factors: IC // G, OC [float16] | // scaling_factors: IC // G, OC [float16] | ||||||
|  | |||||||
							
								
								
									
										278
									
								
								csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										278
									
								
								csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,278 @@ | |||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | #include <assert.h> | ||||||
|  | #include <stdint.h> | ||||||
|  | #include <float.h> | ||||||
|  | #include <type_traits> | ||||||
|  | #include "../../attention/attention_dtypes.h" | ||||||
|  | #include "../../attention/dtype_float32.cuh" | ||||||
|  | #include "../../attention/dtype_float16.cuh" | ||||||
|  | #include "../../attention/dtype_bfloat16.cuh" | ||||||
|  |  | ||||||
|  | #pragma once | ||||||
|  |  | ||||||
|  | namespace vllm { | ||||||
|  | #ifdef ENABLE_FP8_E5M2 | ||||||
|  | namespace fp8_e5m2_unscaled { | ||||||
|  |  | ||||||
|  | template<typename Tout, typename Tin> | ||||||
|  | __inline__ __device__ Tout vec_conversion(const Tin& x) | ||||||
|  | { | ||||||
|  |     return x; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8 -> half | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a) | ||||||
|  | { | ||||||
|  |     __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); | ||||||
|  |     return res.x; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x2 -> half2 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a) | ||||||
|  | { | ||||||
|  |     union { | ||||||
|  |         uint16_t u16[2]; | ||||||
|  |         uint32_t u32; | ||||||
|  |     } tmp; | ||||||
|  |     __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2); | ||||||
|  |     tmp.u16[0] = res.x; | ||||||
|  |     tmp.u16[1] = res.y; | ||||||
|  |     return tmp.u32; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x4 -> half2x2 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a) | ||||||
|  | { | ||||||
|  |     union { | ||||||
|  |         uint2    u32x2; | ||||||
|  |         uint32_t u32[2]; | ||||||
|  |     } tmp; | ||||||
|  |     tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a); | ||||||
|  |     tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U)); | ||||||
|  |     return tmp.u32x2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x8 -> half2x4 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a) | ||||||
|  | { | ||||||
|  |     union { | ||||||
|  |         uint4 u64x2; | ||||||
|  |         uint2 u64[2]; | ||||||
|  |     } tmp; | ||||||
|  |     tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x); | ||||||
|  |     tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y); | ||||||
|  |     return tmp.u64x2; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8 -> __nv_bfloat16 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) | ||||||
|  | { | ||||||
|  |     // Note there is no direct convert function from fp8 to bf16. | ||||||
|  |     // fp8 -> half | ||||||
|  |     __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2); | ||||||
|  |     // half -> float -> bf16 | ||||||
|  |     float tmp = half_to_float(res.x); | ||||||
|  |     return __float2bfloat16(tmp); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x2 -> __nv_bfloat162 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) | ||||||
|  | { | ||||||
|  |     __nv_bfloat162 res; | ||||||
|  |     res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); | ||||||
|  |     res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x4 -> bf16_4_t | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a) | ||||||
|  | { | ||||||
|  |     bf16_4_t res; | ||||||
|  |     res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); | ||||||
|  |     res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x8 -> bf16_8_t | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a) | ||||||
|  | { | ||||||
|  |     bf16_4_t tmp1, tmp2; | ||||||
|  |     tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x); | ||||||
|  |     tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y); | ||||||
|  |     bf16_8_t res; | ||||||
|  |     res.x = tmp1.x; | ||||||
|  |     res.y = tmp1.y; | ||||||
|  |     res.z = tmp2.x; | ||||||
|  |     res.w = tmp2.y; | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8 -> float | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a) | ||||||
|  | { | ||||||
|  |     // fp8 -> half | ||||||
|  |     uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a); | ||||||
|  |     // half -> float | ||||||
|  |     return half_to_float(tmp); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x2 -> float2 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a) | ||||||
|  | { | ||||||
|  |     // fp8x2 -> half2 | ||||||
|  |     uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a); | ||||||
|  |     // half2 -> float2 | ||||||
|  |     return half2_to_float2(tmp); | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x4 -> float4 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a) | ||||||
|  | { | ||||||
|  |     Float4_ res; | ||||||
|  |     res.x = vec_conversion<float2, uint16_t>((uint16_t)a); | ||||||
|  |     res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U)); | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x8 -> float8 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a) | ||||||
|  | { | ||||||
|  |     Float4_ tmp1, tmp2; | ||||||
|  |     tmp1 = vec_conversion<Float4_, uint32_t>(a.x); | ||||||
|  |     tmp2 = vec_conversion<Float4_, uint32_t>(a.y); | ||||||
|  |     Float8_ res; | ||||||
|  |     res.x = tmp1.x; | ||||||
|  |     res.y = tmp1.y; | ||||||
|  |     res.z = tmp2.x; | ||||||
|  |     res.w = tmp2.y; | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | // half -> fp8 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a) | ||||||
|  | { | ||||||
|  |     __half_raw tmp; | ||||||
|  |     tmp.x = a; | ||||||
|  |     __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2); | ||||||
|  |     return (uint8_t)res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // bf16 -> fp8 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a) | ||||||
|  | { | ||||||
|  | #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 | ||||||
|  |     assert(false); | ||||||
|  | #else | ||||||
|  |     __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2); | ||||||
|  |     return (uint8_t)res; | ||||||
|  | #endif | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // float -> fp8 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a) | ||||||
|  | { | ||||||
|  |     __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2); | ||||||
|  |     return (uint8_t)res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // fp8x4 -> float4 | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a) | ||||||
|  | { | ||||||
|  |     Float4_ tmp = vec_conversion<Float4_, uint32_t>(a); | ||||||
|  |     float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); | ||||||
|  |     return res; | ||||||
|  | } | ||||||
|  |  | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a) | ||||||
|  | { | ||||||
|  |     union { | ||||||
|  |         half2    float16; | ||||||
|  |         uint32_t uint32; | ||||||
|  |     }; | ||||||
|  |  | ||||||
|  |     float16 = __float22half2_rn(a); | ||||||
|  |     return uint32; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a) | ||||||
|  | { | ||||||
|  |     uint2  b; | ||||||
|  |     float2 val; | ||||||
|  |     val.x = a.x.x; | ||||||
|  |     val.y = a.x.y; | ||||||
|  |     b.x   = vec_conversion<uint32_t, float2>(val); | ||||||
|  |  | ||||||
|  |     val.x = a.y.x; | ||||||
|  |     val.y = a.y.y; | ||||||
|  |     b.y   = vec_conversion<uint32_t, float2>(val); | ||||||
|  |  | ||||||
|  |     return b; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a) | ||||||
|  | { | ||||||
|  |     float4 b; | ||||||
|  |     b.x = a.x.x; | ||||||
|  |     b.y = a.x.y; | ||||||
|  |     b.z = a.y.x; | ||||||
|  |     b.w = a.y.y; | ||||||
|  |     return b; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a) | ||||||
|  | { | ||||||
|  |     uint4 b; | ||||||
|  |     b.x = vec_conversion<uint32_t, float2>(a.x); | ||||||
|  |     b.y = vec_conversion<uint32_t, float2>(a.y); | ||||||
|  |     b.z = vec_conversion<uint32_t, float2>(a.z); | ||||||
|  |     b.w = vec_conversion<uint32_t, float2>(a.w); | ||||||
|  |     return b; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) { | ||||||
|  |     __nv_bfloat162 b; | ||||||
|  |     from_float(b, a); | ||||||
|  |     return b; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) { | ||||||
|  |     bf16_4_t b; | ||||||
|  |     from_float(b, a); | ||||||
|  |     return b; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | template<> | ||||||
|  | __inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) { | ||||||
|  |     bf16_8_t b; | ||||||
|  |     from_float(b, a); | ||||||
|  |     return b; | ||||||
|  | } | ||||||
|  |  | ||||||
|  | } // namespace fp8_e5m2_unscaled | ||||||
|  | #endif // ENABLE_FP8_E5M2 | ||||||
|  | } // namespace vllm | ||||||
| @ -28,6 +28,7 @@ namespace gptq { | |||||||
| #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) | #define DIVIDE(x, size) (((x) + (size) - 1) / (size)) | ||||||
|  |  | ||||||
| #if defined(USE_ROCM) | #if defined(USE_ROCM) | ||||||
|  | #include <hipblas/hipblas.h> | ||||||
| __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle, | __host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle, | ||||||
|                                                                hipblasOperation_t transA, |                                                                hipblasOperation_t transA, | ||||||
|                                                                hipblasOperation_t transB, |                                                                hipblasOperation_t transB, | ||||||
| @ -286,7 +287,8 @@ void gemm_half_q_half_cuda_part | |||||||
|  |  | ||||||
|     fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); |     fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count); | ||||||
|  |  | ||||||
|     kernel<<<gridDim, blockDim>>> |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     kernel<<<gridDim, blockDim, 0, stream>>> | ||||||
|     ( |     ( | ||||||
|         a, |         a, | ||||||
|         b_q_weight, |         b_q_weight, | ||||||
| @ -433,7 +435,8 @@ void reconstruct_exllama | |||||||
|     gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); |     gridDim.y = DIVIDE(height, BLOCK_KN_SIZE); | ||||||
|     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); |     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); | ||||||
|  |  | ||||||
|     reconstruct_exllama_kernel<<<gridDim, blockDim>>> |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>> | ||||||
|     ( |     ( | ||||||
|         b_q_weight, |         b_q_weight, | ||||||
|         b_q_perm, |         b_q_perm, | ||||||
| @ -520,12 +523,21 @@ __global__ void gemm_half_q_half_alt_kernel( | |||||||
|             zeros_tmp[tmp_k] = zero; |             zeros_tmp[tmp_k] = zero; | ||||||
|         } |         } | ||||||
|         for (int m = 0; m < b_end; m++) { |         for (int m = 0; m < b_end; m++) { | ||||||
|  | #ifndef USE_ROCM | ||||||
|             res2 = {}; |             res2 = {}; | ||||||
|  | #else | ||||||
|  |             res2.x = __half_as_ushort(__float2half(0)); | ||||||
|  |             res2.y = __half_as_ushort(__float2half(0)); | ||||||
|  | #endif | ||||||
|             res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); |             res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2); | ||||||
|             res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); |             res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2); | ||||||
|             res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); |             res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2); | ||||||
|             res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); |             res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2); | ||||||
|  | #ifndef USE_ROCM | ||||||
|             res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); |             res[m] = __hadd(res[m], __hadd(res2.x, res2.y)); | ||||||
|  | #else | ||||||
|  |             res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y))); | ||||||
|  | #endif | ||||||
|         } |         } | ||||||
|         i += width; |         i += width; | ||||||
|         k += 4; |         k += 4; | ||||||
| @ -557,7 +569,8 @@ void gemm_half_q_half_alt | |||||||
|     gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); |     gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX); | ||||||
|     gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); |     gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE); | ||||||
|  |  | ||||||
|     gemm_half_q_half_alt_kernel<<<gridDim, blockDim>>> |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>> | ||||||
|     ( |     ( | ||||||
|         (const half2*) a, |         (const half2*) a, | ||||||
|         b_q_weight, |         b_q_weight, | ||||||
| @ -629,7 +642,8 @@ void reconstruct_gptq | |||||||
|     blockDim.y = 1; |     blockDim.y = 1; | ||||||
|     gridDim.y = DIVIDE(height, 8); |     gridDim.y = DIVIDE(height, 8); | ||||||
|     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); |     gridDim.x = DIVIDE(width, BLOCK_KN_SIZE); | ||||||
|     reconstruct_gptq_kernel<<<gridDim, blockDim>>> |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>> | ||||||
|     ( |     ( | ||||||
|         b_q_weight, |         b_q_weight, | ||||||
|         b_gptq_scales, |         b_gptq_scales, | ||||||
| @ -784,7 +798,8 @@ void shuffle_exllama_weight | |||||||
|         gridDim.x = DIVIDE(width, THREADS_X); |         gridDim.x = DIVIDE(width, THREADS_X); | ||||||
|         gridDim.y = height / 8; |         gridDim.y = height / 8; | ||||||
|  |  | ||||||
|         make_sequential_kernel<<<gridDim, blockDim>>> |         const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |         make_sequential_kernel<<<gridDim, blockDim, 0, stream>>> | ||||||
|         ( |         ( | ||||||
|             q_weight, |             q_weight, | ||||||
|             new_qweight, |             new_qweight, | ||||||
| @ -803,7 +818,8 @@ void shuffle_exllama_weight | |||||||
|     blockDim.y = 1; |     blockDim.y = 1; | ||||||
|     gridDim.x = DIVIDE(width, THREADS_X); |     gridDim.x = DIVIDE(width, THREADS_X); | ||||||
|     gridDim.y = 1; |     gridDim.y = 1; | ||||||
|     shuffle_kernel<<<gridDim, blockDim>>>(q_weight, height, width); |     const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |     shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width); | ||||||
| } | } | ||||||
|  |  | ||||||
| }  // namespace gptq | }  // namespace gptq | ||||||
|  | |||||||
| @ -7,6 +7,7 @@ | |||||||
| // half-tensor | // half-tensor | ||||||
| #include <c10/cuda/CUDAStream.h> | #include <c10/cuda/CUDAStream.h> | ||||||
| #include <ATen/cuda/CUDATensorMethods.cuh> | #include <ATen/cuda/CUDATensorMethods.cuh> | ||||||
|  | #include <c10/cuda/CUDAGuard.h> | ||||||
|  |  | ||||||
| #define BLOCKWIDTH 128 | #define BLOCKWIDTH 128 | ||||||
| #define BLOCKHEIGHT4 16 | #define BLOCKHEIGHT4 16 | ||||||
| @ -200,7 +201,9 @@ void squeezellm_gemm( | |||||||
|   ); |   ); | ||||||
|   dim3 threads(BLOCKWIDTH); |   dim3 threads(BLOCKWIDTH); | ||||||
|  |  | ||||||
|   vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>( |   const at::cuda::OptionalCUDAGuard device_guard(device_of(vec)); | ||||||
|  |   const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||||||
|  |   vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>( | ||||||
| #ifndef USE_ROCM | #ifndef USE_ROCM | ||||||
|     (half2*) vec.data<at::Half>(), |     (half2*) vec.data<at::Half>(), | ||||||
| #else | #else | ||||||
|  | |||||||
| @ -9,11 +9,15 @@ | |||||||
| # If extensions (or modules to document with autodoc) are in another directory, | # If extensions (or modules to document with autodoc) are in another directory, | ||||||
| # add these directories to sys.path here. If the directory is relative to the | # add these directories to sys.path here. If the directory is relative to the | ||||||
| # documentation root, use os.path.abspath to make it absolute, like shown here. | # documentation root, use os.path.abspath to make it absolute, like shown here. | ||||||
| # |  | ||||||
| # import os |  | ||||||
| # import sys |  | ||||||
| # sys.path.insert(0, os.path.abspath('.')) |  | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import sys | ||||||
|  | from sphinx.ext import autodoc | ||||||
|  | import logging | ||||||
|  |  | ||||||
|  | sys.path.insert(0, os.path.abspath(os.path.join('..', '..'))) | ||||||
|  |  | ||||||
|  | logger = logging.getLogger(__name__) | ||||||
|  |  | ||||||
| # -- Project information ----------------------------------------------------- | # -- Project information ----------------------------------------------------- | ||||||
|  |  | ||||||
| @ -21,7 +25,6 @@ project = 'vLLM' | |||||||
| copyright = '2023, vLLM Team' | copyright = '2023, vLLM Team' | ||||||
| author = 'the vLLM Team' | author = 'the vLLM Team' | ||||||
|  |  | ||||||
|  |  | ||||||
| # -- General configuration --------------------------------------------------- | # -- General configuration --------------------------------------------------- | ||||||
|  |  | ||||||
| # Add any Sphinx extension module names here, as strings. They can be | # Add any Sphinx extension module names here, as strings. They can be | ||||||
| @ -32,6 +35,8 @@ extensions = [ | |||||||
|     "sphinx.ext.viewcode", |     "sphinx.ext.viewcode", | ||||||
|     "sphinx.ext.intersphinx", |     "sphinx.ext.intersphinx", | ||||||
|     "sphinx_copybutton", |     "sphinx_copybutton", | ||||||
|  |     "sphinx.ext.autodoc", | ||||||
|  |     "sphinx.ext.autosummary", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| # Add any paths that contain templates here, relative to this directory. | # Add any paths that contain templates here, relative to this directory. | ||||||
| @ -55,7 +60,6 @@ html_title = project | |||||||
| html_theme = 'sphinx_book_theme' | html_theme = 'sphinx_book_theme' | ||||||
| html_logo = 'assets/logos/vllm-logo-text-light.png' | html_logo = 'assets/logos/vllm-logo-text-light.png' | ||||||
| html_theme_options = { | html_theme_options = { | ||||||
|     'logo_only': True, |  | ||||||
|     'path_to_docs': 'docs/source', |     'path_to_docs': 'docs/source', | ||||||
|     'repository_url': 'https://github.com/vllm-project/vllm', |     'repository_url': 'https://github.com/vllm-project/vllm', | ||||||
|     'use_repository_button': True, |     'use_repository_button': True, | ||||||
| @ -64,4 +68,29 @@ html_theme_options = { | |||||||
| # Add any paths that contain custom static files (such as style sheets) here, | # Add any paths that contain custom static files (such as style sheets) here, | ||||||
| # relative to this directory. They are copied after the builtin static files, | # relative to this directory. They are copied after the builtin static files, | ||||||
| # so a file named "default.css" will overwrite the builtin "default.css". | # so a file named "default.css" will overwrite the builtin "default.css". | ||||||
| html_static_path = ['_static'] | # html_static_path = ['_static'] | ||||||
|  |  | ||||||
|  | # Mock out external dependencies here. | ||||||
|  | autodoc_mock_imports = [ | ||||||
|  |     "torch", "transformers", "psutil", "aioprometheus", "sentencepiece", | ||||||
|  |     "vllm.cuda_utils", "vllm._C" | ||||||
|  | ] | ||||||
|  |  | ||||||
|  | for mock_target in autodoc_mock_imports: | ||||||
|  |     if mock_target in sys.modules: | ||||||
|  |         logger.info( | ||||||
|  |             f"Potentially problematic mock target ({mock_target}) found; " | ||||||
|  |             "autodoc_mock_imports cannot mock modules that have already " | ||||||
|  |             "been loaded into sys.modules when the sphinx build starts.") | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class MockedClassDocumenter(autodoc.ClassDocumenter): | ||||||
|  |     """Remove note about base class when a class is derived from object.""" | ||||||
|  |  | ||||||
|  |     def add_line(self, line: str, source: str, *lineno: int) -> None: | ||||||
|  |         if line == "   Bases: :py:class:`object`": | ||||||
|  |             return | ||||||
|  |         super().add_line(line, source, *lineno) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | autodoc.ClassDocumenter = MockedClassDocumenter | ||||||
|  | |||||||
							
								
								
									
										7
									
								
								docs/source/dev/engine/async_llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/dev/engine/async_llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,7 @@ | |||||||
|  |  | ||||||
|  | AsyncLLMEngine | ||||||
|  | ================================= | ||||||
|  |  | ||||||
|  | .. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine | ||||||
|  |     :members: generate, abort | ||||||
|  |     :show-inheritance: | ||||||
							
								
								
									
										13
									
								
								docs/source/dev/engine/engine_index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								docs/source/dev/engine/engine_index.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,13 @@ | |||||||
|  | vLLM Engine | ||||||
|  | ================================= | ||||||
|  |  | ||||||
|  | .. automodule:: vllm.engine | ||||||
|  | .. currentmodule:: vllm.engine | ||||||
|  |  | ||||||
|  | .. toctree:: | ||||||
|  |    :maxdepth: 2 | ||||||
|  |    :caption: Engines | ||||||
|  |  | ||||||
|  |    llm_engine | ||||||
|  |    async_llm_engine | ||||||
|  |  | ||||||
							
								
								
									
										6
									
								
								docs/source/dev/engine/llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								docs/source/dev/engine/llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | |||||||
|  | LLMEngine | ||||||
|  | ================================= | ||||||
|  |  | ||||||
|  | .. autoclass:: vllm.engine.llm_engine.LLMEngine | ||||||
|  |     :members: add_request, abort_request, step, _init_cache | ||||||
|  |     :show-inheritance: | ||||||
| @ -11,10 +11,10 @@ Requirements | |||||||
| ------------ | ------------ | ||||||
|  |  | ||||||
| * OS: Linux | * OS: Linux | ||||||
| * Python: 3.8 -- 3.11 (Verified on 3.10) | * Python: 3.8 -- 3.11 | ||||||
| * GPU: MI200s | * GPU: MI200s (gfx90a), MI300 (gfx942) | ||||||
| * Pytorch 2.0.1/2.1.1/2.2 | * Pytorch 2.0.1/2.1.1/2.2 | ||||||
| * ROCm 5.7 | * ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9) | ||||||
|  |  | ||||||
| Installation options: | Installation options: | ||||||
|  |  | ||||||
| @ -27,6 +27,8 @@ Installation options: | |||||||
| (Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image | (Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image | ||||||
| --------------------------------------------------------------------------- | --------------------------------------------------------------------------- | ||||||
|  |  | ||||||
|  | This option is for ROCm 5.7 only: | ||||||
|  |  | ||||||
| .. code-block:: console | .. code-block:: console | ||||||
|  |  | ||||||
|     $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4 |     $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4 | ||||||
| @ -50,6 +52,9 @@ Option 2: Build from source | |||||||
|  |  | ||||||
| You can build and install vLLM from source: | You can build and install vLLM from source: | ||||||
|  |  | ||||||
|  | Below instruction is for ROCm 5.7 only.  | ||||||
|  | At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website. | ||||||
|  |  | ||||||
| 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): | 0. Install prerequisites (skip if you are already in an environment/docker with the following installed): | ||||||
|  |  | ||||||
| - `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_ | - `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_ | ||||||
| @ -95,6 +100,23 @@ You can build and install vLLM from source: | |||||||
|  |  | ||||||
| Build a docker image from `Dockerfile.rocm`, and launch a docker container. | Build a docker image from `Dockerfile.rocm`, and launch a docker container. | ||||||
|  |  | ||||||
|  | The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments: | ||||||
|  |  | ||||||
|  | * `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1` | ||||||
|  | * `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942` | ||||||
|  | * `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5` | ||||||
|  |  | ||||||
|  | Their values can be passed in when running ``docker build`` with ``--build-arg`` options. | ||||||
|  |  | ||||||
|  | For example, to build docker image for vllm on ROCm 5.7, you can run: | ||||||
|  |  | ||||||
|  | .. code-block:: console | ||||||
|  |  | ||||||
|  |     $ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \ | ||||||
|  |        -f Dockerfile.rocm -t vllm-rocm .  | ||||||
|  |  | ||||||
|  | To build vllm on ROCm 6.0, you can use the default: | ||||||
|  |  | ||||||
| .. code-block:: console | .. code-block:: console | ||||||
|  |  | ||||||
|     $ docker build -f Dockerfile.rocm -t vllm-rocm .  |     $ docker build -f Dockerfile.rocm -t vllm-rocm .  | ||||||
| @ -116,6 +138,7 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from | |||||||
|  |  | ||||||
| - `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_ | - `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_ | ||||||
| - `Pytorch <https://pytorch.org/>`_ | - `Pytorch <https://pytorch.org/>`_ | ||||||
|  | - `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_ | ||||||
|  |  | ||||||
| 1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_ | 1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_ | ||||||
|  |  | ||||||
| @ -141,3 +164,8 @@ Alternatively, if you plan to install vLLM-ROCm on a local machine or start from | |||||||
|         $ cd vllm |         $ cd vllm | ||||||
|         $ pip install -U -r requirements-rocm.txt |         $ pip install -U -r requirements-rocm.txt | ||||||
|         $ python setup.py install # This may take 5-10 minutes. |         $ python setup.py install # This may take 5-10 minutes. | ||||||
|  |  | ||||||
|  | .. note:: | ||||||
|  |  | ||||||
|  |     - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation. | ||||||
|  |  | ||||||
|  | |||||||
| @ -42,6 +42,10 @@ You can install vLLM using pip: | |||||||
|         $ pip uninstall torch -y |         $ pip uninstall torch -y | ||||||
|         $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 |         $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118 | ||||||
|  |  | ||||||
|  |         $ # Re-install xFormers with CUDA 11.8. | ||||||
|  |         $ pip uninstall xformers -y | ||||||
|  |         $ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118 | ||||||
|  |  | ||||||
|  |  | ||||||
| .. _build_from_source: | .. _build_from_source: | ||||||
|  |  | ||||||
|  | |||||||
| @ -11,6 +11,14 @@ This guide shows how to use vLLM to: | |||||||
|  |  | ||||||
| Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide. | Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide. | ||||||
|  |  | ||||||
|  | .. note:: | ||||||
|  |  | ||||||
|  |     By default, vLLM downloads model from `HuggingFace <https://huggingface.co/>`_. If you would like to use models from `ModelScope <https://www.modelscope.cn>`_ in the following examples, please set the environment variable: | ||||||
|  |  | ||||||
|  |     .. code-block:: shell | ||||||
|  |  | ||||||
|  |         export VLLM_USE_MODELSCOPE=True | ||||||
|  |  | ||||||
| Offline Batched Inference | Offline Batched Inference | ||||||
| ------------------------- | ------------------------- | ||||||
|  |  | ||||||
| @ -40,16 +48,6 @@ Initialize vLLM's engine for offline inference with the ``LLM`` class and the `O | |||||||
|  |  | ||||||
|     llm = LLM(model="facebook/opt-125m") |     llm = LLM(model="facebook/opt-125m") | ||||||
|  |  | ||||||
| Use model from www.modelscope.cn |  | ||||||
|  |  | ||||||
| .. code-block:: shell |  | ||||||
|  |  | ||||||
|     export VLLM_USE_MODELSCOPE=True |  | ||||||
|  |  | ||||||
| .. code-block:: python |  | ||||||
|  |  | ||||||
|     llm = LLM(model="qwen/Qwen-7B-Chat", revision="v1.1.8", trust_remote_code=True) |  | ||||||
|  |  | ||||||
| Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens. | Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens. | ||||||
|  |  | ||||||
| .. code-block:: python | .. code-block:: python | ||||||
| @ -65,49 +63,11 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM | |||||||
|  |  | ||||||
| The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_. | The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_. | ||||||
|  |  | ||||||
|  |  | ||||||
| API Server |  | ||||||
| ---------- |  | ||||||
|  |  | ||||||
| vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests. |  | ||||||
|  |  | ||||||
| Start the server: |  | ||||||
|  |  | ||||||
| .. code-block:: console |  | ||||||
|  |  | ||||||
|     $ python -m vllm.entrypoints.api_server |  | ||||||
|  |  | ||||||
| Use model from www.modelscope.cn |  | ||||||
|  |  | ||||||
| .. code-block:: console |  | ||||||
|  |  | ||||||
|     $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.api_server \ |  | ||||||
|     $    --model="qwen/Qwen-7B-Chat" \ |  | ||||||
|     $    --revision="v1.1.8" \ |  | ||||||
|     $    --trust-remote-code |  | ||||||
|  |  | ||||||
|  |  | ||||||
| By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model. |  | ||||||
|  |  | ||||||
| Query the model in shell: |  | ||||||
|  |  | ||||||
| .. code-block:: console |  | ||||||
|  |  | ||||||
|     $ curl http://localhost:8000/generate \ |  | ||||||
|     $     -d '{ |  | ||||||
|     $         "prompt": "San Francisco is a", |  | ||||||
|     $         "use_beam_search": true, |  | ||||||
|     $         "n": 4, |  | ||||||
|     $         "temperature": 0 |  | ||||||
|     $     }' |  | ||||||
|  |  | ||||||
| See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example. |  | ||||||
|  |  | ||||||
| OpenAI-Compatible Server | OpenAI-Compatible Server | ||||||
| ------------------------ | ------------------------ | ||||||
|  |  | ||||||
| vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. | vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API. | ||||||
| By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints. | By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints. | ||||||
|  |  | ||||||
| Start the server: | Start the server: | ||||||
|  |  | ||||||
| @ -116,13 +76,6 @@ Start the server: | |||||||
|     $ python -m vllm.entrypoints.openai.api_server \ |     $ python -m vllm.entrypoints.openai.api_server \ | ||||||
|     $     --model facebook/opt-125m |     $     --model facebook/opt-125m | ||||||
|  |  | ||||||
| Use model from www.modelscope.cn |  | ||||||
|  |  | ||||||
| .. code-block:: console |  | ||||||
|  |  | ||||||
|     $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \ |  | ||||||
|     $     --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code |  | ||||||
|  |  | ||||||
| By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: | By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument: | ||||||
|  |  | ||||||
| .. code-block:: console | .. code-block:: console | ||||||
| @ -137,6 +90,8 @@ This server can be queried in the same format as OpenAI API. For example, list t | |||||||
|  |  | ||||||
|     $ curl http://localhost:8000/v1/models |     $ curl http://localhost:8000/v1/models | ||||||
|  |  | ||||||
|  | You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header. | ||||||
|  |  | ||||||
| Using OpenAI Completions API with vLLM | Using OpenAI Completions API with vLLM | ||||||
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ | ||||||
|  |  | ||||||
|  | |||||||
| @ -31,7 +31,7 @@ vLLM is fast with: | |||||||
| * Efficient management of attention key and value memory with **PagedAttention** | * Efficient management of attention key and value memory with **PagedAttention** | ||||||
| * Continuous batching of incoming requests | * Continuous batching of incoming requests | ||||||
| * Fast model execution with CUDA/HIP graph | * Fast model execution with CUDA/HIP graph | ||||||
| * Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_ | * Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_, FP8 KV Cache | ||||||
| * Optimized CUDA kernels | * Optimized CUDA kernels | ||||||
|  |  | ||||||
| vLLM is flexible and easy to use with: | vLLM is flexible and easy to use with: | ||||||
| @ -42,6 +42,8 @@ vLLM is flexible and easy to use with: | |||||||
| * Streaming outputs | * Streaming outputs | ||||||
| * OpenAI-compatible API server | * OpenAI-compatible API server | ||||||
| * Support NVIDIA GPUs and AMD GPUs | * Support NVIDIA GPUs and AMD GPUs | ||||||
|  | * (Experimental) Prefix caching support | ||||||
|  | * (Experimental) Multi-lora support | ||||||
|  |  | ||||||
| For more information, check out the following: | For more information, check out the following: | ||||||
|  |  | ||||||
| @ -85,4 +87,16 @@ Documentation | |||||||
|    :maxdepth: 1 |    :maxdepth: 1 | ||||||
|    :caption: Quantization |    :caption: Quantization | ||||||
|  |  | ||||||
|    quantization/auto_awq |    quantization/auto_awq | ||||||
|  |  | ||||||
|  | .. toctree:: | ||||||
|  |    :maxdepth: 2 | ||||||
|  |    :caption: Developer Documentation | ||||||
|  |  | ||||||
|  |    dev/engine/engine_index | ||||||
|  |  | ||||||
|  | Indices and tables | ||||||
|  | ================== | ||||||
|  |  | ||||||
|  | * :ref:`genindex` | ||||||
|  | * :ref:`modindex` | ||||||
|  | |||||||
| @ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following | |||||||
|     +    positions: torch.Tensor, |     +    positions: torch.Tensor, | ||||||
|     +    kv_caches: List[KVCache], |     +    kv_caches: List[KVCache], | ||||||
|     +    input_metadata: InputMetadata, |     +    input_metadata: InputMetadata, | ||||||
|     +    cache_events: Optional[List[torch.cuda.Event]], |     +) -> Optional[SamplerOutput]: | ||||||
|     +) -> SamplerOutput: |  | ||||||
|  |  | ||||||
| 3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. | 1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. | ||||||
| 4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. | 2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. | ||||||
|  |  | ||||||
| .. note:: | .. note:: | ||||||
|     Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. |     Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. | ||||||
|  | |||||||
| @ -89,9 +89,11 @@ Below, you can find an explanation of every engine argument for vLLM: | |||||||
|  |  | ||||||
|     CPU swap space size (GiB) per GPU. |     CPU swap space size (GiB) per GPU. | ||||||
|  |  | ||||||
| .. option:: --gpu-memory-utilization <percentage> | .. option:: --gpu-memory-utilization <fraction> | ||||||
|  |  | ||||||
|     The percentage of GPU memory to be used for the model executor. |     The fraction of GPU memory to be used for the model executor, which can range from 0 to 1.  | ||||||
|  |     For example, a value of 0.5 would imply 50% GPU memory utilization. | ||||||
|  |     If unspecified, will use the default value of 0.9. | ||||||
|  |  | ||||||
| .. option:: --max-num-batched-tokens <tokens> | .. option:: --max-num-batched-tokens <tokens> | ||||||
|  |  | ||||||
|  | |||||||
| @ -23,6 +23,9 @@ Alongside each architecture, we include some popular models that use it. | |||||||
|   * - :code:`ChatGLMModel` |   * - :code:`ChatGLMModel` | ||||||
|     - ChatGLM |     - ChatGLM | ||||||
|     - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. |     - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. | ||||||
|  |   * - :code:`DeciLMForCausalLM` | ||||||
|  |     - DeciLM | ||||||
|  |     - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. | ||||||
|   * - :code:`BloomForCausalLM` |   * - :code:`BloomForCausalLM` | ||||||
|     - BLOOM, BLOOMZ, BLOOMChat |     - BLOOM, BLOOMZ, BLOOMChat | ||||||
|     - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. |     - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. | ||||||
| @ -65,6 +68,12 @@ Alongside each architecture, we include some popular models that use it. | |||||||
|   * - :code:`QWenLMHeadModel` |   * - :code:`QWenLMHeadModel` | ||||||
|     - Qwen |     - Qwen | ||||||
|     - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. |     - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. | ||||||
|  |   * - :code:`Qwen2ForCausalLM` | ||||||
|  |     - Qwen2 | ||||||
|  |     - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. | ||||||
|  |   * - :code:`StableLMEpochForCausalLM` | ||||||
|  |     - StableLM | ||||||
|  |     - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. | ||||||
|   * - :code:`YiForCausalLM` |   * - :code:`YiForCausalLM` | ||||||
|     - Yi |     - Yi | ||||||
|     - :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. |     - :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc. | ||||||
| @ -90,7 +99,7 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr | |||||||
|     If vLLM successfully generates text, it indicates that your model is supported. |     If vLLM successfully generates text, it indicates that your model is supported. | ||||||
|  |  | ||||||
| .. tip:: | .. tip:: | ||||||
|     To use models from `ModelScope <www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable: |     To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable: | ||||||
|  |  | ||||||
|     .. code-block:: shell |     .. code-block:: shell | ||||||
|  |  | ||||||
|  | |||||||
							
								
								
									
										32
									
								
								docs/source/quantization/fp8_e5m2_kv_cache.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										32
									
								
								docs/source/quantization/fp8_e5m2_kv_cache.rst
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,32 @@ | |||||||
|  | .. _fp8_e5m2_kv_cache: | ||||||
|  |  | ||||||
|  | FP8 E5M2 KV Cache | ||||||
|  | ================== | ||||||
|  |  | ||||||
|  | The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits. | ||||||
|  | The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other. | ||||||
|  |  | ||||||
|  | Here is an example of how to enable this feature: | ||||||
|  |  | ||||||
|  | .. code-block:: python | ||||||
|  |     from vllm import LLM, SamplingParams | ||||||
|  |     # Sample prompts. | ||||||
|  |     prompts = [ | ||||||
|  |         "Hello, my name is", | ||||||
|  |         "The president of the United States is", | ||||||
|  |         "The capital of France is", | ||||||
|  |         "The future of AI is", | ||||||
|  |     ] | ||||||
|  |     # Create a sampling params object. | ||||||
|  |     sampling_params = SamplingParams(temperature=0.8, top_p=0.95) | ||||||
|  |     # Create an LLM. | ||||||
|  |     llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2") | ||||||
|  |     # Generate texts from the prompts. The output is a list of RequestOutput objects | ||||||
|  |     # that contain the prompt, generated text, and other information. | ||||||
|  |     outputs = llm.generate(prompts, sampling_params) | ||||||
|  |     # Print the outputs. | ||||||
|  |     for output in outputs: | ||||||
|  |         prompt = output.prompt | ||||||
|  |         generated_text = output.outputs[0].text | ||||||
|  |         print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||||||
|  |  | ||||||
| @ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha | |||||||
|  |  | ||||||
|     print(llm("What is the capital of France ?")) |     print(llm("What is the capital of France ?")) | ||||||
|  |  | ||||||
| Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/extras/integrations/llms/vllm.ipynb>`_ for more details. | Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/docs/integrations/llms/vllm.ipynb>`_ for more details. | ||||||
							
								
								
									
										81
									
								
								examples/gradio_openai_chatbot_webserver.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								examples/gradio_openai_chatbot_webserver.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,81 @@ | |||||||
|  | import argparse | ||||||
|  | from openai import OpenAI | ||||||
|  | import gradio as gr | ||||||
|  |  | ||||||
|  | # Argument parser setup | ||||||
|  | parser = argparse.ArgumentParser( | ||||||
|  |     description='Chatbot Interface with Customizable Parameters') | ||||||
|  | parser.add_argument('--model-url', | ||||||
|  |                     type=str, | ||||||
|  |                     default='http://localhost:8000/v1', | ||||||
|  |                     help='Model URL') | ||||||
|  | parser.add_argument('-m', | ||||||
|  |                     '--model', | ||||||
|  |                     type=str, | ||||||
|  |                     required=True, | ||||||
|  |                     help='Model name for the chatbot') | ||||||
|  | parser.add_argument('--temp', | ||||||
|  |                     type=float, | ||||||
|  |                     default=0.8, | ||||||
|  |                     help='Temperature for text generation') | ||||||
|  | parser.add_argument('--stop-token-ids', | ||||||
|  |                     type=str, | ||||||
|  |                     default='', | ||||||
|  |                     help='Comma-separated stop token IDs') | ||||||
|  | parser.add_argument("--host", type=str, default=None) | ||||||
|  | parser.add_argument("--port", type=int, default=8001) | ||||||
|  |  | ||||||
|  | # Parse the arguments | ||||||
|  | args = parser.parse_args() | ||||||
|  |  | ||||||
|  | # Set OpenAI's API key and API base to use vLLM's API server. | ||||||
|  | openai_api_key = "EMPTY" | ||||||
|  | openai_api_base = args.model_url | ||||||
|  |  | ||||||
|  | # Create an OpenAI client to interact with the API server | ||||||
|  | client = OpenAI( | ||||||
|  |     api_key=openai_api_key, | ||||||
|  |     base_url=openai_api_base, | ||||||
|  | ) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def predict(message, history): | ||||||
|  |     # Convert chat history to OpenAI format | ||||||
|  |     history_openai_format = [{ | ||||||
|  |         "role": "system", | ||||||
|  |         "content": "You are a great ai assistant." | ||||||
|  |     }] | ||||||
|  |     for human, assistant in history: | ||||||
|  |         history_openai_format.append({"role": "user", "content": human}) | ||||||
|  |         history_openai_format.append({ | ||||||
|  |             "role": "assistant", | ||||||
|  |             "content": assistant | ||||||
|  |         }) | ||||||
|  |     history_openai_format.append({"role": "user", "content": message}) | ||||||
|  |  | ||||||
|  |     # Create a chat completion request and send it to the API server | ||||||
|  |     stream = client.chat.completions.create( | ||||||
|  |         model=args.model,  # Model name to use | ||||||
|  |         messages=history_openai_format,  # Chat history | ||||||
|  |         temperature=args.temp,  # Temperature for text generation | ||||||
|  |         stream=True,  # Stream response | ||||||
|  |         extra_body={ | ||||||
|  |             'repetition_penalty': | ||||||
|  |             1, | ||||||
|  |             'stop_token_ids': [ | ||||||
|  |                 int(id.strip()) for id in args.stop_token_ids.split(',') | ||||||
|  |                 if id.strip() | ||||||
|  |             ] if args.stop_token_ids else [] | ||||||
|  |         }) | ||||||
|  |  | ||||||
|  |     # Read and return generated text from response stream | ||||||
|  |     partial_message = "" | ||||||
|  |     for chunk in stream: | ||||||
|  |         partial_message += (chunk.choices[0].delta.content or "") | ||||||
|  |         yield partial_message | ||||||
|  |  | ||||||
|  |  | ||||||
|  | # Create and launch a chat interface with Gradio | ||||||
|  | gr.ChatInterface(predict).queue().launch(server_name=args.host, | ||||||
|  |                                          server_port=args.port, | ||||||
|  |                                          share=True) | ||||||
| @ -47,6 +47,6 @@ if __name__ == "__main__": | |||||||
|     args = parser.parse_args() |     args = parser.parse_args() | ||||||
|  |  | ||||||
|     demo = build_demo() |     demo = build_demo() | ||||||
|     demo.queue(concurrency_count=100).launch(server_name=args.host, |     demo.queue().launch(server_name=args.host, | ||||||
|                                              server_port=args.port, |                         server_port=args.port, | ||||||
|                                              share=True) |                         share=True) | ||||||
|  | |||||||
							
								
								
									
										117
									
								
								examples/multilora_inference.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										117
									
								
								examples/multilora_inference.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,117 @@ | |||||||
|  | """ | ||||||
|  | This example shows how to use the multi-LoRA functionality for offline inference. | ||||||
|  |  | ||||||
|  | Requires HuggingFace credentials for access to Llama2. | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | from typing import Optional, List, Tuple | ||||||
|  |  | ||||||
|  | from huggingface_hub import snapshot_download | ||||||
|  |  | ||||||
|  | from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput | ||||||
|  | from vllm.lora.request import LoRARequest | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def create_test_prompts(lora_path: str) -> List[Tuple[str, SamplingParams]]: | ||||||
|  |     """Create a list of test prompts with their sampling parameters. | ||||||
|  |      | ||||||
|  |     2 requests for base model, 4 requests for the LoRA. We define 2 | ||||||
|  |     different LoRA adapters (using the same model for demo purposes). | ||||||
|  |     Since we also set `max_loras=1`, the expectation is that the requests | ||||||
|  |     with the second LoRA adapter will be ran after all requests with the | ||||||
|  |     first adapter have finished. | ||||||
|  |     """ | ||||||
|  |     return [ | ||||||
|  |         ("A robot may not injure a human being", | ||||||
|  |          SamplingParams(temperature=0.0, | ||||||
|  |                         logprobs=1, | ||||||
|  |                         prompt_logprobs=1, | ||||||
|  |                         max_tokens=128), None), | ||||||
|  |         ("To be or not to be,", | ||||||
|  |          SamplingParams(temperature=0.8, | ||||||
|  |                         top_k=5, | ||||||
|  |                         presence_penalty=0.2, | ||||||
|  |                         max_tokens=128), None), | ||||||
|  |         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", | ||||||
|  |          SamplingParams(temperature=0.0, | ||||||
|  |                         logprobs=1, | ||||||
|  |                         prompt_logprobs=1, | ||||||
|  |                         max_tokens=128, | ||||||
|  |                         stop_token_ids=[32003]), | ||||||
|  |          LoRARequest("sql-lora", 1, lora_path)), | ||||||
|  |         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", | ||||||
|  |          SamplingParams(n=3, | ||||||
|  |                         best_of=3, | ||||||
|  |                         use_beam_search=True, | ||||||
|  |                         temperature=0, | ||||||
|  |                         max_tokens=128, | ||||||
|  |                         stop_token_ids=[32003]), | ||||||
|  |          LoRARequest("sql-lora", 1, lora_path)), | ||||||
|  |         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", | ||||||
|  |          SamplingParams(temperature=0.0, | ||||||
|  |                         logprobs=1, | ||||||
|  |                         prompt_logprobs=1, | ||||||
|  |                         max_tokens=128, | ||||||
|  |                         stop_token_ids=[32003]), | ||||||
|  |          LoRARequest("sql-lora2", 2, lora_path)), | ||||||
|  |         ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", | ||||||
|  |          SamplingParams(n=3, | ||||||
|  |                         best_of=3, | ||||||
|  |                         use_beam_search=True, | ||||||
|  |                         temperature=0, | ||||||
|  |                         max_tokens=128, | ||||||
|  |                         stop_token_ids=[32003]), | ||||||
|  |          LoRARequest("sql-lora", 1, lora_path)), | ||||||
|  |     ] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def process_requests(engine: LLMEngine, | ||||||
|  |                      test_prompts: List[Tuple[str, SamplingParams, | ||||||
|  |                                               Optional[LoRARequest]]]): | ||||||
|  |     """Continuously process a list of prompts and handle the outputs.""" | ||||||
|  |     request_id = 0 | ||||||
|  |  | ||||||
|  |     while test_prompts or engine.has_unfinished_requests(): | ||||||
|  |         if test_prompts: | ||||||
|  |             prompt, sampling_params, lora_request = test_prompts.pop(0) | ||||||
|  |             engine.add_request(str(request_id), | ||||||
|  |                                prompt, | ||||||
|  |                                sampling_params, | ||||||
|  |                                lora_request=lora_request) | ||||||
|  |             request_id += 1 | ||||||
|  |  | ||||||
|  |         request_outputs: List[RequestOutput] = engine.step() | ||||||
|  |  | ||||||
|  |         for request_output in request_outputs: | ||||||
|  |             if request_output.finished: | ||||||
|  |                 print(request_output) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def initialize_engine() -> LLMEngine: | ||||||
|  |     """Initialize the LLMEngine.""" | ||||||
|  |     # max_loras: controls the number of LoRAs that can be used in the same | ||||||
|  |     #   batch. Larger numbers will cause higher memory usage, as each LoRA | ||||||
|  |     #   slot requires its own preallocated tensor. | ||||||
|  |     # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger | ||||||
|  |     #   numbers will cause higher memory usage. If you know that all LoRAs will | ||||||
|  |     #   use the same rank, it is recommended to set this as low as possible. | ||||||
|  |     # max_cpu_loras: controls the size of the CPU LoRA cache. | ||||||
|  |     engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf", | ||||||
|  |                              enable_lora=True, | ||||||
|  |                              max_loras=1, | ||||||
|  |                              max_lora_rank=8, | ||||||
|  |                              max_cpu_loras=2, | ||||||
|  |                              max_num_seqs=256) | ||||||
|  |     return LLMEngine.from_engine_args(engine_args) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def main(): | ||||||
|  |     """Main function that sets up and runs the prompt processing.""" | ||||||
|  |     engine = initialize_engine() | ||||||
|  |     lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test") | ||||||
|  |     test_prompts = create_test_prompts(lora_path) | ||||||
|  |     process_requests(engine, test_prompts) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == '__main__': | ||||||
|  |     main() | ||||||
							
								
								
									
										59
									
								
								examples/offline_inference_with_prefix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								examples/offline_inference_with_prefix.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,59 @@ | |||||||
|  | from vllm import LLM, SamplingParams | ||||||
|  |  | ||||||
|  | prefix = ( | ||||||
|  |     "You are an expert school principal, skilled in effectively managing " | ||||||
|  |     "faculty and staff. Draft 10-15 questions for a potential first grade " | ||||||
|  |     "Head Teacher for my K-12, all-girls', independent school that emphasizes " | ||||||
|  |     "community, joyful discovery, and life-long learning. The candidate is " | ||||||
|  |     "coming in for a first-round panel interview for a 8th grade Math " | ||||||
|  |     "teaching role. They have 5 years of previous teaching experience " | ||||||
|  |     "as an assistant teacher at a co-ed, public school with experience " | ||||||
|  |     "in middle school math teaching. Based on these information, fulfill " | ||||||
|  |     "the following paragraph: ") | ||||||
|  |  | ||||||
|  | # Sample prompts. | ||||||
|  | prompts = [ | ||||||
|  |     "Hello, my name is", | ||||||
|  |     "The president of the United States is", | ||||||
|  |     "The capital of France is", | ||||||
|  |     "The future of AI is", | ||||||
|  | ] | ||||||
|  | # Create a sampling params object. | ||||||
|  | sampling_params = SamplingParams(temperature=0.0) | ||||||
|  |  | ||||||
|  | # Create an LLM. | ||||||
|  | llm = LLM(model="facebook/opt-125m") | ||||||
|  |  | ||||||
|  | generating_prompts = [prefix + prompt for prompt in prompts] | ||||||
|  |  | ||||||
|  | # Generate texts from the prompts. The output is a list of RequestOutput objects | ||||||
|  | # that contain the prompt, generated text, and other information. | ||||||
|  | outputs = llm.generate(generating_prompts, sampling_params) | ||||||
|  | # Print the outputs. | ||||||
|  | for output in outputs: | ||||||
|  |     prompt = output.prompt | ||||||
|  |     generated_text = output.outputs[0].text | ||||||
|  |     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||||||
|  |  | ||||||
|  | print("-" * 80) | ||||||
|  |  | ||||||
|  | # -1 since the last token can change when concatenating prompts. | ||||||
|  | prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1 | ||||||
|  |  | ||||||
|  | # The llm.generate call will batch all prompts and send the batch at once if resources allow. | ||||||
|  | # The prefix will only be cached after the first batch is processed, so we need to call generate once | ||||||
|  | # to calculate the prefix and cache it. | ||||||
|  | outputs = llm.generate(generating_prompts[0], | ||||||
|  |                        sampling_params, | ||||||
|  |                        prefix_pos=[prefix_pos]) | ||||||
|  |  | ||||||
|  | # Subsequent batches can leverage the cached prefix | ||||||
|  | outputs = llm.generate(generating_prompts, | ||||||
|  |                        sampling_params, | ||||||
|  |                        prefix_pos=[prefix_pos] * len(generating_prompts)) | ||||||
|  |  | ||||||
|  | # Print the outputs. You should see the same outputs as before | ||||||
|  | for output in outputs: | ||||||
|  |     prompt = output.prompt | ||||||
|  |     generated_text = output.outputs[0].text | ||||||
|  |     print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") | ||||||
| @ -32,6 +32,5 @@ chat_completion = client.chat.completions.create( | |||||||
|     model=model, |     model=model, | ||||||
| ) | ) | ||||||
|  |  | ||||||
|  |  | ||||||
| print("Chat completion results:") | print("Chat completion results:") | ||||||
| print(chat_completion) | print(chat_completion) | ||||||
|  | |||||||
| @ -21,8 +21,7 @@ completion = client.completions.create( | |||||||
|     echo=False, |     echo=False, | ||||||
|     n=2, |     n=2, | ||||||
|     stream=stream, |     stream=stream, | ||||||
|     logprobs=3 |     logprobs=3) | ||||||
| ) |  | ||||||
|  |  | ||||||
| print("Completion results:") | print("Completion results:") | ||||||
| if stream: | if stream: | ||||||
|  | |||||||
							
								
								
									
										22
									
								
								examples/template_baichuan.jinja
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								examples/template_baichuan.jinja
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,22 @@ | |||||||
|  | {{ (messages|selectattr('role', 'equalto', 'system')|list|last).content|trim if (messages|selectattr('role', 'equalto', 'system')|list) else '' }} | ||||||
|  |  | ||||||
|  | {% for message in messages %} | ||||||
|  | {% if message['role'] == 'user' %} | ||||||
|  | <reserved_106> | ||||||
|  | {{ message['content']|trim -}} | ||||||
|  | {% if not loop.last %} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | {% endif %} | ||||||
|  | {% elif message['role'] == 'assistant' %} | ||||||
|  | <reserved_107> | ||||||
|  | {{ message['content']|trim -}} | ||||||
|  | {% if not loop.last %} | ||||||
|  |  | ||||||
|  |  | ||||||
|  | {% endif %} | ||||||
|  | {% endif %} | ||||||
|  | {% endfor %} | ||||||
|  | {% if add_generation_prompt and messages[-1]['role'] != 'assistant' %} | ||||||
|  | <reserved_107> | ||||||
|  | {% endif %} | ||||||
| @ -71,7 +71,7 @@ format_changed() { | |||||||
|  |  | ||||||
| # Format all files | # Format all files | ||||||
| format_all() { | format_all() { | ||||||
|     yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" vllm tests |     yapf --in-place "${YAPF_FLAGS[@]}" "${YAPF_EXCLUDES[@]}" . | ||||||
| } | } | ||||||
|  |  | ||||||
| ## This flag formats individual files. --files *must* be the first command line | ## This flag formats individual files. --files *must* be the first command line | ||||||
|  | |||||||
| @ -13,4 +13,9 @@ types-setuptools | |||||||
| pytest | pytest | ||||||
| pytest-forked | pytest-forked | ||||||
| pytest-asyncio | pytest-asyncio | ||||||
|  | httpx | ||||||
|  | einops # required for MPT | ||||||
|  | flash_attn # required for HuggingFace's llama implementation | ||||||
|  | openai | ||||||
|  | requests | ||||||
|  | ray | ||||||
							
								
								
									
										9
									
								
								requirements-neuron.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								requirements-neuron.txt
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,9 @@ | |||||||
|  | sentencepiece  # Required for LLaMA tokenizer. | ||||||
|  | numpy | ||||||
|  | transformers-neuronx >= 0.9.0 | ||||||
|  | torch-neuronx >= 2.1.0 | ||||||
|  | neuronx-cc | ||||||
|  | fastapi | ||||||
|  | uvicorn[standard] | ||||||
|  | pydantic >= 2.0  # Required for OpenAI server. | ||||||
|  | aioprometheus[starlette] | ||||||
| @ -2,14 +2,12 @@ ninja  # For faster builds. | |||||||
| typing-extensions>=4.8.0 | typing-extensions>=4.8.0 | ||||||
| starlette | starlette | ||||||
| psutil | psutil | ||||||
| ray >= 2.5.1 | ray >= 2.9 | ||||||
| pandas  # Required for Ray data. |  | ||||||
| pyarrow  # Required for Ray data. |  | ||||||
| sentencepiece  # Required for LLaMA tokenizer. | sentencepiece  # Required for LLaMA tokenizer. | ||||||
| numpy | numpy | ||||||
| tokenizers>=0.15.0 | tokenizers>=0.15.0 | ||||||
| transformers >= 4.36.0  # Required for Mixtral. | transformers >= 4.37.0  # Required for Mixtral. | ||||||
| fastapi | fastapi | ||||||
| uvicorn[standard] | uvicorn[standard] | ||||||
| pydantic == 1.10.13  # Required for OpenAI server. | pydantic >= 2.0  # Required for OpenAI server. | ||||||
| aioprometheus[starlette] | aioprometheus[starlette] | ||||||
|  | |||||||
| @ -1,14 +1,13 @@ | |||||||
| ninja  # For faster builds. | ninja  # For faster builds. | ||||||
| psutil | psutil | ||||||
| ray >= 2.5.1 | ray >= 2.9 | ||||||
| pandas  # Required for Ray data. |  | ||||||
| pyarrow  # Required for Ray data. |  | ||||||
| sentencepiece  # Required for LLaMA tokenizer. | sentencepiece  # Required for LLaMA tokenizer. | ||||||
| numpy | numpy | ||||||
| torch == 2.1.2 | torch == 2.1.2 | ||||||
| transformers >= 4.36.0  # Required for Mixtral. | transformers >= 4.37.0 # Required for Qwen2 | ||||||
| xformers == 0.0.23.post1  # Required for CUDA 12.1. | xformers == 0.0.23.post1  # Required for CUDA 12.1. | ||||||
| fastapi | fastapi | ||||||
| uvicorn[standard] | uvicorn[standard] | ||||||
| pydantic == 1.10.13  # Required for OpenAI server. | pydantic >= 2.0  # Required for OpenAI server. | ||||||
| aioprometheus[starlette] | aioprometheus[starlette] | ||||||
|  | pynvml == 11.5.0 | ||||||
|  | |||||||
							
								
								
									
										148
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										148
									
								
								setup.py
									
									
									
									
									
								
							| @ -1,13 +1,16 @@ | |||||||
|  | import contextlib | ||||||
| import io | import io | ||||||
| import os | import os | ||||||
| import re | import re | ||||||
| import subprocess | import subprocess | ||||||
| from typing import List, Set |  | ||||||
| import warnings | import warnings | ||||||
|  | from pathlib import Path | ||||||
|  | from typing import List, Set | ||||||
|  |  | ||||||
| from packaging.version import parse, Version | from packaging.version import parse, Version | ||||||
| import setuptools | import setuptools | ||||||
| import torch | import torch | ||||||
|  | import torch.utils.cpp_extension as torch_cpp_ext | ||||||
| from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME | from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CUDA_HOME, ROCM_HOME | ||||||
|  |  | ||||||
| ROOT_DIR = os.path.dirname(__file__) | ROOT_DIR = os.path.dirname(__file__) | ||||||
| @ -24,8 +27,17 @@ def _is_hip() -> bool: | |||||||
|     return torch.version.hip is not None |     return torch.version.hip is not None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _is_neuron() -> bool: | ||||||
|  |     torch_neuronx_installed = True | ||||||
|  |     try: | ||||||
|  |         subprocess.run(["neuron-ls"], capture_output=True, check=True) | ||||||
|  |     except FileNotFoundError: | ||||||
|  |         torch_neuronx_installed = False | ||||||
|  |     return torch_neuronx_installed | ||||||
|  |  | ||||||
|  |  | ||||||
| def _is_cuda() -> bool: | def _is_cuda() -> bool: | ||||||
|     return torch.version.cuda is not None |     return (torch.version.cuda is not None) and not _is_neuron() | ||||||
|  |  | ||||||
|  |  | ||||||
| # Compiler flags. | # Compiler flags. | ||||||
| @ -39,6 +51,8 @@ if _is_hip(): | |||||||
|             "Cannot find ROCM_HOME. ROCm must be available to build the package." |             "Cannot find ROCM_HOME. ROCm must be available to build the package." | ||||||
|         ) |         ) | ||||||
|     NVCC_FLAGS += ["-DUSE_ROCM"] |     NVCC_FLAGS += ["-DUSE_ROCM"] | ||||||
|  |     NVCC_FLAGS += ["-U__HIP_NO_HALF_CONVERSIONS__"] | ||||||
|  |     NVCC_FLAGS += ["-U__HIP_NO_HALF_OPERATORS__"] | ||||||
|  |  | ||||||
| if _is_cuda() and CUDA_HOME is None: | if _is_cuda() and CUDA_HOME is None: | ||||||
|     raise RuntimeError( |     raise RuntimeError( | ||||||
| @ -87,6 +101,30 @@ def get_hipcc_rocm_version(): | |||||||
|         return None |         return None | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def glob(pattern: str): | ||||||
|  |     root = Path(__name__).parent | ||||||
|  |     return [str(p) for p in root.glob(pattern)] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def get_neuronxcc_version(): | ||||||
|  |     import sysconfig | ||||||
|  |     site_dir = sysconfig.get_paths()["purelib"] | ||||||
|  |     version_file = os.path.join(site_dir, "neuronxcc", "version", | ||||||
|  |                                 "__init__.py") | ||||||
|  |  | ||||||
|  |     # Check if the command was executed successfully | ||||||
|  |     with open(version_file, "rt") as fp: | ||||||
|  |         content = fp.read() | ||||||
|  |  | ||||||
|  |     # Extract the version using a regular expression | ||||||
|  |     match = re.search(r"__version__ = '(\S+)'", content) | ||||||
|  |     if match: | ||||||
|  |         # Return the version string | ||||||
|  |         return match.group(1) | ||||||
|  |     else: | ||||||
|  |         raise RuntimeError("Could not find HIP version in the output") | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_nvcc_cuda_version(cuda_dir: str) -> Version: | def get_nvcc_cuda_version(cuda_dir: str) -> Version: | ||||||
|     """Get the CUDA version from nvcc. |     """Get the CUDA version from nvcc. | ||||||
|  |  | ||||||
| @ -151,6 +189,8 @@ if _is_cuda() and not compute_capabilities: | |||||||
|                 "GPUs with compute capability below 7.0 are not supported.") |                 "GPUs with compute capability below 7.0 are not supported.") | ||||||
|         compute_capabilities.add(f"{major}.{minor}") |         compute_capabilities.add(f"{major}.{minor}") | ||||||
|  |  | ||||||
|  | ext_modules = [] | ||||||
|  |  | ||||||
| if _is_cuda(): | if _is_cuda(): | ||||||
|     nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) |     nvcc_cuda_version = get_nvcc_cuda_version(CUDA_HOME) | ||||||
|     if not compute_capabilities: |     if not compute_capabilities: | ||||||
| @ -188,6 +228,8 @@ if _is_cuda(): | |||||||
|             raise RuntimeError( |             raise RuntimeError( | ||||||
|                 "CUDA 11.8 or higher is required for compute capability 9.0.") |                 "CUDA 11.8 or higher is required for compute capability 9.0.") | ||||||
|  |  | ||||||
|  |     NVCC_FLAGS_PUNICA = NVCC_FLAGS.copy() | ||||||
|  |  | ||||||
|     # Add target compute capabilities to NVCC flags. |     # Add target compute capabilities to NVCC flags. | ||||||
|     for capability in compute_capabilities: |     for capability in compute_capabilities: | ||||||
|         num = capability[0] + capability[2] |         num = capability[0] + capability[2] | ||||||
| @ -196,6 +238,14 @@ if _is_cuda(): | |||||||
|             NVCC_FLAGS += [ |             NVCC_FLAGS += [ | ||||||
|                 "-gencode", f"arch=compute_{num},code=compute_{num}" |                 "-gencode", f"arch=compute_{num},code=compute_{num}" | ||||||
|             ] |             ] | ||||||
|  |         if int(capability[0]) >= 8: | ||||||
|  |             NVCC_FLAGS_PUNICA += [ | ||||||
|  |                 "-gencode", f"arch=compute_{num},code=sm_{num}" | ||||||
|  |             ] | ||||||
|  |             if capability.endswith("+PTX"): | ||||||
|  |                 NVCC_FLAGS_PUNICA += [ | ||||||
|  |                     "-gencode", f"arch=compute_{num},code=compute_{num}" | ||||||
|  |                 ] | ||||||
|  |  | ||||||
|     # Use NVCC threads to parallelize the build. |     # Use NVCC threads to parallelize the build. | ||||||
|     if nvcc_cuda_version >= Version("11.2"): |     if nvcc_cuda_version >= Version("11.2"): | ||||||
| @ -203,14 +253,52 @@ if _is_cuda(): | |||||||
|         num_threads = min(os.cpu_count(), nvcc_threads) |         num_threads = min(os.cpu_count(), nvcc_threads) | ||||||
|         NVCC_FLAGS += ["--threads", str(num_threads)] |         NVCC_FLAGS += ["--threads", str(num_threads)] | ||||||
|  |  | ||||||
| elif _is_hip(): |     if nvcc_cuda_version >= Version("11.8"): | ||||||
|     amd_arch = get_amdgpu_offload_arch() |         NVCC_FLAGS += ["-DENABLE_FP8_E5M2"] | ||||||
|     if amd_arch not in ROCM_SUPPORTED_ARCHS: |  | ||||||
|         raise RuntimeError( |  | ||||||
|             f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" |  | ||||||
|             f"amdgpu_arch_found: {amd_arch}") |  | ||||||
|  |  | ||||||
| ext_modules = [] |     # changes for punica kernels | ||||||
|  |     NVCC_FLAGS += torch_cpp_ext.COMMON_NVCC_FLAGS | ||||||
|  |     REMOVE_NVCC_FLAGS = [ | ||||||
|  |         '-D__CUDA_NO_HALF_OPERATORS__', | ||||||
|  |         '-D__CUDA_NO_HALF_CONVERSIONS__', | ||||||
|  |         '-D__CUDA_NO_BFLOAT16_CONVERSIONS__', | ||||||
|  |         '-D__CUDA_NO_HALF2_OPERATORS__', | ||||||
|  |     ] | ||||||
|  |     for flag in REMOVE_NVCC_FLAGS: | ||||||
|  |         with contextlib.suppress(ValueError): | ||||||
|  |             torch_cpp_ext.COMMON_NVCC_FLAGS.remove(flag) | ||||||
|  |  | ||||||
|  |     install_punica = bool(int(os.getenv("VLLM_INSTALL_PUNICA_KERNELS", "0"))) | ||||||
|  |     device_count = torch.cuda.device_count() | ||||||
|  |     for i in range(device_count): | ||||||
|  |         major, minor = torch.cuda.get_device_capability(i) | ||||||
|  |         if major < 8: | ||||||
|  |             install_punica = False | ||||||
|  |             break | ||||||
|  |     if install_punica: | ||||||
|  |         ext_modules.append( | ||||||
|  |             CUDAExtension( | ||||||
|  |                 name="vllm._punica_C", | ||||||
|  |                 sources=["csrc/punica/punica_ops.cc"] + | ||||||
|  |                 glob("csrc/punica/bgmv/*.cu"), | ||||||
|  |                 extra_compile_args={ | ||||||
|  |                     "cxx": CXX_FLAGS, | ||||||
|  |                     "nvcc": NVCC_FLAGS_PUNICA, | ||||||
|  |                 }, | ||||||
|  |             )) | ||||||
|  | elif _is_hip(): | ||||||
|  |     amd_archs = os.getenv("GPU_ARCHS") | ||||||
|  |     if amd_archs is None: | ||||||
|  |         amd_archs = get_amdgpu_offload_arch() | ||||||
|  |     for arch in amd_archs.split(";"): | ||||||
|  |         if arch not in ROCM_SUPPORTED_ARCHS: | ||||||
|  |             raise RuntimeError( | ||||||
|  |                 f"Only the following arch is supported: {ROCM_SUPPORTED_ARCHS}" | ||||||
|  |                 f"amdgpu_arch_found: {arch}") | ||||||
|  |         NVCC_FLAGS += [f"--offload-arch={arch}"] | ||||||
|  |  | ||||||
|  | elif _is_neuron(): | ||||||
|  |     neuronxcc_version = get_neuronxcc_version() | ||||||
|  |  | ||||||
| vllm_extension_sources = [ | vllm_extension_sources = [ | ||||||
|     "csrc/cache_kernels.cu", |     "csrc/cache_kernels.cu", | ||||||
| @ -219,23 +307,27 @@ vllm_extension_sources = [ | |||||||
|     "csrc/activation_kernels.cu", |     "csrc/activation_kernels.cu", | ||||||
|     "csrc/layernorm_kernels.cu", |     "csrc/layernorm_kernels.cu", | ||||||
|     "csrc/quantization/squeezellm/quant_cuda_kernel.cu", |     "csrc/quantization/squeezellm/quant_cuda_kernel.cu", | ||||||
|  |     "csrc/quantization/gptq/q_gemm.cu", | ||||||
|     "csrc/cuda_utils_kernels.cu", |     "csrc/cuda_utils_kernels.cu", | ||||||
|  |     "csrc/moe_align_block_size_kernels.cu", | ||||||
|     "csrc/pybind.cpp", |     "csrc/pybind.cpp", | ||||||
| ] | ] | ||||||
|  |  | ||||||
| if _is_cuda(): | if _is_cuda(): | ||||||
|     vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") |     vllm_extension_sources.append("csrc/quantization/awq/gemm_kernels.cu") | ||||||
|     vllm_extension_sources.append("csrc/quantization/gptq/q_gemm.cu") |     vllm_extension_sources.append("csrc/custom_all_reduce.cu") | ||||||
|  |  | ||||||
| vllm_extension = CUDAExtension( | if not _is_neuron(): | ||||||
|     name="vllm._C", |     vllm_extension = CUDAExtension( | ||||||
|     sources=vllm_extension_sources, |         name="vllm._C", | ||||||
|     extra_compile_args={ |         sources=vllm_extension_sources, | ||||||
|         "cxx": CXX_FLAGS, |         extra_compile_args={ | ||||||
|         "nvcc": NVCC_FLAGS, |             "cxx": CXX_FLAGS, | ||||||
|     }, |             "nvcc": NVCC_FLAGS, | ||||||
| ) |         }, | ||||||
| ext_modules.append(vllm_extension) |         libraries=["cuda"] if _is_cuda() else [], | ||||||
|  |     ) | ||||||
|  |     ext_modules.append(vllm_extension) | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_path(*filepath) -> str: | def get_path(*filepath) -> str: | ||||||
| @ -264,6 +356,12 @@ def get_vllm_version() -> str: | |||||||
|         if hipcc_version != MAIN_CUDA_VERSION: |         if hipcc_version != MAIN_CUDA_VERSION: | ||||||
|             rocm_version_str = hipcc_version.replace(".", "")[:3] |             rocm_version_str = hipcc_version.replace(".", "")[:3] | ||||||
|             version += f"+rocm{rocm_version_str}" |             version += f"+rocm{rocm_version_str}" | ||||||
|  |     elif _is_neuron(): | ||||||
|  |         # Get the Neuron version | ||||||
|  |         neuron_version = str(neuronxcc_version) | ||||||
|  |         if neuron_version != MAIN_CUDA_VERSION: | ||||||
|  |             neuron_version_str = neuron_version.replace(".", "")[:3] | ||||||
|  |             version += f"+neuron{neuron_version_str}" | ||||||
|     else: |     else: | ||||||
|         cuda_version = str(nvcc_cuda_version) |         cuda_version = str(nvcc_cuda_version) | ||||||
|         if cuda_version != MAIN_CUDA_VERSION: |         if cuda_version != MAIN_CUDA_VERSION: | ||||||
| @ -287,12 +385,20 @@ def get_requirements() -> List[str]: | |||||||
|     if _is_hip(): |     if _is_hip(): | ||||||
|         with open(get_path("requirements-rocm.txt")) as f: |         with open(get_path("requirements-rocm.txt")) as f: | ||||||
|             requirements = f.read().strip().split("\n") |             requirements = f.read().strip().split("\n") | ||||||
|  |     elif _is_neuron(): | ||||||
|  |         with open(get_path("requirements-neuron.txt")) as f: | ||||||
|  |             requirements = f.read().strip().split("\n") | ||||||
|     else: |     else: | ||||||
|         with open(get_path("requirements.txt")) as f: |         with open(get_path("requirements.txt")) as f: | ||||||
|             requirements = f.read().strip().split("\n") |             requirements = f.read().strip().split("\n") | ||||||
|     return requirements |     return requirements | ||||||
|  |  | ||||||
|  |  | ||||||
|  | package_data = {"vllm": ["py.typed"]} | ||||||
|  | if os.environ.get("VLLM_USE_PRECOMPILED"): | ||||||
|  |     ext_modules = [] | ||||||
|  |     package_data["vllm"].append("*.so") | ||||||
|  |  | ||||||
| setuptools.setup( | setuptools.setup( | ||||||
|     name="vllm", |     name="vllm", | ||||||
|     version=get_vllm_version(), |     version=get_vllm_version(), | ||||||
| @ -320,6 +426,6 @@ setuptools.setup( | |||||||
|     python_requires=">=3.8", |     python_requires=">=3.8", | ||||||
|     install_requires=get_requirements(), |     install_requires=get_requirements(), | ||||||
|     ext_modules=ext_modules, |     ext_modules=ext_modules, | ||||||
|     cmdclass={"build_ext": BuildExtension}, |     cmdclass={"build_ext": BuildExtension} if not _is_neuron() else {}, | ||||||
|     package_data={"vllm": ["py.typed"]}, |     package_data=package_data, | ||||||
| ) | ) | ||||||
|  | |||||||
| @ -8,11 +8,11 @@ import pytest | |||||||
| import requests | import requests | ||||||
|  |  | ||||||
|  |  | ||||||
| def _query_server(prompt: str) -> dict: | def _query_server(prompt: str, max_tokens: int = 5) -> dict: | ||||||
|     response = requests.post("http://localhost:8000/generate", |     response = requests.post("http://localhost:8000/generate", | ||||||
|                              json={ |                              json={ | ||||||
|                                  "prompt": prompt, |                                  "prompt": prompt, | ||||||
|                                  "max_tokens": 100, |                                  "max_tokens": max_tokens, | ||||||
|                                  "temperature": 0, |                                  "temperature": 0, | ||||||
|                                  "ignore_eos": True |                                  "ignore_eos": True | ||||||
|                              }) |                              }) | ||||||
| @ -20,13 +20,22 @@ def _query_server(prompt: str) -> dict: | |||||||
|     return response.json() |     return response.json() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def _query_server_long(prompt: str) -> dict: | ||||||
|  |     return _query_server(prompt, max_tokens=500) | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
| def api_server(): | def api_server(): | ||||||
|     script_path = Path(__file__).parent.joinpath( |     script_path = Path(__file__).parent.joinpath( | ||||||
|         "api_server_async_engine.py").absolute() |         "api_server_async_engine.py").absolute() | ||||||
|     uvicorn_process = subprocess.Popen([ |     uvicorn_process = subprocess.Popen([ | ||||||
|         sys.executable, "-u", |         sys.executable, | ||||||
|         str(script_path), "--model", "facebook/opt-125m" |         "-u", | ||||||
|  |         str(script_path), | ||||||
|  |         "--model", | ||||||
|  |         "facebook/opt-125m", | ||||||
|  |         "--host", | ||||||
|  |         "127.0.0.1", | ||||||
|     ]) |     ]) | ||||||
|     yield |     yield | ||||||
|     uvicorn_process.terminate() |     uvicorn_process.terminate() | ||||||
| @ -44,13 +53,14 @@ def test_api_server(api_server): | |||||||
|     """ |     """ | ||||||
|     with Pool(32) as pool: |     with Pool(32) as pool: | ||||||
|         # Wait until the server is ready |         # Wait until the server is ready | ||||||
|         prompts = ["Hello world"] * 1 |         prompts = ["warm up"] * 1 | ||||||
|         result = None |         result = None | ||||||
|         while not result: |         while not result: | ||||||
|             try: |             try: | ||||||
|                 for _ in pool.map(_query_server, prompts): |                 for r in pool.map(_query_server, prompts): | ||||||
|  |                     result = r | ||||||
|                     break |                     break | ||||||
|             except Exception: |             except requests.exceptions.ConnectionError: | ||||||
|                 time.sleep(1) |                 time.sleep(1) | ||||||
|  |  | ||||||
|         # Actual tests start here |         # Actual tests start here | ||||||
| @ -63,17 +73,22 @@ def test_api_server(api_server): | |||||||
|         assert num_aborted_requests == 0 |         assert num_aborted_requests == 0 | ||||||
|  |  | ||||||
|         # Try with 100 prompts |         # Try with 100 prompts | ||||||
|         prompts = ["Hello world"] * 100 |         prompts = ["test prompt"] * 100 | ||||||
|         for result in pool.map(_query_server, prompts): |         for result in pool.map(_query_server, prompts): | ||||||
|             assert result |             assert result | ||||||
|  |  | ||||||
|  |     with Pool(32) as pool: | ||||||
|         # Cancel requests |         # Cancel requests | ||||||
|         pool.map_async(_query_server, prompts) |         prompts = ["canceled requests"] * 100 | ||||||
|  |         pool.map_async(_query_server_long, prompts) | ||||||
|         time.sleep(0.01) |         time.sleep(0.01) | ||||||
|         pool.terminate() |         pool.terminate() | ||||||
|         pool.join() |         pool.join() | ||||||
|  |  | ||||||
|         # check cancellation stats |         # check cancellation stats | ||||||
|  |         # give it some times to update the stats | ||||||
|  |         time.sleep(1) | ||||||
|  |  | ||||||
|         num_aborted_requests = requests.get( |         num_aborted_requests = requests.get( | ||||||
|             "http://localhost:8000/stats").json()["num_aborted_requests"] |             "http://localhost:8000/stats").json()["num_aborted_requests"] | ||||||
|         assert num_aborted_requests > 0 |         assert num_aborted_requests > 0 | ||||||
| @ -81,6 +96,6 @@ def test_api_server(api_server): | |||||||
|     # check that server still runs after cancellations |     # check that server still runs after cancellations | ||||||
|     with Pool(32) as pool: |     with Pool(32) as pool: | ||||||
|         # Try with 100 prompts |         # Try with 100 prompts | ||||||
|         prompts = ["Hello world"] * 100 |         prompts = ["test prompt after canceled"] * 100 | ||||||
|         for result in pool.map(_query_server, prompts): |         for result in pool.map(_query_server, prompts): | ||||||
|             assert result |             assert result | ||||||
|  | |||||||
| @ -25,6 +25,13 @@ class MockEngine: | |||||||
|         return [RequestOutput( |         return [RequestOutput( | ||||||
|             request_id=self.request_id)] if self.request_id else [] |             request_id=self.request_id)] if self.request_id else [] | ||||||
|  |  | ||||||
|  |     async def encode_request_async( | ||||||
|  |         self, | ||||||
|  |         *args, | ||||||
|  |         **kwargs, | ||||||
|  |     ): | ||||||
|  |         return [1] | ||||||
|  |  | ||||||
|     def generate(self, request_id): |     def generate(self, request_id): | ||||||
|         self.request_id = request_id |         self.request_id = request_id | ||||||
|  |  | ||||||
| @ -35,6 +42,10 @@ class MockEngine: | |||||||
|         del kwargs  # Unused |         del kwargs  # Unused | ||||||
|         self.add_request_calls += 1 |         self.add_request_calls += 1 | ||||||
|  |  | ||||||
|  |     async def add_request_async(self, **kwargs): | ||||||
|  |         del kwargs  # Unused | ||||||
|  |         self.add_request_calls += 1 | ||||||
|  |  | ||||||
|     def abort_request(self, request_id): |     def abort_request(self, request_id): | ||||||
|         del request_id  # Unused |         del request_id  # Unused | ||||||
|         self.abort_request_calls += 1 |         self.abort_request_calls += 1 | ||||||
|  | |||||||
| @ -1,10 +1,16 @@ | |||||||
| from argparse import Namespace |  | ||||||
| from dataclasses import dataclass | from dataclasses import dataclass | ||||||
|  | import os | ||||||
|  | import pathlib | ||||||
| 
 | 
 | ||||||
| import pytest | import pytest | ||||||
| from fastapi.testclient import TestClient |  | ||||||
| 
 | 
 | ||||||
| from vllm.entrypoints.openai.api_server import * | from vllm.transformers_utils.tokenizer import get_tokenizer | ||||||
|  | from vllm.entrypoints.openai.serving_chat import OpenAIServingChat | ||||||
|  | from vllm.entrypoints.openai.protocol import ChatCompletionRequest | ||||||
|  | 
 | ||||||
|  | chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( | ||||||
|  |     __file__))).parent.parent / "examples/template_chatml.jinja" | ||||||
|  | assert chatml_jinja_path.exists() | ||||||
| 
 | 
 | ||||||
| # Define models, templates, and their corresponding expected outputs | # Define models, templates, and their corresponding expected outputs | ||||||
| MODEL_TEMPLATE_GENERATON_OUTPUT = [ | MODEL_TEMPLATE_GENERATON_OUTPUT = [ | ||||||
| @ -12,8 +18,7 @@ MODEL_TEMPLATE_GENERATON_OUTPUT = [ | |||||||
|      "Hello</s>Hi there!</s>What is the capital of</s>"), |      "Hello</s>Hi there!</s>What is the capital of</s>"), | ||||||
|     ("facebook/opt-125m", None, False, |     ("facebook/opt-125m", None, False, | ||||||
|      "Hello</s>Hi there!</s>What is the capital of</s>"), |      "Hello</s>Hi there!</s>What is the capital of</s>"), | ||||||
|     ("facebook/opt-125m", "../../examples/template_chatml.jinja", True, |     ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user | ||||||
|      """<|im_start|>user |  | ||||||
| Hello<|im_end|> | Hello<|im_end|> | ||||||
| <|im_start|>assistant | <|im_start|>assistant | ||||||
| Hi there!<|im_end|> | Hi there!<|im_end|> | ||||||
| @ -21,8 +26,7 @@ Hi there!<|im_end|> | |||||||
| What is the capital of<|im_end|> | What is the capital of<|im_end|> | ||||||
| <|im_start|>assistant | <|im_start|>assistant | ||||||
| """), | """), | ||||||
|     ("facebook/opt-125m", "../../examples/template_chatml.jinja", False, |     ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user | ||||||
|      """<|im_start|>user |  | ||||||
| Hello<|im_end|> | Hello<|im_end|> | ||||||
| <|im_start|>assistant | <|im_start|>assistant | ||||||
| Hi there!<|im_end|> | Hi there!<|im_end|> | ||||||
| @ -44,7 +48,6 @@ TEST_MESSAGES = [ | |||||||
|         'content': 'What is the capital of' |         'content': 'What is the capital of' | ||||||
|     }, |     }, | ||||||
| ] | ] | ||||||
| client = TestClient(app) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @dataclass | @dataclass | ||||||
| @ -52,14 +55,17 @@ class MockTokenizer: | |||||||
|     chat_template = None |     chat_template = None | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @dataclass | ||||||
|  | class MockServingChat: | ||||||
|  |     tokenizer: MockTokenizer | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def test_load_chat_template(): | def test_load_chat_template(): | ||||||
|     # Testing chatml template |     # Testing chatml template | ||||||
|     template = "../../examples/template_chatml.jinja" |  | ||||||
|     mock_args = Namespace(chat_template=template) |  | ||||||
|     tokenizer = MockTokenizer() |     tokenizer = MockTokenizer() | ||||||
| 
 |     mock_serving_chat = MockServingChat(tokenizer) | ||||||
|     # Call the function with the mocked args |     OpenAIServingChat._load_chat_template(mock_serving_chat, | ||||||
|     load_chat_template(mock_args, tokenizer) |                                           chat_template=chatml_jinja_path) | ||||||
| 
 | 
 | ||||||
|     template_content = tokenizer.chat_template |     template_content = tokenizer.chat_template | ||||||
| 
 | 
 | ||||||
| @ -73,11 +79,11 @@ def test_load_chat_template(): | |||||||
| def test_no_load_chat_template(): | def test_no_load_chat_template(): | ||||||
|     # Testing chatml template |     # Testing chatml template | ||||||
|     template = "../../examples/does_not_exist" |     template = "../../examples/does_not_exist" | ||||||
|     mock_args = Namespace(chat_template=template) |  | ||||||
|     tokenizer = MockTokenizer() |     tokenizer = MockTokenizer() | ||||||
| 
 | 
 | ||||||
|     # Call the function with the mocked args |     mock_serving_chat = MockServingChat(tokenizer) | ||||||
|     load_chat_template(mock_args, tokenizer=tokenizer) |     OpenAIServingChat._load_chat_template(mock_serving_chat, | ||||||
|  |                                           chat_template=template) | ||||||
|     template_content = tokenizer.chat_template |     template_content = tokenizer.chat_template | ||||||
| 
 | 
 | ||||||
|     # Test assertions |     # Test assertions | ||||||
| @ -94,9 +100,9 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, | |||||||
|                               expected_output): |                               expected_output): | ||||||
|     # Initialize the tokenizer |     # Initialize the tokenizer | ||||||
|     tokenizer = get_tokenizer(tokenizer_name=model) |     tokenizer = get_tokenizer(tokenizer_name=model) | ||||||
| 
 |     mock_serving_chat = MockServingChat(tokenizer) | ||||||
|     mock_args = Namespace(chat_template=template) |     OpenAIServingChat._load_chat_template(mock_serving_chat, | ||||||
|     load_chat_template(mock_args, tokenizer) |                                           chat_template=template) | ||||||
| 
 | 
 | ||||||
|     # Create a mock request object using keyword arguments |     # Create a mock request object using keyword arguments | ||||||
|     mock_request = ChatCompletionRequest( |     mock_request = ChatCompletionRequest( | ||||||
| @ -112,8 +118,3 @@ async def test_get_gen_prompt(model, template, add_generation_prompt, | |||||||
| 
 | 
 | ||||||
|     # Test assertion |     # Test assertion | ||||||
|     assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}" |     assert result == expected_output, f"The generated prompt does not match the expected output for model {model} and template {template}" | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def test_health_endpoint(): |  | ||||||
|     response = client.get("/health") |  | ||||||
|     assert response.status_code == 200 |  | ||||||
| @ -8,8 +8,9 @@ from transformers import AutoModelForCausalLM | |||||||
| from vllm import LLM, SamplingParams | from vllm import LLM, SamplingParams | ||||||
| from vllm.transformers_utils.tokenizer import get_tokenizer | from vllm.transformers_utils.tokenizer import get_tokenizer | ||||||
|  |  | ||||||
| _TEST_PROMPTS = ["prompts/example.txt"] | _TEST_DIR = os.path.dirname(__file__) | ||||||
| _LONG_PROMPTS = ["prompts/summary.txt"] | _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] | ||||||
|  | _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] | ||||||
|  |  | ||||||
|  |  | ||||||
| def _read_prompts(filename: str) -> str: | def _read_prompts(filename: str) -> str: | ||||||
| @ -24,7 +25,7 @@ def _read_prompts(filename: str) -> str: | |||||||
| def example_prompts() -> List[str]: | def example_prompts() -> List[str]: | ||||||
|     prompts = [] |     prompts = [] | ||||||
|     for filename in _TEST_PROMPTS: |     for filename in _TEST_PROMPTS: | ||||||
|         prompts += _read_prompts(os.path.join("tests", filename)) |         prompts += _read_prompts(filename) | ||||||
|     return prompts |     return prompts | ||||||
|  |  | ||||||
|  |  | ||||||
| @ -32,7 +33,7 @@ def example_prompts() -> List[str]: | |||||||
| def example_long_prompts() -> List[str]: | def example_long_prompts() -> List[str]: | ||||||
|     prompts = [] |     prompts = [] | ||||||
|     for filename in _LONG_PROMPTS: |     for filename in _LONG_PROMPTS: | ||||||
|         prompts += _read_prompts(os.path.join("tests", filename)) |         prompts += _read_prompts(filename) | ||||||
|     return prompts |     return prompts | ||||||
|  |  | ||||||
|  |  | ||||||
|  | |||||||
| @ -2,32 +2,20 @@ | |||||||
|  |  | ||||||
| Run `pytest tests/distributed/test_comm_ops.py --forked`. | Run `pytest tests/distributed/test_comm_ops.py --forked`. | ||||||
| """ | """ | ||||||
| from multiprocessing import Process, set_start_method |  | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
|  | import ray | ||||||
|  |  | ||||||
| from vllm.config import ParallelConfig |  | ||||||
| from vllm.engine.ray_utils import get_open_port |  | ||||||
| from vllm.model_executor.parallel_utils.communication_op import ( | from vllm.model_executor.parallel_utils.communication_op import ( | ||||||
|     tensor_model_parallel_all_reduce, |     tensor_model_parallel_all_reduce, | ||||||
|     tensor_model_parallel_all_gather, |     tensor_model_parallel_all_gather, | ||||||
|  |     broadcast_tensor_dict, | ||||||
| ) | ) | ||||||
| from vllm.worker.worker import _init_distributed_environment | from vllm.test_utils import (init_test_distributed_environment, | ||||||
|  |                              multi_process_tensor_parallel) | ||||||
|  |  | ||||||
| def init_test_distributed_environment(pipeline_parallel_size: int, |  | ||||||
|                                       tensor_parallel_size: int, rank: int, |  | ||||||
|                                       distributed_init_port: str): |  | ||||||
|     parallel_config = ParallelConfig(pipeline_parallel_size, |  | ||||||
|                                      tensor_parallel_size, |  | ||||||
|                                      worker_use_ray=True) |  | ||||||
|     distributed_init_method = f"tcp://localhost:{distributed_init_port}" |  | ||||||
|     torch.cuda.set_device(rank) |  | ||||||
|     _init_distributed_environment(parallel_config, rank, |  | ||||||
|                                   distributed_init_method) |  | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ray.remote(num_gpus=1, max_calls=1) | ||||||
| def all_reduce_test_worker(tensor_parallel_size: int, rank: int, | def all_reduce_test_worker(tensor_parallel_size: int, rank: int, | ||||||
|                            distributed_init_port: str): |                            distributed_init_port: str): | ||||||
|     init_test_distributed_environment(1, tensor_parallel_size, rank, |     init_test_distributed_environment(1, tensor_parallel_size, rank, | ||||||
| @ -43,6 +31,7 @@ def all_reduce_test_worker(tensor_parallel_size: int, rank: int, | |||||||
|     assert torch.allclose(t, expected) |     assert torch.allclose(t, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ray.remote(num_gpus=1, max_calls=1) | ||||||
| def all_gather_test_worker(tensor_parallel_size: int, rank: int, | def all_gather_test_worker(tensor_parallel_size: int, rank: int, | ||||||
|                            distributed_init_port: str): |                            distributed_init_port: str): | ||||||
|     init_test_distributed_environment(1, tensor_parallel_size, rank, |     init_test_distributed_environment(1, tensor_parallel_size, rank, | ||||||
| @ -64,20 +53,40 @@ def all_gather_test_worker(tensor_parallel_size: int, rank: int, | |||||||
|         assert torch.allclose(t, expected) |         assert torch.allclose(t, expected) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ray.remote(num_gpus=1, max_calls=1) | ||||||
|  | def broadcast_tensor_dict_test_worker(tensor_parallel_size: int, rank: int, | ||||||
|  |                                       distributed_init_port: str): | ||||||
|  |     init_test_distributed_environment(1, tensor_parallel_size, rank, | ||||||
|  |                                       distributed_init_port) | ||||||
|  |     test_dict = { | ||||||
|  |         "a": torch.arange(8, dtype=torch.float32, device="cuda"), | ||||||
|  |         "b": torch.arange(16, dtype=torch.int8, device="cuda"), | ||||||
|  |         "c": "test", | ||||||
|  |         "d": [1, 2, 3], | ||||||
|  |         "e": { | ||||||
|  |             "a": 1, | ||||||
|  |             "b": 2 | ||||||
|  |         }, | ||||||
|  |     } | ||||||
|  |  | ||||||
|  |     if rank == 0: | ||||||
|  |         broadcast_tensor_dict(test_dict, src=0) | ||||||
|  |     else: | ||||||
|  |         recv_dict = broadcast_tensor_dict(src=0) | ||||||
|  |         assert len(recv_dict) == len(test_dict) | ||||||
|  |         assert torch.allclose(recv_dict["a"], test_dict["a"]) | ||||||
|  |         assert torch.allclose(recv_dict["b"], test_dict["b"]) | ||||||
|  |         assert recv_dict["c"] == test_dict["c"] | ||||||
|  |         assert recv_dict["d"] == test_dict["d"] | ||||||
|  |         assert recv_dict["e"] == test_dict["e"] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.skipif(torch.cuda.device_count() < 2, | @pytest.mark.skipif(torch.cuda.device_count() < 2, | ||||||
|                     reason="Need at least 2 GPUs to run the test.") |                     reason="Need at least 2 GPUs to run the test.") | ||||||
| @pytest.mark.parametrize("tensor_parallel_size", [2]) | @pytest.mark.parametrize("tensor_parallel_size", [2]) | ||||||
| @pytest.mark.parametrize("test_target", | @pytest.mark.parametrize("test_target", [ | ||||||
|                          [all_reduce_test_worker, all_gather_test_worker]) |     all_reduce_test_worker, all_gather_test_worker, | ||||||
|  |     broadcast_tensor_dict_test_worker | ||||||
|  | ]) | ||||||
| def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): | def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): | ||||||
|     set_start_method("spawn", force=True) |     multi_process_tensor_parallel(tensor_parallel_size, test_target) | ||||||
|     distributed_init_port = get_open_port() |  | ||||||
|     processes = [] |  | ||||||
|     for rank in range(tensor_parallel_size): |  | ||||||
|         p = Process(target=test_target, |  | ||||||
|                     args=(tensor_parallel_size, rank, distributed_init_port)) |  | ||||||
|         p.start() |  | ||||||
|         processes.append(p) |  | ||||||
|     for p in processes: |  | ||||||
|         p.join() |  | ||||||
|     assert all(p.exitcode == 0 for p in processes) |  | ||||||
|  | |||||||
							
								
								
									
										85
									
								
								tests/distributed/test_custom_all_reduce.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										85
									
								
								tests/distributed/test_custom_all_reduce.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,85 @@ | |||||||
|  | import random | ||||||
|  |  | ||||||
|  | import os | ||||||
|  | import pytest | ||||||
|  | import ray | ||||||
|  | import torch | ||||||
|  | import torch.distributed as dist | ||||||
|  |  | ||||||
|  | from vllm.model_executor.parallel_utils import custom_all_reduce as custom_ar | ||||||
|  | from vllm.model_executor.parallel_utils.communication_op import ( | ||||||
|  |     tensor_model_parallel_all_reduce) | ||||||
|  | from vllm.test_utils import (init_test_distributed_environment, | ||||||
|  |                              multi_process_tensor_parallel) | ||||||
|  |  | ||||||
|  | random.seed(42) | ||||||
|  | test_sizes = [random.randint(1024, 2048 * 1024) for _ in range(8)] | ||||||
|  | for i, v in enumerate(test_sizes): | ||||||
|  |     test_sizes[i] -= v % 8 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ray.remote(num_gpus=1, max_calls=1) | ||||||
|  | def graph_allreduce(world_size, rank, distributed_init_port): | ||||||
|  |     del os.environ["CUDA_VISIBLE_DEVICES"] | ||||||
|  |     device = torch.device(f"cuda:{rank}") | ||||||
|  |     torch.cuda.set_device(device) | ||||||
|  |     init_test_distributed_environment(1, world_size, rank, | ||||||
|  |                                       distributed_init_port) | ||||||
|  |  | ||||||
|  |     custom_ar.init_custom_ar() | ||||||
|  |     for sz in test_sizes: | ||||||
|  |         for dtype in [torch.float32, torch.float16, torch.bfloat16]: | ||||||
|  |             with custom_ar.capture(): | ||||||
|  |                 # use integers so result matches NCCL exactly | ||||||
|  |                 inp1 = torch.randint(1, | ||||||
|  |                                      16, (sz, ), | ||||||
|  |                                      dtype=dtype, | ||||||
|  |                                      device=torch.cuda.current_device()) | ||||||
|  |                 inp2 = torch.randint(1, | ||||||
|  |                                      16, (sz, ), | ||||||
|  |                                      dtype=dtype, | ||||||
|  |                                      device=torch.cuda.current_device()) | ||||||
|  |                 torch.cuda.synchronize() | ||||||
|  |                 graph = torch.cuda.CUDAGraph() | ||||||
|  |                 with torch.cuda.graph(graph): | ||||||
|  |                     out1 = tensor_model_parallel_all_reduce(inp1) | ||||||
|  |                     # the input buffer is immediately modified to test | ||||||
|  |                     # synchronization | ||||||
|  |                     dist.all_reduce(inp1) | ||||||
|  |                     out2 = tensor_model_parallel_all_reduce(inp2) | ||||||
|  |                     dist.all_reduce(inp2) | ||||||
|  |             graph.replay() | ||||||
|  |             assert torch.allclose(out1, inp1) | ||||||
|  |             assert torch.allclose(out2, inp2) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ray.remote(num_gpus=1, max_calls=1) | ||||||
|  | def eager_allreduce(world_size, rank, distributed_init_port): | ||||||
|  |     del os.environ["CUDA_VISIBLE_DEVICES"] | ||||||
|  |     device = torch.device(f"cuda:{rank}") | ||||||
|  |     torch.cuda.set_device(device) | ||||||
|  |     init_test_distributed_environment(1, world_size, rank, | ||||||
|  |                                       distributed_init_port) | ||||||
|  |  | ||||||
|  |     sz = 1024 | ||||||
|  |     custom_ar.init_custom_ar() | ||||||
|  |     fa = custom_ar.get_handle() | ||||||
|  |     inp = torch.ones(sz, dtype=torch.float32, device=device) | ||||||
|  |     out = fa.all_reduce_unreg(inp) | ||||||
|  |     assert torch.allclose(out, inp * world_size) | ||||||
|  |  | ||||||
|  |     inp = torch.ones(sz * 4, dtype=torch.bfloat16, device=device) | ||||||
|  |     out = fa.all_reduce_unreg(inp) | ||||||
|  |     assert torch.allclose(out, inp * world_size) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.skipif(torch.cuda.device_count() < 2, | ||||||
|  |                     reason="Need at least 2 GPUs to run the test.") | ||||||
|  | @pytest.mark.parametrize("tensor_parallel_size", [2]) | ||||||
|  | @pytest.mark.parametrize("test_target", [eager_allreduce, graph_allreduce]) | ||||||
|  | def test_multi_process_tensor_parallel(tensor_parallel_size, test_target): | ||||||
|  |     multi_process_tensor_parallel(tensor_parallel_size, test_target) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     multi_process_tensor_parallel(2, graph_allreduce) | ||||||
							
								
								
									
										254
									
								
								tests/entrypoints/test_openai_server.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										254
									
								
								tests/entrypoints/test_openai_server.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,254 @@ | |||||||
|  | import os | ||||||
|  | import subprocess | ||||||
|  | import time | ||||||
|  |  | ||||||
|  | import sys | ||||||
|  | import pytest | ||||||
|  | import requests | ||||||
|  | import ray  # using Ray for overall ease of process management, parallel requests, and debugging. | ||||||
|  | import openai  # use the official client for correctness check | ||||||
|  |  | ||||||
|  | MAX_SERVER_START_WAIT_S = 600  # wait for server to start for 60 seconds | ||||||
|  | MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"  # any model with a chat template should work here | ||||||
|  |  | ||||||
|  | pytestmark = pytest.mark.asyncio | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @ray.remote(num_gpus=1) | ||||||
|  | class ServerRunner: | ||||||
|  |  | ||||||
|  |     def __init__(self, args): | ||||||
|  |         env = os.environ.copy() | ||||||
|  |         env["PYTHONUNBUFFERED"] = "1" | ||||||
|  |         self.proc = subprocess.Popen( | ||||||
|  |             ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, | ||||||
|  |             env=env, | ||||||
|  |             stdout=sys.stdout, | ||||||
|  |             stderr=sys.stderr, | ||||||
|  |         ) | ||||||
|  |         self._wait_for_server() | ||||||
|  |  | ||||||
|  |     def ready(self): | ||||||
|  |         return True | ||||||
|  |  | ||||||
|  |     def _wait_for_server(self): | ||||||
|  |         # run health check | ||||||
|  |         start = time.time() | ||||||
|  |         while True: | ||||||
|  |             try: | ||||||
|  |                 if requests.get( | ||||||
|  |                         "http://localhost:8000/health").status_code == 200: | ||||||
|  |                     break | ||||||
|  |             except Exception as err: | ||||||
|  |                 if self.proc.poll() is not None: | ||||||
|  |                     raise RuntimeError("Server exited unexpectedly.") from err | ||||||
|  |  | ||||||
|  |                 time.sleep(0.5) | ||||||
|  |                 if time.time() - start > MAX_SERVER_START_WAIT_S: | ||||||
|  |                     raise RuntimeError( | ||||||
|  |                         "Server failed to start in time.") from err | ||||||
|  |  | ||||||
|  |     def __del__(self): | ||||||
|  |         if hasattr(self, "proc"): | ||||||
|  |             self.proc.terminate() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture(scope="session") | ||||||
|  | def server(): | ||||||
|  |     ray.init() | ||||||
|  |     server_runner = ServerRunner.remote([ | ||||||
|  |         "--model", | ||||||
|  |         MODEL_NAME, | ||||||
|  |         "--dtype", | ||||||
|  |         "bfloat16",  # use half precision for speed and memory savings in CI environment | ||||||
|  |         "--max-model-len", | ||||||
|  |         "8192", | ||||||
|  |         "--enforce-eager", | ||||||
|  |     ]) | ||||||
|  |     ray.get(server_runner.ready.remote()) | ||||||
|  |     yield server_runner | ||||||
|  |     ray.shutdown() | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.fixture(scope="session") | ||||||
|  | def client(): | ||||||
|  |     client = openai.AsyncOpenAI( | ||||||
|  |         base_url="http://localhost:8000/v1", | ||||||
|  |         api_key="token-abc123", | ||||||
|  |     ) | ||||||
|  |     yield client | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_single_completion(server, client: openai.AsyncOpenAI): | ||||||
|  |     completion = await client.completions.create(model=MODEL_NAME, | ||||||
|  |                                                  prompt="Hello, my name is", | ||||||
|  |                                                  max_tokens=5, | ||||||
|  |                                                  temperature=0.0) | ||||||
|  |  | ||||||
|  |     assert completion.id is not None | ||||||
|  |     assert completion.choices is not None and len(completion.choices) == 1 | ||||||
|  |     assert completion.choices[0].text is not None and len( | ||||||
|  |         completion.choices[0].text) >= 5 | ||||||
|  |     assert completion.choices[0].finish_reason == "length" | ||||||
|  |     assert completion.usage == openai.types.CompletionUsage( | ||||||
|  |         completion_tokens=5, prompt_tokens=6, total_tokens=11) | ||||||
|  |  | ||||||
|  |     # test using token IDs | ||||||
|  |     completion = await client.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         prompt=[0, 0, 0, 0, 0], | ||||||
|  |         max_tokens=5, | ||||||
|  |         temperature=0.0, | ||||||
|  |     ) | ||||||
|  |     assert completion.choices[0].text is not None and len( | ||||||
|  |         completion.choices[0].text) >= 5 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_single_chat_session(server, client: openai.AsyncOpenAI): | ||||||
|  |     messages = [{ | ||||||
|  |         "role": "system", | ||||||
|  |         "content": "you are a helpful assistant" | ||||||
|  |     }, { | ||||||
|  |         "role": "user", | ||||||
|  |         "content": "what is 1+1?" | ||||||
|  |     }] | ||||||
|  |  | ||||||
|  |     # test single completion | ||||||
|  |     chat_completion = await client.chat.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         messages=messages, | ||||||
|  |         max_tokens=10, | ||||||
|  |     ) | ||||||
|  |     assert chat_completion.id is not None | ||||||
|  |     assert chat_completion.choices is not None and len( | ||||||
|  |         chat_completion.choices) == 1 | ||||||
|  |     assert chat_completion.choices[0].message is not None | ||||||
|  |     message = chat_completion.choices[0].message | ||||||
|  |     assert message.content is not None and len(message.content) >= 10 | ||||||
|  |     assert message.role == "assistant" | ||||||
|  |     messages.append({"role": "assistant", "content": message.content}) | ||||||
|  |  | ||||||
|  |     # test multi-turn dialogue | ||||||
|  |     messages.append({"role": "user", "content": "express your result in json"}) | ||||||
|  |     chat_completion = await client.chat.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         messages=messages, | ||||||
|  |         max_tokens=10, | ||||||
|  |     ) | ||||||
|  |     message = chat_completion.choices[0].message | ||||||
|  |     assert message.content is not None and len(message.content) >= 0 | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_completion_streaming(server, client: openai.AsyncOpenAI): | ||||||
|  |     prompt = "What is an LLM?" | ||||||
|  |  | ||||||
|  |     single_completion = await client.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         prompt=prompt, | ||||||
|  |         max_tokens=5, | ||||||
|  |         temperature=0.0, | ||||||
|  |     ) | ||||||
|  |     single_output = single_completion.choices[0].text | ||||||
|  |     single_usage = single_completion.usage | ||||||
|  |  | ||||||
|  |     stream = await client.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         prompt=prompt, | ||||||
|  |         max_tokens=5, | ||||||
|  |         temperature=0.0, | ||||||
|  |         stream=True, | ||||||
|  |     ) | ||||||
|  |     chunks = [] | ||||||
|  |     async for chunk in stream: | ||||||
|  |         chunks.append(chunk.choices[0].text) | ||||||
|  |     assert chunk.choices[0].finish_reason == "length" | ||||||
|  |     assert chunk.usage == single_usage | ||||||
|  |     assert "".join(chunks) == single_output | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_chat_streaming(server, client: openai.AsyncOpenAI): | ||||||
|  |     messages = [{ | ||||||
|  |         "role": "system", | ||||||
|  |         "content": "you are a helpful assistant" | ||||||
|  |     }, { | ||||||
|  |         "role": "user", | ||||||
|  |         "content": "what is 1+1?" | ||||||
|  |     }] | ||||||
|  |  | ||||||
|  |     # test single completion | ||||||
|  |     chat_completion = await client.chat.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         messages=messages, | ||||||
|  |         max_tokens=10, | ||||||
|  |         temperature=0.0, | ||||||
|  |     ) | ||||||
|  |     output = chat_completion.choices[0].message.content | ||||||
|  |     stop_reason = chat_completion.choices[0].finish_reason | ||||||
|  |  | ||||||
|  |     # test streaming | ||||||
|  |     stream = await client.chat.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         messages=messages, | ||||||
|  |         max_tokens=10, | ||||||
|  |         temperature=0.0, | ||||||
|  |         stream=True, | ||||||
|  |     ) | ||||||
|  |     chunks = [] | ||||||
|  |     async for chunk in stream: | ||||||
|  |         delta = chunk.choices[0].delta | ||||||
|  |         if delta.role: | ||||||
|  |             assert delta.role == "assistant" | ||||||
|  |         if delta.content: | ||||||
|  |             chunks.append(delta.content) | ||||||
|  |     assert chunk.choices[0].finish_reason == stop_reason | ||||||
|  |     assert "".join(chunks) == output | ||||||
|  |  | ||||||
|  |  | ||||||
|  | async def test_batch_completions(server, client: openai.AsyncOpenAI): | ||||||
|  |     # test simple list | ||||||
|  |     batch = await client.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         prompt=["Hello, my name is", "Hello, my name is"], | ||||||
|  |         max_tokens=5, | ||||||
|  |         temperature=0.0, | ||||||
|  |     ) | ||||||
|  |     assert len(batch.choices) == 2 | ||||||
|  |     assert batch.choices[0].text == batch.choices[1].text | ||||||
|  |  | ||||||
|  |     # test n = 2 | ||||||
|  |     batch = await client.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         prompt=["Hello, my name is", "Hello, my name is"], | ||||||
|  |         n=2, | ||||||
|  |         max_tokens=5, | ||||||
|  |         temperature=0.0, | ||||||
|  |         extra_body=dict( | ||||||
|  |             # NOTE: this has to be true for n > 1 in vLLM, but not necessary for official client. | ||||||
|  |             use_beam_search=True), | ||||||
|  |     ) | ||||||
|  |     assert len(batch.choices) == 4 | ||||||
|  |     assert batch.choices[0].text != batch.choices[ | ||||||
|  |         1].text, "beam search should be different" | ||||||
|  |     assert batch.choices[0].text == batch.choices[ | ||||||
|  |         2].text, "two copies of the same prompt should be the same" | ||||||
|  |     assert batch.choices[1].text == batch.choices[ | ||||||
|  |         3].text, "two copies of the same prompt should be the same" | ||||||
|  |  | ||||||
|  |     # test streaming | ||||||
|  |     batch = await client.completions.create( | ||||||
|  |         model=MODEL_NAME, | ||||||
|  |         prompt=["Hello, my name is", "Hello, my name is"], | ||||||
|  |         max_tokens=5, | ||||||
|  |         temperature=0.0, | ||||||
|  |         stream=True, | ||||||
|  |     ) | ||||||
|  |     texts = [""] * 2 | ||||||
|  |     async for chunk in batch: | ||||||
|  |         assert len(chunk.choices) == 1 | ||||||
|  |         choice = chunk.choices[0] | ||||||
|  |         texts[choice.index] += choice.text | ||||||
|  |     assert texts[0] == texts[1] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     pytest.main([__file__]) | ||||||
| @ -1,43 +1,7 @@ | |||||||
| from typing import List, Tuple |  | ||||||
|  |  | ||||||
| import pytest | import pytest | ||||||
| import torch | from vllm.utils import create_kv_caches_with_random | ||||||
|  |  | ||||||
|  |  | ||||||
| def create_kv_caches( |  | ||||||
|     num_blocks: int, |  | ||||||
|     block_size: int, |  | ||||||
|     num_layers: int, |  | ||||||
|     num_heads: int, |  | ||||||
|     head_size: int, |  | ||||||
|     dtype: torch.dtype, |  | ||||||
|     seed: int, |  | ||||||
| ) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |  | ||||||
|     torch.random.manual_seed(seed) |  | ||||||
|     torch.cuda.manual_seed(seed) |  | ||||||
|  |  | ||||||
|     scale = head_size**-0.5 |  | ||||||
|     x = 16 // torch.tensor([], dtype=dtype).element_size() |  | ||||||
|     key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) |  | ||||||
|     key_caches = [] |  | ||||||
|     for _ in range(num_layers): |  | ||||||
|         key_cache = torch.empty(size=key_cache_shape, |  | ||||||
|                                 dtype=dtype, |  | ||||||
|                                 device='cuda') |  | ||||||
|         key_cache.uniform_(-scale, scale) |  | ||||||
|         key_caches.append(key_cache) |  | ||||||
|  |  | ||||||
|     value_cache_shape = (num_blocks, num_heads, head_size, block_size) |  | ||||||
|     value_caches = [] |  | ||||||
|     for _ in range(num_layers): |  | ||||||
|         value_cache = torch.empty(size=value_cache_shape, |  | ||||||
|                                   dtype=dtype, |  | ||||||
|                                   device='cuda') |  | ||||||
|         value_cache.uniform_(-scale, scale) |  | ||||||
|         value_caches.append(value_cache) |  | ||||||
|     return key_caches, value_caches |  | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.fixture() | @pytest.fixture() | ||||||
| def kv_cache_factory(): | def kv_cache_factory(): | ||||||
|     return create_kv_caches |     return create_kv_caches_with_random | ||||||
|  | |||||||
| @ -7,22 +7,26 @@ DTYPES = [torch.half, torch.bfloat16, torch.float] | |||||||
| NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing | NUM_TOKENS = [7, 83, 2048]  # Arbitrary values for testing | ||||||
| D = [512, 4096, 5120, 13824]  # Arbitrary values for testing | D = [512, 4096, 5120, 13824]  # Arbitrary values for testing | ||||||
| SEEDS = [0] | SEEDS = [0] | ||||||
|  | DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||||||
| @pytest.mark.parametrize("d", D) | @pytest.mark.parametrize("d", D) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_silu_and_mul( | def test_silu_and_mul( | ||||||
|     num_tokens: int, |     num_tokens: int, | ||||||
|     d: int, |     d: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|     x = torch.randn(num_tokens, 2 * d, dtype=dtype, device="cuda") |     gpu_id = f"cuda:{device}" | ||||||
|  |     x = torch.randn(num_tokens, 2 * d, dtype=dtype, device=gpu_id) | ||||||
|     layer = SiluAndMul() |     layer = SiluAndMul() | ||||||
|     out = layer(x) |     out = layer(x) | ||||||
|     ref_out = layer._forward(x) |     ref_out = layer._forward(x) | ||||||
| @ -33,16 +37,19 @@ def test_silu_and_mul( | |||||||
| @pytest.mark.parametrize("d", D) | @pytest.mark.parametrize("d", D) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_gelu_new( | def test_gelu_new( | ||||||
|     num_tokens: int, |     num_tokens: int, | ||||||
|     d: int, |     d: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|     x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") |     gpu_id = f"cuda:{device}" | ||||||
|  |     x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) | ||||||
|     layer = NewGELU() |     layer = NewGELU() | ||||||
|     out = layer(x) |     out = layer(x) | ||||||
|     ref_out = layer._forward(x) |     ref_out = layer._forward(x) | ||||||
| @ -53,15 +60,18 @@ def test_gelu_new( | |||||||
| @pytest.mark.parametrize("d", D) | @pytest.mark.parametrize("d", D) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| def test_gelu_fast( | def test_gelu_fast( | ||||||
|     num_tokens: int, |     num_tokens: int, | ||||||
|     d: int, |     d: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|     x = torch.randn(num_tokens, d, dtype=dtype, device="cuda") |     gpu_id = f"cuda:{device}" | ||||||
|  |     x = torch.randn(num_tokens, d, dtype=dtype, device=gpu_id) | ||||||
|     layer = FastGELU() |     layer = FastGELU() | ||||||
|     out = layer(x) |     out = layer(x) | ||||||
|     ref_out = layer._forward(x) |     ref_out = layer._forward(x) | ||||||
|  | |||||||
| @ -6,14 +6,16 @@ import torch | |||||||
| from xformers import ops as xops | from xformers import ops as xops | ||||||
| from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask | from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask | ||||||
|  |  | ||||||
| from vllm._C import ops | from vllm._C import ops, cache_ops | ||||||
| from vllm.utils import get_max_shared_memory_bytes | from vllm.utils import get_max_shared_memory_bytes | ||||||
|  |  | ||||||
| FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 | ||||||
| # This will change depending on the compute capability. | # This will change depending on the compute capability. | ||||||
| # - 512 as a buffer | # - 512 as a buffer | ||||||
| MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 | MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 | ||||||
| NUM_BLOCKS = 40000  # Arbitrary values for testing | # There may not be enough gpu memory due to large NUM_BLOCKS. | ||||||
|  | # Reduce NUM_BLOCKS when it happens. | ||||||
|  | NUM_BLOCKS = 4321  # Arbitrary values for testing | ||||||
| PARTITION_SIZE = 512 | PARTITION_SIZE = 512 | ||||||
|  |  | ||||||
| DTYPES = [torch.half, torch.bfloat16, torch.float] | DTYPES = [torch.half, torch.bfloat16, torch.float] | ||||||
| @ -23,7 +25,9 @@ NUM_HEADS = [(40, 40), (64, 8)]  # Arbitrary values for testing | |||||||
| HEAD_SIZES = [64, 80, 96, 112, 128, 256] | HEAD_SIZES = [64, 80, 96, 112, 128, 256] | ||||||
| BLOCK_SIZES = [16, 32] | BLOCK_SIZES = [16, 32] | ||||||
| USE_ALIBI = [False, True] | USE_ALIBI = [False, True] | ||||||
|  | KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] | ||||||
| SEEDS = [0] | SEEDS = [0] | ||||||
|  | DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] | ||||||
|  |  | ||||||
|  |  | ||||||
| def ref_masked_attention( | def ref_masked_attention( | ||||||
| @ -87,7 +91,7 @@ def ref_single_query_cached_kv_attention( | |||||||
|         alibi_bias = None |         alibi_bias = None | ||||||
|         if alibi_slopes is not None: |         if alibi_slopes is not None: | ||||||
|             # Create the ALiBi bias used in the paged attention kernel. |             # Create the ALiBi bias used in the paged attention kernel. | ||||||
|             position_ids = torch.arange(context_len, device="cuda").int() |             position_ids = torch.arange(context_len, device=query.device).int() | ||||||
|             alibi_bias = (position_ids - context_len + 1).float() |             alibi_bias = (position_ids - context_len + 1).float() | ||||||
|             alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( |             alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( | ||||||
|                 1, 1, -1) |                 1, 1, -1) | ||||||
| @ -104,7 +108,9 @@ def ref_single_query_cached_kv_attention( | |||||||
| @pytest.mark.parametrize("use_alibi", USE_ALIBI) | @pytest.mark.parametrize("use_alibi", USE_ALIBI) | ||||||
| @pytest.mark.parametrize("block_size", BLOCK_SIZES) | @pytest.mark.parametrize("block_size", BLOCK_SIZES) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
|  | @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| def test_paged_attention( | def test_paged_attention( | ||||||
|     kv_cache_factory, |     kv_cache_factory, | ||||||
|     version: str, |     version: str, | ||||||
| @ -114,19 +120,21 @@ def test_paged_attention( | |||||||
|     use_alibi: bool, |     use_alibi: bool, | ||||||
|     block_size: int, |     block_size: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|  |     kv_cache_dtype: str, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|  |     gpu_id = f"cuda:{device}" | ||||||
|     scale = float(1.0 / (head_size**0.5)) |     scale = float(1.0 / (head_size**0.5)) | ||||||
|     num_query_heads, num_kv_heads = num_heads |     num_query_heads, num_kv_heads = num_heads | ||||||
|     query = torch.empty(num_seqs, |     query = torch.empty(num_seqs, | ||||||
|                         num_query_heads, |                         num_query_heads, | ||||||
|                         head_size, |                         head_size, | ||||||
|                         dtype=dtype, |                         dtype=dtype, | ||||||
|                         device="cuda") |                         device=gpu_id) | ||||||
|     query.uniform_(-scale, scale) |     query.uniform_(-scale, scale) | ||||||
|  |  | ||||||
|     assert num_query_heads % num_kv_heads == 0 |     assert num_query_heads % num_kv_heads == 0 | ||||||
| @ -135,12 +143,12 @@ def test_paged_attention( | |||||||
|     if use_alibi: |     if use_alibi: | ||||||
|         alibi_slopes = torch.randn(num_query_heads, |         alibi_slopes = torch.randn(num_query_heads, | ||||||
|                                    dtype=torch.float, |                                    dtype=torch.float, | ||||||
|                                    device="cuda") |                                    device=gpu_id) | ||||||
|  |  | ||||||
|     context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] |     context_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] | ||||||
|     context_lens[-1] = MAX_SEQ_LEN |     context_lens[-1] = MAX_SEQ_LEN | ||||||
|     max_context_len = max(context_lens) |     max_context_len = max(context_lens) | ||||||
|     context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") |     context_lens = torch.tensor(context_lens, dtype=torch.int, device=gpu_id) | ||||||
|  |  | ||||||
|     # Create the block tables. |     # Create the block tables. | ||||||
|     max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size |     max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size | ||||||
| @ -151,12 +159,13 @@ def test_paged_attention( | |||||||
|             for _ in range(max_num_blocks_per_seq) |             for _ in range(max_num_blocks_per_seq) | ||||||
|         ] |         ] | ||||||
|         block_tables.append(block_table) |         block_tables.append(block_table) | ||||||
|     block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda") |     block_tables = torch.tensor(block_tables, dtype=torch.int, device=gpu_id) | ||||||
|  |  | ||||||
|     # Create the KV caches. |     # Create the KV caches. | ||||||
|     key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, |     key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, | ||||||
|                                                 num_kv_heads, head_size, dtype, |                                                 num_kv_heads, head_size, | ||||||
|                                                 seed) |                                                 kv_cache_dtype, dtype, seed, | ||||||
|  |                                                 gpu_id) | ||||||
|     key_cache, value_cache = key_caches[0], value_caches[0] |     key_cache, value_cache = key_caches[0], value_caches[0] | ||||||
|  |  | ||||||
|     # Call the paged attention kernel. |     # Call the paged attention kernel. | ||||||
| @ -174,6 +183,7 @@ def test_paged_attention( | |||||||
|             block_size, |             block_size, | ||||||
|             max_context_len, |             max_context_len, | ||||||
|             alibi_slopes, |             alibi_slopes, | ||||||
|  |             kv_cache_dtype, | ||||||
|         ) |         ) | ||||||
|     elif version == "v2": |     elif version == "v2": | ||||||
|         num_partitions = ((max_context_len + PARTITION_SIZE - 1) // |         num_partitions = ((max_context_len + PARTITION_SIZE - 1) // | ||||||
| @ -206,11 +216,30 @@ def test_paged_attention( | |||||||
|             block_size, |             block_size, | ||||||
|             max_context_len, |             max_context_len, | ||||||
|             alibi_slopes, |             alibi_slopes, | ||||||
|  |             kv_cache_dtype, | ||||||
|         ) |         ) | ||||||
|     else: |     else: | ||||||
|         raise AssertionError(f"Unknown version: {version}") |         raise AssertionError(f"Unknown version: {version}") | ||||||
|  |  | ||||||
|     # Run the reference implementation. |     # Run the reference implementation. | ||||||
|  |     if kv_cache_dtype == "fp8_e5m2": | ||||||
|  |         # Convert cache data back to dtype. | ||||||
|  |         x = 16 // torch.tensor([], dtype=dtype).element_size() | ||||||
|  |         key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, | ||||||
|  |                            block_size, x) | ||||||
|  |         dequantized_key_cache = torch.empty(size=key_cache_shape, | ||||||
|  |                                             dtype=dtype, | ||||||
|  |                                             device=gpu_id) | ||||||
|  |         cache_ops.convert_fp8_e5m2(key_cache, dequantized_key_cache) | ||||||
|  |         key_cache = dequantized_key_cache | ||||||
|  |  | ||||||
|  |         value_cache_shape = value_cache.shape | ||||||
|  |         dequantized_value_cache = torch.empty(size=value_cache_shape, | ||||||
|  |                                               dtype=dtype, | ||||||
|  |                                               device=gpu_id) | ||||||
|  |         cache_ops.convert_fp8_e5m2(value_cache, dequantized_value_cache) | ||||||
|  |         value_cache = dequantized_value_cache | ||||||
|  |  | ||||||
|     ref_output = torch.empty_like(query) |     ref_output = torch.empty_like(query) | ||||||
|     ref_single_query_cached_kv_attention( |     ref_single_query_cached_kv_attention( | ||||||
|         ref_output, |         ref_output, | ||||||
| @ -227,7 +256,12 @@ def test_paged_attention( | |||||||
|     # NOTE(woosuk): Due to the kernel-level differences in the two |     # NOTE(woosuk): Due to the kernel-level differences in the two | ||||||
|     # implementations, there is a small numerical difference in the two |     # implementations, there is a small numerical difference in the two | ||||||
|     # outputs. Thus, we use a relaxed tolerance for the test. |     # outputs. Thus, we use a relaxed tolerance for the test. | ||||||
|     assert torch.allclose(output, ref_output, atol=1e-3, rtol=1e-5) |     # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, | ||||||
|  |     # so we use a relaxed tolerance for the test. | ||||||
|  |     atol, rtol = 1e-3, 1e-5 | ||||||
|  |     if kv_cache_dtype == "fp8_e5m2": | ||||||
|  |         atol, rtol = 1e-2, 1e-5 | ||||||
|  |     assert torch.allclose(output, ref_output, atol=atol, rtol=rtol) | ||||||
|  |  | ||||||
|  |  | ||||||
| def ref_multi_query_kv_attention( | def ref_multi_query_kv_attention( | ||||||
| @ -249,7 +283,7 @@ def ref_multi_query_kv_attention( | |||||||
|         attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), |         attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), | ||||||
|                                diagonal=1) |                                diagonal=1) | ||||||
|         attn_mask = attn_mask * torch.finfo(dtype).min |         attn_mask = attn_mask * torch.finfo(dtype).min | ||||||
|         attn_mask = attn_mask.to(dtype=dtype, device="cuda") |         attn_mask = attn_mask.to(dtype=dtype, device=query.device) | ||||||
|  |  | ||||||
|         ref_output = ref_masked_attention( |         ref_output = ref_masked_attention( | ||||||
|             query[start_idx:end_idx], |             query[start_idx:end_idx], | ||||||
| @ -269,6 +303,7 @@ def ref_multi_query_kv_attention( | |||||||
| @pytest.mark.parametrize("head_size", HEAD_SIZES) | @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_multi_query_kv_attention( | def test_multi_query_kv_attention( | ||||||
|     num_seqs: int, |     num_seqs: int, | ||||||
| @ -276,11 +311,12 @@ def test_multi_query_kv_attention( | |||||||
|     head_size: int, |     head_size: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|  |     gpu_id = f"cuda:{device}" | ||||||
|     # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. |     # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. | ||||||
|     # As the xformers library is already tested with its own tests, we can use |     # As the xformers library is already tested with its own tests, we can use | ||||||
|     # a smaller MAX_SEQ_LEN here. |     # a smaller MAX_SEQ_LEN here. | ||||||
| @ -294,7 +330,7 @@ def test_multi_query_kv_attention( | |||||||
|                       num_query_heads + 2 * num_kv_heads, |                       num_query_heads + 2 * num_kv_heads, | ||||||
|                       head_size, |                       head_size, | ||||||
|                       dtype=dtype, |                       dtype=dtype, | ||||||
|                       device="cuda") |                       device=gpu_id) | ||||||
|     qkv.uniform_(-scale, scale) |     qkv.uniform_(-scale, scale) | ||||||
|     query, key, value = qkv.split( |     query, key, value = qkv.split( | ||||||
|         [num_query_heads, num_kv_heads, num_kv_heads], dim=1) |         [num_query_heads, num_kv_heads, num_kv_heads], dim=1) | ||||||
|  | |||||||
| @ -3,17 +3,22 @@ import random | |||||||
| import pytest | import pytest | ||||||
| import torch | import torch | ||||||
|  |  | ||||||
|  | from typing import Tuple | ||||||
|  |  | ||||||
| from vllm._C import cache_ops | from vllm._C import cache_ops | ||||||
|  |  | ||||||
|  | COPYING_DIRECTION = [('cuda', 'cpu'), ('cuda', 'cuda'), ('cpu', 'cuda')] | ||||||
| DTYPES = [torch.half, torch.bfloat16, torch.float] | DTYPES = [torch.half, torch.bfloat16, torch.float] | ||||||
| NUM_TOKENS = [83]  # Arbitrary values for testing | NUM_TOKENS = [42]  # Arbitrary values for testing | ||||||
| NUM_LAYERS = [1]  # Arbitrary values for testing | NUM_LAYERS = [1]  # Arbitrary values for testing | ||||||
| NUM_HEADS = [8]  # Arbitrary values for testing | NUM_HEADS = [8]  # Arbitrary values for testing | ||||||
| HEAD_SIZES = [64, 80, 96, 112, 128, 256] | HEAD_SIZES = [64, 80, 96, 112, 128, 256] | ||||||
| BLOCK_SIZES = [8, 16, 32] | BLOCK_SIZES = [8, 16, 32] | ||||||
| NUM_BLOCKS = [1024, 36000]  # Arbitrary values for testing | NUM_BLOCKS = [1024, 3600]  # Arbitrary values for testing | ||||||
| NUM_MAPPINGS = [256]  # Arbitrary values for testing | NUM_MAPPINGS = [256]  # Arbitrary values for testing | ||||||
| SEEDS = [0] | SEEDS = [0] | ||||||
|  | DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] | ||||||
|  | KV_CACHE_DTYPE = ["auto", "fp8_e5m2"] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) | @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) | ||||||
| @ -24,6 +29,8 @@ SEEDS = [0] | |||||||
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) | @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
|  | @pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_copy_blocks( | def test_copy_blocks( | ||||||
|     kv_cache_factory, |     kv_cache_factory, | ||||||
| @ -35,11 +42,13 @@ def test_copy_blocks( | |||||||
|     num_blocks: int, |     num_blocks: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
|  |     kv_cache_dtype: str, | ||||||
| ) -> None: | ) -> None: | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|  |     gpu_id = f"cuda:{device}" | ||||||
|     # Generate random block mappings where each source block is mapped to two |     # Generate random block mappings where each source block is mapped to two | ||||||
|     # destination blocks. |     # destination blocks. | ||||||
|     assert 2 * num_mappings <= num_blocks |     assert 2 * num_mappings <= num_blocks | ||||||
| @ -56,7 +65,8 @@ def test_copy_blocks( | |||||||
|     # Create the KV caches. |     # Create the KV caches. | ||||||
|     key_caches, value_caches = kv_cache_factory(num_blocks, block_size, |     key_caches, value_caches = kv_cache_factory(num_blocks, block_size, | ||||||
|                                                 num_layers, num_heads, |                                                 num_layers, num_heads, | ||||||
|                                                 head_size, dtype, seed) |                                                 head_size, kv_cache_dtype, | ||||||
|  |                                                 dtype, seed, gpu_id) | ||||||
|  |  | ||||||
|     # Clone the KV caches. |     # Clone the KV caches. | ||||||
|     cloned_key_caches = [key_cache.clone() for key_cache in key_caches] |     cloned_key_caches = [key_cache.clone() for key_cache in key_caches] | ||||||
| @ -88,6 +98,7 @@ def test_copy_blocks( | |||||||
| @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) | @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_reshape_and_cache( | def test_reshape_and_cache( | ||||||
|     kv_cache_factory, |     kv_cache_factory, | ||||||
| @ -98,28 +109,29 @@ def test_reshape_and_cache( | |||||||
|     num_blocks: int, |     num_blocks: int, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     random.seed(seed) |     random.seed(seed) | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|  |     gpu_id = f"cuda:{device}" | ||||||
|     # Create a random slot mapping. |     # Create a random slot mapping. | ||||||
|     num_slots = block_size * num_blocks |     num_slots = block_size * num_blocks | ||||||
|     slot_mapping = random.sample(range(num_slots), num_tokens) |     slot_mapping = random.sample(range(num_slots), num_tokens) | ||||||
|     slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device="cuda") |     slot_mapping = torch.tensor(slot_mapping, dtype=torch.long, device=gpu_id) | ||||||
|  |  | ||||||
|     qkv = torch.randn(num_tokens, |     qkv = torch.randn(num_tokens, | ||||||
|                       3, |                       3, | ||||||
|                       num_heads, |                       num_heads, | ||||||
|                       head_size, |                       head_size, | ||||||
|                       dtype=dtype, |                       dtype=dtype, | ||||||
|                       device="cuda") |                       device=gpu_id) | ||||||
|     _, key, value = qkv.unbind(dim=1) |     _, key, value = qkv.unbind(dim=1) | ||||||
|  |  | ||||||
|     # Create the KV caches. |     # Create the KV caches. | ||||||
|     key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, |     key_caches, value_caches = kv_cache_factory(num_blocks, block_size, 1, | ||||||
|                                                 num_heads, head_size, dtype, |                                                 num_heads, head_size, dtype, | ||||||
|                                                 seed) |                                                 None, seed, gpu_id) | ||||||
|     key_cache, value_cache = key_caches[0], value_caches[0] |     key_cache, value_cache = key_caches[0], value_caches[0] | ||||||
|  |  | ||||||
|     # Clone the KV caches. |     # Clone the KV caches. | ||||||
| @ -128,7 +140,7 @@ def test_reshape_and_cache( | |||||||
|  |  | ||||||
|     # Call the reshape_and_cache kernel. |     # Call the reshape_and_cache kernel. | ||||||
|     cache_ops.reshape_and_cache(key, value, key_cache, value_cache, |     cache_ops.reshape_and_cache(key, value, key_cache, value_cache, | ||||||
|                                 slot_mapping) |                                 slot_mapping, "auto") | ||||||
|  |  | ||||||
|     # Run the reference implementation. |     # Run the reference implementation. | ||||||
|     reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) |     reshaped_key = key.reshape(num_tokens, *key_cache[0, :, :, 0, :].shape) | ||||||
| @ -144,3 +156,68 @@ def test_reshape_and_cache( | |||||||
|  |  | ||||||
|     assert torch.allclose(key_cache, cloned_key_cache) |     assert torch.allclose(key_cache, cloned_key_cache) | ||||||
|     assert torch.allclose(value_cache, cloned_value_cache) |     assert torch.allclose(value_cache, cloned_value_cache) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("direction", COPYING_DIRECTION) | ||||||
|  | @pytest.mark.parametrize("num_mappings", NUM_MAPPINGS) | ||||||
|  | @pytest.mark.parametrize("num_heads", NUM_HEADS) | ||||||
|  | @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||||||
|  | @pytest.mark.parametrize("block_size", BLOCK_SIZES) | ||||||
|  | @pytest.mark.parametrize("num_blocks", NUM_BLOCKS) | ||||||
|  | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
|  | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
|  | @torch.inference_mode() | ||||||
|  | def test_swap_blocks( | ||||||
|  |     kv_cache_factory, | ||||||
|  |     direction: Tuple[str, str], | ||||||
|  |     num_mappings: int, | ||||||
|  |     num_heads: int, | ||||||
|  |     head_size: int, | ||||||
|  |     block_size: int, | ||||||
|  |     num_blocks: int, | ||||||
|  |     dtype: torch.dtype, | ||||||
|  |     seed: int, | ||||||
|  |     device: int, | ||||||
|  | ) -> None: | ||||||
|  |     random.seed(seed) | ||||||
|  |     torch.random.manual_seed(seed) | ||||||
|  |     torch.cuda.manual_seed(seed) | ||||||
|  |     src_device = f"{direction[0]}:{device}" if direction[ | ||||||
|  |         0] == "cuda" else direction[0] | ||||||
|  |     dst_device = f"{direction[1]}:{device}" if direction[ | ||||||
|  |         1] == "cuda" else direction[1] | ||||||
|  |  | ||||||
|  |     src_blocks = random.sample(range(num_blocks), num_mappings) | ||||||
|  |     # For the same device, mapping must not overlap | ||||||
|  |     if src_device == dst_device: | ||||||
|  |         remaining_blocks = list(set(range(num_blocks)) - set(src_blocks)) | ||||||
|  |         dst_blocks = random.sample(remaining_blocks, num_mappings) | ||||||
|  |     else: | ||||||
|  |         dst_blocks = random.sample(range(num_blocks), num_mappings) | ||||||
|  |  | ||||||
|  |     block_mapping = dict(zip(src_blocks, dst_blocks)) | ||||||
|  |  | ||||||
|  |     # Create the KV caches on the first device. | ||||||
|  |     src_key_caches, src_value_caches = kv_cache_factory( | ||||||
|  |         num_blocks, block_size, 1, num_heads, head_size, dtype, seed, | ||||||
|  |         src_device) | ||||||
|  |  | ||||||
|  |     # Create the KV caches on the second device. | ||||||
|  |     dist_key_caches, dist_value_caches = kv_cache_factory( | ||||||
|  |         num_blocks, block_size, 1, num_heads, head_size, dtype, seed, | ||||||
|  |         dst_device) | ||||||
|  |  | ||||||
|  |     src_key_caches_clone = src_key_caches[0].clone() | ||||||
|  |     src_value_caches_clone = src_value_caches[0].clone() | ||||||
|  |  | ||||||
|  |     # Call the swap_blocks kernel. | ||||||
|  |     cache_ops.swap_blocks(src_key_caches[0], dist_key_caches[0], block_mapping) | ||||||
|  |     cache_ops.swap_blocks(src_value_caches[0], dist_value_caches[0], | ||||||
|  |                           block_mapping) | ||||||
|  |  | ||||||
|  |     for src, dst in block_mapping.items(): | ||||||
|  |         assert torch.allclose(src_key_caches_clone[src].cpu(), | ||||||
|  |                               dist_key_caches[0][dst].cpu()) | ||||||
|  |         assert torch.allclose(src_value_caches_clone[src].cpu(), | ||||||
|  |                               dist_value_caches[0][dst].cpu()) | ||||||
|  | |||||||
							
								
								
									
										50
									
								
								tests/kernels/test_fused_moe.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										50
									
								
								tests/kernels/test_fused_moe.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,50 @@ | |||||||
|  | import pytest | ||||||
|  | import torch | ||||||
|  |  | ||||||
|  | from vllm.model_executor.layers.fused_moe import fused_moe | ||||||
|  | from vllm.model_executor.layers.activation import SiluAndMul | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def torch_moe(a, w1, w2, topk_weight, topk_ids): | ||||||
|  |     B, D = a.shape | ||||||
|  |     a = a.view(B, -1, D).repeat(1, topk_ids.shape[1], 1).reshape(-1, D) | ||||||
|  |     out = torch.zeros(B * topk_ids.shape[1], | ||||||
|  |                       w2.shape[1], | ||||||
|  |                       dtype=a.dtype, | ||||||
|  |                       device=a.device) | ||||||
|  |     topk_ids = topk_ids.view(-1) | ||||||
|  |     topk_weight = topk_weight.view(-1) | ||||||
|  |     for i in range(w1.shape[0]): | ||||||
|  |         mask = topk_ids == i | ||||||
|  |         if mask.sum(): | ||||||
|  |             out[mask] = SiluAndMul()( | ||||||
|  |                 a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose(0, 1) | ||||||
|  |     return (out.view(B, -1, w2.shape[1]) * | ||||||
|  |             topk_weight.view(B, -1, 1)).sum(dim=1) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("m", [512, 222, 33, 1]) | ||||||
|  | @pytest.mark.parametrize("n", [2048, 256, 1024]) | ||||||
|  | @pytest.mark.parametrize("k", [128, 511, 1024]) | ||||||
|  | @pytest.mark.parametrize("e", [8, 64]) | ||||||
|  | @pytest.mark.parametrize("topk", [2, 6]) | ||||||
|  | @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) | ||||||
|  | def test_fused_moe( | ||||||
|  |     m: int, | ||||||
|  |     n: int, | ||||||
|  |     k: int, | ||||||
|  |     e: int, | ||||||
|  |     topk: int, | ||||||
|  |     dtype: torch.dtype, | ||||||
|  | ): | ||||||
|  |     a = torch.randn((m, k), device='cuda', dtype=dtype) / 10 | ||||||
|  |     w1 = torch.randn((e, 2 * n, k), device='cuda', dtype=dtype) / 10 | ||||||
|  |     w2 = torch.randn((e, k, n), device='cuda', dtype=dtype) / 10 | ||||||
|  |  | ||||||
|  |     score = torch.randn((m, e), device='cuda', dtype=dtype) | ||||||
|  |     score = torch.softmax(score, dim=-1) | ||||||
|  |     topk_weight, topk_ids = torch.topk(score, topk) | ||||||
|  |  | ||||||
|  |     triton_output = fused_moe(a, w1, w2, topk_weight, topk_ids, False) | ||||||
|  |     torch_output = torch_moe(a, w1, w2, topk_weight, topk_ids) | ||||||
|  |     assert torch.allclose(triton_output, torch_output, atol=1e-2, rtol=0) | ||||||
| @ -8,6 +8,7 @@ NUM_TOKENS = [7, 83, 4096]  # Arbitrary values for testing | |||||||
| HIDDEN_SIZES = [768, 5120, 8192]  # Arbitrary values for testing | HIDDEN_SIZES = [768, 5120, 8192]  # Arbitrary values for testing | ||||||
| ADD_RESIDUAL = [False, True] | ADD_RESIDUAL = [False, True] | ||||||
| SEEDS = [0] | SEEDS = [0] | ||||||
|  | DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | @pytest.mark.parametrize("num_tokens", NUM_TOKENS) | ||||||
| @ -15,6 +16,7 @@ SEEDS = [0] | |||||||
| @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) | @pytest.mark.parametrize("add_residual", ADD_RESIDUAL) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_rms_norm( | def test_rms_norm( | ||||||
|     num_tokens: int, |     num_tokens: int, | ||||||
| @ -22,14 +24,15 @@ def test_rms_norm( | |||||||
|     add_residual: bool, |     add_residual: bool, | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
| ) -> None: | ) -> None: | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|  |     gpu_id = f"cuda:{device}" | ||||||
|     layer = RMSNorm(hidden_size).to(dtype).cuda() |     layer = RMSNorm(hidden_size).to(dtype=dtype, device=gpu_id) | ||||||
|     layer.weight.data.normal_(mean=1.0, std=0.1) |     layer.weight.data.normal_(mean=1.0, std=0.1) | ||||||
|     scale = 1 / (2 * hidden_size) |     scale = 1 / (2 * hidden_size) | ||||||
|     x = torch.randn(num_tokens, hidden_size, dtype=dtype, device="cuda") |     x = torch.randn(num_tokens, hidden_size, dtype=dtype, device=gpu_id) | ||||||
|     x *= scale |     x *= scale | ||||||
|     residual = torch.randn_like(x) * scale if add_residual else None |     residual = torch.randn_like(x) * scale if add_residual else None | ||||||
|  |  | ||||||
|  | |||||||
| @ -13,6 +13,7 @@ NUM_HEADS = [7, 17]  # Arbitrary values for testing | |||||||
| BATCH_SIZES = [1, 5]  # Arbitrary values for testing | BATCH_SIZES = [1, 5]  # Arbitrary values for testing | ||||||
| SEQ_LENS = [11, 8192]  # Arbitrary values for testing | SEQ_LENS = [11, 8192]  # Arbitrary values for testing | ||||||
| SEEDS = [0] | SEEDS = [0] | ||||||
|  | DEVICES = [i for i in range(1 if torch.cuda.device_count() == 1 else 2)] | ||||||
|  |  | ||||||
|  |  | ||||||
| @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) | @pytest.mark.parametrize("is_neox_style", IS_NEOX_STYLE) | ||||||
| @ -23,6 +24,7 @@ SEEDS = [0] | |||||||
| @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) | @pytest.mark.parametrize("rotary_dim", ROTARY_DIMS) | ||||||
| @pytest.mark.parametrize("dtype", DTYPES) | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
| @pytest.mark.parametrize("seed", SEEDS) | @pytest.mark.parametrize("seed", SEEDS) | ||||||
|  | @pytest.mark.parametrize("device", DEVICES) | ||||||
| @torch.inference_mode() | @torch.inference_mode() | ||||||
| def test_rotary_embedding( | def test_rotary_embedding( | ||||||
|     is_neox_style: bool, |     is_neox_style: bool, | ||||||
| @ -33,6 +35,7 @@ def test_rotary_embedding( | |||||||
|     rotary_dim: Optional[int], |     rotary_dim: Optional[int], | ||||||
|     dtype: torch.dtype, |     dtype: torch.dtype, | ||||||
|     seed: int, |     seed: int, | ||||||
|  |     device: int, | ||||||
|     max_position: int = 8192, |     max_position: int = 8192, | ||||||
|     base: int = 10000, |     base: int = 10000, | ||||||
| ) -> None: | ) -> None: | ||||||
| @ -40,20 +43,20 @@ def test_rotary_embedding( | |||||||
|         rotary_dim = head_size |         rotary_dim = head_size | ||||||
|     torch.random.manual_seed(seed) |     torch.random.manual_seed(seed) | ||||||
|     torch.cuda.manual_seed(seed) |     torch.cuda.manual_seed(seed) | ||||||
|  |     gpu_id = f"cuda:{device}" | ||||||
|     if rotary_dim is None: |     if rotary_dim is None: | ||||||
|         rotary_dim = head_size |         rotary_dim = head_size | ||||||
|     rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) |     rope = get_rope(head_size, rotary_dim, max_position, base, is_neox_style) | ||||||
|     rope = rope.to(dtype).cuda() |     rope = rope.to(dtype=dtype, device=gpu_id) | ||||||
|  |  | ||||||
|     positions = torch.randint(0, |     positions = torch.randint(0, | ||||||
|                               max_position, (batch_size, seq_len), |                               max_position, (batch_size, seq_len), | ||||||
|                               device="cuda") |                               device=gpu_id) | ||||||
|     query = torch.randn(batch_size, |     query = torch.randn(batch_size, | ||||||
|                         seq_len, |                         seq_len, | ||||||
|                         num_heads * head_size, |                         num_heads * head_size, | ||||||
|                         dtype=dtype, |                         dtype=dtype, | ||||||
|                         device="cuda") |                         device=gpu_id) | ||||||
|     key = torch.randn_like(query) |     key = torch.randn_like(query) | ||||||
|  |  | ||||||
|     # NOTE(woosuk): The reference implementation should be executed first |     # NOTE(woosuk): The reference implementation should be executed first | ||||||
|  | |||||||
							
								
								
									
										169
									
								
								tests/kernels/test_prefix_prefill.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								tests/kernels/test_prefix_prefill.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,169 @@ | |||||||
|  | import random | ||||||
|  | import pytest | ||||||
|  | import time | ||||||
|  |  | ||||||
|  | import torch | ||||||
|  | from vllm.model_executor.layers.triton_kernel.prefix_prefill import ( | ||||||
|  |     context_attention_fwd) | ||||||
|  | from xformers import ops as xops | ||||||
|  | from xformers.ops.fmha.attn_bias import BlockDiagonalCausalFromBottomRightMask | ||||||
|  |  | ||||||
|  | NUM_HEADS = [12] | ||||||
|  | HEAD_SIZES = [128] | ||||||
|  | DTYPES = [torch.float16] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | @pytest.mark.parametrize("num_heads", NUM_HEADS) | ||||||
|  | @pytest.mark.parametrize("head_size", HEAD_SIZES) | ||||||
|  | @pytest.mark.parametrize("dtype", DTYPES) | ||||||
|  | @torch.inference_mode() | ||||||
|  | def test_contexted_kv_attention( | ||||||
|  |     num_heads: int, | ||||||
|  |     head_size: int, | ||||||
|  |     dtype: torch.dtype, | ||||||
|  | ) -> None: | ||||||
|  |     random.seed(0) | ||||||
|  |     torch.manual_seed(0) | ||||||
|  |     MAX_SEQ_LEN = 1024 | ||||||
|  |     MAX_CTX_LEN = 1024 | ||||||
|  |     BS = 10 | ||||||
|  |     cache_size = 640 | ||||||
|  |     block_size = 32 | ||||||
|  |     max_block_per_request = 64 | ||||||
|  |     subquery_lens = [random.randint(16, MAX_SEQ_LEN) for _ in range(BS)] | ||||||
|  |     ctx_lens = [random.randint(16, MAX_CTX_LEN) for _ in range(BS)] | ||||||
|  |     seq_lens = [a + b for a, b in zip(subquery_lens, ctx_lens)] | ||||||
|  |  | ||||||
|  |     num_tokens = sum(subquery_lens) | ||||||
|  |     query = torch.empty(num_tokens, | ||||||
|  |                         num_heads, | ||||||
|  |                         head_size, | ||||||
|  |                         dtype=dtype, | ||||||
|  |                         device='cuda') | ||||||
|  |     query.uniform_(-1e-3, 1e-3) | ||||||
|  |     output = torch.empty(num_tokens, | ||||||
|  |                          num_heads, | ||||||
|  |                          head_size, | ||||||
|  |                          dtype=dtype, | ||||||
|  |                          device='cuda') | ||||||
|  |  | ||||||
|  |     kv = torch.empty(sum(seq_lens), | ||||||
|  |                      2, | ||||||
|  |                      num_heads, | ||||||
|  |                      head_size, | ||||||
|  |                      dtype=dtype, | ||||||
|  |                      device='cuda') | ||||||
|  |     kv.uniform_(-1e-3, 1e-3) | ||||||
|  |     key, value = kv.unbind(dim=1) | ||||||
|  |  | ||||||
|  |     k_cache = torch.zeros(cache_size, | ||||||
|  |                           block_size, | ||||||
|  |                           num_heads, | ||||||
|  |                           head_size, | ||||||
|  |                           dtype=dtype, | ||||||
|  |                           device='cuda') | ||||||
|  |     v_cache = torch.zeros(cache_size, | ||||||
|  |                           block_size, | ||||||
|  |                           num_heads, | ||||||
|  |                           head_size, | ||||||
|  |                           dtype=dtype, | ||||||
|  |                           device='cuda') | ||||||
|  |     k = torch.zeros(sum(subquery_lens), | ||||||
|  |                     num_heads, | ||||||
|  |                     head_size, | ||||||
|  |                     dtype=dtype, | ||||||
|  |                     device='cuda') | ||||||
|  |     v = torch.zeros(sum(subquery_lens), | ||||||
|  |                     num_heads, | ||||||
|  |                     head_size, | ||||||
|  |                     dtype=dtype, | ||||||
|  |                     device='cuda') | ||||||
|  |     values = torch.arange(0, cache_size, dtype=torch.long, device='cuda') | ||||||
|  |     values = values[torch.randperm(cache_size)] | ||||||
|  |     block_table = values[:BS * max_block_per_request].view( | ||||||
|  |         BS, max_block_per_request) | ||||||
|  |     b_seq_len = torch.tensor(seq_lens, dtype=torch.long, device='cuda') | ||||||
|  |     b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long, device='cuda') | ||||||
|  |     b_start_loc = torch.cumsum(torch.tensor([0] + subquery_lens[:-1], | ||||||
|  |                                             dtype=torch.long, | ||||||
|  |                                             device='cuda'), | ||||||
|  |                                dim=0) | ||||||
|  |     max_input_len = MAX_SEQ_LEN | ||||||
|  |     # copy kv to cache | ||||||
|  |     b_seq_start_loc = torch.cumsum(torch.tensor([0] + seq_lens[:-1], | ||||||
|  |                                                 dtype=torch.long, | ||||||
|  |                                                 device='cuda'), | ||||||
|  |                                    dim=0) | ||||||
|  |     for i in range(BS): | ||||||
|  |         for j in range(subquery_lens[i]): | ||||||
|  |             k[b_start_loc[i] + j].copy_(key[b_seq_start_loc[i] + b_ctx_len[i] + | ||||||
|  |                                             j]) | ||||||
|  |             v[b_start_loc[i] + j].copy_(value[b_seq_start_loc[i] + | ||||||
|  |                                               b_ctx_len[i] + j]) | ||||||
|  |         cur_ctx = 0 | ||||||
|  |         block_id = 0 | ||||||
|  |         while cur_ctx < b_ctx_len[i]: | ||||||
|  |             start_loc = b_seq_start_loc[i] + cur_ctx | ||||||
|  |             if cur_ctx + block_size > b_ctx_len[i]: | ||||||
|  |                 end_loc = b_seq_start_loc[i] + b_ctx_len[i] | ||||||
|  |             else: | ||||||
|  |                 end_loc = start_loc + block_size | ||||||
|  |             start_slot = block_table[i, block_id] * block_size | ||||||
|  |             end_slot = start_slot + end_loc - start_loc | ||||||
|  |             k_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( | ||||||
|  |                 key[start_loc:end_loc]) | ||||||
|  |             v_cache.view(-1, num_heads, head_size)[start_slot:end_slot].copy_( | ||||||
|  |                 value[start_loc:end_loc]) | ||||||
|  |             cur_ctx += block_size | ||||||
|  |             block_id += 1 | ||||||
|  |     # transpose K_cache[num_blocks, block_size, num_kv_heads, head_size] | ||||||
|  |     # to K_cache[num_blocks, num_kv_heads, head_size/8, block_size, 8] | ||||||
|  |     k_cache = k_cache.view(-1, block_size, num_heads, head_size // 8, | ||||||
|  |                            8).permute(0, 2, 3, 1, 4).contiguous() | ||||||
|  |     # transpose V_cache[num_blocks, block_size, num_kv_heads, head_size] | ||||||
|  |     # to V_cache[num_blocks, num_kv_heads, head_size, block_size] | ||||||
|  |     v_cache = v_cache.view(-1, block_size, num_heads, | ||||||
|  |                            head_size).permute(0, 2, 3, 1).contiguous() | ||||||
|  |  | ||||||
|  |     # Warm up the Triton kernel by calling it once before actually measuring generation time | ||||||
|  |     context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, | ||||||
|  |                           b_start_loc, b_seq_len, b_ctx_len, max_input_len) | ||||||
|  |     torch.cuda.synchronize() | ||||||
|  |     start_time = time.time() | ||||||
|  |     context_attention_fwd(query, k, v, output, k_cache, v_cache, block_table, | ||||||
|  |                           b_start_loc, b_seq_len, b_ctx_len, max_input_len) | ||||||
|  |     torch.cuda.synchronize() | ||||||
|  |     end_time = time.time() | ||||||
|  |     print(f"triton Time: {(end_time - start_time)*1000:.2f} ms") | ||||||
|  |  | ||||||
|  |     scale = float(1.0 / (head_size**0.5)) | ||||||
|  |  | ||||||
|  |     attn_op = xops.fmha.cutlass.FwOp() | ||||||
|  |  | ||||||
|  |     attn_bias = BlockDiagonalCausalFromBottomRightMask.from_seqlens( | ||||||
|  |         subquery_lens, seq_lens) | ||||||
|  |     output_ref = xops.memory_efficient_attention_forward( | ||||||
|  |         query.unsqueeze(0), | ||||||
|  |         key.unsqueeze(0), | ||||||
|  |         value.unsqueeze(0), | ||||||
|  |         attn_bias=attn_bias, | ||||||
|  |         p=0.0, | ||||||
|  |         scale=scale, | ||||||
|  |         op=attn_op, | ||||||
|  |     ) | ||||||
|  |     torch.cuda.synchronize() | ||||||
|  |     start_time = time.time() | ||||||
|  |     output_ref = xops.memory_efficient_attention_forward( | ||||||
|  |         query.unsqueeze(0), | ||||||
|  |         key.unsqueeze(0), | ||||||
|  |         value.unsqueeze(0), | ||||||
|  |         attn_bias=attn_bias, | ||||||
|  |         p=0.0, | ||||||
|  |         scale=scale, | ||||||
|  |         op=attn_op, | ||||||
|  |     ) | ||||||
|  |     torch.cuda.synchronize() | ||||||
|  |     end_time = time.time() | ||||||
|  |     print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") | ||||||
|  |     output_ref = output_ref.squeeze(0) | ||||||
|  |     assert torch.allclose(output_ref, output, atol=1e-6, rtol=0) | ||||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user
	