mirror of
				https://github.com/vllm-project/vllm.git
				synced 2025-11-04 17:34:34 +08:00 
			
		
		
		
	Compare commits
	
		
			243 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 8fbd84bf78 | |||
| 7d2dcce175 | |||
| dc903e70ac | |||
| a9c8212895 | |||
| c20ecb6a51 | |||
| 5253edaacb | |||
| 017d9f1515 | |||
| 181b27d881 | |||
| 63e2a6419d | |||
| 264017a2bf | |||
| e433c115bc | |||
| 86fd8bb0ac | |||
| ab3a5a8259 | |||
| a61f0521b8 | |||
| 537c9755a7 | |||
| 786b7f18a5 | |||
| 8f36444c4f | |||
| 185b2c29e2 | |||
| 5f08050d8d | |||
| 64da65b322 | |||
| 5255d99dc5 | |||
| 4f2ad11135 | |||
| d7afab6d3a | |||
| 31348dff03 | |||
| 25e86b6a61 | |||
| 4efbac6d35 | |||
| 87069ccf68 | |||
| 7e45107f51 | |||
| 0c48b37c31 | |||
| 7eacffd951 | |||
| 2a543d6efe | |||
| 317b29de0f | |||
| a463c333dd | |||
| ea356004d4 | |||
| 5c976a7e1a | |||
| f964493274 | |||
| a4211a4dc3 | |||
| 563836496a | |||
| 4ca2c358b1 | |||
| 0580aab02f | |||
| 3711811b1d | |||
| 65b89d16ee | |||
| 931746bc6d | |||
| c81dddb45c | |||
| fe6d09ae61 | |||
| ed70c70ea3 | |||
| f0d4e14557 | |||
| 2ccee3def6 | |||
| b92adec8e8 | |||
| 56f738ae9b | |||
| 72d3a30c63 | |||
| c9b45adeeb | |||
| 5a6c81b051 | |||
| 51cd22ce56 | |||
| 5ed704ec8c | |||
| 4abf6336ec | |||
| 0e163fce18 | |||
| 96b6f475dd | |||
| c410f5d020 | |||
| bb8c697ee0 | |||
| b9e96b17de | |||
| 923797fea4 | |||
| cd9e60c76c | |||
| 93b38bea5d | |||
| d0d93b92b1 | |||
| 89efcf1ce5 | |||
| c664b0e683 | |||
| d69ff0cbbb | |||
| 1af090b57d | |||
| 3dad944485 | |||
| 105a40f53a | |||
| bbe9bd9684 | |||
| 4f65af0e25 | |||
| d79ced3292 | |||
| ab40644669 | |||
| 5d60def02c | |||
| ea8489fce2 | |||
| 1b20639a43 | |||
| b72af8f1ed | |||
| 9090bf02e7 | |||
| 7d648418b8 | |||
| 89be30fa7d | |||
| f8ecb84c02 | |||
| 5f036d2bcc | |||
| 380170038e | |||
| 220a47627b | |||
| beb89f68b4 | |||
| 390b495ff3 | |||
| 3a0e1fc070 | |||
| 6b7de1a030 | |||
| 5265631d15 | |||
| 2832e7b9f9 | |||
| 3a7dd7e367 | |||
| 223c19224b | |||
| f1f6cc10c7 | |||
| 3209b49033 | |||
| 1e4277d2d1 | |||
| 9b945daaf1 | |||
| 9c1352eb57 | |||
| 7a0b011dd5 | |||
| 63e835cbcc | |||
| 94b5edeb53 | |||
| ab7e6006d6 | |||
| 18bfcdd05c | |||
| 71d63ed72e | |||
| d75c40734a | |||
| 5b23c3f26f | |||
| 00efdc84ba | |||
| 91a61da9b1 | |||
| ef9b636e2d | |||
| 2709c0009a | |||
| dd7e8f5f64 | |||
| d2a68364c4 | |||
| 7e1081139d | |||
| 18473cf498 | |||
| 4df417d059 | |||
| 5d80a9178b | |||
| 8a25d3a71a | |||
| d10f8e1d43 | |||
| 14cc317ba4 | |||
| e1957c6ebd | |||
| 8cd5a992bf | |||
| 947f0b23cc | |||
| f780504d12 | |||
| bfc072addf | |||
| 2a18da257c | |||
| 6e01e8c1c8 | |||
| 9f659bf07f | |||
| 35c4bc20d9 | |||
| 218dc2ccda | |||
| 827cbcd37c | |||
| cb7a1c1cbf | |||
| 7878958c0d | |||
| ce036244c9 | |||
| 48cf1e413c | |||
| 97460585d9 | |||
| f745847ef7 | |||
| 6549aef245 | |||
| 50376faa7b | |||
| 4b61c6b669 | |||
| 79d64c4954 | |||
| 74cd5abdd1 | |||
| 28c3f12104 | |||
| c884819135 | |||
| 05921a9a7a | |||
| d0215a58e7 | |||
| 937e7b7d7c | |||
| aee8ef661a | |||
| 2e0b6e7757 | |||
| 941767127c | |||
| 74d8d77626 | |||
| fd4ea8ef5c | |||
| 1066cbd152 | |||
| 6ef00b03a2 | |||
| 9140561059 | |||
| 77af974b40 | |||
| 4934d49274 | |||
| 358c328d69 | |||
| 4aaafdd289 | |||
| 66b108d142 | |||
| e0ff920001 | |||
| face83c7ec | |||
| 1db83e31a2 | |||
| a1b9cb2a34 | |||
| 3a4fd5ca59 | |||
| c17daa9f89 | |||
| bd29cf3d3a | |||
| 31bff69151 | |||
| ba4f826738 | |||
| de60a3fb93 | |||
| 21d5daa4ac | |||
| 290e015c6c | |||
| 1b7c791d60 | |||
| bbe4466fd9 | |||
| 08133c4d1a | |||
| 76a7983b23 | |||
| 8041b7305e | |||
| 3ec8c25cd0 | |||
| 671af2b1c0 | |||
| 6f41f0e377 | |||
| 2c9b638065 | |||
| a7347d9a6d | |||
| f8c688d746 | |||
| c9fadda543 | |||
| 30fb0956df | |||
| 3a765bd5e1 | |||
| 26c52a5ea6 | |||
| c3372e87be | |||
| b0a1d667b0 | |||
| e1d5402238 | |||
| 3d1cfbfc74 | |||
| 37ca558103 | |||
| eed74a558f | |||
| 2acd76f346 | |||
| b81a6a6bb3 | |||
| 0fbfc4b81b | |||
| c06170cc8e | |||
| 614856da25 | |||
| 05bdf4eaf3 | |||
| 6774bd50b0 | |||
| 31c1f3255e | |||
| 21d93c140d | |||
| f1c8520146 | |||
| 096827c284 | |||
| 6565d9e33e | |||
| f375ec8440 | |||
| 518369d78c | |||
| 30bad5c492 | |||
| 3fefe271ec | |||
| 6428f1d051 | |||
| 7e1b21daac | |||
| cb3f30c600 | |||
| f3e024bece | |||
| 31d2ab4aff | |||
| eb17212858 | |||
| 4dd4b5c538 | |||
| 6120e5aaea | |||
| 2eaa81b236 | |||
| 81ce2a4b26 | |||
| 5dd80d3777 | |||
| beeee69bc9 | |||
| 9bf28d0b69 | |||
| c0ce15dfb2 | |||
| b9bcdc7158 | |||
| 4ff0203987 | |||
| b5f882cc98 | |||
| 2e8fc0d4c3 | |||
| dacaf5a400 | |||
| 24cde76a15 | |||
| 1aa1361510 | |||
| fe470ae5ad | |||
| 3a8c2381f7 | |||
| c85b80c2b6 | |||
| 2b981012a6 | |||
| 6ccc0bfffb | |||
| c8e7eb1eb3 | |||
| 24f60a54f4 | |||
| 42c02f5892 | |||
| ebede26ebf | |||
| d940ce497e | |||
| 05ff90b692 | |||
| 1d9b737e05 | |||
| 60dc62dc9e | 
							
								
								
									
										69
									
								
								.buildkite/run-benchmarks.sh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								.buildkite/run-benchmarks.sh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,69 @@
 | 
			
		||||
# This script is run by buildkite to run the benchmarks and upload the results to buildkite
 | 
			
		||||
 | 
			
		||||
set -ex
 | 
			
		||||
set -o pipefail
 | 
			
		||||
 | 
			
		||||
# cd into parent directory of this file
 | 
			
		||||
cd "$(dirname "${BASH_SOURCE[0]}")/.."
 | 
			
		||||
 | 
			
		||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
 | 
			
		||||
 | 
			
		||||
# run python-based benchmarks and upload the result to buildkite
 | 
			
		||||
python3 benchmarks/benchmark_latency.py 2>&1 | tee benchmark_latency.txt
 | 
			
		||||
bench_latency_exit_code=$?
 | 
			
		||||
 | 
			
		||||
python3 benchmarks/benchmark_throughput.py --input-len 256 --output-len 256 2>&1 | tee benchmark_throughput.txt
 | 
			
		||||
bench_throughput_exit_code=$?
 | 
			
		||||
 | 
			
		||||
# run server-based benchmarks and upload the result to buildkite
 | 
			
		||||
python3 -m vllm.entrypoints.openai.api_server --model meta-llama/Llama-2-7b-chat-hf &
 | 
			
		||||
server_pid=$!
 | 
			
		||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
 | 
			
		||||
 | 
			
		||||
# wait for server to start, timeout after 600 seconds
 | 
			
		||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
 | 
			
		||||
python3 benchmarks/benchmark_serving.py \
 | 
			
		||||
    --backend openai \
 | 
			
		||||
    --dataset ./ShareGPT_V3_unfiltered_cleaned_split.json \
 | 
			
		||||
    --model meta-llama/Llama-2-7b-chat-hf \
 | 
			
		||||
    --num-prompts 20 \
 | 
			
		||||
    --endpoint /v1/completions \
 | 
			
		||||
    --tokenizer meta-llama/Llama-2-7b-chat-hf \
 | 
			
		||||
    --save-result \
 | 
			
		||||
    2>&1 | tee benchmark_serving.txt
 | 
			
		||||
bench_serving_exit_code=$?
 | 
			
		||||
kill $server_pid
 | 
			
		||||
 | 
			
		||||
# write the results into a markdown file
 | 
			
		||||
echo "### Latency Benchmarks" >> benchmark_results.md
 | 
			
		||||
sed -n '1p' benchmark_latency.txt >> benchmark_results.md # first line
 | 
			
		||||
echo "" >> benchmark_results.md
 | 
			
		||||
sed -n '$p' benchmark_latency.txt >> benchmark_results.md # last line
 | 
			
		||||
 | 
			
		||||
echo "### Throughput Benchmarks" >> benchmark_results.md
 | 
			
		||||
sed -n '1p' benchmark_throughput.txt >> benchmark_results.md # first line
 | 
			
		||||
echo "" >> benchmark_results.md
 | 
			
		||||
sed -n '$p' benchmark_throughput.txt >> benchmark_results.md # last line
 | 
			
		||||
 | 
			
		||||
echo "### Serving Benchmarks" >> benchmark_results.md
 | 
			
		||||
sed -n '1p' benchmark_serving.txt >> benchmark_results.md # first line
 | 
			
		||||
echo "" >> benchmark_results.md
 | 
			
		||||
tail -n 13 benchmark_serving.txt >> benchmark_results.md # last 13 lines
 | 
			
		||||
 | 
			
		||||
# upload the results to buildkite
 | 
			
		||||
/workspace/buildkite-agent annotate --style "info" --context "benchmark-results" < benchmark_results.md
 | 
			
		||||
 | 
			
		||||
# exit with the exit code of the benchmarks
 | 
			
		||||
if [ $bench_latency_exit_code -ne 0 ]; then
 | 
			
		||||
    exit $bench_latency_exit_code
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [ $bench_throughput_exit_code -ne 0 ]; then
 | 
			
		||||
    exit $bench_throughput_exit_code
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
if [ $bench_serving_exit_code -ne 0 ]; then
 | 
			
		||||
    exit $bench_serving_exit_code
 | 
			
		||||
fi
 | 
			
		||||
 | 
			
		||||
/workspace/buildkite-agent artifact upload openai-*.json
 | 
			
		||||
							
								
								
									
										66
									
								
								.buildkite/test-pipeline.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										66
									
								
								.buildkite/test-pipeline.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,66 @@
 | 
			
		||||
# In this file, you can add more tests to run either by adding a new step or
 | 
			
		||||
# adding a new command to an existing step. See different options here for examples.
 | 
			
		||||
# This script will be feed into Jinja template in `test-template.j2` to generate
 | 
			
		||||
# the final pipeline yaml file.
 | 
			
		||||
 | 
			
		||||
steps:
 | 
			
		||||
- label: Regression Test
 | 
			
		||||
  command: pytest -v -s test_regression.py
 | 
			
		||||
  working_dir: "/vllm-workspace/tests" # optional
 | 
			
		||||
 | 
			
		||||
- label: AsyncEngine Test
 | 
			
		||||
  command: pytest -v -s async_engine
 | 
			
		||||
 | 
			
		||||
- label: Basic Correctness Test
 | 
			
		||||
  command: pytest -v -s --forked basic_correctness
 | 
			
		||||
 | 
			
		||||
- label: Distributed Comm Ops Test
 | 
			
		||||
  command: pytest -v -s --forked test_comm_ops.py
 | 
			
		||||
  working_dir: "/vllm-workspace/tests/distributed"
 | 
			
		||||
  num_gpus: 2 # only support 1 or 2 for now.
 | 
			
		||||
 | 
			
		||||
- label: Distributed Correctness Test
 | 
			
		||||
  command: pytest -v -s --forked test_basic_distributed_correctness.py
 | 
			
		||||
  working_dir: "/vllm-workspace/tests/distributed"
 | 
			
		||||
  num_gpus: 2 # only support 1 or 2 for now.
 | 
			
		||||
 | 
			
		||||
- label: Engine Test
 | 
			
		||||
  command: pytest -v -s engine
 | 
			
		||||
 | 
			
		||||
- label: Entrypoints Test
 | 
			
		||||
  command: pytest -v -s entrypoints
 | 
			
		||||
 | 
			
		||||
- label: Kernels Test
 | 
			
		||||
  command: pytest -v -s kernels
 | 
			
		||||
  soft_fail: true
 | 
			
		||||
 | 
			
		||||
- label: Models Test
 | 
			
		||||
  commands:
 | 
			
		||||
    - pytest -v -s models --forked
 | 
			
		||||
  soft_fail: true
 | 
			
		||||
 | 
			
		||||
- label: Prefix Caching Test
 | 
			
		||||
  commands:
 | 
			
		||||
    - pytest -v -s prefix_caching
 | 
			
		||||
 | 
			
		||||
- label: Samplers Test
 | 
			
		||||
  command: pytest -v -s samplers --forked
 | 
			
		||||
 | 
			
		||||
- label: Worker Test
 | 
			
		||||
  command: pytest -v -s worker
 | 
			
		||||
 | 
			
		||||
- label: LoRA Test
 | 
			
		||||
  command: pytest -v -s lora
 | 
			
		||||
 | 
			
		||||
- label: Benchmarks
 | 
			
		||||
  working_dir: "/vllm-workspace/.buildkite"
 | 
			
		||||
  commands:
 | 
			
		||||
  - pip install aiohttp
 | 
			
		||||
  - bash run-benchmarks.sh
 | 
			
		||||
 | 
			
		||||
- label: Documentation Build
 | 
			
		||||
  working_dir: "/vllm-workspace/docs"
 | 
			
		||||
  no_gpu: True
 | 
			
		||||
  commands:
 | 
			
		||||
  - pip install -r requirements-docs.txt
 | 
			
		||||
  - SPHINXOPTS=\"-W\" make html
 | 
			
		||||
							
								
								
									
										56
									
								
								.buildkite/test-template.j2
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										56
									
								
								.buildkite/test-template.j2
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,56 @@
 | 
			
		||||
{% set docker_image = "us-central1-docker.pkg.dev/vllm-405802/vllm-ci-test-repo/vllm-test:$BUILDKITE_COMMIT" %}
 | 
			
		||||
{% set default_num_gpu = 1 %}
 | 
			
		||||
{% set default_working_dir = "/vllm-workspace/tests" %}
 | 
			
		||||
 | 
			
		||||
steps:
 | 
			
		||||
  - label: ":docker: build image"
 | 
			
		||||
    commands:
 | 
			
		||||
      - "docker build --build-arg max_jobs=16 --tag {{ docker_image }} --target test --progress plain ."
 | 
			
		||||
      - "docker push {{ docker_image }}"
 | 
			
		||||
    env:
 | 
			
		||||
      DOCKER_BUILDKIT: "1"
 | 
			
		||||
    retry:
 | 
			
		||||
      automatic:
 | 
			
		||||
        - exit_status: -1  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
  - wait
 | 
			
		||||
 | 
			
		||||
  {% for step in steps %}
 | 
			
		||||
  - label: "{{ step.label }}"
 | 
			
		||||
    agents:
 | 
			
		||||
      queue: kubernetes
 | 
			
		||||
    soft_fail: {{ step.soft_fail or false }}
 | 
			
		||||
    retry:
 | 
			
		||||
      automatic:
 | 
			
		||||
        - exit_status: -1  # Agent was lost
 | 
			
		||||
          limit: 5
 | 
			
		||||
    plugins:
 | 
			
		||||
      - kubernetes:
 | 
			
		||||
          podSpec:
 | 
			
		||||
            volumes:
 | 
			
		||||
              - name: dshm
 | 
			
		||||
                emptyDir:
 | 
			
		||||
                  medium: Memory
 | 
			
		||||
            containers:
 | 
			
		||||
              - image: "{{ docker_image }}"
 | 
			
		||||
                command: ["bash"]
 | 
			
		||||
                args:
 | 
			
		||||
                - '-c'
 | 
			
		||||
                - "'cd {{ (step.working_dir or default_working_dir) | safe  }} && {{ step.command  or (step.commands | join(' && ')) | safe }}'"
 | 
			
		||||
                {% if not step.no_gpu %}
 | 
			
		||||
                resources:
 | 
			
		||||
                  requests:
 | 
			
		||||
                    nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
 | 
			
		||||
                  limits:
 | 
			
		||||
                    nvidia.com/gpu: "{{ step.num_gpus or default_num_gpu }}"
 | 
			
		||||
                {% endif %}
 | 
			
		||||
                env:
 | 
			
		||||
                  - name: HF_TOKEN
 | 
			
		||||
                    valueFrom:
 | 
			
		||||
                      secretKeyRef:
 | 
			
		||||
                        name: hf-token-secret
 | 
			
		||||
                        key: token
 | 
			
		||||
                volumeMounts:
 | 
			
		||||
                  - mountPath: /dev/shm
 | 
			
		||||
                    name: dshm
 | 
			
		||||
  {% endfor %}
 | 
			
		||||
							
								
								
									
										1
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1 @@
 | 
			
		||||
vllm/*.so
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/publish.yml
									
									
									
									
										vendored
									
									
								
							@ -49,7 +49,7 @@ jobs:
 | 
			
		||||
      matrix:
 | 
			
		||||
          os: ['ubuntu-20.04']
 | 
			
		||||
          python-version: ['3.8', '3.9', '3.10', '3.11']
 | 
			
		||||
          pytorch-version: ['2.1.0']
 | 
			
		||||
          pytorch-version: ['2.1.2']  # Must be the most recent version that meets requirements.txt.
 | 
			
		||||
          cuda-version: ['11.8', '12.1']
 | 
			
		||||
 | 
			
		||||
    steps:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/scripts/build.sh
									
									
									
									
										vendored
									
									
								
							@ -13,6 +13,8 @@ $python_executable -m pip install -r requirements.txt
 | 
			
		||||
 | 
			
		||||
# Limit the number of parallel jobs to avoid OOM
 | 
			
		||||
export MAX_JOBS=1
 | 
			
		||||
# Make sure punica is built for the release (for LoRA)
 | 
			
		||||
export VLLM_INSTALL_PUNICA_KERNELS=1
 | 
			
		||||
 | 
			
		||||
# Build
 | 
			
		||||
$python_executable setup.py bdist_wheel --dist-dir=dist
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/workflows/yapf.yml
									
									
									
									
										vendored
									
									
								
							@ -28,4 +28,4 @@ jobs:
 | 
			
		||||
        pip install toml==0.10.2
 | 
			
		||||
    - name: Running yapf
 | 
			
		||||
      run: |
 | 
			
		||||
        yapf --diff --recursive vllm tests
 | 
			
		||||
        yapf --diff --recursive .
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										7
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -177,3 +177,10 @@ _build/
 | 
			
		||||
# vim swap files
 | 
			
		||||
*.swo
 | 
			
		||||
*.swp
 | 
			
		||||
 | 
			
		||||
# hip files generated by PyTorch
 | 
			
		||||
*.hip
 | 
			
		||||
*_hip*
 | 
			
		||||
 | 
			
		||||
# Benchmark dataset
 | 
			
		||||
*.json
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										64
									
								
								Dockerfile
									
									
									
									
									
								
							
							
						
						
									
										64
									
								
								Dockerfile
									
									
									
									
									
								
							@ -1,7 +1,17 @@
 | 
			
		||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
 | 
			
		||||
# to run the OpenAI compatible server.
 | 
			
		||||
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
FROM nvidia/cuda:12.1.0-devel-ubuntu22.04 AS dev
 | 
			
		||||
 | 
			
		||||
RUN apt-get update -y \
 | 
			
		||||
    && apt-get install -y python3-pip
 | 
			
		||||
    && apt-get install -y python3-pip git
 | 
			
		||||
 | 
			
		||||
# Workaround for https://github.com/openai/triton/issues/2507 and
 | 
			
		||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
 | 
			
		||||
# this won't be needed for future versions of this docker image
 | 
			
		||||
# or future versions of triton.
 | 
			
		||||
RUN ldconfig /usr/local/cuda-12.1/compat/
 | 
			
		||||
 | 
			
		||||
WORKDIR /workspace
 | 
			
		||||
 | 
			
		||||
@ -14,8 +24,10 @@ RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
COPY requirements-dev.txt requirements-dev.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install -r requirements-dev.txt
 | 
			
		||||
#################### BASE BUILD IMAGE ####################
 | 
			
		||||
 | 
			
		||||
# image to build pytorch extensions
 | 
			
		||||
 | 
			
		||||
#################### EXTENSION BUILD IMAGE ####################
 | 
			
		||||
FROM dev AS build
 | 
			
		||||
 | 
			
		||||
# install build dependencies
 | 
			
		||||
@ -30,23 +42,43 @@ COPY requirements.txt requirements.txt
 | 
			
		||||
COPY pyproject.toml pyproject.toml
 | 
			
		||||
COPY vllm/__init__.py vllm/__init__.py
 | 
			
		||||
 | 
			
		||||
# cuda arch list used by torch
 | 
			
		||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.6 8.9 9.0+PTX'
 | 
			
		||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
 | 
			
		||||
# max jobs used by Ninja to build extensions
 | 
			
		||||
ENV MAX_JOBS=$max_jobs
 | 
			
		||||
RUN python3 setup.py build_ext --inplace
 | 
			
		||||
ARG max_jobs=2
 | 
			
		||||
ENV MAX_JOBS=${max_jobs}
 | 
			
		||||
# number of threads used by nvcc
 | 
			
		||||
ARG nvcc_threads=8
 | 
			
		||||
ENV NVCC_THREADS=$nvcc_threads
 | 
			
		||||
# make sure punica kernels are built (for LoRA)
 | 
			
		||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
 | 
			
		||||
 | 
			
		||||
RUN python3 setup.py build_ext --inplace
 | 
			
		||||
#################### EXTENSION Build IMAGE ####################
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#################### TEST IMAGE ####################
 | 
			
		||||
# image to run unit testing suite
 | 
			
		||||
FROM dev AS test
 | 
			
		||||
 | 
			
		||||
# copy pytorch extensions separately to avoid having to rebuild
 | 
			
		||||
# when python code changes
 | 
			
		||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
 | 
			
		||||
COPY tests tests
 | 
			
		||||
COPY vllm vllm
 | 
			
		||||
WORKDIR /vllm-workspace
 | 
			
		||||
# ADD is used to preserve directory structure
 | 
			
		||||
ADD . /vllm-workspace/
 | 
			
		||||
COPY --from=build /workspace/vllm/*.so /vllm-workspace/vllm/
 | 
			
		||||
# ignore build dependencies installation because we are using pre-complied extensions
 | 
			
		||||
RUN rm pyproject.toml
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip VLLM_USE_PRECOMPILED=1 pip install . --verbose
 | 
			
		||||
#################### TEST IMAGE ####################
 | 
			
		||||
 | 
			
		||||
ENTRYPOINT ["python3", "-m", "pytest", "tests"]
 | 
			
		||||
 | 
			
		||||
# use CUDA base as CUDA runtime dependencies are already installed via pip
 | 
			
		||||
FROM nvidia/cuda:12.1.0-base-ubuntu22.04 AS vllm-base
 | 
			
		||||
#################### RUNTIME BASE IMAGE ####################
 | 
			
		||||
# We used base cuda image because pytorch installs its own cuda libraries.
 | 
			
		||||
# However cupy depends on cuda libraries so we had to switch to the runtime image
 | 
			
		||||
# In the future it would be nice to get a container with pytorch and cuda without duplicating cuda
 | 
			
		||||
FROM nvidia/cuda:12.1.0-runtime-ubuntu22.04 AS vllm-base
 | 
			
		||||
 | 
			
		||||
# libnccl required for ray
 | 
			
		||||
RUN apt-get update -y \
 | 
			
		||||
@ -56,22 +88,18 @@ WORKDIR /workspace
 | 
			
		||||
COPY requirements.txt requirements.txt
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install -r requirements.txt
 | 
			
		||||
#################### RUNTIME BASE IMAGE ####################
 | 
			
		||||
 | 
			
		||||
FROM vllm-base AS vllm
 | 
			
		||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
 | 
			
		||||
COPY vllm vllm
 | 
			
		||||
 | 
			
		||||
EXPOSE 8000
 | 
			
		||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.api_server"]
 | 
			
		||||
 | 
			
		||||
#################### OPENAI API SERVER ####################
 | 
			
		||||
# openai api server alternative
 | 
			
		||||
FROM vllm-base AS vllm-openai
 | 
			
		||||
# install additional dependencies for openai api server
 | 
			
		||||
RUN --mount=type=cache,target=/root/.cache/pip \
 | 
			
		||||
    pip install accelerate fschat
 | 
			
		||||
    pip install accelerate
 | 
			
		||||
 | 
			
		||||
COPY --from=build /workspace/vllm/*.so /workspace/vllm/
 | 
			
		||||
COPY vllm vllm
 | 
			
		||||
 | 
			
		||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
 | 
			
		||||
 | 
			
		||||
#################### OPENAI API SERVER ####################
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										95
									
								
								Dockerfile.rocm
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										95
									
								
								Dockerfile.rocm
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,95 @@
 | 
			
		||||
# default base image
 | 
			
		||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
 | 
			
		||||
 | 
			
		||||
FROM $BASE_IMAGE
 | 
			
		||||
 | 
			
		||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
 | 
			
		||||
 | 
			
		||||
RUN echo "Base image is $BASE_IMAGE"
 | 
			
		||||
 | 
			
		||||
# BASE_IMAGE for ROCm_5.7: "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1"
 | 
			
		||||
# BASE_IMAGE for ROCm_6.0: "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
 | 
			
		||||
RUN echo "FA_GFX_ARCHS is $FA_GFX_ARCHS"
 | 
			
		||||
 | 
			
		||||
ARG FA_BRANCH="3d2b6f5"
 | 
			
		||||
RUN echo "FA_BRANCH is $FA_BRANCH"
 | 
			
		||||
 | 
			
		||||
# whether to build flash-attention
 | 
			
		||||
# if 0, will not build flash attention
 | 
			
		||||
# this is useful for gfx target where flash-attention is not supported
 | 
			
		||||
# In that case, we need to use the python reference attention implementation in vllm
 | 
			
		||||
ARG BUILD_FA="1"
 | 
			
		||||
 | 
			
		||||
# Install some basic utilities
 | 
			
		||||
RUN apt-get update && apt-get install python3 python3-pip -y
 | 
			
		||||
 | 
			
		||||
# Install some basic utilities
 | 
			
		||||
RUN apt-get update && apt-get install -y \
 | 
			
		||||
    curl \
 | 
			
		||||
    ca-certificates \
 | 
			
		||||
    sudo \
 | 
			
		||||
    git \
 | 
			
		||||
    bzip2 \
 | 
			
		||||
    libx11-6 \
 | 
			
		||||
    build-essential \
 | 
			
		||||
    wget \
 | 
			
		||||
    unzip \
 | 
			
		||||
    nvidia-cuda-toolkit \
 | 
			
		||||
    tmux \
 | 
			
		||||
 && rm -rf /var/lib/apt/lists/*
 | 
			
		||||
 | 
			
		||||
### Mount Point ###
 | 
			
		||||
# When launching the container, mount the code directory to /app
 | 
			
		||||
ARG APP_MOUNT=/app
 | 
			
		||||
VOLUME [ ${APP_MOUNT} ]
 | 
			
		||||
WORKDIR ${APP_MOUNT}
 | 
			
		||||
 | 
			
		||||
RUN python3 -m pip install --upgrade pip
 | 
			
		||||
RUN python3 -m pip install --no-cache-dir fastapi ninja tokenizers pandas
 | 
			
		||||
 | 
			
		||||
ENV LLVM_SYMBOLIZER_PATH=/opt/rocm/llvm/bin/llvm-symbolizer
 | 
			
		||||
ENV PATH=$PATH:/opt/rocm/bin:/libtorch/bin:
 | 
			
		||||
ENV LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm/lib/:/libtorch/lib:
 | 
			
		||||
ENV CPLUS_INCLUDE_PATH=$CPLUS_INCLUDE_PATH:/libtorch/include:/libtorch/include/torch/csrc/api/include/:/opt/rocm/include/:
 | 
			
		||||
 | 
			
		||||
# Install ROCm flash-attention
 | 
			
		||||
RUN if [ "$BUILD_FA" = "1" ]; then \
 | 
			
		||||
    mkdir libs \
 | 
			
		||||
    && cd libs \
 | 
			
		||||
    && git clone https://github.com/ROCm/flash-attention.git \
 | 
			
		||||
    && cd flash-attention \
 | 
			
		||||
    && git checkout ${FA_BRANCH} \
 | 
			
		||||
    && git submodule update --init \
 | 
			
		||||
    && export GPU_ARCHS=${FA_GFX_ARCHS} \
 | 
			
		||||
    && if [ "$BASE_IMAGE" = "rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" ]; then \
 | 
			
		||||
        patch /opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/utils/hipify/hipify_python.py hipify_patch.patch; fi \
 | 
			
		||||
    && python3 setup.py install \
 | 
			
		||||
    && cd ..; \
 | 
			
		||||
    fi
 | 
			
		||||
 | 
			
		||||
COPY ./ /app/vllm
 | 
			
		||||
 | 
			
		||||
RUN python3 -m pip install --upgrade pip
 | 
			
		||||
RUN python3 -m pip install xformers==0.0.23 --no-deps
 | 
			
		||||
 | 
			
		||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
 | 
			
		||||
# Manually removed it so that later steps of numpy upgrade can continue
 | 
			
		||||
RUN if [ "$BASE_IMAGE" = "rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" ]; then \
 | 
			
		||||
    rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/; fi
 | 
			
		||||
 | 
			
		||||
RUN cd /app \
 | 
			
		||||
    && cd vllm \
 | 
			
		||||
    && pip install -U -r requirements-rocm.txt \
 | 
			
		||||
    && if [ "$BUILD_FA" = "1" ]; then \
 | 
			
		||||
       bash patch_xformers.rocm.sh; fi \
 | 
			
		||||
    && patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h /app/vllm/rocm_patch/rocm_bf16.patch \
 | 
			
		||||
    && python3 setup.py install \
 | 
			
		||||
    && cd ..
 | 
			
		||||
 | 
			
		||||
RUN python3 -m pip install --upgrade pip
 | 
			
		||||
RUN python3 -m pip install --no-cache-dir ray[all]
 | 
			
		||||
 | 
			
		||||
CMD ["/bin/bash"]
 | 
			
		||||
							
								
								
									
										19
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										19
									
								
								README.md
									
									
									
									
									
								
							@ -17,6 +17,9 @@ Easy, fast, and cheap LLM serving for everyone
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
*Latest News* 🔥
 | 
			
		||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
 | 
			
		||||
- [2024/01] Added ROCm 6.0 support to vLLM.
 | 
			
		||||
- [2023/12] Added ROCm 5.7 support to vLLM.
 | 
			
		||||
- [2023/10] We hosted [the first vLLM meetup](https://lu.ma/first-vllm-meetup) in SF! Please find the meetup slides [here](https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?usp=sharing).
 | 
			
		||||
- [2023/09] We created our [Discord server](https://discord.gg/jz7wjKhh6g)! Join us to discuss vLLM and LLM serving! We will also post the latest announcements and updates there.
 | 
			
		||||
- [2023/09] We released our [PagedAttention paper](https://arxiv.org/abs/2309.06180) on arXiv!
 | 
			
		||||
@ -26,7 +29,7 @@ Easy, fast, and cheap LLM serving for everyone
 | 
			
		||||
- [2023/06] We officially released vLLM! FastChat-vLLM integration has powered [LMSYS Vicuna and Chatbot Arena](https://chat.lmsys.org) since mid-April. Check out our [blog post](https://vllm.ai).
 | 
			
		||||
 | 
			
		||||
---
 | 
			
		||||
 | 
			
		||||
## About
 | 
			
		||||
vLLM is a fast and easy-to-use library for LLM inference and serving.
 | 
			
		||||
 | 
			
		||||
vLLM is fast with:
 | 
			
		||||
@ -34,6 +37,8 @@ vLLM is fast with:
 | 
			
		||||
- State-of-the-art serving throughput
 | 
			
		||||
- Efficient management of attention key and value memory with **PagedAttention**
 | 
			
		||||
- Continuous batching of incoming requests
 | 
			
		||||
- Fast model execution with CUDA/HIP graph
 | 
			
		||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
 | 
			
		||||
- Optimized CUDA kernels
 | 
			
		||||
 | 
			
		||||
vLLM is flexible and easy to use with:
 | 
			
		||||
@ -43,6 +48,9 @@ vLLM is flexible and easy to use with:
 | 
			
		||||
- Tensor parallelism support for distributed inference
 | 
			
		||||
- Streaming outputs
 | 
			
		||||
- OpenAI-compatible API server
 | 
			
		||||
- Support NVIDIA GPUs and AMD GPUs
 | 
			
		||||
- (Experimental) Prefix caching support
 | 
			
		||||
- (Experimental) Multi-lora support
 | 
			
		||||
 | 
			
		||||
vLLM seamlessly supports many Hugging Face models, including the following architectures:
 | 
			
		||||
 | 
			
		||||
@ -50,18 +58,25 @@ vLLM seamlessly supports many Hugging Face models, including the following archi
 | 
			
		||||
- Baichuan & Baichuan2 (`baichuan-inc/Baichuan2-13B-Chat`, `baichuan-inc/Baichuan-7B`, etc.)
 | 
			
		||||
- BLOOM (`bigscience/bloom`, `bigscience/bloomz`, etc.)
 | 
			
		||||
- ChatGLM (`THUDM/chatglm2-6b`, `THUDM/chatglm3-6b`, etc.)
 | 
			
		||||
- DeciLM (`Deci/DeciLM-7B`, `Deci/DeciLM-7B-instruct`, etc.)
 | 
			
		||||
- Falcon (`tiiuae/falcon-7b`, `tiiuae/falcon-40b`, `tiiuae/falcon-rw-7b`, etc.)
 | 
			
		||||
- Gemma (`google/gemma-2b`, `google/gemma-7b`, etc.)
 | 
			
		||||
- GPT-2 (`gpt2`, `gpt2-xl`, etc.)
 | 
			
		||||
- GPT BigCode (`bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, etc.)
 | 
			
		||||
- GPT-J (`EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc.)
 | 
			
		||||
- GPT-NeoX (`EleutherAI/gpt-neox-20b`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc.)
 | 
			
		||||
- InternLM (`internlm/internlm-7b`, `internlm/internlm-chat-7b`, etc.)
 | 
			
		||||
- InternLM2 (`internlm/internlm2-7b`, `internlm/internlm2-chat-7b`, etc.)
 | 
			
		||||
- LLaMA & LLaMA-2 (`meta-llama/Llama-2-70b-hf`, `lmsys/vicuna-13b-v1.3`, `young-geng/koala`, `openlm-research/open_llama_13b`, etc.)
 | 
			
		||||
- Mistral (`mistralai/Mistral-7B-v0.1`, `mistralai/Mistral-7B-Instruct-v0.1`, etc.)
 | 
			
		||||
- Mixtral (`mistralai/Mixtral-8x7B-v0.1`, `mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.)
 | 
			
		||||
- MPT (`mosaicml/mpt-7b`, `mosaicml/mpt-30b`, etc.)
 | 
			
		||||
- OLMo (`allenai/OLMo-1B`, `allenai/OLMo-7B`, etc.)
 | 
			
		||||
- OPT (`facebook/opt-66b`, `facebook/opt-iml-max-30b`, etc.)
 | 
			
		||||
- Phi-1.5 (`microsoft/phi-1_5`, etc.)
 | 
			
		||||
- Phi (`microsoft/phi-1_5`, `microsoft/phi-2`, etc.)
 | 
			
		||||
- Qwen (`Qwen/Qwen-7B`, `Qwen/Qwen-7B-Chat`, etc.)
 | 
			
		||||
- Qwen2 (`Qwen/Qwen2-7B-beta`, `Qwen/Qwen-7B-Chat-beta`, etc.)
 | 
			
		||||
- StableLM(`stabilityai/stablelm-3b-4e1t`, `stabilityai/stablelm-base-alpha-7b-v2`, etc.)
 | 
			
		||||
- Yi (`01-ai/Yi-6B`, `01-ai/Yi-34B`, etc.)
 | 
			
		||||
 | 
			
		||||
Install vLLM with pip or [from source](https://vllm.readthedocs.io/en/latest/getting_started/installation.html#build-from-source):
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										284
									
								
								benchmarks/backend_request_func.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								benchmarks/backend_request_func.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,284 @@
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
from tqdm.asyncio import tqdm
 | 
			
		||||
 | 
			
		||||
AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RequestFuncInput:
 | 
			
		||||
    prompt: str
 | 
			
		||||
    api_url: str
 | 
			
		||||
    prompt_len: int
 | 
			
		||||
    output_len: int
 | 
			
		||||
    model: str
 | 
			
		||||
    best_of: int = 1
 | 
			
		||||
    use_beam_search: bool = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class RequestFuncOutput:
 | 
			
		||||
    generated_text: str = ""
 | 
			
		||||
    success: bool = False
 | 
			
		||||
    latency: float = 0
 | 
			
		||||
    ttft: float = 0
 | 
			
		||||
    prompt_len: int = 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def async_request_tgi(
 | 
			
		||||
    request_func_input: RequestFuncInput,
 | 
			
		||||
    pbar: Optional[tqdm] = None,
 | 
			
		||||
) -> RequestFuncOutput:
 | 
			
		||||
    api_url = request_func_input.api_url
 | 
			
		||||
    assert api_url.endswith("generate_stream")
 | 
			
		||||
 | 
			
		||||
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 | 
			
		||||
        assert not request_func_input.use_beam_search
 | 
			
		||||
        params = {
 | 
			
		||||
            "best_of": request_func_input.best_of,
 | 
			
		||||
            "max_new_tokens": request_func_input.output_len,
 | 
			
		||||
            "do_sample": True,
 | 
			
		||||
            "temperature": 0.01,  # TGI does not accept 0.0 temperature.
 | 
			
		||||
            "top_p": 0.99,  # TGI does not accept 1.0 top_p.
 | 
			
		||||
        }
 | 
			
		||||
        payload = {
 | 
			
		||||
            "inputs": request_func_input.prompt,
 | 
			
		||||
            "parameters": params,
 | 
			
		||||
        }
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for data in response.content.iter_any():
 | 
			
		||||
                        if ttft == 0:
 | 
			
		||||
                            ttft = time.perf_counter() - st
 | 
			
		||||
                            output.ttft = ttft
 | 
			
		||||
                    output.latency = time.perf_counter() - st
 | 
			
		||||
 | 
			
		||||
                    body = data.decode("utf-8").lstrip("data:")
 | 
			
		||||
                    output.generated_text = json.loads(body)["generated_text"]
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                else:
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
 | 
			
		||||
            output.success = False
 | 
			
		||||
 | 
			
		||||
        if pbar:
 | 
			
		||||
            pbar.update(1)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def async_request_vllm(
 | 
			
		||||
    request_func_input: RequestFuncInput,
 | 
			
		||||
    pbar: Optional[tqdm] = None,
 | 
			
		||||
) -> RequestFuncOutput:
 | 
			
		||||
    api_url = request_func_input.api_url
 | 
			
		||||
    assert api_url.endswith("generate")
 | 
			
		||||
 | 
			
		||||
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 | 
			
		||||
        payload = {
 | 
			
		||||
            "prompt": request_func_input.prompt,
 | 
			
		||||
            "n": 1,
 | 
			
		||||
            "best_of": request_func_input.best_of,
 | 
			
		||||
            "use_beam_search": request_func_input.use_beam_search,
 | 
			
		||||
            "temperature": 0.0 if request_func_input.use_beam_search else 1.0,
 | 
			
		||||
            "top_p": 1.0,
 | 
			
		||||
            "max_tokens": request_func_input.output_len,
 | 
			
		||||
            "ignore_eos": True,
 | 
			
		||||
            "stream": True,
 | 
			
		||||
        }
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for data in response.content.iter_any():
 | 
			
		||||
                        if ttft == 0:
 | 
			
		||||
                            ttft = time.perf_counter() - st
 | 
			
		||||
                            output.ttft = ttft
 | 
			
		||||
                    output.latency = time.perf_counter() - st
 | 
			
		||||
 | 
			
		||||
                    # When streaming, '\0' is appended to the end of the response.
 | 
			
		||||
                    body = data.decode("utf-8").strip("\0")
 | 
			
		||||
                    output.generated_text = json.loads(
 | 
			
		||||
                        body)["text"][0][len(request_func_input.prompt):]
 | 
			
		||||
                    output.success = True
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
 | 
			
		||||
            output.success = False
 | 
			
		||||
 | 
			
		||||
        if pbar:
 | 
			
		||||
            pbar.update(1)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def async_request_trt_llm(
 | 
			
		||||
    request_func_input: RequestFuncInput,
 | 
			
		||||
    pbar: Optional[tqdm] = None,
 | 
			
		||||
) -> RequestFuncOutput:
 | 
			
		||||
    api_url = request_func_input.api_url
 | 
			
		||||
    assert api_url.endswith("generate_stream")
 | 
			
		||||
 | 
			
		||||
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 | 
			
		||||
        assert not request_func_input.use_beam_search
 | 
			
		||||
        assert request_func_input.best_of == 1
 | 
			
		||||
        payload = {
 | 
			
		||||
            "accumulate_tokens": True,
 | 
			
		||||
            "text_input": request_func_input.prompt,
 | 
			
		||||
            "temperature": 0.0,
 | 
			
		||||
            "top_p": 1.0,
 | 
			
		||||
            "max_tokens": request_func_input.output_len,
 | 
			
		||||
            "stream": True,
 | 
			
		||||
        }
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
        ttft = 0
 | 
			
		||||
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload) as resp:
 | 
			
		||||
                if resp.status == 200:
 | 
			
		||||
                    async for data in resp.content.iter_any():
 | 
			
		||||
                        if ttft == 0:
 | 
			
		||||
                            ttft = time.perf_counter() - st
 | 
			
		||||
                            output.ttft = ttft
 | 
			
		||||
                    output.latency = time.perf_counter() - st
 | 
			
		||||
 | 
			
		||||
                    body = data.decode("utf-8").lstrip("data:")
 | 
			
		||||
                    output.generated_text = json.loads(body)["text_output"]
 | 
			
		||||
                    output.success = True
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
 | 
			
		||||
            output.success = False
 | 
			
		||||
 | 
			
		||||
        if pbar:
 | 
			
		||||
            pbar.update(1)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def async_request_deepspeed_mii(
 | 
			
		||||
    request_func_input: RequestFuncInput,
 | 
			
		||||
    pbar: Optional[tqdm] = None,
 | 
			
		||||
) -> RequestFuncOutput:
 | 
			
		||||
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 | 
			
		||||
        assert request_func_input.best_of == 1
 | 
			
		||||
        assert not request_func_input.use_beam_search
 | 
			
		||||
 | 
			
		||||
        payload = {
 | 
			
		||||
            "prompts": request_func_input.prompt,
 | 
			
		||||
            "max_new_tokens": request_func_input.output_len,
 | 
			
		||||
            "ignore_eos": True,
 | 
			
		||||
            "do_sample": True,
 | 
			
		||||
            "temperature":
 | 
			
		||||
            0.01,  # deepspeed-mii does not accept 0.0 temperature.
 | 
			
		||||
            "top_p": 1.0,
 | 
			
		||||
        }
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        # DeepSpeed-MII doesn't support streaming as of Jan 28 2024, will use 0 as placeholder.
 | 
			
		||||
        # https://github.com/microsoft/DeepSpeed-MII/pull/311
 | 
			
		||||
        output.ttft = 0
 | 
			
		||||
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=request_func_input.api_url,
 | 
			
		||||
                                    json=payload) as resp:
 | 
			
		||||
                if resp.status == 200:
 | 
			
		||||
                    parsed_resp = await resp.json()
 | 
			
		||||
                    output.latency = time.perf_counter() - st
 | 
			
		||||
                    output.generated_text = parsed_resp[0]["generated_text"]
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                else:
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
 | 
			
		||||
            output.success = False
 | 
			
		||||
 | 
			
		||||
        if pbar:
 | 
			
		||||
            pbar.update(1)
 | 
			
		||||
        return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def async_request_openai_completions(
 | 
			
		||||
    request_func_input: RequestFuncInput,
 | 
			
		||||
    pbar: Optional[tqdm] = None,
 | 
			
		||||
) -> RequestFuncOutput:
 | 
			
		||||
    api_url = request_func_input.api_url
 | 
			
		||||
    assert api_url.endswith("v1/completions")
 | 
			
		||||
 | 
			
		||||
    async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
 | 
			
		||||
        assert not request_func_input.use_beam_search
 | 
			
		||||
        payload = {
 | 
			
		||||
            "model": request_func_input.model,
 | 
			
		||||
            "prompt": request_func_input.prompt,
 | 
			
		||||
            "temperature": 0.0,
 | 
			
		||||
            "best_of": request_func_input.best_of,
 | 
			
		||||
            "max_tokens": request_func_input.output_len,
 | 
			
		||||
            "stream": True,
 | 
			
		||||
        }
 | 
			
		||||
        headers = {
 | 
			
		||||
            "Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        output = RequestFuncOutput()
 | 
			
		||||
        output.prompt_len = request_func_input.prompt_len
 | 
			
		||||
 | 
			
		||||
        generated_text = ""
 | 
			
		||||
        ttft = 0
 | 
			
		||||
        st = time.perf_counter()
 | 
			
		||||
        try:
 | 
			
		||||
            async with session.post(url=api_url, json=payload,
 | 
			
		||||
                                    headers=headers) as response:
 | 
			
		||||
                if response.status == 200:
 | 
			
		||||
                    async for chunk in response.content:
 | 
			
		||||
                        if ttft == 0:
 | 
			
		||||
                            ttft = time.perf_counter() - st
 | 
			
		||||
                            output.ttft = ttft
 | 
			
		||||
 | 
			
		||||
                        chunk = chunk.strip()
 | 
			
		||||
                        if not chunk:
 | 
			
		||||
                            continue
 | 
			
		||||
 | 
			
		||||
                        chunk = chunk.decode("utf-8").lstrip("data: ")
 | 
			
		||||
                        if chunk == "[DONE]":
 | 
			
		||||
                            latency = time.perf_counter() - st
 | 
			
		||||
                        else:
 | 
			
		||||
                            body = json.loads(chunk)
 | 
			
		||||
                            generated_text += body["choices"][0]["text"]
 | 
			
		||||
 | 
			
		||||
                    output.generated_text = generated_text
 | 
			
		||||
                    output.success = True
 | 
			
		||||
                    output.latency = latency
 | 
			
		||||
                else:
 | 
			
		||||
                    output.success = False
 | 
			
		||||
        except (aiohttp.ClientOSError, aiohttp.ServerDisconnectedError):
 | 
			
		||||
            output.success = False
 | 
			
		||||
 | 
			
		||||
    if pbar:
 | 
			
		||||
        pbar.update(1)
 | 
			
		||||
    return output
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
ASYNC_REQUEST_FUNCS = {
 | 
			
		||||
    "tgi": async_request_tgi,
 | 
			
		||||
    "vllm": async_request_vllm,
 | 
			
		||||
    "deepspeed-mii": async_request_deepspeed_mii,
 | 
			
		||||
    "openai": async_request_openai_completions,
 | 
			
		||||
    "tensorrt-llm": async_request_trt_llm,
 | 
			
		||||
}
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
"""Benchmark the latency of processing a single batch of requests."""
 | 
			
		||||
import argparse
 | 
			
		||||
import time
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -21,6 +23,9 @@ def main(args: argparse.Namespace):
 | 
			
		||||
        tensor_parallel_size=args.tensor_parallel_size,
 | 
			
		||||
        trust_remote_code=args.trust_remote_code,
 | 
			
		||||
        dtype=args.dtype,
 | 
			
		||||
        enforce_eager=args.enforce_eager,
 | 
			
		||||
        kv_cache_dtype=args.kv_cache_dtype,
 | 
			
		||||
        device=args.device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(
 | 
			
		||||
@ -32,14 +37,20 @@ def main(args: argparse.Namespace):
 | 
			
		||||
        max_tokens=args.output_len,
 | 
			
		||||
    )
 | 
			
		||||
    print(sampling_params)
 | 
			
		||||
    dummy_prompt_token_ids = [[0] * args.input_len] * args.batch_size
 | 
			
		||||
    dummy_prompt_token_ids = np.random.randint(10000,
 | 
			
		||||
                                               size=(args.batch_size,
 | 
			
		||||
                                                     args.input_len))
 | 
			
		||||
    dummy_prompt_token_ids = dummy_prompt_token_ids.tolist()
 | 
			
		||||
 | 
			
		||||
    def run_to_completion(profile: bool = False):
 | 
			
		||||
        if profile:
 | 
			
		||||
            with torch.profiler.profile(activities=[
 | 
			
		||||
    def run_to_completion(profile_dir: Optional[str] = None):
 | 
			
		||||
        if profile_dir:
 | 
			
		||||
            with torch.profiler.profile(
 | 
			
		||||
                    activities=[
 | 
			
		||||
                        torch.profiler.ProfilerActivity.CPU,
 | 
			
		||||
                        torch.profiler.ProfilerActivity.CUDA,
 | 
			
		||||
            ]) as p:
 | 
			
		||||
                    ],
 | 
			
		||||
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(
 | 
			
		||||
                        str(profile_dir))) as p:
 | 
			
		||||
                llm.generate(prompt_token_ids=dummy_prompt_token_ids,
 | 
			
		||||
                             sampling_params=sampling_params,
 | 
			
		||||
                             use_tqdm=False)
 | 
			
		||||
@ -54,17 +65,22 @@ def main(args: argparse.Namespace):
 | 
			
		||||
            return latency
 | 
			
		||||
 | 
			
		||||
    print("Warming up...")
 | 
			
		||||
    run_to_completion(profile=False)
 | 
			
		||||
    run_to_completion(profile_dir=None)
 | 
			
		||||
 | 
			
		||||
    if args.profile:
 | 
			
		||||
        print("Profiling...")
 | 
			
		||||
        run_to_completion(profile=True)
 | 
			
		||||
        profile_dir = args.profile_result_dir
 | 
			
		||||
        if not profile_dir:
 | 
			
		||||
            profile_dir = Path(
 | 
			
		||||
                "."
 | 
			
		||||
            ) / "vllm_benchmark_result" / f"latency_result_{time.time()}"
 | 
			
		||||
        print(f"Profiling (results will be saved to '{profile_dir}')...")
 | 
			
		||||
        run_to_completion(profile_dir=profile_dir)
 | 
			
		||||
        return
 | 
			
		||||
 | 
			
		||||
    # Benchmark.
 | 
			
		||||
    latencies = []
 | 
			
		||||
    for _ in tqdm(range(args.num_iters), desc="Profiling iterations"):
 | 
			
		||||
        latencies.append(run_to_completion(profile=False))
 | 
			
		||||
        latencies.append(run_to_completion(profile_dir=None))
 | 
			
		||||
    print(f'Avg latency: {np.mean(latencies)} seconds')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -76,7 +92,7 @@ if __name__ == '__main__':
 | 
			
		||||
    parser.add_argument('--tokenizer', type=str, default=None)
 | 
			
		||||
    parser.add_argument('--quantization',
 | 
			
		||||
                        '-q',
 | 
			
		||||
                        choices=['awq', 'squeezellm', None],
 | 
			
		||||
                        choices=['awq', 'gptq', 'squeezellm', None],
 | 
			
		||||
                        default=None)
 | 
			
		||||
    parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
 | 
			
		||||
    parser.add_argument('--input-len', type=int, default=32)
 | 
			
		||||
@ -103,9 +119,31 @@ if __name__ == '__main__':
 | 
			
		||||
        'The "auto" option will use FP16 precision '
 | 
			
		||||
        'for FP32 and FP16 models, and BF16 precision '
 | 
			
		||||
        'for BF16 models.')
 | 
			
		||||
    parser.add_argument('--enforce-eager',
 | 
			
		||||
                        action='store_true',
 | 
			
		||||
                        help='enforce eager mode and disable CUDA graph')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--kv-cache-dtype",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=['auto', 'fp8_e5m2'],
 | 
			
		||||
        default='auto',
 | 
			
		||||
        help=
 | 
			
		||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--profile',
 | 
			
		||||
        action='store_true',
 | 
			
		||||
        help='profile the generation process of a single batch')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        '--profile-result-dir',
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help=('path to save the pytorch profiler output. Can be visualized '
 | 
			
		||||
              'with ui.perfetto.dev or Tensorboard.'))
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--device",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="cuda",
 | 
			
		||||
        choices=["cuda"],
 | 
			
		||||
        help='device type for vLLM execution, supporting CUDA only currently.')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
 | 
			
		||||
@ -20,15 +20,36 @@ import asyncio
 | 
			
		||||
import json
 | 
			
		||||
import random
 | 
			
		||||
import time
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import AsyncGenerator, List, Tuple
 | 
			
		||||
 | 
			
		||||
import aiohttp
 | 
			
		||||
import numpy as np
 | 
			
		||||
from tqdm.asyncio import tqdm
 | 
			
		||||
from transformers import PreTrainedTokenizerBase
 | 
			
		||||
from vllm.transformers_utils.tokenizer import get_tokenizer
 | 
			
		||||
 | 
			
		||||
# (prompt len, output len, latency)
 | 
			
		||||
REQUEST_LATENCY: List[Tuple[int, int, float]] = []
 | 
			
		||||
from backend_request_func import (
 | 
			
		||||
    ASYNC_REQUEST_FUNCS,
 | 
			
		||||
    RequestFuncInput,
 | 
			
		||||
    RequestFuncOutput,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class BenchmarkMetrics:
 | 
			
		||||
    completed: int
 | 
			
		||||
    total_input: int
 | 
			
		||||
    total_output: int
 | 
			
		||||
    request_throughput: float
 | 
			
		||||
    input_throughput: float
 | 
			
		||||
    output_throughput: float
 | 
			
		||||
    mean_ttft_ms: float
 | 
			
		||||
    median_ttft_ms: float
 | 
			
		||||
    p99_ttft_ms: float
 | 
			
		||||
    mean_tpot_ms: float
 | 
			
		||||
    median_tpot_ms: float
 | 
			
		||||
    p99_tpot_ms: float
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def sample_requests(
 | 
			
		||||
@ -40,15 +61,15 @@ def sample_requests(
 | 
			
		||||
    with open(dataset_path) as f:
 | 
			
		||||
        dataset = json.load(f)
 | 
			
		||||
    # Filter out the conversations with less than 2 turns.
 | 
			
		||||
    dataset = [
 | 
			
		||||
        data for data in dataset
 | 
			
		||||
        if len(data["conversations"]) >= 2
 | 
			
		||||
    ]
 | 
			
		||||
    dataset = [data for data in dataset if len(data["conversations"]) >= 2]
 | 
			
		||||
    # Only keep the first two turns of each conversation.
 | 
			
		||||
    dataset = [
 | 
			
		||||
        (data["conversations"][0]["value"], data["conversations"][1]["value"])
 | 
			
		||||
        for data in dataset
 | 
			
		||||
    ]
 | 
			
		||||
    dataset = [(data["conversations"][0]["value"],
 | 
			
		||||
                data["conversations"][1]["value"]) for data in dataset]
 | 
			
		||||
 | 
			
		||||
    # some of these will be filtered out, so sample more than we need
 | 
			
		||||
    sampled_indices = random.sample(range(len(dataset)),
 | 
			
		||||
                                    int(num_requests * 1.2))
 | 
			
		||||
    dataset = [dataset[i] for i in sampled_indices]
 | 
			
		||||
 | 
			
		||||
    # Tokenize the prompts and completions.
 | 
			
		||||
    prompts = [prompt for prompt, _ in dataset]
 | 
			
		||||
@ -96,79 +117,125 @@ async def get_request(
 | 
			
		||||
        await asyncio.sleep(interval)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def send_request(
 | 
			
		||||
    backend: str,
 | 
			
		||||
    api_url: str,
 | 
			
		||||
    prompt: str,
 | 
			
		||||
    prompt_len: int,
 | 
			
		||||
    output_len: int,
 | 
			
		||||
    best_of: int,
 | 
			
		||||
    use_beam_search: bool,
 | 
			
		||||
) -> None:
 | 
			
		||||
    request_start_time = time.perf_counter()
 | 
			
		||||
def calculate_metrics(
 | 
			
		||||
    input_requests: List[Tuple[str, int, int]],
 | 
			
		||||
    outputs: List[RequestFuncOutput],
 | 
			
		||||
    dur_s: float,
 | 
			
		||||
    tokenizer: PreTrainedTokenizerBase,
 | 
			
		||||
) -> BenchmarkMetrics:
 | 
			
		||||
    total_output = 0
 | 
			
		||||
    total_input = 0
 | 
			
		||||
    completed = 0
 | 
			
		||||
    per_token_latencies = []
 | 
			
		||||
    ttfts = []
 | 
			
		||||
    for i in range(len(outputs)):
 | 
			
		||||
        if outputs[i].success:
 | 
			
		||||
            output_len = len(tokenizer.encode(outputs[i].generated_text))
 | 
			
		||||
            total_output += output_len
 | 
			
		||||
            total_input += input_requests[i][1]
 | 
			
		||||
            per_token_latencies.append(outputs[i].latency / output_len)
 | 
			
		||||
            ttfts.append(outputs[i].ttft)
 | 
			
		||||
            completed += 1
 | 
			
		||||
 | 
			
		||||
    headers = {"User-Agent": "Benchmark Client"}
 | 
			
		||||
    if backend == "vllm":
 | 
			
		||||
        pload = {
 | 
			
		||||
            "prompt": prompt,
 | 
			
		||||
            "n": 1,
 | 
			
		||||
            "best_of": best_of,
 | 
			
		||||
            "use_beam_search": use_beam_search,
 | 
			
		||||
            "temperature": 0.0 if use_beam_search else 1.0,
 | 
			
		||||
            "top_p": 1.0,
 | 
			
		||||
            "max_tokens": output_len,
 | 
			
		||||
            "ignore_eos": True,
 | 
			
		||||
            "stream": False,
 | 
			
		||||
        }
 | 
			
		||||
    elif backend == "tgi":
 | 
			
		||||
        assert not use_beam_search
 | 
			
		||||
        params = {
 | 
			
		||||
            "best_of": best_of,
 | 
			
		||||
            "max_new_tokens": output_len,
 | 
			
		||||
            "do_sample": True,
 | 
			
		||||
        }
 | 
			
		||||
        pload = {
 | 
			
		||||
            "inputs": prompt,
 | 
			
		||||
            "parameters": params,
 | 
			
		||||
        }
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown backend: {backend}")
 | 
			
		||||
    metrics = BenchmarkMetrics(
 | 
			
		||||
        completed=completed,
 | 
			
		||||
        total_input=total_input,
 | 
			
		||||
        total_output=total_output,
 | 
			
		||||
        request_throughput=completed / dur_s,
 | 
			
		||||
        input_throughput=total_input / dur_s,
 | 
			
		||||
        output_throughput=total_output / dur_s,
 | 
			
		||||
        mean_ttft_ms=np.mean(ttfts) * 1000,
 | 
			
		||||
        median_ttft_ms=np.median(ttfts) * 1000,
 | 
			
		||||
        p99_ttft_ms=np.percentile(ttfts, 99) * 1000,
 | 
			
		||||
        mean_tpot_ms=np.mean(per_token_latencies) * 1000,
 | 
			
		||||
        median_tpot_ms=np.median(per_token_latencies) * 1000,
 | 
			
		||||
        p99_tpot_ms=np.percentile(per_token_latencies, 99) * 1000,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    timeout = aiohttp.ClientTimeout(total=3 * 3600)
 | 
			
		||||
    async with aiohttp.ClientSession(timeout=timeout) as session:
 | 
			
		||||
        while True:
 | 
			
		||||
            async with session.post(api_url, headers=headers, json=pload) as response:
 | 
			
		||||
                chunks = []
 | 
			
		||||
                async for chunk, _ in response.content.iter_chunks():
 | 
			
		||||
                    chunks.append(chunk)
 | 
			
		||||
            output = b"".join(chunks).decode("utf-8")
 | 
			
		||||
            output = json.loads(output)
 | 
			
		||||
 | 
			
		||||
            # Re-send the request if it failed.
 | 
			
		||||
            if "error" not in output:
 | 
			
		||||
                break
 | 
			
		||||
 | 
			
		||||
    request_end_time = time.perf_counter()
 | 
			
		||||
    request_latency = request_end_time - request_start_time
 | 
			
		||||
    REQUEST_LATENCY.append((prompt_len, output_len, request_latency))
 | 
			
		||||
    return metrics
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def benchmark(
 | 
			
		||||
    backend: str,
 | 
			
		||||
    api_url: str,
 | 
			
		||||
    model_id: str,
 | 
			
		||||
    tokenizer: PreTrainedTokenizerBase,
 | 
			
		||||
    input_requests: List[Tuple[str, int, int]],
 | 
			
		||||
    best_of: int,
 | 
			
		||||
    use_beam_search: bool,
 | 
			
		||||
    request_rate: float,
 | 
			
		||||
) -> None:
 | 
			
		||||
    tasks: List[asyncio.Task] = []
 | 
			
		||||
    disable_tqdm: bool,
 | 
			
		||||
):
 | 
			
		||||
    if backend in ASYNC_REQUEST_FUNCS:
 | 
			
		||||
        request_func = ASYNC_REQUEST_FUNCS.get(backend)
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError(f"Unknown backend: {backend}")
 | 
			
		||||
 | 
			
		||||
    pbar = None if disable_tqdm else tqdm(total=len(input_requests))
 | 
			
		||||
 | 
			
		||||
    print(f"Traffic request rate: {request_rate}")
 | 
			
		||||
 | 
			
		||||
    benchmark_start_time = time.perf_counter()
 | 
			
		||||
    tasks = []
 | 
			
		||||
    async for request in get_request(input_requests, request_rate):
 | 
			
		||||
        prompt, prompt_len, output_len = request
 | 
			
		||||
        task = asyncio.create_task(send_request(backend, api_url, prompt,
 | 
			
		||||
                                                prompt_len, output_len,
 | 
			
		||||
                                                best_of, use_beam_search))
 | 
			
		||||
        tasks.append(task)
 | 
			
		||||
    await asyncio.gather(*tasks)
 | 
			
		||||
        request_func_input = RequestFuncInput(
 | 
			
		||||
            model=model_id,
 | 
			
		||||
            prompt=prompt,
 | 
			
		||||
            api_url=api_url,
 | 
			
		||||
            prompt_len=prompt_len,
 | 
			
		||||
            output_len=output_len,
 | 
			
		||||
            best_of=best_of,
 | 
			
		||||
            use_beam_search=use_beam_search,
 | 
			
		||||
        )
 | 
			
		||||
        tasks.append(
 | 
			
		||||
            asyncio.create_task(
 | 
			
		||||
                request_func(request_func_input=request_func_input,
 | 
			
		||||
                             pbar=pbar)))
 | 
			
		||||
    outputs = await asyncio.gather(*tasks)
 | 
			
		||||
 | 
			
		||||
    if not disable_tqdm:
 | 
			
		||||
        pbar.close()
 | 
			
		||||
 | 
			
		||||
    benchmark_duration = time.perf_counter() - benchmark_start_time
 | 
			
		||||
 | 
			
		||||
    metrics = calculate_metrics(
 | 
			
		||||
        input_requests=input_requests,
 | 
			
		||||
        outputs=outputs,
 | 
			
		||||
        dur_s=benchmark_duration,
 | 
			
		||||
        tokenizer=tokenizer,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    print(f"Successful requests: {metrics.completed}")
 | 
			
		||||
    print(f"Benchmark duration: {benchmark_duration:2f} s")
 | 
			
		||||
    print(f"Total input tokens: {metrics.total_input}")
 | 
			
		||||
    print(f"Total generated tokens: {metrics.total_output}")
 | 
			
		||||
    print(f"Request throughput: {metrics.request_throughput:.2f} requests/s")
 | 
			
		||||
    print(f"Input token throughput: {metrics.input_throughput:.2f} tokens/s")
 | 
			
		||||
    print(f"Output token throughput: {metrics.output_throughput:.2f} tokens/s")
 | 
			
		||||
    print(f"Mean TTFT: {metrics.mean_ttft_ms:.2f} ms")
 | 
			
		||||
    print(f"Median TTFT: {metrics.median_ttft_ms:.2f} ms")
 | 
			
		||||
    print(f"P99 TTFT: {metrics.p99_ttft_ms:.2f} ms")
 | 
			
		||||
    print(f"Mean TPOT: {metrics.mean_tpot_ms:.2f} ms")
 | 
			
		||||
    print(f"Median TPOT: {metrics.median_tpot_ms:.2f} ms")
 | 
			
		||||
    print(f"P99 TPOT: {metrics.p99_tpot_ms:.2f} ms")
 | 
			
		||||
 | 
			
		||||
    result = {
 | 
			
		||||
        "duration": benchmark_duration,
 | 
			
		||||
        "completed": metrics.completed,
 | 
			
		||||
        "total_input_tokens": metrics.total_input,
 | 
			
		||||
        "total_output_tokens": metrics.total_output,
 | 
			
		||||
        "request_inthroughput": metrics.request_throughput,
 | 
			
		||||
        "input_throughput": metrics.input_throughput,
 | 
			
		||||
        "output_throughput": metrics.output_throughput,
 | 
			
		||||
        "mean_ttft_ms": metrics.mean_ttft_ms,
 | 
			
		||||
        "median_ttft_ms": metrics.median_ttft_ms,
 | 
			
		||||
        "p99_ttft_ms": metrics.p99_ttft_ms,
 | 
			
		||||
        "mean_tpot_ms": metrics.mean_tpot_ms,
 | 
			
		||||
        "median_tpot_ms": metrics.median_tpot_ms,
 | 
			
		||||
        "p99_tpot_ms": metrics.p99_tpot_ms
 | 
			
		||||
    }
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main(args: argparse.Namespace):
 | 
			
		||||
@ -176,58 +243,145 @@ def main(args: argparse.Namespace):
 | 
			
		||||
    random.seed(args.seed)
 | 
			
		||||
    np.random.seed(args.seed)
 | 
			
		||||
 | 
			
		||||
    api_url = f"http://{args.host}:{args.port}/generate"
 | 
			
		||||
    tokenizer = get_tokenizer(args.tokenizer, trust_remote_code=args.trust_remote_code)
 | 
			
		||||
    backend = args.backend
 | 
			
		||||
    model_id = args.model
 | 
			
		||||
    tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
 | 
			
		||||
 | 
			
		||||
    if args.base_url is not None:
 | 
			
		||||
        api_url = f"{args.base_url}{args.endpoint}"
 | 
			
		||||
    else:
 | 
			
		||||
        api_url = f"http://{args.host}:{args.port}{args.endpoint}"
 | 
			
		||||
 | 
			
		||||
    tokenizer = get_tokenizer(tokenizer_id,
 | 
			
		||||
                              trust_remote_code=args.trust_remote_code)
 | 
			
		||||
    input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer)
 | 
			
		||||
 | 
			
		||||
    benchmark_start_time = time.perf_counter()
 | 
			
		||||
    asyncio.run(benchmark(args.backend, api_url, input_requests, args.best_of,
 | 
			
		||||
                          args.use_beam_search, args.request_rate))
 | 
			
		||||
    benchmark_end_time = time.perf_counter()
 | 
			
		||||
    benchmark_time = benchmark_end_time - benchmark_start_time
 | 
			
		||||
    print(f"Total time: {benchmark_time:.2f} s")
 | 
			
		||||
    print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
 | 
			
		||||
    benchmark_result = asyncio.run(
 | 
			
		||||
        benchmark(
 | 
			
		||||
            backend=backend,
 | 
			
		||||
            api_url=api_url,
 | 
			
		||||
            model_id=model_id,
 | 
			
		||||
            tokenizer=tokenizer,
 | 
			
		||||
            input_requests=input_requests,
 | 
			
		||||
            best_of=args.best_of,
 | 
			
		||||
            use_beam_search=args.use_beam_search,
 | 
			
		||||
            request_rate=args.request_rate,
 | 
			
		||||
            disable_tqdm=args.disable_tqdm,
 | 
			
		||||
        ))
 | 
			
		||||
 | 
			
		||||
    # Compute the latency statistics.
 | 
			
		||||
    avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY])
 | 
			
		||||
    print(f"Average latency: {avg_latency:.2f} s")
 | 
			
		||||
    avg_per_token_latency = np.mean([
 | 
			
		||||
        latency / (prompt_len + output_len)
 | 
			
		||||
        for prompt_len, output_len, latency in REQUEST_LATENCY
 | 
			
		||||
    ])
 | 
			
		||||
    print(f"Average latency per token: {avg_per_token_latency:.2f} s")
 | 
			
		||||
    avg_per_output_token_latency = np.mean([
 | 
			
		||||
        latency / output_len
 | 
			
		||||
        for _, output_len, latency in REQUEST_LATENCY
 | 
			
		||||
    ])
 | 
			
		||||
    print("Average latency per output token: "
 | 
			
		||||
          f"{avg_per_output_token_latency:.2f} s")
 | 
			
		||||
    # Save config and results to json
 | 
			
		||||
    if args.save_result:
 | 
			
		||||
        result_json = {}
 | 
			
		||||
 | 
			
		||||
        # Setup
 | 
			
		||||
        current_dt = datetime.now().strftime("%Y%m%d-%H%M%S")
 | 
			
		||||
        result_json["date"] = current_dt
 | 
			
		||||
        result_json["backend"] = backend
 | 
			
		||||
        result_json["version"] = args.version
 | 
			
		||||
        result_json["model_id"] = model_id
 | 
			
		||||
        result_json["tokenizer_id"] = tokenizer_id
 | 
			
		||||
        result_json["best_of"] = args.best_of
 | 
			
		||||
        result_json["use_beam_search"] = args.use_beam_search
 | 
			
		||||
        result_json["num_prompts"] = args.num_prompts
 | 
			
		||||
 | 
			
		||||
        # Traffic
 | 
			
		||||
        result_json["request_rate"] = (
 | 
			
		||||
            args.request_rate if args.request_rate < float("inf") else "inf")
 | 
			
		||||
 | 
			
		||||
        # Merge with benchmark result
 | 
			
		||||
        result_json = {**result_json, **benchmark_result}
 | 
			
		||||
 | 
			
		||||
        # Save to file
 | 
			
		||||
        base_model_id = model_id.split("/")[-1]
 | 
			
		||||
        file_name = f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
 | 
			
		||||
        with open(file_name, "w") as outfile:
 | 
			
		||||
            json.dump(result_json, outfile)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    parser = argparse.ArgumentParser(
 | 
			
		||||
        description="Benchmark the online serving throughput.")
 | 
			
		||||
    parser.add_argument("--backend", type=str, default="vllm",
 | 
			
		||||
                        choices=["vllm", "tgi"])
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--backend",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="vllm",
 | 
			
		||||
        choices=list(ASYNC_REQUEST_FUNCS.keys()),
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--version",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="N/A",
 | 
			
		||||
        help="Version of the serving backend/engine.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--base-url",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default=None,
 | 
			
		||||
        help="Server or API base url if not using http host and port.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--host", type=str, default="localhost")
 | 
			
		||||
    parser.add_argument("--port", type=int, default=8000)
 | 
			
		||||
    parser.add_argument("--dataset", type=str, required=True,
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--endpoint",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="/generate",
 | 
			
		||||
        help="API endpoint.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--dataset",
 | 
			
		||||
                        type=str,
 | 
			
		||||
                        required=True,
 | 
			
		||||
                        help="Path to the dataset.")
 | 
			
		||||
    parser.add_argument("--tokenizer", type=str, required=True,
 | 
			
		||||
                        help="Name or path of the tokenizer.")
 | 
			
		||||
    parser.add_argument("--best-of", type=int, default=1,
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--model",
 | 
			
		||||
        type=str,
 | 
			
		||||
        required=True,
 | 
			
		||||
        help="Name of the model.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--tokenizer",
 | 
			
		||||
        type=str,
 | 
			
		||||
        help=
 | 
			
		||||
        "Name or path of the tokenizer, if not using the default model tokenizer.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--best-of",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1,
 | 
			
		||||
        help="Generates `best_of` sequences per prompt and "
 | 
			
		||||
                             "returns the best one.")
 | 
			
		||||
        "returns the best one.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--use-beam-search", action="store_true")
 | 
			
		||||
    parser.add_argument("--num-prompts", type=int, default=1000,
 | 
			
		||||
                        help="Number of prompts to process.")
 | 
			
		||||
    parser.add_argument("--request-rate", type=float, default=float("inf"),
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--num-prompts",
 | 
			
		||||
        type=int,
 | 
			
		||||
        default=1000,
 | 
			
		||||
        help="Number of prompts to process.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--request-rate",
 | 
			
		||||
        type=float,
 | 
			
		||||
        default=float("inf"),
 | 
			
		||||
        help="Number of requests per second. If this is inf, "
 | 
			
		||||
        "then all the requests are sent at time 0. "
 | 
			
		||||
        "Otherwise, we use Poisson process to synthesize "
 | 
			
		||||
                             "the request arrival times.")
 | 
			
		||||
        "the request arrival times.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0)
 | 
			
		||||
    parser.add_argument('--trust-remote-code', action='store_true',
 | 
			
		||||
                        help='trust remote code from huggingface')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--trust-remote-code",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Trust remote code from huggingface",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--disable-tqdm",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Specify to disbale tqdm progress bar.",
 | 
			
		||||
    )
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--save-result",
 | 
			
		||||
        action="store_true",
 | 
			
		||||
        help="Specify to save benchmark results to a json file",
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    main(args)
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,10 @@ def run_vllm(
 | 
			
		||||
    use_beam_search: bool,
 | 
			
		||||
    trust_remote_code: bool,
 | 
			
		||||
    dtype: str,
 | 
			
		||||
    max_model_len: Optional[int] = None,
 | 
			
		||||
    max_model_len: Optional[int],
 | 
			
		||||
    enforce_eager: bool,
 | 
			
		||||
    kv_cache_dtype: str,
 | 
			
		||||
    device: str,
 | 
			
		||||
) -> float:
 | 
			
		||||
    from vllm import LLM, SamplingParams
 | 
			
		||||
    llm = LLM(
 | 
			
		||||
@ -81,6 +84,9 @@ def run_vllm(
 | 
			
		||||
        trust_remote_code=trust_remote_code,
 | 
			
		||||
        dtype=dtype,
 | 
			
		||||
        max_model_len=max_model_len,
 | 
			
		||||
        enforce_eager=enforce_eager,
 | 
			
		||||
        kv_cache_dtype=kv_cache_dtype,
 | 
			
		||||
        device=device,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # Add the requests to the engine.
 | 
			
		||||
@ -204,7 +210,8 @@ def main(args: argparse.Namespace):
 | 
			
		||||
                                args.quantization, args.tensor_parallel_size,
 | 
			
		||||
                                args.seed, args.n, args.use_beam_search,
 | 
			
		||||
                                args.trust_remote_code, args.dtype,
 | 
			
		||||
                                args.max_model_len)
 | 
			
		||||
                                args.max_model_len, args.enforce_eager,
 | 
			
		||||
                                args.kv_cache_dtype, args.device)
 | 
			
		||||
    elif args.backend == "hf":
 | 
			
		||||
        assert args.tensor_parallel_size == 1
 | 
			
		||||
        elapsed_time = run_hf(requests, args.model, tokenizer, args.n,
 | 
			
		||||
@ -244,7 +251,7 @@ if __name__ == "__main__":
 | 
			
		||||
    parser.add_argument("--tokenizer", type=str, default=None)
 | 
			
		||||
    parser.add_argument('--quantization',
 | 
			
		||||
                        '-q',
 | 
			
		||||
                        choices=['awq', 'squeezellm', None],
 | 
			
		||||
                        choices=['awq', 'gptq', 'squeezellm', None],
 | 
			
		||||
                        default=None)
 | 
			
		||||
    parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
 | 
			
		||||
    parser.add_argument("--n",
 | 
			
		||||
@ -279,6 +286,22 @@ if __name__ == "__main__":
 | 
			
		||||
        'The "auto" option will use FP16 precision '
 | 
			
		||||
        'for FP32 and FP16 models, and BF16 precision '
 | 
			
		||||
        'for BF16 models.')
 | 
			
		||||
    parser.add_argument("--enforce-eager",
 | 
			
		||||
                        action="store_true",
 | 
			
		||||
                        help="enforce eager execution")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--kv-cache-dtype",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["auto", "fp8_e5m2"],
 | 
			
		||||
        default="auto",
 | 
			
		||||
        help=
 | 
			
		||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--device",
 | 
			
		||||
        type=str,
 | 
			
		||||
        default="cuda",
 | 
			
		||||
        choices=["cuda"],
 | 
			
		||||
        help='device type for vLLM execution, supporting CUDA only currently.')
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    if args.tokenizer is None:
 | 
			
		||||
        args.tokenizer = args.model
 | 
			
		||||
 | 
			
		||||
@ -1,9 +1,11 @@
 | 
			
		||||
from typing import Optional
 | 
			
		||||
import argparse
 | 
			
		||||
import random
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
 | 
			
		||||
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, create_kv_caches_with_random
 | 
			
		||||
from vllm._C import ops
 | 
			
		||||
 | 
			
		||||
NUM_BLOCKS = 1024
 | 
			
		||||
@ -23,9 +25,12 @@ def main(
 | 
			
		||||
    dtype: torch.dtype,
 | 
			
		||||
    seed: int,
 | 
			
		||||
    do_profile: bool,
 | 
			
		||||
    device: str = "cuda",
 | 
			
		||||
    kv_cache_dtype: Optional[str] = None,
 | 
			
		||||
) -> None:
 | 
			
		||||
    random.seed(seed)
 | 
			
		||||
    torch.random.manual_seed(seed)
 | 
			
		||||
    if torch.cuda.is_available():
 | 
			
		||||
        torch.cuda.manual_seed(seed)
 | 
			
		||||
 | 
			
		||||
    scale = float(1.0 / (head_size**0.5))
 | 
			
		||||
@ -33,23 +38,19 @@ def main(
 | 
			
		||||
                        num_query_heads,
 | 
			
		||||
                        head_size,
 | 
			
		||||
                        dtype=dtype,
 | 
			
		||||
                        device="cuda")
 | 
			
		||||
                        device=device)
 | 
			
		||||
    query.uniform_(-scale, scale)
 | 
			
		||||
 | 
			
		||||
    assert num_query_heads % num_kv_heads == 0
 | 
			
		||||
    num_queries_per_kv = num_query_heads // num_kv_heads
 | 
			
		||||
    head_mapping = torch.repeat_interleave(
 | 
			
		||||
        torch.arange(num_kv_heads, dtype=torch.int32, device="cuda"),
 | 
			
		||||
        num_queries_per_kv)
 | 
			
		||||
    alibi_slopes = None
 | 
			
		||||
    if use_alibi:
 | 
			
		||||
        alibi_slopes = torch.randn(num_query_heads,
 | 
			
		||||
                                   dtype=torch.float,
 | 
			
		||||
                                   device="cuda")
 | 
			
		||||
                                   device=device)
 | 
			
		||||
 | 
			
		||||
    context_lens = [context_len for _ in range(num_seqs)]
 | 
			
		||||
    max_context_len = max(context_lens)
 | 
			
		||||
    context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda")
 | 
			
		||||
    context_lens = torch.tensor(context_lens, dtype=torch.int, device=device)
 | 
			
		||||
 | 
			
		||||
    # Create the block tables.
 | 
			
		||||
    max_num_blocks_per_seq = (max_context_len + block_size - 1) // block_size
 | 
			
		||||
@ -60,18 +61,18 @@ def main(
 | 
			
		||||
            for _ in range(max_num_blocks_per_seq)
 | 
			
		||||
        ]
 | 
			
		||||
        block_tables.append(block_table)
 | 
			
		||||
    block_tables = torch.tensor(block_tables, dtype=torch.int, device="cuda")
 | 
			
		||||
    block_tables = torch.tensor(block_tables, dtype=torch.int, device=device)
 | 
			
		||||
 | 
			
		||||
    # Create the KV cache.
 | 
			
		||||
    x = 16 // torch.tensor([], dtype=dtype).element_size()
 | 
			
		||||
    key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, block_size, x)
 | 
			
		||||
    key_cache = torch.empty(size=key_cache_shape, dtype=dtype, device="cuda")
 | 
			
		||||
    key_cache.uniform_(-scale, scale)
 | 
			
		||||
    value_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size, block_size)
 | 
			
		||||
    value_cache = torch.empty(size=value_cache_shape,
 | 
			
		||||
                              dtype=dtype,
 | 
			
		||||
                              device="cuda")
 | 
			
		||||
    value_cache.uniform_(-scale, scale)
 | 
			
		||||
    key_caches, value_caches = create_kv_caches_with_random(NUM_BLOCKS,
 | 
			
		||||
                                                            block_size,
 | 
			
		||||
                                                            1,
 | 
			
		||||
                                                            num_kv_heads,
 | 
			
		||||
                                                            head_size,
 | 
			
		||||
                                                            kv_cache_dtype,
 | 
			
		||||
                                                            dtype,
 | 
			
		||||
                                                            device=device)
 | 
			
		||||
    key_cache, value_cache = key_caches[0], value_caches[0]
 | 
			
		||||
 | 
			
		||||
    # Prepare for the paged attention kernel.
 | 
			
		||||
    output = torch.empty_like(query)
 | 
			
		||||
@ -90,7 +91,7 @@ def main(
 | 
			
		||||
        )
 | 
			
		||||
        max_logits = torch.empty_like(exp_sums)
 | 
			
		||||
 | 
			
		||||
    def run_benchmark(num_iters: int, profile: bool = False) -> float:
 | 
			
		||||
    def run_cuda_benchmark(num_iters: int, profile: bool = False) -> float:
 | 
			
		||||
        torch.cuda.synchronize()
 | 
			
		||||
        if profile:
 | 
			
		||||
            torch.cuda.cudart().cudaProfilerStart()
 | 
			
		||||
@ -103,13 +104,14 @@ def main(
 | 
			
		||||
                    query,
 | 
			
		||||
                    key_cache,
 | 
			
		||||
                    value_cache,
 | 
			
		||||
                    head_mapping,
 | 
			
		||||
                    num_kv_heads,
 | 
			
		||||
                    scale,
 | 
			
		||||
                    block_tables,
 | 
			
		||||
                    context_lens,
 | 
			
		||||
                    block_size,
 | 
			
		||||
                    max_context_len,
 | 
			
		||||
                    alibi_slopes,
 | 
			
		||||
                    kv_cache_dtype,
 | 
			
		||||
                )
 | 
			
		||||
            elif version == "v2":
 | 
			
		||||
                ops.paged_attention_v2(
 | 
			
		||||
@ -120,13 +122,14 @@ def main(
 | 
			
		||||
                    query,
 | 
			
		||||
                    key_cache,
 | 
			
		||||
                    value_cache,
 | 
			
		||||
                    head_mapping,
 | 
			
		||||
                    num_kv_heads,
 | 
			
		||||
                    scale,
 | 
			
		||||
                    block_tables,
 | 
			
		||||
                    context_lens,
 | 
			
		||||
                    block_size,
 | 
			
		||||
                    max_context_len,
 | 
			
		||||
                    alibi_slopes,
 | 
			
		||||
                    kv_cache_dtype,
 | 
			
		||||
                )
 | 
			
		||||
            else:
 | 
			
		||||
                raise ValueError(f"Invalid version: {version}")
 | 
			
		||||
@ -139,6 +142,7 @@ def main(
 | 
			
		||||
 | 
			
		||||
    # Warmup.
 | 
			
		||||
    print("Warming up...")
 | 
			
		||||
    run_benchmark = run_cuda_benchmark
 | 
			
		||||
    run_benchmark(num_iters=3, profile=False)
 | 
			
		||||
 | 
			
		||||
    # Benchmark.
 | 
			
		||||
@ -172,16 +176,19 @@ if __name__ == '__main__':
 | 
			
		||||
                        default="half")
 | 
			
		||||
    parser.add_argument("--seed", type=int, default=0)
 | 
			
		||||
    parser.add_argument("--profile", action="store_true")
 | 
			
		||||
    parser.add_argument(
 | 
			
		||||
        "--kv-cache-dtype",
 | 
			
		||||
        type=str,
 | 
			
		||||
        choices=["auto", "fp8_e5m2"],
 | 
			
		||||
        default="auto",
 | 
			
		||||
        help=
 | 
			
		||||
        'Data type for kv cache storage. If "auto", will use model data type.')
 | 
			
		||||
    parser.add_argument("--device", type=str, choices=["cuda"], default="cuda")
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
    print(args)
 | 
			
		||||
 | 
			
		||||
    if args.num_query_heads % args.num_kv_heads != 0:
 | 
			
		||||
        raise ValueError("num_query_heads must be divisible by num_kv_heads")
 | 
			
		||||
    dtype_to_torch_dtype = {
 | 
			
		||||
        "half": torch.half,
 | 
			
		||||
        "bfloat16": torch.bfloat16,
 | 
			
		||||
        "float": torch.float,
 | 
			
		||||
    }
 | 
			
		||||
    main(
 | 
			
		||||
        version=args.version,
 | 
			
		||||
        num_seqs=args.batch_size,
 | 
			
		||||
@ -191,7 +198,8 @@ if __name__ == '__main__':
 | 
			
		||||
        head_size=args.head_size,
 | 
			
		||||
        block_size=args.block_size,
 | 
			
		||||
        use_alibi=args.use_alibi,
 | 
			
		||||
        dtype=dtype_to_torch_dtype[args.dtype],
 | 
			
		||||
        dtype=STR_DTYPE_TO_TORCH_DTYPE[args.dtype],
 | 
			
		||||
        seed=args.seed,
 | 
			
		||||
        do_profile=args.profile,
 | 
			
		||||
        kv_cache_dtype=args.kv_cache_dtype,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
@ -6,7 +6,7 @@ TOKENS=$2
 | 
			
		||||
 | 
			
		||||
docker run --gpus all --shm-size 1g -p $PORT:80 \
 | 
			
		||||
           -v $PWD/data:/data \
 | 
			
		||||
           ghcr.io/huggingface/text-generation-inference:0.8 \
 | 
			
		||||
           ghcr.io/huggingface/text-generation-inference:1.4.0 \
 | 
			
		||||
           --model-id $MODEL \
 | 
			
		||||
           --sharded false  \
 | 
			
		||||
           --max-input-length 1024 \
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
@ -18,8 +20,8 @@ __global__ void silu_and_mul_kernel(
 | 
			
		||||
  const int d) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
 | 
			
		||||
    const scalar_t x = __ldg(&input[token_idx * 2 * d + idx]);
 | 
			
		||||
    const scalar_t y = __ldg(&input[token_idx * 2 * d + d + idx]);
 | 
			
		||||
    const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]);
 | 
			
		||||
    const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]);
 | 
			
		||||
    out[token_idx * d + idx] = silu(x) * y;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -35,6 +37,7 @@ void silu_and_mul(
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(d, 1024));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    input.scalar_type(),
 | 
			
		||||
@ -57,7 +60,7 @@ __global__ void activation_kernel(
 | 
			
		||||
  const int d) {
 | 
			
		||||
  const int64_t token_idx = blockIdx.x;
 | 
			
		||||
  for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) {
 | 
			
		||||
    const scalar_t x = __ldg(&input[token_idx * d + idx]);
 | 
			
		||||
    const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]);
 | 
			
		||||
    out[token_idx * d + idx] = ACT_FN(x);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
@ -70,6 +73,7 @@ __global__ void activation_kernel(
 | 
			
		||||
  int64_t num_tokens = input.numel() / d;                                                 \
 | 
			
		||||
  dim3 grid(num_tokens);                                                                  \
 | 
			
		||||
  dim3 block(std::min(d, 1024));                                                          \
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));                       \
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();                           \
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(                                                           \
 | 
			
		||||
    input.scalar_type(),                                                                  \
 | 
			
		||||
 | 
			
		||||
@ -4,3 +4,4 @@
 | 
			
		||||
#include "dtype_float16.cuh"
 | 
			
		||||
#include "dtype_float32.cuh"
 | 
			
		||||
#include "dtype_bfloat16.cuh"
 | 
			
		||||
#include "dtype_fp8_e5m2.cuh"
 | 
			
		||||
 | 
			
		||||
@ -15,15 +15,27 @@
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
#include <hip/hip_runtime.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "attention_dtypes.h"
 | 
			
		||||
#include "attention_utils.cuh"
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
#include "../quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
#define WARP_SIZE 32
 | 
			
		||||
#else
 | 
			
		||||
#define WARP_SIZE warpSize
 | 
			
		||||
#endif
 | 
			
		||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
 | 
			
		||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
 | 
			
		||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
 | 
			
		||||
@ -40,7 +52,7 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 | 
			
		||||
  // Compute the sum per warp.
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
 | 
			
		||||
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Warp leaders store the data to shared memory.
 | 
			
		||||
@ -59,29 +71,31 @@ inline __device__ float block_sum(float* red_smem, float sum) {
 | 
			
		||||
  // Parallel reduction inside the warp.
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
    sum += __shfl_xor_sync(uint32_t(-1), sum, mask);
 | 
			
		||||
    sum += VLLM_SHFL_XOR_SYNC(sum, mask);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // Broadcast to other threads.
 | 
			
		||||
  return __shfl_sync(uint32_t(-1), sum, 0);
 | 
			
		||||
  return VLLM_SHFL_SYNC(sum, 0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// TODO(woosuk): Merge the last two dimensions of the grid.
 | 
			
		||||
// Grid: (num_heads, num_seqs, max_num_partitions).
 | 
			
		||||
template<
 | 
			
		||||
  typename scalar_t,
 | 
			
		||||
  typename cache_t,
 | 
			
		||||
  int HEAD_SIZE,
 | 
			
		||||
  int BLOCK_SIZE,
 | 
			
		||||
  int NUM_THREADS,
 | 
			
		||||
  bool IS_FP8_E5M2_KV_CACHE,
 | 
			
		||||
  int PARTITION_SIZE = 0> // Zero means no partitioning.
 | 
			
		||||
__device__ void paged_attention_kernel(
 | 
			
		||||
  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
 | 
			
		||||
  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
 | 
			
		||||
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, max_num_partitions, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
			
		||||
  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
 | 
			
		||||
  const int* __restrict__ head_mapping,   // [num_heads]
 | 
			
		||||
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
			
		||||
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
 | 
			
		||||
  const int num_kv_heads,                 // [num_heads]
 | 
			
		||||
  const float scale,
 | 
			
		||||
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
			
		||||
  const int* __restrict__ context_lens,   // [num_seqs]
 | 
			
		||||
@ -124,7 +138,8 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
 | 
			
		||||
  const int head_idx = blockIdx.x;
 | 
			
		||||
  const int num_heads = gridDim.x;
 | 
			
		||||
  const int kv_head_idx = head_mapping[head_idx];
 | 
			
		||||
  const int num_queries_per_kv = num_heads / num_kv_heads;
 | 
			
		||||
  const int kv_head_idx = head_idx / num_queries_per_kv;
 | 
			
		||||
  const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
 | 
			
		||||
 | 
			
		||||
  // A vector type to store a part of a key or a query.
 | 
			
		||||
@ -135,6 +150,9 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
  constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1);
 | 
			
		||||
  using K_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
 | 
			
		||||
  using Q_vec = typename Vec<scalar_t, VEC_SIZE>::Type;
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
  using Quant_vec = typename Vec<cache_t, VEC_SIZE>::Type;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE;
 | 
			
		||||
  constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE;
 | 
			
		||||
@ -166,7 +184,7 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
 | 
			
		||||
  // x == THREAD_GROUP_SIZE * VEC_SIZE
 | 
			
		||||
  // Each thread group fetches x elements from the key at a time.
 | 
			
		||||
  constexpr int x = 16 / sizeof(scalar_t);
 | 
			
		||||
  constexpr int x = 16 / sizeof(cache_t);
 | 
			
		||||
  float qk_max = -FLT_MAX;
 | 
			
		||||
 | 
			
		||||
  // Iterate over the key blocks.
 | 
			
		||||
@ -192,14 +210,24 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
      for (int j = 0; j < NUM_VECS_PER_THREAD; j++) {
 | 
			
		||||
        const scalar_t* k_ptr = k_cache + physical_block_number * kv_block_stride
 | 
			
		||||
        const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride
 | 
			
		||||
                                       + kv_head_idx * kv_head_stride
 | 
			
		||||
                                       + physical_block_offset * x;
 | 
			
		||||
        const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE;
 | 
			
		||||
        const int offset1 = (vec_idx * VEC_SIZE) / x;
 | 
			
		||||
        const int offset2 = (vec_idx * VEC_SIZE) % x;
 | 
			
		||||
        if constexpr (IS_FP8_E5M2_KV_CACHE) {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
          Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
 | 
			
		||||
          // Vector conversion from Quant_vec to K_vec.
 | 
			
		||||
          k_vecs[j] = fp8_e5m2_unscaled::vec_conversion<K_vec, Quant_vec>(k_vec_quant);
 | 
			
		||||
#else
 | 
			
		||||
          assert(false);
 | 
			
		||||
#endif
 | 
			
		||||
        } else {
 | 
			
		||||
          k_vecs[j] = *reinterpret_cast<const K_vec*>(k_ptr + offset1 * BLOCK_SIZE * x + offset2);
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      // Compute dot product.
 | 
			
		||||
      // This includes a reduction across the threads in the same thread group.
 | 
			
		||||
@ -223,7 +251,7 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
  // The 0-th thread of each thread group already has its max qk value.
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) {
 | 
			
		||||
    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
 | 
			
		||||
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
 | 
			
		||||
  }
 | 
			
		||||
  if (lane == 0) {
 | 
			
		||||
    red_smem[warp_idx] = qk_max;
 | 
			
		||||
@ -235,10 +263,10 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
  qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
    qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask));
 | 
			
		||||
    qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask));
 | 
			
		||||
  }
 | 
			
		||||
  // Broadcast the max qk value to all threads.
 | 
			
		||||
  qk_max = __shfl_sync(uint32_t(-1), qk_max, 0);
 | 
			
		||||
  qk_max = VLLM_SHFL_SYNC(qk_max, 0);
 | 
			
		||||
 | 
			
		||||
  // Get the sum of the exp values.
 | 
			
		||||
  float exp_sum = 0.f;
 | 
			
		||||
@ -272,6 +300,9 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
  constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE);
 | 
			
		||||
  using V_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
 | 
			
		||||
  using L_vec = typename Vec<scalar_t, V_VEC_SIZE>::Type;
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
  using V_quant_vec = typename Vec<cache_t, V_VEC_SIZE>::Type;
 | 
			
		||||
#endif
 | 
			
		||||
  using Float_L_vec = typename FloatVec<L_vec>::Type;
 | 
			
		||||
 | 
			
		||||
  constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE;
 | 
			
		||||
@ -297,14 +328,25 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
    L_vec logits_vec;
 | 
			
		||||
    from_float(logits_vec, *reinterpret_cast<Float_L_vec*>(logits + token_idx - start_token_idx));
 | 
			
		||||
 | 
			
		||||
    const scalar_t* v_ptr = v_cache + physical_block_number * kv_block_stride
 | 
			
		||||
    const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride
 | 
			
		||||
                                   + kv_head_idx * kv_head_stride;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) {
 | 
			
		||||
      const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER;
 | 
			
		||||
      if (row_idx < HEAD_SIZE) {
 | 
			
		||||
        const int offset = row_idx * BLOCK_SIZE + physical_block_offset;
 | 
			
		||||
        V_vec v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
 | 
			
		||||
        V_vec v_vec;
 | 
			
		||||
        if constexpr (IS_FP8_E5M2_KV_CACHE) {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
          V_quant_vec v_quant_vec = *reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
 | 
			
		||||
          // Vector conversion from V_quant_vec to V_vec.
 | 
			
		||||
          v_vec = fp8_e5m2_unscaled::vec_conversion<V_vec, V_quant_vec>(v_quant_vec);
 | 
			
		||||
#else
 | 
			
		||||
          assert(false);
 | 
			
		||||
#endif
 | 
			
		||||
        } else {
 | 
			
		||||
          v_vec = *reinterpret_cast<const V_vec*>(v_ptr + offset);
 | 
			
		||||
        }
 | 
			
		||||
        if (block_idx == num_context_blocks - 1) {
 | 
			
		||||
          // NOTE(woosuk): When v_vec contains the tokens that are out of the context,
 | 
			
		||||
          // we should explicitly zero out the values since they may contain NaNs.
 | 
			
		||||
@ -326,7 +368,7 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
    float acc = accs[i];
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
      acc += __shfl_xor_sync(uint32_t(-1), acc, mask);
 | 
			
		||||
      acc += VLLM_SHFL_XOR_SYNC(acc, mask);
 | 
			
		||||
    }
 | 
			
		||||
    accs[i] = acc;
 | 
			
		||||
  }
 | 
			
		||||
@ -385,15 +427,17 @@ __device__ void paged_attention_kernel(
 | 
			
		||||
// Grid: (num_heads, num_seqs, 1).
 | 
			
		||||
template<
 | 
			
		||||
  typename scalar_t,
 | 
			
		||||
  typename cache_t,
 | 
			
		||||
  int HEAD_SIZE,
 | 
			
		||||
  int BLOCK_SIZE,
 | 
			
		||||
  int NUM_THREADS>
 | 
			
		||||
  int NUM_THREADS,
 | 
			
		||||
  bool IS_FP8_E5M2_KV_CACHE>
 | 
			
		||||
__global__ void paged_attention_v1_kernel(
 | 
			
		||||
  scalar_t* __restrict__ out,             // [num_seqs, num_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
			
		||||
  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
 | 
			
		||||
  const int* __restrict__ head_mapping,   // [num_heads]
 | 
			
		||||
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
			
		||||
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
 | 
			
		||||
  const int num_kv_heads,                 // [num_heads]
 | 
			
		||||
  const float scale,
 | 
			
		||||
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
			
		||||
  const int* __restrict__ context_lens,   // [num_seqs]
 | 
			
		||||
@ -402,27 +446,29 @@ __global__ void paged_attention_v1_kernel(
 | 
			
		||||
  const int q_stride,
 | 
			
		||||
  const int kv_block_stride,
 | 
			
		||||
  const int kv_head_stride) {
 | 
			
		||||
  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>(
 | 
			
		||||
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
 | 
			
		||||
    /* exp_sums */ nullptr, /* max_logits */ nullptr,
 | 
			
		||||
    out, q, k_cache, v_cache, head_mapping, scale, block_tables, context_lens,
 | 
			
		||||
    out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
 | 
			
		||||
    max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Grid: (num_heads, num_seqs, max_num_partitions).
 | 
			
		||||
template<
 | 
			
		||||
  typename scalar_t,
 | 
			
		||||
  typename cache_t,
 | 
			
		||||
  int HEAD_SIZE,
 | 
			
		||||
  int BLOCK_SIZE,
 | 
			
		||||
  int NUM_THREADS,
 | 
			
		||||
  bool IS_FP8_E5M2_KV_CACHE,
 | 
			
		||||
  int PARTITION_SIZE>
 | 
			
		||||
__global__ void paged_attention_v2_kernel(
 | 
			
		||||
  float* __restrict__ exp_sums,           // [num_seqs, num_heads, max_num_partitions]
 | 
			
		||||
  float* __restrict__ max_logits,         // [num_seqs, num_heads, max_num_partitions]
 | 
			
		||||
  scalar_t* __restrict__ tmp_out,         // [num_seqs, num_heads, max_num_partitions, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ q,         // [num_seqs, num_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ k_cache,   // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
			
		||||
  const scalar_t* __restrict__ v_cache,   // [num_blocks, num_kv_heads, head_size, block_size]
 | 
			
		||||
  const int* __restrict__ head_mapping,   // [num_heads]
 | 
			
		||||
  const cache_t* __restrict__ k_cache,    // [num_blocks, num_kv_heads, head_size/x, block_size, x]
 | 
			
		||||
  const cache_t* __restrict__ v_cache,    // [num_blocks, num_kv_heads, head_size, block_size]
 | 
			
		||||
  const int num_kv_heads,                 // [num_heads]
 | 
			
		||||
  const float scale,
 | 
			
		||||
  const int* __restrict__ block_tables,   // [num_seqs, max_num_blocks_per_seq]
 | 
			
		||||
  const int* __restrict__ context_lens,   // [num_seqs]
 | 
			
		||||
@ -431,8 +477,8 @@ __global__ void paged_attention_v2_kernel(
 | 
			
		||||
  const int q_stride,
 | 
			
		||||
  const int kv_block_stride,
 | 
			
		||||
  const int kv_head_stride) {
 | 
			
		||||
  paged_attention_kernel<scalar_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>(
 | 
			
		||||
    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, head_mapping, scale,
 | 
			
		||||
  paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
 | 
			
		||||
    exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
 | 
			
		||||
    block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
 | 
			
		||||
    q_stride, kv_block_stride, kv_head_stride);
 | 
			
		||||
}
 | 
			
		||||
@ -492,7 +538,7 @@ __global__ void paged_attention_v2_reduce_kernel(
 | 
			
		||||
  // Reduce within the warp.
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
 | 
			
		||||
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
 | 
			
		||||
  }
 | 
			
		||||
  if (lane == 0) {
 | 
			
		||||
    red_smem[warp_idx] = max_logit;
 | 
			
		||||
@ -502,10 +548,10 @@ __global__ void paged_attention_v2_reduce_kernel(
 | 
			
		||||
  max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
    max_logit = fmaxf(max_logit, __shfl_xor_sync(uint32_t(-1), max_logit, mask));
 | 
			
		||||
    max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask));
 | 
			
		||||
  }
 | 
			
		||||
  // Broadcast the max value to all threads.
 | 
			
		||||
  max_logit = __shfl_sync(uint32_t(-1), max_logit, 0);
 | 
			
		||||
  max_logit = VLLM_SHFL_SYNC(max_logit, 0);
 | 
			
		||||
 | 
			
		||||
  // Load rescaled exp sums to shared memory.
 | 
			
		||||
  float* shared_exp_sums = reinterpret_cast<float*>(shared_mem + sizeof(float) * num_partitions);
 | 
			
		||||
@ -539,16 +585,16 @@ __global__ void paged_attention_v2_reduce_kernel(
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
 | 
			
		||||
#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE)                                                  \
 | 
			
		||||
  cudaFuncSetAttribute(                                                                       \
 | 
			
		||||
    vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>,                   \
 | 
			
		||||
    cudaFuncAttributeMaxDynamicSharedMemorySize, shared_mem_size);                            \
 | 
			
		||||
  vllm::paged_attention_v1_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS>                      \
 | 
			
		||||
  <<<grid, block, shared_mem_size, stream>>>(                                                 \
 | 
			
		||||
  VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(                                       \
 | 
			
		||||
    ((void*)vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,   \
 | 
			
		||||
      IS_FP8_E5M2_KV_CACHE>), shared_mem_size);                                               \
 | 
			
		||||
  vllm::paged_attention_v1_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
 | 
			
		||||
  IS_FP8_E5M2_KV_CACHE><<<grid, block, shared_mem_size, stream>>>(                            \
 | 
			
		||||
    out_ptr,                                                                                  \
 | 
			
		||||
    query_ptr,                                                                                \
 | 
			
		||||
    key_cache_ptr,                                                                            \
 | 
			
		||||
    value_cache_ptr,                                                                          \
 | 
			
		||||
    head_mapping_ptr,                                                                         \
 | 
			
		||||
    num_kv_heads,                                                                             \
 | 
			
		||||
    scale,                                                                                    \
 | 
			
		||||
    block_tables_ptr,                                                                         \
 | 
			
		||||
    context_lens_ptr,                                                                         \
 | 
			
		||||
@ -561,14 +607,16 @@ __global__ void paged_attention_v2_reduce_kernel(
 | 
			
		||||
// TODO(woosuk): Tune NUM_THREADS.
 | 
			
		||||
template<
 | 
			
		||||
  typename T,
 | 
			
		||||
  typename CACHE_T,
 | 
			
		||||
  int BLOCK_SIZE,
 | 
			
		||||
  bool IS_FP8_E5M2_KV_CACHE,
 | 
			
		||||
  int NUM_THREADS = 128>
 | 
			
		||||
void paged_attention_v1_launcher(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& head_mapping,
 | 
			
		||||
  int num_kv_heads,
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,
 | 
			
		||||
  torch::Tensor& context_lens,
 | 
			
		||||
@ -592,9 +640,8 @@ void paged_attention_v1_launcher(
 | 
			
		||||
 | 
			
		||||
  T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
 | 
			
		||||
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
			
		||||
  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
			
		||||
  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
			
		||||
  int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
 | 
			
		||||
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
 | 
			
		||||
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
 | 
			
		||||
  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
			
		||||
  int* context_lens_ptr = context_lens.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
@ -608,6 +655,7 @@ void paged_attention_v1_launcher(
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_heads, num_seqs, 1);
 | 
			
		||||
  dim3 block(NUM_THREADS);
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  switch (head_size) {
 | 
			
		||||
    // NOTE(woosuk): To reduce the compilation time, we only compile for the
 | 
			
		||||
@ -637,13 +685,13 @@ void paged_attention_v1_launcher(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CALL_V1_LAUNCHER(T, BLOCK_SIZE)                             \
 | 
			
		||||
  paged_attention_v1_launcher<T, BLOCK_SIZE>(                       \
 | 
			
		||||
#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)       \
 | 
			
		||||
  paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>( \
 | 
			
		||||
    out,                                                                     \
 | 
			
		||||
    query,                                                                   \
 | 
			
		||||
    key_cache,                                                               \
 | 
			
		||||
    value_cache,                                                             \
 | 
			
		||||
    head_mapping,                                                   \
 | 
			
		||||
    num_kv_heads,                                                            \
 | 
			
		||||
    scale,                                                                   \
 | 
			
		||||
    block_tables,                                                            \
 | 
			
		||||
    context_lens,                                                            \
 | 
			
		||||
@ -652,16 +700,16 @@ void paged_attention_v1_launcher(
 | 
			
		||||
 | 
			
		||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
 | 
			
		||||
// 1, 2, 4, 64, 128, 256.
 | 
			
		||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T)                              \
 | 
			
		||||
#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE) \
 | 
			
		||||
  switch (block_size) {                                               \
 | 
			
		||||
    case 8:                                                           \
 | 
			
		||||
      CALL_V1_LAUNCHER(T, 8);                                       \
 | 
			
		||||
      CALL_V1_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);          \
 | 
			
		||||
      break;                                                          \
 | 
			
		||||
    case 16:                                                          \
 | 
			
		||||
      CALL_V1_LAUNCHER(T, 16);                                      \
 | 
			
		||||
      CALL_V1_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);         \
 | 
			
		||||
      break;                                                          \
 | 
			
		||||
    case 32:                                                          \
 | 
			
		||||
      CALL_V1_LAUNCHER(T, 32);                                      \
 | 
			
		||||
      CALL_V1_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);         \
 | 
			
		||||
      break;                                                          \
 | 
			
		||||
    default:                                                          \
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported block size: ", block_size);     \
 | 
			
		||||
@ -673,26 +721,42 @@ void paged_attention_v1(
 | 
			
		||||
  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
 | 
			
		||||
  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  torch::Tensor& head_mapping,    // [num_heads]
 | 
			
		||||
  int num_kv_heads,               // [num_heads]
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
 | 
			
		||||
  torch::Tensor& context_lens,    // [num_seqs]
 | 
			
		||||
  int block_size,
 | 
			
		||||
  int max_context_len,
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
  const std::string& kv_cache_dtype) {
 | 
			
		||||
  if (kv_cache_dtype == "auto") {
 | 
			
		||||
    if (query.dtype() == at::ScalarType::Float) {
 | 
			
		||||
    CALL_V1_LAUNCHER_BLOCK_SIZE(float);
 | 
			
		||||
      CALL_V1_LAUNCHER_BLOCK_SIZE(float, float, false);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::Half) {
 | 
			
		||||
    CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t);
 | 
			
		||||
      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
    CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
 | 
			
		||||
      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
			
		||||
    }
 | 
			
		||||
  } else if (kv_cache_dtype == "fp8_e5m2") {
 | 
			
		||||
    if (query.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_V1_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_V1_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_V1_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE)                                                  \
 | 
			
		||||
  vllm::paged_attention_v2_kernel<T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, PARTITION_SIZE>      \
 | 
			
		||||
  vllm::paged_attention_v2_kernel<T, CACHE_T, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,             \
 | 
			
		||||
  IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>                                                       \
 | 
			
		||||
  <<<grid, block, shared_mem_size, stream>>>(                                                 \
 | 
			
		||||
    exp_sums_ptr,                                                                             \
 | 
			
		||||
    max_logits_ptr,                                                                           \
 | 
			
		||||
@ -700,7 +764,7 @@ void paged_attention_v1(
 | 
			
		||||
    query_ptr,                                                                                \
 | 
			
		||||
    key_cache_ptr,                                                                            \
 | 
			
		||||
    value_cache_ptr,                                                                          \
 | 
			
		||||
    head_mapping_ptr,                                                                         \
 | 
			
		||||
    num_kv_heads,                                                                             \
 | 
			
		||||
    scale,                                                                                    \
 | 
			
		||||
    block_tables_ptr,                                                                         \
 | 
			
		||||
    context_lens_ptr,                                                                         \
 | 
			
		||||
@ -720,7 +784,9 @@ void paged_attention_v1(
 | 
			
		||||
 | 
			
		||||
template<
 | 
			
		||||
  typename T,
 | 
			
		||||
  typename CACHE_T,
 | 
			
		||||
  int BLOCK_SIZE,
 | 
			
		||||
  bool IS_FP8_E5M2_KV_CACHE,
 | 
			
		||||
  int NUM_THREADS = 128,
 | 
			
		||||
  int PARTITION_SIZE = 512>
 | 
			
		||||
void paged_attention_v2_launcher(
 | 
			
		||||
@ -731,7 +797,7 @@ void paged_attention_v2_launcher(
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& head_mapping,
 | 
			
		||||
  int num_kv_heads,
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,
 | 
			
		||||
  torch::Tensor& context_lens,
 | 
			
		||||
@ -758,9 +824,8 @@ void paged_attention_v2_launcher(
 | 
			
		||||
  float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
 | 
			
		||||
  T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
 | 
			
		||||
  T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
 | 
			
		||||
  T* key_cache_ptr = reinterpret_cast<T*>(key_cache.data_ptr());
 | 
			
		||||
  T* value_cache_ptr = reinterpret_cast<T*>(value_cache.data_ptr());
 | 
			
		||||
  int* head_mapping_ptr = reinterpret_cast<int*>(head_mapping.data_ptr());
 | 
			
		||||
  CACHE_T* key_cache_ptr = reinterpret_cast<CACHE_T*>(key_cache.data_ptr());
 | 
			
		||||
  CACHE_T* value_cache_ptr = reinterpret_cast<CACHE_T*>(value_cache.data_ptr());
 | 
			
		||||
  int* block_tables_ptr = block_tables.data_ptr<int>();
 | 
			
		||||
  int* context_lens_ptr = context_lens.data_ptr<int>();
 | 
			
		||||
 | 
			
		||||
@ -777,6 +842,7 @@ void paged_attention_v2_launcher(
 | 
			
		||||
  int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float);
 | 
			
		||||
 | 
			
		||||
  dim3 block(NUM_THREADS);
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  switch (head_size) {
 | 
			
		||||
    // NOTE(woosuk): To reduce the compilation time, we only compile for the
 | 
			
		||||
@ -806,8 +872,8 @@ void paged_attention_v2_launcher(
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CALL_V2_LAUNCHER(T, BLOCK_SIZE)                             \
 | 
			
		||||
  paged_attention_v2_launcher<T, BLOCK_SIZE>(                       \
 | 
			
		||||
#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE)           \
 | 
			
		||||
  paged_attention_v2_launcher<T, CACHE_T, BLOCK_SIZE, IS_FP8_E5M2_KV_CACHE>(     \
 | 
			
		||||
    out,                                                                         \
 | 
			
		||||
    exp_sums,                                                                    \
 | 
			
		||||
    max_logits,                                                                  \
 | 
			
		||||
@ -815,7 +881,7 @@ void paged_attention_v2_launcher(
 | 
			
		||||
    query,                                                                       \
 | 
			
		||||
    key_cache,                                                                   \
 | 
			
		||||
    value_cache,                                                                 \
 | 
			
		||||
    head_mapping,                                                   \
 | 
			
		||||
    num_kv_heads,                                                                \
 | 
			
		||||
    scale,                                                                       \
 | 
			
		||||
    block_tables,                                                                \
 | 
			
		||||
    context_lens,                                                                \
 | 
			
		||||
@ -824,16 +890,16 @@ void paged_attention_v2_launcher(
 | 
			
		||||
 | 
			
		||||
// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
 | 
			
		||||
// 1, 2, 4, 64, 128, 256.
 | 
			
		||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T)                              \
 | 
			
		||||
#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, IS_FP8_E5M2_KV_CACHE)       \
 | 
			
		||||
  switch (block_size) {                                                     \
 | 
			
		||||
    case 8:                                                                 \
 | 
			
		||||
      CALL_V2_LAUNCHER(T, 8);                                       \
 | 
			
		||||
      CALL_V2_LAUNCHER(T, CACHE_T, 8, IS_FP8_E5M2_KV_CACHE);                \
 | 
			
		||||
      break;                                                                \
 | 
			
		||||
    case 16:                                                                \
 | 
			
		||||
      CALL_V2_LAUNCHER(T, 16);                                      \
 | 
			
		||||
      CALL_V2_LAUNCHER(T, CACHE_T, 16, IS_FP8_E5M2_KV_CACHE);               \
 | 
			
		||||
      break;                                                                \
 | 
			
		||||
    case 32:                                                                \
 | 
			
		||||
      CALL_V2_LAUNCHER(T, 32);                                      \
 | 
			
		||||
      CALL_V2_LAUNCHER(T, CACHE_T, 32, IS_FP8_E5M2_KV_CACHE);               \
 | 
			
		||||
      break;                                                                \
 | 
			
		||||
    default:                                                                \
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported block size: ", block_size);           \
 | 
			
		||||
@ -848,22 +914,37 @@ void paged_attention_v2(
 | 
			
		||||
  torch::Tensor& query,           // [num_seqs, num_heads, head_size]
 | 
			
		||||
  torch::Tensor& key_cache,       // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  torch::Tensor& value_cache,     // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  torch::Tensor& head_mapping,    // [num_heads]
 | 
			
		||||
  int num_kv_heads,               // [num_heads]
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,    // [num_seqs, max_num_blocks_per_seq]
 | 
			
		||||
  torch::Tensor& context_lens,    // [num_seqs]
 | 
			
		||||
  int block_size,
 | 
			
		||||
  int max_context_len,
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes) {
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
  const std::string& kv_cache_dtype) {
 | 
			
		||||
  if (kv_cache_dtype == "auto") {
 | 
			
		||||
    if (query.dtype() == at::ScalarType::Float) {
 | 
			
		||||
    CALL_V2_LAUNCHER_BLOCK_SIZE(float);
 | 
			
		||||
      CALL_V2_LAUNCHER_BLOCK_SIZE(float, float, false);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::Half) {
 | 
			
		||||
    CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t);
 | 
			
		||||
      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint16_t, false);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
    CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16);
 | 
			
		||||
      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, __nv_bfloat16, false);
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
			
		||||
    }
 | 
			
		||||
  } else if (kv_cache_dtype == "fp8_e5m2") {
 | 
			
		||||
    if (query.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_V2_LAUNCHER_BLOCK_SIZE(float, uint8_t, true);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_V2_LAUNCHER_BLOCK_SIZE(uint16_t, uint8_t, true);
 | 
			
		||||
    } else if (query.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_V2_LAUNCHER_BLOCK_SIZE(__nv_bfloat16, uint8_t, true);
 | 
			
		||||
    } else {
 | 
			
		||||
      TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#undef WARP_SIZE
 | 
			
		||||
 | 
			
		||||
@ -17,6 +17,7 @@
 | 
			
		||||
 */
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "../cuda_compat.h"
 | 
			
		||||
#include "attention_dtypes.h"
 | 
			
		||||
 | 
			
		||||
#include <float.h>
 | 
			
		||||
@ -39,7 +40,7 @@ inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) {
 | 
			
		||||
  float qk = sum(qk_vec);
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) {
 | 
			
		||||
    qk += __shfl_xor_sync(uint32_t(-1), qk, mask);
 | 
			
		||||
    qk += VLLM_SHFL_XOR_SYNC(qk, mask);
 | 
			
		||||
  }
 | 
			
		||||
  return qk;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -21,8 +21,17 @@
 | 
			
		||||
#include "attention_generic.cuh"
 | 
			
		||||
#include "dtype_float32.cuh"
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #include <cuda_bf16.h>
 | 
			
		||||
  #include <cuda_fp16.h>
 | 
			
		||||
#else
 | 
			
		||||
  #include <hip/hip_bf16.h>
 | 
			
		||||
  #include <hip/hip_fp16.h>
 | 
			
		||||
 | 
			
		||||
  typedef __hip_bfloat162 __nv_bfloat162;
 | 
			
		||||
  typedef __hip_bfloat16 __nv_bfloat16;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
@ -98,7 +107,11 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
			
		||||
  assert(false);
 | 
			
		||||
#else
 | 
			
		||||
  #ifndef USE_ROCM
 | 
			
		||||
    return a + b;
 | 
			
		||||
  #else
 | 
			
		||||
    return __hadd(a, b);
 | 
			
		||||
  #endif
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,6 +21,10 @@
 | 
			
		||||
#include "attention_generic.cuh"
 | 
			
		||||
#include "dtype_float32.cuh"
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  #include <hip/hip_fp16.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
@ -63,21 +67,47 @@ struct FloatVec<uint4> {
 | 
			
		||||
 | 
			
		||||
// Utility functions for type conversions.
 | 
			
		||||
inline __device__ uint32_t h0_h0(uint16_t a) {
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  uint32_t b;
 | 
			
		||||
  asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a));
 | 
			
		||||
  return b;
 | 
			
		||||
#else
 | 
			
		||||
  union {
 | 
			
		||||
   uint32_t u32;
 | 
			
		||||
   uint16_t u16[2];
 | 
			
		||||
  } tmp;
 | 
			
		||||
  tmp.u16[0] = a;
 | 
			
		||||
  tmp.u16[1] = a;
 | 
			
		||||
  return tmp.u32;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ float half_to_float(uint16_t h) {
 | 
			
		||||
  float f;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h));
 | 
			
		||||
#endif
 | 
			
		||||
  return f;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ float2 half2_to_float2(uint32_t v) {
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  uint16_t lo, hi;
 | 
			
		||||
  asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v));
 | 
			
		||||
  return make_float2(half_to_float(lo), half_to_float(hi));
 | 
			
		||||
#else
 | 
			
		||||
  union {
 | 
			
		||||
    uint32_t u32;
 | 
			
		||||
    uint16_t u16[2];
 | 
			
		||||
  } tmp;
 | 
			
		||||
  tmp.u32 = v;
 | 
			
		||||
  float2 ret;
 | 
			
		||||
  ret.x = half_to_float(tmp.u16[0]);
 | 
			
		||||
  ret.y = half_to_float(tmp.u16[1]);
 | 
			
		||||
  return ret;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ uint16_t float_to_half(float f) {
 | 
			
		||||
@ -85,7 +115,11 @@ inline __device__ uint16_t float_to_half(float f) {
 | 
			
		||||
    uint32_t u32;
 | 
			
		||||
    uint16_t u16[2];
 | 
			
		||||
  } tmp;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f));
 | 
			
		||||
#endif
 | 
			
		||||
  return tmp.u16[0];
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -94,26 +128,38 @@ inline __device__ uint32_t float2_to_half2(float2 f) {
 | 
			
		||||
    uint32_t u32;
 | 
			
		||||
    uint16_t u16[2];
 | 
			
		||||
  } tmp;
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
 | 
			
		||||
    asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x));
 | 
			
		||||
  #else
 | 
			
		||||
    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x));
 | 
			
		||||
    asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y));
 | 
			
		||||
  #endif
 | 
			
		||||
#else
 | 
			
		||||
  tmp.u16[0] = float_to_half(f.x);
 | 
			
		||||
  tmp.u16[1] = float_to_half(f.y);
 | 
			
		||||
#endif
 | 
			
		||||
  return tmp.u32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Vector addition.
 | 
			
		||||
inline __device__ uint16_t add(uint16_t a, uint16_t b) {
 | 
			
		||||
  uint16_t c;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
 | 
			
		||||
#endif
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline __device__ uint32_t add(uint32_t a, uint32_t b) {
 | 
			
		||||
  uint32_t c;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
 | 
			
		||||
#endif
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -158,14 +204,22 @@ inline __device__ Float8_ add(uint4 a, Float8_ fb) {
 | 
			
		||||
template<>
 | 
			
		||||
inline __device__ uint16_t mul(uint16_t a, uint16_t b) {
 | 
			
		||||
  uint16_t c;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
 | 
			
		||||
#endif
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
inline __device__ uint32_t mul(uint32_t a, uint32_t b) {
 | 
			
		||||
  uint32_t c;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b));
 | 
			
		||||
#endif
 | 
			
		||||
  return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -272,7 +326,11 @@ inline __device__ Float8_ mul(uint16_t a, uint4 b) {
 | 
			
		||||
// Vector fused multiply-add.
 | 
			
		||||
inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) {
 | 
			
		||||
  uint32_t d;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c));
 | 
			
		||||
#else
 | 
			
		||||
  asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" : "=v"(d) : "v"(a), "v"(b), "v"(c));
 | 
			
		||||
#endif
 | 
			
		||||
  return d;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										35
									
								
								csrc/attention/dtype_fp8_e5m2.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								csrc/attention/dtype_fp8_e5m2.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "attention_generic.cuh"
 | 
			
		||||
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
#include <cuda_fp8.h>
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
// fp8 vector types for quantization of kv cache
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 1> {
 | 
			
		||||
    using Type = uint8_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 2> {
 | 
			
		||||
    using Type = uint16_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 4> {
 | 
			
		||||
    using Type = uint32_t;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
struct Vec<uint8_t, 8> {
 | 
			
		||||
    using Type = uint2;
 | 
			
		||||
};
 | 
			
		||||
#endif // ENABLE_FP8_E5M2
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
							
								
								
									
										10
									
								
								csrc/cache.h
									
									
									
									
									
								
							
							
						
						
									
										10
									
								
								csrc/cache.h
									
									
									
									
									
								
							@ -1,3 +1,5 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
#include <map>
 | 
			
		||||
@ -18,7 +20,8 @@ void reshape_and_cache(
 | 
			
		||||
  torch::Tensor& value,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& slot_mapping);
 | 
			
		||||
  torch::Tensor& slot_mapping,
 | 
			
		||||
  const std::string& kv_cache_dtype);
 | 
			
		||||
 | 
			
		||||
void gather_cached_kv(
 | 
			
		||||
  torch::Tensor& key,
 | 
			
		||||
@ -26,3 +29,8 @@ void gather_cached_kv(
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& slot_mapping);
 | 
			
		||||
 | 
			
		||||
// Just for unittest
 | 
			
		||||
void convert_fp8_e5m2(
 | 
			
		||||
  torch::Tensor& src_cache,
 | 
			
		||||
  torch::Tensor& dst_cache);
 | 
			
		||||
 | 
			
		||||
@ -1,13 +1,23 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
#include "quantization/fp8_e5m2_kvcache/quant_utils.cuh"
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cassert>
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  #include <hip/hip_bf16.h>
 | 
			
		||||
  typedef __hip_bfloat16 __nv_bfloat16;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
void swap_blocks(
 | 
			
		||||
  torch::Tensor& src,
 | 
			
		||||
  torch::Tensor& dst,
 | 
			
		||||
@ -28,10 +38,11 @@ void swap_blocks(
 | 
			
		||||
    TORCH_CHECK(false, "Invalid device combination");
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void *src_ptr = src.data_ptr();
 | 
			
		||||
  void *dst_ptr = dst.data_ptr();
 | 
			
		||||
  char *src_ptr = static_cast<char*>(src.data_ptr());
 | 
			
		||||
  char *dst_ptr = static_cast<char*>(dst.data_ptr());
 | 
			
		||||
 | 
			
		||||
  const int64_t block_size_in_bytes = src.element_size() * src[0].numel();
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(src_device.is_cuda() ? src_device : dst_device);
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  // NOTE(woosuk): This can be slow if the number of blocks is large.
 | 
			
		||||
  for (const auto& pair : block_mapping) {
 | 
			
		||||
@ -126,8 +137,9 @@ void copy_blocks(
 | 
			
		||||
  const int numel_per_block = key_caches[0][0].numel();
 | 
			
		||||
  dim3 grid(num_layers, num_pairs);
 | 
			
		||||
  dim3 block(std::min(1024, numel_per_block));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(cache_device);
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
 | 
			
		||||
    key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] {
 | 
			
		||||
      vllm::copy_blocks_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
        key_cache_ptrs_tensor.data_ptr<int64_t>(),
 | 
			
		||||
@ -139,12 +151,12 @@ void copy_blocks(
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename scalar_t>
 | 
			
		||||
template<typename scalar_t, typename cache_t, bool is_fp8_e5m2_kv_cache>
 | 
			
		||||
__global__ void reshape_and_cache_kernel(
 | 
			
		||||
  const scalar_t* __restrict__ key,           // [num_tokens, num_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ value,         // [num_tokens, num_heads, head_size]
 | 
			
		||||
  scalar_t* __restrict__ key_cache,           // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  scalar_t* __restrict__ value_cache,         // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  cache_t* __restrict__ key_cache,            // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  cache_t* __restrict__ value_cache,          // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  const int64_t* __restrict__ slot_mapping,   // [num_tokens]
 | 
			
		||||
  const int key_stride,
 | 
			
		||||
  const int value_stride,
 | 
			
		||||
@ -181,19 +193,45 @@ __global__ void reshape_and_cache_kernel(
 | 
			
		||||
                                  + head_idx * head_size * block_size
 | 
			
		||||
                                  + head_offset * block_size
 | 
			
		||||
                                  + block_offset;
 | 
			
		||||
    key_cache[tgt_key_idx] = key[src_key_idx];
 | 
			
		||||
    value_cache[tgt_value_idx] = value[src_value_idx];
 | 
			
		||||
    scalar_t tgt_key = key[src_key_idx];
 | 
			
		||||
    scalar_t tgt_value = value[src_value_idx];
 | 
			
		||||
    if constexpr (is_fp8_e5m2_kv_cache) {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
      key_cache[tgt_key_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_key);
 | 
			
		||||
      value_cache[tgt_value_idx] = fp8_e5m2_unscaled::vec_conversion<uint8_t, scalar_t>(tgt_value);
 | 
			
		||||
#else
 | 
			
		||||
      assert(false);
 | 
			
		||||
#endif
 | 
			
		||||
    } else {
 | 
			
		||||
      key_cache[tgt_key_idx] = tgt_key;
 | 
			
		||||
      value_cache[tgt_value_idx] = tgt_value;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
 | 
			
		||||
#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE)                                \
 | 
			
		||||
  vllm::reshape_and_cache_kernel<KV_T, CACHE_T, IS_FP8_E5M2_KV_CACHE><<<grid, block, 0, stream>>>( \
 | 
			
		||||
    reinterpret_cast<KV_T*>(key.data_ptr()),                                                       \
 | 
			
		||||
    reinterpret_cast<KV_T*>(value.data_ptr()),                                                     \
 | 
			
		||||
    reinterpret_cast<CACHE_T*>(key_cache.data_ptr()),                                              \
 | 
			
		||||
    reinterpret_cast<CACHE_T*>(value_cache.data_ptr()),                                            \
 | 
			
		||||
    slot_mapping.data_ptr<int64_t>(),                                                              \
 | 
			
		||||
    key_stride,                                                                                    \
 | 
			
		||||
    value_stride,                                                                                  \
 | 
			
		||||
    num_heads,                                                                                     \
 | 
			
		||||
    head_size,                                                                                     \
 | 
			
		||||
    block_size,                                                                                    \
 | 
			
		||||
    x);
 | 
			
		||||
 | 
			
		||||
void reshape_and_cache(
 | 
			
		||||
  torch::Tensor& key,           // [num_tokens, num_heads, head_size]
 | 
			
		||||
  torch::Tensor& value,         // [num_tokens, num_heads, head_size]
 | 
			
		||||
  torch::Tensor& key_cache,     // [num_blocks, num_heads, head_size/x, block_size, x]
 | 
			
		||||
  torch::Tensor& value_cache,   // [num_blocks, num_heads, head_size, block_size]
 | 
			
		||||
  torch::Tensor& slot_mapping)  // [num_tokens]
 | 
			
		||||
  torch::Tensor& slot_mapping,  // [num_tokens]
 | 
			
		||||
  const std::string& kv_cache_dtype)
 | 
			
		||||
{
 | 
			
		||||
  int num_tokens = key.size(0);
 | 
			
		||||
  int num_heads = key.size(1);
 | 
			
		||||
@ -206,24 +244,27 @@ void reshape_and_cache(
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(num_heads * head_size, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    key.scalar_type(),
 | 
			
		||||
    "reshape_and_cache_kernel",
 | 
			
		||||
    [&] {
 | 
			
		||||
      vllm::reshape_and_cache_kernel<scalar_t><<<grid, block, 0, stream>>>(
 | 
			
		||||
        key.data_ptr<scalar_t>(),
 | 
			
		||||
        value.data_ptr<scalar_t>(),
 | 
			
		||||
        key_cache.data_ptr<scalar_t>(),
 | 
			
		||||
        value_cache.data_ptr<scalar_t>(),
 | 
			
		||||
        slot_mapping.data_ptr<int64_t>(),
 | 
			
		||||
        key_stride,
 | 
			
		||||
        value_stride,
 | 
			
		||||
        num_heads,
 | 
			
		||||
        head_size,
 | 
			
		||||
        block_size,
 | 
			
		||||
        x);
 | 
			
		||||
    });
 | 
			
		||||
  if (kv_cache_dtype == "auto") {
 | 
			
		||||
    if (key.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(float, float, false);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(uint16_t, uint16_t, false);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, __nv_bfloat16, false);
 | 
			
		||||
    }
 | 
			
		||||
  } else if (kv_cache_dtype == "fp8_e5m2") {
 | 
			
		||||
    if (key.dtype() == at::ScalarType::Float) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(float, uint8_t, true);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::Half) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(uint16_t, uint8_t, true);
 | 
			
		||||
    } else if (key.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
      CALL_RESHAPE_AND_CACHE(__nv_bfloat16, uint8_t, true);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
@ -267,8 +308,8 @@ __global__ void gather_cached_kv_kernel(
 | 
			
		||||
                                + head_offset * block_size
 | 
			
		||||
                                + block_offset;
 | 
			
		||||
 | 
			
		||||
      key[tgt_key_idx] = __ldg(&key_cache[src_key_idx]);
 | 
			
		||||
      value[tgt_value_idx] = __ldg(&value_cache[src_value_idx]);
 | 
			
		||||
      key[tgt_key_idx] = VLLM_LDG(&key_cache[src_key_idx]);
 | 
			
		||||
      value[tgt_value_idx] = VLLM_LDG(&value_cache[src_value_idx]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -333,8 +374,8 @@ __global__ void gather_cached_kv_kernel_optimized(
 | 
			
		||||
            src_key_indices[j] = src_key_idx;
 | 
			
		||||
            src_value_indices[j] = src_value_idx;
 | 
			
		||||
 | 
			
		||||
            keys_to_store[j] = __ldg(&key_cache[src_key_idx]);
 | 
			
		||||
            values_to_store[j] = __ldg(&value_cache[src_value_idx]);
 | 
			
		||||
            keys_to_store[j] = VLLM_LDG(&key_cache[src_key_idx]);
 | 
			
		||||
            values_to_store[j] = VLLM_LDG(&value_cache[src_value_idx]);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
@ -366,8 +407,9 @@ void gather_cached_kv(
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(num_heads * head_size, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(
 | 
			
		||||
    key.scalar_type(),
 | 
			
		||||
    "gather_cached_kv_kernel_optimized",
 | 
			
		||||
    [&] {
 | 
			
		||||
@ -385,3 +427,55 @@ void gather_cached_kv(
 | 
			
		||||
        x);
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename Tout, typename Tin>
 | 
			
		||||
__global__ void convert_fp8_e5m2_kernel(
 | 
			
		||||
  const Tin* __restrict__ src_cache,
 | 
			
		||||
  Tout* __restrict__ dst_cache,
 | 
			
		||||
  const int64_t block_stride) {
 | 
			
		||||
  const int64_t block_idx = blockIdx.x;
 | 
			
		||||
  for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
 | 
			
		||||
    int64_t idx = block_idx * block_stride + i;
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
    dst_cache[idx] = fp8_e5m2_unscaled::vec_conversion<Tout, Tin>(src_cache[idx]);
 | 
			
		||||
#else
 | 
			
		||||
    assert(false);
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
 | 
			
		||||
#define CALL_CONVERT_FP8_E5M2(Tout, Tin)                                 \
 | 
			
		||||
  vllm::convert_fp8_e5m2_kernel<Tout, Tin><<<grid, block, 0, stream>>>(  \
 | 
			
		||||
    reinterpret_cast<Tin*>(src_cache.data_ptr()),                        \
 | 
			
		||||
    reinterpret_cast<Tout*>(dst_cache.data_ptr()),                       \
 | 
			
		||||
    block_stride);
 | 
			
		||||
 | 
			
		||||
void convert_fp8_e5m2(
 | 
			
		||||
  torch::Tensor& src_cache,
 | 
			
		||||
  torch::Tensor& dst_cache)
 | 
			
		||||
{
 | 
			
		||||
  int64_t num_blocks = src_cache.size(0);
 | 
			
		||||
  int64_t block_stride = src_cache.stride(0);
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_blocks);
 | 
			
		||||
  dim3 block(std::min(block_stride, int64_t(512)));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  if (src_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint8_t, float);
 | 
			
		||||
  } else if (src_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint8_t, uint16_t);
 | 
			
		||||
  } else if (src_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint8_t, __nv_bfloat16);
 | 
			
		||||
  } else if (dst_cache.dtype() == at::ScalarType::Float) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(float, uint8_t);
 | 
			
		||||
  } else if (dst_cache.dtype() == at::ScalarType::Half) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(uint16_t, uint8_t);
 | 
			
		||||
  } else if (dst_cache.dtype() == at::ScalarType::BFloat16) {
 | 
			
		||||
    CALL_CONVERT_FP8_E5M2(__nv_bfloat16, uint8_t);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										28
									
								
								csrc/cuda_compat.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								csrc/cuda_compat.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_LDG(arg) __ldg(arg)
 | 
			
		||||
#else
 | 
			
		||||
  #define VLLM_LDG(arg) *(arg)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
 | 
			
		||||
#else
 | 
			
		||||
  #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane)
 | 
			
		||||
#else
 | 
			
		||||
  #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
 | 
			
		||||
    cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
 | 
			
		||||
#else
 | 
			
		||||
  #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
 | 
			
		||||
    hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,10 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
int get_device_attribute(
 | 
			
		||||
    int attribute,
 | 
			
		||||
    int device_id);
 | 
			
		||||
 | 
			
		||||
int get_max_shared_memory_per_block_device_attribute(
 | 
			
		||||
    int device_id);
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,7 @@
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
  #include <hip/hip_runtime.h>
 | 
			
		||||
  #include <hip/hip_runtime_api.h>
 | 
			
		||||
#endif
 | 
			
		||||
int get_device_attribute(
 | 
			
		||||
    int attribute,
 | 
			
		||||
    int device_id)
 | 
			
		||||
@ -12,3 +16,20 @@ int get_device_attribute(
 | 
			
		||||
    cudaDeviceGetAttribute(&value, static_cast<cudaDeviceAttr>(attribute), device);
 | 
			
		||||
    return value;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
int get_max_shared_memory_per_block_device_attribute(
 | 
			
		||||
    int device_id)
 | 
			
		||||
{
 | 
			
		||||
int attribute;    
 | 
			
		||||
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
 | 
			
		||||
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
 | 
			
		||||
 | 
			
		||||
#ifdef USE_ROCM
 | 
			
		||||
    attribute = hipDeviceAttributeMaxSharedMemoryPerBlock;
 | 
			
		||||
#else
 | 
			
		||||
    attribute = cudaDevAttrMaxSharedMemoryPerBlockOptin;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
    return get_device_attribute(attribute, device_id);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										148
									
								
								csrc/custom_all_reduce.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										148
									
								
								csrc/custom_all_reduce.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,148 @@
 | 
			
		||||
#include <ATen/cuda/Exceptions.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <c10/cuda/CUDAStream.h>
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
#include "custom_all_reduce.cuh"
 | 
			
		||||
 | 
			
		||||
// fake pointer type
 | 
			
		||||
using fptr_t = uint64_t;
 | 
			
		||||
static_assert(sizeof(void *) == sizeof(fptr_t));
 | 
			
		||||
 | 
			
		||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
 | 
			
		||||
                      const std::vector<std::string> &handles,
 | 
			
		||||
                      const std::vector<int64_t> &offsets, int rank,
 | 
			
		||||
                      bool full_nvlink) {
 | 
			
		||||
  int world_size = offsets.size();
 | 
			
		||||
  if (world_size > 8)
 | 
			
		||||
    throw std::invalid_argument("world size > 8 is not supported");
 | 
			
		||||
  if (world_size % 2 != 0)
 | 
			
		||||
    throw std::invalid_argument("Odd num gpus is not supported for now");
 | 
			
		||||
  if (world_size != handles.size())
 | 
			
		||||
    throw std::invalid_argument(
 | 
			
		||||
        "handles length should equal to offsets length");
 | 
			
		||||
  if (rank < 0 || rank >= world_size)
 | 
			
		||||
    throw std::invalid_argument("invalid rank passed in");
 | 
			
		||||
 | 
			
		||||
  cudaIpcMemHandle_t ipc_handles[8];
 | 
			
		||||
  for (int i = 0; i < world_size; i++) {
 | 
			
		||||
    std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t));
 | 
			
		||||
  }
 | 
			
		||||
  return (fptr_t) new vllm::CustomAllreduce(
 | 
			
		||||
      reinterpret_cast<vllm::Metadata *>(meta.data_ptr()), rank_data.data_ptr(),
 | 
			
		||||
      rank_data.numel(), ipc_handles, offsets, rank, full_nvlink);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
 * Make sure tensor t's data lies completely within ((char)t.data_ptr()) +
 | 
			
		||||
 * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous()
 | 
			
		||||
 * because it allows transpose of contiguous slice (i.e. slicing the first
 | 
			
		||||
 * dimension). Currently, we require this because stride information is not
 | 
			
		||||
 * passed into the kernels and we treat input tensors as flat.
 | 
			
		||||
 *
 | 
			
		||||
 * Examples
 | 
			
		||||
 * A = torch.zeros(3, 3, 3)
 | 
			
		||||
 * 1. A: OK
 | 
			
		||||
 * 2. A[1:]: OK
 | 
			
		||||
 * 3. A.permute(2, 0, 1): OK
 | 
			
		||||
 * 4. A[1:].permute(2, 0, 1): OK
 | 
			
		||||
 * 5. A[None].expand(2, -1, -1, -1): Not OK
 | 
			
		||||
 * 6. A[:, 1:, 1:]: Not OK
 | 
			
		||||
 */
 | 
			
		||||
bool _is_weak_contiguous(torch::Tensor &t) {
 | 
			
		||||
  return t.is_contiguous() ||
 | 
			
		||||
         (t.storage().nbytes() - t.storage_offset() * t.element_size() ==
 | 
			
		||||
          t.numel() * t.element_size());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
 | 
			
		||||
                      bool full_nvlink) {
 | 
			
		||||
  auto inp_size = inp.numel() * inp.element_size();
 | 
			
		||||
  // custom allreduce requires input byte size to be multiples of 16
 | 
			
		||||
  if (inp_size % 16 != 0) return false;
 | 
			
		||||
  if (!_is_weak_contiguous(inp)) return false;
 | 
			
		||||
  if (world_size == 2 || full_nvlink) return inp_size <= max_size;
 | 
			
		||||
  // 4 PCIE GPUs use 2 stage allreduce, and is only faster than NCCL when size
 | 
			
		||||
  // <= 512k
 | 
			
		||||
  return world_size <= 4 && inp_size <= 512 * 1024;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
 | 
			
		||||
                 cudaStream_t stream) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  TORCH_CHECK(_is_weak_contiguous(out));
 | 
			
		||||
  switch (out.scalar_type()) {
 | 
			
		||||
    case at::ScalarType::Float: {
 | 
			
		||||
      fa->allreduce<float>(stream, reinterpret_cast<float *>(inp.data_ptr()),
 | 
			
		||||
                           reinterpret_cast<float *>(out.data_ptr()),
 | 
			
		||||
                           out.numel());
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
    case at::ScalarType::Half: {
 | 
			
		||||
      fa->allreduce<half>(stream, reinterpret_cast<half *>(inp.data_ptr()),
 | 
			
		||||
                          reinterpret_cast<half *>(out.data_ptr()),
 | 
			
		||||
                          out.numel());
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
 | 
			
		||||
    case at::ScalarType::BFloat16: {
 | 
			
		||||
      fa->allreduce<nv_bfloat16>(
 | 
			
		||||
          stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
 | 
			
		||||
          reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
    default:
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "custom allreduce only supports float32, float16 and bfloat16");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
 | 
			
		||||
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
 | 
			
		||||
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
 | 
			
		||||
  TORCH_CHECK_EQ(inp.numel(), out.numel());
 | 
			
		||||
  _all_reduce(_fa, inp, out, stream);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
 | 
			
		||||
                      torch::Tensor &out) {
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
 | 
			
		||||
  auto stream = c10::cuda::getCurrentCUDAStream().stream();
 | 
			
		||||
 | 
			
		||||
  auto input_size = inp.numel() * inp.element_size();
 | 
			
		||||
  TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
 | 
			
		||||
  TORCH_CHECK_EQ(inp.numel(), out.numel());
 | 
			
		||||
  TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(),
 | 
			
		||||
              "registered buffer is too small to contain the input");
 | 
			
		||||
  AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(),
 | 
			
		||||
                                input_size, cudaMemcpyDeviceToDevice, stream));
 | 
			
		||||
  _all_reduce(_fa, reg_buffer, out, stream);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void dispose(fptr_t _fa) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  delete fa;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int meta_size() { return sizeof(vllm::Metadata); }
 | 
			
		||||
 | 
			
		||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
 | 
			
		||||
                     const std::vector<std::string> &handles,
 | 
			
		||||
                     const std::vector<int64_t> &offsets) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  fa->register_buffer(handles, offsets, t.data_ptr());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(
 | 
			
		||||
    fptr_t _fa) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  return fa->get_graph_buffer_ipc_meta();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
 | 
			
		||||
                            const std::vector<std::vector<int64_t>> &offsets) {
 | 
			
		||||
  auto fa = reinterpret_cast<vllm::CustomAllreduce *>(_fa);
 | 
			
		||||
  fa->register_graph_buffers(handles, offsets);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										562
									
								
								csrc/custom_all_reduce.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										562
									
								
								csrc/custom_all_reduce.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,562 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
#include <cuda_bf16.h>
 | 
			
		||||
#include <cuda_fp16.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <map>
 | 
			
		||||
#include <unordered_map>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#define CUDACHECK(cmd)                                              \
 | 
			
		||||
  do {                                                              \
 | 
			
		||||
    cudaError_t e = cmd;                                            \
 | 
			
		||||
    if (e != cudaSuccess) {                                         \
 | 
			
		||||
      printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \
 | 
			
		||||
             cudaGetErrorString(e));                                \
 | 
			
		||||
      exit(EXIT_FAILURE);                                           \
 | 
			
		||||
    }                                                               \
 | 
			
		||||
  } while (0)
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
struct Signal {
 | 
			
		||||
  alignas(64) union {
 | 
			
		||||
    uint64_t flag;
 | 
			
		||||
    unsigned char data[8];
 | 
			
		||||
  } start;
 | 
			
		||||
  alignas(64) union {
 | 
			
		||||
    uint64_t flag;
 | 
			
		||||
    unsigned char data[8];
 | 
			
		||||
  } end;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
struct Metadata {
 | 
			
		||||
  alignas(128) Signal sg;
 | 
			
		||||
  alignas(128) int counter;
 | 
			
		||||
};
 | 
			
		||||
static_assert(offsetof(Metadata, counter) == 128);
 | 
			
		||||
static_assert(sizeof(Metadata) == 256);
 | 
			
		||||
 | 
			
		||||
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
 | 
			
		||||
 | 
			
		||||
struct RankSignals {
 | 
			
		||||
  volatile Signal *signals[8];
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// like std::array, but aligned
 | 
			
		||||
template <typename T, int sz>
 | 
			
		||||
struct __align__(alignof(T) * sz) array_t {
 | 
			
		||||
  T data[sz];
 | 
			
		||||
  using type = T;
 | 
			
		||||
  static constexpr int size = sz;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// use packed type to maximize memory efficiency
 | 
			
		||||
// goal: generate ld.128 and st.128 instructions
 | 
			
		||||
template <typename T>
 | 
			
		||||
struct packed_t {
 | 
			
		||||
  // the (P)acked type for load/store
 | 
			
		||||
  using P = array_t<T, 16 / sizeof(T)>;
 | 
			
		||||
  // the (A)ccumulator type for reduction
 | 
			
		||||
  using A = array_t<float, 16 / sizeof(T)>;
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
#define DINLINE __device__ __forceinline__
 | 
			
		||||
 | 
			
		||||
// scalar cast functions
 | 
			
		||||
DINLINE float upcast_s(half val) { return __half2float(val); }
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
DINLINE T downcast_s(float val);
 | 
			
		||||
template <>
 | 
			
		||||
DINLINE half downcast_s(float val) {
 | 
			
		||||
  return __float2half(val);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// scalar add functions
 | 
			
		||||
// for some reason when compiling with Pytorch, the + operator for half and
 | 
			
		||||
// bfloat is disabled so we call the intrinsics directly
 | 
			
		||||
DINLINE half &assign_add(half &a, half b) {
 | 
			
		||||
  a = __hadd(a, b);
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
DINLINE float &assign_add(float &a, float b) { return a += b; }
 | 
			
		||||
 | 
			
		||||
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
 | 
			
		||||
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
 | 
			
		||||
template <>
 | 
			
		||||
DINLINE nv_bfloat16 downcast_s(float val) {
 | 
			
		||||
  return __float2bfloat16(val);
 | 
			
		||||
}
 | 
			
		||||
DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
 | 
			
		||||
  a = __hadd(a, b);
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
template <typename T, int N>
 | 
			
		||||
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int i = 0; i < N; i++) {
 | 
			
		||||
    assign_add(a.data[i], b.data[i]);
 | 
			
		||||
  }
 | 
			
		||||
  return a;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int N>
 | 
			
		||||
DINLINE array_t<float, N> upcast(array_t<T, N> val) {
 | 
			
		||||
  if constexpr (std::is_same<T, float>::value) {
 | 
			
		||||
    return val;
 | 
			
		||||
  } else {
 | 
			
		||||
    array_t<float, N> out;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int i = 0; i < N; i++) {
 | 
			
		||||
      out.data[i] = upcast_s(val.data[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename O>
 | 
			
		||||
DINLINE O downcast(array_t<float, O::size> val) {
 | 
			
		||||
  if constexpr (std::is_same<typename O::type, float>::value) {
 | 
			
		||||
    return val;
 | 
			
		||||
  } else {
 | 
			
		||||
    O out;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int i = 0; i < O::size; i++) {
 | 
			
		||||
      out.data[i] = downcast_s<typename O::type>(val.data[i]);
 | 
			
		||||
    }
 | 
			
		||||
    return out;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// compute flag at compile time
 | 
			
		||||
__host__ __device__ constexpr uint64_t compute_flag(int ngpus) {
 | 
			
		||||
  auto m = std::numeric_limits<uint64_t>::max();
 | 
			
		||||
  return m >> ((8 - ngpus) * 8);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int ngpus>
 | 
			
		||||
DINLINE void start_sync(const RankSignals &sg, volatile Metadata *meta,
 | 
			
		||||
                        int rank) {
 | 
			
		||||
  constexpr auto FLAG = compute_flag(ngpus);
 | 
			
		||||
  if (blockIdx.x == 0) {
 | 
			
		||||
    if (threadIdx.x < ngpus)
 | 
			
		||||
      // simultaneously write to the corresponding byte to all other ranks.
 | 
			
		||||
      // Latency = 1 p2p write
 | 
			
		||||
      sg.signals[threadIdx.x]->start.data[rank] = 255;
 | 
			
		||||
    else if (threadIdx.x == 32)
 | 
			
		||||
      // reset
 | 
			
		||||
      meta->sg.end.flag = 0;
 | 
			
		||||
  }
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
    while (meta->sg.start.flag != FLAG)
 | 
			
		||||
      ;
 | 
			
		||||
  }
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int ngpus, bool final_sync = false>
 | 
			
		||||
DINLINE void end_sync(const RankSignals &sg, volatile Metadata *meta,
 | 
			
		||||
                      int rank) {
 | 
			
		||||
  constexpr auto FLAG = compute_flag(ngpus);
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
  __shared__ int num;
 | 
			
		||||
  if (threadIdx.x == 0) num = atomicAdd((int *)&meta->counter, 1);
 | 
			
		||||
  __syncthreads();
 | 
			
		||||
 | 
			
		||||
  // Only the last completing block can perform the end synchronization
 | 
			
		||||
  // This can ensures when the final busy wait ends, all ranks must have
 | 
			
		||||
  // finished reading each other's buffer.
 | 
			
		||||
  if (num == gridDim.x - 1) {
 | 
			
		||||
    if (threadIdx.x == 32) {
 | 
			
		||||
      // reset in a different warp
 | 
			
		||||
      meta->counter = 0;
 | 
			
		||||
      meta->sg.start.flag = 0;
 | 
			
		||||
    } else if (threadIdx.x < ngpus) {
 | 
			
		||||
      // simultaneously write to the corresponding byte to all other ranks.
 | 
			
		||||
      // Latency = 1 p2p write
 | 
			
		||||
      sg.signals[threadIdx.x]->end.data[rank] = 255;
 | 
			
		||||
    }
 | 
			
		||||
    // if this is the final sync, only one block needs it
 | 
			
		||||
    // because kernel exit can serve as sync
 | 
			
		||||
    if constexpr (final_sync) {
 | 
			
		||||
      if (threadIdx.x == 0) {
 | 
			
		||||
        while (meta->sg.end.flag != FLAG)
 | 
			
		||||
          ;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  if constexpr (!final_sync) {
 | 
			
		||||
    if (threadIdx.x == 0) {
 | 
			
		||||
      while (meta->sg.end.flag != FLAG)
 | 
			
		||||
        ;
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename P, int ngpus, typename A>
 | 
			
		||||
DINLINE P packed_reduce(const P *ptrs[], int idx) {
 | 
			
		||||
  A tmp = upcast(ptrs[0][idx]);
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int i = 1; i < ngpus; i++) {
 | 
			
		||||
    packed_assign_add(tmp, upcast(ptrs[i][idx]));
 | 
			
		||||
  }
 | 
			
		||||
  return downcast<P>(tmp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int ngpus>
 | 
			
		||||
__global__ void __launch_bounds__(512, 1)
 | 
			
		||||
    cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
 | 
			
		||||
                               volatile Metadata *meta, T *__restrict__ result,
 | 
			
		||||
                               int rank, int size) {
 | 
			
		||||
  using P = typename packed_t<T>::P;
 | 
			
		||||
  using A = typename packed_t<T>::A;
 | 
			
		||||
  // note: we don't reorder the address so the accumulation order is the same
 | 
			
		||||
  // for all ranks, ensuring bitwise identical results
 | 
			
		||||
  auto dp = *_dp;
 | 
			
		||||
  start_sync<ngpus>(sg, meta, rank);
 | 
			
		||||
  // do the actual reduction
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    ((P *)result)[idx] =
 | 
			
		||||
        packed_reduce<P, ngpus, A>((const P **)&dp.ptrs[0], idx);
 | 
			
		||||
  }
 | 
			
		||||
  end_sync<ngpus, true>(sg, meta, rank);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename P>
 | 
			
		||||
DINLINE P *get_tmp_buf(volatile Signal *sg) {
 | 
			
		||||
  return (P *)(((Metadata *)sg) + 1);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int ngpus>
 | 
			
		||||
__global__ void __launch_bounds__(512, 1)
 | 
			
		||||
    cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
 | 
			
		||||
                               volatile Metadata *meta, T *__restrict__ result,
 | 
			
		||||
                               int rank, int size) {
 | 
			
		||||
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
  int stride = gridDim.x * blockDim.x;
 | 
			
		||||
  using P = typename packed_t<T>::P;
 | 
			
		||||
  using A = typename packed_t<T>::A;
 | 
			
		||||
  int part = size / ngpus;
 | 
			
		||||
  int start = rank * part;
 | 
			
		||||
  int end = rank == ngpus - 1 ? size : start + part;
 | 
			
		||||
  const P *ptrs[ngpus];
 | 
			
		||||
  P *tmps[ngpus];
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int i = 0; i < ngpus; i++) {
 | 
			
		||||
    int target = (rank + i) % ngpus;
 | 
			
		||||
    ptrs[i] = (const P *)_dp->ptrs[target];
 | 
			
		||||
    tmps[i] = get_tmp_buf<P>(sg.signals[target]);
 | 
			
		||||
  }
 | 
			
		||||
  auto tmp_out = tmps[0];
 | 
			
		||||
  start_sync<ngpus>(sg, meta, rank);
 | 
			
		||||
  // stage 1: reduce scatter
 | 
			
		||||
  for (int idx = start + tid; idx < end; idx += stride) {
 | 
			
		||||
    tmp_out[idx - start] = packed_reduce<P, ngpus, A>(ptrs, idx);
 | 
			
		||||
  }
 | 
			
		||||
  // Maybe TODO: replace this with per-block release-acquire
 | 
			
		||||
  // can save about 1-2us (not a lot though)
 | 
			
		||||
  end_sync<ngpus>(sg, meta, rank);
 | 
			
		||||
 | 
			
		||||
  // stage 2: allgather
 | 
			
		||||
  for (int idx = tid; idx < part; idx += stride) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int i = 0; i < ngpus; i++) {
 | 
			
		||||
      int dst_idx = ((rank + i) % ngpus) * part + idx;
 | 
			
		||||
      ((P *)result)[dst_idx] = tmps[i][idx];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  // process the last larger partition
 | 
			
		||||
  int remaining = size - part * ngpus;
 | 
			
		||||
  if (tid < remaining) {
 | 
			
		||||
    int dst_idx = tid + part * ngpus;
 | 
			
		||||
    ((P *)result)[dst_idx] = get_tmp_buf<P>(sg.signals[ngpus - 1])[part + tid];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // faster than this
 | 
			
		||||
  // for (int idx = tid; idx < size; idx += stride) {
 | 
			
		||||
  //   int target_rank = idx / part;
 | 
			
		||||
  //   if (target_rank == ngpus) target_rank -= 1;
 | 
			
		||||
  //   ((P *)result)[idx] = tmps[target_rank][idx - target_rank * part];
 | 
			
		||||
  // }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T, int ngpus>
 | 
			
		||||
__global__ void __launch_bounds__(512, 1)
 | 
			
		||||
    cross_device_reduce_half_butterfly(RankData *_dp, RankSignals sg,
 | 
			
		||||
                                       volatile Metadata *meta,
 | 
			
		||||
                                       T *__restrict__ result, int rank,
 | 
			
		||||
                                       int size) {
 | 
			
		||||
  int tid = blockIdx.x * blockDim.x + threadIdx.x;
 | 
			
		||||
  int stride = gridDim.x * blockDim.x;
 | 
			
		||||
  using P = typename packed_t<T>::P;
 | 
			
		||||
  using A = typename packed_t<T>::A;
 | 
			
		||||
  auto tmp_out = get_tmp_buf<P>(sg.signals[rank]);
 | 
			
		||||
  constexpr int hg = ngpus / 2;
 | 
			
		||||
  // Actually not quite half butterfly.
 | 
			
		||||
  // This is an all-to-all within each group containing half of the ranks
 | 
			
		||||
  // followed by cross-group add. Equivalent to half butterfly when there
 | 
			
		||||
  // are 4 GPUs, a common case for PCIe cards like T4 and A10.
 | 
			
		||||
  const P *ptrs[hg];
 | 
			
		||||
  {
 | 
			
		||||
    int start = rank - rank % hg;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int i = 0; i < hg; i++) {
 | 
			
		||||
      ptrs[i] = (const P *)_dp->ptrs[i + start];
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  start_sync<ngpus>(sg, meta, rank);
 | 
			
		||||
  for (int idx = tid; idx < size; idx += stride) {
 | 
			
		||||
    tmp_out[idx] = packed_reduce<P, hg, A>(ptrs, idx);
 | 
			
		||||
  }
 | 
			
		||||
  end_sync<ngpus>(sg, meta, rank);
 | 
			
		||||
 | 
			
		||||
  auto src = get_tmp_buf<P>(sg.signals[(ngpus - 1) - rank % ngpus]);
 | 
			
		||||
  // do the cross group reduction
 | 
			
		||||
  for (int idx = tid; idx < size; idx += stride) {
 | 
			
		||||
    auto tmp = tmp_out[idx];
 | 
			
		||||
    packed_assign_add(tmp, src[idx]);
 | 
			
		||||
    ((P *)result)[idx] = tmp;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
using IPC_KEY = std::array<uint8_t, sizeof(cudaIpcMemHandle_t)>;
 | 
			
		||||
static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t));
 | 
			
		||||
static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t));
 | 
			
		||||
 | 
			
		||||
class CustomAllreduce {
 | 
			
		||||
 public:
 | 
			
		||||
  int rank_;
 | 
			
		||||
  int world_size_;
 | 
			
		||||
  bool full_nvlink_;
 | 
			
		||||
 | 
			
		||||
  // below are device pointers
 | 
			
		||||
  RankSignals sg_;
 | 
			
		||||
  std::unordered_map<void *, RankData *> buffers_;
 | 
			
		||||
  Metadata *meta_;
 | 
			
		||||
 | 
			
		||||
  // stores the registered device pointers from all ranks
 | 
			
		||||
  RankData *d_rank_data_base_, *d_rank_data_end_;
 | 
			
		||||
  std::vector<void *> graph_unreg_buffers_;
 | 
			
		||||
  // a map from IPC handles to opened IPC pointers
 | 
			
		||||
  std::map<IPC_KEY, char *> ipc_handles_;
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * meta is a pointer to device metadata and temporary buffer for allreduce.
 | 
			
		||||
   *
 | 
			
		||||
   * There's a total of sizeof(Metadata) of prefix before the actual data,
 | 
			
		||||
   * so meta + 1 points to actual temporary buffer.
 | 
			
		||||
   *
 | 
			
		||||
   * note: this class does not own any device memory. Any required buffers
 | 
			
		||||
   * are passed in from the constructor
 | 
			
		||||
   */
 | 
			
		||||
  CustomAllreduce(Metadata *meta, void *rank_data, size_t rank_data_sz,
 | 
			
		||||
                  const cudaIpcMemHandle_t *handles,
 | 
			
		||||
                  const std::vector<int64_t> &offsets, int rank,
 | 
			
		||||
                  bool full_nvlink = true)
 | 
			
		||||
      : rank_(rank),
 | 
			
		||||
        world_size_(offsets.size()),
 | 
			
		||||
        full_nvlink_(full_nvlink),
 | 
			
		||||
        meta_(meta),
 | 
			
		||||
        d_rank_data_base_(reinterpret_cast<RankData *>(rank_data)),
 | 
			
		||||
        d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) {
 | 
			
		||||
    for (int i = 0; i < world_size_; i++) {
 | 
			
		||||
      Metadata *rank_meta;
 | 
			
		||||
      if (i != rank_) {
 | 
			
		||||
        char *handle = open_ipc_handle(&handles[i]);
 | 
			
		||||
        handle += offsets[i];
 | 
			
		||||
        rank_meta = (Metadata *)handle;
 | 
			
		||||
      } else {
 | 
			
		||||
        rank_meta = meta_;
 | 
			
		||||
      }
 | 
			
		||||
      sg_.signals[i] = &rank_meta->sg;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  char *open_ipc_handle(const void *ipc_handle) {
 | 
			
		||||
    auto [it, new_handle] =
 | 
			
		||||
        ipc_handles_.insert({*((IPC_KEY *)ipc_handle), nullptr});
 | 
			
		||||
    if (new_handle) {
 | 
			
		||||
      char *ipc_ptr;
 | 
			
		||||
      CUDACHECK(cudaIpcOpenMemHandle((void **)&ipc_ptr,
 | 
			
		||||
                                     *((const cudaIpcMemHandle_t *)ipc_handle),
 | 
			
		||||
                                     cudaIpcMemLazyEnablePeerAccess));
 | 
			
		||||
      it->second = ipc_ptr;
 | 
			
		||||
    }
 | 
			
		||||
    return it->second;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  std::pair<std::vector<uint8_t>, std::vector<int64_t>>
 | 
			
		||||
  get_graph_buffer_ipc_meta() {
 | 
			
		||||
    auto num_buffers = graph_unreg_buffers_.size();
 | 
			
		||||
    auto handle_sz = sizeof(cudaIpcMemHandle_t);
 | 
			
		||||
    std::vector<uint8_t> handles(handle_sz * num_buffers, 0);
 | 
			
		||||
    std::vector<int64_t> offsets(num_buffers);
 | 
			
		||||
    for (int i = 0; i < num_buffers; i++) {
 | 
			
		||||
      auto ptr = graph_unreg_buffers_[i];
 | 
			
		||||
      void *base_ptr;
 | 
			
		||||
      // note: must share the base address of each allocation, or we get wrong
 | 
			
		||||
      // address
 | 
			
		||||
      if (cuPointerGetAttribute(&base_ptr,
 | 
			
		||||
                                CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
 | 
			
		||||
                                (CUdeviceptr)ptr) != CUDA_SUCCESS)
 | 
			
		||||
        throw std::runtime_error("failed to get pointer attr");
 | 
			
		||||
      CUDACHECK(cudaIpcGetMemHandle(
 | 
			
		||||
          (cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
 | 
			
		||||
      offsets[i] = ((char *)ptr) - ((char *)base_ptr);
 | 
			
		||||
    }
 | 
			
		||||
    return std::make_pair(handles, offsets);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void check_rank_data_capacity(size_t num = 1) {
 | 
			
		||||
    if (d_rank_data_base_ + num > d_rank_data_end_)
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "Rank data buffer is overflowed by " +
 | 
			
		||||
          std::to_string(d_rank_data_base_ + num - d_rank_data_end_));
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  void register_buffer(const std::vector<std::string> &handles,
 | 
			
		||||
                       const std::vector<int64_t> &offsets, void *self) {
 | 
			
		||||
    check_rank_data_capacity();
 | 
			
		||||
    RankData data;
 | 
			
		||||
    for (int i = 0; i < world_size_; i++) {
 | 
			
		||||
      if (i != rank_) {
 | 
			
		||||
        char *handle = open_ipc_handle(handles[i].data());
 | 
			
		||||
        handle += offsets[i];
 | 
			
		||||
        data.ptrs[i] = handle;
 | 
			
		||||
      } else {
 | 
			
		||||
        data.ptrs[i] = self;
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    auto d_data = d_rank_data_base_++;
 | 
			
		||||
    CUDACHECK(
 | 
			
		||||
        cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice));
 | 
			
		||||
    buffers_[self] = d_data;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  // note: when registering graph buffers, we intentionally choose to not
 | 
			
		||||
  // deduplicate the addresses. That means if the allocator reuses some
 | 
			
		||||
  // addresses, they will be registered again. This is to account for the remote
 | 
			
		||||
  // possibility of different allocation patterns between ranks. For example,
 | 
			
		||||
  // rank 1 may get the same input address for the second allreduce, but rank 2
 | 
			
		||||
  // got a different address. IPC handles have internal reference counting
 | 
			
		||||
  // mechanism so overhead should be small.
 | 
			
		||||
  void register_graph_buffers(
 | 
			
		||||
      const std::vector<std::string> &handles,
 | 
			
		||||
      const std::vector<std::vector<int64_t>> &offsets) {
 | 
			
		||||
    auto num_buffers = graph_unreg_buffers_.size();
 | 
			
		||||
    check_rank_data_capacity(num_buffers);
 | 
			
		||||
    std::vector<RankData> rank_data(num_buffers);
 | 
			
		||||
    for (int i = 0; i < num_buffers; i++) {
 | 
			
		||||
      auto self_ptr = graph_unreg_buffers_[i];
 | 
			
		||||
      auto &rd = rank_data[i];
 | 
			
		||||
      for (int j = 0; j < world_size_; j++) {
 | 
			
		||||
        if (j != rank_) {
 | 
			
		||||
          char *handle =
 | 
			
		||||
              open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]);
 | 
			
		||||
          handle += offsets[j][i];
 | 
			
		||||
          rd.ptrs[j] = handle;
 | 
			
		||||
        } else {
 | 
			
		||||
          rd.ptrs[j] = self_ptr;
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
    CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(),
 | 
			
		||||
                         sizeof(RankData) * num_buffers,
 | 
			
		||||
                         cudaMemcpyHostToDevice));
 | 
			
		||||
    d_rank_data_base_ += num_buffers;
 | 
			
		||||
    graph_unreg_buffers_.clear();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /**
 | 
			
		||||
   * This is the result after careful grid search. Using 36 blocks give the best
 | 
			
		||||
   * or close to the best runtime on the devices I tried: A100, A10, A30, T4,
 | 
			
		||||
   * V100. You'll notice that NCCL kernels also only take a small amount of SMs.
 | 
			
		||||
   * Not quite sure the underlying reason, but my guess is that too many SMs
 | 
			
		||||
   * will cause contention on NVLink bus.
 | 
			
		||||
   */
 | 
			
		||||
  template <typename T>
 | 
			
		||||
  void allreduce(cudaStream_t stream, T *input, T *output, int size,
 | 
			
		||||
                 int threads = 512, int block_limit = 36) {
 | 
			
		||||
    auto d = packed_t<T>::P::size;
 | 
			
		||||
    if (size % d != 0)
 | 
			
		||||
      throw std::runtime_error(
 | 
			
		||||
          "custom allreduce currently requires input length to be multiple "
 | 
			
		||||
          "of " +
 | 
			
		||||
          std::to_string(d));
 | 
			
		||||
 | 
			
		||||
    RankData *ptrs;
 | 
			
		||||
    cudaStreamCaptureStatus status;
 | 
			
		||||
    CUDACHECK(cudaStreamIsCapturing(stream, &status));
 | 
			
		||||
    if (status == cudaStreamCaptureStatusActive) {
 | 
			
		||||
      ptrs = d_rank_data_base_ + graph_unreg_buffers_.size();
 | 
			
		||||
      graph_unreg_buffers_.push_back(input);
 | 
			
		||||
    } else {
 | 
			
		||||
      auto it = buffers_.find(input);
 | 
			
		||||
      if (it == buffers_.end())
 | 
			
		||||
        throw std::runtime_error(
 | 
			
		||||
            "buffer address " +
 | 
			
		||||
            std::to_string(reinterpret_cast<uint64_t>(input)) +
 | 
			
		||||
            " is not registered!");
 | 
			
		||||
      ptrs = it->second;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    size /= d;
 | 
			
		||||
    auto bytes = size * sizeof(typename packed_t<T>::P);
 | 
			
		||||
    int blocks = std::min(block_limit, (size + threads - 1) / threads);
 | 
			
		||||
#define KL(ngpus, name) \
 | 
			
		||||
  name<T, ngpus>        \
 | 
			
		||||
      <<<blocks, threads, 0, stream>>>(ptrs, sg_, meta_, output, rank_, size);
 | 
			
		||||
#define REDUCE_CASE(ngpus)                            \
 | 
			
		||||
  case ngpus: {                                       \
 | 
			
		||||
    if (world_size_ == 2) {                           \
 | 
			
		||||
      KL(ngpus, cross_device_reduce_1stage);          \
 | 
			
		||||
    } else if (full_nvlink_) {                        \
 | 
			
		||||
      if ((world_size_ <= 4 && bytes < 512 * 1024) || \
 | 
			
		||||
          (world_size_ <= 8 && bytes < 256 * 1024)) { \
 | 
			
		||||
        KL(ngpus, cross_device_reduce_1stage);        \
 | 
			
		||||
      } else {                                        \
 | 
			
		||||
        KL(ngpus, cross_device_reduce_2stage);        \
 | 
			
		||||
      }                                               \
 | 
			
		||||
    } else {                                          \
 | 
			
		||||
      KL(ngpus, cross_device_reduce_half_butterfly);  \
 | 
			
		||||
    }                                                 \
 | 
			
		||||
    break;                                            \
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
    switch (world_size_) {
 | 
			
		||||
      REDUCE_CASE(2)
 | 
			
		||||
      REDUCE_CASE(4)
 | 
			
		||||
      REDUCE_CASE(6)
 | 
			
		||||
      REDUCE_CASE(8)
 | 
			
		||||
      default:
 | 
			
		||||
        throw std::runtime_error(
 | 
			
		||||
            "custom allreduce only supports num gpus in (2,4,6,8). Actual num "
 | 
			
		||||
            "gpus = " +
 | 
			
		||||
            std::to_string(world_size_));
 | 
			
		||||
    }
 | 
			
		||||
#undef REDUCE_CASE
 | 
			
		||||
#undef KL
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ~CustomAllreduce() {
 | 
			
		||||
    for (auto [_, ptr] : ipc_handles_) {
 | 
			
		||||
      CUDACHECK(cudaIpcCloseMemHandle(ptr));
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
};
 | 
			
		||||
/**
 | 
			
		||||
 * To inspect PTX/SASS, copy paste this header file to compiler explorer and add
 | 
			
		||||
 a template instantiation:
 | 
			
		||||
 * template void CustomAllreduce::allreduce<half>(cudaStream_t, half *, half *,
 | 
			
		||||
 int, int, int);
 | 
			
		||||
*/
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
							
								
								
									
										284
									
								
								csrc/custom_all_reduce_test.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										284
									
								
								csrc/custom_all_reduce_test.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,284 @@
 | 
			
		||||
/**
 | 
			
		||||
 * This is a standalone test for custom allreduce.
 | 
			
		||||
 * To compile, make sure you have MPI and NCCL installed in your system.
 | 
			
		||||
 * export MPI_HOME=XXX
 | 
			
		||||
 * nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
 | 
			
		||||
 * custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
 | 
			
		||||
 *
 | 
			
		||||
 * Warning: this C++ test is not designed to be very readable and was used
 | 
			
		||||
 * during the rapid prototyping process.
 | 
			
		||||
 *
 | 
			
		||||
 * To run:
 | 
			
		||||
 * mpirun -np 8 ./custom_all_reduce_test
 | 
			
		||||
 */
 | 
			
		||||
#include <cuda.h>
 | 
			
		||||
#include <curand_kernel.h>
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
#include <stdlib.h>
 | 
			
		||||
 | 
			
		||||
#include <limits>
 | 
			
		||||
#include <vector>
 | 
			
		||||
 | 
			
		||||
#include "cuda_profiler_api.h"
 | 
			
		||||
#include "custom_all_reduce.cuh"
 | 
			
		||||
#include "mpi.h"
 | 
			
		||||
#include "nccl.h"
 | 
			
		||||
 | 
			
		||||
#define MPICHECK(cmd)                                                  \
 | 
			
		||||
  do {                                                                 \
 | 
			
		||||
    int e = cmd;                                                       \
 | 
			
		||||
    if (e != MPI_SUCCESS) {                                            \
 | 
			
		||||
      printf("Failed: MPI error %s:%d '%d'\n", __FILE__, __LINE__, e); \
 | 
			
		||||
      exit(EXIT_FAILURE);                                              \
 | 
			
		||||
    }                                                                  \
 | 
			
		||||
  } while (0)
 | 
			
		||||
 | 
			
		||||
#define NCCLCHECK(cmd)                                              \
 | 
			
		||||
  do {                                                              \
 | 
			
		||||
    ncclResult_t r = cmd;                                           \
 | 
			
		||||
    if (r != ncclSuccess) {                                         \
 | 
			
		||||
      printf("Failed, NCCL error %s:%d '%s'\n", __FILE__, __LINE__, \
 | 
			
		||||
             ncclGetErrorString(r));                                \
 | 
			
		||||
      exit(EXIT_FAILURE);                                           \
 | 
			
		||||
    }                                                               \
 | 
			
		||||
  } while (0)
 | 
			
		||||
 | 
			
		||||
__global__ void dummy_kernel() {
 | 
			
		||||
  for (int i = 0; i < 100; i++) __nanosleep(1000000);  // 100ms
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
__global__ void set_data(T *data, int size, int myRank) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    data[idx] = myRank * 0.11f;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
__global__ void convert_data(const T *data1, const T *data2, double *fdata1,
 | 
			
		||||
                             double *fdata2, int size) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    fdata1[idx] = data1[idx];
 | 
			
		||||
    fdata2[idx] = data2[idx];
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void init_rand(curandState_t *state, int size, int nRanks) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    for (int i = 0; i < nRanks; i++) {
 | 
			
		||||
      curand_init(i + 1, idx, 0, &state[idx * nRanks + i]);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
__global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
 | 
			
		||||
                         int myRank, int nRanks, int size) {
 | 
			
		||||
  for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
 | 
			
		||||
       idx += gridDim.x * blockDim.x) {
 | 
			
		||||
    double sum = 0.0;
 | 
			
		||||
    for (int i = 0; i < nRanks; i++) {
 | 
			
		||||
      double val = curand_uniform_double(&state[idx * nRanks + i]) * 4;
 | 
			
		||||
      T hval = val;  // downcast first
 | 
			
		||||
      sum += static_cast<double>(hval);
 | 
			
		||||
      if (i == myRank) data[idx] = hval;
 | 
			
		||||
    }
 | 
			
		||||
    ground_truth[idx] = sum;
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <typename T>
 | 
			
		||||
void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
 | 
			
		||||
         int data_size) {
 | 
			
		||||
  T *result;
 | 
			
		||||
  cudaStream_t stream;
 | 
			
		||||
  CUDACHECK(cudaStreamCreateWithFlags(&stream, cudaStreamNonBlocking));
 | 
			
		||||
  CUDACHECK(cudaMalloc(&result, data_size * sizeof(T)));
 | 
			
		||||
  CUDACHECK(cudaMemset(result, 0, data_size * sizeof(T)));
 | 
			
		||||
 | 
			
		||||
  cudaIpcMemHandle_t self_data_handle;
 | 
			
		||||
  cudaIpcMemHandle_t data_handles[8];
 | 
			
		||||
  vllm::Metadata *buffer;
 | 
			
		||||
  T *self_data_copy;
 | 
			
		||||
  /**
 | 
			
		||||
   * Allocate IPC buffer
 | 
			
		||||
   *
 | 
			
		||||
   * The first section is a temporary buffer for storing intermediate allreduce
 | 
			
		||||
   * results, if a particular algorithm requires it. The second section is for
 | 
			
		||||
   * the input to the allreduce. The actual API takes the input pointer as an
 | 
			
		||||
   * argument (that is, they can and usually should be allocated separately).
 | 
			
		||||
   * But since the input pointers and the temporary buffer all require IPC
 | 
			
		||||
   * registration, they are allocated and registered together in the test for
 | 
			
		||||
   * convenience.
 | 
			
		||||
   */
 | 
			
		||||
  CUDACHECK(
 | 
			
		||||
      cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
 | 
			
		||||
  CUDACHECK(cudaMemset(buffer, 0,
 | 
			
		||||
                       2 * data_size * sizeof(T) + sizeof(vllm::Metadata)));
 | 
			
		||||
  CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T)));
 | 
			
		||||
  CUDACHECK(cudaIpcGetMemHandle(&self_data_handle, buffer));
 | 
			
		||||
 | 
			
		||||
  MPICHECK(MPI_Allgather(&self_data_handle, sizeof(cudaIpcMemHandle_t),
 | 
			
		||||
                         MPI_BYTE, data_handles, sizeof(cudaIpcMemHandle_t),
 | 
			
		||||
                         MPI_BYTE, MPI_COMM_WORLD));
 | 
			
		||||
 | 
			
		||||
  void *rank_data;
 | 
			
		||||
  size_t rank_data_sz = 16 * 1024 * 1024;
 | 
			
		||||
  CUDACHECK(cudaMalloc(&rank_data, rank_data_sz));
 | 
			
		||||
  std::vector<int64_t> offsets(nRanks, 0);
 | 
			
		||||
  vllm::CustomAllreduce fa(buffer, rank_data, rank_data_sz, data_handles,
 | 
			
		||||
                           offsets, myRank);
 | 
			
		||||
  auto *self_data =
 | 
			
		||||
      reinterpret_cast<T *>(reinterpret_cast<char *>(buffer) +
 | 
			
		||||
                            sizeof(vllm::Metadata) + data_size * sizeof(T));
 | 
			
		||||
  // hack buffer registration
 | 
			
		||||
  {
 | 
			
		||||
    std::vector<std::string> handles;
 | 
			
		||||
    handles.reserve(nRanks);
 | 
			
		||||
    for (int i = 0; i < nRanks; i++) {
 | 
			
		||||
      char *begin = (char *)&data_handles[i];
 | 
			
		||||
      char *end = (char *)&data_handles[i + 1];
 | 
			
		||||
      handles.emplace_back(begin, end);
 | 
			
		||||
    }
 | 
			
		||||
    std::vector<int64_t> offsets(
 | 
			
		||||
        nRanks, sizeof(vllm::Metadata) + data_size * sizeof(T));
 | 
			
		||||
    fa.register_buffer(handles, offsets, self_data);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  double *ground_truth;
 | 
			
		||||
  CUDACHECK(cudaMallocHost(&ground_truth, data_size * sizeof(double)));
 | 
			
		||||
  curandState_t *states;
 | 
			
		||||
  CUDACHECK(cudaMalloc(&states, sizeof(curandState_t) * nRanks * data_size));
 | 
			
		||||
  init_rand<<<108, 1024, 0, stream>>>(states, data_size, nRanks);
 | 
			
		||||
  gen_data<T><<<108, 1024, 0, stream>>>(states, self_data, ground_truth, myRank,
 | 
			
		||||
                                        nRanks, data_size);
 | 
			
		||||
  CUDACHECK(cudaMemcpyAsync(self_data_copy, self_data, data_size * sizeof(T),
 | 
			
		||||
                            cudaMemcpyDeviceToDevice, stream));
 | 
			
		||||
  cudaEvent_t start, stop;
 | 
			
		||||
  CUDACHECK(cudaEventCreate(&start));
 | 
			
		||||
  CUDACHECK(cudaEventCreate(&stop));
 | 
			
		||||
 | 
			
		||||
  ncclDataType_t ncclDtype;
 | 
			
		||||
  if (std::is_same<T, half>::value) {
 | 
			
		||||
    ncclDtype = ncclFloat16;
 | 
			
		||||
  } else if (std::is_same<T, nv_bfloat16>::value) {
 | 
			
		||||
    ncclDtype = ncclBfloat16;
 | 
			
		||||
  } else {
 | 
			
		||||
    ncclDtype = ncclFloat;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  dummy_kernel<<<1, 1, 0, stream>>>();
 | 
			
		||||
  constexpr int warmup_iters = 5;
 | 
			
		||||
  constexpr int num_iters = 25;
 | 
			
		||||
  // warmup
 | 
			
		||||
  for (int i = 0; i < warmup_iters; i++) {
 | 
			
		||||
    NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
 | 
			
		||||
                            stream));
 | 
			
		||||
  }
 | 
			
		||||
  CUDACHECK(cudaEventRecord(start, stream));
 | 
			
		||||
  for (int i = 0; i < num_iters; i++) {
 | 
			
		||||
    NCCLCHECK(ncclAllReduce(result, result, data_size, ncclDtype, ncclSum, comm,
 | 
			
		||||
                            stream));
 | 
			
		||||
  }
 | 
			
		||||
  CUDACHECK(cudaEventRecord(stop, stream));
 | 
			
		||||
  CUDACHECK(cudaStreamSynchronize(stream));
 | 
			
		||||
  float allreduce_ms = 0;
 | 
			
		||||
  cudaEventElapsedTime(&allreduce_ms, start, stop);
 | 
			
		||||
 | 
			
		||||
  // if (myRank == 1) dummy_kernel<<<1, 1, 0, stream>>>();
 | 
			
		||||
  // set_data<T><<<16, 1024, 0, stream>>>(self_data, data_size, myRank);
 | 
			
		||||
 | 
			
		||||
  dummy_kernel<<<1, 1, 0, stream>>>();
 | 
			
		||||
  // warm up
 | 
			
		||||
  for (int i = 0; i < warmup_iters; i++) {
 | 
			
		||||
    fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
 | 
			
		||||
  }
 | 
			
		||||
  CUDACHECK(cudaEventRecord(start, stream));
 | 
			
		||||
  for (int i = 0; i < num_iters; i++) {
 | 
			
		||||
    fa.allreduce<T>(stream, self_data, result, data_size, threads, block_limit);
 | 
			
		||||
  }
 | 
			
		||||
  CUDACHECK(cudaEventRecord(stop, stream));
 | 
			
		||||
  CUDACHECK(cudaStreamSynchronize(stream));
 | 
			
		||||
 | 
			
		||||
  float duration_ms = 0;
 | 
			
		||||
  cudaEventElapsedTime(&duration_ms, start, stop);
 | 
			
		||||
  if (myRank == 0)
 | 
			
		||||
    printf(
 | 
			
		||||
        "Rank %d done, nGPUs:%d, sz (kb): %d, %d, %d, my time:%.2fus, nccl "
 | 
			
		||||
        "time:%.2fus\n",
 | 
			
		||||
        myRank, nRanks, data_size * sizeof(T) / 1024, threads, block_limit,
 | 
			
		||||
        duration_ms * 1e3 / num_iters, allreduce_ms * 1e3 / num_iters);
 | 
			
		||||
 | 
			
		||||
  // And wait for all the queued up work to complete
 | 
			
		||||
  CUDACHECK(cudaStreamSynchronize(stream));
 | 
			
		||||
 | 
			
		||||
  NCCLCHECK(ncclAllReduce(self_data_copy, self_data, data_size, ncclDtype,
 | 
			
		||||
                          ncclSum, comm, stream));
 | 
			
		||||
 | 
			
		||||
  double *nccl_result, *my_result;
 | 
			
		||||
  CUDACHECK(cudaMallocHost(&nccl_result, data_size * sizeof(double)));
 | 
			
		||||
  CUDACHECK(cudaMallocHost(&my_result, data_size * sizeof(double)));
 | 
			
		||||
 | 
			
		||||
  convert_data<T><<<108, 1024, 0, stream>>>(self_data, result, nccl_result,
 | 
			
		||||
                                            my_result, data_size);
 | 
			
		||||
  CUDACHECK(cudaStreamSynchronize(stream));
 | 
			
		||||
 | 
			
		||||
  for (unsigned long j = 0; j < data_size; j++) {
 | 
			
		||||
    auto diff = abs(nccl_result[j] - my_result[j]);
 | 
			
		||||
    if (diff >= 1e-2) {
 | 
			
		||||
      printf("Rank %d: Verification mismatch at %lld: %f != (my) %f, gt=%f\n",
 | 
			
		||||
             myRank, j, nccl_result[j], my_result[j], ground_truth[j]);
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  long double nccl_diffs = 0.0;
 | 
			
		||||
  long double my_diffs = 0.0;
 | 
			
		||||
  for (int j = 0; j < data_size; j++) {
 | 
			
		||||
    nccl_diffs += abs(nccl_result[j] - ground_truth[j]);
 | 
			
		||||
    my_diffs += abs(my_result[j] - ground_truth[j]);
 | 
			
		||||
  }
 | 
			
		||||
  if (myRank == 0)
 | 
			
		||||
    std::cout << "average abs diffs: nccl: " << nccl_diffs / data_size
 | 
			
		||||
              << " me: " << my_diffs / data_size << std::endl;
 | 
			
		||||
 | 
			
		||||
  CUDACHECK(cudaFree(result));
 | 
			
		||||
  CUDACHECK(cudaFree(self_data_copy));
 | 
			
		||||
  CUDACHECK(cudaFree(rank_data));
 | 
			
		||||
  CUDACHECK(cudaFree(buffer));
 | 
			
		||||
  CUDACHECK(cudaFree(states));
 | 
			
		||||
  CUDACHECK(cudaFreeHost(ground_truth));
 | 
			
		||||
  CUDACHECK(cudaFreeHost(nccl_result));
 | 
			
		||||
  CUDACHECK(cudaFreeHost(my_result));
 | 
			
		||||
  CUDACHECK(cudaStreamDestroy(stream));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
int main(int argc, char **argv) {
 | 
			
		||||
  int nRanks, myRank;
 | 
			
		||||
  MPICHECK(MPI_Init(&argc, &argv));
 | 
			
		||||
  MPICHECK(MPI_Comm_rank(MPI_COMM_WORLD, &myRank));
 | 
			
		||||
  MPICHECK(MPI_Comm_size(MPI_COMM_WORLD, &nRanks));
 | 
			
		||||
  CUDACHECK(cudaSetDevice(myRank));
 | 
			
		||||
  ncclUniqueId id;
 | 
			
		||||
  ncclComm_t comm;
 | 
			
		||||
  if (myRank == 0) ncclGetUniqueId(&id);
 | 
			
		||||
  MPICHECK(MPI_Bcast(static_cast<void *>(&id), sizeof(id), MPI_BYTE, 0,
 | 
			
		||||
                     MPI_COMM_WORLD));
 | 
			
		||||
  NCCLCHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
 | 
			
		||||
 | 
			
		||||
  cudaProfilerStart();
 | 
			
		||||
  // for (int threads : {256, 512}) {
 | 
			
		||||
  //   for (int block_limit = 16; block_limit < 112; block_limit += 4) {
 | 
			
		||||
  //     run<half>(myRank, nRanks, comm, threads, block_limit, 4096 * 1024);
 | 
			
		||||
  //   }
 | 
			
		||||
  // }
 | 
			
		||||
  for (int sz = 512; sz <= (32 << 20); sz *= 2) {
 | 
			
		||||
    run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 50);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  cudaProfilerStop();
 | 
			
		||||
  return EXIT_SUCCESS;
 | 
			
		||||
}
 | 
			
		||||
@ -2,6 +2,8 @@
 | 
			
		||||
 * Adapted from
 | 
			
		||||
 * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h
 | 
			
		||||
 */
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)              \
 | 
			
		||||
@ -12,3 +14,24 @@
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...)             \
 | 
			
		||||
  AT_DISPATCH_SWITCH(                                             \
 | 
			
		||||
    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)     \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)       \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)   \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)           \
 | 
			
		||||
  AT_DISPATCH_SWITCH(                                                    \
 | 
			
		||||
    TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
 | 
			
		||||
    
 | 
			
		||||
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)             \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)      \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)     \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)       \
 | 
			
		||||
  AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
 | 
			
		||||
 | 
			
		||||
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...)             \
 | 
			
		||||
  AT_DISPATCH_SWITCH(                                             \
 | 
			
		||||
    TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,6 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
#include "reduction_utils.cuh"
 | 
			
		||||
@ -76,6 +77,7 @@ void rms_norm(
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(hidden_size, 1024));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    input.scalar_type(),
 | 
			
		||||
@ -101,6 +103,7 @@ void fused_add_rms_norm(
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(hidden_size, 1024));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    input.scalar_type(),
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								csrc/moe/moe_ops.cpp
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								csrc/moe/moe_ops.cpp
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
#include "moe_ops.h"
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
  m.def("topk_softmax", &topk_softmax, "Apply topk softmax to the gating outputs.");
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										9
									
								
								csrc/moe/moe_ops.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										9
									
								
								csrc/moe/moe_ops.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,9 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
void topk_softmax(
 | 
			
		||||
  torch::Tensor& topk_weights,
 | 
			
		||||
  torch::Tensor& topk_indices,
 | 
			
		||||
  torch::Tensor& token_expert_indices,
 | 
			
		||||
  torch::Tensor& gating_output);
 | 
			
		||||
							
								
								
									
										499
									
								
								csrc/moe/topk_softmax_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										499
									
								
								csrc/moe/topk_softmax_kernels.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,499 @@
 | 
			
		||||
/*
 | 
			
		||||
 * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu
 | 
			
		||||
 * Copyright (c) 2024, The vLLM team.
 | 
			
		||||
 * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 | 
			
		||||
 * SPDX-License-Identifier: Apache-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
 * you may not use this file except in compliance with the License.
 | 
			
		||||
 * You may obtain a copy of the License at
 | 
			
		||||
 *
 | 
			
		||||
 * http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 *
 | 
			
		||||
 * Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
 * distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
 * See the License for the specific language governing permissions and
 | 
			
		||||
 * limitations under the License.
 | 
			
		||||
 */
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include <cub/cub.cuh>
 | 
			
		||||
#include <cub/util_type.cuh>
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace moe {
 | 
			
		||||
 | 
			
		||||
static constexpr int WARP_SIZE = 32;
 | 
			
		||||
 | 
			
		||||
/// Aligned array type
 | 
			
		||||
template <
 | 
			
		||||
    typename T,
 | 
			
		||||
    /// Number of elements in the array
 | 
			
		||||
    int N,
 | 
			
		||||
    /// Alignment requirement in bytes
 | 
			
		||||
    int Alignment = sizeof(T) * N
 | 
			
		||||
>
 | 
			
		||||
class alignas(Alignment) AlignedArray {
 | 
			
		||||
    float data[N];
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// ====================== Softmax things ===============================
 | 
			
		||||
// We have our own implementation of softmax here so we can support transposing the output
 | 
			
		||||
// in the softmax kernel when we extend this module to support expert-choice routing.
 | 
			
		||||
template <int TPB>
 | 
			
		||||
__launch_bounds__(TPB) __global__
 | 
			
		||||
    void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols)
 | 
			
		||||
{
 | 
			
		||||
    using BlockReduce = cub::BlockReduce<float, TPB>;
 | 
			
		||||
    __shared__ typename BlockReduce::TempStorage tmpStorage;
 | 
			
		||||
 | 
			
		||||
    __shared__ float normalizing_factor;
 | 
			
		||||
    __shared__ float float_max;
 | 
			
		||||
 | 
			
		||||
    const int thread_row_offset = blockIdx.x * num_cols;
 | 
			
		||||
 | 
			
		||||
    cub::Sum sum;
 | 
			
		||||
    float threadData(-FLT_MAX);
 | 
			
		||||
 | 
			
		||||
    // Don't touch finished rows.
 | 
			
		||||
    if ((finished != nullptr) && finished[blockIdx.x])
 | 
			
		||||
    {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
 | 
			
		||||
    {
 | 
			
		||||
        const int idx = thread_row_offset + ii;
 | 
			
		||||
        threadData = max(static_cast<float>(input[idx]), threadData);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max());
 | 
			
		||||
    if (threadIdx.x == 0)
 | 
			
		||||
    {
 | 
			
		||||
        float_max = maxElem;
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    threadData = 0;
 | 
			
		||||
 | 
			
		||||
    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
 | 
			
		||||
    {
 | 
			
		||||
        const int idx = thread_row_offset + ii;
 | 
			
		||||
        threadData += exp((static_cast<float>(input[idx]) - float_max));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum);
 | 
			
		||||
 | 
			
		||||
    if (threadIdx.x == 0)
 | 
			
		||||
    {
 | 
			
		||||
        normalizing_factor = 1.f / Z;
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    for (int ii = threadIdx.x; ii < num_cols; ii += TPB)
 | 
			
		||||
    {
 | 
			
		||||
        const int idx = thread_row_offset + ii;
 | 
			
		||||
        const float val = exp((static_cast<float>(input[idx]) - float_max)) * normalizing_factor;
 | 
			
		||||
        output[idx] = val;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int TPB>
 | 
			
		||||
__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output,
 | 
			
		||||
    int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert)
 | 
			
		||||
{
 | 
			
		||||
 | 
			
		||||
    using cub_kvp = cub::KeyValuePair<int, float>;
 | 
			
		||||
    using BlockReduce = cub::BlockReduce<cub_kvp, TPB>;
 | 
			
		||||
    __shared__ typename BlockReduce::TempStorage tmpStorage;
 | 
			
		||||
 | 
			
		||||
    cub_kvp thread_kvp;
 | 
			
		||||
    cub::ArgMax arg_max;
 | 
			
		||||
 | 
			
		||||
    const int num_rows = gridDim.x;
 | 
			
		||||
    const int block_row = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
    const bool row_is_active = finished ? !finished[block_row] : true;
 | 
			
		||||
    const int thread_read_offset = blockIdx.x * num_experts;
 | 
			
		||||
    for (int k_idx = 0; k_idx < k; ++k_idx)
 | 
			
		||||
    {
 | 
			
		||||
        thread_kvp.key = 0;
 | 
			
		||||
        thread_kvp.value = -1.f; // This is OK because inputs are probabilities
 | 
			
		||||
 | 
			
		||||
        cub_kvp inp_kvp;
 | 
			
		||||
        for (int expert = threadIdx.x; expert < num_experts; expert += TPB)
 | 
			
		||||
        {
 | 
			
		||||
            const int idx = thread_read_offset + expert;
 | 
			
		||||
            inp_kvp.key = expert;
 | 
			
		||||
            inp_kvp.value = inputs_after_softmax[idx];
 | 
			
		||||
 | 
			
		||||
            for (int prior_k = 0; prior_k < k_idx; ++prior_k)
 | 
			
		||||
            {
 | 
			
		||||
                const int prior_winning_expert = indices[k * block_row + prior_k];
 | 
			
		||||
 | 
			
		||||
                if (prior_winning_expert == expert)
 | 
			
		||||
                {
 | 
			
		||||
                    inp_kvp = thread_kvp;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            thread_kvp = arg_max(inp_kvp, thread_kvp);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max);
 | 
			
		||||
        if (threadIdx.x == 0)
 | 
			
		||||
        {
 | 
			
		||||
            // Ignore experts the node isn't responsible for with expert parallelism
 | 
			
		||||
            const int expert = result_kvp.key;
 | 
			
		||||
            const bool node_uses_expert = expert >= start_expert && expert < end_expert;
 | 
			
		||||
            const bool should_process_row = row_is_active && node_uses_expert;
 | 
			
		||||
 | 
			
		||||
            const int idx = k * block_row + k_idx;
 | 
			
		||||
            output[idx] = result_kvp.value;
 | 
			
		||||
            indices[idx] = should_process_row ? (expert - start_expert) : num_experts;
 | 
			
		||||
            assert(indices[idx] >= 0);
 | 
			
		||||
            source_rows[idx] = k_idx * num_rows + block_row;
 | 
			
		||||
        }
 | 
			
		||||
        __syncthreads();
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ====================== TopK softmax things ===============================
 | 
			
		||||
 | 
			
		||||
/*
 | 
			
		||||
  A Top-K gating softmax written to exploit when the number of experts in the MoE layers
 | 
			
		||||
  are a small power of 2. This allows us to cleanly share the rows among the threads in
 | 
			
		||||
  a single warp and eliminate communication between warps (so no need to use shared mem).
 | 
			
		||||
 | 
			
		||||
  It fuses the softmax, max and argmax into a single kernel.
 | 
			
		||||
 | 
			
		||||
  Limitations:
 | 
			
		||||
  1) This implementation is intended for when the number of experts is a small power of 2.
 | 
			
		||||
  2) This implementation assumes k is small, but will work for any k.
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
template <int VPT, int NUM_EXPERTS, int WARPS_PER_CTA, int BYTES_PER_LDG>
 | 
			
		||||
__launch_bounds__(WARPS_PER_CTA* WARP_SIZE) __global__
 | 
			
		||||
    void topkGatingSoftmax(const float* input, const bool* finished, float* output, const int num_rows, int* indices,
 | 
			
		||||
        int* source_rows, const int k, const int start_expert, const int end_expert)
 | 
			
		||||
{
 | 
			
		||||
    // We begin by enforcing compile time assertions and setting up compile time constants.
 | 
			
		||||
    static_assert(VPT == (VPT & -VPT), "VPT must be power of 2");
 | 
			
		||||
    static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2");
 | 
			
		||||
    static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2");
 | 
			
		||||
    static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16");
 | 
			
		||||
 | 
			
		||||
    // Number of bytes each thread pulls in per load
 | 
			
		||||
    static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
 | 
			
		||||
    static constexpr int ELTS_PER_ROW = NUM_EXPERTS;
 | 
			
		||||
    static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT;
 | 
			
		||||
    static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG;
 | 
			
		||||
 | 
			
		||||
    // Restrictions based on previous section.
 | 
			
		||||
    static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg");
 | 
			
		||||
    static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp");
 | 
			
		||||
    static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2");
 | 
			
		||||
    static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size");
 | 
			
		||||
 | 
			
		||||
    // We have NUM_EXPERTS elements per row. We specialize for small #experts
 | 
			
		||||
    static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT;
 | 
			
		||||
    static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW;
 | 
			
		||||
    static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP;
 | 
			
		||||
 | 
			
		||||
    // Restrictions for previous section.
 | 
			
		||||
    static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp");
 | 
			
		||||
 | 
			
		||||
    // ===================== From this point, we finally start computing run-time variables. ========================
 | 
			
		||||
 | 
			
		||||
    // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps.
 | 
			
		||||
    // This, each block processes a chunk of rows. We start by computing the start row for each block.
 | 
			
		||||
    const int cta_base_row = blockIdx.x * ROWS_PER_CTA;
 | 
			
		||||
 | 
			
		||||
    // Now, using the base row per thread block, we compute the base row per warp.
 | 
			
		||||
    const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP;
 | 
			
		||||
 | 
			
		||||
    // The threads in a warp are split into sub-groups that will work on a row.
 | 
			
		||||
    // We compute row offset for each thread sub-group
 | 
			
		||||
    const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW;
 | 
			
		||||
    const int thread_row = warp_base_row + thread_row_in_warp;
 | 
			
		||||
 | 
			
		||||
    // Threads with indices out of bounds should early exit here.
 | 
			
		||||
    if (thread_row >= num_rows)
 | 
			
		||||
    {
 | 
			
		||||
        return;
 | 
			
		||||
    }
 | 
			
		||||
    const bool row_is_active = finished ? !finished[thread_row] : true;
 | 
			
		||||
 | 
			
		||||
    // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the
 | 
			
		||||
    // row it will read.
 | 
			
		||||
    const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW;
 | 
			
		||||
 | 
			
		||||
    // Now, we compute the group each thread belong to in order to determine the first column to start loads.
 | 
			
		||||
    const int thread_group_idx = threadIdx.x % THREADS_PER_ROW;
 | 
			
		||||
    const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG;
 | 
			
		||||
    const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread;
 | 
			
		||||
 | 
			
		||||
    // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory,
 | 
			
		||||
    // this can support all powers of 2 up to 16.
 | 
			
		||||
    // NOTE(woosuk): The original implementation uses CUTLASS aligned array here.
 | 
			
		||||
    // We defined our own aligned array and use it here to avoid the dependency on CUTLASS.
 | 
			
		||||
    using AccessType = AlignedArray<float, ELTS_PER_LDG>;
 | 
			
		||||
 | 
			
		||||
    // Finally, we pull in the data from global mem
 | 
			
		||||
    float row_chunk[VPT];
 | 
			
		||||
    AccessType* row_chunk_vec_ptr = reinterpret_cast<AccessType*>(&row_chunk);
 | 
			
		||||
    const AccessType* vec_thread_read_ptr = reinterpret_cast<const AccessType*>(thread_read_ptr);
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int ii = 0; ii < LDG_PER_THREAD; ++ii)
 | 
			
		||||
    {
 | 
			
		||||
        row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just
 | 
			
		||||
    // convert to float afterwards for the exp + sum reduction.
 | 
			
		||||
    float thread_max = row_chunk[0];
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int ii = 1; ii < VPT; ++ii)
 | 
			
		||||
    {
 | 
			
		||||
        thread_max = max(thread_max, row_chunk[ii]);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce.
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
 | 
			
		||||
    {
 | 
			
		||||
        thread_max = max(thread_max, __shfl_xor_sync(0xFFFFFFFF, thread_max, mask, THREADS_PER_ROW));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // From this point, thread max in all the threads have the max within the row.
 | 
			
		||||
    // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum.
 | 
			
		||||
    float row_sum = 0;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int ii = 0; ii < VPT; ++ii)
 | 
			
		||||
    {
 | 
			
		||||
        row_chunk[ii] = expf(row_chunk[ii] - thread_max);
 | 
			
		||||
        row_sum += row_chunk[ii];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern.
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
 | 
			
		||||
    {
 | 
			
		||||
        row_sum += __shfl_xor_sync(0xFFFFFFFF, row_sum, mask, THREADS_PER_ROW);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables
 | 
			
		||||
    // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to
 | 
			
		||||
    // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row.
 | 
			
		||||
    // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the
 | 
			
		||||
    // argmax after computing the softmax.
 | 
			
		||||
    const float reciprocal_row_sum = 1.f / row_sum;
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (int ii = 0; ii < VPT; ++ii)
 | 
			
		||||
    {
 | 
			
		||||
        row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along
 | 
			
		||||
    // with the max index.
 | 
			
		||||
    int start_col = first_elt_read_by_thread;
 | 
			
		||||
    static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW;
 | 
			
		||||
 | 
			
		||||
    for (int k_idx = 0; k_idx < k; ++k_idx)
 | 
			
		||||
    {
 | 
			
		||||
        // First, each thread does the local argmax
 | 
			
		||||
        float max_val = row_chunk[0];
 | 
			
		||||
        int expert = start_col;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
        for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG)
 | 
			
		||||
        {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
            for (int ii = 0; ii < ELTS_PER_LDG; ++ii)
 | 
			
		||||
            {
 | 
			
		||||
                float val = row_chunk[ldg * ELTS_PER_LDG + ii];
 | 
			
		||||
 | 
			
		||||
                // No check on the experts here since columns with the smallest index are processed first and only
 | 
			
		||||
                // updated if > (not >=)
 | 
			
		||||
                if (val > max_val)
 | 
			
		||||
                {
 | 
			
		||||
                    max_val = val;
 | 
			
		||||
                    expert = col + ii;
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max.
 | 
			
		||||
// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can
 | 
			
		||||
// then blank out their max with -inf and the warp can run more iterations...
 | 
			
		||||
#pragma unroll
 | 
			
		||||
        for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2)
 | 
			
		||||
        {
 | 
			
		||||
            float other_max = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, THREADS_PER_ROW);
 | 
			
		||||
            int other_expert = __shfl_xor_sync(0xFFFFFFFF, expert, mask, THREADS_PER_ROW);
 | 
			
		||||
 | 
			
		||||
            // We want lower indices to "win" in every thread so we break ties this way
 | 
			
		||||
            if (other_max > max_val || (other_max == max_val && other_expert < expert))
 | 
			
		||||
            {
 | 
			
		||||
                max_val = other_max;
 | 
			
		||||
                expert = other_expert;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Write the max for this k iteration to global memory.
 | 
			
		||||
        if (thread_group_idx == 0)
 | 
			
		||||
        {
 | 
			
		||||
            // Add a guard to ignore experts not included by this node
 | 
			
		||||
            const bool node_uses_expert = expert >= start_expert && expert < end_expert;
 | 
			
		||||
            const bool should_process_row = row_is_active && node_uses_expert;
 | 
			
		||||
 | 
			
		||||
            // The lead thread from each sub-group will write out the final results to global memory. (This will be a
 | 
			
		||||
            // single) thread per row of the input/output matrices.
 | 
			
		||||
            const int idx = k * thread_row + k_idx;
 | 
			
		||||
            output[idx] = max_val;
 | 
			
		||||
            indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS;
 | 
			
		||||
            source_rows[idx] = k_idx * num_rows + thread_row;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        // Finally, we clear the value in the thread with the current max if there is another iteration to run.
 | 
			
		||||
        if (k_idx + 1 < k)
 | 
			
		||||
        {
 | 
			
		||||
            const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG;
 | 
			
		||||
            const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW;
 | 
			
		||||
 | 
			
		||||
            // Only the thread in the group which produced the max will reset the "winning" value to -inf.
 | 
			
		||||
            if (thread_group_idx == thread_to_clear_in_group)
 | 
			
		||||
            {
 | 
			
		||||
                const int offset_for_expert = expert % ELTS_PER_LDG;
 | 
			
		||||
                // Safe to set to any negative value since row_chunk values must be between 0 and 1.
 | 
			
		||||
                row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f;
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
namespace detail
 | 
			
		||||
{
 | 
			
		||||
// Constructs some constants needed to partition the work across threads at compile time.
 | 
			
		||||
template <int EXPERTS, int BYTES_PER_LDG>
 | 
			
		||||
struct TopkConstants
 | 
			
		||||
{
 | 
			
		||||
    static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float);
 | 
			
		||||
    static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, "");
 | 
			
		||||
    static constexpr int VECs_PER_THREAD = std::max(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE));
 | 
			
		||||
    static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG;
 | 
			
		||||
    static constexpr int THREADS_PER_ROW = EXPERTS / VPT;
 | 
			
		||||
    static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW;
 | 
			
		||||
};
 | 
			
		||||
} // namespace detail
 | 
			
		||||
 | 
			
		||||
template <int EXPERTS, int WARPS_PER_TB>
 | 
			
		||||
void topkGatingSoftmaxLauncherHelper(const float* input, const bool* finished, float* output, int* indices,
 | 
			
		||||
    int* source_row, const int num_rows, const int k, const int start_expert, const int end_expert, cudaStream_t stream)
 | 
			
		||||
{
 | 
			
		||||
    static constexpr std::size_t MAX_BYTES_PER_LDG = 16;
 | 
			
		||||
 | 
			
		||||
    static constexpr int BYTES_PER_LDG = std::min(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS);
 | 
			
		||||
    using Constants = detail::TopkConstants<EXPERTS, BYTES_PER_LDG>;
 | 
			
		||||
    static constexpr int VPT = Constants::VPT;
 | 
			
		||||
    static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP;
 | 
			
		||||
    const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP;
 | 
			
		||||
    const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;
 | 
			
		||||
 | 
			
		||||
    dim3 block_dim(WARP_SIZE, WARPS_PER_TB);
 | 
			
		||||
    topkGatingSoftmax<VPT, EXPERTS, WARPS_PER_TB, BYTES_PER_LDG><<<num_blocks, block_dim, 0, stream>>>(
 | 
			
		||||
        input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB)                       \
 | 
			
		||||
    topkGatingSoftmaxLauncherHelper<NUM_EXPERTS, WARPS_PER_TB>(         \
 | 
			
		||||
        gating_output, nullptr, topk_weights, topk_indicies,            \
 | 
			
		||||
        token_expert_indices, num_tokens, topk, 0, num_experts,         \
 | 
			
		||||
        stream);
 | 
			
		||||
 | 
			
		||||
void topkGatingSoftmaxKernelLauncher(
 | 
			
		||||
    const float* gating_output,
 | 
			
		||||
    float* topk_weights,
 | 
			
		||||
    int* topk_indicies,
 | 
			
		||||
    int* token_expert_indices,
 | 
			
		||||
    float* softmax_workspace,
 | 
			
		||||
    const int num_tokens,
 | 
			
		||||
    const int num_experts,
 | 
			
		||||
    const int topk,
 | 
			
		||||
    cudaStream_t stream) {
 | 
			
		||||
    static constexpr int WARPS_PER_TB = 4;
 | 
			
		||||
    switch (num_experts) {
 | 
			
		||||
        case 1:
 | 
			
		||||
            LAUNCH_SOFTMAX(1, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 2:
 | 
			
		||||
            LAUNCH_SOFTMAX(2, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 4:
 | 
			
		||||
            LAUNCH_SOFTMAX(4, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 8:
 | 
			
		||||
            LAUNCH_SOFTMAX(8, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 16:
 | 
			
		||||
            LAUNCH_SOFTMAX(16, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 32:
 | 
			
		||||
            LAUNCH_SOFTMAX(32, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 64:
 | 
			
		||||
            LAUNCH_SOFTMAX(64, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 128:
 | 
			
		||||
            LAUNCH_SOFTMAX(128, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        case 256:
 | 
			
		||||
            LAUNCH_SOFTMAX(256, WARPS_PER_TB);
 | 
			
		||||
            break;
 | 
			
		||||
        default: {
 | 
			
		||||
            TORCH_CHECK(softmax_workspace != nullptr,
 | 
			
		||||
                "softmax_workspace must be provided for num_experts that are not a power of 2.");
 | 
			
		||||
            static constexpr int TPB = 256;
 | 
			
		||||
            moeSoftmax<TPB><<<num_tokens, TPB, 0, stream>>>(
 | 
			
		||||
                gating_output, nullptr, softmax_workspace, num_experts);
 | 
			
		||||
            moeTopK<TPB><<<num_tokens, TPB, 0, stream>>>(
 | 
			
		||||
                softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices,
 | 
			
		||||
                num_experts, topk, 0, num_experts);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace moe
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
 | 
			
		||||
void topk_softmax(
 | 
			
		||||
    torch::Tensor& topk_weights,                // [num_tokens, topk]
 | 
			
		||||
    torch::Tensor& topk_indices,                // [num_tokens, topk]
 | 
			
		||||
    torch::Tensor& token_expert_indices,        // [num_tokens, topk]
 | 
			
		||||
    torch::Tensor& gating_output)               // [num_tokens, num_experts]
 | 
			
		||||
{
 | 
			
		||||
    const int num_experts = gating_output.size(-1);
 | 
			
		||||
    const int num_tokens = gating_output.numel() / num_experts;
 | 
			
		||||
    const int topk = topk_weights.size(-1);
 | 
			
		||||
 | 
			
		||||
    const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
 | 
			
		||||
    const bool needs_workspace = !is_pow_2 || num_experts > 256;
 | 
			
		||||
    const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0;
 | 
			
		||||
 | 
			
		||||
    const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output));
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options());
 | 
			
		||||
    vllm::moe::topkGatingSoftmaxKernelLauncher(
 | 
			
		||||
        gating_output.data_ptr<float>(),
 | 
			
		||||
        topk_weights.data_ptr<float>(),
 | 
			
		||||
        topk_indices.data_ptr<int>(),
 | 
			
		||||
        token_expert_indices.data_ptr<int>(),
 | 
			
		||||
        softmax_workspace.data_ptr<float>(),
 | 
			
		||||
        num_tokens,
 | 
			
		||||
        num_experts,
 | 
			
		||||
        topk,
 | 
			
		||||
        stream);
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										108
									
								
								csrc/moe_align_block_size_kernels.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										108
									
								
								csrc/moe_align_block_size_kernels.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,108 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
 | 
			
		||||
#include <ATen/ATen.h>
 | 
			
		||||
#include <THC/THCAtomics.cuh>
 | 
			
		||||
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
 | 
			
		||||
const static size_t NUM_MAX_EXPERTS = 64;
 | 
			
		||||
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
template <typename scalar_t>
 | 
			
		||||
__global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids, 
 | 
			
		||||
                                int32_t *sorted_token_ids, 
 | 
			
		||||
                                int32_t *expert_ids, 
 | 
			
		||||
                                int32_t *total_tokens_post_pad,
 | 
			
		||||
                                int32_t num_experts, 
 | 
			
		||||
                                int32_t block_size, 
 | 
			
		||||
                                size_t numel) {
 | 
			
		||||
    const size_t tokens_per_thread = CEILDIV(numel, blockDim.x);
 | 
			
		||||
    const size_t start_idx = threadIdx.x * tokens_per_thread;
 | 
			
		||||
    __shared__ int32_t tokens_cnts[NUM_MAX_EXPERTS + 1][NUM_MAX_EXPERTS];
 | 
			
		||||
    __shared__ int32_t cumsum[NUM_MAX_EXPERTS + 1];
 | 
			
		||||
    for (int i = 0; i < num_experts; ++i) {
 | 
			
		||||
        tokens_cnts[threadIdx.x + 1][i] = 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
    * In the first step we compute token_cnts[thread_index + 1][expert_index],
 | 
			
		||||
    * which counts how many tokens in the token shard of thread_index are assigned
 | 
			
		||||
    * to expert expert_index.
 | 
			
		||||
    */
 | 
			
		||||
    for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 | 
			
		||||
        ++tokens_cnts[threadIdx.x + 1][topk_ids[i]]; 
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // For each expert we accumulate the token counts from the different threads.
 | 
			
		||||
    tokens_cnts[0][threadIdx.x] = 0;
 | 
			
		||||
    for (int i = 1; i <= blockDim.x; ++i) {
 | 
			
		||||
        tokens_cnts[i][threadIdx.x] += tokens_cnts[i-1][threadIdx.x];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    
 | 
			
		||||
    // We accumulate the token counts of all experts in thread 0.
 | 
			
		||||
    if (threadIdx.x == 0) {
 | 
			
		||||
        cumsum[0] = 0;
 | 
			
		||||
        for (int i = 1; i <= num_experts; ++i) {
 | 
			
		||||
            cumsum[i] = cumsum[i-1] + CEILDIV(tokens_cnts[blockDim.x][i - 1], block_size) * block_size;
 | 
			
		||||
        }
 | 
			
		||||
        *total_tokens_post_pad = cumsum[num_experts];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    /**
 | 
			
		||||
    * For each expert, each thread processes the tokens of the corresponding blocks
 | 
			
		||||
    * and stores the corresponding expert_id for each block.
 | 
			
		||||
    */
 | 
			
		||||
    for (int i = cumsum[threadIdx.x];i < cumsum[threadIdx.x + 1];i += block_size) {
 | 
			
		||||
        expert_ids[i / block_size] = threadIdx.x;
 | 
			
		||||
    }
 | 
			
		||||
    
 | 
			
		||||
    /**
 | 
			
		||||
    * Each thread processes a token shard, calculating the index of each token after
 | 
			
		||||
    * sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
 | 
			
		||||
    * block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
 | 
			
		||||
    * where * represents a padding value(preset in python).
 | 
			
		||||
    */
 | 
			
		||||
    for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
 | 
			
		||||
        int32_t expert_id = topk_ids[i];
 | 
			
		||||
        /** The cumsum[expert_id] stores the starting index of the tokens that the
 | 
			
		||||
        * expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
 | 
			
		||||
        * stores the indices of the tokens processed by the expert with expert_id within
 | 
			
		||||
        * the current thread's token shard.
 | 
			
		||||
        */
 | 
			
		||||
        int32_t rank_post_pad = tokens_cnts[threadIdx.x][expert_id] + cumsum[expert_id];
 | 
			
		||||
        sorted_token_ids[rank_post_pad] = i;
 | 
			
		||||
        ++tokens_cnts[threadIdx.x][expert_id];
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void moe_align_block_size(
 | 
			
		||||
    torch::Tensor topk_ids,
 | 
			
		||||
    int num_experts,
 | 
			
		||||
    int block_size,
 | 
			
		||||
    torch::Tensor sorted_token_ids,
 | 
			
		||||
    torch::Tensor experts_ids,
 | 
			
		||||
    torch::Tensor num_tokens_post_pad) {
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    assert(num_experts <= NUM_MAX_EXPERTS);
 | 
			
		||||
    VLLM_DISPATCH_INTEGRAL_TYPES(
 | 
			
		||||
        topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
 | 
			
		||||
        vllm::moe_align_block_size_kernel<scalar_t><<<1, num_experts, 0, stream>>>(
 | 
			
		||||
            topk_ids.data_ptr<scalar_t>(), 
 | 
			
		||||
            sorted_token_ids.data_ptr<int32_t>(), 
 | 
			
		||||
            experts_ids.data_ptr<int32_t>(), 
 | 
			
		||||
            num_tokens_post_pad.data_ptr<int32_t>(), 
 | 
			
		||||
            num_experts,
 | 
			
		||||
            block_size,
 | 
			
		||||
            topk_ids.numel());
 | 
			
		||||
    });
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										63
									
								
								csrc/ops.h
									
									
									
									
									
								
							
							
						
						
									
										63
									
								
								csrc/ops.h
									
									
									
									
									
								
							@ -1,3 +1,5 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
void paged_attention_v1(
 | 
			
		||||
@ -5,13 +7,14 @@ void paged_attention_v1(
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& head_mapping,
 | 
			
		||||
  int num_kv_heads,
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,
 | 
			
		||||
  torch::Tensor& context_lens,
 | 
			
		||||
  int block_size,
 | 
			
		||||
  int max_context_len,
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes);
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
  const std::string& kv_cache_dtype);
 | 
			
		||||
 | 
			
		||||
void paged_attention_v2(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
@ -21,13 +24,14 @@ void paged_attention_v2(
 | 
			
		||||
  torch::Tensor& query,
 | 
			
		||||
  torch::Tensor& key_cache,
 | 
			
		||||
  torch::Tensor& value_cache,
 | 
			
		||||
  torch::Tensor& head_mapping,
 | 
			
		||||
  int num_kv_heads,
 | 
			
		||||
  float scale,
 | 
			
		||||
  torch::Tensor& block_tables,
 | 
			
		||||
  torch::Tensor& context_lens,
 | 
			
		||||
  int block_size,
 | 
			
		||||
  int max_context_len,
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes);
 | 
			
		||||
  const c10::optional<torch::Tensor>& alibi_slopes,
 | 
			
		||||
  const std::string& kv_cache_dtype);
 | 
			
		||||
 | 
			
		||||
void rms_norm(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
@ -61,6 +65,7 @@ void gelu_fast(
 | 
			
		||||
  torch::Tensor& out,
 | 
			
		||||
  torch::Tensor& input);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
torch::Tensor awq_gemm(
 | 
			
		||||
  torch::Tensor _in_feats,
 | 
			
		||||
  torch::Tensor _kernel,
 | 
			
		||||
@ -68,8 +73,58 @@ torch::Tensor awq_gemm(
 | 
			
		||||
  torch::Tensor _zeros,
 | 
			
		||||
  int split_k_iters);
 | 
			
		||||
 | 
			
		||||
torch::Tensor awq_dequantize(
 | 
			
		||||
    torch::Tensor _kernel,
 | 
			
		||||
    torch::Tensor _scaling_factors,
 | 
			
		||||
    torch::Tensor _zeros,
 | 
			
		||||
    int split_k_iters,
 | 
			
		||||
    int thx,
 | 
			
		||||
    int thy);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
void squeezellm_gemm(
 | 
			
		||||
  torch::Tensor vec,
 | 
			
		||||
  torch::Tensor mat,
 | 
			
		||||
  torch::Tensor mul,
 | 
			
		||||
  torch::Tensor lookup_table);
 | 
			
		||||
 | 
			
		||||
torch::Tensor gptq_gemm(
 | 
			
		||||
  torch::Tensor a,
 | 
			
		||||
  torch::Tensor b_q_weight,
 | 
			
		||||
  torch::Tensor b_gptq_qzeros,
 | 
			
		||||
  torch::Tensor b_gptq_scales,
 | 
			
		||||
  torch::Tensor b_g_idx,
 | 
			
		||||
  bool use_exllama);
 | 
			
		||||
 | 
			
		||||
void gptq_shuffle(
 | 
			
		||||
  torch::Tensor q_weight,
 | 
			
		||||
  torch::Tensor q_perm);
 | 
			
		||||
 | 
			
		||||
void moe_align_block_size(
 | 
			
		||||
  torch::Tensor topk_ids,
 | 
			
		||||
  int num_experts,
 | 
			
		||||
  int block_size,
 | 
			
		||||
  torch::Tensor sorted_token_ids,
 | 
			
		||||
  torch::Tensor experts_ids,
 | 
			
		||||
  torch::Tensor num_tokens_post_pad);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
using fptr_t = uint64_t;
 | 
			
		||||
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
 | 
			
		||||
                    const std::vector<std::string> &handles,
 | 
			
		||||
                    const std::vector<int64_t> &offsets, int rank,
 | 
			
		||||
                    bool full_nvlink);
 | 
			
		||||
bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
 | 
			
		||||
                      bool full_nvlink);
 | 
			
		||||
void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out);
 | 
			
		||||
void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
 | 
			
		||||
                      torch::Tensor &out);
 | 
			
		||||
void dispose(fptr_t _fa);
 | 
			
		||||
int meta_size();
 | 
			
		||||
void register_buffer(fptr_t _fa, torch::Tensor &t,
 | 
			
		||||
                     const std::vector<std::string> &handles,
 | 
			
		||||
                     const std::vector<int64_t> &offsets);
 | 
			
		||||
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
 | 
			
		||||
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
 | 
			
		||||
                            const std::vector<std::vector<int64_t>> &offsets);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,8 @@
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
#include "dispatch_utils.h"
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
@ -19,14 +21,14 @@ inline __device__ void apply_rotary_embedding(
 | 
			
		||||
    // GPT-NeoX style rotary embedding.
 | 
			
		||||
    x_index = rot_offset;
 | 
			
		||||
    y_index = embed_dim + rot_offset;
 | 
			
		||||
    cos = __ldg(cos_ptr + x_index);
 | 
			
		||||
    sin = __ldg(sin_ptr + x_index);
 | 
			
		||||
    cos = VLLM_LDG(cos_ptr + x_index);
 | 
			
		||||
    sin = VLLM_LDG(sin_ptr + x_index);
 | 
			
		||||
  } else {
 | 
			
		||||
    // GPT-J style rotary embedding.
 | 
			
		||||
    x_index = 2 * rot_offset;
 | 
			
		||||
    y_index = 2 * rot_offset + 1;
 | 
			
		||||
    cos = __ldg(cos_ptr + x_index / 2);
 | 
			
		||||
    sin = __ldg(sin_ptr + x_index / 2);
 | 
			
		||||
    cos = VLLM_LDG(cos_ptr + x_index / 2);
 | 
			
		||||
    sin = VLLM_LDG(sin_ptr + x_index / 2);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  const scalar_t x = arr[x_index];
 | 
			
		||||
@ -42,8 +44,8 @@ __global__ void rotary_embedding_kernel(
 | 
			
		||||
  scalar_t* __restrict__ key,                   // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
 | 
			
		||||
  const scalar_t* __restrict__ cos_sin_cache,   // [max_position, 2, rot_dim // 2]
 | 
			
		||||
  const int rot_dim,
 | 
			
		||||
  const int query_stride,
 | 
			
		||||
  const int key_stride,
 | 
			
		||||
  const int64_t query_stride,
 | 
			
		||||
  const int64_t key_stride,
 | 
			
		||||
  const int num_heads,
 | 
			
		||||
  const int num_kv_heads,
 | 
			
		||||
  const int head_size) {
 | 
			
		||||
@ -59,7 +61,7 @@ __global__ void rotary_embedding_kernel(
 | 
			
		||||
  const int nq = num_heads * embed_dim;
 | 
			
		||||
  for (int i = threadIdx.x; i < nq; i += blockDim.x) {
 | 
			
		||||
    const int head_idx = i / embed_dim;
 | 
			
		||||
    const int token_head = token_idx * query_stride + head_idx * head_size;
 | 
			
		||||
    const int64_t token_head = token_idx * query_stride + head_idx * head_size;
 | 
			
		||||
    const int rot_offset = i % embed_dim;
 | 
			
		||||
    apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
 | 
			
		||||
                                              sin_ptr, rot_offset, embed_dim);
 | 
			
		||||
@ -68,7 +70,7 @@ __global__ void rotary_embedding_kernel(
 | 
			
		||||
  const int nk = num_kv_heads * embed_dim;
 | 
			
		||||
  for (int i = threadIdx.x; i < nk; i += blockDim.x) {
 | 
			
		||||
    const int head_idx = i / embed_dim;
 | 
			
		||||
    const int token_head = token_idx * key_stride + head_idx * head_size;
 | 
			
		||||
    const int64_t token_head = token_idx * key_stride + head_idx * head_size;
 | 
			
		||||
    const int rot_offset = i % embed_dim;
 | 
			
		||||
    apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
 | 
			
		||||
                                              sin_ptr, rot_offset, embed_dim);
 | 
			
		||||
@ -88,11 +90,12 @@ void rotary_embedding(
 | 
			
		||||
  int rot_dim = cos_sin_cache.size(1);
 | 
			
		||||
  int num_heads = query.size(-1) / head_size;
 | 
			
		||||
  int num_kv_heads = key.size(-1) / head_size;
 | 
			
		||||
  int query_stride = query.stride(-2);
 | 
			
		||||
  int key_stride = key.stride(-2);
 | 
			
		||||
  int64_t query_stride = query.stride(-2);
 | 
			
		||||
  int64_t key_stride = key.stride(-2);
 | 
			
		||||
 | 
			
		||||
  dim3 grid(num_tokens);
 | 
			
		||||
  dim3 block(std::min(num_heads * rot_dim / 2, 512));
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  VLLM_DISPATCH_FLOATING_TYPES(
 | 
			
		||||
    query.scalar_type(),
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										217
									
								
								csrc/punica/LICENSE
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								csrc/punica/LICENSE
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,217 @@
 | 
			
		||||
Contains code from https://github.com/punica-ai/punica
 | 
			
		||||
 | 
			
		||||
                                 Apache License
 | 
			
		||||
                           Version 2.0, January 2004
 | 
			
		||||
                        http://www.apache.org/licenses/
 | 
			
		||||
 | 
			
		||||
   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 | 
			
		||||
 | 
			
		||||
   1. Definitions.
 | 
			
		||||
 | 
			
		||||
      "License" shall mean the terms and conditions for use, reproduction,
 | 
			
		||||
      and distribution as defined by Sections 1 through 9 of this document.
 | 
			
		||||
 | 
			
		||||
      "Licensor" shall mean the copyright owner or entity authorized by
 | 
			
		||||
      the copyright owner that is granting the License.
 | 
			
		||||
 | 
			
		||||
      "Legal Entity" shall mean the union of the acting entity and all
 | 
			
		||||
      other entities that control, are controlled by, or are under common
 | 
			
		||||
      control with that entity. For the purposes of this definition,
 | 
			
		||||
      "control" means (i) the power, direct or indirect, to cause the
 | 
			
		||||
      direction or management of such entity, whether by contract or
 | 
			
		||||
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
 | 
			
		||||
      outstanding shares, or (iii) beneficial ownership of such entity.
 | 
			
		||||
 | 
			
		||||
      "You" (or "Your") shall mean an individual or Legal Entity
 | 
			
		||||
      exercising permissions granted by this License.
 | 
			
		||||
 | 
			
		||||
      "Source" form shall mean the preferred form for making modifications,
 | 
			
		||||
      including but not limited to software source code, documentation
 | 
			
		||||
      source, and configuration files.
 | 
			
		||||
 | 
			
		||||
      "Object" form shall mean any form resulting from mechanical
 | 
			
		||||
      transformation or translation of a Source form, including but
 | 
			
		||||
      not limited to compiled object code, generated documentation,
 | 
			
		||||
      and conversions to other media types.
 | 
			
		||||
 | 
			
		||||
      "Work" shall mean the work of authorship, whether in Source or
 | 
			
		||||
      Object form, made available under the License, as indicated by a
 | 
			
		||||
      copyright notice that is included in or attached to the work
 | 
			
		||||
      (an example is provided in the Appendix below).
 | 
			
		||||
 | 
			
		||||
      "Derivative Works" shall mean any work, whether in Source or Object
 | 
			
		||||
      form, that is based on (or derived from) the Work and for which the
 | 
			
		||||
      editorial revisions, annotations, elaborations, or other modifications
 | 
			
		||||
      represent, as a whole, an original work of authorship. For the purposes
 | 
			
		||||
      of this License, Derivative Works shall not include works that remain
 | 
			
		||||
      separable from, or merely link (or bind by name) to the interfaces of,
 | 
			
		||||
      the Work and Derivative Works thereof.
 | 
			
		||||
 | 
			
		||||
      "Contribution" shall mean any work of authorship, including
 | 
			
		||||
      the original version of the Work and any modifications or additions
 | 
			
		||||
      to that Work or Derivative Works thereof, that is intentionally
 | 
			
		||||
      submitted to Licensor for inclusion in the Work by the copyright owner
 | 
			
		||||
      or by an individual or Legal Entity authorized to submit on behalf of
 | 
			
		||||
      the copyright owner. For the purposes of this definition, "submitted"
 | 
			
		||||
      means any form of electronic, verbal, or written communication sent
 | 
			
		||||
      to the Licensor or its representatives, including but not limited to
 | 
			
		||||
      communication on electronic mailing lists, source code control systems,
 | 
			
		||||
      and issue tracking systems that are managed by, or on behalf of, the
 | 
			
		||||
      Licensor for the purpose of discussing and improving the Work, but
 | 
			
		||||
      excluding communication that is conspicuously marked or otherwise
 | 
			
		||||
      designated in writing by the copyright owner as "Not a Contribution."
 | 
			
		||||
 | 
			
		||||
      "Contributor" shall mean Licensor and any individual or Legal Entity
 | 
			
		||||
      on behalf of whom a Contribution has been received by Licensor and
 | 
			
		||||
      subsequently incorporated within the Work.
 | 
			
		||||
 | 
			
		||||
   2. Grant of Copyright License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      copyright license to reproduce, prepare Derivative Works of,
 | 
			
		||||
      publicly display, publicly perform, sublicense, and distribute the
 | 
			
		||||
      Work and such Derivative Works in Source or Object form.
 | 
			
		||||
 | 
			
		||||
   3. Grant of Patent License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      (except as stated in this section) patent license to make, have made,
 | 
			
		||||
      use, offer to sell, sell, import, and otherwise transfer the Work,
 | 
			
		||||
      where such license applies only to those patent claims licensable
 | 
			
		||||
      by such Contributor that are necessarily infringed by their
 | 
			
		||||
      Contribution(s) alone or by combination of their Contribution(s)
 | 
			
		||||
      with the Work to which such Contribution(s) was submitted. If You
 | 
			
		||||
      institute patent litigation against any entity (including a
 | 
			
		||||
      cross-claim or counterclaim in a lawsuit) alleging that the Work
 | 
			
		||||
      or a Contribution incorporated within the Work constitutes direct
 | 
			
		||||
      or contributory patent infringement, then any patent licenses
 | 
			
		||||
      granted to You under this License for that Work shall terminate
 | 
			
		||||
      as of the date such litigation is filed.
 | 
			
		||||
 | 
			
		||||
   4. Redistribution. You may reproduce and distribute copies of the
 | 
			
		||||
      Work or Derivative Works thereof in any medium, with or without
 | 
			
		||||
      modifications, and in Source or Object form, provided that You
 | 
			
		||||
      meet the following conditions:
 | 
			
		||||
 | 
			
		||||
      (a) You must give any other recipients of the Work or
 | 
			
		||||
          Derivative Works a copy of this License; and
 | 
			
		||||
 | 
			
		||||
      (b) You must cause any modified files to carry prominent notices
 | 
			
		||||
          stating that You changed the files; and
 | 
			
		||||
 | 
			
		||||
      (c) You must retain, in the Source form of any Derivative Works
 | 
			
		||||
          that You distribute, all copyright, patent, trademark, and
 | 
			
		||||
          attribution notices from the Source form of the Work,
 | 
			
		||||
          excluding those notices that do not pertain to any part of
 | 
			
		||||
          the Derivative Works; and
 | 
			
		||||
 | 
			
		||||
      (d) If the Work includes a "NOTICE" text file as part of its
 | 
			
		||||
          distribution, then any Derivative Works that You distribute must
 | 
			
		||||
          include a readable copy of the attribution notices contained
 | 
			
		||||
          within such NOTICE file, excluding those notices that do not
 | 
			
		||||
          pertain to any part of the Derivative Works, in at least one
 | 
			
		||||
          of the following places: within a NOTICE text file distributed
 | 
			
		||||
          as part of the Derivative Works; within the Source form or
 | 
			
		||||
          documentation, if provided along with the Derivative Works; or,
 | 
			
		||||
          within a display generated by the Derivative Works, if and
 | 
			
		||||
          wherever such third-party notices normally appear. The contents
 | 
			
		||||
          of the NOTICE file are for informational purposes only and
 | 
			
		||||
          do not modify the License. You may add Your own attribution
 | 
			
		||||
          notices within Derivative Works that You distribute, alongside
 | 
			
		||||
          or as an addendum to the NOTICE text from the Work, provided
 | 
			
		||||
          that such additional attribution notices cannot be construed
 | 
			
		||||
          as modifying the License.
 | 
			
		||||
 | 
			
		||||
      You may add Your own copyright statement to Your modifications and
 | 
			
		||||
      may provide additional or different license terms and conditions
 | 
			
		||||
      for use, reproduction, or distribution of Your modifications, or
 | 
			
		||||
      for any such Derivative Works as a whole, provided Your use,
 | 
			
		||||
      reproduction, and distribution of the Work otherwise complies with
 | 
			
		||||
      the conditions stated in this License.
 | 
			
		||||
 | 
			
		||||
   5. Submission of Contributions. Unless You explicitly state otherwise,
 | 
			
		||||
      any Contribution intentionally submitted for inclusion in the Work
 | 
			
		||||
      by You to the Licensor shall be under the terms and conditions of
 | 
			
		||||
      this License, without any additional terms or conditions.
 | 
			
		||||
      Notwithstanding the above, nothing herein shall supersede or modify
 | 
			
		||||
      the terms of any separate license agreement you may have executed
 | 
			
		||||
      with Licensor regarding such Contributions.
 | 
			
		||||
 | 
			
		||||
   6. Trademarks. This License does not grant permission to use the trade
 | 
			
		||||
      names, trademarks, service marks, or product names of the Licensor,
 | 
			
		||||
      except as required for reasonable and customary use in describing the
 | 
			
		||||
      origin of the Work and reproducing the content of the NOTICE file.
 | 
			
		||||
 | 
			
		||||
   7. Disclaimer of Warranty. Unless required by applicable law or
 | 
			
		||||
      agreed to in writing, Licensor provides the Work (and each
 | 
			
		||||
      Contributor provides its Contributions) on an "AS IS" BASIS,
 | 
			
		||||
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
			
		||||
      implied, including, without limitation, any warranties or conditions
 | 
			
		||||
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
 | 
			
		||||
      PARTICULAR PURPOSE. You are solely responsible for determining the
 | 
			
		||||
      appropriateness of using or redistributing the Work and assume any
 | 
			
		||||
      risks associated with Your exercise of permissions under this License.
 | 
			
		||||
 | 
			
		||||
   8. Limitation of Liability. In no event and under no legal theory,
 | 
			
		||||
      whether in tort (including negligence), contract, or otherwise,
 | 
			
		||||
      unless required by applicable law (such as deliberate and grossly
 | 
			
		||||
      negligent acts) or agreed to in writing, shall any Contributor be
 | 
			
		||||
      liable to You for damages, including any direct, indirect, special,
 | 
			
		||||
      incidental, or consequential damages of any character arising as a
 | 
			
		||||
      result of this License or out of the use or inability to use the
 | 
			
		||||
      Work (including but not limited to damages for loss of goodwill,
 | 
			
		||||
      work stoppage, computer failure or malfunction, or any and all
 | 
			
		||||
      other commercial damages or losses), even if such Contributor
 | 
			
		||||
      has been advised of the possibility of such damages.
 | 
			
		||||
 | 
			
		||||
   9. Accepting Warranty or Additional Liability. While redistributing
 | 
			
		||||
      the Work or Derivative Works thereof, You may choose to offer,
 | 
			
		||||
      and charge a fee for, acceptance of support, warranty, indemnity,
 | 
			
		||||
      or other liability obligations and/or rights consistent with this
 | 
			
		||||
      License. However, in accepting such obligations, You may act only
 | 
			
		||||
      on Your own behalf and on Your sole responsibility, not on behalf
 | 
			
		||||
      of any other Contributor, and only if You agree to indemnify,
 | 
			
		||||
      defend, and hold each Contributor harmless for any liability
 | 
			
		||||
      incurred by, or claims asserted against, such Contributor by reason
 | 
			
		||||
      of your accepting any such warranty or additional liability.
 | 
			
		||||
 | 
			
		||||
   END OF TERMS AND CONDITIONS
 | 
			
		||||
 | 
			
		||||
   APPENDIX: How to apply the Apache License to your work.
 | 
			
		||||
 | 
			
		||||
      To apply the Apache License to your work, attach the following
 | 
			
		||||
      boilerplate notice, with the fields enclosed by brackets "{}"
 | 
			
		||||
      replaced with your own identifying information. (Don't include
 | 
			
		||||
      the brackets!)  The text should be enclosed in the appropriate
 | 
			
		||||
      comment syntax for the file format. We also recommend that a
 | 
			
		||||
      file or class name and description of purpose be included on the
 | 
			
		||||
      same "printed page" as the copyright notice for easier
 | 
			
		||||
      identification within third-party archives.
 | 
			
		||||
 | 
			
		||||
   Copyright {yyyy} {name of copyright owner}
 | 
			
		||||
 | 
			
		||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
   you may not use this file except in compliance with the License.
 | 
			
		||||
   You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
       http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
   Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
   See the License for the specific language governing permissions and
 | 
			
		||||
   limitations under the License.
 | 
			
		||||
 | 
			
		||||
------------------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
This product bundles various third-party components under other open source licenses.
 | 
			
		||||
This section summarizes those components and their licenses. See licenses/
 | 
			
		||||
for text of these licenses.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Apache-2.0
 | 
			
		||||
* third_party/nvbench (with LLVM exception)
 | 
			
		||||
* third_party/flashinfer
 | 
			
		||||
 | 
			
		||||
BSD-3-Clause:
 | 
			
		||||
* third_party/cutlass
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_half, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_half)
 | 
			
		||||
							
								
								
									
										59
									
								
								csrc/punica/bgmv/bgmv_config.h
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								csrc/punica/bgmv/bgmv_config.h
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,59 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
 | 
			
		||||
          typename W_T>
 | 
			
		||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
                 const W_T *__restrict__ W,
 | 
			
		||||
                 const int64_t *__restrict__ indicies, int64_t y_offset,
 | 
			
		||||
                 int64_t full_y_size, int64_t batch_size, int64_t num_layers,
 | 
			
		||||
                 int64_t layer_idx, float scale);
 | 
			
		||||
 | 
			
		||||
// clang-format off
 | 
			
		||||
 | 
			
		||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 128) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 256) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 512) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 1024) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 1280) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 1728) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 1792) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 2048) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 2560) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 2752) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3072) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3456) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 3584) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 4096) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 5120) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 5504) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 5632) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 6912) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 7168) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 8192) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 9216) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 10240) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 11008) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 12288) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 13824) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 14336) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 16384) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 20480) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 28672) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32000) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32256) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32512) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 32768) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 33024) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 36864) \
 | 
			
		||||
    f(in_T, out_T, W_T, narrow, 49152) \
 | 
			
		||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
 | 
			
		||||
 | 
			
		||||
// Keep this in sync with vllm/config::LoRAConfig
 | 
			
		||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8)  \
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
 | 
			
		||||
    FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
 | 
			
		||||
 | 
			
		||||
// clang-format on
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_bfloat16, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_bfloat16)
 | 
			
		||||
							
								
								
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,4 @@
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, float, nv_half)
 | 
			
		||||
							
								
								
									
										294
									
								
								csrc/punica/bgmv/bgmv_impl.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										294
									
								
								csrc/punica/bgmv/bgmv_impl.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,294 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <cooperative_groups.h>
 | 
			
		||||
#include <cuda/pipeline>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
#include <iostream>
 | 
			
		||||
#include <stdio.h>
 | 
			
		||||
 | 
			
		||||
#include "vec_dtypes.cuh"
 | 
			
		||||
 | 
			
		||||
namespace cg = cooperative_groups;
 | 
			
		||||
 | 
			
		||||
// nthrs = (32, 4)
 | 
			
		||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
 | 
			
		||||
          size_t W_copy_size, int tx, int ty, int tz, typename in_T,
 | 
			
		||||
          typename out_T, typename W_T>
 | 
			
		||||
__global__ void
 | 
			
		||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
                   const W_T *__restrict__ W,
 | 
			
		||||
                   const int64_t *__restrict__ indicies, int64_t y_offset,
 | 
			
		||||
                   int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
 | 
			
		||||
                   float scale) {
 | 
			
		||||
  size_t batch_idx = blockIdx.y;
 | 
			
		||||
  int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
 | 
			
		||||
  if (idx < 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto block = cg::this_thread_block();
 | 
			
		||||
  size_t j = blockIdx.x;
 | 
			
		||||
  constexpr size_t num_pipeline_stages = 2;
 | 
			
		||||
  constexpr size_t tile_size = tx * ty * vec_size;
 | 
			
		||||
  __shared__ W_T W_shared[num_pipeline_stages * tile_size];
 | 
			
		||||
  __shared__ in_T X_shared[num_pipeline_stages * tile_size];
 | 
			
		||||
  __shared__ float y_warpwise[ty];
 | 
			
		||||
 | 
			
		||||
  size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
 | 
			
		||||
  size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
 | 
			
		||||
  auto pipe = cuda::make_pipeline();
 | 
			
		||||
 | 
			
		||||
  // pipeline load W/X and compute WX;
 | 
			
		||||
  pipe.producer_acquire();
 | 
			
		||||
  cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                     W + (idx * feat_out + j) * feat_in +
 | 
			
		||||
                         (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                     cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
 | 
			
		||||
  cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                     X + (batch_idx * feat_in) +
 | 
			
		||||
                         (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                     cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
 | 
			
		||||
  pipe.producer_commit();
 | 
			
		||||
  size_t copy_idx, compute_idx;
 | 
			
		||||
  float y = 0.f;
 | 
			
		||||
  vec_t<in_T, vec_size> x_vec;
 | 
			
		||||
  vec_t<W_T, vec_size> w_vec;
 | 
			
		||||
  size_t tile_idx;
 | 
			
		||||
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
 | 
			
		||||
       ++tile_idx) {
 | 
			
		||||
    copy_idx = tile_idx % num_pipeline_stages;
 | 
			
		||||
    // pipeline stage: async copy W fragment
 | 
			
		||||
    pipe.producer_acquire();
 | 
			
		||||
    if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
 | 
			
		||||
      cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
 | 
			
		||||
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                         W + (idx * feat_out + j) * feat_in +
 | 
			
		||||
                             tile_idx * tile_size +
 | 
			
		||||
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                         cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
 | 
			
		||||
      cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
 | 
			
		||||
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                         X + (batch_idx * feat_in) + tile_idx * tile_size +
 | 
			
		||||
                             (threadIdx.y * tx + threadIdx.x) * vec_size,
 | 
			
		||||
                         cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
 | 
			
		||||
    }
 | 
			
		||||
    pipe.producer_commit();
 | 
			
		||||
 | 
			
		||||
    compute_idx = (tile_idx - 1) % num_pipeline_stages;
 | 
			
		||||
    // pipeline stage: compute WX
 | 
			
		||||
    pipe.consumer_wait();
 | 
			
		||||
    block.sync();
 | 
			
		||||
    x_vec.load(X_shared + X_shared_offset[compute_idx] +
 | 
			
		||||
               (threadIdx.y * tx + threadIdx.x) * vec_size);
 | 
			
		||||
    w_vec.load(W_shared + W_shared_offset[compute_idx] +
 | 
			
		||||
               (threadIdx.y * tx + threadIdx.x) * vec_size);
 | 
			
		||||
    float sum = 0.f;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (size_t i = 0; i < vec_size; ++i) {
 | 
			
		||||
      sum += float(w_vec[i]) * float(x_vec[i]) * scale;
 | 
			
		||||
    }
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (size_t offset = tx / 2; offset > 0; offset /= 2) {
 | 
			
		||||
      sum += __shfl_down_sync(0xffffffff, sum, offset);
 | 
			
		||||
    }
 | 
			
		||||
    y_warpwise[threadIdx.y] = sum;
 | 
			
		||||
    block.sync();
 | 
			
		||||
#pragma unroll
 | 
			
		||||
    for (size_t i = 0; i < ty; ++i) {
 | 
			
		||||
      y += y_warpwise[i];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    block.sync();
 | 
			
		||||
    pipe.consumer_release();
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  compute_idx = (tile_idx - 1) % num_pipeline_stages;
 | 
			
		||||
  // final pipeline stage
 | 
			
		||||
  pipe.consumer_wait();
 | 
			
		||||
  block.sync();
 | 
			
		||||
  x_vec.load(X_shared + X_shared_offset[compute_idx] +
 | 
			
		||||
             (threadIdx.y * tx + threadIdx.x) * vec_size);
 | 
			
		||||
  w_vec.load(W_shared + W_shared_offset[compute_idx] +
 | 
			
		||||
             (threadIdx.y * tx + threadIdx.x) * vec_size);
 | 
			
		||||
  float sum = 0.f;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (size_t i = 0; i < vec_size; ++i) {
 | 
			
		||||
    sum += float(w_vec[i]) * float(x_vec[i]) * scale;
 | 
			
		||||
  }
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (size_t offset = tx / 2; offset > 0; offset /= 2) {
 | 
			
		||||
    sum += __shfl_down_sync(0xffffffff, sum, offset);
 | 
			
		||||
  }
 | 
			
		||||
  y_warpwise[threadIdx.y] =
 | 
			
		||||
      ((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
 | 
			
		||||
          ? sum
 | 
			
		||||
          : 0.f;
 | 
			
		||||
  block.sync();
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (size_t i = 0; i < ty; ++i) {
 | 
			
		||||
    y += y_warpwise[i];
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  block.sync();
 | 
			
		||||
  pipe.consumer_release();
 | 
			
		||||
 | 
			
		||||
  // write Y;
 | 
			
		||||
  if (block.thread_rank() == 0) {
 | 
			
		||||
    Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// nthrs = (2, 16, 4)
 | 
			
		||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
 | 
			
		||||
          typename in_T, typename out_T, typename W_T>
 | 
			
		||||
__global__ void
 | 
			
		||||
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
                   const W_T *__restrict__ W,
 | 
			
		||||
                   const int64_t *__restrict__ indicies, int64_t y_offset,
 | 
			
		||||
                   int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
 | 
			
		||||
                   float scale) {
 | 
			
		||||
  size_t batch_idx = blockIdx.y;
 | 
			
		||||
  int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
 | 
			
		||||
 | 
			
		||||
  if (idx < 0) {
 | 
			
		||||
    return;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  auto block = cg::this_thread_block();
 | 
			
		||||
  size_t tile_idx = blockIdx.x;
 | 
			
		||||
 | 
			
		||||
  // load X;
 | 
			
		||||
  vec_t<in_T, vec_size> x_vec;
 | 
			
		||||
  x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
 | 
			
		||||
 | 
			
		||||
  // load W;
 | 
			
		||||
  vec_t<W_T, vec_size> w_vec;
 | 
			
		||||
  w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
 | 
			
		||||
             block.thread_rank() * vec_size);
 | 
			
		||||
 | 
			
		||||
  float sum = 0.f;
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (size_t i = 0; i < vec_size; ++i) {
 | 
			
		||||
    sum += float(w_vec[i]) * float(x_vec[i]) * scale;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  cg::thread_block_tile g = cg::tiled_partition<tx>(block);
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (size_t offset = tx / 2; offset > 0; offset /= 2) {
 | 
			
		||||
    sum += g.shfl_down(sum, offset);
 | 
			
		||||
  }
 | 
			
		||||
  sum = g.shfl(sum, 0);
 | 
			
		||||
 | 
			
		||||
  if (threadIdx.x == 0) {
 | 
			
		||||
    Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
 | 
			
		||||
      threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
 | 
			
		||||
          typename W_T>
 | 
			
		||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
 | 
			
		||||
                 const W_T *__restrict__ W,
 | 
			
		||||
                 const int64_t *__restrict__ indicies, int64_t y_offset,
 | 
			
		||||
                 int64_t full_y_size, int64_t batch_size, int64_t num_layers,
 | 
			
		||||
                 int64_t layer_idx, float scale) {
 | 
			
		||||
  constexpr size_t vec_size = 8;
 | 
			
		||||
  constexpr int tz = 4;
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
 | 
			
		||||
  if constexpr (feat_in < feat_out) {
 | 
			
		||||
    static_assert(feat_in % vec_size == 0);
 | 
			
		||||
    constexpr int tx = feat_in / vec_size;
 | 
			
		||||
 | 
			
		||||
    static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
 | 
			
		||||
                  (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
 | 
			
		||||
                  (8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
 | 
			
		||||
 | 
			
		||||
    if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
 | 
			
		||||
      constexpr int ty = 32 / tx;
 | 
			
		||||
      dim3 nblks(feat_out / (ty * tz), batch_size);
 | 
			
		||||
      dim3 nthrs(tx, ty, tz);
 | 
			
		||||
 | 
			
		||||
      bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
 | 
			
		||||
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
 | 
			
		||||
                                        full_y_size, num_layers, layer_idx,
 | 
			
		||||
                                        scale);
 | 
			
		||||
    } else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
 | 
			
		||||
      constexpr int ty = 16 / tx;
 | 
			
		||||
      dim3 nblks(feat_out / (ty * tz), batch_size);
 | 
			
		||||
      dim3 nthrs(tx, ty, tz);
 | 
			
		||||
 | 
			
		||||
      bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
 | 
			
		||||
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
 | 
			
		||||
                                        full_y_size, num_layers, layer_idx,
 | 
			
		||||
                                        scale);
 | 
			
		||||
    } else {
 | 
			
		||||
      constexpr int ty = 8 / tx;
 | 
			
		||||
      dim3 nblks(feat_out / (ty * tz), batch_size);
 | 
			
		||||
      dim3 nthrs(tx, ty, tz);
 | 
			
		||||
 | 
			
		||||
      bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
 | 
			
		||||
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
 | 
			
		||||
                                        full_y_size, num_layers, layer_idx,
 | 
			
		||||
                                        scale);
 | 
			
		||||
    }
 | 
			
		||||
  } else {
 | 
			
		||||
    static_assert(feat_in % (vec_size * 32) == 0 ||
 | 
			
		||||
                  feat_in % (vec_size * 16) == 0 ||
 | 
			
		||||
                  feat_in % (vec_size * 8) == 0);
 | 
			
		||||
 | 
			
		||||
    if constexpr (feat_in % (vec_size * 32) == 0) {
 | 
			
		||||
      constexpr int tx = 32;
 | 
			
		||||
      constexpr int ty = 4;
 | 
			
		||||
 | 
			
		||||
      dim3 nblks(feat_out, batch_size);
 | 
			
		||||
      dim3 nthrs(tx, ty);
 | 
			
		||||
 | 
			
		||||
      bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
 | 
			
		||||
                         vec_size * sizeof(W_T), tx, ty, tz>
 | 
			
		||||
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
 | 
			
		||||
                                        full_y_size, num_layers, layer_idx,
 | 
			
		||||
                                        scale);
 | 
			
		||||
    } else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
 | 
			
		||||
      constexpr int tx = 32;
 | 
			
		||||
      constexpr int ty = 4;
 | 
			
		||||
 | 
			
		||||
      dim3 nblks(feat_out, batch_size);
 | 
			
		||||
      dim3 nthrs(tx, ty);
 | 
			
		||||
 | 
			
		||||
      bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
 | 
			
		||||
                         vec_size * sizeof(in_T) / 2,
 | 
			
		||||
                         vec_size * sizeof(W_T) / 2, tx, ty, tz>
 | 
			
		||||
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
 | 
			
		||||
                                        full_y_size, num_layers, layer_idx,
 | 
			
		||||
                                        scale);
 | 
			
		||||
    } else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
 | 
			
		||||
      constexpr int tx = 16;
 | 
			
		||||
      constexpr int ty = 4;
 | 
			
		||||
 | 
			
		||||
      dim3 nblks(feat_out, batch_size);
 | 
			
		||||
      dim3 nthrs(tx, ty);
 | 
			
		||||
 | 
			
		||||
      bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
 | 
			
		||||
                         vec_size * sizeof(in_T) / 2,
 | 
			
		||||
                         vec_size * sizeof(W_T) / 2, tx, ty, tz>
 | 
			
		||||
          <<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
 | 
			
		||||
                                        full_y_size, num_layers, layer_idx,
 | 
			
		||||
                                        scale);
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)                         \
 | 
			
		||||
  template void bgmv_kernel<feat_in, feat_out>(                                \
 | 
			
		||||
      out_T * __restrict__ Y, const in_T *__restrict__ X,                      \
 | 
			
		||||
      const W_T *__restrict__ W, const int64_t *__restrict__ indicies,         \
 | 
			
		||||
      int64_t y_offset, int64_t full_y_size, int64_t batch_size,               \
 | 
			
		||||
      int64_t num_layers, int64_t layer_idx, float scale);
 | 
			
		||||
 | 
			
		||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide)                      \
 | 
			
		||||
  INST_BGMV(narrow, wide, in_T, out_T, W_T)                                    \
 | 
			
		||||
  INST_BGMV(wide, narrow, in_T, out_T, W_T)
 | 
			
		||||
							
								
								
									
										27
									
								
								csrc/punica/bgmv/generator.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								csrc/punica/bgmv/generator.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,27 @@
 | 
			
		||||
DTYPES = ["fp16", "bf16", "fp32"]
 | 
			
		||||
DTYPE_MAP = {
 | 
			
		||||
    "fp16": "nv_half",
 | 
			
		||||
    "bf16": "nv_bfloat16",
 | 
			
		||||
    "fp32": "float",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
TEMPLATE = """
 | 
			
		||||
#include "bgmv_config.h"
 | 
			
		||||
#include "bgmv_impl.cuh"
 | 
			
		||||
 | 
			
		||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
 | 
			
		||||
""".lstrip()
 | 
			
		||||
 | 
			
		||||
for input_dtype in DTYPES:
 | 
			
		||||
    for output_dtype in DTYPES:
 | 
			
		||||
        for weight_dtype in DTYPES:
 | 
			
		||||
            if weight_dtype == "fp32":
 | 
			
		||||
                # FP32 weights are not supported.
 | 
			
		||||
                continue
 | 
			
		||||
            kernel_definition = TEMPLATE.format(
 | 
			
		||||
                input_dtype=DTYPE_MAP[input_dtype],
 | 
			
		||||
                output_dtype=DTYPE_MAP[output_dtype],
 | 
			
		||||
                weight_dtype=DTYPE_MAP[weight_dtype])
 | 
			
		||||
            filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
 | 
			
		||||
            with open(filename, "w") as f:
 | 
			
		||||
                f.write(kernel_definition)
 | 
			
		||||
							
								
								
									
										1324
									
								
								csrc/punica/bgmv/vec_dtypes.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1324
									
								
								csrc/punica/bgmv/vec_dtypes.cuh
									
									
									
									
									
										Normal file
									
								
							
										
											
												File diff suppressed because it is too large
												Load Diff
											
										
									
								
							
							
								
								
									
										563
									
								
								csrc/punica/punica_ops.cc
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										563
									
								
								csrc/punica/punica_ops.cc
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,563 @@
 | 
			
		||||
#include <cuda_bf16.h>
 | 
			
		||||
#include <cuda_fp16.h>
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
 | 
			
		||||
#include "bgmv/bgmv_config.h"
 | 
			
		||||
 | 
			
		||||
namespace {
 | 
			
		||||
 | 
			
		||||
//====== utils ======
 | 
			
		||||
 | 
			
		||||
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
 | 
			
		||||
                        const char *a_name, const char *b_name) {
 | 
			
		||||
  TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
 | 
			
		||||
              a.dim(), " vs ", b.dim());
 | 
			
		||||
  for (int i = 0; i < a.dim(); ++i) {
 | 
			
		||||
    TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
 | 
			
		||||
                ".size(", i, ")");
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
inline constexpr uint32_t pack_u16(uint16_t a, uint16_t b) {
 | 
			
		||||
  return (uint32_t(a) << 16) | uint32_t(b);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
 | 
			
		||||
 | 
			
		||||
#define CHECK_CONTIGUOUS(x)                                                    \
 | 
			
		||||
  TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
 | 
			
		||||
 | 
			
		||||
#define CHECK_INPUT(x)                                                         \
 | 
			
		||||
  CHECK_CUDA(x);                                                               \
 | 
			
		||||
  CHECK_CONTIGUOUS(x)
 | 
			
		||||
 | 
			
		||||
#define CHECK_DIM(d, x)                                                        \
 | 
			
		||||
  TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
 | 
			
		||||
 | 
			
		||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
 | 
			
		||||
 | 
			
		||||
#define CHECK_EQ(a, b)                                                         \
 | 
			
		||||
  TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
 | 
			
		||||
 | 
			
		||||
//====== bgmv ======
 | 
			
		||||
 | 
			
		||||
template <typename in_T, typename out_T, typename W_T>
 | 
			
		||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
 | 
			
		||||
                               const int64_t *lora_indices,
 | 
			
		||||
                               uint16_t in_features, uint16_t out_features,
 | 
			
		||||
                               int64_t y_offset, int64_t full_y_size,
 | 
			
		||||
                               int64_t batch_size, int64_t num_layers,
 | 
			
		||||
                               int64_t layer_idx, float scale) {
 | 
			
		||||
  switch (pack_u16(in_features, out_features)) {
 | 
			
		||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out)                   \
 | 
			
		||||
  case pack_u16(feat_in, feat_out):                                            \
 | 
			
		||||
    bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset,            \
 | 
			
		||||
                                   full_y_size, batch_size, num_layers,        \
 | 
			
		||||
                                   layer_idx, scale);                          \
 | 
			
		||||
    break;
 | 
			
		||||
#define CASE(_in_T, _out_T, _W_T, narrow, wide)                                \
 | 
			
		||||
  CASE_ONESIDE(in_T, out_T, W_T, narrow, wide)                                 \
 | 
			
		||||
  CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
 | 
			
		||||
 | 
			
		||||
    FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
 | 
			
		||||
#undef CASE
 | 
			
		||||
#undef CASE_ONESIDE
 | 
			
		||||
  default:
 | 
			
		||||
    return false;
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  return true;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
 | 
			
		||||
                   torch::Tensor indicies, int64_t layer_idx, float scale) {
 | 
			
		||||
  CHECK_INPUT(y);
 | 
			
		||||
  CHECK_INPUT(x);
 | 
			
		||||
  CHECK_INPUT(w);
 | 
			
		||||
  CHECK_INPUT(indicies);
 | 
			
		||||
 | 
			
		||||
  CHECK_DIM(2, y);
 | 
			
		||||
  CHECK_DIM(2, x);
 | 
			
		||||
  CHECK_DIM(4, w);
 | 
			
		||||
  CHECK_DIM(1, indicies);
 | 
			
		||||
 | 
			
		||||
  int64_t B = x.size(0);
 | 
			
		||||
  int64_t h_in = x.size(1);
 | 
			
		||||
  int64_t h_out = y.size(1);
 | 
			
		||||
  int64_t num_layers = w.size(1);
 | 
			
		||||
  CHECK_EQ(w.size(3), h_in);
 | 
			
		||||
  CHECK_EQ(w.size(2), h_out);
 | 
			
		||||
  CHECK_EQ(indicies.size(0), x.size(0));
 | 
			
		||||
  CHECK_EQ(y.size(0), x.size(0));
 | 
			
		||||
  bool ok = false;
 | 
			
		||||
  if (h_in < 65536 && h_out < 65536) {
 | 
			
		||||
    // TODO: See if we can get rid of this massive nested switch
 | 
			
		||||
    switch (x.scalar_type()) {
 | 
			
		||||
    case at::ScalarType::Half:
 | 
			
		||||
      switch (y.scalar_type()) {
 | 
			
		||||
      case at::ScalarType::Half:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::BFloat16:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::Float:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    case at::ScalarType::BFloat16:
 | 
			
		||||
      switch (y.scalar_type()) {
 | 
			
		||||
      case at::ScalarType::Half:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::BFloat16:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::Float:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    case at::ScalarType::Float:
 | 
			
		||||
      switch (y.scalar_type()) {
 | 
			
		||||
      case at::ScalarType::Half:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::BFloat16:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::Float:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out, 0,
 | 
			
		||||
                                  h_out, B, num_layers, layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
 | 
			
		||||
              " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
 | 
			
		||||
                             torch::Tensor indicies, int64_t layer_idx,
 | 
			
		||||
                             float scale, int64_t h_in, int64_t h_out,
 | 
			
		||||
                             int64_t y_offset) {
 | 
			
		||||
  CHECK_INPUT(y);
 | 
			
		||||
  CHECK_INPUT(x);
 | 
			
		||||
  CHECK_INPUT(w);
 | 
			
		||||
  CHECK_INPUT(indicies);
 | 
			
		||||
 | 
			
		||||
  CHECK_DIM(2, y);
 | 
			
		||||
  CHECK_DIM(2, x);
 | 
			
		||||
  CHECK_DIM(4, w);
 | 
			
		||||
  CHECK_DIM(1, indicies);
 | 
			
		||||
 | 
			
		||||
  int64_t B = x.size(0);
 | 
			
		||||
  int64_t num_layers = w.size(1);
 | 
			
		||||
  int64_t full_y_size = y.size(1);
 | 
			
		||||
  CHECK_EQ(w.size(3), h_in);
 | 
			
		||||
  CHECK_EQ(w.size(2), h_out);
 | 
			
		||||
  CHECK_EQ(indicies.size(0), x.size(0));
 | 
			
		||||
  CHECK_EQ(y.size(0), x.size(0));
 | 
			
		||||
  bool ok = false;
 | 
			
		||||
  if (h_in < 65536 && h_out < 65536) {
 | 
			
		||||
    // TODO: See if we can get rid of this massive nested switch
 | 
			
		||||
    switch (x.scalar_type()) {
 | 
			
		||||
    case at::ScalarType::Half:
 | 
			
		||||
      switch (y.scalar_type()) {
 | 
			
		||||
      case at::ScalarType::Half:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::BFloat16:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::Float:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    case at::ScalarType::BFloat16:
 | 
			
		||||
      switch (y.scalar_type()) {
 | 
			
		||||
      case at::ScalarType::Half:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::BFloat16:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::Float:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    case at::ScalarType::Float:
 | 
			
		||||
      switch (y.scalar_type()) {
 | 
			
		||||
      case at::ScalarType::Half:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::BFloat16:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      case at::ScalarType::Float:
 | 
			
		||||
        switch (w.scalar_type()) {
 | 
			
		||||
        case at::ScalarType::Half:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_half *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        case at::ScalarType::BFloat16:
 | 
			
		||||
          ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
 | 
			
		||||
                                  static_cast<float *>(x.data_ptr()),
 | 
			
		||||
                                  static_cast<nv_bfloat16 *>(w.data_ptr()),
 | 
			
		||||
                                  indicies.data_ptr<int64_t>(), h_in, h_out,
 | 
			
		||||
                                  y_offset, full_y_size, B, num_layers,
 | 
			
		||||
                                  layer_idx, scale);
 | 
			
		||||
          break;
 | 
			
		||||
        default:
 | 
			
		||||
          break;
 | 
			
		||||
        }
 | 
			
		||||
        break;
 | 
			
		||||
      default:
 | 
			
		||||
        break;
 | 
			
		||||
      }
 | 
			
		||||
      break;
 | 
			
		||||
    default:
 | 
			
		||||
      break;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
 | 
			
		||||
              " dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace
 | 
			
		||||
 | 
			
		||||
//====== pybind ======
 | 
			
		||||
 | 
			
		||||
#define DEFINE_pybind(name) m.def(#name, &name, #name);
 | 
			
		||||
 | 
			
		||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
  m.def("dispatch_bgmv", &dispatch_bgmv, "dispatch_bgmv");
 | 
			
		||||
  m.def("dispatch_bgmv_low_level", &dispatch_bgmv_low_level,
 | 
			
		||||
        "dispatch_bgmv_low_level");
 | 
			
		||||
}
 | 
			
		||||
@ -49,8 +49,17 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    "Apply GPT-NeoX or GPT-J style rotary embedding to query and key");
 | 
			
		||||
 | 
			
		||||
  // Quantization ops
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
 | 
			
		||||
  ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ");
 | 
			
		||||
#endif
 | 
			
		||||
  ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
 | 
			
		||||
  ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
 | 
			
		||||
  ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
 | 
			
		||||
  ops.def(
 | 
			
		||||
    "moe_align_block_size",
 | 
			
		||||
    &moe_align_block_size,
 | 
			
		||||
    "Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");
 | 
			
		||||
 | 
			
		||||
  // Cache ops
 | 
			
		||||
  pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
 | 
			
		||||
@ -70,6 +79,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    "gather_cached_kv",
 | 
			
		||||
    &gather_cached_kv,
 | 
			
		||||
    "Gather key and value from the cache into contiguous QKV tensors");
 | 
			
		||||
  cache_ops.def(
 | 
			
		||||
    "convert_fp8_e5m2",
 | 
			
		||||
    &convert_fp8_e5m2,
 | 
			
		||||
    "Convert the key and value cache to fp8_e5m2 data type");
 | 
			
		||||
 | 
			
		||||
  // Cuda utils
 | 
			
		||||
  pybind11::module cuda_utils = m.def_submodule("cuda_utils", "vLLM cuda utils");
 | 
			
		||||
@ -77,4 +90,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
 | 
			
		||||
    "get_device_attribute",
 | 
			
		||||
    &get_device_attribute,
 | 
			
		||||
    "Gets the specified device attribute.");
 | 
			
		||||
 | 
			
		||||
  cuda_utils.def(
 | 
			
		||||
    "get_max_shared_memory_per_block_device_attribute",
 | 
			
		||||
    &get_max_shared_memory_per_block_device_attribute,
 | 
			
		||||
    "Gets the maximum shared memory per block device attribute.");
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  // Custom all-reduce kernels
 | 
			
		||||
  pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
 | 
			
		||||
  custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
 | 
			
		||||
  custom_ar.def("should_custom_ar", &should_custom_ar, "should_custom_ar");
 | 
			
		||||
  custom_ar.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg");
 | 
			
		||||
  custom_ar.def("all_reduce_unreg", &all_reduce_unreg, "all_reduce_unreg");
 | 
			
		||||
  custom_ar.def("dispose", &dispose, "dispose");
 | 
			
		||||
  custom_ar.def("meta_size", &meta_size, "meta_size");
 | 
			
		||||
  custom_ar.def("register_buffer", ®ister_buffer, "register_buffer");
 | 
			
		||||
  custom_ar.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta,
 | 
			
		||||
                "get_graph_buffer_ipc_meta");
 | 
			
		||||
  custom_ar.def("register_graph_buffers", ®ister_graph_buffers,
 | 
			
		||||
                "register_graph_buffers");
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -27,35 +27,48 @@ __pack_half2(const half x, const half y) {
 | 
			
		||||
  return (v1 << 16) | v0;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
 | 
			
		||||
template<int N>
 | 
			
		||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
 | 
			
		||||
  int G,
 | 
			
		||||
  int split_k_iters,
 | 
			
		||||
  half* __restrict__ A,
 | 
			
		||||
  int* __restrict__ B,
 | 
			
		||||
  half* __restrict__ scaling_factors,
 | 
			
		||||
  int* __restrict__ zeros,
 | 
			
		||||
  int M,
 | 
			
		||||
  int IC,
 | 
			
		||||
  int OC,
 | 
			
		||||
  half* __restrict__ C)
 | 
			
		||||
{
 | 
			
		||||
  // Only support matrix n = 64 or 128
 | 
			
		||||
  assert(N == 64 || N == 128);
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
 | 
			
		||||
  assert(false);
 | 
			
		||||
#else
 | 
			
		||||
  static constexpr uint32_t ZERO = 0x0;
 | 
			
		||||
  float C_warp[32];
 | 
			
		||||
  __shared__ half A_shared[16 * (32 + 8)];
 | 
			
		||||
  __shared__ half B_shared[32 * (128 + 8)];
 | 
			
		||||
  __shared__ half B_shared[32 * (N + 8)];
 | 
			
		||||
 | 
			
		||||
  __shared__ half scaling_factors_shared[128];
 | 
			
		||||
  __shared__ half zeros_shared[128];
 | 
			
		||||
  __shared__ half scaling_factors_shared[N];
 | 
			
		||||
  __shared__ half zeros_shared[N];
 | 
			
		||||
 | 
			
		||||
  int j_factors1 = ((OC + 128 - 1) / 128);
 | 
			
		||||
  int j_factors1 = ((OC + N - 1) / N);
 | 
			
		||||
  int blockIdx_x = 0;
 | 
			
		||||
  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
 | 
			
		||||
  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
 | 
			
		||||
 | 
			
		||||
  half A_shared_warp[8];
 | 
			
		||||
  half B_shared_warp[32];
 | 
			
		||||
  for (int j_0_4_init = 0; j_0_4_init < 4; ++j_0_4_init) {
 | 
			
		||||
  half B_shared_warp[N / 4];
 | 
			
		||||
  for (int j_0_4_init = 0; j_0_4_init < N / 32; ++j_0_4_init) {
 | 
			
		||||
    for (int i = 0; i < 8; ++i) {
 | 
			
		||||
      C_warp[(j_0_4_init * 8) + i] = 0.0;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  static constexpr int row_stride_warp = 32 * 8 / 32;
 | 
			
		||||
  static constexpr int row_stride = 2 * 32 * 8 / 128;
 | 
			
		||||
  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 128;
 | 
			
		||||
  static constexpr int row_stride = 2 * 32 * 8 / N;
 | 
			
		||||
  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
 | 
			
		||||
  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
			
		||||
  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
 | 
			
		||||
  // bool wb_C_flag = (threadIdx.x / 4) < M;
 | 
			
		||||
@ -65,10 +78,10 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
                + (((int)threadIdx.x) % (32 / 8)) * 8;
 | 
			
		||||
 | 
			
		||||
  int* B_ptr = B
 | 
			
		||||
            + ((int)threadIdx.y) * (OC / 8) * 2
 | 
			
		||||
            + (((int)threadIdx.x) / (128 / 8)) * (OC / 8)
 | 
			
		||||
            + (((int)blockIdx_y) % j_factors1) * (128 / 8)
 | 
			
		||||
            + (((int)threadIdx.x) % (128 / 8)) * 1;
 | 
			
		||||
            + ((int)threadIdx.y) * (OC / 8) * (256 / N)
 | 
			
		||||
            + (((int)threadIdx.x) / (N / 8)) * (OC / 8)
 | 
			
		||||
            + (((int)blockIdx_y) % j_factors1) * (N / 8)
 | 
			
		||||
            + (((int)threadIdx.x) % (N / 8)) * 1;
 | 
			
		||||
// Why * 1 in the above line?
 | 
			
		||||
 | 
			
		||||
  half* A_shared_ptr = A_shared
 | 
			
		||||
@ -77,22 +90,22 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
 | 
			
		||||
 | 
			
		||||
  half* B_shared_ptr = B_shared
 | 
			
		||||
                    + ((int)threadIdx.y) * (row_stride / 2) * (128 + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) / (128 / 8)) * (128 + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) % (128 / 8)) * 8;
 | 
			
		||||
                    + ((int)threadIdx.y) * (row_stride / 2) * (N + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) / (N / 8)) * (N + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) % (N / 8)) * 8;
 | 
			
		||||
 | 
			
		||||
  int* zeros_ptr = zeros
 | 
			
		||||
                + (((int)blockIdx_y) % j_factors1) * (128 / 8)
 | 
			
		||||
                + ((int)threadIdx.x) % (128 / 8);
 | 
			
		||||
                + (((int)blockIdx_y) % j_factors1) * (N / 8)
 | 
			
		||||
                + ((int)threadIdx.x) % (N / 8);
 | 
			
		||||
 | 
			
		||||
  half* scaling_factors_ptr = scaling_factors
 | 
			
		||||
                            + (((int)blockIdx_y) % j_factors1) * (128) 
 | 
			
		||||
                            + (((int)threadIdx.x) % (128 / 8)) * 8;
 | 
			
		||||
                            + (((int)blockIdx_y) % j_factors1) * N
 | 
			
		||||
                            + (((int)threadIdx.x) % (N / 8)) * 8;
 | 
			
		||||
 | 
			
		||||
  half* C_ptr = C
 | 
			
		||||
              + static_cast<long long>(blockIdx_z) * M * OC        // blockIdz.x -> split_k dim
 | 
			
		||||
              + (((int)blockIdx_y) % j_factors1) * 128
 | 
			
		||||
              + ((int)threadIdx.y) * 64
 | 
			
		||||
              + (((int)blockIdx_y) % j_factors1) * N
 | 
			
		||||
              + ((int)threadIdx.y) * (N / 2)
 | 
			
		||||
              + (((int)threadIdx.x) % 4) * 2;
 | 
			
		||||
 | 
			
		||||
  // preload s.f. and zeros
 | 
			
		||||
@ -123,7 +136,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
 | 
			
		||||
    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
 | 
			
		||||
 | 
			
		||||
    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 8; ++ax0_ax1_fused_0) {
 | 
			
		||||
    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < N / 16; ++ax0_ax1_fused_0) {
 | 
			
		||||
 | 
			
		||||
      // B: 32 x 136 (128+8) float16
 | 
			
		||||
      // each warp: 32 x 4
 | 
			
		||||
@ -152,7 +165,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
      */
 | 
			
		||||
 | 
			
		||||
      // write back
 | 
			
		||||
      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (128 + 8)) = B_loaded_fp16;
 | 
			
		||||
      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (N + 8)) = B_loaded_fp16;
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
@ -174,13 +187,13 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
        );
 | 
			
		||||
      }
 | 
			
		||||
 | 
			
		||||
      for (int ax1_0 = 0; ax1_0 < 4; ++ax1_0) {
 | 
			
		||||
      for (int ax1_0 = 0; ax1_0 < N / 32; ++ax1_0) {
 | 
			
		||||
        {
 | 
			
		||||
          unsigned int addr;
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
			
		||||
            : "=r"(addr)
 | 
			
		||||
            : "l"((void *)((&(B_shared[(((k_0_1 * 2176) + (((int)threadIdx.y) * 64)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 136) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
			
		||||
            : "l"((void *)((&(B_shared[(((k_0_1 * (N * 16 + 128)) + (((int)threadIdx.y) * (N / 2))) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * (N + 8)) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
			
		||||
          );
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
 | 
			
		||||
@ -190,7 +203,7 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
          );
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      for (int j_0_4 = 0; j_0_4 < 4; ++j_0_4) {
 | 
			
		||||
      for (int j_0_4 = 0; j_0_4 < N / 32; ++j_0_4) {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
@ -258,118 +271,45 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n128k32(int G, i
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, int split_k_iters, half* __restrict__ A, int* __restrict__ B, half* __restrict__ scaling_factors, int* __restrict__ zeros, int M, int IC, int OC, half* __restrict__ C) 
 | 
			
		||||
__global__ void __launch_bounds__(64) dequantize_weights(
 | 
			
		||||
    int* __restrict__ B,
 | 
			
		||||
    half* __restrict__ scaling_factors,
 | 
			
		||||
    int* __restrict__ zeros,
 | 
			
		||||
    half* __restrict__ C,
 | 
			
		||||
    int G
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
 | 
			
		||||
  assert(false);
 | 
			
		||||
#else
 | 
			
		||||
  int j_factors1 = 4;
 | 
			
		||||
  int row_stride2 = 4;
 | 
			
		||||
  int split_k_iters = 1;
 | 
			
		||||
  static constexpr uint32_t ZERO = 0x0;
 | 
			
		||||
  float C_warp[32];
 | 
			
		||||
  __shared__ half A_shared[16 * (32 + 8)];
 | 
			
		||||
  __shared__ half B_shared[32 * (64 + 8)];
 | 
			
		||||
  half B_shared[32 * (128 + 8)];
 | 
			
		||||
 | 
			
		||||
  __shared__ half scaling_factors_shared[64];
 | 
			
		||||
  __shared__ half zeros_shared[64];
 | 
			
		||||
  half* B_shared_ptr2 = B_shared;
 | 
			
		||||
 | 
			
		||||
  int j_factors1 = ((OC + 64 - 1) / 64);
 | 
			
		||||
  half B_shared_warp[32];
 | 
			
		||||
  int OC = 512;
 | 
			
		||||
 | 
			
		||||
  int blockIdx_x = 0;
 | 
			
		||||
  int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
 | 
			
		||||
  int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
 | 
			
		||||
  int N = blockDim.x * gridDim.x;  // 2
 | 
			
		||||
  int col = (blockIdx.x * blockDim.x + threadIdx.x);
 | 
			
		||||
  int row = blockIdx.y * blockDim.y + threadIdx.y;
 | 
			
		||||
  int index1 = 8 * col + 8 * row * N;
 | 
			
		||||
  half* C_ptr2 = C + index1;
 | 
			
		||||
 | 
			
		||||
  half A_shared_warp[8];
 | 
			
		||||
  half B_shared_warp[16];
 | 
			
		||||
  for (int j_0_4_init = 0; j_0_4_init < 2; ++j_0_4_init) {
 | 
			
		||||
    for (int i = 0; i < 8; ++i) {
 | 
			
		||||
      C_warp[(j_0_4_init * 8) + i] = 0.0;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
  int index2 = col + row * N;
 | 
			
		||||
  int* B_ptr2 = B + index2;
 | 
			
		||||
 | 
			
		||||
  static constexpr int row_stride_warp = 32 * 8 / 32;
 | 
			
		||||
  static constexpr int row_stride = 2 * 32 * 8 / 64;
 | 
			
		||||
  bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < 64;
 | 
			
		||||
  // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
			
		||||
  bool ld_A_flag = (blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp + threadIdx.x * 8 / 32) < M;     // threadIdx.y is warp_id
 | 
			
		||||
  // bool wb_C_flag = (threadIdx.x / 4) < M;
 | 
			
		||||
  int index3 = col + (int)(row / G) * N;
 | 
			
		||||
  int* zeros_ptr2 = zeros + index3;
 | 
			
		||||
  int index4 = 8 * col + (int)(row / G) * N * 8;
 | 
			
		||||
  half* scaling_factors_ptr2 = scaling_factors + index4;
 | 
			
		||||
 | 
			
		||||
  half* A_ptr = A 
 | 
			
		||||
                + (((int)blockIdx_y) / j_factors1 * 16 + (((int)threadIdx.y) * row_stride_warp) + ((int)threadIdx.x) / (32 / 8)) * IC
 | 
			
		||||
                + (((int)threadIdx.x) % (32 / 8)) * 8;
 | 
			
		||||
  
 | 
			
		||||
  int* B_ptr = B
 | 
			
		||||
            + ((int)threadIdx.y) * (OC / 8) * 4
 | 
			
		||||
            + (((int)threadIdx.x) / (64 / 8)) * (OC / 8)
 | 
			
		||||
            + (((int)blockIdx_y) % j_factors1) * (64 / 8)
 | 
			
		||||
            + (((int)threadIdx.x) % (64 / 8)) * 1;
 | 
			
		||||
// Why * 1 in the above line?
 | 
			
		||||
                        
 | 
			
		||||
  half* A_shared_ptr = A_shared 
 | 
			
		||||
                    + ((int)threadIdx.y) * row_stride_warp * (32 + 8) 
 | 
			
		||||
                    + (((int)threadIdx.x) / (32 / 8)) * (32 + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) % (32 / 8) ) * 8;
 | 
			
		||||
 | 
			
		||||
  half* B_shared_ptr = B_shared
 | 
			
		||||
                    + ((int)threadIdx.y) * (row_stride / 2) * (64 + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) / (64 / 8)) * (64 + 8)
 | 
			
		||||
                    + (((int)threadIdx.x) % (64 / 8)) * 8;
 | 
			
		||||
  
 | 
			
		||||
  int* zeros_ptr = zeros
 | 
			
		||||
                + (((int)blockIdx_y) % j_factors1) * (64 / 8)
 | 
			
		||||
                + ((int)threadIdx.x) % (64 / 8);
 | 
			
		||||
  
 | 
			
		||||
  half* scaling_factors_ptr = scaling_factors
 | 
			
		||||
                            + (((int)blockIdx_y) % j_factors1) * (64) 
 | 
			
		||||
                            + (((int)threadIdx.x) % (64 / 8)) * 8;
 | 
			
		||||
 | 
			
		||||
  half* C_ptr = C 
 | 
			
		||||
              + static_cast<long long>(blockIdx_z) * M * OC        // blockIdz.x -> split_k dim
 | 
			
		||||
              + (((int)blockIdx_y) % j_factors1) * 64
 | 
			
		||||
              + ((int)threadIdx.y) * 32
 | 
			
		||||
              + (((int)threadIdx.x) % 4) * 2;
 | 
			
		||||
 | 
			
		||||
  // preload s.f. and zeros
 | 
			
		||||
  int k_bound = (IC / 32 + split_k_iters - 1) / split_k_iters;
 | 
			
		||||
  if ((k_bound - 1) * split_k_iters * 32 + blockIdx_z * 32 >= IC) k_bound -= 1;
 | 
			
		||||
  for (int _k_0_0 = 0; _k_0_0 < k_bound; ++_k_0_0) {
 | 
			
		||||
    int k_0_0 = _k_0_0 * split_k_iters + blockIdx_z;
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
    // TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
 | 
			
		||||
    if (ld_A_flag)
 | 
			
		||||
    {
 | 
			
		||||
      *(uint4*)(A_shared_ptr) = *(uint4*)(A_ptr + (k_0_0 * 32));
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
      *(uint4*)(A_shared_ptr) = make_uint4(0, 0, 0, 0);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
 | 
			
		||||
    uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr + k_0_0 * 32 / G * (OC / 8));
 | 
			
		||||
  uint32_t zeros_loaded = *(uint32_t*)(zeros_ptr2);
 | 
			
		||||
  uint4 B_loaded_zero = dequantize_s4_to_fp16x2(zeros_loaded);
 | 
			
		||||
    uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr + k_0_0 * 32 / G * (OC));
 | 
			
		||||
    /*
 | 
			
		||||
    if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
 | 
			
		||||
      printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
 | 
			
		||||
    }
 | 
			
		||||
    */
 | 
			
		||||
    // uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
 | 
			
		||||
    int* B_ptr_local = B_ptr + k_0_0 * 32 * (OC / 8);
 | 
			
		||||
  uint4 B_loaded_scale = *(uint4*)(scaling_factors_ptr2);
 | 
			
		||||
 | 
			
		||||
    for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 4; ++ax0_ax1_fused_0) {
 | 
			
		||||
 | 
			
		||||
      // B: 32 x 136 (128+8) float16
 | 
			
		||||
      // each warp: 32 x 4
 | 
			
		||||
      // each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
 | 
			
		||||
      // *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
 | 
			
		||||
      // row stride in shared memory: (NWARPS * 32 * 8 / cta_N) 
 | 
			
		||||
      uint32_t B_loaded = *(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
 | 
			
		||||
  uint32_t B_loaded = *(uint32_t*)B_ptr2;
 | 
			
		||||
  uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
 | 
			
		||||
      //uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
 | 
			
		||||
 | 
			
		||||
      // uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
 | 
			
		||||
      // - zero and * scale
 | 
			
		||||
      // TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
 | 
			
		||||
  asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_zero.x));
 | 
			
		||||
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.x) : "r"(B_loaded_fp16.x), "r"(B_loaded_scale.x), "r"(ZERO));
 | 
			
		||||
  asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.y) : "r"(B_loaded_fp16.y), "r"(B_loaded_zero.y));
 | 
			
		||||
@ -378,124 +318,68 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16n64k32(int G, in
 | 
			
		||||
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.z) : "r"(B_loaded_fp16.z), "r"(B_loaded_scale.z), "r"(ZERO));
 | 
			
		||||
  asm volatile("sub.f16x2 %0, %1, %2;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_zero.w));
 | 
			
		||||
  asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(B_loaded_fp16.w) : "r"(B_loaded_fp16.w), "r"(B_loaded_scale.w), "r"(ZERO));
 | 
			
		||||
      /*
 | 
			
		||||
      if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
 | 
			
		||||
        printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
 | 
			
		||||
      }
 | 
			
		||||
      */
 | 
			
		||||
 | 
			
		||||
      // write back
 | 
			
		||||
      *(uint4*)(B_shared_ptr + ax0_ax1_fused_0 * row_stride * (64 + 8)) = B_loaded_fp16;
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
  *(uint4*)B_shared_ptr2 = B_loaded_fp16;
 | 
			
		||||
 | 
			
		||||
    for (int k_0_1 = 0; k_0_1 < 2; ++k_0_1) 
 | 
			
		||||
    {
 | 
			
		||||
      {
 | 
			
		||||
        unsigned int addr;
 | 
			
		||||
        __asm__ __volatile__(
 | 
			
		||||
          "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
			
		||||
          : "=r"(addr)
 | 
			
		||||
          : "l"((void *)((&(A_shared[(k_0_1 * 16)])) + (((((int)threadIdx.x) & 15) * 40) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
			
		||||
        );
 | 
			
		||||
        __asm__ __volatile__(
 | 
			
		||||
          "ldmatrix.sync.aligned.m8n8.x4.shared.b16"
 | 
			
		||||
          "{%0, %1, %2, %3}, [%4];\n"
 | 
			
		||||
          : "=r"(((unsigned *)(A_shared_warp + 0))[0]), "=r"(((unsigned *)(A_shared_warp + 0))[1]), "=r"(((unsigned *)(A_shared_warp + 0))[2]), "=r"(((unsigned *)(A_shared_warp + 0))[3])
 | 
			
		||||
          : "r"(addr)
 | 
			
		||||
        );
 | 
			
		||||
  for (int i = 0; i < 8; ++i) {
 | 
			
		||||
    *(C_ptr2 + i) = B_shared[i];
 | 
			
		||||
  }
 | 
			
		||||
        
 | 
			
		||||
 | 
			
		||||
      for (int ax1_0 = 0; ax1_0 < 2; ++ax1_0) 
 | 
			
		||||
      {
 | 
			
		||||
        {
 | 
			
		||||
          unsigned int addr;
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }\n"
 | 
			
		||||
            : "=r"(addr)
 | 
			
		||||
            : "l"((void *)((&(B_shared[(((k_0_1 * 1152) + (((int)threadIdx.y) * 32)) + (ax1_0 * 16))])) + (((((int)threadIdx.x) & 15) * 72) + ((((int)threadIdx.x) >> 4) * 8))))
 | 
			
		||||
          );
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
 | 
			
		||||
            "{%0, %1, %2, %3}, [%4];\n"
 | 
			
		||||
            : "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[0]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[1]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[2]), "=r"(((unsigned *)(B_shared_warp + (ax1_0 * 8)))[3])
 | 
			
		||||
            : "r"(addr)
 | 
			
		||||
          );
 | 
			
		||||
        }
 | 
			
		||||
      }
 | 
			
		||||
      
 | 
			
		||||
      for (int j_0_4 = 0; j_0_4 < 2; ++j_0_4) 
 | 
			
		||||
      {
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
			
		||||
            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
			
		||||
            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
			
		||||
            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
			
		||||
            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
			
		||||
            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
			
		||||
            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
			
		||||
            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
			
		||||
            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
			
		||||
            : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
 | 
			
		||||
            "{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n"
 | 
			
		||||
            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
			
		||||
            : "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
			
		||||
        }
 | 
			
		||||
#else
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
 | 
			
		||||
            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
 | 
			
		||||
            :  "=f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "=f"(((float *)(C_warp + (j_0_4 * 8)))[3])
 | 
			
		||||
            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[0]), "r"(((unsigned *)(B_shared_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[0]), "f"(((float *)(C_warp + (j_0_4 * 8)))[1]), "f"(((float *)(C_warp + (j_0_4 * 8)))[2]), "f"(((float *)(C_warp + (j_0_4 * 8)))[3]));
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        {
 | 
			
		||||
          __asm__ __volatile__(
 | 
			
		||||
            "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
 | 
			
		||||
            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n"
 | 
			
		||||
            :  "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "=f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3])
 | 
			
		||||
            : "r"(((unsigned *)(A_shared_warp + 0))[0]), "r"(((unsigned *)(A_shared_warp + 0))[1]), "r"(((unsigned *)(A_shared_warp + 0))[2]), "r"(((unsigned *)(A_shared_warp + 0))[3]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[0]), "r"(((unsigned *)(B_shared_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[0]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[1]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[2]), "f"(((float *)(C_warp + ((j_0_4 * 8) + 4)))[3]));
 | 
			
		||||
        }
 | 
			
		||||
#endif
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
// TODO: Shang: Hoist loop invariance.
 | 
			
		||||
  for (int ax1_0_1 = 0; ax1_0_1 < 2; ++ax1_0_1) {
 | 
			
		||||
    for (int local_id = 0; local_id < 8; ++local_id) {
 | 
			
		||||
      int row_offset = (((int)blockIdx_y) / j_factors1) * 16 + ((int)threadIdx.x) / 4 + (local_id % 4) / 2 * 8;
 | 
			
		||||
      if (row_offset < M)
 | 
			
		||||
      {
 | 
			
		||||
        *(C_ptr + ax1_0_1 * 16 + row_offset * OC + (local_id / 4) * 8 + local_id % 2) = __float2half(C_warp[(ax1_0_1 * 8) + local_id]);
 | 
			
		||||
      }
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace awq
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
 | 
			
		||||
torch::Tensor awq_dequantize(
 | 
			
		||||
    torch::Tensor _kernel,
 | 
			
		||||
    torch::Tensor _scaling_factors,
 | 
			
		||||
    torch::Tensor _zeros,
 | 
			
		||||
    int split_k_iters,
 | 
			
		||||
    int thx,
 | 
			
		||||
    int thy)
 | 
			
		||||
{
 | 
			
		||||
    int in_c = _kernel.size(0);
 | 
			
		||||
    int qout_c = _kernel.size(1);
 | 
			
		||||
    int out_c = qout_c * 8;
 | 
			
		||||
    int G = in_c / _scaling_factors.size(0);
 | 
			
		||||
 | 
			
		||||
    int x_thread = thx;
 | 
			
		||||
    int y_thread = thy;
 | 
			
		||||
 | 
			
		||||
    int x_blocks = 1;
 | 
			
		||||
    int y_blocks = 1;
 | 
			
		||||
    if (thx==0) {
 | 
			
		||||
      x_thread = qout_c;
 | 
			
		||||
    }
 | 
			
		||||
    if (thy==0) {
 | 
			
		||||
      y_thread = in_c;
 | 
			
		||||
    }
 | 
			
		||||
    if (thx==0 && thy==0) {
 | 
			
		||||
      x_thread = 8;
 | 
			
		||||
      y_thread = 8;
 | 
			
		||||
      x_blocks = (int)(qout_c / 8);
 | 
			
		||||
      y_blocks = (int)(in_c / 8);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    const at::cuda::OptionalCUDAGuard device_guard(device_of(_scaling_factors));
 | 
			
		||||
 | 
			
		||||
    auto options = torch::TensorOptions().dtype(_scaling_factors.dtype()).device(_scaling_factors.device());
 | 
			
		||||
    at::Tensor _de_kernel = torch::empty({in_c, out_c}, options);
 | 
			
		||||
 | 
			
		||||
    auto kernel = reinterpret_cast<int*>(_kernel.data_ptr<int>());
 | 
			
		||||
    auto de_kernel = reinterpret_cast<half*>(_de_kernel.data_ptr<at::Half>());
 | 
			
		||||
    auto scaling_factors = reinterpret_cast<half*>(_scaling_factors.data_ptr<at::Half>());
 | 
			
		||||
    auto zeros = reinterpret_cast<int*>(_zeros.data_ptr<int>());
 | 
			
		||||
 | 
			
		||||
    dim3 num_blocks(x_blocks, y_blocks);
 | 
			
		||||
    dim3 threads_per_block(x_thread, y_thread);
 | 
			
		||||
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    vllm::awq::dequantize_weights<<<num_blocks, threads_per_block, 0, stream>>>(
 | 
			
		||||
        kernel, scaling_factors, zeros, de_kernel, G);
 | 
			
		||||
 | 
			
		||||
    return _de_kernel;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// in_feats: M, IC [float16]
 | 
			
		||||
// kernel: IC, OC // 8 [int32] -> cast to IC, OC [uint4b]
 | 
			
		||||
// scaling_factors: IC // G, OC [float16]
 | 
			
		||||
@ -542,8 +426,9 @@ torch::Tensor awq_gemm(
 | 
			
		||||
        // threadIdx.x: 32
 | 
			
		||||
        // threadIdx.y: i_factors[2] * j_factors[2]
 | 
			
		||||
        dim3 threads_per_block(32, 2);
 | 
			
		||||
        vllm::awq::gemm_forward_4bit_cuda_m16n128k32<<<num_blocks, threads_per_block, 0, stream>>>(
 | 
			
		||||
            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
 | 
			
		||||
        vllm::awq::gemm_forward_4bit_cuda_m16nXk32<128><<<num_blocks, threads_per_block, 0, stream>>>(
 | 
			
		||||
            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
 | 
			
		||||
            num_out_channels, out_feats);
 | 
			
		||||
    }
 | 
			
		||||
    else if (num_out_channels % 64 == 0)
 | 
			
		||||
    {
 | 
			
		||||
@ -553,8 +438,9 @@ torch::Tensor awq_gemm(
 | 
			
		||||
        // threadIdx.x: 32
 | 
			
		||||
        // threadIdx.y: i_factors[2] * j_factors[2]
 | 
			
		||||
        dim3 threads_per_block(32, 2);
 | 
			
		||||
        vllm::awq::gemm_forward_4bit_cuda_m16n64k32<<<num_blocks, threads_per_block, 0, stream>>>(
 | 
			
		||||
            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels, num_out_channels, out_feats);
 | 
			
		||||
        vllm::awq::gemm_forward_4bit_cuda_m16nXk32<64><<<num_blocks, threads_per_block, 0, stream>>>(
 | 
			
		||||
            group_size, split_k_iters, in_feats, kernel, scaling_factors, zeros, num_in_feats, num_in_channels,
 | 
			
		||||
            num_out_channels, out_feats);
 | 
			
		||||
    }
 | 
			
		||||
    return _out_feats.sum(0);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										277
									
								
								csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										277
									
								
								csrc/quantization/fp8_e5m2_kvcache/quant_utils.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,277 @@
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include <assert.h>
 | 
			
		||||
#include <stdint.h>
 | 
			
		||||
#include <float.h>
 | 
			
		||||
#include <type_traits>
 | 
			
		||||
#include "../../attention/attention_dtypes.h"
 | 
			
		||||
#include "../../attention/dtype_float32.cuh"
 | 
			
		||||
#include "../../attention/dtype_float16.cuh"
 | 
			
		||||
#include "../../attention/dtype_bfloat16.cuh"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
#ifdef ENABLE_FP8_E5M2
 | 
			
		||||
namespace fp8_e5m2_unscaled {
 | 
			
		||||
 | 
			
		||||
template<typename Tout, typename Tin>
 | 
			
		||||
__inline__ __device__ Tout vec_conversion(const Tin& x)
 | 
			
		||||
{
 | 
			
		||||
    return x;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8 -> half
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint16_t vec_conversion<uint16_t, uint8_t>(const uint8_t& a)
 | 
			
		||||
{
 | 
			
		||||
    __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
 | 
			
		||||
    return res.x;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x2 -> half2
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, uint16_t>(const uint16_t& a)
 | 
			
		||||
{
 | 
			
		||||
    union {
 | 
			
		||||
        uint16_t u16[2];
 | 
			
		||||
        uint32_t u32;
 | 
			
		||||
    } tmp;
 | 
			
		||||
    __half2_raw res = __nv_cvt_fp8x2_to_halfraw2(a, __NV_E5M2);
 | 
			
		||||
    tmp.u16[0] = res.x;
 | 
			
		||||
    tmp.u16[1] = res.y;
 | 
			
		||||
    return tmp.u32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x4 -> half2x2
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint2 vec_conversion<uint2, uint32_t>(const uint32_t& a)
 | 
			
		||||
{
 | 
			
		||||
    union {
 | 
			
		||||
        uint2    u32x2;
 | 
			
		||||
        uint32_t u32[2];
 | 
			
		||||
    } tmp;
 | 
			
		||||
    tmp.u32[0] = vec_conversion<uint32_t, uint16_t>((uint16_t)a);
 | 
			
		||||
    tmp.u32[1] = vec_conversion<uint32_t, uint16_t>((uint16_t)(a >> 16U));
 | 
			
		||||
    return tmp.u32x2;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x8 -> half2x4
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint4 vec_conversion<uint4, uint2>(const uint2& a)
 | 
			
		||||
{
 | 
			
		||||
    union {
 | 
			
		||||
        uint4 u64x2;
 | 
			
		||||
        uint2 u64[2];
 | 
			
		||||
    } tmp;
 | 
			
		||||
    tmp.u64[0] = vec_conversion<uint2, uint32_t>(a.x);
 | 
			
		||||
    tmp.u64[1] = vec_conversion<uint2, uint32_t>(a.y);
 | 
			
		||||
    return tmp.u64x2;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8 -> __nv_bfloat16
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ __nv_bfloat16 vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a)
 | 
			
		||||
{
 | 
			
		||||
    // Note there is no direct convert function from fp8 to bf16.
 | 
			
		||||
    // fp8 -> half
 | 
			
		||||
    __half_raw res = __nv_cvt_fp8_to_halfraw(a, __NV_E5M2);
 | 
			
		||||
    // half -> float -> bf16
 | 
			
		||||
    float tmp = half_to_float(res.x);
 | 
			
		||||
    return __float2bfloat16(tmp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x2 -> __nv_bfloat162
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a)
 | 
			
		||||
{
 | 
			
		||||
    __nv_bfloat162 res;
 | 
			
		||||
    res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a);
 | 
			
		||||
    res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U));
 | 
			
		||||
    return res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x4 -> bf16_4_t
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, uint32_t>(const uint32_t& a)
 | 
			
		||||
{
 | 
			
		||||
    bf16_4_t res;
 | 
			
		||||
    res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a);
 | 
			
		||||
    res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U));
 | 
			
		||||
    return res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x8 -> bf16_8_t
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, uint2>(const uint2& a)
 | 
			
		||||
{
 | 
			
		||||
    bf16_4_t tmp1, tmp2;
 | 
			
		||||
    tmp1 = vec_conversion<bf16_4_t, uint32_t>(a.x);
 | 
			
		||||
    tmp2 = vec_conversion<bf16_4_t, uint32_t>(a.y);
 | 
			
		||||
    bf16_8_t res;
 | 
			
		||||
    res.x = tmp1.x;
 | 
			
		||||
    res.y = tmp1.y;
 | 
			
		||||
    res.z = tmp2.x;
 | 
			
		||||
    res.w = tmp2.y;
 | 
			
		||||
    return res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8 -> float
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ float vec_conversion<float, uint8_t>(const uint8_t& a)
 | 
			
		||||
{
 | 
			
		||||
    // fp8 -> half
 | 
			
		||||
    uint16_t tmp = vec_conversion<uint16_t, uint8_t>(a);
 | 
			
		||||
    // half -> float
 | 
			
		||||
    return half_to_float(tmp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x2 -> float2
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ float2 vec_conversion<float2, uint16_t>(const uint16_t& a)
 | 
			
		||||
{
 | 
			
		||||
    // fp8x2 -> half2
 | 
			
		||||
    uint32_t tmp = vec_conversion<uint32_t, uint16_t>(a);
 | 
			
		||||
    // half2 -> float2
 | 
			
		||||
    return half2_to_float2(tmp);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x4 -> float4
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ Float4_ vec_conversion<Float4_, uint32_t>(const uint32_t& a)
 | 
			
		||||
{
 | 
			
		||||
    Float4_ res;
 | 
			
		||||
    res.x = vec_conversion<float2, uint16_t>((uint16_t)a);
 | 
			
		||||
    res.y = vec_conversion<float2, uint16_t>((uint16_t)(a >> 16U));
 | 
			
		||||
    return res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x8 -> float8
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ Float8_ vec_conversion<Float8_, uint2>(const uint2& a)
 | 
			
		||||
{
 | 
			
		||||
    Float4_ tmp1, tmp2;
 | 
			
		||||
    tmp1 = vec_conversion<Float4_, uint32_t>(a.x);
 | 
			
		||||
    tmp2 = vec_conversion<Float4_, uint32_t>(a.y);
 | 
			
		||||
    Float8_ res;
 | 
			
		||||
    res.x = tmp1.x;
 | 
			
		||||
    res.y = tmp1.y;
 | 
			
		||||
    res.z = tmp2.x;
 | 
			
		||||
    res.w = tmp2.y;
 | 
			
		||||
    return res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
// half -> fp8
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, uint16_t>(const uint16_t& a)
 | 
			
		||||
{
 | 
			
		||||
    __half_raw tmp;
 | 
			
		||||
    tmp.x = a;
 | 
			
		||||
    __nv_fp8_storage_t res = __nv_cvt_halfraw_to_fp8(tmp, __NV_SATFINITE, __NV_E5M2);
 | 
			
		||||
    return (uint8_t)res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// bf16 -> fp8
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, __nv_bfloat16>(const __nv_bfloat16& a)
 | 
			
		||||
{
 | 
			
		||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
 | 
			
		||||
    assert(false);
 | 
			
		||||
#else
 | 
			
		||||
    __nv_fp8_storage_t res = __nv_cvt_bfloat16raw_to_fp8(__nv_bfloat16_raw(a), __NV_SATFINITE, __NV_E5M2);
 | 
			
		||||
    return (uint8_t)res;
 | 
			
		||||
#endif
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// float -> fp8
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint8_t vec_conversion<uint8_t, float>(const float& a)
 | 
			
		||||
{
 | 
			
		||||
    __nv_fp8_storage_t res = __nv_cvt_float_to_fp8(a, __NV_SATFINITE, __NV_E5M2);
 | 
			
		||||
    return (uint8_t)res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// fp8x4 -> float4
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ float4 vec_conversion<float4, uint32_t>(const uint32_t& a)
 | 
			
		||||
{
 | 
			
		||||
    Float4_ tmp = vec_conversion<Float4_, uint32_t>(a);
 | 
			
		||||
    float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y);
 | 
			
		||||
    return res;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint32_t vec_conversion<uint32_t, float2>(const float2& a)
 | 
			
		||||
{
 | 
			
		||||
    union {
 | 
			
		||||
        half2    float16;
 | 
			
		||||
        uint32_t uint32;
 | 
			
		||||
    };
 | 
			
		||||
 | 
			
		||||
    float16 = __float22half2_rn(a);
 | 
			
		||||
    return uint32;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint2 vec_conversion<uint2, Float4_>(const Float4_& a)
 | 
			
		||||
{
 | 
			
		||||
    uint2  b;
 | 
			
		||||
    float2 val;
 | 
			
		||||
    val.x = a.x.x;
 | 
			
		||||
    val.y = a.x.y;
 | 
			
		||||
    b.x   = vec_conversion<uint32_t, float2>(val);
 | 
			
		||||
 | 
			
		||||
    val.x = a.y.x;
 | 
			
		||||
    val.y = a.y.y;
 | 
			
		||||
    b.y   = vec_conversion<uint32_t, float2>(val);
 | 
			
		||||
 | 
			
		||||
    return b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ float4 vec_conversion<float4, Float4_>(const Float4_& a)
 | 
			
		||||
{
 | 
			
		||||
    float4 b;
 | 
			
		||||
    b.x = a.x.x;
 | 
			
		||||
    b.y = a.x.y;
 | 
			
		||||
    b.z = a.y.x;
 | 
			
		||||
    b.w = a.y.y;
 | 
			
		||||
    return b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ uint4 vec_conversion<uint4, Float8_>(const Float8_& a)
 | 
			
		||||
{
 | 
			
		||||
    uint4 b;
 | 
			
		||||
    b.x = vec_conversion<uint32_t, float2>(a.x);
 | 
			
		||||
    b.y = vec_conversion<uint32_t, float2>(a.y);
 | 
			
		||||
    b.z = vec_conversion<uint32_t, float2>(a.z);
 | 
			
		||||
    b.w = vec_conversion<uint32_t, float2>(a.w);
 | 
			
		||||
    return b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ __nv_bfloat162 vec_conversion<__nv_bfloat162, float2>(const float2 &a) {
 | 
			
		||||
    __nv_bfloat162 b;
 | 
			
		||||
    from_float(b, a);
 | 
			
		||||
    return b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ bf16_4_t vec_conversion<bf16_4_t, Float4_>(const Float4_ &a) {
 | 
			
		||||
    bf16_4_t b;
 | 
			
		||||
    from_float(b, a);
 | 
			
		||||
    return b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
template<>
 | 
			
		||||
__inline__ __device__ bf16_8_t vec_conversion<bf16_8_t, Float8_>(const Float8_ &a) {
 | 
			
		||||
    bf16_8_t b;
 | 
			
		||||
    from_float(b, a);
 | 
			
		||||
    return b;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
} // namespace fp8_e5m2_unscaled
 | 
			
		||||
#endif // ENABLE_FP8_E5M2
 | 
			
		||||
} // namespace vllm
 | 
			
		||||
							
								
								
									
										64
									
								
								csrc/quantization/gptq/compat.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										64
									
								
								csrc/quantization/gptq/compat.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,64 @@
 | 
			
		||||
/*
 | 
			
		||||
Copied from https://github.com/turboderp/exllamav2
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#ifndef _compat_cuh
 | 
			
		||||
#define _compat_cuh
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace gptq {
 | 
			
		||||
// atomicAdd for half types, to support CC < 7.x
 | 
			
		||||
 | 
			
		||||
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
 | 
			
		||||
{
 | 
			
		||||
    unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
 | 
			
		||||
    unsigned int old = *address_as_ui;
 | 
			
		||||
    unsigned int assumed;
 | 
			
		||||
 | 
			
		||||
    do
 | 
			
		||||
    {
 | 
			
		||||
        assumed = old;
 | 
			
		||||
        __half_raw hsum;
 | 
			
		||||
        hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
 | 
			
		||||
        half tmpres = __hadd(hsum, val);
 | 
			
		||||
        hsum = __half_raw(tmpres);
 | 
			
		||||
        old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
 | 
			
		||||
        old = atomicCAS(address_as_ui, assumed, old);
 | 
			
		||||
    }
 | 
			
		||||
    while (assumed != old);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// atomicAdd for half2 types
 | 
			
		||||
 | 
			
		||||
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
 | 
			
		||||
{
 | 
			
		||||
    unsigned int* address_as_ui = (unsigned int*)address;
 | 
			
		||||
    unsigned int old = *address_as_ui;
 | 
			
		||||
    unsigned int assumed;
 | 
			
		||||
    do
 | 
			
		||||
    {
 | 
			
		||||
        assumed = old;
 | 
			
		||||
        half2 old_val = *((half2*)&old);
 | 
			
		||||
        half2 new_val = __hadd2(old_val, val);
 | 
			
		||||
        old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
 | 
			
		||||
    }
 | 
			
		||||
    while (assumed != old);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
//
 | 
			
		||||
 | 
			
		||||
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
 | 
			
		||||
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
 | 
			
		||||
 | 
			
		||||
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
 | 
			
		||||
 | 
			
		||||
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
 | 
			
		||||
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
}  // namespace gptq
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
#endif
 | 
			
		||||
							
								
								
									
										151
									
								
								csrc/quantization/gptq/matrix_view.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										151
									
								
								csrc/quantization/gptq/matrix_view.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,151 @@
 | 
			
		||||
/*
 | 
			
		||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#ifndef _matrix_view_cuh
 | 
			
		||||
#define _matrix_view_cuh
 | 
			
		||||
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
#include <cuda_fp16.h>
 | 
			
		||||
 | 
			
		||||
#include "qdq_util.cuh"
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace gptq {
 | 
			
		||||
 | 
			
		||||
class MatrixView_half
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    const half* data;
 | 
			
		||||
    const int height;
 | 
			
		||||
    const int width;
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
 | 
			
		||||
        : data(data), height(height), width(width)
 | 
			
		||||
    { }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
 | 
			
		||||
    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
 | 
			
		||||
    __device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
 | 
			
		||||
    __device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        half2* ptr = (half2*) item_ptr(row, column);
 | 
			
		||||
        half2 i01 = ptr[0];
 | 
			
		||||
        half2 i23 = ptr[1];
 | 
			
		||||
        items[0] = __low2half(i01);
 | 
			
		||||
        items[1] = __high2half(i01);
 | 
			
		||||
        items[2] = __low2half(i23);
 | 
			
		||||
        items[3] = __high2half(i23);
 | 
			
		||||
    }
 | 
			
		||||
    __device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        half2* ptr = (half2*)item_ptr(row, column);
 | 
			
		||||
        half2 i01 = ptr[0];
 | 
			
		||||
        half2 i23 = ptr[1];
 | 
			
		||||
        items[0] = __half2float(__low2half(i01));
 | 
			
		||||
        items[1] = __half2float(__high2half(i01));
 | 
			
		||||
        items[2] = __half2float(__low2half(i23));
 | 
			
		||||
        items[3] = __half2float(__high2half(i23));
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        half2* ptr = (half2*)item_ptr(row, column);
 | 
			
		||||
        half2 i01 = ptr[0];
 | 
			
		||||
        half2 i23 = ptr[1];
 | 
			
		||||
        items[0] = __half2half2(__low2half(i01));
 | 
			
		||||
        items[1] = __half2half2(__high2half(i01));
 | 
			
		||||
        items[2] = __half2half2(__low2half(i23));
 | 
			
		||||
        items[3] = __half2half2(__high2half(i23));
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class MatrixView_half_rw
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    half* data;
 | 
			
		||||
    const int height;
 | 
			
		||||
    const int width;
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
 | 
			
		||||
        : data(data), height(height), width(width)
 | 
			
		||||
    { }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
 | 
			
		||||
    __device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
 | 
			
		||||
    __device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
 | 
			
		||||
    __device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
 | 
			
		||||
    __device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
 | 
			
		||||
    {
 | 
			
		||||
        half2 v01 = __halves2half2(v0, v1);
 | 
			
		||||
        half2 v23 = __halves2half2(v2, v3);
 | 
			
		||||
        half2* ptr = (half2*) item_ptr(row, column);
 | 
			
		||||
        ptr[0] = v01;
 | 
			
		||||
        ptr[1] = v23;
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class MatrixView_q4_row
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    const uint32_t* data;
 | 
			
		||||
    const int height;
 | 
			
		||||
    const int width;
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
 | 
			
		||||
        : data(data), height(height), width(width)
 | 
			
		||||
    { }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ int item(int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        int shift = (column & 0x07) * 4;
 | 
			
		||||
        return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        int shift = (column & 0x07) * 4;
 | 
			
		||||
        uint32_t d = data[row * width / 8 + column / 8] >> shift;
 | 
			
		||||
        items[0] = d & 0x0f;
 | 
			
		||||
        items[1] = (d >> 4) & 0x0f;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        int shift = (column & 0x07) * 4;
 | 
			
		||||
        uint32_t d = data[row * width / 8 + column / 8] >> shift;
 | 
			
		||||
        items[0] = d & 0x0f;
 | 
			
		||||
        items[1] = (d >> 4) & 0x0f;
 | 
			
		||||
        items[2] = (d >> 8) & 0x0f;
 | 
			
		||||
        items[3] = (d >> 12) & 0x0f;
 | 
			
		||||
    }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
class MatrixView_q4_column
 | 
			
		||||
{
 | 
			
		||||
public:
 | 
			
		||||
    const uint32_t* data;
 | 
			
		||||
    const int height;
 | 
			
		||||
    const int width;
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
 | 
			
		||||
        : data(data), height(height), width(width)
 | 
			
		||||
    { }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ int item(int row, int column) const
 | 
			
		||||
    {
 | 
			
		||||
        int shift = (row & 0x07) * 4;
 | 
			
		||||
        return (data[row / 8 * width + column] >> shift) & 0x0f;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
 | 
			
		||||
    __device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
}  // namespace gptq
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
#endif
 | 
			
		||||
							
								
								
									
										875
									
								
								csrc/quantization/gptq/q_gemm.cu
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										875
									
								
								csrc/quantization/gptq/q_gemm.cu
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,875 @@
 | 
			
		||||
/*
 | 
			
		||||
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/qwopqwop200/GPTQ-for-LLaMa
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#include <cstdint>
 | 
			
		||||
#include <cstdio>
 | 
			
		||||
 | 
			
		||||
#include <torch/extension.h>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
#include <ATen/cuda/CUDAContext.h>
 | 
			
		||||
#include <cuda_runtime.h>
 | 
			
		||||
#include <cuda_fp16.h>
 | 
			
		||||
 | 
			
		||||
#include "compat.cuh"
 | 
			
		||||
#include "matrix_view.cuh"
 | 
			
		||||
#include "qdq_4.cuh"
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace gptq {
 | 
			
		||||
 | 
			
		||||
#define BLOCK_KN_SIZE 128
 | 
			
		||||
#define BLOCK_M_SIZE_MAX 8
 | 
			
		||||
#define MAX_GROUPS_IN_BLOCK (BLOCK_KN_SIZE / 32)
 | 
			
		||||
#define MAX_Q_GEMM_ROWS 50
 | 
			
		||||
#define MAX_ALT_GEMM_ROWS 8
 | 
			
		||||
#define THREADS_X 32
 | 
			
		||||
#define THREADS_Y 32
 | 
			
		||||
#define DIVIDE(x, size) (((x) + (size) - 1) / (size))
 | 
			
		||||
 | 
			
		||||
#if defined(USE_ROCM)
 | 
			
		||||
#include <hipblas/hipblas.h>
 | 
			
		||||
__host__ __forceinline__ hipblasStatus_t __compat_hipblasHgemm(hipblasHandle_t    handle,
 | 
			
		||||
                                                               hipblasOperation_t transA,
 | 
			
		||||
                                                               hipblasOperation_t transB,
 | 
			
		||||
                                                               int                m,
 | 
			
		||||
                                                               int                n,
 | 
			
		||||
                                                               int                k,
 | 
			
		||||
                                                               const half*        alpha,
 | 
			
		||||
                                                               const half*        AP,
 | 
			
		||||
                                                               int                lda,
 | 
			
		||||
                                                               const half*        BP,
 | 
			
		||||
                                                               int                ldb,
 | 
			
		||||
                                                               const half*        beta,
 | 
			
		||||
                                                               half*              CP,
 | 
			
		||||
                                                               int                ldc) {
 | 
			
		||||
    return hipblasHgemm(handle, transA, transB, m, n, k,
 | 
			
		||||
                        reinterpret_cast<const hipblasHalf *>(alpha),
 | 
			
		||||
                        reinterpret_cast<const hipblasHalf *>(AP), lda,
 | 
			
		||||
                        reinterpret_cast<const hipblasHalf *>(BP), ldb,
 | 
			
		||||
                        reinterpret_cast<const hipblasHalf *>(beta),
 | 
			
		||||
                        reinterpret_cast<hipblasHalf *>(CP), ldc);
 | 
			
		||||
}
 | 
			
		||||
#define hipblasHgemm __compat_hipblasHgemm
 | 
			
		||||
 | 
			
		||||
// Previous version of PyTorch were converting to rocBLAS instead of hipBLAS.
 | 
			
		||||
#define rocblas_operation_none HIPBLAS_OP_N
 | 
			
		||||
#define rocblas_hgemm __compat_hipblasHgemm
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ half2 dot22_8(half2(&dq)[4], const half* a_ptr, const half2 g_result)
 | 
			
		||||
{
 | 
			
		||||
    half2 result = {};
 | 
			
		||||
    const half2* a2_ptr = (const half2*)a_ptr;
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
 | 
			
		||||
    return __hadd2(result, g_result);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ float dot22_8_f(half2(&dq)[4], const half* a_ptr)
 | 
			
		||||
{
 | 
			
		||||
    half2 result = {};
 | 
			
		||||
    const half2* a2_ptr = (const half2*)a_ptr;
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int i = 0; i < 4; i++) result = __hfma2(dq[i], *a2_ptr++, result);
 | 
			
		||||
    return __half2float(__low2half(result)) + __half2float(__high2half(result));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
typedef void (*fp_gemm_half_q_half_gptq_kernel)
 | 
			
		||||
(
 | 
			
		||||
    const half*,
 | 
			
		||||
    const uint32_t*,
 | 
			
		||||
    const uint32_t*,
 | 
			
		||||
    const half*,
 | 
			
		||||
    half*,
 | 
			
		||||
    const int,
 | 
			
		||||
    const int,
 | 
			
		||||
    const int,
 | 
			
		||||
    const int,
 | 
			
		||||
    const int*
 | 
			
		||||
);
 | 
			
		||||
 | 
			
		||||
template <bool first_block, int m_count>
 | 
			
		||||
__global__ void gemm_half_q_half_gptq_kernel
 | 
			
		||||
(
 | 
			
		||||
    const half* __restrict__ a,
 | 
			
		||||
    const uint32_t* __restrict__ b_q_weight,
 | 
			
		||||
    const uint32_t* __restrict__ b_gptq_qzeros,
 | 
			
		||||
    const half* __restrict__ b_gptq_scales,
 | 
			
		||||
    half* __restrict__ c,
 | 
			
		||||
    const int size_m,
 | 
			
		||||
    const int size_n,
 | 
			
		||||
    const int size_k,
 | 
			
		||||
    const int groups,
 | 
			
		||||
    const int* __restrict__ b_q_perm
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    MatrixView_half a_(a, size_m, size_k);
 | 
			
		||||
    MatrixView_half_rw c_(c, size_m, size_n);
 | 
			
		||||
    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
 | 
			
		||||
    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
 | 
			
		||||
 | 
			
		||||
    int t = threadIdx.x;
 | 
			
		||||
 | 
			
		||||
    // Block
 | 
			
		||||
    int offset_n = blockIdx.x * BLOCK_KN_SIZE * 4;
 | 
			
		||||
    int offset_m = blockIdx.y * m_count;
 | 
			
		||||
    int offset_k = blockIdx.z * BLOCK_KN_SIZE;
 | 
			
		||||
 | 
			
		||||
    int end_n = min(offset_n + BLOCK_KN_SIZE * 4, size_n);
 | 
			
		||||
    int end_m = min(offset_m + m_count, size_m);
 | 
			
		||||
    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
 | 
			
		||||
 | 
			
		||||
    int n = offset_n + t * 4;
 | 
			
		||||
 | 
			
		||||
    // Preload block_a
 | 
			
		||||
    __shared__ half block_a[m_count][BLOCK_KN_SIZE];
 | 
			
		||||
 | 
			
		||||
    if (offset_k + t < end_k)
 | 
			
		||||
    {
 | 
			
		||||
        for (int m = 0; m < m_count; ++m)
 | 
			
		||||
        {
 | 
			
		||||
            const half* a_ptr = a_.item_ptr(offset_m + m, 0);
 | 
			
		||||
            half* block_a_ptr = block_a[m];
 | 
			
		||||
 | 
			
		||||
            half a0;
 | 
			
		||||
            if (b_q_perm) a0 = a_ptr[b_q_perm[offset_k + t]];
 | 
			
		||||
            else a0 = a_ptr[offset_k + t];
 | 
			
		||||
            block_a_ptr[t] = a0;
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Zero output
 | 
			
		||||
    if (n >= size_n) return;
 | 
			
		||||
 | 
			
		||||
    if (blockIdx.z == 0)
 | 
			
		||||
    {
 | 
			
		||||
        for (int m = 0; m < m_count; m++)
 | 
			
		||||
            *((uint64_t*)c_.item_ptr(offset_m + m, n)) = 0;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    // Find initial group
 | 
			
		||||
    int groupsize = size_k / groups;
 | 
			
		||||
    int group = offset_k / groupsize;
 | 
			
		||||
    int nextgroup = offset_k + groupsize;
 | 
			
		||||
 | 
			
		||||
    // a, b offset
 | 
			
		||||
    int qk = offset_k / (32 / 4);
 | 
			
		||||
 | 
			
		||||
    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
 | 
			
		||||
    const half* a_ptr = &block_a[0][0];
 | 
			
		||||
    int a_stride = BLOCK_KN_SIZE;
 | 
			
		||||
 | 
			
		||||
    // Initial group
 | 
			
		||||
    int zeros[4];
 | 
			
		||||
    float scales[4];
 | 
			
		||||
    half2 z1z16[4][2];
 | 
			
		||||
    half2 y1y16[4][2];
 | 
			
		||||
    b_gptq_qzeros_.item4(zeros, group, n);
 | 
			
		||||
    b_gptq_scales_.item4_f(scales, group, n);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
 | 
			
		||||
 | 
			
		||||
    // Column result
 | 
			
		||||
    float block_c[m_count][4] = {};
 | 
			
		||||
 | 
			
		||||
    // Dequantize and multiply
 | 
			
		||||
    int k = offset_k;
 | 
			
		||||
    while (k < end_k)
 | 
			
		||||
    {
 | 
			
		||||
        if (k == nextgroup)
 | 
			
		||||
        {
 | 
			
		||||
            group++;
 | 
			
		||||
            nextgroup += groupsize;
 | 
			
		||||
            b_gptq_qzeros_.item4(zeros, group, n);
 | 
			
		||||
            b_gptq_scales_.item4_f(scales, group, n);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        #pragma unroll
 | 
			
		||||
        for (int j = 0; j < 4; j++)
 | 
			
		||||
        {
 | 
			
		||||
            const int4* b_ptr4 = (int4*) b_ptr;
 | 
			
		||||
            int4 load_int4 = *b_ptr4;
 | 
			
		||||
 | 
			
		||||
            half2 dq[4][4];
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
 | 
			
		||||
 | 
			
		||||
            #pragma unroll
 | 
			
		||||
            for (int m = 0; m < m_count; m++)
 | 
			
		||||
            {
 | 
			
		||||
                block_c[m][0] = fma(dot22_8_f(dq[0], a_ptr + m * a_stride), scales[0], block_c[m][0]);
 | 
			
		||||
                block_c[m][1] = fma(dot22_8_f(dq[1], a_ptr + m * a_stride), scales[1], block_c[m][1]);
 | 
			
		||||
                block_c[m][2] = fma(dot22_8_f(dq[2], a_ptr + m * a_stride), scales[2], block_c[m][2]);
 | 
			
		||||
                block_c[m][3] = fma(dot22_8_f(dq[3], a_ptr + m * a_stride), scales[3], block_c[m][3]);
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            b_ptr += size_n;
 | 
			
		||||
            a_ptr += 8;
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        k += 32;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    for (int m = 0; m < m_count; m++)
 | 
			
		||||
    {
 | 
			
		||||
        half2 *out = (half2*) c_.item_ptr(offset_m + m, n);
 | 
			
		||||
        half2 result01 = __halves2half2(__float2half_rn(block_c[m][0]), __float2half_rn(block_c[m][1]));
 | 
			
		||||
        half2 result23 = __halves2half2(__float2half_rn(block_c[m][2]), __float2half_rn(block_c[m][3]));
 | 
			
		||||
        atomicAdd(out    , result01);
 | 
			
		||||
        atomicAdd(out + 1, result23);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
fp_gemm_half_q_half_gptq_kernel pick_gemm_half_q_half_gptq_kernel(bool first_block, const int m_count)
 | 
			
		||||
{
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 1
 | 
			
		||||
    if (m_count == 1) return gemm_half_q_half_gptq_kernel<true, 1>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 2
 | 
			
		||||
    if (m_count == 2) return gemm_half_q_half_gptq_kernel<true, 2>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 3
 | 
			
		||||
    if (m_count == 3) return gemm_half_q_half_gptq_kernel<true, 3>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 4
 | 
			
		||||
    if (m_count == 4) return gemm_half_q_half_gptq_kernel<true, 4>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 5
 | 
			
		||||
    if (m_count == 5) return gemm_half_q_half_gptq_kernel<true, 5>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 6
 | 
			
		||||
    if (m_count == 6) return gemm_half_q_half_gptq_kernel<true, 6>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 7
 | 
			
		||||
    if (m_count == 7) return gemm_half_q_half_gptq_kernel<true, 7>;
 | 
			
		||||
    #endif
 | 
			
		||||
    #if BLOCK_M_SIZE_MAX >= 8
 | 
			
		||||
    if (m_count == 8) return gemm_half_q_half_gptq_kernel<true, 8>;
 | 
			
		||||
    #endif
 | 
			
		||||
    return NULL;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void gemm_half_q_half_cuda_part
 | 
			
		||||
(
 | 
			
		||||
    const half* a,
 | 
			
		||||
    const uint32_t* b_q_weight,
 | 
			
		||||
    const uint32_t* b_gptq_qzeros,
 | 
			
		||||
    const half* b_gptq_scales,
 | 
			
		||||
    const int* b_q_perm,
 | 
			
		||||
    half* c,
 | 
			
		||||
    int size_m,
 | 
			
		||||
    int size_n,
 | 
			
		||||
    int size_k,
 | 
			
		||||
    int m_count,
 | 
			
		||||
    int groups
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    dim3 blockDim, gridDim;
 | 
			
		||||
    blockDim.x = BLOCK_KN_SIZE;
 | 
			
		||||
    blockDim.y = 1;
 | 
			
		||||
    blockDim.z = 1;
 | 
			
		||||
    gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE * 4);
 | 
			
		||||
    gridDim.y = DIVIDE(size_m, m_count);
 | 
			
		||||
    gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
 | 
			
		||||
 | 
			
		||||
    fp_gemm_half_q_half_gptq_kernel kernel = pick_gemm_half_q_half_gptq_kernel(true, m_count);
 | 
			
		||||
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    kernel<<<gridDim, blockDim, 0, stream>>>
 | 
			
		||||
    (
 | 
			
		||||
        a,
 | 
			
		||||
        b_q_weight,
 | 
			
		||||
        b_gptq_qzeros,
 | 
			
		||||
        b_gptq_scales,
 | 
			
		||||
        c,
 | 
			
		||||
        size_m,
 | 
			
		||||
        size_n,
 | 
			
		||||
        size_k,
 | 
			
		||||
        groups,
 | 
			
		||||
        b_q_perm
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__global__ void reconstruct_exllama_kernel
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t* __restrict__ b_q_weight,
 | 
			
		||||
    const int* __restrict__ b_q_perm,
 | 
			
		||||
    const uint32_t* __restrict__ b_gptq_qzeros,
 | 
			
		||||
    const half* __restrict__ b_gptq_scales,
 | 
			
		||||
    const int size_k,
 | 
			
		||||
    const int size_n,
 | 
			
		||||
    const int groups,
 | 
			
		||||
    half* __restrict__ b
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    MatrixView_half_rw b_(b, size_k, size_n);
 | 
			
		||||
    MatrixView_q4_row b_gptq_qzeros_(b_gptq_qzeros, groups, size_n);
 | 
			
		||||
    MatrixView_half b_gptq_scales_(b_gptq_scales, groups, size_n);
 | 
			
		||||
 | 
			
		||||
    int offset_k = BLOCK_KN_SIZE * blockIdx.y;
 | 
			
		||||
    int offset_n = BLOCK_KN_SIZE * blockIdx.x * 4;
 | 
			
		||||
 | 
			
		||||
    int end_k = min(offset_k + BLOCK_KN_SIZE, size_k);
 | 
			
		||||
 | 
			
		||||
    // Preload remapping table
 | 
			
		||||
    __shared__ int perm[BLOCK_KN_SIZE];
 | 
			
		||||
    int t = threadIdx.x;
 | 
			
		||||
 | 
			
		||||
    if (b_q_perm)
 | 
			
		||||
    {
 | 
			
		||||
        if (offset_k + t < size_k)
 | 
			
		||||
            perm[t] = b_q_perm[offset_k + t];
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // Column
 | 
			
		||||
    int n = offset_n + t * 4;
 | 
			
		||||
    if (n >= size_n) return;
 | 
			
		||||
 | 
			
		||||
    // Find initial group
 | 
			
		||||
    int groupsize = size_k / groups;
 | 
			
		||||
    int group = offset_k / groupsize;
 | 
			
		||||
    int nextgroup = offset_k + groupsize;
 | 
			
		||||
 | 
			
		||||
    // b offset
 | 
			
		||||
    int qk = offset_k / (32 / 4);
 | 
			
		||||
 | 
			
		||||
    const uint32_t* b_ptr = b_q_weight + qk * size_n + n;
 | 
			
		||||
 | 
			
		||||
    // Initial zeros/scale
 | 
			
		||||
    int zeros[4];
 | 
			
		||||
    half2 scales[4];
 | 
			
		||||
    half2 z1z16[4][2];
 | 
			
		||||
    half2 y1y16[4][2];
 | 
			
		||||
    b_gptq_qzeros_.item4(zeros, group, n);
 | 
			
		||||
    b_gptq_scales_.item4_h2(scales, group, n);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
 | 
			
		||||
    dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
 | 
			
		||||
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    int k = offset_k;
 | 
			
		||||
    int lk = 0;
 | 
			
		||||
 | 
			
		||||
    while (k < end_k)
 | 
			
		||||
    {
 | 
			
		||||
        if (k == nextgroup)
 | 
			
		||||
        {
 | 
			
		||||
            group++;
 | 
			
		||||
            nextgroup += groupsize;
 | 
			
		||||
            b_gptq_qzeros_.item4(zeros, group, n);
 | 
			
		||||
            b_gptq_scales_.item4_h2(scales, group, n);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[0] + 1, z1z16[0], y1y16[0]);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[1] + 1, z1z16[1], y1y16[1]);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[2] + 1, z1z16[2], y1y16[2]);
 | 
			
		||||
            dequant_4bit_8_prep_zero(zeros[3] + 1, z1z16[3], y1y16[3]);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        for (int p = 0; p < 4; p++)
 | 
			
		||||
        {
 | 
			
		||||
            half2 dq[4][4];
 | 
			
		||||
            const int4* b_ptr4 = (int4*) b_ptr;
 | 
			
		||||
            int4 load_int4 = *b_ptr4;
 | 
			
		||||
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.x, dq[0], z1z16[0], y1y16[0], size_n, false);
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.y, dq[1], z1z16[1], y1y16[1], size_n, false);
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.z, dq[2], z1z16[2], y1y16[2], size_n, false);
 | 
			
		||||
            dequant_4bit_8_gptq(load_int4.w, dq[3], z1z16[3], y1y16[3], size_n, false);
 | 
			
		||||
 | 
			
		||||
            b_ptr += size_n;
 | 
			
		||||
            //half* dqh = (half*)dq;
 | 
			
		||||
            if (b_q_perm)
 | 
			
		||||
            {
 | 
			
		||||
                for (int j = 0; j < 4; j++)
 | 
			
		||||
                {
 | 
			
		||||
                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
 | 
			
		||||
                    b_.set4(perm[lk++], n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
 | 
			
		||||
                    b_.set4(perm[lk++], n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
            else
 | 
			
		||||
            {
 | 
			
		||||
                for (int j = 0; j < 4; j++)
 | 
			
		||||
                {
 | 
			
		||||
                    for (int v = 0; v < 4; v++) dq[v][j] = __hmul2(scales[v], dq[v][j]);
 | 
			
		||||
                    b_.set4(offset_k + lk++, n, __low2half(dq[0][j]), __low2half(dq[1][j]), __low2half(dq[2][j]), __low2half(dq[3][j]));
 | 
			
		||||
                    b_.set4(offset_k + lk++, n, __high2half(dq[0][j]), __high2half(dq[1][j]), __high2half(dq[2][j]), __high2half(dq[3][j]));
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
        k += 32;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void reconstruct_exllama
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t* b_q_weight,
 | 
			
		||||
    const uint32_t* b_gptq_qzeros,
 | 
			
		||||
    const half* b_gptq_scales,
 | 
			
		||||
    const int* b_q_perm,
 | 
			
		||||
    half* out,
 | 
			
		||||
    int height,
 | 
			
		||||
    int width,
 | 
			
		||||
    int groups
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    dim3 blockDim, gridDim;
 | 
			
		||||
    blockDim.x = BLOCK_KN_SIZE;
 | 
			
		||||
    blockDim.y = 1;
 | 
			
		||||
    gridDim.y = DIVIDE(height, BLOCK_KN_SIZE);
 | 
			
		||||
    gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
 | 
			
		||||
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    reconstruct_exllama_kernel<<<gridDim, blockDim, 0, stream>>>
 | 
			
		||||
    (
 | 
			
		||||
        b_q_weight,
 | 
			
		||||
        b_q_perm,
 | 
			
		||||
        b_gptq_qzeros,
 | 
			
		||||
        b_gptq_scales,
 | 
			
		||||
        height,
 | 
			
		||||
        width,
 | 
			
		||||
        groups,
 | 
			
		||||
        out
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__global__ void gemm_half_q_half_alt_kernel(
 | 
			
		||||
    const half2* __restrict__ vec,
 | 
			
		||||
    const uint32_t* __restrict__ mat,
 | 
			
		||||
    half* __restrict__ mul,
 | 
			
		||||
    const half* __restrict__ scales,
 | 
			
		||||
    const uint32_t* __restrict__ zeros,
 | 
			
		||||
    const int* __restrict__ g_idx,
 | 
			
		||||
    int batch,
 | 
			
		||||
    int height,
 | 
			
		||||
    int width
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    int zero_width = width / 8;
 | 
			
		||||
    int vec_height = height * 4;
 | 
			
		||||
    const int blockwidth2 = BLOCK_KN_SIZE / 2;
 | 
			
		||||
    int b = blockIdx.y * BLOCK_M_SIZE_MAX;
 | 
			
		||||
    int b_end = min(BLOCK_M_SIZE_MAX, batch - b);
 | 
			
		||||
    int h = BLOCK_KN_SIZE * blockIdx.z / 8;
 | 
			
		||||
    int h_end = min(BLOCK_KN_SIZE / 8, height - h) * 4;
 | 
			
		||||
    int w = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
    __shared__ half2 blockvec[BLOCK_M_SIZE_MAX][blockwidth2];
 | 
			
		||||
    if (threadIdx.x < h_end) {
 | 
			
		||||
        for (int m = 0; m < b_end; ++m) {
 | 
			
		||||
          blockvec[m][threadIdx.x] =
 | 
			
		||||
              vec[(m + b) * vec_height + blockIdx.z * BLOCK_KN_SIZE / 2 +
 | 
			
		||||
                  threadIdx.x];
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    __shared__ half2 deq2[256][8];
 | 
			
		||||
    int val = threadIdx.x / 8;
 | 
			
		||||
    int off = threadIdx.x % 8;
 | 
			
		||||
    for (; val < 256; val += BLOCK_KN_SIZE / 8) {
 | 
			
		||||
        deq2[val][off] = __halves2half2(
 | 
			
		||||
            __int2half_rn(val & 0xF), __int2half_rn(val >> 4)
 | 
			
		||||
        );
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (blockIdx.z == 0)
 | 
			
		||||
    {
 | 
			
		||||
        for (int m = 0; m < b_end; m++)
 | 
			
		||||
            mul[(b + m) * width + w] = __int2half_rn(0);
 | 
			
		||||
    }
 | 
			
		||||
    __syncthreads();
 | 
			
		||||
 | 
			
		||||
    int i = width * h + w;
 | 
			
		||||
    int g_h = h * 8;
 | 
			
		||||
    int k = 0;
 | 
			
		||||
    int z_w = w / 8;
 | 
			
		||||
    int z_mod = (w % 8) * 4;
 | 
			
		||||
    half2 res2;
 | 
			
		||||
    half res[BLOCK_M_SIZE_MAX] = {};
 | 
			
		||||
 | 
			
		||||
    unsigned int tmp;
 | 
			
		||||
    while (k < h_end) {
 | 
			
		||||
        tmp = mat[i];
 | 
			
		||||
        half2 scales_tmp[4];
 | 
			
		||||
        half2 zeros_tmp[4];
 | 
			
		||||
        for (int tmp_k = 0; tmp_k < 4; tmp_k++) {
 | 
			
		||||
            int g = g_idx[g_h + (k + tmp_k) * 2];
 | 
			
		||||
            int g2 = g_idx[g_h + (k + tmp_k) * 2 + 1];
 | 
			
		||||
            half scale_f = scales[g * width + w];
 | 
			
		||||
            half scale_f2 = scales[g2 * width + w];
 | 
			
		||||
            half2 scale = __halves2half2(scale_f, scale_f2);
 | 
			
		||||
            half2 zero = __halves2half2(
 | 
			
		||||
                __hmul(scale_f, __int2half_rn(-((zeros[g * zero_width + z_w] >> z_mod) & 0xF) - 1)),
 | 
			
		||||
                __hmul(scale_f2, __int2half_rn(-((zeros[g2 * zero_width + z_w] >> z_mod) & 0xF) - 1))
 | 
			
		||||
            );
 | 
			
		||||
            scales_tmp[tmp_k] = scale;
 | 
			
		||||
            zeros_tmp[tmp_k] = zero;
 | 
			
		||||
        }
 | 
			
		||||
        for (int m = 0; m < b_end; m++) {
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
            res2 = {};
 | 
			
		||||
#else
 | 
			
		||||
            res2.x = __half_as_ushort(__float2half(0));
 | 
			
		||||
            res2.y = __half_as_ushort(__float2half(0));
 | 
			
		||||
#endif
 | 
			
		||||
            res2 = __hfma2(__hfma2(deq2[(tmp >>  0) & 0xff][off], scales_tmp[0], zeros_tmp[0]), blockvec[m][k + 0], res2);
 | 
			
		||||
            res2 = __hfma2(__hfma2(deq2[(tmp >>  8) & 0xff][off], scales_tmp[1], zeros_tmp[1]), blockvec[m][k + 1], res2);
 | 
			
		||||
            res2 = __hfma2(__hfma2(deq2[(tmp >> 16) & 0xff][off], scales_tmp[2], zeros_tmp[2]), blockvec[m][k + 2], res2);
 | 
			
		||||
            res2 = __hfma2(__hfma2(deq2[(tmp >> 24) & 0xff][off], scales_tmp[3], zeros_tmp[3]), blockvec[m][k + 3], res2);
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
            res[m] = __hadd(res[m], __hadd(res2.x, res2.y));
 | 
			
		||||
#else
 | 
			
		||||
            res[m] = __hadd(res[m], __hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)));
 | 
			
		||||
#endif
 | 
			
		||||
        }
 | 
			
		||||
        i += width;
 | 
			
		||||
        k += 4;
 | 
			
		||||
    }
 | 
			
		||||
    for (int m = 0; m < b_end; m++) {
 | 
			
		||||
        atomicAdd(&mul[(b + m) * width + w], res[m]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void gemm_half_q_half_alt
 | 
			
		||||
(
 | 
			
		||||
    const half* a,
 | 
			
		||||
    const uint32_t* b_q_weight,
 | 
			
		||||
    const uint32_t* b_gptq_qzeros,
 | 
			
		||||
    const half* b_gptq_scales,
 | 
			
		||||
    const int* b_g_idx,
 | 
			
		||||
    half* c,
 | 
			
		||||
    int size_m,
 | 
			
		||||
    int size_n,
 | 
			
		||||
    int size_k
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    dim3 blockDim, gridDim;
 | 
			
		||||
    blockDim.x = BLOCK_KN_SIZE;
 | 
			
		||||
    blockDim.y = 1;
 | 
			
		||||
    blockDim.z = 1;
 | 
			
		||||
    gridDim.x = DIVIDE(size_n, BLOCK_KN_SIZE);
 | 
			
		||||
    gridDim.y = DIVIDE(size_m, BLOCK_M_SIZE_MAX);
 | 
			
		||||
    gridDim.z = DIVIDE(size_k, BLOCK_KN_SIZE);
 | 
			
		||||
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    gemm_half_q_half_alt_kernel<<<gridDim, blockDim, 0, stream>>>
 | 
			
		||||
    (
 | 
			
		||||
        (const half2*) a,
 | 
			
		||||
        b_q_weight,
 | 
			
		||||
        c,
 | 
			
		||||
        b_gptq_scales,
 | 
			
		||||
        b_gptq_qzeros,
 | 
			
		||||
        b_g_idx,
 | 
			
		||||
        size_m,
 | 
			
		||||
        size_k / 8,
 | 
			
		||||
        size_n
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__global__ void reconstruct_gptq_kernel
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t* __restrict__ w,
 | 
			
		||||
    const half* __restrict__ w_scales,
 | 
			
		||||
    const uint32_t* __restrict__ w_zeros,
 | 
			
		||||
    const int* __restrict__ g_idx,
 | 
			
		||||
    const int height,
 | 
			
		||||
    const int width,
 | 
			
		||||
    const int group,
 | 
			
		||||
    half* __restrict__ out
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    // Start of block
 | 
			
		||||
 | 
			
		||||
    int column = BLOCK_KN_SIZE * blockIdx.x + threadIdx.x;
 | 
			
		||||
    int row = blockIdx.y * 8;
 | 
			
		||||
    if (column >= width) return;
 | 
			
		||||
 | 
			
		||||
    // Views
 | 
			
		||||
 | 
			
		||||
    MatrixView_q4_column w_(w, height, width);
 | 
			
		||||
    MatrixView_half_rw out_(out, height, width);
 | 
			
		||||
    MatrixView_half w_scales_(w_scales, group, width);
 | 
			
		||||
    MatrixView_q4_row w_zeros_(w_zeros, group, width);
 | 
			
		||||
 | 
			
		||||
    uint32_t w_read = w_.item_uint32_t(row, column);
 | 
			
		||||
    half* out_ptr = out_.item_ptr(row, column);
 | 
			
		||||
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int s = 0; s < 32; s += 4)
 | 
			
		||||
    {
 | 
			
		||||
        int group = g_idx[row + s / 4];
 | 
			
		||||
        half w_scale = w_scales_.item(group, column);
 | 
			
		||||
        uint32_t w_zero = w_zeros_.item(group, column) + 1;
 | 
			
		||||
        half w_item = __hmul(__int2half_rn((int)((w_read >> s) & 0x0f) - w_zero), w_scale);
 | 
			
		||||
        *out_ptr = w_item; out_ptr += out_.width;
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void reconstruct_gptq
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t* b_q_weight,
 | 
			
		||||
    const uint32_t* b_gptq_qzeros,
 | 
			
		||||
    const half* b_gptq_scales,
 | 
			
		||||
    const int* b_g_idx,
 | 
			
		||||
    half* out,
 | 
			
		||||
    int height,
 | 
			
		||||
    int width,
 | 
			
		||||
    int groups
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    dim3 blockDim, gridDim;
 | 
			
		||||
    blockDim.x = BLOCK_KN_SIZE;
 | 
			
		||||
    blockDim.y = 1;
 | 
			
		||||
    gridDim.y = DIVIDE(height, 8);
 | 
			
		||||
    gridDim.x = DIVIDE(width, BLOCK_KN_SIZE);
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    reconstruct_gptq_kernel<<<gridDim, blockDim, 0, stream>>>
 | 
			
		||||
    (
 | 
			
		||||
        b_q_weight,
 | 
			
		||||
        b_gptq_scales,
 | 
			
		||||
        b_gptq_qzeros,
 | 
			
		||||
        b_g_idx,
 | 
			
		||||
        height,
 | 
			
		||||
        width,
 | 
			
		||||
        groups,
 | 
			
		||||
        out
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void gemm_half_q_half_cuda
 | 
			
		||||
(
 | 
			
		||||
    cublasHandle_t cublas_handle,
 | 
			
		||||
    const half* a,
 | 
			
		||||
    const uint32_t* b_q_weight,
 | 
			
		||||
    const uint32_t* b_gptq_qzeros,
 | 
			
		||||
    const half* b_gptq_scales,
 | 
			
		||||
    const int* b_g_idx,
 | 
			
		||||
    half* c,
 | 
			
		||||
    half* temp_dq,
 | 
			
		||||
    int size_m,
 | 
			
		||||
    int size_n,
 | 
			
		||||
    int size_k,
 | 
			
		||||
    int groups,
 | 
			
		||||
    bool use_exllama
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    if ((use_exllama && size_m > MAX_Q_GEMM_ROWS) || (!use_exllama && size_m > MAX_ALT_GEMM_ROWS)) {
 | 
			
		||||
        // Reconstruct FP16 matrix, then cuBLAS
 | 
			
		||||
        if (use_exllama) {
 | 
			
		||||
            reconstruct_exllama(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, temp_dq,
 | 
			
		||||
                                size_k, size_n, groups);
 | 
			
		||||
        }
 | 
			
		||||
        else
 | 
			
		||||
        {
 | 
			
		||||
            reconstruct_gptq(b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
 | 
			
		||||
                             temp_dq, size_k, size_n, groups);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        const half alpha = __float2half(1.0f);
 | 
			
		||||
        const half beta = __float2half(0.0f);
 | 
			
		||||
        cublasHgemm(cublas_handle,
 | 
			
		||||
                    CUBLAS_OP_N,
 | 
			
		||||
                    CUBLAS_OP_N,
 | 
			
		||||
                    size_n, size_m, size_k,
 | 
			
		||||
                    &alpha, temp_dq, size_n,
 | 
			
		||||
                            a,       size_k,
 | 
			
		||||
                    &beta,  c,       size_n);
 | 
			
		||||
    }
 | 
			
		||||
    else if (use_exllama)
 | 
			
		||||
    {
 | 
			
		||||
        // Quantized matmul
 | 
			
		||||
        int max_chunks = size_m / BLOCK_M_SIZE_MAX;
 | 
			
		||||
        int last_chunk = max_chunks * BLOCK_M_SIZE_MAX;
 | 
			
		||||
        int last_chunk_size = size_m - last_chunk;
 | 
			
		||||
 | 
			
		||||
        if (max_chunks)
 | 
			
		||||
        {
 | 
			
		||||
            gemm_half_q_half_cuda_part(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
 | 
			
		||||
                                        c, last_chunk, size_n, size_k, BLOCK_M_SIZE_MAX,
 | 
			
		||||
                                        groups);
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        if (last_chunk_size)
 | 
			
		||||
        {
 | 
			
		||||
            gemm_half_q_half_cuda_part(a + last_chunk * size_k, b_q_weight, b_gptq_qzeros,
 | 
			
		||||
                                        b_gptq_scales, b_g_idx, c + last_chunk * size_n,
 | 
			
		||||
                                        last_chunk_size, size_n, size_k, last_chunk_size,
 | 
			
		||||
                                        groups);
 | 
			
		||||
        }
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
        gemm_half_q_half_alt(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx,
 | 
			
		||||
                             c, size_m, size_n, size_k);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__global__ void shuffle_kernel
 | 
			
		||||
(
 | 
			
		||||
    uint32_t* __restrict__ b_q_weight,
 | 
			
		||||
    const int size_k,
 | 
			
		||||
    const int size_n
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    int n = blockIdx.x * THREADS_X + threadIdx.x;
 | 
			
		||||
    if (n >= size_n) return;
 | 
			
		||||
    int k = 0;
 | 
			
		||||
    uint32_t* b_ptr = b_q_weight + n;
 | 
			
		||||
    while (k < size_k) { shuffle_4bit_8 (b_ptr, size_n); b_ptr += 1 * size_n; k +=  8; }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__global__ void make_sequential_kernel
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t* __restrict__ w,
 | 
			
		||||
    uint32_t* __restrict__ w_new,
 | 
			
		||||
    const int* __restrict__ q_perm,
 | 
			
		||||
    const int w_height,
 | 
			
		||||
    const int w_width
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    const uint64_t* w2 = (uint64_t*) w;
 | 
			
		||||
    uint64_t* w_new2 = (uint64_t*) w_new;
 | 
			
		||||
    int w2_stride = w_width >> 1;
 | 
			
		||||
    int w2_column = THREADS_X * blockIdx.x + threadIdx.x;
 | 
			
		||||
    if (w2_column >= w2_stride) return;
 | 
			
		||||
    int w_new2_row = blockIdx.y;
 | 
			
		||||
    int q_perm_idx = w_new2_row << 3;
 | 
			
		||||
    uint64_t dst = 0;
 | 
			
		||||
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int i = 0; i < 8; i++)
 | 
			
		||||
    {
 | 
			
		||||
        int source_row = q_perm[q_perm_idx++];
 | 
			
		||||
 | 
			
		||||
        int w2_row = source_row >> 3;
 | 
			
		||||
        int w2_subrow = source_row & 0x07;
 | 
			
		||||
        int w2_row_shift = w2_subrow << 2;
 | 
			
		||||
        int wnew2_row_shift = i << 2;
 | 
			
		||||
 | 
			
		||||
        uint64_t src = w2[w2_row * w2_stride + w2_column];
 | 
			
		||||
        src >>= w2_row_shift;
 | 
			
		||||
        src &= 0x0000000f0000000f;
 | 
			
		||||
        src <<= wnew2_row_shift;
 | 
			
		||||
        dst |= src;
 | 
			
		||||
    }
 | 
			
		||||
    w_new2[w_new2_row * w2_stride + w2_column] = dst;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
void shuffle_exllama_weight
 | 
			
		||||
(
 | 
			
		||||
    uint32_t* q_weight,
 | 
			
		||||
    int* q_perm,
 | 
			
		||||
    int height,
 | 
			
		||||
    int width
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    if (q_perm)
 | 
			
		||||
    {
 | 
			
		||||
        uint32_t* new_qweight = NULL;
 | 
			
		||||
        cudaMalloc(&new_qweight, height / 8 * width * sizeof(uint32_t));
 | 
			
		||||
 | 
			
		||||
        dim3 blockDim, gridDim;
 | 
			
		||||
        blockDim.x = THREADS_X;
 | 
			
		||||
        blockDim.y = 1;
 | 
			
		||||
        gridDim.x = DIVIDE(width, THREADS_X);
 | 
			
		||||
        gridDim.y = height / 8;
 | 
			
		||||
 | 
			
		||||
        const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
        make_sequential_kernel<<<gridDim, blockDim, 0, stream>>>
 | 
			
		||||
        (
 | 
			
		||||
            q_weight,
 | 
			
		||||
            new_qweight,
 | 
			
		||||
            q_perm,
 | 
			
		||||
            height / 8,
 | 
			
		||||
            width
 | 
			
		||||
        );
 | 
			
		||||
        // Replace qweights
 | 
			
		||||
        cudaMemcpyAsync(q_weight, new_qweight, height / 8 * width * sizeof(uint32_t), cudaMemcpyDeviceToDevice);
 | 
			
		||||
        // Cleanup
 | 
			
		||||
        cudaDeviceSynchronize();
 | 
			
		||||
        cudaFree(new_qweight);
 | 
			
		||||
    }
 | 
			
		||||
    dim3 blockDim, gridDim;
 | 
			
		||||
    blockDim.x = THREADS_X;
 | 
			
		||||
    blockDim.y = 1;
 | 
			
		||||
    gridDim.x = DIVIDE(width, THREADS_X);
 | 
			
		||||
    gridDim.y = 1;
 | 
			
		||||
    const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
    shuffle_kernel<<<gridDim, blockDim, 0, stream>>>(q_weight, height, width);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gptq
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
torch::Tensor gptq_gemm
 | 
			
		||||
(
 | 
			
		||||
    torch::Tensor a,
 | 
			
		||||
    torch::Tensor b_q_weight,
 | 
			
		||||
    torch::Tensor b_gptq_qzeros,
 | 
			
		||||
    torch::Tensor b_gptq_scales,
 | 
			
		||||
    torch::Tensor b_g_idx,
 | 
			
		||||
    bool use_exllama
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
 | 
			
		||||
    auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
 | 
			
		||||
    at::Tensor c = torch::empty({a.size(0), b_q_weight.size(1)}, options);
 | 
			
		||||
    at::Tensor temp_dq = torch::empty({b_q_weight.size(0) * 8, b_q_weight.size(1)}, options);
 | 
			
		||||
 | 
			
		||||
    vllm::gptq::gemm_half_q_half_cuda
 | 
			
		||||
    (
 | 
			
		||||
        at::cuda::getCurrentCUDABlasHandle(),
 | 
			
		||||
        (const half*) a.data_ptr(),
 | 
			
		||||
        (const uint32_t*) b_q_weight.data_ptr(),
 | 
			
		||||
        (const uint32_t*)b_gptq_qzeros.data_ptr(),
 | 
			
		||||
        (const half*) b_gptq_scales.data_ptr(),
 | 
			
		||||
        b_g_idx.device().is_meta() ? NULL : (const int*) b_g_idx.data_ptr(),
 | 
			
		||||
        (half*) c.data_ptr(),
 | 
			
		||||
        (half*) temp_dq.data_ptr(),
 | 
			
		||||
        c.size(0),  // m
 | 
			
		||||
        c.size(1),  // n
 | 
			
		||||
        a.size(1),  // k
 | 
			
		||||
        b_gptq_qzeros.size(0),  // group number
 | 
			
		||||
        use_exllama
 | 
			
		||||
    );
 | 
			
		||||
    return c;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
void gptq_shuffle
 | 
			
		||||
(
 | 
			
		||||
    torch::Tensor q_weight,
 | 
			
		||||
    torch::Tensor q_perm
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    const at::cuda::OptionalCUDAGuard device_guard(device_of(q_weight));
 | 
			
		||||
    vllm::gptq::shuffle_exllama_weight(
 | 
			
		||||
        (uint32_t*) q_weight.data_ptr(),
 | 
			
		||||
        q_perm.device().is_meta() ? NULL : (int*) q_perm.data_ptr(),
 | 
			
		||||
        q_weight.size(0) * 8,
 | 
			
		||||
        q_weight.size(1)
 | 
			
		||||
    );
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										235
									
								
								csrc/quantization/gptq/qdq_4.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										235
									
								
								csrc/quantization/gptq/qdq_4.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,235 @@
 | 
			
		||||
/*
 | 
			
		||||
Copied from https://github.com/turboderp/exllamav2
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#ifndef _qdq_4_cuh
 | 
			
		||||
#define _qdq_4_cuh
 | 
			
		||||
 | 
			
		||||
#include "qdq_util.cuh"
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace gptq {
 | 
			
		||||
// Permutation:
 | 
			
		||||
//
 | 
			
		||||
// 77775555 33331111  66664444 22220000
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void shuffle_4bit_8
 | 
			
		||||
(
 | 
			
		||||
    uint32_t* q,
 | 
			
		||||
    int stride
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    uint32_t qa = q[0];
 | 
			
		||||
    uint32_t qb = 0;
 | 
			
		||||
 | 
			
		||||
    #pragma unroll
 | 
			
		||||
    for (int i = 0; i < 4; i++)
 | 
			
		||||
    {
 | 
			
		||||
        uint32_t qa0 = qa & 0x0f;
 | 
			
		||||
        uint32_t qa1 = (qa & 0xf0) >> 4;
 | 
			
		||||
        qa >>= 8;
 | 
			
		||||
        qb |= (qa1 << (i * 4 + 16));
 | 
			
		||||
        qb |= (qa0 << (i * 4));
 | 
			
		||||
    }
 | 
			
		||||
    q[0] = qb;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t q_0,
 | 
			
		||||
    half2 (&dq)[4],
 | 
			
		||||
    int stride
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    const uint32_t c0 = 0x64006400;
 | 
			
		||||
    const half y16_ = __float2half_rn(1.0f / 16.0f);
 | 
			
		||||
    const half2 y16 = __halves2half2(y16_, y16_);
 | 
			
		||||
    const half z1_  = __float2half_rn(-1024.0f         - 8.0f);
 | 
			
		||||
    const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
 | 
			
		||||
    const half2 z1  = __halves2half2(z1_,  z1_);
 | 
			
		||||
    const half2 z16 = __halves2half2(z16_, z16_);
 | 
			
		||||
 | 
			
		||||
    uint32_t qa = q_0;
 | 
			
		||||
    half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1])      + 1024
 | 
			
		||||
    half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
 | 
			
		||||
    qa >>= 8;
 | 
			
		||||
    half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5])      + 1024
 | 
			
		||||
    half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
 | 
			
		||||
 | 
			
		||||
    dq[0] = __hadd2(q0.as_half2, z1);
 | 
			
		||||
    dq[1] = __hfma2(q1.as_half2, y16, z16);
 | 
			
		||||
    dq[2] = __hadd2(q2.as_half2, z1);
 | 
			
		||||
    dq[3] = __hfma2(q3.as_half2, y16, z16);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t zero,
 | 
			
		||||
    const half scale,
 | 
			
		||||
    half2 (&z1z16)[2],
 | 
			
		||||
    half2 (&y1y16)[2]
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
 | 
			
		||||
    half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
 | 
			
		||||
 | 
			
		||||
    half2 scale2 = __half2half2(scale);
 | 
			
		||||
 | 
			
		||||
    z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
 | 
			
		||||
    z1z16[1] = __hmul2(scale2, __half2half2(z16));
 | 
			
		||||
 | 
			
		||||
    const half y1 = __float2half_rn(1.0f);
 | 
			
		||||
    const half y16 = __float2half_rn(1.0f / 16.0f);
 | 
			
		||||
 | 
			
		||||
    y1y16[0] = __hmul2(scale2, __half2half2(y1));
 | 
			
		||||
    y1y16[1] = __hmul2(scale2, __half2half2(y16));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t zero,
 | 
			
		||||
    half2(&z1z16)[2],
 | 
			
		||||
    half2(&y1y16)[2]
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
 | 
			
		||||
    half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
 | 
			
		||||
 | 
			
		||||
    z1z16[0] = __half2half2(z1.as_half);
 | 
			
		||||
    z1z16[1] = __half2half2(z16);
 | 
			
		||||
 | 
			
		||||
    const half y1 = __float2half_rn(1.0f);
 | 
			
		||||
    const half y16 = __float2half_rn(1.0f / 16.0f);
 | 
			
		||||
 | 
			
		||||
    y1y16[0] = __half2half2(y1);
 | 
			
		||||
    y1y16[1] = __half2half2(y16);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8_gptq
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t q_0,
 | 
			
		||||
    half2 (&dq)[4],
 | 
			
		||||
    half2 (&z1z16)[2],
 | 
			
		||||
    half2 (&y1y16)[2],
 | 
			
		||||
    int stride,
 | 
			
		||||
    bool scaled
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    const uint32_t c0 = 0x64006400;
 | 
			
		||||
 | 
			
		||||
    uint32_t qa = q_0;
 | 
			
		||||
    half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0]      + 1024, q[1]      + 1024 )
 | 
			
		||||
    half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
 | 
			
		||||
    qa >>= 8;
 | 
			
		||||
    half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4]      + 1024, q[5]      + 1024 )
 | 
			
		||||
    half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
 | 
			
		||||
 | 
			
		||||
    if (scaled)
 | 
			
		||||
    {
 | 
			
		||||
        dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]);  // half2( q[0] * s - z * s, q[1] * s - z * s)
 | 
			
		||||
        dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]);  // half2( q[2] * s - z * s, q[3] * s - z * s)
 | 
			
		||||
        dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
 | 
			
		||||
        dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
        dq[0] = __hadd2(q0.as_half2,           z1z16[0]);  // half2( q[0] - z, q[1] - z )
 | 
			
		||||
        dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]);  // half2( q[2] - z, q[3] - z )
 | 
			
		||||
        dq[2] = __hadd2(q2.as_half2,           z1z16[0]);  // half2( q[4] - z, q[5] - z )
 | 
			
		||||
        dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);  // half2( q[6] - z, q[7] - z )
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
}  // namespace gptq
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
#else
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace gptq {
 | 
			
		||||
__forceinline__ __device__ void shuffle_4bit_8
 | 
			
		||||
(
 | 
			
		||||
    uint32_t* q,
 | 
			
		||||
    int stride
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t q_0,
 | 
			
		||||
    half2 (&dq)[4],
 | 
			
		||||
    int stride
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    half dqh[8];
 | 
			
		||||
    for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
 | 
			
		||||
 | 
			
		||||
    for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t zero,
 | 
			
		||||
    const half scale,
 | 
			
		||||
    half2 (&z1)[2],
 | 
			
		||||
    half2 (&y1)[2]
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    half z = __int2half_rn(-((int)zero));
 | 
			
		||||
    z = __hmul(z, scale);
 | 
			
		||||
    z1[0] = __half2half2(z);
 | 
			
		||||
    y1[0] = __half2half2(scale);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8_prep_zero
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t zero,
 | 
			
		||||
    half2(&z1)[2],
 | 
			
		||||
    half2(&y1)[2]
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    half z = __int2half_rn(-((int)zero));
 | 
			
		||||
    z1[0] = __half2half2(z);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ void dequant_4bit_8_gptq
 | 
			
		||||
(
 | 
			
		||||
    const uint32_t q_0,
 | 
			
		||||
    half2 (&dq)[4],
 | 
			
		||||
    half2 (&z1)[2],
 | 
			
		||||
    half2 (&y1)[2],
 | 
			
		||||
    int stride,
 | 
			
		||||
    bool scaled
 | 
			
		||||
)
 | 
			
		||||
{
 | 
			
		||||
    half2 dqh2[8];
 | 
			
		||||
 | 
			
		||||
    uint32_t qa = q_0;
 | 
			
		||||
    for (int i = 0; i < 4; i++)
 | 
			
		||||
    {
 | 
			
		||||
        half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
 | 
			
		||||
        half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
 | 
			
		||||
        dqh2[i] = __halves2half2(d0, d1);
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    if (scaled)
 | 
			
		||||
    {
 | 
			
		||||
        dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
 | 
			
		||||
        dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
 | 
			
		||||
        dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
 | 
			
		||||
        dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
 | 
			
		||||
    }
 | 
			
		||||
    else
 | 
			
		||||
    {
 | 
			
		||||
        dq[0] = __hadd2(dqh2[0], z1[0]);
 | 
			
		||||
        dq[1] = __hadd2(dqh2[1], z1[0]);
 | 
			
		||||
        dq[2] = __hadd2(dqh2[2], z1[0]);
 | 
			
		||||
        dq[3] = __hadd2(dqh2[3], z1[0]);
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gptq
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
 | 
			
		||||
#endif
 | 
			
		||||
							
								
								
									
										60
									
								
								csrc/quantization/gptq/qdq_util.cuh
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										60
									
								
								csrc/quantization/gptq/qdq_util.cuh
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,60 @@
 | 
			
		||||
/*
 | 
			
		||||
Copied from https://github.com/turboderp/exllamav2
 | 
			
		||||
*/
 | 
			
		||||
 | 
			
		||||
#ifndef _qdq_util_cuh
 | 
			
		||||
#define _qdq_util_cuh
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
namespace gptq {
 | 
			
		||||
 | 
			
		||||
union half2_uint32
 | 
			
		||||
{
 | 
			
		||||
    uint32_t as_uint32;
 | 
			
		||||
    half2 as_half2;
 | 
			
		||||
    __device__ half2_uint32(uint32_t val) : as_uint32(val) {}
 | 
			
		||||
    __device__ half2_uint32(half2 val) : as_half2(val) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
union half_uint16
 | 
			
		||||
{
 | 
			
		||||
    uint16_t as_uint16;
 | 
			
		||||
    half as_half;
 | 
			
		||||
    __device__ half_uint16(uint16_t val) : as_uint16(val) {}
 | 
			
		||||
    __device__ half_uint16(half val) : as_half(val) {}
 | 
			
		||||
};
 | 
			
		||||
 | 
			
		||||
// Max_scale premultiplied by 1/256
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
 | 
			
		||||
{
 | 
			
		||||
    int qs_i = qs + 1;
 | 
			
		||||
    half qs_h = __int2half_rn(qs_i * qs_i);
 | 
			
		||||
    qs_h = __hmul(qs_h, max_scale);
 | 
			
		||||
    return qs_h;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
 | 
			
		||||
{
 | 
			
		||||
    return __hmul(__int2half_rn(q - qzero), scale);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
 | 
			
		||||
{
 | 
			
		||||
    //return __hsub(__int2half_rn(q), __int2half_rn(qzero));
 | 
			
		||||
    return __int2half_rn(q - qzero);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
 | 
			
		||||
{
 | 
			
		||||
    return (int)((q >> shift) & mask);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
 | 
			
		||||
{
 | 
			
		||||
    return (int)(__funnelshift_rc(q0, q1, shift) & mask);
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
}  // namespace gptq
 | 
			
		||||
}  // namespace vllm
 | 
			
		||||
#endif
 | 
			
		||||
@ -7,6 +7,7 @@
 | 
			
		||||
// half-tensor
 | 
			
		||||
#include <c10/cuda/CUDAStream.h>
 | 
			
		||||
#include <ATen/cuda/CUDATensorMethods.cuh>
 | 
			
		||||
#include <c10/cuda/CUDAGuard.h>
 | 
			
		||||
 | 
			
		||||
#define BLOCKWIDTH 128
 | 
			
		||||
#define BLOCKHEIGHT4 16
 | 
			
		||||
@ -20,9 +21,17 @@ __device__ inline unsigned int as_unsigned(int i) {
 | 
			
		||||
 | 
			
		||||
// 4-bit matvec kernel (LUT-based)
 | 
			
		||||
__global__ void NUQ4MatMulKernel(
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    const  half2* __restrict__ vec,
 | 
			
		||||
#else
 | 
			
		||||
    const  __half2* __restrict__ vec,
 | 
			
		||||
#endif
 | 
			
		||||
    const    int* __restrict__ mat,
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
           half2* __restrict__ mul,
 | 
			
		||||
#else
 | 
			
		||||
          float2* __restrict__ mul,
 | 
			
		||||
#endif
 | 
			
		||||
    const  __half* __restrict__ lookup_table,
 | 
			
		||||
    int height,
 | 
			
		||||
    int width,
 | 
			
		||||
@ -35,7 +44,11 @@ __global__ void NUQ4MatMulKernel(
 | 
			
		||||
  int row = BLOCKHEIGHT4 * blockIdx.x;
 | 
			
		||||
  int col =  BLOCKWIDTH * blockIdx.y + threadIdx.x;
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  __shared__ half2 blockvec[blockwidth2];
 | 
			
		||||
#else
 | 
			
		||||
  __shared__ __half2 blockvec[blockwidth2];
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  __shared__ __half deq2[16][BLOCKWIDTH];
 | 
			
		||||
  int off = threadIdx.x;
 | 
			
		||||
@ -46,8 +59,13 @@ __global__ void NUQ4MatMulKernel(
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  __half res;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
  half2 res2;
 | 
			
		||||
  half2 tmp2;
 | 
			
		||||
#else
 | 
			
		||||
  __half2 res2;
 | 
			
		||||
  __half2 tmp2;
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
  int i;
 | 
			
		||||
  int k;
 | 
			
		||||
@ -68,48 +86,96 @@ __global__ void NUQ4MatMulKernel(
 | 
			
		||||
    while (k < blockwidth2) {
 | 
			
		||||
      tmp1 = as_unsigned(mat[i]);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
      res2 = {};
 | 
			
		||||
      tmp2 = {};
 | 
			
		||||
#else
 | 
			
		||||
      res2.x = __half_as_ushort(__float2half(0));
 | 
			
		||||
      res2.y = __half_as_ushort(__float2half(0));
 | 
			
		||||
      tmp2.x = __half_as_ushort(__float2half(0));
 | 
			
		||||
      tmp2.y = __half_as_ushort(__float2half(0));
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
      lut_index1 = tmp1 & 0xF;
 | 
			
		||||
      lut_index2 = (tmp1 >> 4) & 0xF;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
      tmp2.x = deq2[lut_index1][off];
 | 
			
		||||
      tmp2.y = deq2[lut_index2][off];
 | 
			
		||||
#else
 | 
			
		||||
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
 | 
			
		||||
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
 | 
			
		||||
#endif
 | 
			
		||||
      res2 = __hfma2(tmp2, blockvec[k + 0], res2);
 | 
			
		||||
 | 
			
		||||
      lut_index1 = (tmp1 >> 8) & 0xF;
 | 
			
		||||
      lut_index2 = (tmp1 >> 12) & 0xF;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
      tmp2.x = deq2[lut_index1][off];
 | 
			
		||||
      tmp2.y = deq2[lut_index2][off];
 | 
			
		||||
#else
 | 
			
		||||
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
 | 
			
		||||
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
 | 
			
		||||
#endif
 | 
			
		||||
      res2 = __hfma2(tmp2, blockvec[k + 1], res2);
 | 
			
		||||
 | 
			
		||||
      lut_index1 = (tmp1 >> 16) & 0xF;
 | 
			
		||||
      lut_index2 = (tmp1 >> 20) & 0xF;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
      tmp2.x = deq2[lut_index1][off];
 | 
			
		||||
      tmp2.y = deq2[lut_index2][off];
 | 
			
		||||
#else
 | 
			
		||||
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
 | 
			
		||||
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
 | 
			
		||||
#endif
 | 
			
		||||
      res2 = __hfma2(tmp2, blockvec[k + 2], res2);
 | 
			
		||||
 | 
			
		||||
      lut_index1 = (tmp1 >> 24) & 0xF;
 | 
			
		||||
      lut_index2 = (tmp1 >> 28) & 0xF;
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
      tmp2.x = deq2[lut_index1][off];
 | 
			
		||||
      tmp2.y = deq2[lut_index2][off];
 | 
			
		||||
#else
 | 
			
		||||
      tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
 | 
			
		||||
      tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
 | 
			
		||||
#endif
 | 
			
		||||
      res2 = __hfma2(tmp2, blockvec[k + 3], res2);
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
      res = __hadd(__hadd(res2.x, res2.y), res);
 | 
			
		||||
#else
 | 
			
		||||
      res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)), res);
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
      i += width;
 | 
			
		||||
      k += 4;
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    // col%2 -> only set one of the two values
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    half2 res3 = {};
 | 
			
		||||
    if (col % 2 == 0) {
 | 
			
		||||
      res3.x = res;
 | 
			
		||||
    } else {
 | 
			
		||||
      res3.y = res;
 | 
			
		||||
    }
 | 
			
		||||
#else
 | 
			
		||||
    __half2 res3;
 | 
			
		||||
    res3.x = __half_as_ushort(__float2half(0));
 | 
			
		||||
    res3.y = __half_as_ushort(__float2half(0));
 | 
			
		||||
    if (col % 2 == 0) {
 | 
			
		||||
      res3.x = __half_as_ushort(res);
 | 
			
		||||
    } else {
 | 
			
		||||
      res3.y = __half_as_ushort(res);
 | 
			
		||||
    }
 | 
			
		||||
#endif
 | 
			
		||||
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    atomicAdd(&mul[b * width / 2 + col / 2], res3);
 | 
			
		||||
#else
 | 
			
		||||
    int tmp_addr = b * width / 2 + col / 2;
 | 
			
		||||
    atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
 | 
			
		||||
    atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
 | 
			
		||||
#endif
 | 
			
		||||
  }
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -135,11 +201,22 @@ void squeezellm_gemm(
 | 
			
		||||
  );
 | 
			
		||||
  dim3 threads(BLOCKWIDTH);
 | 
			
		||||
 | 
			
		||||
  vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads>>>(
 | 
			
		||||
  const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
 | 
			
		||||
  const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
 | 
			
		||||
  vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    (half2*) vec.data<at::Half>(),
 | 
			
		||||
#else
 | 
			
		||||
    (__half2*) vec.data_ptr<at::Half>(),
 | 
			
		||||
#endif
 | 
			
		||||
    mat.data_ptr<int>(),
 | 
			
		||||
#ifndef USE_ROCM
 | 
			
		||||
    (half2*) mul.data<at::Half>(),
 | 
			
		||||
    (__half*) lookup_table.data<at::Half>(),
 | 
			
		||||
#else
 | 
			
		||||
    (float2*) mul.data_ptr<float>(),
 | 
			
		||||
    (__half*) lookup_table.data_ptr<at::Half>(),
 | 
			
		||||
#endif
 | 
			
		||||
    height, width, batch, vec_height
 | 
			
		||||
  );
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@ -17,13 +17,15 @@
 | 
			
		||||
 */
 | 
			
		||||
#pragma once
 | 
			
		||||
 | 
			
		||||
#include "cuda_compat.h"
 | 
			
		||||
 | 
			
		||||
namespace vllm {
 | 
			
		||||
 | 
			
		||||
template<typename T>
 | 
			
		||||
__inline__ __device__ T warpReduceSum(T val) {
 | 
			
		||||
#pragma unroll
 | 
			
		||||
  for (int mask = 16; mask > 0; mask >>= 1)
 | 
			
		||||
    val += __shfl_xor_sync(0xffffffff, val, mask, 32);
 | 
			
		||||
    val += VLLM_SHFL_XOR_SYNC(val, mask);
 | 
			
		||||
  return val;
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,11 +9,15 @@
 | 
			
		||||
# If extensions (or modules to document with autodoc) are in another directory,
 | 
			
		||||
# add these directories to sys.path here. If the directory is relative to the
 | 
			
		||||
# documentation root, use os.path.abspath to make it absolute, like shown here.
 | 
			
		||||
#
 | 
			
		||||
# import os
 | 
			
		||||
# import sys
 | 
			
		||||
# sys.path.insert(0, os.path.abspath('.'))
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from sphinx.ext import autodoc
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
 | 
			
		||||
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
# -- Project information -----------------------------------------------------
 | 
			
		||||
 | 
			
		||||
@ -21,7 +25,6 @@ project = 'vLLM'
 | 
			
		||||
copyright = '2023, vLLM Team'
 | 
			
		||||
author = 'the vLLM Team'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# -- General configuration ---------------------------------------------------
 | 
			
		||||
 | 
			
		||||
# Add any Sphinx extension module names here, as strings. They can be
 | 
			
		||||
@ -32,6 +35,8 @@ extensions = [
 | 
			
		||||
    "sphinx.ext.viewcode",
 | 
			
		||||
    "sphinx.ext.intersphinx",
 | 
			
		||||
    "sphinx_copybutton",
 | 
			
		||||
    "sphinx.ext.autodoc",
 | 
			
		||||
    "sphinx.ext.autosummary",
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
# Add any paths that contain templates here, relative to this directory.
 | 
			
		||||
@ -55,7 +60,6 @@ html_title = project
 | 
			
		||||
html_theme = 'sphinx_book_theme'
 | 
			
		||||
html_logo = 'assets/logos/vllm-logo-text-light.png'
 | 
			
		||||
html_theme_options = {
 | 
			
		||||
    'logo_only': True,
 | 
			
		||||
    'path_to_docs': 'docs/source',
 | 
			
		||||
    'repository_url': 'https://github.com/vllm-project/vllm',
 | 
			
		||||
    'use_repository_button': True,
 | 
			
		||||
@ -64,4 +68,31 @@ html_theme_options = {
 | 
			
		||||
# Add any paths that contain custom static files (such as style sheets) here,
 | 
			
		||||
# relative to this directory. They are copied after the builtin static files,
 | 
			
		||||
# so a file named "default.css" will overwrite the builtin "default.css".
 | 
			
		||||
html_static_path = ['_static']
 | 
			
		||||
# html_static_path = ['_static']
 | 
			
		||||
 | 
			
		||||
# Mock out external dependencies here.
 | 
			
		||||
autodoc_mock_imports = [
 | 
			
		||||
    "torch", "transformers", "psutil", "aioprometheus", "sentencepiece",
 | 
			
		||||
    "vllm.cuda_utils", "vllm._C"
 | 
			
		||||
]
 | 
			
		||||
 | 
			
		||||
for mock_target in autodoc_mock_imports:
 | 
			
		||||
    if mock_target in sys.modules:
 | 
			
		||||
        logger.info(
 | 
			
		||||
            f"Potentially problematic mock target ({mock_target}) found; "
 | 
			
		||||
            "autodoc_mock_imports cannot mock modules that have already "
 | 
			
		||||
            "been loaded into sys.modules when the sphinx build starts.")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MockedClassDocumenter(autodoc.ClassDocumenter):
 | 
			
		||||
    """Remove note about base class when a class is derived from object."""
 | 
			
		||||
 | 
			
		||||
    def add_line(self, line: str, source: str, *lineno: int) -> None:
 | 
			
		||||
        if line == "   Bases: :py:class:`object`":
 | 
			
		||||
            return
 | 
			
		||||
        super().add_line(line, source, *lineno)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
autodoc.ClassDocumenter = MockedClassDocumenter
 | 
			
		||||
 | 
			
		||||
navigation_with_keys = False
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										7
									
								
								docs/source/dev/engine/async_llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								docs/source/dev/engine/async_llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,7 @@
 | 
			
		||||
 | 
			
		||||
AsyncLLMEngine
 | 
			
		||||
=================================
 | 
			
		||||
 | 
			
		||||
.. autoclass:: vllm.engine.async_llm_engine.AsyncLLMEngine
 | 
			
		||||
    :members: generate, abort
 | 
			
		||||
    :show-inheritance:
 | 
			
		||||
							
								
								
									
										13
									
								
								docs/source/dev/engine/engine_index.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										13
									
								
								docs/source/dev/engine/engine_index.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,13 @@
 | 
			
		||||
vLLM Engine
 | 
			
		||||
=================================
 | 
			
		||||
 | 
			
		||||
.. automodule:: vllm.engine
 | 
			
		||||
.. currentmodule:: vllm.engine
 | 
			
		||||
 | 
			
		||||
.. toctree::
 | 
			
		||||
   :maxdepth: 2
 | 
			
		||||
   :caption: Engines
 | 
			
		||||
 | 
			
		||||
   llm_engine
 | 
			
		||||
   async_llm_engine
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										6
									
								
								docs/source/dev/engine/llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								docs/source/dev/engine/llm_engine.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,6 @@
 | 
			
		||||
LLMEngine
 | 
			
		||||
=================================
 | 
			
		||||
 | 
			
		||||
.. autoclass:: vllm.engine.llm_engine.LLMEngine
 | 
			
		||||
    :members: add_request, abort_request, step, _init_cache
 | 
			
		||||
    :show-inheritance:
 | 
			
		||||
							
								
								
									
										172
									
								
								docs/source/getting_started/amd-installation.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										172
									
								
								docs/source/getting_started/amd-installation.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,172 @@
 | 
			
		||||
.. _installation_rocm:
 | 
			
		||||
 | 
			
		||||
Installation with ROCm
 | 
			
		||||
======================
 | 
			
		||||
 | 
			
		||||
vLLM 0.2.4 onwards supports model inferencing and serving on AMD GPUs with ROCm.
 | 
			
		||||
At the moment AWQ quantization is not supported in ROCm, but SqueezeLLM quantization has been ported.
 | 
			
		||||
Data types currently supported in ROCm are FP16 and BF16.
 | 
			
		||||
 | 
			
		||||
Requirements
 | 
			
		||||
------------
 | 
			
		||||
 | 
			
		||||
* OS: Linux
 | 
			
		||||
* Python: 3.8 -- 3.11
 | 
			
		||||
* GPU: MI200s (gfx90a), MI300 (gfx942), Radeon RX 7900 series (gfx1100)
 | 
			
		||||
* Pytorch 2.0.1/2.1.1/2.2
 | 
			
		||||
* ROCm 5.7 (Verified on python 3.10) or ROCm 6.0 (Verified on python 3.9)
 | 
			
		||||
 | 
			
		||||
Installation options:
 | 
			
		||||
 | 
			
		||||
#. :ref:`(Recommended) Quick start with vLLM pre-installed in Docker Image <quick_start_docker_rocm>`
 | 
			
		||||
#. :ref:`Build from source <build_from_source_rocm>`
 | 
			
		||||
#. :ref:`Build from source with docker <build_from_source_docker_rocm>`
 | 
			
		||||
 | 
			
		||||
.. _quick_start_docker_rocm:
 | 
			
		||||
 | 
			
		||||
(Recommended) Option 1: Quick start with vLLM pre-installed in Docker Image
 | 
			
		||||
---------------------------------------------------------------------------
 | 
			
		||||
 | 
			
		||||
This option is for ROCm 5.7 only:
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ docker pull embeddedllminfo/vllm-rocm:vllm-v0.2.4
 | 
			
		||||
    $ docker run -it \
 | 
			
		||||
       --network=host \
 | 
			
		||||
       --group-add=video \
 | 
			
		||||
       --ipc=host \
 | 
			
		||||
       --cap-add=SYS_PTRACE \
 | 
			
		||||
       --security-opt seccomp=unconfined \
 | 
			
		||||
       --device /dev/kfd \
 | 
			
		||||
       --device /dev/dri \
 | 
			
		||||
       -v <path/to/model>:/app/model \
 | 
			
		||||
       embeddedllminfo/vllm-rocm \
 | 
			
		||||
       bash
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
.. _build_from_source_rocm:
 | 
			
		||||
 | 
			
		||||
Option 2: Build from source
 | 
			
		||||
---------------------------
 | 
			
		||||
 | 
			
		||||
You can build and install vLLM from source:
 | 
			
		||||
 | 
			
		||||
Below instruction is for ROCm 5.7 only. 
 | 
			
		||||
At the time of this documentation update, PyTorch on ROCm 6.0 wheel is not yet available on the PyTorch website.
 | 
			
		||||
 | 
			
		||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
 | 
			
		||||
 | 
			
		||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
 | 
			
		||||
- `Pytorch <https://pytorch.org/>`_
 | 
			
		||||
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ pip install torch==2.2.0.dev20231206+rocm5.7 --index-url https://download.pytorch.org/whl/nightly/rocm5.7 # tested version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
 | 
			
		||||
 | 
			
		||||
    Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
    - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
 | 
			
		||||
    - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
 | 
			
		||||
    - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
 | 
			
		||||
    - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
 | 
			
		||||
 | 
			
		||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
 | 
			
		||||
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ pip install xformers==0.0.23 --no-deps
 | 
			
		||||
        $ bash patch_xformers.rocm.sh
 | 
			
		||||
 | 
			
		||||
3. Build vLLM.
 | 
			
		||||
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ cd vllm
 | 
			
		||||
        $ pip install -U -r requirements-rocm.txt
 | 
			
		||||
        $ python setup.py install # This may take 5-10 minutes. Currently, `pip install .`` does not work for ROCm installation
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
.. _build_from_source_docker_rocm:
 | 
			
		||||
 | 
			
		||||
Option 3: Build from source with docker
 | 
			
		||||
-----------------------------------------------------
 | 
			
		||||
 | 
			
		||||
You can build and install vLLM from source:
 | 
			
		||||
 | 
			
		||||
Build a docker image from `Dockerfile.rocm`, and launch a docker container.
 | 
			
		||||
 | 
			
		||||
The `Dokerfile.rocm` is designed to support both ROCm 5.7 and ROCm 6.0 and later versions. It provides flexibility to customize the build of docker image using the following arguments:
 | 
			
		||||
 | 
			
		||||
* `BASE_IMAGE`: specifies the base image used when running ``docker build``, specifically the PyTorch on ROCm base image. We have tested ROCm 5.7 and ROCm 6.0. The default is `rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1`
 | 
			
		||||
* `FX_GFX_ARCHS`: specifies the GFX architecture that is used to build flash-attention, for example, `gfx90a;gfx942` for MI200 and MI300. The default is `gfx90a;gfx942`
 | 
			
		||||
* `FA_BRANCH`: specifies the branch used to build the flash-attention in `ROCmSoftwarePlatform's flash-attention repo <https://github.com/ROCmSoftwarePlatform/flash-attention>`_. The default is `3d2b6f5`
 | 
			
		||||
* `BUILD_FA`: specifies whether to build flash-attention. For `Radeon RX 7900 series (gfx1100) <https://rocm.docs.amd.com/projects/radeon/en/latest/index.html>`_, this should be set to 0 before flash-attention supports this target.
 | 
			
		||||
 | 
			
		||||
Their values can be passed in when running ``docker build`` with ``--build-arg`` options.
 | 
			
		||||
 | 
			
		||||
For example, to build docker image for vllm on ROCm 5.7, you can run:
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ docker build --build-arg BASE_IMAGE="rocm/pytorch:rocm5.7_ubuntu22.04_py3.10_pytorch_2.0.1" \
 | 
			
		||||
       -f Dockerfile.rocm -t vllm-rocm . 
 | 
			
		||||
 | 
			
		||||
To build vllm on ROCm 6.0, you can use the default:
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ docker build -f Dockerfile.rocm -t vllm-rocm . 
 | 
			
		||||
    $ docker run -it \
 | 
			
		||||
       --network=host \
 | 
			
		||||
       --group-add=video \
 | 
			
		||||
       --ipc=host \
 | 
			
		||||
       --cap-add=SYS_PTRACE \
 | 
			
		||||
       --security-opt seccomp=unconfined \
 | 
			
		||||
       --device /dev/kfd \
 | 
			
		||||
       --device /dev/dri \
 | 
			
		||||
       -v <path/to/model>:/app/model \
 | 
			
		||||
       vllm-rocm \
 | 
			
		||||
       bash
 | 
			
		||||
 | 
			
		||||
Alternatively, if you plan to install vLLM-ROCm on a local machine or start from a fresh docker image (e.g. rocm/pytorch), you can follow the steps below:
 | 
			
		||||
 | 
			
		||||
0. Install prerequisites (skip if you are already in an environment/docker with the following installed):
 | 
			
		||||
 | 
			
		||||
- `ROCm <https://rocm.docs.amd.com/en/latest/deploy/linux/index.html>`_
 | 
			
		||||
- `Pytorch <https://pytorch.org/>`_
 | 
			
		||||
- `hipBLAS <https://rocm.docs.amd.com/projects/hipBLAS/en/latest/install.html>`_
 | 
			
		||||
 | 
			
		||||
1. Install `flash attention for ROCm <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm>`_
 | 
			
		||||
 | 
			
		||||
    Install ROCm's flash attention (v2.0.4) following the instructions from `ROCmSoftwarePlatform/flash-attention <https://github.com/ROCmSoftwarePlatform/flash-attention/tree/flash_attention_for_rocm#amd-gpurocm-support>`_
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
    - If you are using rocm5.7 with pytorch 2.1.0 onwards, you don't need to apply the `hipify_python.patch`. You can build the ROCm flash attention directly.
 | 
			
		||||
    - If you fail to install `ROCmSoftwarePlatform/flash-attention`, try cloning from the commit `6fd2f8e572805681cd67ef8596c7e2ce521ed3c6`.
 | 
			
		||||
    - ROCm's Flash-attention-2 (v2.0.4) does not support sliding windows attention.
 | 
			
		||||
    - You might need to downgrade the "ninja" version to 1.10 it is not used when compiling flash-attention-2 (e.g. `pip install ninja==1.10.2.4`)
 | 
			
		||||
 | 
			
		||||
2. Setup `xformers==0.0.23` without dependencies, and apply patches to adapt for ROCm flash attention
 | 
			
		||||
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ pip install xformers==0.0.23 --no-deps
 | 
			
		||||
        $ bash patch_xformers.rocm.sh
 | 
			
		||||
 | 
			
		||||
3. Build vLLM.
 | 
			
		||||
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ cd vllm
 | 
			
		||||
        $ pip install -U -r requirements-rocm.txt
 | 
			
		||||
        $ python setup.py install # This may take 5-10 minutes.
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
 | 
			
		||||
    - You may need to turn on the ``--enforce-eager`` flag if you experience process hang when running the `benchmark_thoughput.py` script to test your installation.
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ You can install vLLM using pip:
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ # (Optional) Create a new conda environment.
 | 
			
		||||
    $ conda create -n myenv python=3.8 -y
 | 
			
		||||
    $ conda create -n myenv python=3.9 -y
 | 
			
		||||
    $ conda activate myenv
 | 
			
		||||
 | 
			
		||||
    $ # Install vLLM with CUDA 12.1.
 | 
			
		||||
@ -34,13 +34,18 @@ You can install vLLM using pip:
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ # Install vLLM with CUDA 11.8.
 | 
			
		||||
        $ # Replace `cp310` with your Python version (e.g., `cp38`, `cp39`, `cp311`).
 | 
			
		||||
        $ pip install https://github.com/vllm-project/vllm/releases/download/v0.2.2/vllm-0.2.2+cu118-cp310-cp310-manylinux1_x86_64.whl
 | 
			
		||||
        $ export VLLM_VERSION=0.2.4
 | 
			
		||||
        $ export PYTHON_VERSION=39
 | 
			
		||||
        $ pip install https://github.com/vllm-project/vllm/releases/download/v${VLLM_VERSION}/vllm-${VLLM_VERSION}+cu118-cp${PYTHON_VERSION}-cp${PYTHON_VERSION}-manylinux1_x86_64.whl
 | 
			
		||||
 | 
			
		||||
        $ # Re-install PyTorch with CUDA 11.8.
 | 
			
		||||
        $ pip uninstall torch -y
 | 
			
		||||
        $ pip install torch --upgrade --index-url https://download.pytorch.org/whl/cu118
 | 
			
		||||
 | 
			
		||||
        $ # Re-install xFormers with CUDA 11.8.
 | 
			
		||||
        $ pip uninstall xformers -y
 | 
			
		||||
        $ pip install --upgrade xformers --index-url https://download.pytorch.org/whl/cu118
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
.. _build_from_source:
 | 
			
		||||
 | 
			
		||||
@ -62,3 +67,13 @@ You can also build and install vLLM from source:
 | 
			
		||||
 | 
			
		||||
        $ # Use `--ipc=host` to make sure the shared memory is large enough.
 | 
			
		||||
        $ docker run --gpus all -it --rm --ipc=host nvcr.io/nvidia/pytorch:23.10-py3
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
    If you are developing the C++ backend of vLLM, consider building vLLM with
 | 
			
		||||
 | 
			
		||||
    .. code-block:: console
 | 
			
		||||
 | 
			
		||||
        $ python setup.py develop
 | 
			
		||||
 | 
			
		||||
    since it will give you incremental builds. The downside is that this method
 | 
			
		||||
    is `deprecated by setuptools <https://github.com/pypa/setuptools/issues/917>`_.
 | 
			
		||||
 | 
			
		||||
@ -11,6 +11,14 @@ This guide shows how to use vLLM to:
 | 
			
		||||
 | 
			
		||||
Be sure to complete the :ref:`installation instructions <installation>` before continuing with this guide.
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
 | 
			
		||||
    By default, vLLM downloads model from `HuggingFace <https://huggingface.co/>`_. If you would like to use models from `ModelScope <https://www.modelscope.cn>`_ in the following examples, please set the environment variable:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: shell
 | 
			
		||||
 | 
			
		||||
        export VLLM_USE_MODELSCOPE=True
 | 
			
		||||
 | 
			
		||||
Offline Batched Inference
 | 
			
		||||
-------------------------
 | 
			
		||||
 | 
			
		||||
@ -40,16 +48,6 @@ Initialize vLLM's engine for offline inference with the ``LLM`` class and the `O
 | 
			
		||||
 | 
			
		||||
    llm = LLM(model="facebook/opt-125m")
 | 
			
		||||
 | 
			
		||||
Use model from www.modelscope.cn
 | 
			
		||||
 | 
			
		||||
.. code-block:: shell
 | 
			
		||||
 | 
			
		||||
    export VLLM_USE_MODELSCOPE=True
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
    llm = LLM(model="qwen/Qwen-7B-Chat", revision="v1.1.8", trust_remote_code=True)
 | 
			
		||||
 | 
			
		||||
Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM engine's waiting queue and executes the vLLM engine to generate the outputs with high throughput. The outputs are returned as a list of ``RequestOutput`` objects, which include all the output tokens.
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
@ -65,49 +63,11 @@ Call ``llm.generate`` to generate the outputs. It adds the input prompts to vLLM
 | 
			
		||||
 | 
			
		||||
The code example can also be found in `examples/offline_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/offline_inference.py>`_.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
API Server
 | 
			
		||||
----------
 | 
			
		||||
 | 
			
		||||
vLLM can be deployed as an LLM service. We provide an example `FastAPI <https://fastapi.tiangolo.com/>`_ server. Check `vllm/entrypoints/api_server.py <https://github.com/vllm-project/vllm/blob/main/vllm/entrypoints/api_server.py>`_ for the server implementation. The server uses ``AsyncLLMEngine`` class to support asynchronous processing of incoming requests.
 | 
			
		||||
 | 
			
		||||
Start the server:
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ python -m vllm.entrypoints.api_server
 | 
			
		||||
 | 
			
		||||
Use model from www.modelscope.cn
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.api_server \
 | 
			
		||||
    $    --model="qwen/Qwen-7B-Chat" \
 | 
			
		||||
    $    --revision="v1.1.8" \
 | 
			
		||||
    $    --trust-remote-code
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
By default, this command starts the server at ``http://localhost:8000`` with the OPT-125M model.
 | 
			
		||||
 | 
			
		||||
Query the model in shell:
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ curl http://localhost:8000/generate \
 | 
			
		||||
    $     -d '{
 | 
			
		||||
    $         "prompt": "San Francisco is a",
 | 
			
		||||
    $         "use_beam_search": true,
 | 
			
		||||
    $         "n": 4,
 | 
			
		||||
    $         "temperature": 0
 | 
			
		||||
    $     }'
 | 
			
		||||
 | 
			
		||||
See `examples/api_client.py <https://github.com/vllm-project/vllm/blob/main/examples/api_client.py>`_ for a more detailed client example.
 | 
			
		||||
 | 
			
		||||
OpenAI-Compatible Server
 | 
			
		||||
------------------------
 | 
			
		||||
 | 
			
		||||
vLLM can be deployed as a server that mimics the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
 | 
			
		||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the above command) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
 | 
			
		||||
vLLM can be deployed as a server that implements the OpenAI API protocol. This allows vLLM to be used as a drop-in replacement for applications using OpenAI API.
 | 
			
		||||
By default, it starts the server at ``http://localhost:8000``. You can specify the address with ``--host`` and ``--port`` arguments. The server currently hosts one model at a time (OPT-125M in the command below) and implements `list models <https://platform.openai.com/docs/api-reference/models/list>`_, `create chat completion <https://platform.openai.com/docs/api-reference/chat/completions/create>`_, and `create completion <https://platform.openai.com/docs/api-reference/completions/create>`_ endpoints. We are actively adding support for more endpoints.
 | 
			
		||||
 | 
			
		||||
Start the server:
 | 
			
		||||
 | 
			
		||||
@ -116,20 +76,13 @@ Start the server:
 | 
			
		||||
    $ python -m vllm.entrypoints.openai.api_server \
 | 
			
		||||
    $     --model facebook/opt-125m
 | 
			
		||||
 | 
			
		||||
Use model from www.modelscope.cn
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ VLLM_USE_MODELSCOPE=True python -m vllm.entrypoints.openai.api_server \
 | 
			
		||||
    $     --model="qwen/Qwen-7B-Chat" --revision="v1.1.8" --trust-remote-code
 | 
			
		||||
 | 
			
		||||
By default, the server uses a predefined chat template stored in the tokenizer. You can override this template by using the ``--chat-template`` argument:
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
   $ python -m vllm.entrypoints.openai.api_server \
 | 
			
		||||
   $     --model facebook/opt-125m \
 | 
			
		||||
   $     --chat-template ./examples/template_chatml.json
 | 
			
		||||
   $     --chat-template ./examples/template_chatml.jinja
 | 
			
		||||
 | 
			
		||||
This server can be queried in the same format as OpenAI API. For example, list the models:
 | 
			
		||||
 | 
			
		||||
@ -137,6 +90,8 @@ This server can be queried in the same format as OpenAI API. For example, list t
 | 
			
		||||
 | 
			
		||||
    $ curl http://localhost:8000/v1/models
 | 
			
		||||
 | 
			
		||||
You can pass in the argument ``--api-key`` or environment variable ``VLLM_API_KEY`` to enable the server to check for API key in the header.
 | 
			
		||||
 | 
			
		||||
Using OpenAI Completions API with vLLM
 | 
			
		||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -30,6 +30,8 @@ vLLM is fast with:
 | 
			
		||||
* State-of-the-art serving throughput
 | 
			
		||||
* Efficient management of attention key and value memory with **PagedAttention**
 | 
			
		||||
* Continuous batching of incoming requests
 | 
			
		||||
* Fast model execution with CUDA/HIP graph
 | 
			
		||||
* Quantization: `GPTQ <https://arxiv.org/abs/2210.17323>`_, `AWQ <https://arxiv.org/abs/2306.00978>`_, `SqueezeLLM <https://arxiv.org/abs/2306.07629>`_, FP8 KV Cache
 | 
			
		||||
* Optimized CUDA kernels
 | 
			
		||||
 | 
			
		||||
vLLM is flexible and easy to use with:
 | 
			
		||||
@ -39,6 +41,9 @@ vLLM is flexible and easy to use with:
 | 
			
		||||
* Tensor parallelism support for distributed inference
 | 
			
		||||
* Streaming outputs
 | 
			
		||||
* OpenAI-compatible API server
 | 
			
		||||
* Support NVIDIA GPUs and AMD GPUs
 | 
			
		||||
* (Experimental) Prefix caching support
 | 
			
		||||
* (Experimental) Multi-lora support
 | 
			
		||||
 | 
			
		||||
For more information, check out the following:
 | 
			
		||||
 | 
			
		||||
@ -56,6 +61,7 @@ Documentation
 | 
			
		||||
   :caption: Getting Started
 | 
			
		||||
 | 
			
		||||
   getting_started/installation
 | 
			
		||||
   getting_started/amd-installation
 | 
			
		||||
   getting_started/quickstart
 | 
			
		||||
 | 
			
		||||
.. toctree::
 | 
			
		||||
@ -76,9 +82,23 @@ Documentation
 | 
			
		||||
   models/supported_models
 | 
			
		||||
   models/adding_model
 | 
			
		||||
   models/engine_args
 | 
			
		||||
   models/lora
 | 
			
		||||
 | 
			
		||||
.. toctree::
 | 
			
		||||
   :maxdepth: 1
 | 
			
		||||
   :caption: Quantization
 | 
			
		||||
 | 
			
		||||
   quantization/auto_awq
 | 
			
		||||
   quantization/fp8_e5m2_kv_cache
 | 
			
		||||
 | 
			
		||||
.. toctree::
 | 
			
		||||
   :maxdepth: 2
 | 
			
		||||
   :caption: Developer Documentation
 | 
			
		||||
 | 
			
		||||
   dev/engine/engine_index
 | 
			
		||||
 | 
			
		||||
Indices and tables
 | 
			
		||||
==================
 | 
			
		||||
 | 
			
		||||
* :ref:`genindex`
 | 
			
		||||
* :ref:`modindex`
 | 
			
		||||
 | 
			
		||||
@ -26,7 +26,7 @@ This gives you the ability to modify the codebase and test your model.
 | 
			
		||||
------------------------
 | 
			
		||||
 | 
			
		||||
Clone the PyTorch model code from the HuggingFace Transformers repository and put it into the `vllm/model_executor/models <https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/models>`_ directory.
 | 
			
		||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adpated from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
 | 
			
		||||
For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/opt.py>`_ was adapted from the HuggingFace's `modeling_opt.py <https://github.com/huggingface/transformers/blob/main/src/transformers/models/opt/modeling_opt.py>`_ file.
 | 
			
		||||
 | 
			
		||||
.. warning::
 | 
			
		||||
    When copying the model code, make sure to review and adhere to the code's copyright and licensing terms.
 | 
			
		||||
@ -58,11 +58,10 @@ Next, you need to rewrite the :code:`forward` methods of your model by following
 | 
			
		||||
    +    positions: torch.Tensor,
 | 
			
		||||
    +    kv_caches: List[KVCache],
 | 
			
		||||
    +    input_metadata: InputMetadata,
 | 
			
		||||
    +    cache_events: Optional[List[torch.cuda.Event]],
 | 
			
		||||
    +) -> SamplerOutput:
 | 
			
		||||
    +) -> Optional[SamplerOutput]:
 | 
			
		||||
 | 
			
		||||
3. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
 | 
			
		||||
4. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
 | 
			
		||||
1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors.
 | 
			
		||||
2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture.
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
    Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings.
 | 
			
		||||
 | 
			
		||||
@ -89,9 +89,11 @@ Below, you can find an explanation of every engine argument for vLLM:
 | 
			
		||||
 | 
			
		||||
    CPU swap space size (GiB) per GPU.
 | 
			
		||||
 | 
			
		||||
.. option:: --gpu-memory-utilization <percentage>
 | 
			
		||||
.. option:: --gpu-memory-utilization <fraction>
 | 
			
		||||
 | 
			
		||||
    The percentage of GPU memory to be used for the model executor.
 | 
			
		||||
    The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. 
 | 
			
		||||
    For example, a value of 0.5 would imply 50% GPU memory utilization.
 | 
			
		||||
    If unspecified, will use the default value of 0.9.
 | 
			
		||||
 | 
			
		||||
.. option:: --max-num-batched-tokens <tokens>
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										91
									
								
								docs/source/models/lora.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										91
									
								
								docs/source/models/lora.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,91 @@
 | 
			
		||||
.. _lora:
 | 
			
		||||
 | 
			
		||||
Using LoRA adapters
 | 
			
		||||
===================
 | 
			
		||||
 | 
			
		||||
This document shows you how to use `LoRA adapters <https://arxiv.org/abs/2106.09685>`_ with vLLM on top of a base model.
 | 
			
		||||
Adapters can be efficiently served on a per request basis with minimal overhead. First we download the adapter(s) and save
 | 
			
		||||
them locally with
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
    from huggingface_hub import snapshot_download
 | 
			
		||||
 | 
			
		||||
    sql_lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Then we instantiate the base model and pass in the ``enable_lora=True`` flag:
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
    from vllm import LLM, SamplingParams
 | 
			
		||||
    from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
    llm = LLM(model="meta-llama/Llama-2-7b-hf", enable_lora=True)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
We can now submit the prompts and call ``llm.generate`` with the ``lora_request`` parameter. The first parameter
 | 
			
		||||
of ``LoRARequest`` is a human identifiable name, the second parameter is a globally unique ID for the adapter and
 | 
			
		||||
the third parameter is the path to the LoRA adapter.
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(
 | 
			
		||||
        temperature=0,
 | 
			
		||||
        max_tokens=256,
 | 
			
		||||
        stop=["[/assistant]"]
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    prompts = [
 | 
			
		||||
         "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
 | 
			
		||||
         "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
    outputs = llm.generate(
 | 
			
		||||
        prompts,
 | 
			
		||||
        sampling_params,
 | 
			
		||||
        lora_request=LoRARequest("sql_adapter", 1, sql_lora_path)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
Check out `examples/multilora_inference.py <https://github.com/vllm-project/vllm/blob/main/examples/multilora_inference.py>`_
 | 
			
		||||
for an example of how to use LoRA adapters with the async engine and how to use more advanced configuration options.
 | 
			
		||||
 | 
			
		||||
Serving LoRA Adapters
 | 
			
		||||
---------------------
 | 
			
		||||
LoRA adapted models can also be served with the Open-AI compatible vLLM server. To do so, we use
 | 
			
		||||
``--lora-modules {name}={path} {name}={path}`` to specify each LoRA module when we kickoff the server:
 | 
			
		||||
 | 
			
		||||
.. code-block:: bash
 | 
			
		||||
 | 
			
		||||
    python -m vllm.entrypoints.api_server \
 | 
			
		||||
        --model meta-llama/Llama-2-7b-hf \
 | 
			
		||||
        --enable-lora \
 | 
			
		||||
        --lora-modules sql-lora=~/.cache/huggingface/hub/models--yard1--llama-2-7b-sql-lora-test/
 | 
			
		||||
 | 
			
		||||
The server entrypoint accepts all other LoRA configuration parameters (``max_loras``, ``max_lora_rank``, ``max_cpu_loras``,
 | 
			
		||||
etc.), which will apply to all forthcoming requests. Upon querying the ``/models`` endpoint, we should see our LoRA along
 | 
			
		||||
with its base model:
 | 
			
		||||
 | 
			
		||||
.. code-block:: bash
 | 
			
		||||
 | 
			
		||||
    curl localhost:8000/v1/models | jq .
 | 
			
		||||
    {
 | 
			
		||||
        "object": "list",
 | 
			
		||||
        "data": [
 | 
			
		||||
            {
 | 
			
		||||
                "id": "meta-llama/Llama-2-7b-hf",
 | 
			
		||||
                "object": "model",
 | 
			
		||||
                ...
 | 
			
		||||
            },
 | 
			
		||||
            {
 | 
			
		||||
                "id": "sql-lora",
 | 
			
		||||
                "object": "model",
 | 
			
		||||
                ...
 | 
			
		||||
            }
 | 
			
		||||
        ]
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
Requests can specify the LoRA adapter as if it were any other model via the ``model`` request parameter. The requests will be
 | 
			
		||||
processed according to the server-wide LoRA configuration (i.e. in parallel with base model requests, and potentially other
 | 
			
		||||
LoRA adapter requests if they were provided and ``max_loras`` is set high enough).
 | 
			
		||||
@ -23,12 +23,18 @@ Alongside each architecture, we include some popular models that use it.
 | 
			
		||||
  * - :code:`ChatGLMModel`
 | 
			
		||||
    - ChatGLM
 | 
			
		||||
    - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc.
 | 
			
		||||
  * - :code:`DeciLMForCausalLM`
 | 
			
		||||
    - DeciLM
 | 
			
		||||
    - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc.
 | 
			
		||||
  * - :code:`BloomForCausalLM`
 | 
			
		||||
    - BLOOM, BLOOMZ, BLOOMChat
 | 
			
		||||
    - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
 | 
			
		||||
  * - :code:`FalconForCausalLM`
 | 
			
		||||
    - Falcon
 | 
			
		||||
    - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
 | 
			
		||||
  * - :code:`GemmaForCausalLM`
 | 
			
		||||
    - Gemma
 | 
			
		||||
    - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc.
 | 
			
		||||
  * - :code:`GPT2LMHeadModel`
 | 
			
		||||
    - GPT-2
 | 
			
		||||
    - :code:`gpt2`, :code:`gpt2-xl`, etc.
 | 
			
		||||
@ -44,32 +50,47 @@ Alongside each architecture, we include some popular models that use it.
 | 
			
		||||
  * - :code:`InternLMForCausalLM`
 | 
			
		||||
    - InternLM
 | 
			
		||||
    - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc.
 | 
			
		||||
  * - :code:`InternLM2ForCausalLM`
 | 
			
		||||
    - InternLM2
 | 
			
		||||
    - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc.
 | 
			
		||||
  * - :code:`LlamaForCausalLM`
 | 
			
		||||
    - LLaMA, LLaMA-2, Vicuna, Alpaca, Koala, Guanaco
 | 
			
		||||
    - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`young-geng/koala`, etc.
 | 
			
		||||
    - LLaMA, LLaMA-2, Vicuna, Alpaca, Yi
 | 
			
		||||
    - :code:`meta-llama/Llama-2-13b-hf`, :code:`meta-llama/Llama-2-70b-hf`, :code:`openlm-research/open_llama_13b`, :code:`lmsys/vicuna-13b-v1.3`, :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
 | 
			
		||||
  * - :code:`MistralForCausalLM`
 | 
			
		||||
    - Mistral, Mistral-Instruct
 | 
			
		||||
    - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc.
 | 
			
		||||
  * - :code:`MixtralForCausalLM`
 | 
			
		||||
    - Mixtral-8x7B, Mixtral-8x7B-Instruct
 | 
			
		||||
    - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, etc.
 | 
			
		||||
  * - :code:`MPTForCausalLM`
 | 
			
		||||
    - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter
 | 
			
		||||
    - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc.
 | 
			
		||||
  * - :code:`OLMoForCausalLM`
 | 
			
		||||
    - OLMo
 | 
			
		||||
    - :code:`allenai/OLMo-1B`, :code:`allenai/OLMo-7B`, etc.
 | 
			
		||||
  * - :code:`OPTForCausalLM`
 | 
			
		||||
    - OPT, OPT-IML
 | 
			
		||||
    - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc.
 | 
			
		||||
  * - :code:`PhiForCausalLM`
 | 
			
		||||
    - Phi-1.5
 | 
			
		||||
    - :code:`microsoft/phi-1_5`, etc.
 | 
			
		||||
    - Phi
 | 
			
		||||
    - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc.
 | 
			
		||||
  * - :code:`QWenLMHeadModel`
 | 
			
		||||
    - Qwen
 | 
			
		||||
    - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc.
 | 
			
		||||
  * - :code:`YiForCausalLM`
 | 
			
		||||
    - Yi
 | 
			
		||||
    - :code:`01-ai/Yi-6B`, :code:`01-ai/Yi-34B`, etc.
 | 
			
		||||
  * - :code:`Qwen2ForCausalLM`
 | 
			
		||||
    - Qwen2
 | 
			
		||||
    - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc.
 | 
			
		||||
  * - :code:`StableLMEpochForCausalLM`
 | 
			
		||||
    - StableLM
 | 
			
		||||
    - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc.
 | 
			
		||||
 | 
			
		||||
If your model uses one of the above model architectures, you can seamlessly run your model with vLLM.
 | 
			
		||||
Otherwise, please refer to :ref:`Adding a New Model <adding_a_new_model>` for instructions on how to implement support for your model.
 | 
			
		||||
Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-project/vllm/issues>`_ project.
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
    Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096.
 | 
			
		||||
 | 
			
		||||
.. tip::
 | 
			
		||||
    The easiest way to check if your model is supported is to run the program below:
 | 
			
		||||
 | 
			
		||||
@ -81,12 +102,17 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
 | 
			
		||||
        output = llm.generate("Hello, my name is")
 | 
			
		||||
        print(output)
 | 
			
		||||
 | 
			
		||||
    To use model from www.modelscope.cn
 | 
			
		||||
    If vLLM successfully generates text, it indicates that your model is supported.
 | 
			
		||||
 | 
			
		||||
.. tip::
 | 
			
		||||
    To use models from `ModelScope <https://www.modelscope.cn>`_ instead of HuggingFace Hub, set an environment variable:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: shell
 | 
			
		||||
 | 
			
		||||
       $ export VLLM_USE_MODELSCOPE=True
 | 
			
		||||
 | 
			
		||||
    And use with :code:`trust_remote_code=True`.
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python
 | 
			
		||||
 | 
			
		||||
        from vllm import LLM
 | 
			
		||||
@ -94,5 +120,3 @@ Alternatively, you can raise an issue on our `GitHub <https://github.com/vllm-pr
 | 
			
		||||
        llm = LLM(model=..., revision=..., trust_remote_code=True)  # Name or path of your model
 | 
			
		||||
        output = llm.generate("Hello, my name is")
 | 
			
		||||
        print(output)
 | 
			
		||||
 | 
			
		||||
    If vLLM successfully generates text, it indicates that your model is supported.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										33
									
								
								docs/source/quantization/fp8_e5m2_kv_cache.rst
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								docs/source/quantization/fp8_e5m2_kv_cache.rst
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
.. _fp8_e5m2_kv_cache:
 | 
			
		||||
 | 
			
		||||
FP8 E5M2 KV Cache
 | 
			
		||||
==================
 | 
			
		||||
 | 
			
		||||
The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits.
 | 
			
		||||
The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other.
 | 
			
		||||
 | 
			
		||||
Here is an example of how to enable this feature:
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
    from vllm import LLM, SamplingParams
 | 
			
		||||
    # Sample prompts.
 | 
			
		||||
    prompts = [
 | 
			
		||||
        "Hello, my name is",
 | 
			
		||||
        "The president of the United States is",
 | 
			
		||||
        "The capital of France is",
 | 
			
		||||
        "The future of AI is",
 | 
			
		||||
    ]
 | 
			
		||||
    # Create a sampling params object.
 | 
			
		||||
    sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 | 
			
		||||
    # Create an LLM.
 | 
			
		||||
    llm = LLM(model="facebook/opt-125m", kv_cache_dtype="fp8_e5m2")
 | 
			
		||||
    # Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
			
		||||
    # that contain the prompt, generated text, and other information.
 | 
			
		||||
    outputs = llm.generate(prompts, sampling_params)
 | 
			
		||||
    # Print the outputs.
 | 
			
		||||
    for output in outputs:
 | 
			
		||||
        prompt = output.prompt
 | 
			
		||||
        generated_text = output.outputs[0].text
 | 
			
		||||
        print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 | 
			
		||||
 | 
			
		||||
@ -29,7 +29,15 @@ You can build and run vLLM from source via the provided dockerfile. To build vLL
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai --build-arg max_jobs=8
 | 
			
		||||
    $ DOCKER_BUILDKIT=1 docker build . --target vllm-openai --tag vllm/vllm-openai # optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
.. note::
 | 
			
		||||
 | 
			
		||||
        By default vLLM will build for all GPU types for widest distribution. If you are just building for the
 | 
			
		||||
        current GPU type the machine is running on, you can add the argument ``--build-arg torch_cuda_arch_list=""``
 | 
			
		||||
        for vLLM to find the current GPU type and build for that.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
To run vLLM:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ Start the serving the LLaMA-13B model on an A100 GPU:
 | 
			
		||||
 | 
			
		||||
    $ sky launch serving.yaml
 | 
			
		||||
 | 
			
		||||
Check the output of the command. There will be a sharable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
 | 
			
		||||
Check the output of the command. There will be a shareable gradio link (like the last line of the following). Open it in your browser to use the LLaMA model to do the text completion.
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -9,13 +9,13 @@ To install langchain, run
 | 
			
		||||
 | 
			
		||||
.. code-block:: console
 | 
			
		||||
 | 
			
		||||
    $ pip install langchain -q
 | 
			
		||||
    $ pip install langchain langchain_community -q
 | 
			
		||||
 | 
			
		||||
To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langchain``.
 | 
			
		||||
 | 
			
		||||
.. code-block:: python
 | 
			
		||||
 | 
			
		||||
    from langchain.llms import VLLM
 | 
			
		||||
    from langchain_community.llms import VLLM
 | 
			
		||||
 | 
			
		||||
    llm = VLLM(model="mosaicml/mpt-7b",
 | 
			
		||||
               trust_remote_code=True,  # mandatory for hf models
 | 
			
		||||
@ -28,4 +28,4 @@ To run inference on a single or multiple GPUs, use ``VLLM`` class from ``langcha
 | 
			
		||||
 | 
			
		||||
    print(llm("What is the capital of France ?"))
 | 
			
		||||
 | 
			
		||||
Please refer to this `Tutorial <https://github.com/langchain-ai/langchain/blob/master/docs/extras/integrations/llms/vllm.ipynb>`_ for more details.
 | 
			
		||||
Please refer to this `Tutorial <https://python.langchain.com/docs/integrations/llms/vllm>`_ for more details.
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										81
									
								
								examples/gradio_openai_chatbot_webserver.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										81
									
								
								examples/gradio_openai_chatbot_webserver.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,81 @@
 | 
			
		||||
import argparse
 | 
			
		||||
from openai import OpenAI
 | 
			
		||||
import gradio as gr
 | 
			
		||||
 | 
			
		||||
# Argument parser setup
 | 
			
		||||
parser = argparse.ArgumentParser(
 | 
			
		||||
    description='Chatbot Interface with Customizable Parameters')
 | 
			
		||||
parser.add_argument('--model-url',
 | 
			
		||||
                    type=str,
 | 
			
		||||
                    default='http://localhost:8000/v1',
 | 
			
		||||
                    help='Model URL')
 | 
			
		||||
parser.add_argument('-m',
 | 
			
		||||
                    '--model',
 | 
			
		||||
                    type=str,
 | 
			
		||||
                    required=True,
 | 
			
		||||
                    help='Model name for the chatbot')
 | 
			
		||||
parser.add_argument('--temp',
 | 
			
		||||
                    type=float,
 | 
			
		||||
                    default=0.8,
 | 
			
		||||
                    help='Temperature for text generation')
 | 
			
		||||
parser.add_argument('--stop-token-ids',
 | 
			
		||||
                    type=str,
 | 
			
		||||
                    default='',
 | 
			
		||||
                    help='Comma-separated stop token IDs')
 | 
			
		||||
parser.add_argument("--host", type=str, default=None)
 | 
			
		||||
parser.add_argument("--port", type=int, default=8001)
 | 
			
		||||
 | 
			
		||||
# Parse the arguments
 | 
			
		||||
args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
# Set OpenAI's API key and API base to use vLLM's API server.
 | 
			
		||||
openai_api_key = "EMPTY"
 | 
			
		||||
openai_api_base = args.model_url
 | 
			
		||||
 | 
			
		||||
# Create an OpenAI client to interact with the API server
 | 
			
		||||
client = OpenAI(
 | 
			
		||||
    api_key=openai_api_key,
 | 
			
		||||
    base_url=openai_api_base,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def predict(message, history):
 | 
			
		||||
    # Convert chat history to OpenAI format
 | 
			
		||||
    history_openai_format = [{
 | 
			
		||||
        "role": "system",
 | 
			
		||||
        "content": "You are a great ai assistant."
 | 
			
		||||
    }]
 | 
			
		||||
    for human, assistant in history:
 | 
			
		||||
        history_openai_format.append({"role": "user", "content": human})
 | 
			
		||||
        history_openai_format.append({
 | 
			
		||||
            "role": "assistant",
 | 
			
		||||
            "content": assistant
 | 
			
		||||
        })
 | 
			
		||||
    history_openai_format.append({"role": "user", "content": message})
 | 
			
		||||
 | 
			
		||||
    # Create a chat completion request and send it to the API server
 | 
			
		||||
    stream = client.chat.completions.create(
 | 
			
		||||
        model=args.model,  # Model name to use
 | 
			
		||||
        messages=history_openai_format,  # Chat history
 | 
			
		||||
        temperature=args.temp,  # Temperature for text generation
 | 
			
		||||
        stream=True,  # Stream response
 | 
			
		||||
        extra_body={
 | 
			
		||||
            'repetition_penalty':
 | 
			
		||||
            1,
 | 
			
		||||
            'stop_token_ids': [
 | 
			
		||||
                int(id.strip()) for id in args.stop_token_ids.split(',')
 | 
			
		||||
                if id.strip()
 | 
			
		||||
            ] if args.stop_token_ids else []
 | 
			
		||||
        })
 | 
			
		||||
 | 
			
		||||
    # Read and return generated text from response stream
 | 
			
		||||
    partial_message = ""
 | 
			
		||||
    for chunk in stream:
 | 
			
		||||
        partial_message += (chunk.choices[0].delta.content or "")
 | 
			
		||||
        yield partial_message
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Create and launch a chat interface with Gradio
 | 
			
		||||
gr.ChatInterface(predict).queue().launch(server_name=args.host,
 | 
			
		||||
                                         server_port=args.port,
 | 
			
		||||
                                         share=True)
 | 
			
		||||
@ -47,6 +47,6 @@ if __name__ == "__main__":
 | 
			
		||||
    args = parser.parse_args()
 | 
			
		||||
 | 
			
		||||
    demo = build_demo()
 | 
			
		||||
    demo.queue(concurrency_count=100).launch(server_name=args.host,
 | 
			
		||||
    demo.queue().launch(server_name=args.host,
 | 
			
		||||
                        server_port=args.port,
 | 
			
		||||
                        share=True)
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										119
									
								
								examples/multilora_inference.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										119
									
								
								examples/multilora_inference.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,119 @@
 | 
			
		||||
"""
 | 
			
		||||
This example shows how to use the multi-LoRA functionality for offline inference.
 | 
			
		||||
 | 
			
		||||
Requires HuggingFace credentials for access to Llama2.
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from typing import Optional, List, Tuple
 | 
			
		||||
 | 
			
		||||
from huggingface_hub import snapshot_download
 | 
			
		||||
 | 
			
		||||
from vllm import EngineArgs, LLMEngine, SamplingParams, RequestOutput
 | 
			
		||||
from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_test_prompts(
 | 
			
		||||
        lora_path: str
 | 
			
		||||
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
 | 
			
		||||
    """Create a list of test prompts with their sampling parameters.
 | 
			
		||||
    
 | 
			
		||||
    2 requests for base model, 4 requests for the LoRA. We define 2
 | 
			
		||||
    different LoRA adapters (using the same model for demo purposes).
 | 
			
		||||
    Since we also set `max_loras=1`, the expectation is that the requests
 | 
			
		||||
    with the second LoRA adapter will be ran after all requests with the
 | 
			
		||||
    first adapter have finished.
 | 
			
		||||
    """
 | 
			
		||||
    return [
 | 
			
		||||
        ("A robot may not injure a human being",
 | 
			
		||||
         SamplingParams(temperature=0.0,
 | 
			
		||||
                        logprobs=1,
 | 
			
		||||
                        prompt_logprobs=1,
 | 
			
		||||
                        max_tokens=128), None),
 | 
			
		||||
        ("To be or not to be,",
 | 
			
		||||
         SamplingParams(temperature=0.8,
 | 
			
		||||
                        top_k=5,
 | 
			
		||||
                        presence_penalty=0.2,
 | 
			
		||||
                        max_tokens=128), None),
 | 
			
		||||
        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
 | 
			
		||||
         SamplingParams(temperature=0.0,
 | 
			
		||||
                        logprobs=1,
 | 
			
		||||
                        prompt_logprobs=1,
 | 
			
		||||
                        max_tokens=128,
 | 
			
		||||
                        stop_token_ids=[32003]),
 | 
			
		||||
         LoRARequest("sql-lora", 1, lora_path)),
 | 
			
		||||
        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
 | 
			
		||||
         SamplingParams(n=3,
 | 
			
		||||
                        best_of=3,
 | 
			
		||||
                        use_beam_search=True,
 | 
			
		||||
                        temperature=0,
 | 
			
		||||
                        max_tokens=128,
 | 
			
		||||
                        stop_token_ids=[32003]),
 | 
			
		||||
         LoRARequest("sql-lora", 1, lora_path)),
 | 
			
		||||
        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]",
 | 
			
		||||
         SamplingParams(temperature=0.0,
 | 
			
		||||
                        logprobs=1,
 | 
			
		||||
                        prompt_logprobs=1,
 | 
			
		||||
                        max_tokens=128,
 | 
			
		||||
                        stop_token_ids=[32003]),
 | 
			
		||||
         LoRARequest("sql-lora2", 2, lora_path)),
 | 
			
		||||
        ("[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]",
 | 
			
		||||
         SamplingParams(n=3,
 | 
			
		||||
                        best_of=3,
 | 
			
		||||
                        use_beam_search=True,
 | 
			
		||||
                        temperature=0,
 | 
			
		||||
                        max_tokens=128,
 | 
			
		||||
                        stop_token_ids=[32003]),
 | 
			
		||||
         LoRARequest("sql-lora", 1, lora_path)),
 | 
			
		||||
    ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def process_requests(engine: LLMEngine,
 | 
			
		||||
                     test_prompts: List[Tuple[str, SamplingParams,
 | 
			
		||||
                                              Optional[LoRARequest]]]):
 | 
			
		||||
    """Continuously process a list of prompts and handle the outputs."""
 | 
			
		||||
    request_id = 0
 | 
			
		||||
 | 
			
		||||
    while test_prompts or engine.has_unfinished_requests():
 | 
			
		||||
        if test_prompts:
 | 
			
		||||
            prompt, sampling_params, lora_request = test_prompts.pop(0)
 | 
			
		||||
            engine.add_request(str(request_id),
 | 
			
		||||
                               prompt,
 | 
			
		||||
                               sampling_params,
 | 
			
		||||
                               lora_request=lora_request)
 | 
			
		||||
            request_id += 1
 | 
			
		||||
 | 
			
		||||
        request_outputs: List[RequestOutput] = engine.step()
 | 
			
		||||
 | 
			
		||||
        for request_output in request_outputs:
 | 
			
		||||
            if request_output.finished:
 | 
			
		||||
                print(request_output)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def initialize_engine() -> LLMEngine:
 | 
			
		||||
    """Initialize the LLMEngine."""
 | 
			
		||||
    # max_loras: controls the number of LoRAs that can be used in the same
 | 
			
		||||
    #   batch. Larger numbers will cause higher memory usage, as each LoRA
 | 
			
		||||
    #   slot requires its own preallocated tensor.
 | 
			
		||||
    # max_lora_rank: controls the maximum supported rank of all LoRAs. Larger
 | 
			
		||||
    #   numbers will cause higher memory usage. If you know that all LoRAs will
 | 
			
		||||
    #   use the same rank, it is recommended to set this as low as possible.
 | 
			
		||||
    # max_cpu_loras: controls the size of the CPU LoRA cache.
 | 
			
		||||
    engine_args = EngineArgs(model="meta-llama/Llama-2-7b-hf",
 | 
			
		||||
                             enable_lora=True,
 | 
			
		||||
                             max_loras=1,
 | 
			
		||||
                             max_lora_rank=8,
 | 
			
		||||
                             max_cpu_loras=2,
 | 
			
		||||
                             max_num_seqs=256)
 | 
			
		||||
    return LLMEngine.from_engine_args(engine_args)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    """Main function that sets up and runs the prompt processing."""
 | 
			
		||||
    engine = initialize_engine()
 | 
			
		||||
    lora_path = snapshot_download(repo_id="yard1/llama-2-7b-sql-lora-test")
 | 
			
		||||
    test_prompts = create_test_prompts(lora_path)
 | 
			
		||||
    process_requests(engine, test_prompts)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
    main()
 | 
			
		||||
							
								
								
									
										70
									
								
								examples/offline_inference_distributed.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										70
									
								
								examples/offline_inference_distributed.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,70 @@
 | 
			
		||||
"""
 | 
			
		||||
This example shows how to use Ray Data for running offline batch inference
 | 
			
		||||
distributively on a multi-nodes cluster.
 | 
			
		||||
 | 
			
		||||
Learn more about Ray Data in https://docs.ray.io/en/latest/data/data.html
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
from vllm import LLM, SamplingParams
 | 
			
		||||
from typing import Dict
 | 
			
		||||
import numpy as np
 | 
			
		||||
import ray
 | 
			
		||||
 | 
			
		||||
# Create a sampling params object.
 | 
			
		||||
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Create a class to do batch inference.
 | 
			
		||||
class LLMPredictor:
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        # Create an LLM.
 | 
			
		||||
        self.llm = LLM(model="meta-llama/Llama-2-7b-chat-hf")
 | 
			
		||||
 | 
			
		||||
    def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, list]:
 | 
			
		||||
        # Generate texts from the prompts.
 | 
			
		||||
        # The output is a list of RequestOutput objects that contain the prompt,
 | 
			
		||||
        # generated text, and other information.
 | 
			
		||||
        outputs = self.llm.generate(batch["text"], sampling_params)
 | 
			
		||||
        prompt = []
 | 
			
		||||
        generated_text = []
 | 
			
		||||
        for output in outputs:
 | 
			
		||||
            prompt.append(output.prompt)
 | 
			
		||||
            generated_text.append(' '.join([o.text for o in output.outputs]))
 | 
			
		||||
        return {
 | 
			
		||||
            "prompt": prompt,
 | 
			
		||||
            "generated_text": generated_text,
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# Read one text file from S3. Ray Data supports reading multiple files
 | 
			
		||||
# from cloud storage (such as JSONL, Parquet, CSV, binary format).
 | 
			
		||||
ds = ray.data.read_text("s3://anonymous@air-example-data/prompts.txt")
 | 
			
		||||
 | 
			
		||||
# Apply batch inference for all input data.
 | 
			
		||||
ds = ds.map_batches(
 | 
			
		||||
    LLMPredictor,
 | 
			
		||||
    # Set the concurrency to the number of LLM instances.
 | 
			
		||||
    concurrency=10,
 | 
			
		||||
    # Specify the number of GPUs required per LLM instance.
 | 
			
		||||
    # NOTE: Do NOT set `num_gpus` when using vLLM with tensor-parallelism
 | 
			
		||||
    # (i.e., `tensor_parallel_size`).
 | 
			
		||||
    num_gpus=1,
 | 
			
		||||
    # Specify the batch size for inference.
 | 
			
		||||
    batch_size=32,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
# Peek first 10 results.
 | 
			
		||||
# NOTE: This is for local testing and debugging. For production use case,
 | 
			
		||||
# one should write full result out as shown below.
 | 
			
		||||
outputs = ds.take(limit=10)
 | 
			
		||||
for output in outputs:
 | 
			
		||||
    prompt = output["prompt"]
 | 
			
		||||
    generated_text = output["generated_text"]
 | 
			
		||||
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 | 
			
		||||
 | 
			
		||||
# Write inference output data out as Parquet files to S3.
 | 
			
		||||
# Multiple files would be written to the output destination,
 | 
			
		||||
# and each task would write one or more files separately.
 | 
			
		||||
#
 | 
			
		||||
# ds.write_parquet("s3://<your-output-bucket>")
 | 
			
		||||
							
								
								
									
										59
									
								
								examples/offline_inference_with_prefix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										59
									
								
								examples/offline_inference_with_prefix.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,59 @@
 | 
			
		||||
from vllm import LLM, SamplingParams
 | 
			
		||||
 | 
			
		||||
prefix = (
 | 
			
		||||
    "You are an expert school principal, skilled in effectively managing "
 | 
			
		||||
    "faculty and staff. Draft 10-15 questions for a potential first grade "
 | 
			
		||||
    "Head Teacher for my K-12, all-girls', independent school that emphasizes "
 | 
			
		||||
    "community, joyful discovery, and life-long learning. The candidate is "
 | 
			
		||||
    "coming in for a first-round panel interview for a 8th grade Math "
 | 
			
		||||
    "teaching role. They have 5 years of previous teaching experience "
 | 
			
		||||
    "as an assistant teacher at a co-ed, public school with experience "
 | 
			
		||||
    "in middle school math teaching. Based on these information, fulfill "
 | 
			
		||||
    "the following paragraph: ")
 | 
			
		||||
 | 
			
		||||
# Sample prompts.
 | 
			
		||||
prompts = [
 | 
			
		||||
    "Hello, my name is",
 | 
			
		||||
    "The president of the United States is",
 | 
			
		||||
    "The capital of France is",
 | 
			
		||||
    "The future of AI is",
 | 
			
		||||
]
 | 
			
		||||
# Create a sampling params object.
 | 
			
		||||
sampling_params = SamplingParams(temperature=0.0)
 | 
			
		||||
 | 
			
		||||
# Create an LLM.
 | 
			
		||||
llm = LLM(model="facebook/opt-125m")
 | 
			
		||||
 | 
			
		||||
generating_prompts = [prefix + prompt for prompt in prompts]
 | 
			
		||||
 | 
			
		||||
# Generate texts from the prompts. The output is a list of RequestOutput objects
 | 
			
		||||
# that contain the prompt, generated text, and other information.
 | 
			
		||||
outputs = llm.generate(generating_prompts, sampling_params)
 | 
			
		||||
# Print the outputs.
 | 
			
		||||
for output in outputs:
 | 
			
		||||
    prompt = output.prompt
 | 
			
		||||
    generated_text = output.outputs[0].text
 | 
			
		||||
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 | 
			
		||||
 | 
			
		||||
print("-" * 80)
 | 
			
		||||
 | 
			
		||||
# -1 since the last token can change when concatenating prompts.
 | 
			
		||||
prefix_pos = len(llm.llm_engine.tokenizer.encode(prefix)) - 1
 | 
			
		||||
 | 
			
		||||
# The llm.generate call will batch all prompts and send the batch at once if resources allow.
 | 
			
		||||
# The prefix will only be cached after the first batch is processed, so we need to call generate once
 | 
			
		||||
# to calculate the prefix and cache it.
 | 
			
		||||
outputs = llm.generate(generating_prompts[0],
 | 
			
		||||
                       sampling_params,
 | 
			
		||||
                       prefix_pos=[prefix_pos])
 | 
			
		||||
 | 
			
		||||
# Subsequent batches can leverage the cached prefix
 | 
			
		||||
outputs = llm.generate(generating_prompts,
 | 
			
		||||
                       sampling_params,
 | 
			
		||||
                       prefix_pos=[prefix_pos] * len(generating_prompts))
 | 
			
		||||
 | 
			
		||||
# Print the outputs. You should see the same outputs as before
 | 
			
		||||
for output in outputs:
 | 
			
		||||
    prompt = output.prompt
 | 
			
		||||
    generated_text = output.outputs[0].text
 | 
			
		||||
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
 | 
			
		||||
@ -32,6 +32,5 @@ chat_completion = client.chat.completions.create(
 | 
			
		||||
    model=model,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
print("Chat completion results:")
 | 
			
		||||
print(chat_completion)
 | 
			
		||||
 | 
			
		||||
@ -21,8 +21,7 @@ completion = client.completions.create(
 | 
			
		||||
    echo=False,
 | 
			
		||||
    n=2,
 | 
			
		||||
    stream=stream,
 | 
			
		||||
    logprobs=3
 | 
			
		||||
)
 | 
			
		||||
    logprobs=3)
 | 
			
		||||
 | 
			
		||||
print("Completion results:")
 | 
			
		||||
if stream:
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										54
									
								
								examples/production_monitoring/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										54
									
								
								examples/production_monitoring/README.md
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,54 @@
 | 
			
		||||
# vLLM + Prometheus/Grafana 
 | 
			
		||||
 | 
			
		||||
This is a simple example that shows you how to connect vLLM metric logging to the Prometheus/Grafana stack. For this example, we launch Prometheus and Grafana via Docker. You can checkout other methods through [Prometheus](https://prometheus.io/) and [Grafana](https://grafana.com/) websites. 
 | 
			
		||||
 | 
			
		||||
Install: 
 | 
			
		||||
- [`docker`](https://docs.docker.com/engine/install/)
 | 
			
		||||
- [`docker compose`](https://docs.docker.com/compose/install/linux/#install-using-the-repository)
 | 
			
		||||
 | 
			
		||||
### Launch
 | 
			
		||||
 | 
			
		||||
Prometheus metric logging is enabled by default in the OpenAI-compatible server. Launch via the entrypoint:
 | 
			
		||||
```bash
 | 
			
		||||
python3 -m vllm.entrypoints.openai.api_server \
 | 
			
		||||
    --model mistralai/Mistral-7B-v0.1 \
 | 
			
		||||
    --max-model-len 2048 \
 | 
			
		||||
    --disable-log-requests
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Launch Prometheus and Grafana servers with `docker compose`:
 | 
			
		||||
```bash
 | 
			
		||||
docker compose up
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Submit some sample requests to the server:
 | 
			
		||||
```bash
 | 
			
		||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
 | 
			
		||||
 | 
			
		||||
python3 ../../benchmarks/benchmark_serving.py \
 | 
			
		||||
    --model mistralai/Mistral-7B-v0.1 \
 | 
			
		||||
    --tokenizer mistralai/Mistral-7B-v0.1 \
 | 
			
		||||
    --endpoint /v1/completions \
 | 
			
		||||
    --dataset ShareGPT_V3_unfiltered_cleaned_split.json \
 | 
			
		||||
    --request-rate 3.0
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
Navigating to [`http://localhost:8000/metrics`](http://localhost:8000/metrics) will show the raw Prometheus metrics being exposed by vLLM.
 | 
			
		||||
 | 
			
		||||
### Grafana Dashboard
 | 
			
		||||
 | 
			
		||||
Navigate to [`http://localhost:3000`](http://localhost:3000). Log in with the default username (`admin`) and password (`admin`).
 | 
			
		||||
 | 
			
		||||
#### Add Prometheus Data Source
 | 
			
		||||
 | 
			
		||||
Navigate to [`http://localhost:3000/connections/datasources/new`](http://localhost:3000/connections/datasources/new) and select Prometheus. 
 | 
			
		||||
 | 
			
		||||
On Prometheus configuration page, we need to add the `Prometheus Server URL` in `Connection`. For this setup, Grafana and Prometheus are running in separate containers, but Docker creates DNS name for each containers. You can just use `http://prometheus:9090`.
 | 
			
		||||
 | 
			
		||||
Click `Save & Test`. You should get a green check saying "Successfully queried the Prometheus API.".
 | 
			
		||||
 | 
			
		||||
#### Import Dashboard 
 | 
			
		||||
 | 
			
		||||
Navigate to [`http://localhost:3000/dashboard/import`](http://localhost:3000/dashboard/import), upload `grafana.json`, and select the `prometheus` datasource. You should see a screen that looks like the following:
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
							
								
								
									
										19
									
								
								examples/production_monitoring/docker-compose.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								examples/production_monitoring/docker-compose.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,19 @@
 | 
			
		||||
# docker-compose.yaml
 | 
			
		||||
version: "3"
 | 
			
		||||
 | 
			
		||||
services:
 | 
			
		||||
  prometheus:
 | 
			
		||||
    image: prom/prometheus:latest
 | 
			
		||||
    extra_hosts:
 | 
			
		||||
      - "host.docker.internal:host-gateway"     # allow a direct connection from container to the local machine
 | 
			
		||||
    ports:
 | 
			
		||||
      - "9090:9090"   # the default port used by Prometheus
 | 
			
		||||
    volumes:
 | 
			
		||||
      - ${PWD}/prometheus.yaml:/etc/prometheus/prometheus.yml # mount Prometheus config file
 | 
			
		||||
 | 
			
		||||
  grafana:
 | 
			
		||||
    image: grafana/grafana:latest
 | 
			
		||||
    depends_on:
 | 
			
		||||
      - prometheus
 | 
			
		||||
    ports:
 | 
			
		||||
      - "3000:3000" # the default port used by Grafana
 | 
			
		||||
							
								
								
									
										931
									
								
								examples/production_monitoring/grafana.json
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										931
									
								
								examples/production_monitoring/grafana.json
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,931 @@
 | 
			
		||||
{
 | 
			
		||||
  "__inputs": [
 | 
			
		||||
    {
 | 
			
		||||
      "name": "DS_PROMETHEUS",
 | 
			
		||||
      "label": "prometheus",
 | 
			
		||||
      "description": "",
 | 
			
		||||
      "type": "datasource",
 | 
			
		||||
      "pluginId": "prometheus",
 | 
			
		||||
      "pluginName": "Prometheus"
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "__elements": {},
 | 
			
		||||
  "__requires": [
 | 
			
		||||
    {
 | 
			
		||||
      "type": "grafana",
 | 
			
		||||
      "id": "grafana",
 | 
			
		||||
      "name": "Grafana",
 | 
			
		||||
      "version": "10.2.3"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "type": "datasource",
 | 
			
		||||
      "id": "prometheus",
 | 
			
		||||
      "name": "Prometheus",
 | 
			
		||||
      "version": "1.0.0"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "type": "panel",
 | 
			
		||||
      "id": "timeseries",
 | 
			
		||||
      "name": "Time series",
 | 
			
		||||
      "version": ""
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "annotations": {
 | 
			
		||||
    "list": [
 | 
			
		||||
      {
 | 
			
		||||
        "builtIn": 1,
 | 
			
		||||
        "datasource": {
 | 
			
		||||
          "type": "grafana",
 | 
			
		||||
          "uid": "-- Grafana --"
 | 
			
		||||
        },
 | 
			
		||||
        "enable": true,
 | 
			
		||||
        "hide": true,
 | 
			
		||||
        "iconColor": "rgba(0, 211, 255, 1)",
 | 
			
		||||
        "name": "Annotations & Alerts",
 | 
			
		||||
        "type": "dashboard"
 | 
			
		||||
      }
 | 
			
		||||
    ]
 | 
			
		||||
  },
 | 
			
		||||
  "description": "Monitoring vLLM Inference Server",
 | 
			
		||||
  "editable": true,
 | 
			
		||||
  "fiscalYearStartMonth": 0,
 | 
			
		||||
  "graphTooltip": 0,
 | 
			
		||||
  "id": null,
 | 
			
		||||
  "links": [],
 | 
			
		||||
  "liveNow": false,
 | 
			
		||||
  "panels": [
 | 
			
		||||
    {
 | 
			
		||||
      "datasource": {
 | 
			
		||||
        "type": "prometheus",
 | 
			
		||||
        "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
      },
 | 
			
		||||
      "description": "End to end request latency measured in seconds.",
 | 
			
		||||
      "fieldConfig": {
 | 
			
		||||
        "defaults": {
 | 
			
		||||
          "color": {
 | 
			
		||||
            "mode": "palette-classic"
 | 
			
		||||
          },
 | 
			
		||||
          "custom": {
 | 
			
		||||
            "axisBorderShow": false,
 | 
			
		||||
            "axisCenteredZero": false,
 | 
			
		||||
            "axisColorMode": "text",
 | 
			
		||||
            "axisLabel": "",
 | 
			
		||||
            "axisPlacement": "auto",
 | 
			
		||||
            "barAlignment": 0,
 | 
			
		||||
            "drawStyle": "line",
 | 
			
		||||
            "fillOpacity": 0,
 | 
			
		||||
            "gradientMode": "none",
 | 
			
		||||
            "hideFrom": {
 | 
			
		||||
              "legend": false,
 | 
			
		||||
              "tooltip": false,
 | 
			
		||||
              "viz": false
 | 
			
		||||
            },
 | 
			
		||||
            "insertNulls": false,
 | 
			
		||||
            "lineInterpolation": "linear",
 | 
			
		||||
            "lineWidth": 1,
 | 
			
		||||
            "pointSize": 5,
 | 
			
		||||
            "scaleDistribution": {
 | 
			
		||||
              "type": "linear"
 | 
			
		||||
            },
 | 
			
		||||
            "showPoints": "auto",
 | 
			
		||||
            "spanNulls": false,
 | 
			
		||||
            "stacking": {
 | 
			
		||||
              "group": "A",
 | 
			
		||||
              "mode": "none"
 | 
			
		||||
            },
 | 
			
		||||
            "thresholdsStyle": {
 | 
			
		||||
              "mode": "off"
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          "mappings": [],
 | 
			
		||||
          "thresholds": {
 | 
			
		||||
            "mode": "absolute",
 | 
			
		||||
            "steps": [
 | 
			
		||||
              {
 | 
			
		||||
                "color": "green",
 | 
			
		||||
                "value": null
 | 
			
		||||
              },
 | 
			
		||||
              {
 | 
			
		||||
                "color": "red",
 | 
			
		||||
                "value": 80
 | 
			
		||||
              }
 | 
			
		||||
            ]
 | 
			
		||||
          },
 | 
			
		||||
          "unit": "s"
 | 
			
		||||
        },
 | 
			
		||||
        "overrides": []
 | 
			
		||||
      },
 | 
			
		||||
      "gridPos": {
 | 
			
		||||
        "h": 8,
 | 
			
		||||
        "w": 12,
 | 
			
		||||
        "x": 0,
 | 
			
		||||
        "y": 0
 | 
			
		||||
      },
 | 
			
		||||
      "id": 9,
 | 
			
		||||
      "options": {
 | 
			
		||||
        "legend": {
 | 
			
		||||
          "calcs": [],
 | 
			
		||||
          "displayMode": "list",
 | 
			
		||||
          "placement": "bottom",
 | 
			
		||||
          "showLegend": true
 | 
			
		||||
        },
 | 
			
		||||
        "tooltip": {
 | 
			
		||||
          "mode": "single",
 | 
			
		||||
          "sort": "none"
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "targets": [
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P99",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "A",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P95",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "B",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P90",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "C",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:e2e_request_latency_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P50",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "D",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "editorMode": "code",
 | 
			
		||||
          "expr": "rate(vllm:e2e_request_latency_seconds_sum[$__rate_interval])\n/\nrate(vllm:e2e_request_latency_seconds_count[$__rate_interval])",
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Average",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "E"
 | 
			
		||||
        }
 | 
			
		||||
      ],
 | 
			
		||||
      "title": "E2E Request Latency",
 | 
			
		||||
      "type": "timeseries"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "datasource": {
 | 
			
		||||
        "type": "prometheus",
 | 
			
		||||
        "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
      },
 | 
			
		||||
      "description": "Number of tokens processed per second",
 | 
			
		||||
      "fieldConfig": {
 | 
			
		||||
        "defaults": {
 | 
			
		||||
          "color": {
 | 
			
		||||
            "mode": "palette-classic"
 | 
			
		||||
          },
 | 
			
		||||
          "custom": {
 | 
			
		||||
            "axisBorderShow": false,
 | 
			
		||||
            "axisCenteredZero": false,
 | 
			
		||||
            "axisColorMode": "text",
 | 
			
		||||
            "axisLabel": "",
 | 
			
		||||
            "axisPlacement": "auto",
 | 
			
		||||
            "barAlignment": 0,
 | 
			
		||||
            "drawStyle": "line",
 | 
			
		||||
            "fillOpacity": 0,
 | 
			
		||||
            "gradientMode": "none",
 | 
			
		||||
            "hideFrom": {
 | 
			
		||||
              "legend": false,
 | 
			
		||||
              "tooltip": false,
 | 
			
		||||
              "viz": false
 | 
			
		||||
            },
 | 
			
		||||
            "insertNulls": false,
 | 
			
		||||
            "lineInterpolation": "linear",
 | 
			
		||||
            "lineWidth": 1,
 | 
			
		||||
            "pointSize": 5,
 | 
			
		||||
            "scaleDistribution": {
 | 
			
		||||
              "type": "linear"
 | 
			
		||||
            },
 | 
			
		||||
            "showPoints": "auto",
 | 
			
		||||
            "spanNulls": false,
 | 
			
		||||
            "stacking": {
 | 
			
		||||
              "group": "A",
 | 
			
		||||
              "mode": "none"
 | 
			
		||||
            },
 | 
			
		||||
            "thresholdsStyle": {
 | 
			
		||||
              "mode": "off"
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          "mappings": [],
 | 
			
		||||
          "thresholds": {
 | 
			
		||||
            "mode": "absolute",
 | 
			
		||||
            "steps": [
 | 
			
		||||
              {
 | 
			
		||||
                "color": "green",
 | 
			
		||||
                "value": null
 | 
			
		||||
              },
 | 
			
		||||
              {
 | 
			
		||||
                "color": "red",
 | 
			
		||||
                "value": 80
 | 
			
		||||
              }
 | 
			
		||||
            ]
 | 
			
		||||
          }
 | 
			
		||||
        },
 | 
			
		||||
        "overrides": []
 | 
			
		||||
      },
 | 
			
		||||
      "gridPos": {
 | 
			
		||||
        "h": 8,
 | 
			
		||||
        "w": 12,
 | 
			
		||||
        "x": 12,
 | 
			
		||||
        "y": 0
 | 
			
		||||
      },
 | 
			
		||||
      "id": 8,
 | 
			
		||||
      "options": {
 | 
			
		||||
        "legend": {
 | 
			
		||||
          "calcs": [],
 | 
			
		||||
          "displayMode": "list",
 | 
			
		||||
          "placement": "bottom",
 | 
			
		||||
          "showLegend": true
 | 
			
		||||
        },
 | 
			
		||||
        "tooltip": {
 | 
			
		||||
          "mode": "single",
 | 
			
		||||
          "sort": "none"
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "targets": [
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "rate(vllm:prompt_tokens_total[$__rate_interval])",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Prompt Tokens/Sec",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "A",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "rate(vllm:generation_tokens_total[$__rate_interval])",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Generation Tokens/Sec",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "B",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        }
 | 
			
		||||
      ],
 | 
			
		||||
      "title": "Token Throughput",
 | 
			
		||||
      "type": "timeseries"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "datasource": {
 | 
			
		||||
        "type": "prometheus",
 | 
			
		||||
        "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
      },
 | 
			
		||||
      "description": "Inter token latency in seconds.",
 | 
			
		||||
      "fieldConfig": {
 | 
			
		||||
        "defaults": {
 | 
			
		||||
          "color": {
 | 
			
		||||
            "mode": "palette-classic"
 | 
			
		||||
          },
 | 
			
		||||
          "custom": {
 | 
			
		||||
            "axisBorderShow": false,
 | 
			
		||||
            "axisCenteredZero": false,
 | 
			
		||||
            "axisColorMode": "text",
 | 
			
		||||
            "axisLabel": "",
 | 
			
		||||
            "axisPlacement": "auto",
 | 
			
		||||
            "barAlignment": 0,
 | 
			
		||||
            "drawStyle": "line",
 | 
			
		||||
            "fillOpacity": 0,
 | 
			
		||||
            "gradientMode": "none",
 | 
			
		||||
            "hideFrom": {
 | 
			
		||||
              "legend": false,
 | 
			
		||||
              "tooltip": false,
 | 
			
		||||
              "viz": false
 | 
			
		||||
            },
 | 
			
		||||
            "insertNulls": false,
 | 
			
		||||
            "lineInterpolation": "linear",
 | 
			
		||||
            "lineWidth": 1,
 | 
			
		||||
            "pointSize": 5,
 | 
			
		||||
            "scaleDistribution": {
 | 
			
		||||
              "type": "linear"
 | 
			
		||||
            },
 | 
			
		||||
            "showPoints": "auto",
 | 
			
		||||
            "spanNulls": false,
 | 
			
		||||
            "stacking": {
 | 
			
		||||
              "group": "A",
 | 
			
		||||
              "mode": "none"
 | 
			
		||||
            },
 | 
			
		||||
            "thresholdsStyle": {
 | 
			
		||||
              "mode": "off"
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          "mappings": [],
 | 
			
		||||
          "thresholds": {
 | 
			
		||||
            "mode": "absolute",
 | 
			
		||||
            "steps": [
 | 
			
		||||
              {
 | 
			
		||||
                "color": "green",
 | 
			
		||||
                "value": null
 | 
			
		||||
              },
 | 
			
		||||
              {
 | 
			
		||||
                "color": "red",
 | 
			
		||||
                "value": 80
 | 
			
		||||
              }
 | 
			
		||||
            ]
 | 
			
		||||
          },
 | 
			
		||||
          "unit": "s"
 | 
			
		||||
        },
 | 
			
		||||
        "overrides": []
 | 
			
		||||
      },
 | 
			
		||||
      "gridPos": {
 | 
			
		||||
        "h": 8,
 | 
			
		||||
        "w": 12,
 | 
			
		||||
        "x": 0,
 | 
			
		||||
        "y": 8
 | 
			
		||||
      },
 | 
			
		||||
      "id": 10,
 | 
			
		||||
      "options": {
 | 
			
		||||
        "legend": {
 | 
			
		||||
          "calcs": [],
 | 
			
		||||
          "displayMode": "list",
 | 
			
		||||
          "placement": "bottom",
 | 
			
		||||
          "showLegend": true
 | 
			
		||||
        },
 | 
			
		||||
        "tooltip": {
 | 
			
		||||
          "mode": "single",
 | 
			
		||||
          "sort": "none"
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "targets": [
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P99",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "A",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P95",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "B",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P90",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "C",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:time_per_output_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P50",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "D",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "editorMode": "code",
 | 
			
		||||
          "expr": "rate(vllm:time_per_output_token_seconds_sum[$__rate_interval])\n/\nrate(vllm:time_per_output_token_seconds_count[$__rate_interval])",
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Mean",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "E"
 | 
			
		||||
        }
 | 
			
		||||
      ],
 | 
			
		||||
      "title": "Time Per Output Token Latency",
 | 
			
		||||
      "type": "timeseries"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "datasource": {
 | 
			
		||||
        "type": "prometheus",
 | 
			
		||||
        "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
      },
 | 
			
		||||
      "description": "Number of requests in RUNNING, WAITING, and SWAPPED state",
 | 
			
		||||
      "fieldConfig": {
 | 
			
		||||
        "defaults": {
 | 
			
		||||
          "color": {
 | 
			
		||||
            "mode": "palette-classic"
 | 
			
		||||
          },
 | 
			
		||||
          "custom": {
 | 
			
		||||
            "axisBorderShow": false,
 | 
			
		||||
            "axisCenteredZero": false,
 | 
			
		||||
            "axisColorMode": "text",
 | 
			
		||||
            "axisLabel": "",
 | 
			
		||||
            "axisPlacement": "auto",
 | 
			
		||||
            "barAlignment": 0,
 | 
			
		||||
            "drawStyle": "line",
 | 
			
		||||
            "fillOpacity": 0,
 | 
			
		||||
            "gradientMode": "none",
 | 
			
		||||
            "hideFrom": {
 | 
			
		||||
              "legend": false,
 | 
			
		||||
              "tooltip": false,
 | 
			
		||||
              "viz": false
 | 
			
		||||
            },
 | 
			
		||||
            "insertNulls": false,
 | 
			
		||||
            "lineInterpolation": "linear",
 | 
			
		||||
            "lineWidth": 1,
 | 
			
		||||
            "pointSize": 5,
 | 
			
		||||
            "scaleDistribution": {
 | 
			
		||||
              "type": "linear"
 | 
			
		||||
            },
 | 
			
		||||
            "showPoints": "auto",
 | 
			
		||||
            "spanNulls": false,
 | 
			
		||||
            "stacking": {
 | 
			
		||||
              "group": "A",
 | 
			
		||||
              "mode": "none"
 | 
			
		||||
            },
 | 
			
		||||
            "thresholdsStyle": {
 | 
			
		||||
              "mode": "off"
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          "mappings": [],
 | 
			
		||||
          "thresholds": {
 | 
			
		||||
            "mode": "absolute",
 | 
			
		||||
            "steps": [
 | 
			
		||||
              {
 | 
			
		||||
                "color": "green",
 | 
			
		||||
                "value": null
 | 
			
		||||
              },
 | 
			
		||||
              {
 | 
			
		||||
                "color": "red",
 | 
			
		||||
                "value": 80
 | 
			
		||||
              }
 | 
			
		||||
            ]
 | 
			
		||||
          },
 | 
			
		||||
          "unit": "none"
 | 
			
		||||
        },
 | 
			
		||||
        "overrides": []
 | 
			
		||||
      },
 | 
			
		||||
      "gridPos": {
 | 
			
		||||
        "h": 8,
 | 
			
		||||
        "w": 12,
 | 
			
		||||
        "x": 12,
 | 
			
		||||
        "y": 8
 | 
			
		||||
      },
 | 
			
		||||
      "id": 3,
 | 
			
		||||
      "options": {
 | 
			
		||||
        "legend": {
 | 
			
		||||
          "calcs": [],
 | 
			
		||||
          "displayMode": "list",
 | 
			
		||||
          "placement": "bottom",
 | 
			
		||||
          "showLegend": true
 | 
			
		||||
        },
 | 
			
		||||
        "tooltip": {
 | 
			
		||||
          "mode": "single",
 | 
			
		||||
          "sort": "none"
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "targets": [
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "vllm:num_requests_running",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "includeNullMetadata": true,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Num Running",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "A",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "vllm:num_requests_swapped",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": true,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Num Swapped",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "B",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "vllm:num_requests_waiting",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": true,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Num Waiting",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "C",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        }
 | 
			
		||||
      ],
 | 
			
		||||
      "title": "Scheduler State",
 | 
			
		||||
      "type": "timeseries"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "datasource": {
 | 
			
		||||
        "type": "prometheus",
 | 
			
		||||
        "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
      },
 | 
			
		||||
      "description": "P50, P90, P95, and P99 TTFT latency in seconds.",
 | 
			
		||||
      "fieldConfig": {
 | 
			
		||||
        "defaults": {
 | 
			
		||||
          "color": {
 | 
			
		||||
            "mode": "palette-classic"
 | 
			
		||||
          },
 | 
			
		||||
          "custom": {
 | 
			
		||||
            "axisBorderShow": false,
 | 
			
		||||
            "axisCenteredZero": false,
 | 
			
		||||
            "axisColorMode": "text",
 | 
			
		||||
            "axisLabel": "",
 | 
			
		||||
            "axisPlacement": "auto",
 | 
			
		||||
            "barAlignment": 0,
 | 
			
		||||
            "drawStyle": "line",
 | 
			
		||||
            "fillOpacity": 0,
 | 
			
		||||
            "gradientMode": "none",
 | 
			
		||||
            "hideFrom": {
 | 
			
		||||
              "legend": false,
 | 
			
		||||
              "tooltip": false,
 | 
			
		||||
              "viz": false
 | 
			
		||||
            },
 | 
			
		||||
            "insertNulls": false,
 | 
			
		||||
            "lineInterpolation": "linear",
 | 
			
		||||
            "lineWidth": 1,
 | 
			
		||||
            "pointSize": 5,
 | 
			
		||||
            "scaleDistribution": {
 | 
			
		||||
              "type": "linear"
 | 
			
		||||
            },
 | 
			
		||||
            "showPoints": "auto",
 | 
			
		||||
            "spanNulls": false,
 | 
			
		||||
            "stacking": {
 | 
			
		||||
              "group": "A",
 | 
			
		||||
              "mode": "none"
 | 
			
		||||
            },
 | 
			
		||||
            "thresholdsStyle": {
 | 
			
		||||
              "mode": "off"
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          "mappings": [],
 | 
			
		||||
          "thresholds": {
 | 
			
		||||
            "mode": "absolute",
 | 
			
		||||
            "steps": [
 | 
			
		||||
              {
 | 
			
		||||
                "color": "green",
 | 
			
		||||
                "value": null
 | 
			
		||||
              },
 | 
			
		||||
              {
 | 
			
		||||
                "color": "red",
 | 
			
		||||
                "value": 80
 | 
			
		||||
              }
 | 
			
		||||
            ]
 | 
			
		||||
          },
 | 
			
		||||
          "unit": "s"
 | 
			
		||||
        },
 | 
			
		||||
        "overrides": []
 | 
			
		||||
      },
 | 
			
		||||
      "gridPos": {
 | 
			
		||||
        "h": 8,
 | 
			
		||||
        "w": 12,
 | 
			
		||||
        "x": 0,
 | 
			
		||||
        "y": 16
 | 
			
		||||
      },
 | 
			
		||||
      "id": 5,
 | 
			
		||||
      "options": {
 | 
			
		||||
        "legend": {
 | 
			
		||||
          "calcs": [],
 | 
			
		||||
          "displayMode": "list",
 | 
			
		||||
          "placement": "bottom",
 | 
			
		||||
          "showLegend": true
 | 
			
		||||
        },
 | 
			
		||||
        "tooltip": {
 | 
			
		||||
          "mode": "single",
 | 
			
		||||
          "sort": "none"
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "targets": [
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.99, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P99",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "A",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.95, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P95",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "B",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.9, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P90",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "C",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "disableTextWrap": false,
 | 
			
		||||
          "editorMode": "builder",
 | 
			
		||||
          "expr": "histogram_quantile(0.5, sum by(le) (rate(vllm:time_to_first_token_seconds_bucket[$__rate_interval])))",
 | 
			
		||||
          "fullMetaSearch": false,
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "includeNullMetadata": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "P50",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "D",
 | 
			
		||||
          "useBackend": false
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "editorMode": "code",
 | 
			
		||||
          "expr": "rate(vllm:time_to_first_token_seconds_sum[$__rate_interval])\n/\nrate(vllm:time_to_first_token_seconds_count[$__rate_interval])",
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "Average",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "E"
 | 
			
		||||
        }
 | 
			
		||||
      ],
 | 
			
		||||
      "title": "Time To First Token Latency",
 | 
			
		||||
      "type": "timeseries"
 | 
			
		||||
    },
 | 
			
		||||
    {
 | 
			
		||||
      "datasource": {
 | 
			
		||||
        "type": "prometheus",
 | 
			
		||||
        "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
      },
 | 
			
		||||
      "description": "Percentage of used cache blocks by vLLM.",
 | 
			
		||||
      "fieldConfig": {
 | 
			
		||||
        "defaults": {
 | 
			
		||||
          "color": {
 | 
			
		||||
            "mode": "palette-classic"
 | 
			
		||||
          },
 | 
			
		||||
          "custom": {
 | 
			
		||||
            "axisBorderShow": false,
 | 
			
		||||
            "axisCenteredZero": false,
 | 
			
		||||
            "axisColorMode": "text",
 | 
			
		||||
            "axisLabel": "",
 | 
			
		||||
            "axisPlacement": "auto",
 | 
			
		||||
            "barAlignment": 0,
 | 
			
		||||
            "drawStyle": "line",
 | 
			
		||||
            "fillOpacity": 0,
 | 
			
		||||
            "gradientMode": "none",
 | 
			
		||||
            "hideFrom": {
 | 
			
		||||
              "legend": false,
 | 
			
		||||
              "tooltip": false,
 | 
			
		||||
              "viz": false
 | 
			
		||||
            },
 | 
			
		||||
            "insertNulls": false,
 | 
			
		||||
            "lineInterpolation": "linear",
 | 
			
		||||
            "lineWidth": 1,
 | 
			
		||||
            "pointSize": 5,
 | 
			
		||||
            "scaleDistribution": {
 | 
			
		||||
              "type": "linear"
 | 
			
		||||
            },
 | 
			
		||||
            "showPoints": "auto",
 | 
			
		||||
            "spanNulls": false,
 | 
			
		||||
            "stacking": {
 | 
			
		||||
              "group": "A",
 | 
			
		||||
              "mode": "none"
 | 
			
		||||
            },
 | 
			
		||||
            "thresholdsStyle": {
 | 
			
		||||
              "mode": "off"
 | 
			
		||||
            }
 | 
			
		||||
          },
 | 
			
		||||
          "mappings": [],
 | 
			
		||||
          "thresholds": {
 | 
			
		||||
            "mode": "absolute",
 | 
			
		||||
            "steps": [
 | 
			
		||||
              {
 | 
			
		||||
                "color": "green",
 | 
			
		||||
                "value": null
 | 
			
		||||
              },
 | 
			
		||||
              {
 | 
			
		||||
                "color": "red",
 | 
			
		||||
                "value": 80
 | 
			
		||||
              }
 | 
			
		||||
            ]
 | 
			
		||||
          },
 | 
			
		||||
          "unit": "percentunit"
 | 
			
		||||
        },
 | 
			
		||||
        "overrides": []
 | 
			
		||||
      },
 | 
			
		||||
      "gridPos": {
 | 
			
		||||
        "h": 8,
 | 
			
		||||
        "w": 12,
 | 
			
		||||
        "x": 12,
 | 
			
		||||
        "y": 16
 | 
			
		||||
      },
 | 
			
		||||
      "id": 4,
 | 
			
		||||
      "options": {
 | 
			
		||||
        "legend": {
 | 
			
		||||
          "calcs": [],
 | 
			
		||||
          "displayMode": "list",
 | 
			
		||||
          "placement": "bottom",
 | 
			
		||||
          "showLegend": true
 | 
			
		||||
        },
 | 
			
		||||
        "tooltip": {
 | 
			
		||||
          "mode": "single",
 | 
			
		||||
          "sort": "none"
 | 
			
		||||
        }
 | 
			
		||||
      },
 | 
			
		||||
      "targets": [
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "editorMode": "code",
 | 
			
		||||
          "expr": "vllm:gpu_cache_usage_perc",
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "GPU Cache Usage",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "A"
 | 
			
		||||
        },
 | 
			
		||||
        {
 | 
			
		||||
          "datasource": {
 | 
			
		||||
            "type": "prometheus",
 | 
			
		||||
            "uid": "${DS_PROMETHEUS}"
 | 
			
		||||
          },
 | 
			
		||||
          "editorMode": "code",
 | 
			
		||||
          "expr": "vllm:cpu_cache_usage_perc",
 | 
			
		||||
          "hide": false,
 | 
			
		||||
          "instant": false,
 | 
			
		||||
          "legendFormat": "CPU Cache Usage",
 | 
			
		||||
          "range": true,
 | 
			
		||||
          "refId": "B"
 | 
			
		||||
        }
 | 
			
		||||
      ],
 | 
			
		||||
      "title": "Cache Utilization",
 | 
			
		||||
      "type": "timeseries"
 | 
			
		||||
    }
 | 
			
		||||
  ],
 | 
			
		||||
  "refresh": "",
 | 
			
		||||
  "schemaVersion": 39,
 | 
			
		||||
  "tags": [],
 | 
			
		||||
  "templating": {
 | 
			
		||||
    "list": []
 | 
			
		||||
  },
 | 
			
		||||
  "time": {
 | 
			
		||||
    "from": "now-5m",
 | 
			
		||||
    "to": "now"
 | 
			
		||||
  },
 | 
			
		||||
  "timepicker": {},
 | 
			
		||||
  "timezone": "",
 | 
			
		||||
  "title": "vLLM",
 | 
			
		||||
  "uid": "b281712d-8bff-41ef-9f3f-71ad43c05e9b",
 | 
			
		||||
  "version": 2,
 | 
			
		||||
  "weekStart": ""
 | 
			
		||||
}
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user