mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-21 07:13:52 +08:00
Compare commits
275 Commits
Author | SHA1 | Date | |
---|---|---|---|
4db5176d97 | |||
4cf1dc39be | |||
6e4852ce28 | |||
8571ac4672 | |||
997cf78308 | |||
57f560aa23 | |||
003f8ee128 | |||
e9630458c7 | |||
82a1b1a82b | |||
c0d8f1636c | |||
cc08fc7225 | |||
7b86e7c9cd | |||
f80ab3521c | |||
16a1cc9bb2 | |||
b1c9aa3daa | |||
179a6a36f2 | |||
83c644fe7e | |||
9fadc7b7a0 | |||
654bc5ca49 | |||
825b044863 | |||
44dcb52e39 | |||
67d745cc68 | |||
99d7cabd7b | |||
fb2c1c86c1 | |||
0c25435daa | |||
a0d164567c | |||
04e5583425 | |||
8c025fa703 | |||
69ea15e5cc | |||
ed812a73fa | |||
708989341e | |||
22e718ff1a | |||
05308891e2 | |||
a8d604ca2a | |||
b482b9a5b1 | |||
806949514a | |||
c16eaac500 | |||
db35186391 | |||
660dea1235 | |||
cf2a1a4d9d | |||
252357793d | |||
3bb4b1e4cd | |||
954f7305a1 | |||
6ce01f3066 | |||
6a11fdfbb8 | |||
805a8a75f2 | |||
562e580abc | |||
fc912e0886 | |||
f4fd390f5d | |||
fb3db61688 | |||
2dd34371a6 | |||
7e0861bd0b | |||
a72a424b3e | |||
c8a7e93273 | |||
3c10591ef2 | |||
0437492ea9 | |||
630dd9e0ae | |||
23993a7997 | |||
1d2e7fb73f | |||
7ecee34321 | |||
7eb0cb4a14 | |||
a0dce9383a | |||
35e9c12bfa | |||
93548eb37e | |||
460c1884e3 | |||
bd70013407 | |||
2ee8d3ba55 | |||
daed30c4a9 | |||
2f4e108f75 | |||
6512937de1 | |||
c0644cf9ce | |||
533d1932d2 | |||
9f0e69b653 | |||
f230cc2ca6 | |||
da1f7cc12a | |||
c32ab8be1a | |||
fb4f530bf5 | |||
79319cedfa | |||
40c27a7cbb | |||
6ca8031e71 | |||
d7a299edaa | |||
052b6f8ca4 | |||
5895b24677 | |||
cbbc904470 | |||
5cf9254a9c | |||
f058403683 | |||
c66c7f86ac | |||
6e063ea35b | |||
af647fb8b3 | |||
61a97c32f6 | |||
4fbf4aa128 | |||
aae6d36f7e | |||
9f69d8245a | |||
9a7e2d0534 | |||
7f8d612d24 | |||
60d1c6e584 | |||
db9e5708a9 | |||
766435e660 | |||
7cbd9ec7a9 | |||
3eeb148f46 | |||
b1366a9534 | |||
75acdaa4b6 | |||
fad5576c58 | |||
f954d0715c | |||
1ad86acf17 | |||
ecb33a28cb | |||
a57d75821c | |||
925de97e05 | |||
aa46953a20 | |||
593e79e733 | |||
c53041ae3b | |||
52f07e3dec | |||
14dbd5a767 | |||
ed94e4f427 | |||
3c3012398e | |||
ced36cd89b | |||
969d032265 | |||
55712941e5 | |||
981b0d5673 | |||
d09b94ca58 | |||
bb5494676f | |||
b5f49ee55b | |||
150a1ffbfd | |||
281977bd6e | |||
3bbb4936dc | |||
aa4867791e | |||
71734f1bf2 | |||
50704f52c4 | |||
07278c37dd | |||
85ad7e2d01 | |||
89a84b0bb7 | |||
084a01fd35 | |||
062a1d0fab | |||
2eb9f4ff26 | |||
443c7cf4cf | |||
1adddb14bf | |||
b7215de2c5 | |||
f3ff63c3f4 | |||
cd7edc4e87 | |||
6a1e25b151 | |||
95db75de64 | |||
65b1f121c8 | |||
889da130e7 | |||
b75e314fff | |||
316a41ac1d | |||
0310029a2f | |||
309aaef825 | |||
9e169a4c61 | |||
5689e256ba | |||
740374d456 | |||
d88c458f44 | |||
421e218b37 | |||
5448f67635 | |||
0e63494cf3 | |||
ee812580f7 | |||
40468b13fa | |||
2cf0df3381 | |||
545146349c | |||
f4f8a9d892 | |||
b570811706 | |||
ccc4a73257 | |||
0a740a11ba | |||
c882a7f5b3 | |||
5e8ca973eb | |||
87525fab92 | |||
2f808e69ab | |||
01c16ede6b | |||
72fc704803 | |||
1bedf210e3 | |||
507ef787d8 | |||
58f53034ad | |||
0eb0757bef | |||
38c4b7e863 | |||
a112a84aad | |||
461089a21a | |||
71950af726 | |||
cb1362a889 | |||
bb2fc08072 | |||
3eda4ec780 | |||
22fa2e35cb | |||
c5201240a4 | |||
97234be0ec | |||
c051bfe4eb | |||
9e0b558a09 | |||
e519ae097a | |||
7c2749a4fd | |||
729171ae58 | |||
c5e8330997 | |||
e0c15758b8 | |||
bdf5fd1386 | |||
5a96ee52a3 | |||
42c7f66a38 | |||
69d5ae38dc | |||
fea59c7712 | |||
739b61a348 | |||
89c1c6a196 | |||
42de2cefcb | |||
c9eef37f32 | |||
396d92d5e0 | |||
25e778aa16 | |||
b6df37f943 | |||
14f91fe67c | |||
d7f4178dd9 | |||
082ecd80d5 | |||
f952bbc8ff | |||
9364f74eee | |||
06d6c5fe9f | |||
683e3cb9c4 | |||
9042d68362 | |||
3f8d42c81f | |||
7bd82002ae | |||
2e26564259 | |||
e81522e879 | |||
45ceb85a0c | |||
4cc24f01b1 | |||
07eb6f19f3 | |||
f0bbfaf917 | |||
30efe41532 | |||
9ed82e7074 | |||
51f8aa90ad | |||
a5314e8698 | |||
a921e86392 | |||
6366efc67b | |||
dbe5588554 | |||
d4201e06d5 | |||
b5672a112c | |||
c5df56f88b | |||
1689219ebf | |||
4ffffccb7e | |||
f53b8f0d05 | |||
2d4733ba2d | |||
15c6a079b1 | |||
ecdb462c24 | |||
58ca663224 | |||
4634c8728b | |||
c8a7d51c49 | |||
e2fbaee725 | |||
8a74c68bd1 | |||
61e592747c | |||
d25877dd9b | |||
1c27d25fb5 | |||
18fecc3559 | |||
b5af8c223c | |||
b5241e41d9 | |||
e76466dde2 | |||
5f0b9933e6 | |||
a38524f338 | |||
2fa4623d9e | |||
a9a2e74d21 | |||
e09ce759aa | |||
5fa6e9876e | |||
5bf35a91e4 | |||
a19e8d3726 | |||
10383887e0 | |||
1d094fd7c0 | |||
ce37be7ba0 | |||
7f62077af5 | |||
09c2eb85dd | |||
978aed5300 | |||
160e1d8c99 | |||
94162beb9f | |||
c467dff24f | |||
9f4ccec761 | |||
38ef94888a | |||
2bb0489cb3 | |||
7508a3dc34 | |||
7a3d2a5b95 | |||
d97011512e | |||
37d776606f | |||
d92b3c5cde | |||
9ad32dacd9 | |||
d6f3b3d5c4 | |||
4552e37b55 | |||
ec9933f4a5 | |||
3dee97b05f |
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import zipfile
|
||||
|
||||
MAX_SIZE_MB = 200
|
||||
MAX_SIZE_MB = 250
|
||||
|
||||
|
||||
def print_top_10_largest_files(zip_file):
|
||||
|
@ -1,14 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -ex
|
||||
set -o pipefail
|
||||
|
||||
(which wget && which curl) || (apt-get update && apt-get install -y wget curl)
|
||||
|
||||
# aws s3 sync s3://air-example-data-2/vllm_opensource_llava/ images/
|
||||
mkdir -p images
|
||||
cd images
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/stop_sign.jpg
|
||||
wget https://air-example-data-2.s3.us-west-2.amazonaws.com/vllm_opensource_llava/cherry_blossom.jpg
|
||||
|
||||
cd -
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-hf-baseline.sh -m nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5
|
||||
model_name: "nm-testing/Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.905
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.905
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8A8-FP8-Channelwise-compressed-tensors"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.752
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.754
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-FBGEMM-nonuniform"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.753
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.753
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Meta-Llama-3-8B-Instruct-nonuniform-test"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.758
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.759
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
11
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
11
.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-QQQ.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m HandH1998/QQQ-Llama-3-8b-g128 -b 32 -l 1000 -f 5 -t 1
|
||||
model_name: "HandH1998/QQQ-Llama-3-8b-g128"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.409
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.406
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
11
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
Normal file
11
.buildkite/lm-eval-harness/configs/Minitron-4B-Base.yaml
Normal file
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nvidia/Minitron-4B-Base -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nvidia/Minitron-4B-Base"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.252
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.252
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -0,0 +1,11 @@
|
||||
# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Qwen2-1.5B-Instruct-FP8W8 -b auto -l 1000 -f 5 -t 1
|
||||
model_name: "nm-testing/Qwen2-1.5B-Instruct-FP8W8"
|
||||
tasks:
|
||||
- name: "gsm8k"
|
||||
metrics:
|
||||
- name: "exact_match,strict-match"
|
||||
value: 0.578
|
||||
- name: "exact_match,flexible-extract"
|
||||
value: 0.585
|
||||
limit: 1000
|
||||
num_fewshot: 5
|
@ -1,3 +1,4 @@
|
||||
Meta-Llama-3-70B-Instruct-FBGEMM-nonuniform.yaml
|
||||
Meta-Llama-3-70B-Instruct.yaml
|
||||
Mixtral-8x7B-Instruct-v0.1.yaml
|
||||
Qwen2-57B-A14-Instruct.yaml
|
||||
|
@ -2,4 +2,9 @@ Meta-Llama-3-8B-Instruct.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8.yaml
|
||||
Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml
|
||||
Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml
|
||||
Minitron-4B-Base.yaml
|
||||
Qwen2-1.5B-Instruct-INT8-compressed-tensors.yaml
|
||||
Qwen2-1.5B-Instruct-FP8W8.yaml
|
||||
Meta-Llama-3-8B-QQQ.yaml
|
||||
|
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
|
||||
--model_args pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend="ray",trust_remote_code=true,max_model_len=4096 \
|
||||
--tasks gsm8k --num_fewshot $FEWSHOT --limit $LIMIT \
|
||||
--batch_size $BATCH_SIZE
|
||||
|
@ -3,30 +3,51 @@
|
||||
|
||||
## Introduction
|
||||
|
||||
This directory contains the performance benchmarking CI for vllm.
|
||||
The goal is to help developers know the impact of their PRs on the performance of vllm.
|
||||
This directory contains two sets of benchmark for vllm.
|
||||
- Performance benchmark: benchmark vllm's performance under various workload, for **developers** to gain clarity on whether their PR improves/degrades vllm's performance
|
||||
- Nightly benchmark: compare vllm's performance against alternatives (tgi, trt-llm and lmdeploy), for **the public** to know when to choose vllm.
|
||||
|
||||
This benchmark will be *triggered* upon:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label.
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for more GPUs is comming later), with different models.
|
||||
See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performance benchmark results and [vLLM GitHub README](https://github.com/vllm-project/vllm/blob/main/README.md) for latest nightly benchmark results.
|
||||
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
**For benchmarking developers**: please try your best to constraint the duration of benchmarking to less than 1.5 hr so that it won't take forever to run.
|
||||
**For benchmarking developers**: please try your best to constraint the duration of benchmarking to about 1 hr so that it won't take forever to run.
|
||||
|
||||
|
||||
## Configuring the workload
|
||||
## Nightly benchmark quick overview
|
||||
|
||||
The benchmarking workload contains three parts:
|
||||
- Latency tests in `latency-tests.json`.
|
||||
- Throughput tests in `throughput-tests.json`.
|
||||
- Serving tests in `serving-tests.json`.
|
||||
**Benchmarking Coverage**: Fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) on Llama-3 8B, 70B and Mixtral 8x7B.
|
||||
|
||||
See [descriptions.md](tests/descriptions.md) for detailed descriptions.
|
||||
**Benchmarking engines**: vllm, TGI, trt-llm and lmdeploy.
|
||||
|
||||
### Latency test
|
||||
**Benchmarking Duration**: about 3.5hrs.
|
||||
|
||||
|
||||
|
||||
## Trigger the benchmark
|
||||
|
||||
Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `nightly-benchmarks` label.
|
||||
|
||||
|
||||
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
See [descriptions.md](tests/descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
|
||||
#### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
|
||||
@ -54,12 +75,12 @@ Note that the performance numbers are highly sensitive to the value of the param
|
||||
WARNING: The benchmarking script will save json results by itself, so please do not configure `--output-json` parameter in the json file.
|
||||
|
||||
|
||||
### Throughput test
|
||||
#### Throughput test
|
||||
The tests are specified in `throughput-tests.json`. The syntax is similar to `latency-tests.json`, except for that the parameters will be fed forward to `benchmark_throughput.py`.
|
||||
|
||||
The number of this test is also stable -- a slight change on the value of this number might vary the performance numbers by a lot.
|
||||
|
||||
### Serving test
|
||||
#### Serving test
|
||||
We test the throughput by using `benchmark_serving.py` with request rate = inf to cover the online serving overhead. The corresponding parameters are in `serving-tests.json`, and here is an example:
|
||||
|
||||
```
|
||||
@ -96,9 +117,36 @@ The number of this test is less stable compared to the delay and latency benchma
|
||||
|
||||
WARNING: The benchmarking script will save json results by itself, so please do not configure `--save-results` or other results-saving-related parameters in `serving-tests.json`.
|
||||
|
||||
## Visualizing the results
|
||||
#### Visualizing the results
|
||||
The `convert-results-json-to-markdown.py` helps you put the benchmarking results inside a markdown table, by formatting [descriptions.md](tests/descriptions.md) with real benchmarking results.
|
||||
You can find the result presented as a table inside the `buildkite/performance-benchmark` job page.
|
||||
If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
||||
|
||||
#### Workflow
|
||||
|
||||
- The [nightly-pipeline.yaml](nightly-pipeline.yaml) specifies the docker containers for different LLM serving engines.
|
||||
- Inside each container, we run [run-nightly-suite.sh](run-nightly-suite.sh), which will probe the serving engine of the current container.
|
||||
- The `run-nightly-suite.sh` will redirect the request to `tests/run-[llm serving engine name]-nightly.sh`, which parses the workload described in [nightly-tests.json](tests/nightly-tests.json) and performs the benchmark.
|
||||
- At last, we run [scripts/plot-nightly-results.py](scripts/plot-nightly-results.py) to collect and plot the final benchmarking results, and update the results to buildkite.
|
||||
|
||||
#### Nightly tests
|
||||
|
||||
In [nightly-tests.json](tests/nightly-tests.json), we include the command line arguments for benchmarking commands, together with the benchmarking test cases. The format is highly similar to performance benchmark.
|
||||
|
||||
#### Docker containers
|
||||
|
||||
The docker containers for benchmarking are specified in `nightly-pipeline.yaml`.
|
||||
|
||||
WARNING: the docker versions are HARD-CODED and SHOULD BE ALIGNED WITH `nightly-descriptions.md`. The docker versions need to be hard-coded as there are several version-specific bug fixes inside `tests/run-[llm serving engine name]-nightly.sh`.
|
||||
|
||||
WARNING: populating `trt-llm` to latest version is not easy, as it requires updating several protobuf files in [tensorrt-demo](https://github.com/neuralmagic/tensorrt-demo.git).
|
||||
|
||||
|
@ -42,20 +42,20 @@ steps:
|
||||
- name: devshm
|
||||
emptyDir:
|
||||
medium: Memory
|
||||
- label: "H100"
|
||||
agents:
|
||||
queue: H100
|
||||
plugins:
|
||||
- docker#v5.11.0:
|
||||
image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
||||
command:
|
||||
- bash
|
||||
- .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
|
||||
mount-buildkite-agent: true
|
||||
propagate-environment: true
|
||||
ipc: host
|
||||
gpus: all
|
||||
environment:
|
||||
- VLLM_USAGE_SOURCE
|
||||
- HF_TOKEN
|
||||
# - label: "H100"
|
||||
# agents:
|
||||
# queue: H100
|
||||
# plugins:
|
||||
# - docker#v5.11.0:
|
||||
# image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT
|
||||
# command:
|
||||
# - bash
|
||||
# - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh
|
||||
# mount-buildkite-agent: true
|
||||
# propagate-environment: true
|
||||
# ipc: host
|
||||
# gpus: all
|
||||
# environment:
|
||||
# - VLLM_USAGE_SOURCE
|
||||
# - HF_TOKEN
|
||||
|
||||
|
@ -34,6 +34,15 @@ check_hf_token() {
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_sharegpt_downloaded() {
|
||||
local FILE=ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
if [ ! -f "$FILE" ]; then
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/$FILE
|
||||
else
|
||||
echo "$FILE already exists."
|
||||
fi
|
||||
}
|
||||
|
||||
json2args() {
|
||||
# transforms the JSON string to command line args, and '_' is replaced to '-'
|
||||
# example:
|
||||
@ -73,11 +82,6 @@ kill_gpu_processes() {
|
||||
echo "All GPU processes have been killed."
|
||||
fi
|
||||
|
||||
# Sometimes kill with pid doesn't work properly, we can also kill all process running python or python3
|
||||
# since we are in container anyway
|
||||
pkill -9 -f python
|
||||
pkill -9 -f python3
|
||||
|
||||
# waiting for GPU processes to be fully killed
|
||||
# loop while nvidia-smi returns any processes
|
||||
while [ -n "$(nvidia-smi --query-compute-apps=pid --format=csv,noheader)" ]; do
|
||||
@ -355,7 +359,7 @@ main() {
|
||||
|
||||
# prepare for benchmarking
|
||||
cd benchmarks || exit 1
|
||||
wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
ensure_sharegpt_downloaded
|
||||
declare -g RESULTS_FOLDER=results/
|
||||
mkdir -p $RESULTS_FOLDER
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
|
@ -55,5 +55,26 @@
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama70B_tp4_sharegpt_specdecode",
|
||||
"qps_list": [2],
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
"disable_log_requests": "",
|
||||
"tensor_parallel_size": 4,
|
||||
"swap_space": 16,
|
||||
"speculative_model": "turboderp/Qwama-0.5B-Instruct",
|
||||
"num_speculative_tokens": 4,
|
||||
"speculative_draft_tensor_parallel_size": 1,
|
||||
"use_v2_block_manager": ""
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3-70B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200
|
||||
}
|
||||
}
|
||||
]
|
||||
]
|
||||
|
@ -3,13 +3,15 @@ steps:
|
||||
agents:
|
||||
queue: cpu_queue
|
||||
commands:
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg buildkite_commit=$BUILDKITE_COMMIT --build-arg USE_SCCACHE=1 --build-arg CUDA_VERSION={{matrix.cuda_version}} --tag vllm-ci:build-image --target build --progress plain ."
|
||||
- "mkdir artifacts"
|
||||
- "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'"
|
||||
# rename the files to change linux -> manylinux1
|
||||
- "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done"
|
||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/"
|
||||
- "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
matrix:
|
||||
setup:
|
||||
cuda_version:
|
||||
|
@ -55,7 +55,7 @@ while true; do
|
||||
done
|
||||
|
||||
echo "--- Pulling container"
|
||||
image_name="rocmshared/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
image_name="rocm/vllm-ci:${BUILDKITE_COMMIT}"
|
||||
container_name="rocm_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head -c 10; echo)"
|
||||
docker pull ${image_name}
|
||||
|
||||
@ -66,11 +66,18 @@ trap remove_docker_container EXIT
|
||||
|
||||
echo "--- Running container"
|
||||
|
||||
HF_CACHE="$(realpath ~)/huggingface"
|
||||
mkdir -p ${HF_CACHE}
|
||||
HF_MOUNT="/root/.cache/huggingface"
|
||||
|
||||
docker run \
|
||||
--device /dev/kfd --device /dev/dri \
|
||||
--network host \
|
||||
--shm-size=16gb \
|
||||
--rm \
|
||||
-e HF_TOKEN \
|
||||
-v ${HF_CACHE}:${HF_MOUNT} \
|
||||
-e HF_HOME=${HF_MOUNT} \
|
||||
--name ${container_name} \
|
||||
${image_name} \
|
||||
/bin/bash -c "${@}"
|
||||
|
@ -3,26 +3,38 @@
|
||||
set -ex
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t cpu-test -f Dockerfile.cpu .
|
||||
docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
|
||||
numactl -C 48-95 -N 1 docker build -t cpu-test -f Dockerfile.cpu .
|
||||
numactl -C 48-95 -N 1 docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" -t cpu-test-avx2 -f Dockerfile.cpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() { docker rm -f cpu-test cpu-test-avx2 || true; }
|
||||
trap remove_docker_container EXIT
|
||||
remove_docker_container
|
||||
|
||||
# Run the image
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
|
||||
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test cpu-test
|
||||
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test
|
||||
docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --cpuset-cpus=48-95 \
|
||||
--cpuset-mems=1 --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --name cpu-test-avx2 cpu-test-avx2
|
||||
--cpuset-mems=1 --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test-avx2 cpu-test-avx2
|
||||
|
||||
# offline inference
|
||||
docker exec cpu-test bash -c "python3 examples/offline_inference.py"
|
||||
docker exec cpu-test-avx2 bash -c "python3 examples/offline_inference.py"
|
||||
|
||||
# Run basic model test
|
||||
docker exec cpu-test bash -c "cd tests;
|
||||
docker exec cpu-test bash -c "
|
||||
pip install pytest Pillow protobuf
|
||||
cd ../
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py" # Mamba on CPU is not supported
|
||||
pytest -v -s tests/models -m \"not vlm\" --ignore=tests/models/test_embedding.py --ignore=tests/models/test_registry.py --ignore=tests/models/test_jamba.py --ignore=tests/models/test_danube3_4b.py" # Mamba and Danube3-4B on CPU is not supported
|
||||
|
||||
# online inference
|
||||
docker exec cpu-test bash -c "
|
||||
export VLLM_CPU_KVCACHE_SPACE=10
|
||||
export VLLM_CPU_OMP_THREADS_BIND=48-92
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model facebook/opt-125m \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
|
16
.buildkite/run-tpu-test.sh
Normal file
16
.buildkite/run-tpu-test.sh
Normal file
@ -0,0 +1,16 @@
|
||||
set -e
|
||||
|
||||
# Build the docker image.
|
||||
docker build -f Dockerfile.tpu -t vllm-tpu .
|
||||
|
||||
# Set up cleanup.
|
||||
remove_docker_container() { docker rm -f tpu-test || true; }
|
||||
trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# For HF_TOKEN.
|
||||
source /etc/environment
|
||||
# Run a simple end-to-end example.
|
||||
docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \
|
||||
python3 /workspace/vllm/examples/offline_inference_tpu.py
|
@ -12,17 +12,15 @@ steps:
|
||||
fast_check_only: true
|
||||
commands:
|
||||
- pytest -v -s async_engine # Async Engine
|
||||
- bash ../.buildkite/download-images.sh # Inputs
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s multimodal
|
||||
- pytest -v -s test_utils.py # Utils
|
||||
- pytest -v -s worker # Worker
|
||||
|
||||
- label: Tensorizer, Metrics, Tracing Test
|
||||
- label: Metrics, Tracing Test
|
||||
fast_check: true
|
||||
fast_check_only: true
|
||||
commands:
|
||||
- apt-get install curl libsodium23 && pytest -v -s tensorizer_loader # Tensorizer
|
||||
- pytest -v -s metrics # Metrics
|
||||
- "pip install \
|
||||
opentelemetry-sdk \
|
||||
@ -45,8 +43,10 @@ steps:
|
||||
mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
# This flashinfer installation will fail on AMD ROCm, so it is set as optional.
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl || true
|
||||
- pytest -v -s basic_correctness/test_basic_correctness.py
|
||||
- pytest -v -s basic_correctness/test_cpu_offload.py
|
||||
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
@ -54,9 +54,8 @@ steps:
|
||||
- label: Core Test
|
||||
mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
commands:
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
- pytest -v -s distributed/test_parallel_state.py
|
||||
|
||||
- label: Distributed Comm Ops Test
|
||||
#mirror_hardwares: [amd]
|
||||
@ -73,7 +72,7 @@ steps:
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py
|
||||
|
||||
@ -82,20 +81,10 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=meta-llama/Llama-2-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- TEST_DIST_MODEL=llava-hf/llava-1.5-7b-hf DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TEST_DIST_MODEL=microsoft/Phi-3-vision-128k-instruct DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- TARGET_TEST_SUITE=L4 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s distributed/test_chunked_prefill_distributed.py
|
||||
- pytest -v -s distributed/test_multimodal_broadcast.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py
|
||||
@ -107,26 +96,17 @@ steps:
|
||||
fast_check: true
|
||||
commands:
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
# We want to test that models which use 2 GPUs work with 4 GPUs, which is why we duplicate them here.
|
||||
# See https://github.com/vllm-project/vllm/pull/5473#issuecomment-2166601837 for context.
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
|
||||
- label: Pipeline Parallelism Test
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
commands:
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=2 PP_SIZE=2 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- TP_SIZE=1 PP_SIZE=3 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- PP_SIZE=4 EAGER_MODE=1 CHUNKED_PREFILL=0 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- pytest -v -s distributed/test_pipeline_parallel.py
|
||||
|
||||
- label: Engine Test
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
commands:
|
||||
- pytest -v -s engine test_sequence.py test_config.py test_logger.py
|
||||
# OOM in the CI unless we run this separately
|
||||
- pytest -v -s tokenization
|
||||
@ -143,39 +123,37 @@ steps:
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
# install aws cli for llava_example.py
|
||||
# install tensorizer for tensorize_vllm_model.py
|
||||
- pip install awscli tensorizer
|
||||
- python3 offline_inference.py
|
||||
- python3 cpu_offload.py
|
||||
- python3 offline_inference_with_prefix.py
|
||||
- python3 llm_engine_example.py
|
||||
- python3 llava_example.py
|
||||
- python3 offline_inference_vision_language.py
|
||||
- python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
|
||||
- label: Inputs Test
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s test_inputs.py
|
||||
- pytest -v -s multimodal
|
||||
|
||||
- label: Kernels Test %N
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
parallelism: 4
|
||||
# - label: Kernels Test %N
|
||||
# #mirror_hardwares: [amd]
|
||||
# commands:
|
||||
# - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
# - pytest -v -s kernels --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT
|
||||
# parallelism: 4
|
||||
|
||||
- label: Models Test
|
||||
#mirror_hardwares: [amd]
|
||||
commands:
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
- pytest -v -s models -m \"not vlm\"
|
||||
|
||||
- label: Vision Language Models Test
|
||||
mirror_hardwares: [amd]
|
||||
commands:
|
||||
- bash ../.buildkite/download-images.sh
|
||||
- pytest -v -s models -m vlm
|
||||
|
||||
- label: Prefix Caching Test
|
||||
@ -207,25 +185,26 @@ steps:
|
||||
- export VLLM_ATTENTION_BACKEND=XFORMERS
|
||||
- pytest -v -s spec_decode
|
||||
|
||||
- label: LoRA Test %N
|
||||
#mirror_hardwares: [amd]
|
||||
command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
parallelism: 4
|
||||
# - label: LoRA Test %N
|
||||
# #mirror_hardwares: [amd]
|
||||
# command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py
|
||||
# parallelism: 4
|
||||
|
||||
- label: LoRA Long Context (Distributed)
|
||||
#mirror_hardwares: [amd]
|
||||
num_gpus: 4
|
||||
# This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
commands:
|
||||
# FIXIT: find out which code initialize cuda before running the test
|
||||
# before the fix, we need to use spawn to test it
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s -x lora/test_long_context.py
|
||||
# - label: LoRA Long Context (Distributed)
|
||||
# #mirror_hardwares: [amd]
|
||||
# num_gpus: 4
|
||||
# # This test runs llama 13B, so it is required to run on 4 GPUs.
|
||||
# commands:
|
||||
# # FIXIT: find out which code initialize cuda before running the test
|
||||
# # before the fix, we need to use spawn to test it
|
||||
# - export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
# - pytest -v -s -x lora/test_long_context.py
|
||||
|
||||
- label: Tensorizer Test
|
||||
#mirror_hardwares: [amd]
|
||||
fast_check: true
|
||||
commands:
|
||||
- apt-get install curl libsodium23
|
||||
- apt-get install -y curl libsodium23
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s tensorizer_loader
|
||||
|
||||
@ -284,9 +263,6 @@ steps:
|
||||
# NOTE: don't test llama model here, it seems hf implementation is buggy
|
||||
# see https://github.com/vllm-project/vllm/pull/5689 for details
|
||||
- pytest -v -s distributed/test_custom_all_reduce.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
- TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py
|
||||
- pytest -v -s -x lora/test_mixtral.py
|
||||
|
6
.github/workflows/clang-format.yml
vendored
6
.github/workflows/clang-format.yml
vendored
@ -30,12 +30,6 @@ jobs:
|
||||
run: |
|
||||
EXCLUDES=(
|
||||
'csrc/moe/topk_softmax_kernels.cu'
|
||||
'csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu'
|
||||
'csrc/punica/bgmv/bgmv_config.h'
|
||||
'csrc/punica/bgmv/bgmv_impl.cuh'
|
||||
'csrc/punica/bgmv/vec_dtypes.cuh'
|
||||
'csrc/punica/punica_ops.cu'
|
||||
'csrc/punica/type_convert.h'
|
||||
)
|
||||
find csrc/ \( -name '*.h' -o -name '*.cpp' -o -name '*.cu' -o -name '*.cuh' \) -print \
|
||||
| grep -vFf <(printf "%s\n" "${EXCLUDES[@]}") \
|
||||
|
33
.github/workflows/mypy.yaml
vendored
33
.github/workflows/mypy.yaml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
@ -32,22 +32,17 @@ jobs:
|
||||
pip install types-setuptools
|
||||
- name: Mypy
|
||||
run: |
|
||||
mypy tests --config-file pyproject.toml
|
||||
mypy vllm/*.py --config-file pyproject.toml
|
||||
mypy vllm/attention --config-file pyproject.toml
|
||||
mypy vllm/core --config-file pyproject.toml
|
||||
mypy vllm/distributed --config-file pyproject.toml
|
||||
mypy vllm/engine --config-file pyproject.toml
|
||||
mypy vllm/entrypoints --config-file pyproject.toml
|
||||
mypy vllm/executor --config-file pyproject.toml
|
||||
mypy vllm/inputs --config-file pyproject.toml
|
||||
mypy vllm/logging --config-file pyproject.toml
|
||||
mypy vllm/lora --config-file pyproject.toml
|
||||
mypy vllm/model_executor --config-file pyproject.toml
|
||||
mypy vllm/multimodal --config-file pyproject.toml
|
||||
mypy vllm/platforms --config-file pyproject.toml
|
||||
mypy vllm/spec_decode --config-file pyproject.toml
|
||||
mypy vllm/transformers_utils --config-file pyproject.toml
|
||||
mypy vllm/usage --config-file pyproject.toml
|
||||
mypy vllm/worker --config-file pyproject.toml
|
||||
mypy
|
||||
mypy tests --follow-imports skip
|
||||
mypy vllm/attention --follow-imports skip
|
||||
mypy vllm/core --follow-imports skip
|
||||
mypy vllm/distributed --follow-imports skip
|
||||
mypy vllm/engine --follow-imports skip
|
||||
mypy vllm/entrypoints --follow-imports skip
|
||||
mypy vllm/executor --follow-imports skip
|
||||
mypy vllm/lora --follow-imports skip
|
||||
mypy vllm/model_executor --follow-imports skip
|
||||
mypy vllm/prompt_adapter --follow-imports skip
|
||||
mypy vllm/spec_decode --follow-imports skip
|
||||
mypy vllm/worker --follow-imports skip
|
||||
|
||||
|
4
.github/workflows/publish.yml
vendored
4
.github/workflows/publish.yml
vendored
@ -48,8 +48,8 @@ jobs:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
os: ['ubuntu-20.04']
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11']
|
||||
pytorch-version: ['2.3.1'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
|
||||
pytorch-version: ['2.4.0'] # Must be the most recent version that meets requirements-cuda.txt.
|
||||
cuda-version: ['11.8', '12.1']
|
||||
|
||||
steps:
|
||||
|
2
.github/workflows/reminder_comment.yml
vendored
2
.github/workflows/reminder_comment.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only trigger `fastcheck` CI to run, which consists only a small and essential subset of tests to quickly catch errors with the flexibility to run extra individual tests on top (you can do this by unblocking test steps in the Buildkite run). \n\nFull CI run is still required to merge this PR so once the PR is ready to go, please make sure to run it. If you need all test signals in between PR commits, you can trigger full CI as well.\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
body: '👋 Hi! Thank you for contributing to the vLLM project.\n Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run `fastcheck` CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your `fast-check` build on Buildkite UI. \n\nOnce the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).\n\n To run full CI, you can do one of these:\n- Comment `/ready` on the PR\n- Add `ready` label to the PR\n- Enable auto-merge.\n\n🚀'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
23
.github/workflows/remove_label_not_ready_comment.yml
vendored
Normal file
23
.github/workflows/remove_label_not_ready_comment.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
||||
name: Remove ready Label on notready Comment
|
||||
|
||||
on:
|
||||
issue_comment:
|
||||
types: [created]
|
||||
|
||||
jobs:
|
||||
add-ready-label:
|
||||
runs-on: ubuntu-latest
|
||||
if: github.event.issue.pull_request && contains(github.event.comment.body, '/notready')
|
||||
steps:
|
||||
- name: Remove ready label
|
||||
uses: actions/github-script@v5
|
||||
with:
|
||||
script: |
|
||||
github.rest.issues.removeLabel({
|
||||
owner: context.repo.owner,
|
||||
repo: context.repo.repo,
|
||||
issue_number: context.issue.number,
|
||||
name: 'ready'
|
||||
})
|
||||
env:
|
||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
2
.github/workflows/ruff.yml
vendored
2
.github/workflows/ruff.yml
vendored
@ -15,7 +15,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
2
.github/workflows/scripts/build.sh
vendored
2
.github/workflows/scripts/build.sh
vendored
@ -13,8 +13,6 @@ $python_executable -m pip install -r requirements-cuda.txt
|
||||
|
||||
# Limit the number of parallel jobs to avoid OOM
|
||||
export MAX_JOBS=1
|
||||
# Make sure punica is built for the release (for LoRA)
|
||||
export VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
# Make sure release wheels are built for the following architectures
|
||||
export TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 8.9 9.0+PTX"
|
||||
# Build
|
||||
|
2
.github/workflows/yapf.yml
vendored
2
.github/workflows/yapf.yml
vendored
@ -14,7 +14,7 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11"]
|
||||
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
- name: Set up Python ${{ matrix.python-version }}
|
||||
|
@ -10,6 +10,7 @@ build:
|
||||
|
||||
sphinx:
|
||||
configuration: docs/source/conf.py
|
||||
fail_on_warning: true
|
||||
|
||||
# If using Sphinx, optionally build your docs in additional formats such as PDF
|
||||
formats:
|
||||
|
134
CMakeLists.txt
134
CMakeLists.txt
@ -14,7 +14,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
|
||||
# Supported python versions. These versions will be searched in order, the
|
||||
# first match will be selected. These should be kept in sync with setup.py.
|
||||
#
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")
|
||||
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11" "3.12")
|
||||
|
||||
# Supported NVIDIA architectures.
|
||||
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
|
||||
@ -32,8 +32,8 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.3.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.4.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.4.0")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.5.0")
|
||||
|
||||
#
|
||||
# Try to find python package with an executable that exactly matches
|
||||
@ -66,6 +66,39 @@ endif()
|
||||
#
|
||||
find_package(Torch REQUIRED)
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling core extension.")
|
||||
|
||||
# Define _core_C extension
|
||||
# built for (almost) every target platform, (excludes TPU and Neuron)
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/core/torch_bindings.cpp")
|
||||
|
||||
define_gpu_extension_target(
|
||||
_core_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
add_dependencies(default _core_C)
|
||||
|
||||
#
|
||||
# Forward the non-CUDA device extensions to external CMake scripts.
|
||||
#
|
||||
@ -74,7 +107,7 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda" AND
|
||||
if (VLLM_TARGET_DEVICE STREQUAL "cpu")
|
||||
include(${CMAKE_CURRENT_LIST_DIR}/cmake/cpu_extension.cmake)
|
||||
else()
|
||||
message(FATAL_ERROR "Unsupported vLLM target device: ${VLLM_TARGET_DEVICE}")
|
||||
return()
|
||||
endif()
|
||||
return()
|
||||
endif()
|
||||
@ -101,7 +134,7 @@ elseif(HIP_FOUND)
|
||||
# ROCm 5.X and 6.X
|
||||
if (ROCM_VERSION_DEV_MAJOR GREATER_EQUAL 5 AND
|
||||
NOT Torch_VERSION VERSION_EQUAL ${TORCH_SUPPORTED_VERSION_ROCM})
|
||||
message(WARNING "Pytorch version ${TORCH_SUPPORTED_VERSION_ROCM} "
|
||||
message(WARNING "Pytorch version >= ${TORCH_SUPPORTED_VERSION_ROCM} "
|
||||
"expected for ROCm build, saw ${Torch_VERSION} instead.")
|
||||
endif()
|
||||
else()
|
||||
@ -132,7 +165,7 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
# Define other extension targets
|
||||
#
|
||||
|
||||
#
|
||||
@ -151,16 +184,18 @@ set(VLLM_EXT_SRC
|
||||
"csrc/quantization/fp8/common.cu"
|
||||
"csrc/cuda_utils_kernels.cu"
|
||||
"csrc/moe_align_block_size_kernels.cu"
|
||||
"csrc/prepare_inputs/advance_step.cu"
|
||||
"csrc/torch_bindings.cpp")
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
include(FetchContent)
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY=ON)
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
FetchContent_Declare(
|
||||
cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
# CUTLASS 3.5.0
|
||||
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
|
||||
# CUTLASS 3.5.1
|
||||
GIT_TAG 06b21349bcf6ddf6a1686a47a137ad1446579db9
|
||||
GIT_PROGRESS TRUE
|
||||
)
|
||||
FetchContent_MakeAvailable(cutlass)
|
||||
|
||||
@ -169,8 +204,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/awq/gemm_kernels.cu"
|
||||
"csrc/quantization/marlin/dense/marlin_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
|
||||
"csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
|
||||
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
|
||||
"csrc/quantization/gptq_marlin/awq_marlin_repack.cu"
|
||||
"csrc/quantization/fp8/fp8_marlin.cu"
|
||||
"csrc/custom_all_reduce.cu"
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu"
|
||||
@ -198,7 +235,7 @@ define_gpu_extension_target(
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_GPU_ARCHES}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR};${CUTLASS_TOOLS_UTIL_INCLUDE_DIR}
|
||||
INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
@ -220,76 +257,7 @@ define_gpu_extension_target(
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
|
||||
#
|
||||
# _punica_C extension
|
||||
#
|
||||
|
||||
set(VLLM_PUNICA_EXT_SRC
|
||||
"csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu"
|
||||
"csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu"
|
||||
"csrc/punica/punica_ops.cu"
|
||||
"csrc/punica/torch_bindings.cpp")
|
||||
|
||||
#
|
||||
# Copy GPU compilation flags+update for punica
|
||||
#
|
||||
set(VLLM_PUNICA_GPU_FLAGS ${VLLM_GPU_FLAGS})
|
||||
list(REMOVE_ITEM VLLM_PUNICA_GPU_FLAGS
|
||||
"-D__CUDA_NO_HALF_OPERATORS__"
|
||||
"-D__CUDA_NO_HALF_CONVERSIONS__"
|
||||
"-D__CUDA_NO_BFLOAT16_CONVERSIONS__"
|
||||
"-D__CUDA_NO_HALF2_OPERATORS__")
|
||||
|
||||
#
|
||||
# Filter out CUDA architectures < 8.0 for punica.
|
||||
#
|
||||
if (${VLLM_GPU_LANG} STREQUAL "CUDA")
|
||||
set(VLLM_PUNICA_GPU_ARCHES)
|
||||
foreach(ARCH ${VLLM_GPU_ARCHES})
|
||||
string_to_ver(CODE_VER ${ARCH})
|
||||
if (CODE_VER GREATER_EQUAL 8.0)
|
||||
list(APPEND VLLM_PUNICA_GPU_ARCHES ${ARCH})
|
||||
endif()
|
||||
endforeach()
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
elseif(${VLLM_GPU_LANG} STREQUAL "HIP")
|
||||
set(VLLM_PUNICA_GPU_ARCHES ${VLLM_GPU_ARCHES})
|
||||
message(STATUS "Punica target arches: ${VLLM_PUNICA_GPU_ARCHES}")
|
||||
endif()
|
||||
|
||||
if (VLLM_PUNICA_GPU_ARCHES)
|
||||
define_gpu_extension_target(
|
||||
_punica_C
|
||||
DESTINATION vllm
|
||||
LANGUAGE ${VLLM_GPU_LANG}
|
||||
SOURCES ${VLLM_PUNICA_EXT_SRC}
|
||||
COMPILE_FLAGS ${VLLM_PUNICA_GPU_FLAGS}
|
||||
ARCHITECTURES ${VLLM_PUNICA_GPU_ARCHES}
|
||||
USE_SABI 3
|
||||
WITH_SOABI)
|
||||
else()
|
||||
message(WARNING "Unable to create _punica_C target because none of the "
|
||||
"requested architectures (${VLLM_GPU_ARCHES}) are supported, i.e. >= 8.0")
|
||||
endif()
|
||||
|
||||
#
|
||||
# Add the `default` target which detects which extensions should be
|
||||
# built based on platform/architecture. This is the same logic that
|
||||
# setup.py uses to select which extensions should be built and should
|
||||
# be kept in sync.
|
||||
#
|
||||
# The `default` target makes direct use of cmake easier since knowledge
|
||||
# of which extensions are supported has been factored in, e.g.
|
||||
#
|
||||
# mkdir build && cd build
|
||||
# cmake -G Ninja -DVLLM_PYTHON_EXECUTABLE=`which python3` -DCMAKE_LIBRARY_OUTPUT_DIRECTORY=../vllm ..
|
||||
# cmake --build . --target default
|
||||
#
|
||||
add_custom_target(default)
|
||||
|
||||
if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling C extension.")
|
||||
@ -298,12 +266,4 @@ if(VLLM_GPU_LANG STREQUAL "CUDA" OR VLLM_GPU_LANG STREQUAL "HIP")
|
||||
message(STATUS "Enabling moe extension.")
|
||||
add_dependencies(default _moe_C)
|
||||
|
||||
# Enable punica if -DVLLM_INSTALL_PUNICA_KERNELS=ON or
|
||||
# VLLM_INSTALL_PUNICA_KERNELS is set in the environment and
|
||||
# there are supported target arches.
|
||||
if (VLLM_PUNICA_GPU_ARCHES AND
|
||||
(ENV{VLLM_INSTALL_PUNICA_KERNELS} OR VLLM_INSTALL_PUNICA_KERNELS))
|
||||
message(STATUS "Enabling punica extension.")
|
||||
add_dependencies(default _punica_C)
|
||||
endif()
|
||||
endif()
|
||||
|
48
Dockerfile
48
Dockerfile
@ -8,10 +8,10 @@
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS base
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
@ -21,13 +21,16 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& apt-get install -y ccache software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv python3-pip \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
||||
&& python3 --version \
|
||||
&& python3 -m pip --version
|
||||
&& python3 --version
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip git curl sudo
|
||||
&& apt-get install -y git curl sudo
|
||||
|
||||
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||
RUN python3 -m pip --version
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
@ -39,6 +42,7 @@ WORKDIR /workspace
|
||||
|
||||
# install build and runtime dependencies
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-adag.txt requirements-adag.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install -r requirements-cuda.txt
|
||||
@ -58,7 +62,7 @@ ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
#################### WHEEL BUILD IMAGE ####################
|
||||
FROM base AS build
|
||||
|
||||
ARG PYTHON_VERSION=3
|
||||
ARG PYTHON_VERSION=3.10
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements-build.txt requirements-build.txt
|
||||
@ -75,6 +79,7 @@ COPY setup.py setup.py
|
||||
COPY cmake cmake
|
||||
COPY CMakeLists.txt CMakeLists.txt
|
||||
COPY requirements-common.txt requirements-common.txt
|
||||
COPY requirements-adag.txt requirements-adag.txt
|
||||
COPY requirements-cuda.txt requirements-cuda.txt
|
||||
COPY pyproject.toml pyproject.toml
|
||||
COPY vllm vllm
|
||||
@ -85,8 +90,6 @@ ENV MAX_JOBS=${max_jobs}
|
||||
# number of threads used by nvcc
|
||||
ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
# make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
|
||||
ARG buildkite_commit
|
||||
ENV BUILDKITE_COMMIT=${buildkite_commit}
|
||||
@ -100,7 +103,11 @@ RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& export SCCACHE_BUCKET=vllm-build-sccache \
|
||||
&& if [ "$CUDA_VERSION" = "11.8.0" ]; then \
|
||||
export SCCACHE_BUCKET=vllm-build-sccache-2; \
|
||||
else \
|
||||
export SCCACHE_BUCKET=vllm-build-sccache; \
|
||||
fi \
|
||||
&& export SCCACHE_REGION=us-west-2 \
|
||||
&& export CMAKE_BUILD_TYPE=Release \
|
||||
&& sccache --show-stats \
|
||||
@ -149,12 +156,27 @@ RUN pip --verbose wheel -r requirements-mamba.txt \
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu22.04 AS vllm-base
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu20.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG PYTHON_VERSION=3.10
|
||||
WORKDIR /vllm-workspace
|
||||
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common \
|
||||
&& add-apt-repository ppa:deadsnakes/ppa \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& if [ "${PYTHON_VERSION}" != "3" ]; then update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1; fi \
|
||||
&& python3 --version
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y python3-pip git vim
|
||||
&& apt-get install -y python3-pip git vim curl libibverbs-dev
|
||||
|
||||
# Install pip s.t. it will be compatible with our PYTHON_VERSION
|
||||
RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION}
|
||||
RUN python3 -m pip --version
|
||||
|
||||
# Workaround for https://github.com/openai/triton/issues/2507 and
|
||||
# https://github.com/pytorch/pytorch/issues/107960 -- hopefully
|
||||
@ -172,7 +194,7 @@ RUN --mount=type=bind,from=mamba-builder,src=/usr/src/mamba,target=/usr/src/mamb
|
||||
python3 -m pip install /usr/src/mamba/*.whl --no-cache-dir
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.9/flashinfer-0.0.9+cu121torch2.3-cp310-cp310-linux_x86_64.whl
|
||||
python3 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.4-cp310-cp310-linux_x86_64.whl
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
|
||||
|
@ -2,8 +2,8 @@
|
||||
|
||||
FROM ubuntu:22.04 AS cpu-test-1
|
||||
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 \
|
||||
RUN apt-get update -y \
|
||||
&& apt-get install -y curl git wget vim numactl gcc-12 g++-12 python3 python3-pip libtcmalloc-minimal4 libnuma-dev \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12
|
||||
|
||||
# https://intel.github.io/intel-extension-for-pytorch/cpu/latest/tutorials/performance_tuning/tuning_guide.html
|
||||
@ -13,8 +13,9 @@ RUN pip install intel-openmp
|
||||
|
||||
ENV LD_PRELOAD="/usr/lib/x86_64-linux-gnu/libtcmalloc_minimal.so.4:/usr/local/lib/libiomp5.so:$LD_PRELOAD"
|
||||
|
||||
RUN echo 'ulimit -c 0' >> ~/.bashrc
|
||||
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.3.100%2Bgit0eb3473-cp310-cp310-linux_x86_64.whl
|
||||
RUN pip install https://intel-extension-for-pytorch.s3.amazonaws.com/ipex_dev/cpu/intel_extension_for_pytorch-2.4.0%2Bgitfbaa4bc-cp310-cp310-linux_x86_64.whl
|
||||
|
||||
RUN pip install --upgrade pip \
|
||||
&& pip install wheel packaging ninja "setuptools>=49.4.0" numpy
|
||||
|
@ -13,12 +13,15 @@ COPY requirements-common.txt /workspace/vllm/
|
||||
COPY requirements-openvino.txt /workspace/vllm/
|
||||
|
||||
COPY vllm/ /workspace/vllm/vllm
|
||||
COPY csrc/core /workspace/vllm/csrc/core
|
||||
COPY cmake/utils.cmake /workspace/vllm/cmake/
|
||||
COPY CMakeLists.txt /workspace/vllm/
|
||||
COPY setup.py /workspace/vllm/
|
||||
|
||||
# install build requirements
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" python3 -m pip install -r /workspace/vllm/requirements-build.txt
|
||||
# build vLLM with OpenVINO backend
|
||||
RUN PIP_PRE=1 PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/nightly/" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
|
||||
RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu https://storage.openvinotoolkit.org/simple/wheels/pre-release" VLLM_TARGET_DEVICE="openvino" python3 -m pip install /workspace/vllm/
|
||||
|
||||
COPY examples/ /workspace/vllm/examples
|
||||
COPY benchmarks/ /workspace/vllm/benchmarks
|
||||
|
101
Dockerfile.rocm
101
Dockerfile.rocm
@ -1,26 +1,24 @@
|
||||
# Default ROCm 6.1 base image
|
||||
ARG BASE_IMAGE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
||||
|
||||
# Tested and supported base rocm/pytorch images
|
||||
ARG ROCm_5_7_BASE="rocm/pytorch:rocm5.7_ubuntu20.04_py3.9_pytorch_2.0.1" \
|
||||
ROCm_6_0_BASE="rocm/pytorch:rocm6.0_ubuntu20.04_py3.9_pytorch_2.1.1" \
|
||||
ROCM_6_1_BASE="rocm/pytorch:rocm6.1.2_ubuntu20.04_py3.9_pytorch_staging"
|
||||
|
||||
# Default ROCm ARCHes to build vLLM for.
|
||||
ARG PYTORCH_ROCM_ARCH="gfx908;gfx90a;gfx942;gfx1100"
|
||||
|
||||
# Whether to build CK-based flash-attention
|
||||
# If 0, will not build flash attention
|
||||
# This is useful for gfx target where flash-attention is not supported
|
||||
# (i.e. those that do not appear in `FA_GFX_ARCHS`)
|
||||
# Triton FA is used by default on ROCm now so this is unnecessary.
|
||||
# Whether to install CK-based flash-attention
|
||||
# If 0, will not install flash-attention
|
||||
ARG BUILD_FA="1"
|
||||
# If `TRY_FA_WHEEL=1`, we will try installing flash-attention from `FA_WHEEL_URL`
|
||||
# If this succeeds, we use the downloaded wheel and skip building flash-attention.
|
||||
# Otherwise, ROCm flash-attention from `FA_BRANCH` will be built for the
|
||||
# architectures specified in `FA_GFX_ARCHS`
|
||||
ARG TRY_FA_WHEEL="1"
|
||||
ARG FA_WHEEL_URL="https://github.com/ROCm/flash-attention/releases/download/v2.5.9post1-cktile-vllm/flash_attn-2.5.9.post1-cp39-cp39-linux_x86_64.whl"
|
||||
ARG FA_GFX_ARCHS="gfx90a;gfx942"
|
||||
ARG FA_BRANCH="ae7928c"
|
||||
ARG FA_BRANCH="23a2b1c2"
|
||||
|
||||
# Whether to build triton on rocm
|
||||
ARG BUILD_TRITON="1"
|
||||
ARG TRITON_BRANCH="0ef1848"
|
||||
ARG TRITON_BRANCH="e0fc12c"
|
||||
|
||||
### Base image build stage
|
||||
FROM $BASE_IMAGE AS base
|
||||
@ -48,29 +46,17 @@ RUN apt-get update && apt-get install -y \
|
||||
ARG APP_MOUNT=/vllm-workspace
|
||||
WORKDIR ${APP_MOUNT}
|
||||
|
||||
RUN pip install --upgrade pip
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
# Remove sccache so it doesn't interfere with ccache
|
||||
# TODO: implement sccache support across components
|
||||
RUN apt-get purge -y sccache; pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
RUN apt-get purge -y sccache; python3 -m pip uninstall -y sccache; rm -f "$(which sccache)"
|
||||
# Install torch == 2.5.0 on ROCm
|
||||
RUN case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-5.7"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm5.7;; \
|
||||
*"rocm-6.0"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.0;; \
|
||||
*"rocm-6.1"*) \
|
||||
pip uninstall -y torch torchaudio torchvision \
|
||||
&& pip install --no-cache-dir --pre \
|
||||
torch==2.5.0.dev20240710 torchaudio==2.4.0.dev20240710 \
|
||||
torchvision==0.20.0.dev20240710 \
|
||||
python3 -m pip uninstall -y torch torchvision \
|
||||
&& python3 -m pip install --no-cache-dir --pre \
|
||||
torch==2.5.0.dev20240726 \
|
||||
torchvision==0.20.0.dev20240726 \
|
||||
--index-url https://download.pytorch.org/whl/nightly/rocm6.1;; \
|
||||
*) ;; esac
|
||||
|
||||
@ -87,29 +73,31 @@ ENV CCACHE_DIR=/root/.cache/ccache
|
||||
FROM base AS build_amdsmi
|
||||
# Build amdsmi wheel always
|
||||
RUN cd /opt/rocm/share/amd_smi \
|
||||
&& pip wheel . --wheel-dir=/install
|
||||
&& python3 -m pip wheel . --wheel-dir=/install
|
||||
|
||||
|
||||
### Flash-Attention wheel build stage
|
||||
FROM base AS build_fa
|
||||
ARG BUILD_FA
|
||||
ARG TRY_FA_WHEEL
|
||||
ARG FA_WHEEL_URL
|
||||
ARG FA_GFX_ARCHS
|
||||
ARG FA_BRANCH
|
||||
# Build ROCm flash-attention wheel if `BUILD_FA = 1`
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
if [ "$BUILD_FA" = "1" ]; then \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-5.7"*) \
|
||||
export VLLM_TORCH_PATH="$(python3 -c 'import torch; print(torch.__path__[0])')" \
|
||||
&& patch "${VLLM_TORCH_PATH}"/utils/hipify/hipify_python.py hipify_patch.patch;; \
|
||||
*) ;; esac \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
if [ "${TRY_FA_WHEEL}" = "1" ] && python3 -m pip install "${FA_WHEEL_URL}"; then \
|
||||
# If a suitable wheel exists, we download it instead of building FA
|
||||
mkdir -p /install && wget -N "${FA_WHEEL_URL}" -P /install; \
|
||||
else \
|
||||
mkdir -p libs \
|
||||
&& cd libs \
|
||||
&& git clone https://github.com/ROCm/flash-attention.git \
|
||||
&& cd flash-attention \
|
||||
&& git checkout "${FA_BRANCH}" \
|
||||
&& git submodule update --init \
|
||||
&& GPU_ARCHS="${FA_GFX_ARCHS}" python3 setup.py bdist_wheel --dist-dir=/install; \
|
||||
fi; \
|
||||
# Create an empty directory otherwise as later build stages expect one
|
||||
else mkdir -p /install; \
|
||||
fi
|
||||
@ -139,19 +127,11 @@ FROM base AS final
|
||||
# Import the vLLM development directory from the build context
|
||||
COPY . .
|
||||
|
||||
# Error related to odd state for numpy 1.20.3 where there is no METADATA etc, but an extra LICENSES_bundled.txt.
|
||||
# Manually remove it so that later steps of numpy upgrade can continue
|
||||
RUN case "$(which python3)" in \
|
||||
*"/opt/conda/envs/py_3.9"*) \
|
||||
rm -rf /opt/conda/envs/py_3.9/lib/python3.9/site-packages/numpy-1.20.3.dist-info/;; \
|
||||
*) ;; esac
|
||||
|
||||
# Package upgrades for useful functionality or to avoid dependency issues
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
python3 -m pip install --upgrade numba scipy huggingface-hub[cli]
|
||||
|
||||
|
||||
# Make sure punica kernels are built (for LoRA)
|
||||
ENV VLLM_INSTALL_PUNICA_KERNELS=1
|
||||
# Workaround for ray >= 2.10.0
|
||||
ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1
|
||||
# Silences the HF Tokenizers warning
|
||||
@ -159,14 +139,11 @@ ENV TOKENIZERS_PARALLELISM=false
|
||||
|
||||
RUN --mount=type=cache,target=${CCACHE_DIR} \
|
||||
--mount=type=cache,target=/root/.cache/pip \
|
||||
pip install -U -r requirements-rocm.txt \
|
||||
python3 -m pip install -Ur requirements-rocm.txt \
|
||||
&& case "$(ls /opt | grep -Po 'rocm-[0-9]\.[0-9]')" in \
|
||||
*"rocm-6.0"*) \
|
||||
patch /opt/rocm/include/hip/amd_detail/amd_hip_bf16.h rocm_patch/rocm_bf16.patch;; \
|
||||
*"rocm-6.1"*) \
|
||||
# Bring in upgrades to HIP graph earlier than ROCm 6.2 for vLLM
|
||||
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P rocm_patch \
|
||||
&& cp rocm_patch/libamdhip64.so.6 /opt/rocm/lib/libamdhip64.so.6 \
|
||||
wget -N https://github.com/ROCm/vllm/raw/fa78403/rocm_patch/libamdhip64.so.6 -P /opt/rocm/lib \
|
||||
# Prevent interference if torch bundles its own HIP runtime
|
||||
&& rm -f "$(python3 -c 'import torch; print(torch.__path__[0])')"/lib/libamdhip64.so* || true;; \
|
||||
*) ;; esac \
|
||||
@ -178,7 +155,7 @@ RUN --mount=type=bind,from=build_amdsmi,src=/install,target=/install \
|
||||
mkdir -p libs \
|
||||
&& cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& pip uninstall -y amdsmi;
|
||||
&& python3 -m pip uninstall -y amdsmi;
|
||||
|
||||
# Copy triton wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||
@ -186,7 +163,7 @@ RUN --mount=type=bind,from=build_triton,src=/install,target=/install \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& pip uninstall -y triton; fi
|
||||
&& python3 -m pip uninstall -y triton; fi
|
||||
|
||||
# Copy flash-attn wheel(s) into final image if they were built
|
||||
RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||
@ -194,11 +171,11 @@ RUN --mount=type=bind,from=build_fa,src=/install,target=/install \
|
||||
&& if ls /install/*.whl; then \
|
||||
cp /install/*.whl libs \
|
||||
# Preemptively uninstall to avoid same-version no-installs
|
||||
&& pip uninstall -y flash-attn; fi
|
||||
&& python3 -m pip uninstall -y flash-attn; fi
|
||||
|
||||
# Install wheels that were built to the final image
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
if ls libs/*.whl; then \
|
||||
pip install libs/*.whl; fi
|
||||
python3 -m pip install libs/*.whl; fi
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
ARG NIGHTLY_DATE="20240601"
|
||||
ARG NIGHTLY_DATE="20240726"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
@ -6,18 +6,18 @@ WORKDIR /workspace
|
||||
|
||||
# Install aiohttp separately to avoid build errors.
|
||||
RUN pip install aiohttp
|
||||
# Install NumPy 1 instead of NumPy 2.
|
||||
RUN pip install "numpy<2"
|
||||
# Install the TPU and Pallas dependencies.
|
||||
RUN pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
RUN pip install torch_xla[pallas] -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
|
||||
|
||||
# Fix FastAPI dependence
|
||||
RUN pip install "starlette<0.38.0"
|
||||
|
||||
# Build vLLM.
|
||||
COPY . /workspace/vllm
|
||||
ENV VLLM_TARGET_DEVICE="tpu"
|
||||
RUN cd /workspace/vllm && python setup.py develop
|
||||
|
||||
# Re-install outlines to avoid dependency errors.
|
||||
# The outlines version must follow requirements-common.txt.
|
||||
RUN pip uninstall outlines -y
|
||||
RUN pip install "outlines>=0.0.43"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
|
@ -1,4 +1,4 @@
|
||||
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu22.04
|
||||
FROM intel/oneapi-basekit:2024.1.0-devel-ubuntu20.04
|
||||
|
||||
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/intel-oneapi-archive-keyring.gpg > /dev/null && \
|
||||
echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \
|
||||
|
@ -1,4 +1,5 @@
|
||||
include LICENSE
|
||||
include requirements-adag.txt
|
||||
include requirements-common.txt
|
||||
include requirements-cuda.txt
|
||||
include requirements-rocm.txt
|
||||
|
@ -17,6 +17,8 @@ Easy, fast, and cheap LLM serving for everyone
|
||||
---
|
||||
|
||||
*Latest News* 🔥
|
||||
- [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing).
|
||||
- [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html).
|
||||
- [2024/06] We hosted [the fourth vLLM meetup](https://lu.ma/agivllm) with Cloudflare and BentoML! Please find the meetup slides [here](https://docs.google.com/presentation/d/1iJ8o7V2bQEi0BFEljLTwc5G1S10_Rhv3beed5oB0NJ4/edit?usp=sharing).
|
||||
- [2024/04] We hosted [the third vLLM meetup](https://robloxandvllmmeetup2024.splashthat.com/) with Roblox! Please find the meetup slides [here](https://docs.google.com/presentation/d/1A--47JAK4BJ39t954HyTkvtfwn0fkqtsL8NGFuslReM/edit?usp=sharing).
|
||||
- [2024/01] We hosted [the second vLLM meetup](https://lu.ma/ygxbpzhl) with IBM! Please find the meetup slides [here](https://docs.google.com/presentation/d/12mI2sKABnUw5RBWXDYY-HtHth4iMSNcEoQ10jDQbxgA/edit?usp=sharing).
|
||||
@ -37,7 +39,7 @@ vLLM is fast with:
|
||||
- Quantization: [GPTQ](https://arxiv.org/abs/2210.17323), [AWQ](https://arxiv.org/abs/2306.00978), [SqueezeLLM](https://arxiv.org/abs/2306.07629), FP8 KV Cache
|
||||
- Optimized CUDA kernels
|
||||
|
||||
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/3924) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
|
||||
**Performance benchmark**: We include a [performance benchmark](https://buildkite.com/vllm/performance-benchmark/builds/4068) that compares the performance of vllm against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [text-generation-inference](https://github.com/huggingface/text-generation-inference) and [lmdeploy](https://github.com/InternLM/lmdeploy)).
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
@ -90,6 +92,7 @@ vLLM is a community project. Our compute resources for development and testing a
|
||||
- Databricks
|
||||
- DeepInfra
|
||||
- Dropbox
|
||||
- Google Cloud
|
||||
- Lambda Lab
|
||||
- NVIDIA
|
||||
- Replicate
|
||||
|
@ -11,7 +11,7 @@ from tqdm import tqdm
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.engine.arg_utils import EngineArgs
|
||||
from vllm.inputs import PromptStrictInputs
|
||||
from vllm.inputs import PromptInputs
|
||||
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
@ -61,7 +61,7 @@ def main(args: argparse.Namespace):
|
||||
dummy_prompt_token_ids = np.random.randint(10000,
|
||||
size=(args.batch_size,
|
||||
args.input_len))
|
||||
dummy_inputs: List[PromptStrictInputs] = [{
|
||||
dummy_inputs: List[PromptInputs] = [{
|
||||
"prompt_token_ids": batch
|
||||
} for batch in dummy_prompt_token_ids.tolist()]
|
||||
|
||||
|
@ -13,25 +13,25 @@ from weight_shapes import WEIGHT_SHAPES
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())[1:]
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
DEFAULT_TP_SIZES = [1]
|
||||
|
||||
# helpers
|
||||
|
||||
|
||||
def to_fp8(tensor: torch.tensor) -> torch.tensor:
|
||||
def to_fp8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
finfo = torch.finfo(torch.float8_e4m3fn)
|
||||
return torch.round(tensor.clamp(
|
||||
min=finfo.min, max=finfo.max)).to(dtype=torch.float8_e4m3fn)
|
||||
|
||||
|
||||
def to_int8(tensor: torch.tensor) -> torch.tensor:
|
||||
def to_int8(tensor: torch.Tensor) -> torch.Tensor:
|
||||
return torch.round(tensor.clamp(min=-128, max=127)).to(dtype=torch.int8)
|
||||
|
||||
|
||||
def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
k: int) -> Tuple[torch.tensor, torch.tensor]:
|
||||
k: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
|
||||
a = torch.randn((m, k), device='cuda') * 5
|
||||
b = torch.randn((n, k), device='cuda').t() * 5
|
||||
@ -47,15 +47,15 @@ def make_rand_tensors(dtype: torch.dtype, m: int, n: int,
|
||||
# impl
|
||||
|
||||
|
||||
def pytorch_mm_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
def pytorch_mm_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch.mm(a, b)
|
||||
|
||||
|
||||
def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
def pytorch_fp8_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
@ -63,9 +63,9 @@ def pytorch_fp8_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
out_dtype=out_dtype)
|
||||
|
||||
|
||||
def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
||||
scale_a: torch.tensor, scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
def pytorch_fp8_impl_fast_accum(a: torch.Tensor, b: torch.Tensor,
|
||||
scale_a: torch.Tensor, scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return torch._scaled_mm(a,
|
||||
b,
|
||||
scale_a=scale_a,
|
||||
@ -74,15 +74,15 @@ def pytorch_fp8_impl_fast_accum(a: torch.tensor, b: torch.tensor,
|
||||
use_fast_accum=True)
|
||||
|
||||
|
||||
def cutlass_impl(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor,
|
||||
out_dtype: torch.dtype) -> torch.tensor:
|
||||
def cutlass_impl(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor,
|
||||
out_dtype: torch.dtype) -> torch.Tensor:
|
||||
return ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype=out_dtype)
|
||||
|
||||
|
||||
# bench
|
||||
def bench_fn(a: torch.tensor, b: torch.tensor, scale_a: torch.tensor,
|
||||
scale_b: torch.tensor, out_dtype: torch.dtype, label: str,
|
||||
def bench_fn(a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor,
|
||||
scale_b: torch.Tensor, out_dtype: torch.dtype, label: str,
|
||||
sub_label: str, fn: Callable, description: str) -> TMeasurement:
|
||||
|
||||
min_run_time = 1
|
||||
@ -112,13 +112,20 @@ def bench_int8(dtype: torch.dtype, m: int, k: int, n: int, label: str,
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
|
||||
timers = []
|
||||
# pytorch impl
|
||||
# pytorch impl - bfloat16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.bfloat16, device="cuda"),
|
||||
b.to(dtype=torch.bfloat16, device="cuda"), scale_a, scale_b,
|
||||
torch.bfloat16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_bf16_bf16_bf16_matmul-no-scales"))
|
||||
|
||||
# pytorch impl - float16
|
||||
timers.append(
|
||||
bench_fn(a.to(dtype=torch.float16, device="cuda"),
|
||||
b.to(dtype=torch.float16, device="cuda"), scale_a, scale_b,
|
||||
torch.float16, label, sub_label, pytorch_mm_impl,
|
||||
"pytorch_fp16_fp16_fp16_matmul-no-scales"))
|
||||
|
||||
# cutlass impl
|
||||
timers.append(
|
||||
bench_fn(a, b, scale_a, scale_b, torch.bfloat16, label, sub_label,
|
||||
|
@ -7,16 +7,17 @@ from benchmark_shapes import WEIGHT_SHAPES
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_NUM_BITS)
|
||||
GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES, GPTQ_MARLIN_SUPPORTED_NUM_BITS)
|
||||
MARLIN_SUPPORTED_GROUP_SIZES, query_marlin_supported_quant_types)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace, marlin_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
marlin_24_quantize)
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
gptq_pack, quantize_weights, sort_weights)
|
||||
gptq_pack, gptq_quantize_weights, sort_weights)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
@ -27,13 +28,14 @@ K_FULL_OPTS = [False, True]
|
||||
|
||||
|
||||
def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
act_order: bool, is_k_full: bool, num_bits: int, group_size: int,
|
||||
size_m: int, size_k: int, size_n: int):
|
||||
act_order: bool, is_k_full: bool, quant_type: ScalarType,
|
||||
group_size: int, size_m: int, size_k: int, size_n: int):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = ("{}, act={} k_full={}, b={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full, num_bits,
|
||||
group_size, size_m, size_k, size_n))
|
||||
sub_label = ("{}, act={} k_full={}, q={}, g={}, "
|
||||
"MKN=({}x{}x{})".format(model, act_order, is_k_full,
|
||||
str(quant_type), group_size, size_m,
|
||||
size_k, size_n))
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
@ -50,16 +52,18 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, num_bits, group_size, act_order)
|
||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta,
|
||||
marlin_24_s) = marlin_24_quantize(b, num_bits, group_size)
|
||||
marlin_24_s) = marlin_24_quantize(b, quant_type, group_size)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx,
|
||||
rand_perm) = quantize_weights(b, num_bits, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, num_bits, size_k, size_n)
|
||||
rand_perm) = gptq_quantize_weights(b, quant_type, group_size, act_order)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
@ -73,10 +77,11 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(size_n, GPTQ_MARLIN_24_MIN_THREAD_N,
|
||||
GPTQ_MARLIN_24_MAX_PARALLEL)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
"num_bits": num_bits,
|
||||
"quant_type": quant_type,
|
||||
"group_size": group_size,
|
||||
"size_m": size_m,
|
||||
"size_n": size_n,
|
||||
@ -87,6 +92,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_zp": marlin_zp,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
@ -125,19 +131,29 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, num_bits, size_m, size_n, size_k, is_k_full)", # noqa: E501
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm",
|
||||
description="gptq_marlin_gemm_fp16",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (num_bits in GPTQ_MARLIN_24_SUPPORTED_NUM_BITS
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp32",
|
||||
).blocked_autorange(min_run_time=min_run_time))
|
||||
|
||||
if (quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES):
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, num_bits, size_m, size_n, size_k)", # noqa: E501
|
||||
"output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -147,7 +163,7 @@ def bench_run(results: List[benchmark.Measurement], model: str,
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt=
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, num_bits)", # noqa: E501
|
||||
"q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -183,12 +199,13 @@ def main(args):
|
||||
) > 0 and is_k_full not in args.limit_k_full:
|
||||
continue
|
||||
|
||||
for num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
|
||||
if len(args.limit_num_bits
|
||||
) > 0 and num_bits not in args.limit_num_bits:
|
||||
for quant_type in query_marlin_supported_quant_types(
|
||||
False):
|
||||
if len(args.limit_num_bits) > 0 and \
|
||||
quant_type.size_bits not in args.limit_num_bits:
|
||||
continue
|
||||
|
||||
for group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
if len(
|
||||
args.limit_group_size
|
||||
) > 0 and group_size not in args.limit_group_size:
|
||||
@ -202,8 +219,8 @@ def main(args):
|
||||
|
||||
for size_m in args.batch_sizes:
|
||||
bench_run(results, model, act_order, is_k_full,
|
||||
num_bits, group_size, size_m, size_k,
|
||||
size_n)
|
||||
quant_type, group_size, size_m,
|
||||
size_k, size_n)
|
||||
|
||||
compare = benchmark.Compare(results)
|
||||
compare.print()
|
||||
|
@ -100,7 +100,7 @@ def main(
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Using default kv_scale
|
||||
kv_scale = 1.0
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
for _ in range(num_iters):
|
||||
if version == "v1":
|
||||
@ -117,7 +117,8 @@ def main(
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
elif version == "v2":
|
||||
ops.paged_attention_v2(
|
||||
@ -136,7 +137,8 @@ def main(
|
||||
max_seq_len,
|
||||
alibi_slopes,
|
||||
kv_cache_dtype,
|
||||
kv_scale,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid version: {version}")
|
||||
@ -173,7 +175,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--num-kv-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument("--block-size", type=int, choices=[16, 32], default=16)
|
||||
parser.add_argument("--use-alibi", action="store_true")
|
||||
|
@ -94,7 +94,7 @@ if __name__ == '__main__':
|
||||
parser.add_argument("--num-heads", type=int, default=8)
|
||||
parser.add_argument("--head-size",
|
||||
type=int,
|
||||
choices=[64, 80, 96, 112, 128, 192, 256],
|
||||
choices=[64, 80, 96, 112, 120, 128, 192, 256],
|
||||
default=128)
|
||||
parser.add_argument("--rotary-dim", type=int, choices=[16, 32], default=32)
|
||||
parser.add_argument("--dtype",
|
||||
|
@ -83,6 +83,8 @@ endif()
|
||||
|
||||
message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}")
|
||||
|
||||
list(APPEND LIBS "numa")
|
||||
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
@ -95,6 +97,7 @@ set(VLLM_EXT_SRC
|
||||
"csrc/cpu/activation.cpp"
|
||||
"csrc/cpu/attention.cpp"
|
||||
"csrc/cpu/cache.cpp"
|
||||
"csrc/cpu/utils.cpp"
|
||||
"csrc/cpu/layernorm.cpp"
|
||||
"csrc/cpu/pos_encoding.cpp"
|
||||
"csrc/cpu/torch_bindings.cpp")
|
||||
@ -104,11 +107,11 @@ define_gpu_extension_target(
|
||||
DESTINATION vllm
|
||||
LANGUAGE CXX
|
||||
SOURCES ${VLLM_EXT_SRC}
|
||||
LIBRARIES ${LIBS}
|
||||
COMPILE_FLAGS ${CXX_COMPILE_FLAGS}
|
||||
USE_SABI 3
|
||||
WITH_SOABI
|
||||
)
|
||||
|
||||
add_custom_target(default)
|
||||
message(STATUS "Enabling C extension.")
|
||||
add_dependencies(default _C)
|
||||
|
@ -181,7 +181,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
|
||||
#
|
||||
# The torch cmake setup hardcodes the detected architecture flags in
|
||||
# `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it
|
||||
# can't modified on a per-target basis, e.g. for the `punica` extension.
|
||||
# can't modified on a per-target basis.
|
||||
# So, all the `-gencode` flags need to be extracted and removed from
|
||||
# `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method.
|
||||
# Since it's not possible to use `target_compiler_options` for adding target
|
||||
|
@ -65,6 +65,7 @@ DEFAULT_CONDA_PATTERNS = {
|
||||
"optree",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
}
|
||||
|
||||
DEFAULT_PIP_PATTERNS = {
|
||||
@ -77,6 +78,7 @@ DEFAULT_PIP_PATTERNS = {
|
||||
"onnx",
|
||||
"nccl",
|
||||
"transformers",
|
||||
"zmq",
|
||||
}
|
||||
|
||||
|
||||
|
@ -105,9 +105,9 @@ __device__ void paged_attention_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
const int seq_idx = blockIdx.y;
|
||||
const int partition_idx = blockIdx.z;
|
||||
const int max_num_partitions = gridDim.z;
|
||||
@ -285,7 +285,7 @@ __device__ void paged_attention_kernel(
|
||||
Quant_vec k_vec_quant = *reinterpret_cast<const Quant_vec*>(
|
||||
k_ptr + offset1 * BLOCK_SIZE * x + offset2);
|
||||
k_vecs[j] = fp8::scaled_convert<K_vec, Quant_vec, KV_DTYPE>(
|
||||
k_vec_quant, kv_scale);
|
||||
k_vec_quant, k_scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -415,7 +415,7 @@ __device__ void paged_attention_kernel(
|
||||
*reinterpret_cast<const V_quant_vec*>(v_ptr + offset);
|
||||
// Vector conversion from V_quant_vec to V_vec.
|
||||
v_vec = fp8::scaled_convert<V_vec, V_quant_vec, KV_DTYPE>(v_quant_vec,
|
||||
kv_scale);
|
||||
v_scale);
|
||||
}
|
||||
if (block_idx == num_seq_blocks - 1) {
|
||||
// NOTE(woosuk): When v_vec contains the tokens that are out of the
|
||||
@ -513,15 +513,15 @@ __global__ void paged_attention_v1_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
KV_DTYPE, IS_BLOCK_SPARSE>(
|
||||
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
|
||||
v_cache, num_kv_heads, scale, block_tables, seq_lens,
|
||||
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride,
|
||||
kv_head_stride, kv_scale, tp_rank, blocksparse_local_blocks,
|
||||
kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks,
|
||||
blocksparse_vert_stride, blocksparse_block_size,
|
||||
blocksparse_head_sliding_step);
|
||||
}
|
||||
@ -549,14 +549,14 @@ __global__ void paged_attention_v2_kernel(
|
||||
const int max_num_blocks_per_seq,
|
||||
const float* __restrict__ alibi_slopes, // [num_heads]
|
||||
const int q_stride, const int kv_block_stride, const int kv_head_stride,
|
||||
const float kv_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
const float k_scale, const float v_scale, const int tp_rank,
|
||||
const int blocksparse_local_blocks, const int blocksparse_vert_stride,
|
||||
const int blocksparse_block_size, const int blocksparse_head_sliding_step) {
|
||||
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
|
||||
KV_DTYPE, IS_BLOCK_SPARSE, PARTITION_SIZE>(
|
||||
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
|
||||
block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride,
|
||||
kv_block_stride, kv_head_stride, kv_scale, tp_rank,
|
||||
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank,
|
||||
blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size,
|
||||
blocksparse_head_sliding_step);
|
||||
}
|
||||
@ -682,7 +682,7 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \
|
||||
scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \
|
||||
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
|
||||
kv_scale, tp_rank, blocksparse_local_blocks, \
|
||||
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||
blocksparse_vert_stride, blocksparse_block_size, \
|
||||
blocksparse_head_sliding_step);
|
||||
|
||||
@ -694,8 +694,8 @@ void paged_attention_v1_launcher(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
||||
const int tp_rank, const int blocksparse_local_blocks,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
@ -706,7 +706,7 @@ void paged_attention_v1_launcher(
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
@ -751,6 +751,9 @@ void paged_attention_v1_launcher(
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V1(112);
|
||||
break;
|
||||
case 120:
|
||||
LAUNCH_PAGED_ATTENTION_V1(120);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V1(128);
|
||||
break;
|
||||
@ -770,7 +773,7 @@ void paged_attention_v1_launcher(
|
||||
paged_attention_v1_launcher<T, CACHE_T, BLOCK_SIZE, KV_DTYPE, \
|
||||
IS_BLOCK_SPARSE>( \
|
||||
out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \
|
||||
seq_lens, max_seq_len, alibi_slopes, kv_scale, tp_rank, \
|
||||
seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \
|
||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||
blocksparse_block_size, blocksparse_head_sliding_step);
|
||||
|
||||
@ -815,8 +818,8 @@ void paged_attention_v1(
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
@ -833,7 +836,7 @@ void paged_attention_v1(
|
||||
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \
|
||||
value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \
|
||||
seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \
|
||||
kv_block_stride, kv_head_stride, kv_scale, tp_rank, \
|
||||
kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \
|
||||
blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||
blocksparse_block_size, blocksparse_head_sliding_step); \
|
||||
vllm::paged_attention_v2_reduce_kernel<T, HEAD_SIZE, NUM_THREADS, \
|
||||
@ -850,8 +853,8 @@ void paged_attention_v2_launcher(
|
||||
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int num_kv_heads, float scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float kv_scale,
|
||||
const int tp_rank, const int blocksparse_local_blocks,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes, float k_scale,
|
||||
float v_scale, const int tp_rank, const int blocksparse_local_blocks,
|
||||
const int blocksparse_vert_stride, const int blocksparse_block_size,
|
||||
const int blocksparse_head_sliding_step) {
|
||||
int num_seqs = query.size(0);
|
||||
@ -862,7 +865,7 @@ void paged_attention_v2_launcher(
|
||||
int kv_block_stride = key_cache.stride(0);
|
||||
int kv_head_stride = key_cache.stride(1);
|
||||
|
||||
int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
[[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1);
|
||||
assert(head_size % thread_group_size == 0);
|
||||
|
||||
// NOTE: alibi_slopes is optional.
|
||||
@ -912,6 +915,9 @@ void paged_attention_v2_launcher(
|
||||
case 112:
|
||||
LAUNCH_PAGED_ATTENTION_V2(112);
|
||||
break;
|
||||
case 120:
|
||||
LAUNCH_PAGED_ATTENTION_V2(120);
|
||||
break;
|
||||
case 128:
|
||||
LAUNCH_PAGED_ATTENTION_V2(128);
|
||||
break;
|
||||
@ -932,8 +938,9 @@ void paged_attention_v2_launcher(
|
||||
IS_BLOCK_SPARSE>( \
|
||||
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
|
||||
num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \
|
||||
kv_scale, tp_rank, blocksparse_local_blocks, blocksparse_vert_stride, \
|
||||
blocksparse_block_size, blocksparse_head_sliding_step);
|
||||
k_scale, v_scale, tp_rank, blocksparse_local_blocks, \
|
||||
blocksparse_vert_stride, blocksparse_block_size, \
|
||||
blocksparse_head_sliding_step);
|
||||
|
||||
#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \
|
||||
switch (is_block_sparse) { \
|
||||
@ -980,8 +987,8 @@ void paged_attention_v2(
|
||||
torch::Tensor& seq_lens, // [num_seqs]
|
||||
int64_t block_size, int64_t max_seq_len,
|
||||
const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
const bool is_block_sparse = (blocksparse_vert_stride > 1);
|
||||
|
@ -94,6 +94,7 @@ inline __device__ float2 bf1622float2(const __nv_bfloat162 val) {
|
||||
#else
|
||||
return __bfloat1622float2(val);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||
@ -102,6 +103,7 @@ inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) {
|
||||
#else
|
||||
return __bfloat162bfloat162(val);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// Vector addition.
|
||||
@ -115,6 +117,7 @@ inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
#endif
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
@ -123,6 +126,7 @@ inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#else
|
||||
return __hadd2(a, b);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) {
|
||||
@ -170,6 +174,7 @@ inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) {
|
||||
#else
|
||||
return __hmul(a, b);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -179,6 +184,7 @@ inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) {
|
||||
#else
|
||||
return __hmul2(a, b);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <>
|
||||
@ -289,6 +295,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b,
|
||||
#else
|
||||
return __hfma2(a, b, c);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||
@ -298,6 +305,7 @@ inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b,
|
||||
#else
|
||||
return __hfma2(bf162bf162(a), b, c);
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) {
|
||||
|
@ -18,14 +18,15 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype,
|
||||
const double kv_scale);
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale);
|
||||
|
||||
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype);
|
||||
const std::string& kv_cache_dtype,
|
||||
const double k_scale, const double v_scale);
|
||||
|
||||
// Just for unittest
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
|
@ -159,8 +159,8 @@ __global__ void reshape_and_cache_kernel(
|
||||
// block_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int key_stride, const int value_stride, const int num_heads,
|
||||
const int head_size, const int block_size, const int x,
|
||||
const float kv_scale) {
|
||||
const int head_size, const int block_size, const int x, const float k_scale,
|
||||
const float v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
if (slot_idx < 0) {
|
||||
@ -196,24 +196,25 @@ __global__ void reshape_and_cache_kernel(
|
||||
value_cache[tgt_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, kv_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
value_cache[tgt_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, kv_scale);
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void reshape_and_cache_flash_kernel(
|
||||
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
|
||||
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
|
||||
scalar_t* __restrict__ k_cache, // [num_blocks, block_size, num_heads,
|
||||
cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
scalar_t* __restrict__ v_cache, // [num_blocks, block_size, num_heads,
|
||||
cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads,
|
||||
// head_size]
|
||||
const int64_t* __restrict__ slot_mapping, // [num_tokens]
|
||||
const int block_stride, const int key_stride, const int value_stride,
|
||||
const int num_heads, const int head_size, const int block_size) {
|
||||
const int num_heads, const int head_size, const int block_size,
|
||||
const float k_scale, const float v_scale) {
|
||||
const int64_t token_idx = blockIdx.x;
|
||||
const int64_t slot_idx = slot_mapping[token_idx];
|
||||
// NOTE: slot_idx can be -1 if the token is padded
|
||||
@ -228,11 +229,20 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
const int64_t src_value_idx = token_idx * value_stride + i;
|
||||
const int head_idx = i / head_size;
|
||||
const int head_offset = i % head_size;
|
||||
const int64_t tgt_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
k_cache[tgt_value_idx] = key[src_key_idx];
|
||||
v_cache[tgt_value_idx] = value[src_value_idx];
|
||||
const int64_t tgt_key_value_idx = block_idx * block_stride +
|
||||
block_offset * num_heads * head_size +
|
||||
head_idx * head_size + head_offset;
|
||||
scalar_t tgt_key = key[src_key_idx];
|
||||
scalar_t tgt_value = value[src_value_idx];
|
||||
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
|
||||
key_cache[tgt_key_value_idx] = tgt_key;
|
||||
value_cache[tgt_key_value_idx] = tgt_value;
|
||||
} else {
|
||||
key_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, k_scale);
|
||||
value_cache[tgt_key_value_idx] =
|
||||
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, v_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace vllm
|
||||
@ -248,7 +258,7 @@ __global__ void reshape_and_cache_flash_kernel(
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
|
||||
num_heads, head_size, block_size, x, kv_scale);
|
||||
num_heads, head_size, block_size, x, k_scale, v_scale);
|
||||
|
||||
void reshape_and_cache(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
@ -258,7 +268,8 @@ void reshape_and_cache(
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, num_heads, head_size, block_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype, const double kv_scale) {
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
@ -277,40 +288,45 @@ void reshape_and_cache(
|
||||
CALL_RESHAPE_AND_CACHE)
|
||||
}
|
||||
|
||||
// KV_T is the stored data type of kv-cache.
|
||||
// CACHE_T is the data type of key and value tensors.
|
||||
// KV_DTYPE is the real data type of kv-cache.
|
||||
#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \
|
||||
vllm::reshape_and_cache_flash_kernel<KV_T, CACHE_T, KV_DTYPE> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<KV_T*>(key.data_ptr()), \
|
||||
reinterpret_cast<KV_T*>(value.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
|
||||
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride, \
|
||||
value_stride, num_heads, head_size, block_size, k_scale, v_scale);
|
||||
|
||||
void reshape_and_cache_flash(
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& k_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& v_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& key, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& value, // [num_tokens, num_heads, head_size]
|
||||
torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor&
|
||||
value_cache, // [num_blocks, block_size, num_heads, head_size]
|
||||
torch::Tensor& slot_mapping, // [num_tokens]
|
||||
const std::string& kv_cache_dtype) {
|
||||
// FIXME: only support auto datatype, does not support fp8
|
||||
if (kv_cache_dtype != "auto") {
|
||||
TORCH_CHECK(false, "Unsupported data type of kv cache: ", kv_cache_dtype);
|
||||
}
|
||||
const std::string& kv_cache_dtype, const double k_scale,
|
||||
const double v_scale) {
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
int head_size = key.size(2);
|
||||
int block_size = k_cache.size(1);
|
||||
int block_size = key_cache.size(1);
|
||||
|
||||
int key_stride = key.stride(0);
|
||||
int value_stride = value.stride(0);
|
||||
int block_stride = k_cache.stride(0);
|
||||
TORCH_CHECK(k_cache.stride(0) == v_cache.stride(0));
|
||||
int block_stride = key_cache.stride(0);
|
||||
TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0));
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(num_heads * head_size, 512));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
key.scalar_type(), "reshape_and_cache_flash", [&] {
|
||||
vllm::reshape_and_cache_flash_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
key.data_ptr<scalar_t>(), value.data_ptr<scalar_t>(),
|
||||
k_cache.data_ptr<scalar_t>(), v_cache.data_ptr<scalar_t>(),
|
||||
slot_mapping.data_ptr<int64_t>(), block_stride, key_stride,
|
||||
value_stride, num_heads, head_size, block_size);
|
||||
});
|
||||
|
||||
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
|
||||
CALL_RESHAPE_AND_CACHE_FLASH);
|
||||
}
|
||||
|
||||
namespace vllm {
|
||||
@ -318,13 +334,13 @@ namespace vllm {
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||
Tout* __restrict__ dst_cache,
|
||||
const float kv_scale,
|
||||
const float scale,
|
||||
const int64_t block_stride) {
|
||||
const int64_t block_idx = blockIdx.x;
|
||||
for (int i = threadIdx.x; i < block_stride; i += blockDim.x) {
|
||||
int64_t idx = block_idx * block_stride + i;
|
||||
dst_cache[idx] =
|
||||
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], kv_scale);
|
||||
fp8::scaled_convert<Tout, Tin, kv_dt>(src_cache[idx], scale);
|
||||
}
|
||||
}
|
||||
|
||||
@ -333,11 +349,11 @@ __global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache,
|
||||
#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \
|
||||
vllm::convert_fp8_kernel<Tout, Tin, KV_DTYPE><<<grid, block, 0, stream>>>( \
|
||||
reinterpret_cast<Tin*>(src_cache.data_ptr()), \
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), kv_scale, block_stride);
|
||||
reinterpret_cast<Tout*>(dst_cache.data_ptr()), scale, block_stride);
|
||||
|
||||
// Only for testing.
|
||||
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
|
||||
const double kv_scale, const std::string& kv_cache_dtype) {
|
||||
const double scale, const std::string& kv_cache_dtype) {
|
||||
torch::Device src_device = src_cache.device();
|
||||
torch::Device dst_device = dst_cache.device();
|
||||
TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU")
|
||||
|
382
csrc/core/scalar_type.hpp
Normal file
382
csrc/core/scalar_type.hpp
Normal file
@ -0,0 +1,382 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/custom_class.h>
|
||||
|
||||
namespace vllm {
|
||||
|
||||
//
|
||||
// ScalarType can represent a wide range of floating point and integer types,
|
||||
// in particular it can be used to represent sub-byte data types (something
|
||||
// that torch.dtype currently does not support).
|
||||
//
|
||||
// ScalarTypeTorch is a subclass of ScalarType that is compatible with
|
||||
// TORCH_LIBRARY, making it accessible from Python as well meaning this class
|
||||
// can be used as a argument for custom operators, helping to simplify these
|
||||
// interfaces.
|
||||
//
|
||||
// The type definitions on the Python side can be found in: vllm/_core_ext.pyi
|
||||
// these type definitions should be kept up to date with any Python API changes
|
||||
// here.
|
||||
//
|
||||
class ScalarType {
|
||||
public:
|
||||
enum NanRepr : int64_t {
|
||||
NAN_NONE = 0, // nans are not supported
|
||||
NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s
|
||||
NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s
|
||||
|
||||
NAN_REPR_ID_MAX
|
||||
};
|
||||
|
||||
constexpr ScalarType(bool signed_, int64_t exponent, int64_t mantissa,
|
||||
int64_t bias, bool finite_values_only = false,
|
||||
NanRepr nan_repr = NAN_IEEE_754)
|
||||
: exponent(exponent),
|
||||
mantissa(mantissa),
|
||||
bias(bias),
|
||||
signed_(signed_),
|
||||
finite_values_only(finite_values_only),
|
||||
nan_repr(nan_repr){};
|
||||
|
||||
static constexpr ScalarType int_(int64_t size_bits, int64_t bias = 0) {
|
||||
return ScalarType(true, 0, size_bits - 1, bias);
|
||||
}
|
||||
|
||||
static constexpr ScalarType uint(int64_t size_bits, int64_t bias = 0) {
|
||||
return ScalarType(false, 0, size_bits, bias);
|
||||
}
|
||||
|
||||
// IEEE 754 compliant floating point type
|
||||
static constexpr ScalarType float_IEEE754(int64_t exponent,
|
||||
int64_t mantissa) {
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
return ScalarType(true, exponent, mantissa, 0, false, NAN_IEEE_754);
|
||||
}
|
||||
|
||||
// IEEE 754 non-compliant floating point type
|
||||
static constexpr ScalarType float_(int64_t exponent, int64_t mantissa,
|
||||
bool finite_values_only,
|
||||
NanRepr nan_repr) {
|
||||
TORCH_CHECK(nan_repr < NAN_REPR_ID_MAX, "Invalid NanRepr");
|
||||
TORCH_CHECK(mantissa > 0 && exponent > 0);
|
||||
TORCH_CHECK(nan_repr != NAN_IEEE_754,
|
||||
"use `float_IEEE754` constructor for floating point types that "
|
||||
"follow IEEE 754 conventions");
|
||||
return ScalarType(true, exponent, mantissa, 0, finite_values_only,
|
||||
nan_repr);
|
||||
}
|
||||
|
||||
int64_t const exponent; // size of the exponent field (0 for integer types)
|
||||
int64_t const mantissa; // size of the mantissa field (size of the integer
|
||||
// excluding the sign bit for integer types)
|
||||
int64_t const bias; // stored values equal value + bias,
|
||||
// used for quantized type
|
||||
bool const signed_; // flag if the type supports negative numbers (i.e. has a
|
||||
// sign bit)
|
||||
|
||||
// Extra Floating point info
|
||||
bool const finite_values_only; // i.e. no +/-inf if true
|
||||
NanRepr const nan_repr; // how NaNs are represented
|
||||
// (not applicable for integer types)
|
||||
|
||||
int64_t size_bits() const { return mantissa + exponent + is_signed(); }
|
||||
bool is_signed() const { return signed_; }
|
||||
bool is_integer() const { return exponent == 0; }
|
||||
bool is_floating_point() const { return exponent > 0; }
|
||||
bool is_ieee_754() const {
|
||||
return is_floating_point() && finite_values_only == false &&
|
||||
nan_repr == NAN_IEEE_754;
|
||||
}
|
||||
bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; }
|
||||
bool has_infs() const {
|
||||
return is_floating_point() && finite_values_only == false;
|
||||
}
|
||||
bool has_bias() const { return bias != 0; }
|
||||
|
||||
private:
|
||||
double _floating_point_max() const {
|
||||
TORCH_CHECK(mantissa <= 52 && exponent <= 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
|
||||
uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) {
|
||||
max_mantissa -= 1;
|
||||
}
|
||||
|
||||
uint64_t max_exponent = (uint64_t(1) << exponent) - 2;
|
||||
if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) {
|
||||
TORCH_CHECK(exponent < 11,
|
||||
"Cannot represent max/min as a double for type ", str());
|
||||
max_exponent += 1;
|
||||
}
|
||||
|
||||
// adjust the exponent to match that of a double
|
||||
// for now we assume the exponent bias is the standard 2^(e-1) -1, (where e
|
||||
// is the exponent bits), there is some precedent for non-standard biases,
|
||||
// example `float8_e4m3b11fnuz` here: https://github.com/jax-ml/ml_dtypes
|
||||
// but to avoid premature over complication we are just assuming the
|
||||
// standard exponent bias until there is a need to support non-standard
|
||||
// biases
|
||||
uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1;
|
||||
uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11
|
||||
|
||||
uint64_t max_exponent_double =
|
||||
max_exponent - exponent_bias + exponent_bias_double;
|
||||
|
||||
// shift the mantissa into the position for a double and
|
||||
// the exponent
|
||||
uint64_t double_raw =
|
||||
(max_mantissa << (52 - mantissa)) | (max_exponent_double << 52);
|
||||
|
||||
return *reinterpret_cast<double*>(&double_raw);
|
||||
}
|
||||
|
||||
std::variant<int64_t, double> _raw_max() const {
|
||||
if (is_floating_point()) {
|
||||
return {_floating_point_max()};
|
||||
} else {
|
||||
TORCH_CHECK(size_bits() < 64 || size_bits() == 64 && is_signed(),
|
||||
"Cannot represent max as a int64_t");
|
||||
return {(int64_t(1) << mantissa) - 1};
|
||||
}
|
||||
}
|
||||
|
||||
std::variant<int64_t, double> _raw_min() const {
|
||||
if (is_floating_point()) {
|
||||
TORCH_CHECK(is_signed(),
|
||||
"We currently assume all floating point types are signed");
|
||||
constexpr uint64_t sign_bit_double = (uint64_t(1) << 63);
|
||||
|
||||
double max = _floating_point_max();
|
||||
uint64_t max_raw = *reinterpret_cast<uint64_t*>(&max);
|
||||
uint64_t min_raw = max_raw | sign_bit_double;
|
||||
return {*reinterpret_cast<double*>(&min_raw)};
|
||||
} else {
|
||||
TORCH_CHECK(!is_signed() || size_bits() <= 64,
|
||||
"Cannot represent min as a int64_t");
|
||||
if (is_signed()) {
|
||||
// set the top bit to 1 (i.e. INT64_MIN) and the rest to 0
|
||||
// then perform an arithmetic shift right to set all the bits above
|
||||
// (size_bits() - 1) to 1
|
||||
return {INT64_MIN >> (64 - size_bits())};
|
||||
} else {
|
||||
return {int64_t(0)};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
// Max representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
std::variant<int64_t, double> max() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_max());
|
||||
}
|
||||
|
||||
// Min representable value for this scalar type.
|
||||
// (accounting for bias if there is one)
|
||||
std::variant<int64_t, double> min() const {
|
||||
return std::visit(
|
||||
[this](auto x) -> std::variant<int64_t, double> { return {x - bias}; },
|
||||
_raw_min());
|
||||
}
|
||||
|
||||
std::string str() const {
|
||||
/* naming generally follows: https://github.com/jax-ml/ml_dtypes
|
||||
* for floating point types (leading f) the scheme is:
|
||||
* `float<size_bits>_e<exponent_bits>m<mantissa_bits>[flags]`
|
||||
* flags:
|
||||
* - no-flags: means it follows IEEE 754 conventions
|
||||
* - f: means finite values only (no infinities)
|
||||
* - n: means nans are supported (non-standard encoding)
|
||||
* for integer types the scheme is:
|
||||
* `[u]int<size_bits>[b<bias>]`
|
||||
* - if bias is not present it means its zero
|
||||
*/
|
||||
if (is_floating_point()) {
|
||||
auto ret = "float" + std::to_string(size_bits()) + "_e" +
|
||||
std::to_string(exponent) + "m" + std::to_string(mantissa);
|
||||
if (!is_ieee_754()) {
|
||||
if (finite_values_only) {
|
||||
ret += "f";
|
||||
}
|
||||
if (nan_repr != NAN_NONE) {
|
||||
ret += "n";
|
||||
}
|
||||
}
|
||||
return ret;
|
||||
} else {
|
||||
auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits());
|
||||
if (has_bias()) {
|
||||
ret += "b" + std::to_string(bias);
|
||||
}
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
bool operator==(ScalarType const& other) const {
|
||||
return mantissa == other.mantissa && exponent == other.exponent &&
|
||||
bias == other.bias && signed_ == other.signed_ &&
|
||||
finite_values_only == other.finite_values_only &&
|
||||
nan_repr == other.nan_repr;
|
||||
}
|
||||
};
|
||||
|
||||
// Create a TORCH_LIBRARY compatible version of ScalarType (i.e. inherit from
|
||||
// torch::CustomClassHolder), we use multiple inheritance here since we cannot
|
||||
// have ScalarType inherit from torch::CustomClassHolder and have a constexpr
|
||||
// constructor at the same time (torch::CustomClassHolder does not have a
|
||||
// constexpr destructor)
|
||||
class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType {
|
||||
public:
|
||||
ScalarTypeTorch(int64_t exponent, int64_t mantissa, int64_t bias,
|
||||
bool _signed)
|
||||
: ScalarType(exponent, mantissa, bias, _signed){};
|
||||
|
||||
ScalarTypeTorch(ScalarType type) : ScalarType(type){};
|
||||
|
||||
using Base = ScalarType;
|
||||
using Self = ScalarTypeTorch;
|
||||
using SelfPtr = c10::intrusive_ptr<Self>;
|
||||
|
||||
static SelfPtr int_(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::int_(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr uint(int64_t size_bits, c10::optional<int64_t> bias) {
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::uint(size_bits, bias.value_or(0)));
|
||||
}
|
||||
|
||||
static SelfPtr float_IEEE754(int64_t exponent, int64_t mantissa) {
|
||||
return c10::make_intrusive<Self>(
|
||||
ScalarType::float_IEEE754(exponent, mantissa));
|
||||
}
|
||||
|
||||
static SelfPtr float_(int64_t exponent, int64_t mantissa,
|
||||
bool finite_values_only, int64_t nan_repr) {
|
||||
return c10::make_intrusive<Self>(ScalarType::float_(
|
||||
exponent, mantissa, finite_values_only, NanRepr(nan_repr)));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
static void bind_readonly_property(torch::class_<Self>& cls,
|
||||
std::string const& name, T Base::*field) {
|
||||
auto getter_func = [field = std::move(field)](SelfPtr const& self) {
|
||||
if constexpr (std::is_member_function_pointer_v<decltype(field)>) {
|
||||
return (self.get()->*field)();
|
||||
} else {
|
||||
return self.get()->*field;
|
||||
}
|
||||
};
|
||||
|
||||
cls.def_property(name, getter_func);
|
||||
}
|
||||
|
||||
template <typename MemberFunc, typename Cls>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
MemberFunc Cls::*member) {
|
||||
cls.def(name, [member = std::move(member)](SelfPtr const& self) {
|
||||
return (self.get()->*member)();
|
||||
});
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_function(torch::class_<Self>& cls, const std::string& name,
|
||||
Func func) {
|
||||
cls.def(name, func);
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
static void bind_static_function(torch::class_<Self>& cls,
|
||||
const std::string& name, Func func) {
|
||||
cls.def_static(name, func);
|
||||
}
|
||||
|
||||
static void bind_class(torch::Library& lib) {
|
||||
auto cls = lib.class_<ScalarTypeTorch>("ScalarType")
|
||||
.def(torch::init<int64_t, int64_t, int64_t, bool>());
|
||||
|
||||
// Bind Properties
|
||||
bind_readonly_property(cls, "mantissa", &Base::mantissa);
|
||||
bind_readonly_property(cls, "exponent", &Base::exponent);
|
||||
bind_readonly_property(cls, "bias", &Base::bias);
|
||||
bind_readonly_property(cls, "signed", &Base::is_signed);
|
||||
bind_readonly_property(cls, "size_bits", &Base::size_bits);
|
||||
|
||||
// Bind member functions
|
||||
bind_function(cls, "is_signed", &Base::is_signed);
|
||||
bind_function(cls, "is_integer", &Base::is_integer);
|
||||
bind_function(cls, "is_floating_point", &Base::is_floating_point);
|
||||
bind_function(cls, "is_ieee_754", &Base::is_ieee_754);
|
||||
bind_function(cls, "has_nans", &Base::has_nans);
|
||||
bind_function(cls, "has_infs", &Base::has_infs);
|
||||
bind_function(cls, "has_bias", &Base::has_bias);
|
||||
|
||||
bind_function(cls, "max", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->max());
|
||||
});
|
||||
bind_function(cls, "min", [](SelfPtr const& self) {
|
||||
return std::visit([](auto arg) { return c10::IValue(arg); },
|
||||
self.get()->min());
|
||||
});
|
||||
|
||||
bind_function(cls, "__str__", &Base::str);
|
||||
bind_function(cls, "__eq__", [](SelfPtr const& self, SelfPtr const& other) {
|
||||
return *self == *other;
|
||||
});
|
||||
bind_function(cls, "__repr__", [](SelfPtr const& self) {
|
||||
return "ScalarType." + self.get()->str();
|
||||
});
|
||||
|
||||
// Bind static functions (convenience constructors)
|
||||
bind_static_function(cls, "int_", &ScalarTypeTorch::int_);
|
||||
bind_static_function(cls, "uint", &ScalarTypeTorch::uint);
|
||||
bind_static_function(cls, "float_IEEE754", &ScalarTypeTorch::float_IEEE754);
|
||||
bind_static_function(cls, "float_", &ScalarTypeTorch::float_);
|
||||
}
|
||||
};
|
||||
|
||||
using ScalarTypeTorchPtr = c10::intrusive_ptr<ScalarTypeTorch>;
|
||||
|
||||
// "rust style" names generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L60-L70
|
||||
static inline constexpr auto kS4 = ScalarType::int_(4);
|
||||
static inline constexpr auto kU4 = ScalarType::uint(4);
|
||||
static inline constexpr auto kU4B8 = ScalarType::uint(4, 8);
|
||||
static inline constexpr auto kS8 = ScalarType::int_(8);
|
||||
static inline constexpr auto kU8 = ScalarType::uint(8);
|
||||
static inline constexpr auto kU8B128 = ScalarType::uint(8, 128);
|
||||
|
||||
static inline constexpr auto kFE3M2f =
|
||||
ScalarType::float_(3, 2, true, ScalarType::NAN_NONE);
|
||||
static inline constexpr auto kFE4M3fn =
|
||||
ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN);
|
||||
static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2);
|
||||
static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7);
|
||||
static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10);
|
||||
|
||||
// Fixed width style names, generally following:
|
||||
// https://github.com/pytorch/pytorch/blob/6d9f74f0af54751311f0dd71f7e5c01a93260ab3/torch/csrc/api/include/torch/types.h#L47-L57
|
||||
static inline constexpr auto kInt4 = kS4;
|
||||
static inline constexpr auto kUint4 = kU4;
|
||||
static inline constexpr auto kUint4b8 = kU4B8;
|
||||
static inline constexpr auto kInt8 = kS8;
|
||||
static inline constexpr auto kUint8 = kU8;
|
||||
static inline constexpr auto kUint8b128 = kU8B128;
|
||||
|
||||
static inline constexpr auto kFloat6_e3m2f = kFE3M2f;
|
||||
static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn;
|
||||
static inline constexpr auto kFloat8_e5m2 = kFE5M2;
|
||||
static inline constexpr auto kFloat16_e8m7 = kFE8M7;
|
||||
static inline constexpr auto kFloat16_e5m10 = kFE5M10;
|
||||
|
||||
// colloquial names
|
||||
static inline constexpr auto kHalf = kFE5M10;
|
||||
static inline constexpr auto kFloat16 = kHalf;
|
||||
static inline constexpr auto kBFloat16 = kFE8M7;
|
||||
|
||||
}; // namespace vllm
|
16
csrc/core/torch_bindings.cpp
Normal file
16
csrc/core/torch_bindings.cpp
Normal file
@ -0,0 +1,16 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "scalar_type.hpp"
|
||||
#include "registration.h"
|
||||
|
||||
// Note the CORE exstension will be built for (almost) all hardware targets so
|
||||
// new additions must account for this. (currently not built for TPU and Neuron)
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, lib) {
|
||||
// ScalarType, a custom class for representing data types that supports
|
||||
// quantized types, declared here so it can be used when creating interfaces
|
||||
// for custom ops.
|
||||
vllm::ScalarTypeTorch::bind_class(lib);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
@ -423,11 +423,11 @@ void paged_attention_v1(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v1_impl",
|
||||
@ -742,11 +742,11 @@ void paged_attention_v2(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
TORCH_CHECK(blocksparse_vert_stride <= 1,
|
||||
"CPU backend does not support blocksparse attention yet.");
|
||||
VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "paged_attention_v2_impl",
|
||||
|
@ -107,8 +107,9 @@ void copy_blocks(std::vector<torch::Tensor> const& key_caches,
|
||||
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
|
||||
torch::Tensor& key_cache, torch::Tensor& value_cache,
|
||||
torch::Tensor& slot_mapping,
|
||||
const std::string& kv_cache_dtype, double kv_scale) {
|
||||
TORCH_CHECK(kv_scale == 1.0f);
|
||||
const std::string& kv_cache_dtype, double k_scale,
|
||||
double v_scale) {
|
||||
TORCH_CHECK(k_scale == 1.0f && v_scale == 1.0f);
|
||||
|
||||
int num_tokens = key.size(0);
|
||||
int num_heads = key.size(1);
|
||||
|
@ -1,9 +1,11 @@
|
||||
#include "cache.h"
|
||||
#include "ops.h"
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/library.h>
|
||||
|
||||
void init_cpu_threads_env(const std::string& cpu_ids);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -16,8 +18,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||
" int blocksparse_local_blocks,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v1", torch::kCPU, &paged_attention_v1);
|
||||
@ -30,8 +32,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor value_cache, int num_kv_heads, float scale,"
|
||||
" Tensor block_tables, Tensor seq_lens, int block_size,"
|
||||
" int max_seq_len, Tensor? alibi_slopes,"
|
||||
" str kv_cache_dtype, float kv_scale, int tp_rank,"
|
||||
" int blocksparse_local_blocks,"
|
||||
" str kv_cache_dtype, float k_scale, float v_scale,"
|
||||
" int tp_rank, int blocksparse_local_blocks,"
|
||||
" int blocksparse_vert_stride, int blocksparse_block_size,"
|
||||
" int blocksparse_head_sliding_step) -> ()");
|
||||
ops.impl("paged_attention_v2", torch::kCPU, &paged_attention_v2);
|
||||
@ -103,8 +105,13 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
" Tensor! key_cache, Tensor! value_cache,"
|
||||
" Tensor slot_mapping,"
|
||||
" str kv_cache_dtype,"
|
||||
" float kv_scale) -> ()");
|
||||
" float k_scale, float v_scale) -> ()");
|
||||
cache_ops.impl("reshape_and_cache", torch::kCPU, &reshape_and_cache);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) {
|
||||
// CPU utils
|
||||
utils.def("init_cpu_threads_env(str cpu_ids) -> ()", &init_cpu_threads_env);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
65
csrc/cpu/utils.cpp
Normal file
65
csrc/cpu/utils.cpp
Normal file
@ -0,0 +1,65 @@
|
||||
#include <numa.h>
|
||||
#include <unistd.h>
|
||||
#include <string>
|
||||
#include <sched.h>
|
||||
|
||||
#include "cpu_types.hpp"
|
||||
|
||||
void init_cpu_threads_env(const std::string& cpu_ids) {
|
||||
bitmask* omp_cpu_mask = numa_parse_cpustring(cpu_ids.c_str());
|
||||
TORCH_CHECK(omp_cpu_mask->size > 0);
|
||||
std::vector<int> omp_cpu_ids;
|
||||
omp_cpu_ids.reserve(omp_cpu_mask->size);
|
||||
|
||||
constexpr int group_size = 8 * sizeof(*omp_cpu_mask->maskp);
|
||||
|
||||
for (int offset = 0; offset < omp_cpu_mask->size; offset += group_size) {
|
||||
unsigned long group_mask = omp_cpu_mask->maskp[offset / group_size];
|
||||
int i = 0;
|
||||
while (group_mask) {
|
||||
if (group_mask & 1) {
|
||||
omp_cpu_ids.emplace_back(offset + i);
|
||||
}
|
||||
++i;
|
||||
group_mask >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Memory node binding
|
||||
if (numa_available() != -1) {
|
||||
int mem_node_id = numa_node_of_cpu(omp_cpu_ids.front());
|
||||
bitmask* mask = numa_parse_nodestring(std::to_string(mem_node_id).c_str());
|
||||
bitmask* src_mask = numa_get_membind();
|
||||
|
||||
int pid = getpid();
|
||||
|
||||
// move all existing pages to the specified numa node.
|
||||
*(src_mask->maskp) = *(src_mask->maskp) ^ *(mask->maskp);
|
||||
int page_num = numa_migrate_pages(pid, src_mask, mask);
|
||||
if (page_num == -1) {
|
||||
TORCH_CHECK(false,
|
||||
"numa_migrate_pages failed. errno: " + std::to_string(errno));
|
||||
}
|
||||
|
||||
// restrict memory allocation node.
|
||||
numa_set_membind(mask);
|
||||
numa_set_strict(1);
|
||||
}
|
||||
|
||||
// OMP threads binding
|
||||
omp_set_num_threads((int)omp_cpu_ids.size());
|
||||
torch::set_num_threads((int)omp_cpu_ids.size());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), torch::get_num_threads());
|
||||
TORCH_CHECK_EQ(omp_cpu_ids.size(), omp_get_max_threads());
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
for (size_t i = 0; i < omp_cpu_ids.size(); ++i) {
|
||||
cpu_set_t* mask = CPU_ALLOC(omp_cpu_mask->size);
|
||||
size_t size = CPU_ALLOC_SIZE(omp_cpu_mask->size);
|
||||
CPU_ZERO_S(size, mask);
|
||||
CPU_SET_S(omp_cpu_ids[i], size, mask);
|
||||
sched_setaffinity(0, sizeof(cpu_set_t), mask);
|
||||
CPU_FREE(mask);
|
||||
}
|
||||
|
||||
numa_free_nodemask(omp_cpu_mask);
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
#include "registration.h"
|
||||
#include "core/registration.h"
|
||||
#include "moe_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
|
49
csrc/ops.h
49
csrc/ops.h
@ -3,13 +3,15 @@
|
||||
#include <optional>
|
||||
#include <torch/library.h>
|
||||
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
void paged_attention_v1(
|
||||
torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache,
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
@ -19,8 +21,8 @@ void paged_attention_v2(
|
||||
torch::Tensor& value_cache, int64_t num_kv_heads, double scale,
|
||||
torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size,
|
||||
int64_t max_seq_len, const c10::optional<torch::Tensor>& alibi_slopes,
|
||||
const std::string& kv_cache_dtype, double kv_scale, const int64_t tp_rank,
|
||||
const int64_t blocksparse_local_blocks,
|
||||
const std::string& kv_cache_dtype, double k_scale, double v_scale,
|
||||
const int64_t tp_rank, const int64_t blocksparse_local_blocks,
|
||||
const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size,
|
||||
const int64_t blocksparse_head_sliding_step);
|
||||
|
||||
@ -52,6 +54,11 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void gelu_quick(torch::Tensor& out, torch::Tensor& input);
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes,
|
||||
const torch::Tensor& codebooks,
|
||||
@ -79,20 +86,27 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_meta,
|
||||
torch::Tensor& b_scales,
|
||||
torch::Tensor& workspace, int64_t num_bits,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n,
|
||||
int64_t size_k);
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full);
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce);
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits);
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits);
|
||||
|
||||
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
@ -105,6 +119,13 @@ void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
c10::optional<torch::Tensor> const& bias);
|
||||
|
||||
torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
|
||||
torch::Tensor const& b_q_weight,
|
||||
torch::Tensor const& s_tok,
|
||||
torch::Tensor const& s_ch,
|
||||
torch::Tensor const& s_group,
|
||||
torch::Tensor& workspace, int64_t size_m,
|
||||
int64_t size_n, int64_t size_k);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
@ -123,12 +144,16 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
|
||||
|
||||
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
torch::Tensor& scale);
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor const& scale);
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
torch::Tensor& scale);
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,
|
||||
c10::optional<torch::Tensor> const& scale_ub);
|
||||
|
||||
void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int64_t block_size, torch::Tensor sorted_token_ids,
|
||||
torch::Tensor experts_ids,
|
||||
|
131
csrc/prepare_inputs/advance_step.cu
Normal file
131
csrc/prepare_inputs/advance_step.cu
Normal file
@ -0,0 +1,131 @@
|
||||
/*
|
||||
* The goal of this GPU kernel is to advance input tensors on the GPU directly
|
||||
* PR: https://github.com/vllm-project/vllm/pull/6338
|
||||
* Current restrictions:
|
||||
* 1. Specialized for DraftModelRunner
|
||||
* 2. Supports flash_attn only
|
||||
*/
|
||||
|
||||
#include "advance_step.cuh"
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
//
|
||||
template <int const num_threads>
|
||||
__global__ void advance_step_kernel(int num_seqs, int num_queries,
|
||||
int block_size, long* input_tokens_ptr,
|
||||
long const* sampled_token_ids_ptr,
|
||||
long* input_positions_ptr,
|
||||
int* seq_lens_ptr, long* slot_mapping_ptr,
|
||||
int const* block_tables_ptr,
|
||||
int64_t const block_tables_stride) {
|
||||
int num_query_blocks = div_ceil(num_queries, num_threads);
|
||||
|
||||
if (blockIdx.x >= num_query_blocks) {
|
||||
return;
|
||||
}
|
||||
|
||||
int cur_query_id = blockIdx.x * num_threads + threadIdx.x;
|
||||
|
||||
if (cur_query_id >= num_queries) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Update input_tokens
|
||||
input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id];
|
||||
|
||||
int seq_len = seq_lens_ptr[cur_query_id];
|
||||
int next_seq_len = seq_len + 1;
|
||||
int next_input_pos = next_seq_len - 1;
|
||||
|
||||
// Update seq_lens
|
||||
seq_lens_ptr[cur_query_id] = next_seq_len;
|
||||
// Update input_positions
|
||||
input_positions_ptr[cur_query_id] = next_input_pos;
|
||||
|
||||
int const* seq_block_tables_ptr =
|
||||
block_tables_ptr + block_tables_stride * cur_query_id;
|
||||
|
||||
int block_index = next_input_pos / block_size;
|
||||
int block_offset = next_input_pos % block_size;
|
||||
|
||||
int slot_num = seq_block_tables_ptr[block_index] * block_size + block_offset;
|
||||
// Update slot_mapping
|
||||
slot_mapping_ptr[cur_query_id] = slot_num;
|
||||
}
|
||||
|
||||
inline void verify_tensor(std::string const& name, torch::Tensor& t,
|
||||
int64_t const size_0, int64_t const size_1,
|
||||
c10::ScalarType const type) {
|
||||
bool size_0_cond = true;
|
||||
if (size_0 != -1) {
|
||||
size_0_cond = t.size(0) == size_0;
|
||||
}
|
||||
|
||||
bool size_1_cond = true;
|
||||
if (size_1 != -1) {
|
||||
size_1_cond = t.size(1) == size_1;
|
||||
}
|
||||
|
||||
bool is_contiguous = t.is_contiguous();
|
||||
bool same_type = t.dtype() == type;
|
||||
|
||||
bool pass = size_0_cond && size_1_cond && is_contiguous && same_type;
|
||||
if (!pass) {
|
||||
TORCH_CHECK(false, "tensor: name = ", name, ", shape = ", t.sizes(),
|
||||
" is_cont = ", t.is_contiguous(), ", type = ", t.dtype(),
|
||||
" is not as expected: shape = [", size_0, ", ", size_1,
|
||||
"], type = ", type);
|
||||
}
|
||||
}
|
||||
|
||||
void advance_step(int num_seqs, int num_queries, int block_size,
|
||||
torch::Tensor& input_tokens, // type: long
|
||||
torch::Tensor& sampled_token_ids, // type: long
|
||||
torch::Tensor& input_positions, // type: long
|
||||
torch::Tensor& seq_lens, // type: int
|
||||
torch::Tensor& slot_mapping, // type: long
|
||||
torch::Tensor& block_tables) { // type: int
|
||||
|
||||
if (logging) {
|
||||
printf("advance_step:\n");
|
||||
printf(" num_seqs = %d\n", num_seqs);
|
||||
printf(" num_queries = %d\n", num_queries);
|
||||
printf(" block_size = %d\n", block_size);
|
||||
}
|
||||
// Verify all tensors
|
||||
verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong);
|
||||
verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1,
|
||||
at::kLong);
|
||||
verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong);
|
||||
verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt);
|
||||
verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong);
|
||||
verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt);
|
||||
|
||||
int dev = sampled_token_ids.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
advance_step_kernel<max_threads><<<blocks, max_threads, 0, stream>>>(
|
||||
num_seqs, num_queries, block_size,
|
||||
reinterpret_cast<long*>(input_tokens.data_ptr()),
|
||||
reinterpret_cast<long const*>(sampled_token_ids.data_ptr()),
|
||||
reinterpret_cast<long*>(input_positions.data_ptr()),
|
||||
reinterpret_cast<int*>(seq_lens.data_ptr()),
|
||||
reinterpret_cast<long*>(slot_mapping.data_ptr()),
|
||||
reinterpret_cast<int const*>(block_tables.data_ptr()),
|
||||
block_tables.stride(0));
|
||||
}
|
||||
|
||||
} // namespace prepare_inputs
|
||||
|
||||
void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size,
|
||||
torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids,
|
||||
torch::Tensor& input_positions, torch::Tensor& seq_lens,
|
||||
torch::Tensor& slot_mapping, torch::Tensor& block_tables) {
|
||||
prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens,
|
||||
sampled_token_ids, input_positions, seq_lens,
|
||||
slot_mapping, block_tables);
|
||||
}
|
19
csrc/prepare_inputs/advance_step.cuh
Normal file
19
csrc/prepare_inputs/advance_step.cuh
Normal file
@ -0,0 +1,19 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cuda.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace prepare_inputs {
|
||||
|
||||
static constexpr int max_threads = 256;
|
||||
static constexpr bool logging = false;
|
||||
|
||||
constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
} // namespace prepare_inputs
|
@ -1,217 +0,0 @@
|
||||
Contains code from https://github.com/punica-ai/punica
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "{}"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright {yyyy} {name of copyright owner}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
|
||||
This product bundles various third-party components under other open source licenses.
|
||||
This section summarizes those components and their licenses. See licenses/
|
||||
for text of these licenses.
|
||||
|
||||
|
||||
Apache-2.0
|
||||
* third_party/nvbench (with LLVM exception)
|
||||
* third_party/flashinfer
|
||||
|
||||
BSD-3-Clause:
|
||||
* third_party/cutlass
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16)
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16)
|
@ -1,218 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale);
|
||||
|
||||
// clang-format off
|
||||
|
||||
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, narrow, 128) \
|
||||
f(in_T, out_T, W_T, narrow, 256) \
|
||||
f(in_T, out_T, W_T, narrow, 512) \
|
||||
f(in_T, out_T, W_T, narrow, 640) \
|
||||
f(in_T, out_T, W_T, narrow, 768) \
|
||||
f(in_T, out_T, W_T, narrow, 896) \
|
||||
f(in_T, out_T, W_T, narrow, 1024) \
|
||||
f(in_T, out_T, W_T, narrow, 1152) \
|
||||
f(in_T, out_T, W_T, narrow, 1216) \
|
||||
f(in_T, out_T, W_T, narrow, 1280) \
|
||||
f(in_T, out_T, W_T, narrow, 1536) \
|
||||
f(in_T, out_T, W_T, narrow, 1664) \
|
||||
f(in_T, out_T, W_T, narrow, 1728) \
|
||||
f(in_T, out_T, W_T, narrow, 1792) \
|
||||
f(in_T, out_T, W_T, narrow, 2048) \
|
||||
f(in_T, out_T, W_T, narrow, 2240) \
|
||||
f(in_T, out_T, W_T, narrow, 2304) \
|
||||
f(in_T, out_T, W_T, narrow, 2368) \
|
||||
f(in_T, out_T, W_T, narrow, 2432) \
|
||||
f(in_T, out_T, W_T, narrow, 2560) \
|
||||
f(in_T, out_T, W_T, narrow, 2752) \
|
||||
f(in_T, out_T, W_T, narrow, 2816) \
|
||||
f(in_T, out_T, W_T, narrow, 3072) \
|
||||
f(in_T, out_T, W_T, narrow, 3328) \
|
||||
f(in_T, out_T, W_T, narrow, 3456) \
|
||||
f(in_T, out_T, W_T, narrow, 3584) \
|
||||
f(in_T, out_T, W_T, narrow, 3712) \
|
||||
f(in_T, out_T, W_T, narrow, 4096) \
|
||||
f(in_T, out_T, W_T, narrow, 4480) \
|
||||
f(in_T, out_T, W_T, narrow, 4608) \
|
||||
f(in_T, out_T, W_T, narrow, 4736) \
|
||||
f(in_T, out_T, W_T, narrow, 4864) \
|
||||
f(in_T, out_T, W_T, narrow, 5120) \
|
||||
f(in_T, out_T, W_T, narrow, 5504) \
|
||||
f(in_T, out_T, W_T, narrow, 5632) \
|
||||
f(in_T, out_T, W_T, narrow, 5888) \
|
||||
f(in_T, out_T, W_T, narrow, 6144) \
|
||||
f(in_T, out_T, W_T, narrow, 6400) \
|
||||
f(in_T, out_T, W_T, narrow, 6848) \
|
||||
f(in_T, out_T, W_T, narrow, 6912) \
|
||||
f(in_T, out_T, W_T, narrow, 7168) \
|
||||
f(in_T, out_T, W_T, narrow, 7424) \
|
||||
f(in_T, out_T, W_T, narrow, 8192) \
|
||||
f(in_T, out_T, W_T, narrow, 8960) \
|
||||
f(in_T, out_T, W_T, narrow, 9216) \
|
||||
f(in_T, out_T, W_T, narrow, 9472) \
|
||||
f(in_T, out_T, W_T, narrow, 10240) \
|
||||
f(in_T, out_T, W_T, narrow, 11008) \
|
||||
f(in_T, out_T, W_T, narrow, 11264) \
|
||||
f(in_T, out_T, W_T, narrow, 12288) \
|
||||
f(in_T, out_T, W_T, narrow, 13696) \
|
||||
f(in_T, out_T, W_T, narrow, 13824) \
|
||||
f(in_T, out_T, W_T, narrow, 14336) \
|
||||
f(in_T, out_T, W_T, narrow, 14784) \
|
||||
f(in_T, out_T, W_T, narrow, 14848) \
|
||||
f(in_T, out_T, W_T, narrow, 15360) \
|
||||
f(in_T, out_T, W_T, narrow, 16384) \
|
||||
f(in_T, out_T, W_T, narrow, 18944) \
|
||||
f(in_T, out_T, W_T, narrow, 20480) \
|
||||
f(in_T, out_T, W_T, narrow, 22016) \
|
||||
f(in_T, out_T, W_T, narrow, 22528) \
|
||||
f(in_T, out_T, W_T, narrow, 24576) \
|
||||
f(in_T, out_T, W_T, narrow, 27392) \
|
||||
f(in_T, out_T, W_T, narrow, 27648) \
|
||||
f(in_T, out_T, W_T, narrow, 28672) \
|
||||
f(in_T, out_T, W_T, narrow, 29568) \
|
||||
f(in_T, out_T, W_T, narrow, 29696) \
|
||||
f(in_T, out_T, W_T, narrow, 32000) \
|
||||
f(in_T, out_T, W_T, narrow, 32256) \
|
||||
f(in_T, out_T, W_T, narrow, 32512) \
|
||||
f(in_T, out_T, W_T, narrow, 32768) \
|
||||
f(in_T, out_T, W_T, narrow, 33024) \
|
||||
f(in_T, out_T, W_T, narrow, 36864) \
|
||||
f(in_T, out_T, W_T, narrow, 43264) \
|
||||
f(in_T, out_T, W_T, narrow, 49152) \
|
||||
f(in_T, out_T, W_T, narrow, 49408) \
|
||||
f(in_T, out_T, W_T, narrow, 60544) \
|
||||
f(in_T, out_T, W_T, narrow, 60672) \
|
||||
f(in_T, out_T, W_T, narrow, 64000) \
|
||||
f(in_T, out_T, W_T, narrow, 64256) \
|
||||
f(in_T, out_T, W_T, narrow, 64512) \
|
||||
f(in_T, out_T, W_T, narrow, 102400) \
|
||||
f(in_T, out_T, W_T, narrow, 102656) \
|
||||
f(in_T, out_T, W_T, narrow, 102912) \
|
||||
f(in_T, out_T, W_T, narrow, 128000) \
|
||||
f(in_T, out_T, W_T, narrow, 128256) \
|
||||
f(in_T, out_T, W_T, narrow, 128512) \
|
||||
|
||||
|
||||
// Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA
|
||||
// and vllm/tests/lora/test_punica.py
|
||||
|
||||
// Used for defining kernels going from the variety of
|
||||
// dim in to the narrow dim out
|
||||
// Using it for the fully sharded column
|
||||
// parallel LoRA A which splits the rank dim
|
||||
#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \
|
||||
f(in_T, out_T, W_T, 128, narrow) \
|
||||
f(in_T, out_T, W_T, 256, narrow) \
|
||||
f(in_T, out_T, W_T, 512, narrow) \
|
||||
f(in_T, out_T, W_T, 640, narrow) \
|
||||
f(in_T, out_T, W_T, 768, narrow) \
|
||||
f(in_T, out_T, W_T, 896, narrow) \
|
||||
f(in_T, out_T, W_T, 1024, narrow) \
|
||||
f(in_T, out_T, W_T, 1152, narrow) \
|
||||
f(in_T, out_T, W_T, 1216, narrow) \
|
||||
f(in_T, out_T, W_T, 1280, narrow) \
|
||||
f(in_T, out_T, W_T, 1536, narrow) \
|
||||
f(in_T, out_T, W_T, 1664, narrow) \
|
||||
f(in_T, out_T, W_T, 1728, narrow) \
|
||||
f(in_T, out_T, W_T, 1792, narrow) \
|
||||
f(in_T, out_T, W_T, 2048, narrow) \
|
||||
f(in_T, out_T, W_T, 2240, narrow) \
|
||||
f(in_T, out_T, W_T, 2304, narrow) \
|
||||
f(in_T, out_T, W_T, 2368, narrow) \
|
||||
f(in_T, out_T, W_T, 2432, narrow) \
|
||||
f(in_T, out_T, W_T, 2560, narrow) \
|
||||
f(in_T, out_T, W_T, 2752, narrow) \
|
||||
f(in_T, out_T, W_T, 2816, narrow) \
|
||||
f(in_T, out_T, W_T, 3072, narrow) \
|
||||
f(in_T, out_T, W_T, 3328, narrow) \
|
||||
f(in_T, out_T, W_T, 3456, narrow) \
|
||||
f(in_T, out_T, W_T, 3584, narrow) \
|
||||
f(in_T, out_T, W_T, 3712, narrow) \
|
||||
f(in_T, out_T, W_T, 4096, narrow) \
|
||||
f(in_T, out_T, W_T, 4480, narrow) \
|
||||
f(in_T, out_T, W_T, 4608, narrow) \
|
||||
f(in_T, out_T, W_T, 4736, narrow) \
|
||||
f(in_T, out_T, W_T, 4864, narrow) \
|
||||
f(in_T, out_T, W_T, 5120, narrow) \
|
||||
f(in_T, out_T, W_T, 5504, narrow) \
|
||||
f(in_T, out_T, W_T, 5632, narrow) \
|
||||
f(in_T, out_T, W_T, 5888, narrow) \
|
||||
f(in_T, out_T, W_T, 6144, narrow) \
|
||||
f(in_T, out_T, W_T, 6400, narrow) \
|
||||
f(in_T, out_T, W_T, 6848, narrow) \
|
||||
f(in_T, out_T, W_T, 6912, narrow) \
|
||||
f(in_T, out_T, W_T, 7168, narrow) \
|
||||
f(in_T, out_T, W_T, 7424, narrow) \
|
||||
f(in_T, out_T, W_T, 8192, narrow) \
|
||||
f(in_T, out_T, W_T, 8960, narrow) \
|
||||
f(in_T, out_T, W_T, 9216, narrow) \
|
||||
f(in_T, out_T, W_T, 9472, narrow) \
|
||||
f(in_T, out_T, W_T, 10240, narrow) \
|
||||
f(in_T, out_T, W_T, 11008, narrow) \
|
||||
f(in_T, out_T, W_T, 11264, narrow) \
|
||||
f(in_T, out_T, W_T, 12288, narrow) \
|
||||
f(in_T, out_T, W_T, 13696, narrow) \
|
||||
f(in_T, out_T, W_T, 13824, narrow) \
|
||||
f(in_T, out_T, W_T, 14336, narrow) \
|
||||
f(in_T, out_T, W_T, 14784, narrow) \
|
||||
f(in_T, out_T, W_T, 14848, narrow) \
|
||||
f(in_T, out_T, W_T, 15360, narrow) \
|
||||
f(in_T, out_T, W_T, 16384, narrow) \
|
||||
f(in_T, out_T, W_T, 18944, narrow) \
|
||||
f(in_T, out_T, W_T, 20480, narrow) \
|
||||
f(in_T, out_T, W_T, 22016, narrow) \
|
||||
f(in_T, out_T, W_T, 22528, narrow) \
|
||||
f(in_T, out_T, W_T, 24576, narrow) \
|
||||
f(in_T, out_T, W_T, 27392, narrow) \
|
||||
f(in_T, out_T, W_T, 27648, narrow) \
|
||||
f(in_T, out_T, W_T, 28672, narrow) \
|
||||
f(in_T, out_T, W_T, 29568, narrow) \
|
||||
f(in_T, out_T, W_T, 29696, narrow) \
|
||||
f(in_T, out_T, W_T, 32000, narrow) \
|
||||
f(in_T, out_T, W_T, 32256, narrow) \
|
||||
f(in_T, out_T, W_T, 32512, narrow) \
|
||||
f(in_T, out_T, W_T, 32768, narrow) \
|
||||
f(in_T, out_T, W_T, 33024, narrow) \
|
||||
f(in_T, out_T, W_T, 36864, narrow) \
|
||||
f(in_T, out_T, W_T, 43264, narrow) \
|
||||
f(in_T, out_T, W_T, 49152, narrow) \
|
||||
f(in_T, out_T, W_T, 49408, narrow) \
|
||||
f(in_T, out_T, W_T, 60544, narrow) \
|
||||
f(in_T, out_T, W_T, 60672, narrow) \
|
||||
f(in_T, out_T, W_T, 64000, narrow) \
|
||||
f(in_T, out_T, W_T, 64256, narrow) \
|
||||
f(in_T, out_T, W_T, 64512, narrow) \
|
||||
f(in_T, out_T, W_T, 102400, narrow) \
|
||||
f(in_T, out_T, W_T, 102656, narrow) \
|
||||
f(in_T, out_T, W_T, 102912, narrow) \
|
||||
f(in_T, out_T, W_T, 128000, narrow) \
|
||||
f(in_T, out_T, W_T, 128256, narrow) \
|
||||
f(in_T, out_T, W_T, 128512, narrow) \
|
||||
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
|
||||
|
||||
|
||||
// Keep this in sync with vllm/config::LoRAConfig
|
||||
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
|
||||
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
|
||||
|
||||
|
||||
#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \
|
||||
FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \
|
||||
f(in_T, out_T, W_T, 8, 64) \
|
||||
f(in_T, out_T, W_T, 16, 64) \
|
||||
f(in_T, out_T, W_T, 32, 64) \
|
||||
f(in_T, out_T, W_T, 64, 64)
|
||||
|
||||
// clang-format on
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half)
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half)
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16)
|
@ -1,5 +0,0 @@
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half)
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half)
|
@ -1,451 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#ifndef USE_ROCM
|
||||
#include <cooperative_groups.h>
|
||||
#else
|
||||
#include <hip/hip_cooperative_groups.h>
|
||||
#endif
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda/pipeline>
|
||||
#endif
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "vec_dtypes.cuh"
|
||||
|
||||
namespace cg = cooperative_groups;
|
||||
|
||||
#ifdef USE_ROCM
|
||||
template <size_t len>
|
||||
__host__ __device__
|
||||
inline void* memcpy_blocking(void *dst, const void *src) {
|
||||
// Does not handle the case of long datatypes
|
||||
char *d = reinterpret_cast<char *>(dst);
|
||||
const char *s = reinterpret_cast<const char *>(src);
|
||||
size_t i = 0;
|
||||
#pragma unroll
|
||||
for (i = 0; i < len; ++i) {
|
||||
d[i] = s[i];
|
||||
}
|
||||
return dst;
|
||||
}
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
// nthrs = (32, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t num_pipeline_stages = 2;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
__shared__ W_T W_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ in_T X_shared[num_pipeline_stages * tile_size];
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
size_t W_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
size_t X_shared_offset[num_pipeline_stages] = {0U, 1U * tile_size};
|
||||
auto pipe = cuda::make_pipeline();
|
||||
|
||||
// pipeline load W/X and compute WX;
|
||||
pipe.producer_acquire();
|
||||
cuda::memcpy_async(W_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + (threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
pipe.producer_commit();
|
||||
size_t copy_idx, compute_idx;
|
||||
float y = 0.f;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 1; tile_idx < (feat_in + tile_size - 1) / tile_size;
|
||||
++tile_idx) {
|
||||
copy_idx = tile_idx % num_pipeline_stages;
|
||||
// pipeline stage: async copy W fragment
|
||||
pipe.producer_acquire();
|
||||
if (tile_idx * tile_size + threadIdx.y * tx * vec_size < feat_in) {
|
||||
cuda::memcpy_async(W_shared + W_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<W_copy_size>(W_copy_size), pipe);
|
||||
cuda::memcpy_async(X_shared + X_shared_offset[copy_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
X + (batch_idx * feat_in) + tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size,
|
||||
cuda::aligned_size_t<X_copy_size>(X_copy_size), pipe);
|
||||
}
|
||||
pipe.producer_commit();
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// pipeline stage: compute WX
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] = sum;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
}
|
||||
|
||||
compute_idx = (tile_idx - 1) % num_pipeline_stages;
|
||||
// final pipeline stage
|
||||
pipe.consumer_wait();
|
||||
block.sync();
|
||||
x_vec.load(X_shared + X_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W_shared + W_shared_offset[compute_idx] +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += __shfl_down_sync(0xffffffff, sum, offset);
|
||||
}
|
||||
y_warpwise[threadIdx.y] =
|
||||
((tile_idx - 1) * tile_size + threadIdx.y * tx * vec_size < feat_in)
|
||||
? sum
|
||||
: 0.f;
|
||||
block.sync();
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y += y_warpwise[i];
|
||||
}
|
||||
|
||||
block.sync();
|
||||
pipe.consumer_release();
|
||||
|
||||
// write Y;
|
||||
if (block.thread_rank() == 0) {
|
||||
Y[batch_idx * full_y_size + y_offset + j] += static_cast<out_T>(y);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
template <int feat_in, int feat_out, size_t vec_size, size_t X_copy_size,
|
||||
size_t W_copy_size, int tx, int ty, int tz, typename in_T,
|
||||
typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
size_t j = blockIdx.x;
|
||||
constexpr size_t tile_size = tx * ty * vec_size;
|
||||
constexpr size_t num_tiles = (feat_in + tile_size - 1) / tile_size;
|
||||
__shared__ float y_warpwise[ty];
|
||||
|
||||
float y = 0;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
size_t tile_idx;
|
||||
|
||||
#pragma unroll
|
||||
for (tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
x_vec.load(X + (batch_idx * feat_in) +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
w_vec.load(W + (idx * feat_out + j) * feat_in +
|
||||
tile_idx * tile_size +
|
||||
(threadIdx.y * tx + threadIdx.x) * vec_size);
|
||||
}
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
}
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += VLLM_SHFL_DOWN_SYNC(sum, offset);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (tile_idx * tile_size + (threadIdx.y * tx + threadIdx.x + 1) * vec_size - 1 < feat_in) {
|
||||
y += sum;
|
||||
}
|
||||
}
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
y_warpwise[threadIdx.y] = y;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float y_write = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < ty; ++i) {
|
||||
y_write += y_warpwise[i];
|
||||
}
|
||||
|
||||
// write Y;
|
||||
if (threadIdx.x == 0 && threadIdx.y == 0) {
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + j;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(y_write));
|
||||
}
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// nthrs = (2, 16, 4)
|
||||
template <int feat_in, int feat_out, size_t vec_size, int tx, int ty, int tz,
|
||||
typename in_T, typename out_T, typename W_T>
|
||||
__global__ void
|
||||
bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t num_layers, int64_t layer_idx,
|
||||
float scale) {
|
||||
size_t batch_idx = blockIdx.y;
|
||||
int64_t idx = indicies[batch_idx] * num_layers + layer_idx;
|
||||
|
||||
if (idx < 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
auto block = cg::this_thread_block();
|
||||
size_t tile_idx = blockIdx.x;
|
||||
|
||||
// load X;
|
||||
vec_t<in_T, vec_size> x_vec;
|
||||
x_vec.load(X + batch_idx * feat_in + threadIdx.x * vec_size);
|
||||
|
||||
// load W;
|
||||
vec_t<W_T, vec_size> w_vec;
|
||||
w_vec.load(W + (idx * feat_out + tile_idx * tz * ty) * feat_in +
|
||||
block.thread_rank() * vec_size);
|
||||
|
||||
float sum = 0.f;
|
||||
#pragma unroll
|
||||
for (size_t i = 0; i < vec_size; ++i) {
|
||||
#ifndef USE_ROCM
|
||||
sum += float(w_vec[i]) * float(x_vec[i]) * scale;
|
||||
#else
|
||||
sum += convert_type<W_T, float>(w_vec[i]) * convert_type<in_T, float>(x_vec[i]) * scale;
|
||||
#endif
|
||||
}
|
||||
|
||||
cg::thread_block_tile g = cg::tiled_partition<tx>(block);
|
||||
#pragma unroll
|
||||
for (size_t offset = tx / 2; offset > 0; offset /= 2) {
|
||||
sum += g.shfl_down(sum, offset);
|
||||
}
|
||||
sum = g.shfl(sum, 0);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
#ifndef USE_ROCM
|
||||
Y[batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y] += static_cast<out_T>(sum);
|
||||
#else
|
||||
size_t y_idx = batch_idx * full_y_size + y_offset + tile_idx * (tz * ty) +
|
||||
threadIdx.z * ty + threadIdx.y;
|
||||
Y[y_idx] = vllm_add<out_T>(Y[y_idx], convert_type<float, out_T>(sum));
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
template <int feat_in, int feat_out, typename in_T, typename out_T,
|
||||
typename W_T>
|
||||
void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
|
||||
const W_T *__restrict__ W,
|
||||
const int64_t *__restrict__ indicies, int64_t y_offset,
|
||||
int64_t full_y_size, int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
constexpr size_t vec_size = 8;
|
||||
constexpr int tz = 4;
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if constexpr (feat_in <= feat_out) {
|
||||
static_assert(feat_in % vec_size == 0);
|
||||
constexpr int tx = feat_in / vec_size;
|
||||
|
||||
static_assert((32 % tx == 0 && feat_out % (32 / tx * tz) == 0) ||
|
||||
(16 % tx == 0 && feat_out % (16 / tx * tz) == 0) ||
|
||||
(8 % tx == 0 && feat_out % (8 / tx * tz) == 0));
|
||||
|
||||
if constexpr (32 % tx == 0 && feat_out % (32 / tx * tz) == 0) {
|
||||
constexpr int ty = 32 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if (16 % tx == 0 && feat_out % (16 / tx * tz) == 0) {
|
||||
constexpr int ty = 16 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else {
|
||||
constexpr int ty = 8 / tx;
|
||||
dim3 nblks(feat_out / (ty * tz), batch_size);
|
||||
dim3 nthrs(tx, ty, tz);
|
||||
|
||||
bgmv_expand_kernel<feat_in, feat_out, vec_size, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
} else {
|
||||
#ifndef USE_ROCM
|
||||
static_assert(feat_in % (vec_size * 32) == 0 ||
|
||||
feat_in % (vec_size * 16) == 0 ||
|
||||
feat_in % (vec_size * 8) == 0);
|
||||
|
||||
if constexpr (feat_in % (vec_size * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size, vec_size * sizeof(in_T),
|
||||
vec_size * sizeof(W_T), tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 32) == 0) {
|
||||
constexpr int tx = 32;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
} else if constexpr (feat_in % (vec_size / 2 * 16) == 0) {
|
||||
constexpr int tx = 16;
|
||||
constexpr int ty = 4;
|
||||
|
||||
dim3 nblks(feat_out, batch_size);
|
||||
dim3 nthrs(tx, ty);
|
||||
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size / 2,
|
||||
vec_size * sizeof(in_T) / 2,
|
||||
vec_size * sizeof(W_T) / 2, tx, ty, tz>
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset,
|
||||
full_y_size, num_layers, layer_idx,
|
||||
scale);
|
||||
}
|
||||
#else
|
||||
constexpr size_t rocm_warp_size = warpSize;
|
||||
|
||||
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
|
||||
feat_in % (rocm_warp_size * vec_size_) == 0
|
||||
|
||||
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
|
||||
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
|
||||
constexpr size_t vec_size_shrink = vec_size_; \
|
||||
constexpr int tx = tx_; \
|
||||
constexpr int ty = ty_; \
|
||||
dim3 nblks(feat_out, batch_size); \
|
||||
dim3 nthrs(tx, ty); \
|
||||
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
|
||||
vec_size_shrink * sizeof(in_T), \
|
||||
vec_size_shrink * sizeof(W_T), \
|
||||
tx, ty, tz> \
|
||||
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
|
||||
full_y_size, num_layers, layer_idx, \
|
||||
scale); \
|
||||
}
|
||||
|
||||
static_assert(CHECK_INPUT_TILEABLE_BY(32) ||
|
||||
CHECK_INPUT_TILEABLE_BY(16) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 8) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 4) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 2) ||
|
||||
CHECK_INPUT_TILEABLE_BY( 1));
|
||||
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(32, vec_size, rocm_warp_size, 32/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM(16, vec_size, rocm_warp_size, 16/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 8, vec_size, rocm_warp_size, 8/vec_size)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 4, vec_size, rocm_warp_size/(vec_size/4), vec_size/4)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 2, vec_size, rocm_warp_size/(vec_size/2), vec_size/2)
|
||||
else
|
||||
LAUNCH_BGMV_SHRINK_KERNELS_ROCM( 1, vec_size, rocm_warp_size/(vec_size/1), vec_size/1)
|
||||
|
||||
#undef CHECK_INPUT_TILEABLE_BY
|
||||
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
|
||||
template void bgmv_kernel<feat_in, feat_out>( \
|
||||
out_T * __restrict__ Y, const in_T *__restrict__ X, \
|
||||
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
|
||||
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
|
||||
int64_t num_layers, int64_t layer_idx, float scale);
|
||||
|
||||
#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \
|
||||
INST_BGMV(feat_in, feat_out, in_T, out_T, W_T)
|
||||
|
||||
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
|
||||
INST_BGMV(wide, narrow, in_T, out_T, W_T)
|
@ -1,48 +0,0 @@
|
||||
DTYPES = ["fp16", "bf16", "fp32"]
|
||||
DTYPE_MAP = {
|
||||
"fp16": "nv_half",
|
||||
"bf16": "nv_bfloat16",
|
||||
"fp32": "float",
|
||||
}
|
||||
|
||||
TEMPLATE = """
|
||||
#include "bgmv_config.h"
|
||||
#include "bgmv_impl.cuh"
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype})
|
||||
""".lstrip() # noqa: E501
|
||||
|
||||
for input_dtype in DTYPES:
|
||||
for output_dtype in DTYPES:
|
||||
for weight_dtype in DTYPES:
|
||||
if weight_dtype == "fp32":
|
||||
# FP32 weights are not supported.
|
||||
continue
|
||||
if output_dtype == "fp32":
|
||||
# LoRA A matrix.
|
||||
if input_dtype != weight_dtype:
|
||||
# NOTE(woosuk): While Punica supports the case where the
|
||||
# input and weight dtypes are different, we only generate
|
||||
# the kernels the same dtypes to reduce the binary size.
|
||||
continue
|
||||
elif input_dtype == "fp32":
|
||||
# LoRA B matrix.
|
||||
if output_dtype != weight_dtype:
|
||||
# NOTE(woosuk): While Punica supports the case where the
|
||||
# output and weight dtypes are different, we only generate
|
||||
# the kernels the same dtypes to reduce the binary size.
|
||||
continue
|
||||
elif not (input_dtype == output_dtype == weight_dtype):
|
||||
# NOTE(woosuk): While Punica supports mixed data types for
|
||||
# input, output, and weight, we only generate the kernels with
|
||||
# the same data types to reduce the binary size.
|
||||
continue
|
||||
|
||||
kernel_definition = TEMPLATE.format(
|
||||
input_dtype=DTYPE_MAP[input_dtype],
|
||||
output_dtype=DTYPE_MAP[output_dtype],
|
||||
weight_dtype=DTYPE_MAP[weight_dtype])
|
||||
filename = f"bgmv_{input_dtype}_{output_dtype}_{weight_dtype}.cu"
|
||||
with open(filename, "w") as f:
|
||||
f.write(kernel_definition)
|
File diff suppressed because it is too large
Load Diff
@ -1,569 +0,0 @@
|
||||
#include <torch/all.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cstdint>
|
||||
|
||||
#include "type_convert.h"
|
||||
#include "../cuda_compat.h"
|
||||
#include "bgmv/bgmv_config.h"
|
||||
|
||||
|
||||
//====== utils ======
|
||||
|
||||
inline void check_shape(const torch::Tensor &a, const torch::Tensor &b,
|
||||
const char *a_name, const char *b_name) {
|
||||
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ",
|
||||
a.dim(), " vs ", b.dim());
|
||||
for (int i = 0; i < a.dim(); ++i) {
|
||||
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name,
|
||||
".size(", i, ")");
|
||||
}
|
||||
}
|
||||
|
||||
inline constexpr uint64_t pack_u32(uint32_t a, uint32_t b) {
|
||||
return (uint64_t(a) << 32) | uint64_t(b);
|
||||
}
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) \
|
||||
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
|
||||
|
||||
#define CHECK_EQ(a, b) \
|
||||
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
//====== bgmv ======
|
||||
|
||||
template <typename in_T, typename out_T, typename W_T>
|
||||
inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W,
|
||||
const int64_t *lora_indices,
|
||||
uint32_t in_features, uint32_t out_features,
|
||||
int64_t y_offset, int64_t full_y_size,
|
||||
int64_t batch_size, int64_t num_layers,
|
||||
int64_t layer_idx, float scale) {
|
||||
// NOTE(woosuk): While Punica supports various combinations of input/output
|
||||
// data types, we limit the supported data types to reduce the binary size.
|
||||
constexpr bool is_input_float = std::is_same<in_T, float>::value;
|
||||
constexpr bool is_output_float = std::is_same<out_T, float>::value;
|
||||
if (is_input_float) {
|
||||
if (!std::is_same<out_T, W_T>::value) {
|
||||
return false;
|
||||
}
|
||||
} else if (is_output_float) {
|
||||
if (!std::is_same<in_T, W_T>::value) {
|
||||
return false;
|
||||
}
|
||||
} else if (!(std::is_same<in_T, W_T>::value &&
|
||||
std::is_same<out_T, W_T>::value)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
switch (pack_u32(in_features, out_features)) {
|
||||
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
|
||||
case pack_u32(feat_in, feat_out): \
|
||||
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
|
||||
full_y_size, batch_size, num_layers, \
|
||||
layer_idx, scale); \
|
||||
break;
|
||||
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
|
||||
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
|
||||
|
||||
FOR_BGMV_WIDE_NARROW(CASE, _, _, _)
|
||||
FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _)
|
||||
#undef CASE
|
||||
#undef CASE_ONESIDE
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, double scale) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t h_in = x.size(1);
|
||||
int64_t h_out = y.size(1);
|
||||
int64_t num_layers = w.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out, 0,
|
||||
h_out, B, num_layers, layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
double scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset) {
|
||||
CHECK_INPUT(y);
|
||||
CHECK_INPUT(x);
|
||||
CHECK_INPUT(w);
|
||||
CHECK_INPUT(indicies);
|
||||
|
||||
CHECK_DIM(2, y);
|
||||
CHECK_DIM(2, x);
|
||||
CHECK_DIM(4, w);
|
||||
CHECK_DIM(1, indicies);
|
||||
|
||||
int64_t B = x.size(0);
|
||||
int64_t num_layers = w.size(1);
|
||||
int64_t full_y_size = y.size(1);
|
||||
CHECK_EQ(w.size(3), h_in);
|
||||
CHECK_EQ(w.size(2), h_out);
|
||||
CHECK_EQ(indicies.size(0), x.size(0));
|
||||
CHECK_EQ(y.size(0), x.size(0));
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
bool ok = false;
|
||||
if (h_in <= 128512 && h_out <= 128512) {
|
||||
// TODO: See if we can get rid of this massive nested switch
|
||||
switch (x.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_half *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (y.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_half *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<nv_bfloat16 *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
case at::ScalarType::Float:
|
||||
switch (w.scalar_type()) {
|
||||
case at::ScalarType::Half:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_half *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
case at::ScalarType::BFloat16:
|
||||
ok = launch_bgmv_kernel(static_cast<float *>(y.data_ptr()),
|
||||
static_cast<float *>(x.data_ptr()),
|
||||
static_cast<nv_bfloat16 *>(w.data_ptr()),
|
||||
indicies.data_ptr<int64_t>(), h_in, h_out,
|
||||
y_offset, full_y_size, B, num_layers,
|
||||
layer_idx, scale);
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(ok, "No suitable kernel.", " h_in=", h_in, " h_out=", h_out,
|
||||
" dtype=", x.scalar_type(), " out_dtype=", y.scalar_type());
|
||||
}
|
@ -1,11 +0,0 @@
|
||||
#pragma once
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
void dispatch_bgmv(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx, double scale);
|
||||
|
||||
void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
|
||||
torch::Tensor indicies, int64_t layer_idx,
|
||||
double scale, int64_t h_in, int64_t h_out,
|
||||
int64_t y_offset);
|
@ -1,18 +0,0 @@
|
||||
#include "registration.h"
|
||||
#include "punica_ops.h"
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"dispatch_bgmv(Tensor! y, Tensor x, Tensor w, Tensor indicies, int "
|
||||
"layer_idx, float scale) -> ()");
|
||||
m.impl("dispatch_bgmv", torch::kCUDA, &dispatch_bgmv);
|
||||
|
||||
m.def(
|
||||
"dispatch_bgmv_low_level(Tensor! y, Tensor x, Tensor w,"
|
||||
"Tensor indicies, int layer_idx,"
|
||||
"float scale, int h_in, int h_out,"
|
||||
"int y_offset) -> ()");
|
||||
m.impl("dispatch_bgmv_low_level", torch::kCUDA, &dispatch_bgmv_low_level);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
@ -1,82 +0,0 @@
|
||||
#ifndef CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
#define CSRC__PUNICA__TYPE_CONVERT_H__
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
#else
|
||||
|
||||
#include <hip/hip_bf16.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
|
||||
#define __TYPE_CONVERT__HOST_DEVICE__ __host__ __device__
|
||||
|
||||
typedef __half nv_half;
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 val) {
|
||||
return __hip_bfloat162{val, val};
|
||||
}
|
||||
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat162 make_bfloat162(__hip_bfloat16 vall, __hip_bfloat16 valr) {
|
||||
return __hip_bfloat162{vall, valr};
|
||||
}
|
||||
|
||||
template <typename T_src, typename T_dst>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T_dst convert_type(T_src val) {
|
||||
return static_cast<T_dst>(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__half, float>(__half val) {
|
||||
return __half2float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half convert_type<float, __half>(float val) {
|
||||
return __float2half(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline float convert_type<__hip_bfloat16, float>(__hip_bfloat16 val) {
|
||||
return __bfloat162float(val);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 convert_type<float, __hip_bfloat16>(float val) {
|
||||
return __float2bfloat16(val);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline T vllm_add(T a, T b) {
|
||||
return a + b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __half vllm_add<__half>(__half a, __half b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
template <>
|
||||
__TYPE_CONVERT__HOST_DEVICE__
|
||||
inline __hip_bfloat16 vllm_add<__hip_bfloat16>(__hip_bfloat16 a, __hip_bfloat16 b) {
|
||||
return __hadd(a, b);
|
||||
}
|
||||
|
||||
#undef __TYPE_CONVERT__HOST_DEVICE__
|
||||
|
||||
#endif // USE_ROCM
|
||||
|
||||
#endif // CSRC__PUNICA__TYPE_CONVERT_H__
|
@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float res = 0;
|
||||
|
||||
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
|
||||
while (iters--) {
|
||||
if (pred && a_gl_rd < a_gl_end) {
|
||||
|
@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
|
||||
|
||||
return result;
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
} // namespace awq
|
||||
|
@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
|
||||
namespace vllm {
|
||||
namespace awq {
|
||||
|
||||
// Pack two half values.
|
||||
static inline __device__ __host__ unsigned __pack_half2(const half x,
|
||||
const half y) {
|
||||
unsigned v0 = *((unsigned short*)&x);
|
||||
unsigned v1 = *((unsigned short*)&y);
|
||||
return (v1 << 16) | v0;
|
||||
}
|
||||
|
||||
template <int N>
|
||||
__global__ void __launch_bounds__(64)
|
||||
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
|
||||
@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
|
||||
__shared__ half A_shared[16 * (32 + 8)];
|
||||
__shared__ half B_shared[32 * (N + 8)];
|
||||
|
||||
__shared__ half scaling_factors_shared[N];
|
||||
__shared__ half zeros_shared[N];
|
||||
|
||||
int j_factors1 = ((OC + N - 1) / N);
|
||||
int blockIdx_x = 0;
|
||||
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
|
||||
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
|
||||
|
||||
@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
|
||||
|
||||
static constexpr int row_stride_warp = 32 * 8 / 32;
|
||||
static constexpr int row_stride = 2 * 32 * 8 / N;
|
||||
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
|
||||
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
|
||||
bool ld_A_flag =
|
||||
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
|
||||
@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
|
||||
uint32_t B_loaded =
|
||||
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
|
||||
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
|
||||
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
|
||||
// 8)) * 8);
|
||||
|
||||
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
|
||||
// % (cta_N / 8)) * 8);
|
||||
// - zero and * scale
|
||||
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
|
||||
// q * scale - zero * scale.
|
||||
@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
|
||||
__global__ void __launch_bounds__(64)
|
||||
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
|
||||
int* __restrict__ zeros, half* __restrict__ C, int G) {
|
||||
int j_factors1 = 4;
|
||||
int row_stride2 = 4;
|
||||
int split_k_iters = 1;
|
||||
static constexpr uint32_t ZERO = 0x0;
|
||||
half B_shared[32 * (128 + 8)];
|
||||
|
||||
half* B_shared_ptr2 = B_shared;
|
||||
|
||||
half B_shared_warp[32];
|
||||
int OC = 512;
|
||||
|
||||
int N = blockDim.x * gridDim.x; // 2
|
||||
int col = (blockIdx.x * blockDim.x + threadIdx.x);
|
||||
int row = blockIdx.y * blockDim.y + threadIdx.y;
|
||||
|
@ -64,8 +64,6 @@ using namespace detail;
|
||||
|
||||
// Row vector broadcast
|
||||
template<
|
||||
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
|
||||
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
|
||||
int Stages,
|
||||
class CtaTileShapeMNK,
|
||||
class Element,
|
||||
@ -73,14 +71,12 @@ template<
|
||||
int Alignment = 128 / sizeof_bits_v<Element>
|
||||
>
|
||||
struct Sm90RowOrScalarBroadcast {
|
||||
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
|
||||
static_assert(
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
|
||||
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
|
||||
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
|
||||
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
|
||||
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
|
||||
|
||||
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
|
||||
struct SharedStorage {
|
||||
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
|
||||
struct SharedStorage {
|
||||
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
|
||||
};
|
||||
|
||||
// This struct has been modified to have a bool indicating that ptr_row is a
|
||||
@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
|
||||
|
||||
CUTLASS_HOST_DEVICE
|
||||
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
|
||||
: params(params),
|
||||
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
|
||||
: params(params)
|
||||
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
|
||||
|
||||
Params params;
|
||||
Element* smem_row;
|
||||
Element *smem = nullptr;
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
is_producer_load_needed() const {
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE bool
|
||||
@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
|
||||
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
|
||||
}
|
||||
|
||||
template <int EpiTiles, class GTensor, class STensor>
|
||||
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
|
||||
: gRow(cute::forward<GTensor>(gRow)),
|
||||
sRow(cute::forward<STensor>(sRow)),
|
||||
params(params) {}
|
||||
|
||||
GTensor gRow; // (CTA_M,CTA_N)
|
||||
STensor sRow; // (CTA_M,CTA_N,PIPE)
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
|
||||
if (!params.row_broadcast) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (issue_tma_load) {
|
||||
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
|
||||
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
|
||||
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
|
||||
// Issue the TMA bulk copy
|
||||
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <class... Args>
|
||||
CUTLASS_DEVICE auto
|
||||
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
|
||||
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
|
||||
cute::move(gRow), cute::move(sRow), params);
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template <int EpiTiles, class RTensor, class STensor>
|
||||
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
|
||||
: tCrRow(cute::forward<RTensor>(tCrRow)),
|
||||
tCsRow(cute::forward<STensor>(tCsRow)),
|
||||
params(params) {}
|
||||
ConsumerStoreCallbacks(
|
||||
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
|
||||
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
|
||||
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
|
||||
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
|
||||
: tGS_gRow(tGS_gRow_)
|
||||
, tGS_sRow(tGS_sRow_)
|
||||
, tGS_cRow(tGS_cRow_)
|
||||
, tiled_G2S(tiled_g2s_)
|
||||
, tSR_sRow(tSR_sRow_)
|
||||
, tSR_rRow(tSR_rRow_)
|
||||
, tCcRow(tCcRow_)
|
||||
, residue_tCcRow(residue_tCcRow_)
|
||||
, params(params_) {}
|
||||
|
||||
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
|
||||
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
|
||||
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
|
||||
Tiled_G2S tiled_G2S;
|
||||
|
||||
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
ThrResidue residue_tCcRow; // (m, n)
|
||||
ThrNum thr_num;
|
||||
Params const& params;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
|
||||
begin() {
|
||||
if (!params.row_broadcast) {
|
||||
fill(tCrRow, *(params.ptr_row));
|
||||
fill(tSR_rRow, *(params.ptr_row));
|
||||
return;
|
||||
}
|
||||
|
||||
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
|
||||
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
|
||||
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
|
||||
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
|
||||
|
||||
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
|
||||
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
|
||||
continue; // OOB of SMEM,
|
||||
}
|
||||
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
|
||||
tGS_sRow_flt(i) = tGS_gRow_flt(i);
|
||||
}
|
||||
else {
|
||||
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
|
||||
}
|
||||
}
|
||||
synchronize();
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin_loop(int epi_m, int epi_n) {
|
||||
if (epi_m == 0) { // Assumes M-major subtile loop
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
|
||||
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
|
||||
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
|
||||
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
|
||||
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
|
||||
copy(tSR_sRow_flt, tSR_rRow_flt);
|
||||
}
|
||||
}
|
||||
|
||||
@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < FragmentSize; ++i) {
|
||||
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
|
||||
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
|
||||
}
|
||||
|
||||
return frg_row;
|
||||
@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
|
||||
>
|
||||
CUTLASS_DEVICE auto
|
||||
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
|
||||
auto [M, N, K, L] = args.problem_shape_mnkl;
|
||||
auto [m, n, k, l] = args.tile_coord_mnkl;
|
||||
using ThreadCount = decltype(size(args.tiled_copy));
|
||||
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
|
||||
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
|
||||
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
|
||||
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
|
||||
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
|
||||
Tensor sRow = make_tensor(make_smem_ptr(smem),
|
||||
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
|
||||
//// G2S: Gmem to Smem
|
||||
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
|
||||
Layout< Shape<_1, ThreadCount>,
|
||||
Stride<_0, _1>>{},
|
||||
Layout<_1>{});
|
||||
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
|
||||
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
|
||||
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
|
||||
|
||||
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
|
||||
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
|
||||
cute::move(tCrRow), cute::move(tCsRow), params);
|
||||
//// G2S: Coord
|
||||
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
|
||||
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
|
||||
|
||||
//// S2R: Smem to Reg
|
||||
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
|
||||
tGS_gRow,
|
||||
tGS_sRow,
|
||||
tGS_cRow, tiled_g2s,
|
||||
tSR_sRow,
|
||||
tSR_rRow,
|
||||
args.tCcD,
|
||||
args.residue_cD,
|
||||
ThreadCount{},
|
||||
params);
|
||||
}
|
||||
};
|
||||
|
||||
@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
|
||||
return args;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static bool
|
||||
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <class ProblemShape>
|
||||
static size_t
|
||||
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
|
||||
@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
|
||||
return EmptyProducerLoadCallbacks{};
|
||||
}
|
||||
|
||||
template<class GTensor, class RTensor>
|
||||
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
|
||||
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
|
||||
CUTLASS_DEVICE
|
||||
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
|
||||
: tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
params(params) {}
|
||||
ConsumerStoreCallbacks(
|
||||
GTensor&& tCgCol,
|
||||
RTensor&& tCrCol,
|
||||
CTensor&& tCcCol,
|
||||
ProblemShape problem_shape,
|
||||
Params const& params
|
||||
):
|
||||
tCgCol(cute::forward<GTensor>(tCgCol)),
|
||||
tCrCol(cute::forward<RTensor>(tCrCol)),
|
||||
tCcCol(cute::forward<CTensor>(tCcCol)),
|
||||
m(get<0>(problem_shape)),
|
||||
params(params) {}
|
||||
|
||||
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
RTensor tCrCol;
|
||||
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
Params const& params;
|
||||
int m;
|
||||
|
||||
CUTLASS_DEVICE void
|
||||
begin() {
|
||||
Tensor pred = make_tensor<bool>(shape(tCgCol));
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < size(pred); ++i) {
|
||||
pred(i) = get<0>(tCcCol(i)) < m;
|
||||
}
|
||||
|
||||
if (!params.col_broadcast) {
|
||||
fill(tCrCol, *(params.ptr_col));
|
||||
return;
|
||||
@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
|
||||
|
||||
// Filter so we don't issue redundant copies over stride-0 modes
|
||||
// (only works if 0-strides are in same location, which is by construction)
|
||||
copy_aligned(filter(tCgCol), filter(tCrCol));
|
||||
copy_if(pred, filter(tCgCol), filter(tCrCol));
|
||||
}
|
||||
|
||||
template <typename ElementAccumulator, int FragmentSize>
|
||||
@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
|
||||
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
|
||||
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
|
||||
cute::move(tCgCol), cute::move(tCrCol), params);
|
||||
// Generate an identity tensor matching the shape of the global tensor and
|
||||
// partition the same way, this will be used to generate the predicate
|
||||
// tensor for loading
|
||||
Tensor cCol = make_identity_tensor(mCol.shape());
|
||||
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
|
||||
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
|
||||
|
||||
return ConsumerStoreCallbacks(
|
||||
cute::move(tCgCol),
|
||||
cute::move(tCrCol),
|
||||
cute::move(tCcCol),
|
||||
args.problem_shape_mnkl,
|
||||
params
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -1,470 +1,18 @@
|
||||
#include <stddef.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "broadcast_load_epilogue_c2x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "scaled_mm_c2x_sm75_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm80_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
|
||||
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
|
||||
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
|
||||
|
||||
Epilogue functions can be defined to post-process the output before it is
|
||||
written to GPU memory.
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace {
|
||||
|
||||
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm75_to_sm80 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm80_to_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm89_to_sm90 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
||||
return evt_compute_args;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
using BiasArgs = typename Bias::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
|
||||
bias_args};
|
||||
return evt_compute_args;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
cutlass::arch::OpMultiplyAdd>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
float, cutlass::layout::RowMajor, 4,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
MainLoopStages, Operator,
|
||||
1 /* epilogue stages */
|
||||
>::GemmKernel>;
|
||||
// clang-format on
|
||||
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
using Epilogue = typename Gemm::Epilogue;
|
||||
auto evt_args =
|
||||
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
typename Gemm::Op::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||
problem_size, // problem size
|
||||
1, // batch count
|
||||
epilogue_args,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldc};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.get(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
void fallback_cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
// the FallbackGemm instead.
|
||||
static const int max_shared_mem_per_block_opt_in =
|
||||
get_cuda_max_shared_memory_per_block_opt_in(0);
|
||||
|
||||
size_t const gemm_shared_mem_size =
|
||||
sizeof(typename Gemm::KernelType::SharedStorage);
|
||||
size_t const fallback_gemm_shared_mem_size =
|
||||
sizeof(typename FallbackGemm::KernelType::SharedStorage);
|
||||
|
||||
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (128, inf)
|
||||
// - M in (64, 128] and N >= 8192
|
||||
// Shared Memory required by this Gemm - 81920 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M64 {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (32, 64]
|
||||
// - M in (64, 128] and N < 8192
|
||||
// Shared Memory required by this Gemm - 122880 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M32 {
|
||||
// M in (16, 32]
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M16 {
|
||||
// M in [1, 16]
|
||||
// Shared Memory required by this Gemm - 51200 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_gemm_sm80_dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128BigN =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128SmallN =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM16 =
|
||||
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm80_config_M16 has the least shared-memory requirement. However,
|
||||
// based on some profiling, we select sm80_config_M32 as a better alternative
|
||||
// performance wise.
|
||||
using FallbackGemm =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
uint32_t const n = out.size(1);
|
||||
bool const small_n = n < 8192;
|
||||
if (small_n) {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
|
||||
FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
// M in (128, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
template <template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
@ -473,20 +21,13 @@ void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>>(
|
||||
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -501,11 +42,11 @@ void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm75_epilogue<ScaledEpilogue>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
@ -518,11 +59,12 @@ void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t, Epilogue>(
|
||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -537,11 +79,11 @@ void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm80_epilogue<ScaledEpilogue>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
@ -550,23 +92,17 @@ template <template <typename, typename> typename Epilogue,
|
||||
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
if (a.dtype() == torch::kInt8) {
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
assert(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_caller<cutlass_2x_gemm<
|
||||
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>>(
|
||||
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
|
||||
Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
} else {
|
||||
@ -574,17 +110,13 @@ void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_caller<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<
|
||||
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_caller<
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90,
|
||||
cutlass::float_e4m3_t, cutlass::half_t, Epilogue,
|
||||
TileShape, WarpShape, InstructionShape, 5>>(
|
||||
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
@ -600,10 +132,10 @@ void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogueBias>(
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm89_epilogue<ScaledEpilogue>(out, a, b, a_scales,
|
||||
b_scales);
|
||||
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
340
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Normal file
340
csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cuh
Normal file
@ -0,0 +1,340 @@
|
||||
#pragma once
|
||||
#include <stddef.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
// clang-format will break include orders
|
||||
// clang-format off
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/gemm_coord.h"
|
||||
#include "cutlass/arch/mma_sm75.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cutlass/arch/mma.h"
|
||||
#include "cutlass/gemm/device/gemm.h"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
|
||||
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
|
||||
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
|
||||
|
||||
#include "broadcast_load_epilogue_c2x.hpp"
|
||||
#include "common.hpp"
|
||||
// clang-format on
|
||||
|
||||
using namespace cute;
|
||||
|
||||
/*
|
||||
Epilogue functions can be defined to post-process the output before it is
|
||||
written to GPU memory.
|
||||
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
|
||||
as well as a static prepare_args function that constructs an
|
||||
EVTCompute::Arguments struct.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
// Wrappers for the GEMM kernel that is used to guard against compilation on
|
||||
// architectures that will never use the kernel. The purpose of this is to
|
||||
// reduce the size of the compiled binary.
|
||||
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
|
||||
// into code that will be executed on the device where it is defined.
|
||||
template <typename Kernel>
|
||||
struct enable_sm75_to_sm80 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm80_to_sm89 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Kernel>
|
||||
struct enable_sm89_to_sm90 : Kernel {
|
||||
template <typename... Args>
|
||||
CUTLASS_DEVICE static void invoke(Args&&... args) {
|
||||
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
|
||||
Kernel::invoke(std::forward<Args>(args)...);
|
||||
#endif
|
||||
}
|
||||
};
|
||||
|
||||
/*
|
||||
* This class provides the common ScaleA and ScaleB descriptors for the
|
||||
* ScaledEpilogue and ScaledEpilogueBias classes.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBase {
|
||||
protected:
|
||||
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
|
||||
|
||||
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
|
||||
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
};
|
||||
|
||||
/*
|
||||
This epilogue function defines a quantized GEMM operation similar to
|
||||
torch._scaled_mm.
|
||||
|
||||
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
|
||||
per-row. B can be quantized per-tensor or per-column.
|
||||
Any combination of per-tensor and per-row or column is supported.
|
||||
A and B must have symmetric quantization (zero point == 0).
|
||||
|
||||
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
|
||||
scales are applied elementwise with numpy-style broadcasting.
|
||||
|
||||
ScaleA and ScaleB define the epilogue functions that apply the scales for
|
||||
the A and B operands respectively. These scales may be either per-tensor or
|
||||
per row or column.
|
||||
*/
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogue
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
using EVTCompute =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
|
||||
return evt_compute_args;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename ElementD, typename OutputTileThreadMap>
|
||||
struct ScaledEpilogueBias
|
||||
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
|
||||
private:
|
||||
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
|
||||
using Accum = typename SUPER::Accum;
|
||||
using ScaleA = typename SUPER::ScaleA;
|
||||
using ScaleB = typename SUPER::ScaleB;
|
||||
|
||||
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiplies, float, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using EVTCompute0 =
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
|
||||
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
|
||||
public:
|
||||
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
|
||||
EVTCompute0, Bias>;
|
||||
using ArgumentType = typename EVTCompute::Arguments;
|
||||
|
||||
static ArgumentType prepare_args(torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
torch::Tensor const& bias) {
|
||||
using ScaleAArgs = typename ScaleA::Arguments;
|
||||
using ScaleBArgs = typename ScaleB::Arguments;
|
||||
using BiasArgs = typename Bias::Arguments;
|
||||
|
||||
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
|
||||
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
|
||||
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
|
||||
|
||||
typename EVTCompute0::Arguments evt0_compute_args{b_args};
|
||||
|
||||
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
|
||||
bias_args};
|
||||
return evt_compute_args;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Arch, template <typename> typename ArchGuard,
|
||||
typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename> typename Epilogue_, typename TileShape,
|
||||
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
|
||||
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
|
||||
struct cutlass_2x_gemm {
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementD = ElementD_;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
|
||||
using Operator =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
|
||||
cutlass::arch::OpMultiplyAddSaturate,
|
||||
FP8MathOperator>::type;
|
||||
|
||||
using OutputTileThreadMap =
|
||||
cutlass::epilogue::threadblock::OutputTileThreadLayout<
|
||||
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
|
||||
>;
|
||||
|
||||
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
|
||||
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
|
||||
Stride<int64_t, Int<1>, Int<0>>>;
|
||||
|
||||
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
|
||||
|
||||
// clang-format off
|
||||
using RowMajor = typename cutlass::layout::RowMajor;
|
||||
using ColumnMajor = typename cutlass::layout::ColumnMajor;
|
||||
using KernelType =
|
||||
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
|
||||
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
|
||||
float, cutlass::layout::RowMajor, 4,
|
||||
ElementAcc, float, cutlass::arch::OpClassTensorOp,
|
||||
Arch,
|
||||
TileShape, WarpShape, InstructionShape,
|
||||
EVTD,
|
||||
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
|
||||
MainLoopStages, Operator,
|
||||
1 /* epilogue stages */
|
||||
>::GemmKernel>;
|
||||
// clang-format on
|
||||
|
||||
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
|
||||
};
|
||||
|
||||
template <typename Gemm, typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_params) {
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int32_t m = a.size(0);
|
||||
int32_t n = b.size(1);
|
||||
int32_t k = a.size(1);
|
||||
cutlass::gemm::GemmCoord problem_size{m, n, k};
|
||||
|
||||
int64_t lda = a.stride(0);
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
|
||||
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
|
||||
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
|
||||
|
||||
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
|
||||
|
||||
using Epilogue = typename Gemm::Epilogue;
|
||||
auto evt_args =
|
||||
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
|
||||
|
||||
typename Gemm::EVTD::Arguments epilogue_args{
|
||||
evt_args,
|
||||
d_args,
|
||||
};
|
||||
|
||||
typename Gemm::Op::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
|
||||
problem_size, // problem size
|
||||
1, // batch count
|
||||
epilogue_args,
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
nullptr,
|
||||
nullptr,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
0,
|
||||
lda,
|
||||
ldb,
|
||||
ldc,
|
||||
ldc};
|
||||
|
||||
// Launch the CUTLASS GEMM kernel.
|
||||
typename Gemm::Op gemm_op;
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
|
||||
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
// In some cases, the GPU isn't able to accommodate the
|
||||
// shared memory requirements of the Gemm. In such cases, use
|
||||
// the FallbackGemm instead.
|
||||
static const int max_shared_mem_per_block_opt_in =
|
||||
get_cuda_max_shared_memory_per_block_opt_in(0);
|
||||
|
||||
size_t const gemm_shared_mem_size =
|
||||
sizeof(typename Gemm::KernelType::SharedStorage);
|
||||
size_t const fallback_gemm_shared_mem_size =
|
||||
sizeof(typename FallbackGemm::KernelType::SharedStorage);
|
||||
|
||||
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
|
||||
return cutlass_gemm_caller<Gemm>(out, a, b,
|
||||
std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
TORCH_CHECK(fallback_gemm_shared_mem_size <=
|
||||
max_shared_mem_per_block_opt_in);
|
||||
return cutlass_gemm_caller<FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
123
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
Normal file
123
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm75_dispatch.cuh
Normal file
@ -0,0 +1,123 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM75 based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (256, inf]
|
||||
// - M in (64, 128]
|
||||
// Shared memory required by this Gemm 32768
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M256 {
|
||||
// M in (128, 256]
|
||||
// Shared memory required by this Gemm 65536
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M64 {
|
||||
// M in (32, 64]
|
||||
// Shared memory required by this Gemm 49152
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm75_config_M32 {
|
||||
// M in [1, 32]
|
||||
// Shared memory required by this Gemm 49152
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM256 =
|
||||
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm75_config_default has the least shared-memory requirements.
|
||||
using FallbackGemm = Cutlass2xGemmDefault;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 32) {
|
||||
// M in [1, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
139
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
Normal file
139
csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm80_dispatch.cuh
Normal file
@ -0,0 +1,139 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM80 based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_default {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (128, inf)
|
||||
// - M in (64, 128] and N >= 8192
|
||||
// Shared Memory required by this Gemm - 81920 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M64 {
|
||||
// This config is used in 2 cases,
|
||||
// - M in (32, 64]
|
||||
// - M in (64, 128] and N < 8192
|
||||
// Shared Memory required by this Gemm - 122880 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M32 {
|
||||
// M in (16, 32]
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm80_config_M16 {
|
||||
// M in [1, 16]
|
||||
// Shared Memory required by this Gemm - 51200 bytes
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
using Cutlass2xGemmDefault =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128BigN =
|
||||
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM128SmallN =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM64 =
|
||||
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM32 =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
using Cutlass2xGemmM16 =
|
||||
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
// Due to shared memory requirements, some Gemms may fail to run on some
|
||||
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
|
||||
// in such cases.
|
||||
// sm80_config_M16 has the least shared-memory requirement. However,
|
||||
// based on some profiling, we select sm80_config_M32 as a better alternative
|
||||
// performance wise.
|
||||
using FallbackGemm =
|
||||
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
uint32_t const n = out.size(1);
|
||||
bool const small_n = n < 8192;
|
||||
if (small_n) {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
|
||||
FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
} else {
|
||||
// M in (128, inf)
|
||||
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -0,0 +1,368 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
#include "cutlass/float8.h"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
|
||||
* shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_fp8_fallback_gemm {
|
||||
// Shared Memory required by this Gemm - 61440 bytes
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5,
|
||||
FP8MathOperator>;
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8196) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5, FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
|
||||
static const int32_t MainLoopStages = 5;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_fp8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 24576) {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, MainLoopStages,
|
||||
FP8MathOperator>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -0,0 +1,353 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm_c2x.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM89 (int8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue>
|
||||
struct sm89_int8_fallback_gemm {
|
||||
// Shared mem requirement : 61440
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
static int32_t const MainLoopStages = 5;
|
||||
|
||||
using Cutlass2xGemm =
|
||||
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
|
||||
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
|
||||
};
|
||||
|
||||
struct sm89_int8_config_default {
|
||||
// M in (256, inf)
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M256 {
|
||||
// M in (128, 256]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 4096) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M128 {
|
||||
// M in (64, 128]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (np2 <= 16384) {
|
||||
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M64 {
|
||||
// M in (32, 64]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 3>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M32 {
|
||||
// M in (16, 32]
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct sm89_int8_config_M16 {
|
||||
// M in [1, 16]
|
||||
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
|
||||
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b, EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
|
||||
using FallbackGemm =
|
||||
typename sm89_int8_fallback_gemm<InType, OutType,
|
||||
Epilogue>::Cutlass2xGemm;
|
||||
|
||||
uint32_t const n = out.size(1);
|
||||
uint32_t const np2 = next_pow_2(n);
|
||||
|
||||
if (np2 <= 8192) {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 5>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
|
||||
|
||||
return vllm::fallback_cutlass_gemm_caller<
|
||||
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
|
||||
InType, OutType, Epilogue, TileShape, WarpShape,
|
||||
InstructionShape, 4>,
|
||||
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, int8_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kInt8);
|
||||
TORCH_CHECK(b.dtype() == torch::kInt8);
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// M in [1, 16]
|
||||
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 32) {
|
||||
// M in (16, 32]
|
||||
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// M in (32, 64]
|
||||
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// M in (64, 128]
|
||||
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// M in (128, 256]
|
||||
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
// M in (256, inf)
|
||||
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -18,8 +18,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "cutlass/util/device_memory.h"
|
||||
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
@ -72,13 +70,9 @@ struct ScaledEpilogueBase {
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<1>, Int<0>, Int<0>>>;
|
||||
|
||||
using ScaleBDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, float>;
|
||||
|
||||
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
|
||||
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
|
||||
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
|
||||
Stride<Int<0>, Int<1>, Int<0>>>;
|
||||
};
|
||||
|
||||
/*
|
||||
@ -154,12 +148,8 @@ struct ScaledEpilogueBias
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
using BiasDescriptor =
|
||||
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
|
||||
EpilogueDescriptor, ElementD>;
|
||||
|
||||
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
|
||||
BiasDescriptor::Stages, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
|
||||
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
|
||||
|
||||
public:
|
||||
@ -251,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
int64_t ldb = b.stride(1);
|
||||
int64_t ldc = out.stride(0);
|
||||
|
||||
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
|
||||
using StrideA = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideB = Stride<int64_t, Int<1>, int64_t>;
|
||||
using StrideC = typename Gemm::StrideC;
|
||||
|
||||
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
|
||||
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
|
||||
StrideA a_stride{lda, Int<1>{}, 0};
|
||||
StrideB b_stride{ldb, Int<1>{}, 0};
|
||||
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
|
||||
|
||||
using GemmKernel = typename Gemm::GemmKernel;
|
||||
@ -282,11 +272,13 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
|
||||
CUTLASS_CHECK(gemm_op.can_implement(args));
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
|
||||
|
||||
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
|
||||
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
|
||||
CUTLASS_CHECK(status);
|
||||
}
|
||||
|
||||
|
@ -38,13 +38,7 @@ bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 89) {
|
||||
// CUTLASS Kernels have not been tuned for Ada Lovelace systems
|
||||
// and are slower than torch.mm. Return false unconditionally in this case.
|
||||
return false;
|
||||
|
||||
// Once the CUTLASS kernels have been optimized for Lovelace systems,
|
||||
// use the following check:
|
||||
// return CUDA_VERSION >= 12040;
|
||||
return CUDA_VERSION >= 12040;
|
||||
}
|
||||
#endif
|
||||
|
||||
|
@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
return {}; // Squash missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
|
@ -7,6 +7,8 @@
|
||||
#include "cuda_compat.h"
|
||||
#include "dispatch_utils.h"
|
||||
|
||||
#include "../../reduction_utils.cuh"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
|
||||
|
||||
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
|
||||
|
||||
template <typename scalar_t>
|
||||
template <bool is_scale_inverted>
|
||||
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
|
||||
const scalar_t val, const float inverted_scale) {
|
||||
float x = static_cast<float>(val) * inverted_scale;
|
||||
float const val, float const scale) {
|
||||
float x = 0.0f;
|
||||
if constexpr (is_scale_inverted) {
|
||||
x = val * scale;
|
||||
} else {
|
||||
x = val / scale;
|
||||
}
|
||||
|
||||
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
|
||||
return static_cast<c10::Float8_e4m3fn>(r);
|
||||
}
|
||||
@ -40,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
|
||||
const scalar_t* __restrict__ input,
|
||||
int64_t num_elems) {
|
||||
__shared__ float cache[1024];
|
||||
int i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
|
||||
|
||||
// First store maximum for all values processes by
|
||||
// the current thread in cache[threadIdx.x]
|
||||
@ -87,6 +95,70 @@ typedef struct __align__(4) {
|
||||
}
|
||||
float8x4_t;
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
|
||||
int64_t const num_elems, int const tid,
|
||||
int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
float absmax_val = 0.0f;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
absmax_val = max(absmax_val, fabs(in_vec.x));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.y));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.z));
|
||||
absmax_val = max(absmax_val, fabs(in_vec.w));
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
absmax_val = max(absmax_val, fabs(input[i]));
|
||||
}
|
||||
|
||||
return absmax_val;
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool is_scale_inverted>
|
||||
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
|
||||
scalar_t const* __restrict__ input,
|
||||
float const scale,
|
||||
int64_t const num_elems,
|
||||
int const tid, int const step) {
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
vec4_t<scalar_t> const* vectorized_in =
|
||||
reinterpret_cast<vec4_t<scalar_t> const*>(input);
|
||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
|
||||
int64_t const num_vec_elems = num_elems >> 2;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int64_t i = tid; i < num_vec_elems; i += step) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
|
||||
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.x), scale);
|
||||
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.y), scale);
|
||||
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.z), scale);
|
||||
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(in_vec.w), scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
|
||||
out[i] = scaled_fp8_conversion<is_scale_inverted>(
|
||||
static_cast<float>(input[i]), scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
|
||||
// Invert the scale so that we can use multiplications to avoid expensive
|
||||
// division.
|
||||
const float inverted_scale = 1.0f / (*scale);
|
||||
scaled_fp8_conversion_vec<scalar_t, true>(
|
||||
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
|
||||
}
|
||||
|
||||
// Vectorized input/output to better utilize memory bandwidth.
|
||||
const vec4_t<scalar_t>* vectorized_in =
|
||||
reinterpret_cast<const vec4_t<scalar_t>*>(input);
|
||||
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
|
||||
template <typename scalar_t>
|
||||
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
|
||||
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
|
||||
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
|
||||
const int hidden_size) {
|
||||
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
|
||||
|
||||
int num_vec_elems = num_elems >> 2;
|
||||
int const tid = threadIdx.x;
|
||||
int const token_idx = blockIdx.x;
|
||||
|
||||
#pragma unroll 4
|
||||
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
|
||||
vec4_t<scalar_t> in_vec = vectorized_in[i];
|
||||
float8x4_t out_vec;
|
||||
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
|
||||
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
|
||||
|
||||
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
|
||||
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
|
||||
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
|
||||
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
|
||||
vectorized_out[i] = out_vec;
|
||||
// For vectorization, token_input and token_output pointers need to be
|
||||
// aligned at 8-byte and 4-byte addresses respectively.
|
||||
bool const can_vectorize = hidden_size % 4 == 0;
|
||||
|
||||
float absmax_val = 0.0f;
|
||||
if (can_vectorize) {
|
||||
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
float const x = static_cast<float>(token_input[i]);
|
||||
absmax_val = max(absmax_val, fabs(x));
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the remaining elements if num_elems is not divisible by 4
|
||||
for (int i = num_vec_elems * 4 + tid; i < num_elems;
|
||||
i += blockDim.x * gridDim.x) {
|
||||
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
|
||||
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
|
||||
__shared__ float token_scale;
|
||||
if (tid == 0) {
|
||||
if (scale_ub) {
|
||||
token_scale = min(block_absmax_val_maybe, *scale_ub);
|
||||
} else {
|
||||
token_scale = block_absmax_val_maybe;
|
||||
}
|
||||
// token scale computation
|
||||
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
|
||||
scale[token_idx] = token_scale;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Note that we don't use inverted scales so we can match FBGemm impl.
|
||||
if (can_vectorize) {
|
||||
scaled_fp8_conversion_vec<scalar_t, false>(
|
||||
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
|
||||
} else {
|
||||
for (int i = tid; i < hidden_size; i += blockDim.x) {
|
||||
token_output[i] = scaled_fp8_conversion<false>(
|
||||
static_cast<float>(token_input[i]), token_scale);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
int64_t num_tokens = input.numel() / input.size(-1);
|
||||
int64_t num_elems = input.numel();
|
||||
@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
scale.data_ptr<float>(), num_elems);
|
||||
});
|
||||
}
|
||||
|
||||
void dynamic_per_token_scaled_fp8_quant(
|
||||
torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
|
||||
int const hidden_size = input.size(-1);
|
||||
int const num_tokens = input.numel() / hidden_size;
|
||||
dim3 const grid(num_tokens);
|
||||
dim3 const block(std::min(hidden_size, 1024));
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
|
||||
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
|
||||
input.data_ptr<scalar_t>(),
|
||||
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
|
||||
hidden_size);
|
||||
});
|
||||
}
|
||||
|
@ -19,10 +19,10 @@
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#include "../gptq_marlin/gptq_marlin.cuh"
|
||||
#include "../gptq_marlin/gptq_marlin_dtypes.cuh"
|
||||
#include "../gptq_marlin/marlin.cuh"
|
||||
#include "../gptq_marlin/marlin_dtypes.cuh"
|
||||
|
||||
using namespace gptq_marlin;
|
||||
using namespace marlin;
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@ -1224,16 +1224,15 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
||||
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
@ -1274,11 +1273,9 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
num_groups = b_scales.size(0);
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
||||
int min_workspace_size =
|
||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
@ -1290,14 +1287,14 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
marlin::max_par);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
|
||||
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
|
||||
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
marlin::max_par);
|
||||
} else {
|
||||
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
|
||||
__NV_SATFINITE, fp8_type);
|
||||
return (uint8_t)res;
|
||||
#endif
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// float -> fp8
|
||||
@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
|
||||
@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
|
||||
}
|
||||
#endif
|
||||
assert(false);
|
||||
__builtin_unreachable(); // Suppress missing return statement warning
|
||||
}
|
||||
|
||||
// The following macro is used to dispatch the conversion function based on
|
||||
|
269
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
269
csrc/quantization/gptq_marlin/awq_marlin_repack.cu
Normal file
@ -0,0 +1,269 @@
|
||||
#include "marlin.cuh"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits>
|
||||
__global__ void awq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
int k_tiles = size_k / tile_k_size;
|
||||
int n_tiles = size_n / tile_n_size;
|
||||
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
|
||||
|
||||
int start_k_tile = blockIdx.x * block_k_tiles;
|
||||
if (start_k_tile >= k_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
// and can only issue the next fetch when it is guaranteed that the previous
|
||||
// shared memory load is fully complete (as it may otherwise be
|
||||
// overwritten).
|
||||
cp_async_wait<repack_stages - 2>();
|
||||
__syncthreads();
|
||||
};
|
||||
|
||||
extern __shared__ int4 sh[];
|
||||
|
||||
constexpr int tile_n_ints = tile_n_size / pack_factor;
|
||||
|
||||
constexpr int stage_n_threads = tile_n_ints / 4;
|
||||
constexpr int stage_k_threads = tile_k_size;
|
||||
constexpr int stage_size = stage_k_threads * stage_n_threads;
|
||||
|
||||
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
cp_async_fence();
|
||||
return;
|
||||
}
|
||||
|
||||
int first_n = n_tile_id * tile_n_size;
|
||||
int first_n_packed = first_n / pack_factor;
|
||||
|
||||
int4* sh_ptr = sh + stage_size * pipe;
|
||||
|
||||
if (threadIdx.x < stage_size) {
|
||||
int k_id = threadIdx.x / stage_n_threads;
|
||||
int n_id = threadIdx.x % stage_n_threads;
|
||||
|
||||
int first_k = k_tile_id * tile_k_size;
|
||||
|
||||
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
|
||||
reinterpret_cast<int4 const*>(
|
||||
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
|
||||
first_n_packed + (n_id * 4)])));
|
||||
}
|
||||
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
|
||||
if (n_tile_id >= n_tiles) {
|
||||
return;
|
||||
}
|
||||
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int th_id = threadIdx.x % 32;
|
||||
|
||||
if (warp_id >= 4) {
|
||||
return;
|
||||
}
|
||||
|
||||
int tc_col = th_id / 4;
|
||||
int tc_row = (th_id % 4) * 2;
|
||||
|
||||
constexpr int tc_offsets[4] = {0, 1, 8, 9};
|
||||
|
||||
int cur_n = warp_id * 16 + tc_col;
|
||||
int cur_n_packed = cur_n / pack_factor;
|
||||
int cur_n_pos = cur_n % pack_factor;
|
||||
|
||||
constexpr int sh_stride = tile_n_ints;
|
||||
constexpr uint32_t mask = (1 << num_bits) - 1;
|
||||
|
||||
int4* sh_stage_ptr = sh + stage_size * pipe;
|
||||
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
|
||||
|
||||
// Undo interleaving
|
||||
int cur_n_pos_unpacked;
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
} else {
|
||||
constexpr int undo_pack[4] = {0, 2, 1, 3};
|
||||
cur_n_pos_unpacked = undo_pack[cur_n_pos];
|
||||
}
|
||||
|
||||
uint32_t vals[8];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int cur_elem = tc_row + tc_offsets[i];
|
||||
|
||||
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
|
||||
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
|
||||
sh_stride * cur_elem];
|
||||
|
||||
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
|
||||
}
|
||||
|
||||
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
|
||||
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
|
||||
|
||||
// Result of:
|
||||
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
|
||||
if constexpr (num_bits == 4) {
|
||||
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
|
||||
|
||||
uint32_t res = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
res |= vals[pack_idx[i]] << (i * 4);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 4 + warp_id] = res;
|
||||
|
||||
} else {
|
||||
constexpr int pack_idx[4] = {0, 2, 1, 3};
|
||||
|
||||
uint32_t res1 = 0;
|
||||
uint32_t res2 = 0;
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
res1 |= vals[pack_idx[i]] << (i * 8);
|
||||
res2 |= vals[4 + pack_idx[i]] << (i * 8);
|
||||
}
|
||||
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
|
||||
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
|
||||
}
|
||||
};
|
||||
|
||||
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
|
||||
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
|
||||
}
|
||||
|
||||
wait_for_stage();
|
||||
};
|
||||
#pragma unroll
|
||||
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
|
||||
int n_tile_id = 0;
|
||||
|
||||
start_pipes(k_tile_id, n_tile_id);
|
||||
|
||||
while (n_tile_id < n_tiles) {
|
||||
#pragma unroll
|
||||
for (int pipe = 0; pipe < repack_stages; pipe++) {
|
||||
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
|
||||
n_tile_id + pipe + repack_stages - 1);
|
||||
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
|
||||
wait_for_stage();
|
||||
}
|
||||
n_tile_id += repack_stages;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS) \
|
||||
else if (num_bits == NUM_BITS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
|
||||
int64_t size_n, int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int const pack_factor = 32 / num_bits;
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(b_q_weight.size(0) == size_k,
|
||||
"b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
" is not size_k = ", size_k);
|
||||
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
|
||||
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
", size_n = ", size_n, ", pack_factor = ", pack_factor);
|
||||
|
||||
// Verify device and strides
|
||||
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
|
||||
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
|
||||
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
|
||||
|
||||
// Alloc buffers
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Get ptrs
|
||||
uint32_t const* b_q_weight_ptr =
|
||||
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
|
||||
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
|
||||
|
||||
// Get dev info
|
||||
int dev = b_q_weight.get_device();
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
|
||||
int blocks;
|
||||
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
|
||||
|
||||
int max_shared_mem = 0;
|
||||
cudaDeviceGetAttribute(&max_shared_mem,
|
||||
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
|
||||
TORCH_CHECK(max_shared_mem > 0);
|
||||
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4)
|
||||
CALL_IF(8)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
return out;
|
||||
}
|
||||
|
||||
#endif
|
@ -19,8 +19,9 @@
|
||||
* Adapted from https://github.com/IST-DASLab/marlin
|
||||
*/
|
||||
|
||||
#include "gptq_marlin.cuh"
|
||||
#include "gptq_marlin_dtypes.cuh"
|
||||
#include "marlin.cuh"
|
||||
#include "marlin_dtypes.cuh"
|
||||
#include "core/scalar_type.hpp"
|
||||
|
||||
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
|
||||
static_assert(std::is_same<scalar_t, half>::value || \
|
||||
@ -32,7 +33,7 @@ inline std::string str(T x) {
|
||||
return std::to_string(x);
|
||||
}
|
||||
|
||||
namespace gptq_marlin {
|
||||
namespace marlin {
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
@ -59,23 +60,27 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full) {
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp) {
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(false,
|
||||
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
|
||||
return torch::empty({1, 1});
|
||||
@ -264,6 +269,114 @@ dequant_8bit<nv_bfloat16>(int q) {
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Zero-point dequantizers
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
|
||||
int q) {
|
||||
const int LO = 0x000f000f;
|
||||
const int HI = 0x00f000f0;
|
||||
const int EX = 0x64006400;
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
|
||||
|
||||
const int SUB = 0x64006400;
|
||||
const int MUL = 0x2c002c00;
|
||||
const int ADD = 0xd400d400;
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&SUB));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&MUL),
|
||||
*reinterpret_cast<const half2*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_4bit_zp<nv_bfloat16>(int q) {
|
||||
static constexpr uint32_t MASK = 0x000f000f;
|
||||
static constexpr uint32_t EX = 0x43004300;
|
||||
|
||||
// Guarantee that the `(a & b) | c` operations are LOP3s.
|
||||
|
||||
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
q >>= 4;
|
||||
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
|
||||
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
static constexpr uint32_t MUL = 0x3F803F80;
|
||||
static constexpr uint32_t ADD = 0xC300C300;
|
||||
|
||||
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&MUL),
|
||||
*reinterpret_cast<const nv_bfloat162*>(&ADD));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
|
||||
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
|
||||
int q) {
|
||||
static constexpr uint32_t mask_for_elt_01 = 0x5250;
|
||||
static constexpr uint32_t mask_for_elt_23 = 0x5351;
|
||||
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
|
||||
|
||||
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
|
||||
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
|
||||
|
||||
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
|
||||
|
||||
typename ScalarType<half>::FragB frag_b;
|
||||
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
|
||||
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
template <>
|
||||
__device__ inline typename ScalarType<nv_bfloat16>::FragB
|
||||
dequant_8bit_zp<nv_bfloat16>(int q) {
|
||||
typename ScalarType<nv_bfloat16>::FragB frag_b;
|
||||
|
||||
float fp32_intermediates[4];
|
||||
uint32_t* fp32_intermediates_casted =
|
||||
reinterpret_cast<uint32_t*>(fp32_intermediates);
|
||||
|
||||
static constexpr uint32_t fp32_base = 0x4B000000;
|
||||
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
|
||||
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
|
||||
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
|
||||
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
|
||||
|
||||
fp32_intermediates[0] -= 8388608.f;
|
||||
fp32_intermediates[1] -= 8388608.f;
|
||||
fp32_intermediates[2] -= 8388608.f;
|
||||
fp32_intermediates[3] -= 8388608.f;
|
||||
|
||||
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
|
||||
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
|
||||
fp32_intermediates_casted[1], 0x7632);
|
||||
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
|
||||
fp32_intermediates_casted[3], 0x7632);
|
||||
|
||||
return frag_b;
|
||||
}
|
||||
|
||||
// Multiply dequantized values by the corresponding quantization scale; used
|
||||
// only for grouped quantization.
|
||||
template <typename scalar_t>
|
||||
@ -277,6 +390,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
frag_b[1] = __hmul2(frag_b[1], s);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
|
||||
int i) {
|
||||
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
|
||||
scalar_t2 zp =
|
||||
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
|
||||
frag_b[0] = __hsub2(frag_b[0], zp);
|
||||
frag_b[1] = __hsub2(frag_b[1], zp);
|
||||
}
|
||||
|
||||
// Same as above, but for act_order (each K is multiplied individually)
|
||||
template <typename scalar_t>
|
||||
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
|
||||
@ -404,6 +528,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
|
||||
const int stages, // number of stages for the async global->shared
|
||||
// fetch pipeline
|
||||
const bool has_act_order, // whether act_order is enabled
|
||||
const bool has_zp, // whether zero-points are enabled
|
||||
const int group_blocks = -1 // number of consecutive 16x16 blocks
|
||||
// with a separate quantization scale
|
||||
>
|
||||
@ -411,14 +536,18 @@ __global__ void Marlin(
|
||||
const int4* __restrict__ A, // fp16 input matrix of shape mxk
|
||||
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
|
||||
int4* __restrict__ C, // fp16 output buffer of shape mxn
|
||||
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
|
||||
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
|
||||
// (k/groupsize)xn
|
||||
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
|
||||
// (k/groupsize)x(n/pack_factor)
|
||||
const int* __restrict__ g_idx, // int32 group indices of shape k
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks // extra global storage for barrier synchronization
|
||||
int num_groups, // number of scale groups per output channel
|
||||
int prob_m, // batch dimension m
|
||||
int prob_n, // output dimension n
|
||||
int prob_k, // reduction dimension k
|
||||
int* locks, // extra global storage for barrier synchronization
|
||||
bool use_fp32_reduce // whether to use fp32 global reduce
|
||||
) {
|
||||
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
|
||||
// same size, which might involve multiple column "slices" (of width 16 *
|
||||
@ -437,6 +566,7 @@ __global__ void Marlin(
|
||||
using FragB = typename ScalarType<scalar_t>::FragB;
|
||||
using FragC = typename ScalarType<scalar_t>::FragC;
|
||||
using FragS = typename ScalarType<scalar_t>::FragS;
|
||||
using FragZP = typename ScalarType<scalar_t>::FragZP;
|
||||
|
||||
constexpr int pack_factor = 32 / num_bits;
|
||||
|
||||
@ -471,6 +601,8 @@ __global__ void Marlin(
|
||||
int slice_idx; // index of threadblock in current slice; numbered bottom to
|
||||
// top
|
||||
|
||||
int par_id = 0;
|
||||
|
||||
// We can easily implement parallel problem execution by just remapping
|
||||
// indices and advancing global pointers
|
||||
if (slice_col_par >= n_tiles) {
|
||||
@ -478,6 +610,7 @@ __global__ void Marlin(
|
||||
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
|
||||
locks += (slice_col_par / n_tiles) * n_tiles;
|
||||
slice_col = slice_col_par % n_tiles;
|
||||
par_id = slice_col_par / n_tiles;
|
||||
}
|
||||
|
||||
// Compute all information about the current slice which is required for
|
||||
@ -508,6 +641,7 @@ __global__ void Marlin(
|
||||
C += 16 * thread_m_blocks * prob_n / 8;
|
||||
locks += n_tiles;
|
||||
slice_col = 0;
|
||||
par_id++;
|
||||
}
|
||||
};
|
||||
init_slice();
|
||||
@ -566,6 +700,13 @@ __global__ void Marlin(
|
||||
int tb_n_warps = thread_n_blocks / 4;
|
||||
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
|
||||
|
||||
// Zero-points sizes/strides
|
||||
int zp_gl_stride = (prob_n / pack_factor) / 4;
|
||||
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
|
||||
constexpr int zp_tb_groups = s_tb_groups;
|
||||
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
|
||||
int zp_gl_rd_delta = zp_gl_stride;
|
||||
|
||||
// Global A read index of current thread.
|
||||
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
|
||||
(threadIdx.x % a_gl_rd_delta_o);
|
||||
@ -605,6 +746,19 @@ __global__ void Marlin(
|
||||
int s_sh_wr = threadIdx.x;
|
||||
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
|
||||
|
||||
// Zero-points
|
||||
int zp_gl_rd;
|
||||
if constexpr (has_zp) {
|
||||
if constexpr (group_blocks == -1) {
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
} else {
|
||||
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
|
||||
zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
}
|
||||
int zp_sh_wr = threadIdx.x;
|
||||
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
|
||||
|
||||
// We use a different scale layout for grouped and column-wise quantization as
|
||||
// we scale a `half2` tile in column-major layout in the former and in
|
||||
// row-major in the latter case.
|
||||
@ -616,6 +770,18 @@ __global__ void Marlin(
|
||||
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
(threadIdx.x % 32) % 4;
|
||||
|
||||
// Zero-points have the same read layout as the scales
|
||||
// (without column-wise case)
|
||||
constexpr int num_col_threads = 8;
|
||||
constexpr int num_row_threads = 4;
|
||||
constexpr int num_ints_per_thread = 8 / pack_factor;
|
||||
int zp_sh_rd;
|
||||
if constexpr (has_zp) {
|
||||
zp_sh_rd = num_ints_per_thread * num_col_threads *
|
||||
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
|
||||
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
|
||||
}
|
||||
|
||||
// Precompute which thread should not read memory in which iterations; this is
|
||||
// needed if there are more threads than required for a certain tilesize or
|
||||
// when the batchsize is not a multiple of 16.
|
||||
@ -664,14 +830,17 @@ __global__ void Marlin(
|
||||
int4* sh_a = sh;
|
||||
int4* sh_b = sh_a + (stages * a_sh_stage);
|
||||
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
|
||||
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
|
||||
int4* sh_s = sh_zp + (stages * zp_sh_stage);
|
||||
|
||||
// Register storage for double buffer of shared memory reads.
|
||||
FragA frag_a[2][thread_m_blocks];
|
||||
I4 frag_b_quant[2][b_thread_vecs];
|
||||
FragC frag_c[thread_m_blocks][4][2];
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
FragS frag_s[2][4]; // No act-order
|
||||
FragS act_frag_s[2][4][4]; // For act-order
|
||||
int frag_qzp[2][num_ints_per_thread]; // Zero-points
|
||||
FragZP frag_zp; // Zero-points in fp16
|
||||
|
||||
// Zero accumulators.
|
||||
auto zero_accums = [&]() {
|
||||
@ -777,6 +946,28 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks != -1) {
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
if constexpr (group_blocks >= thread_k_blocks) {
|
||||
// Only fetch zero-points if this tile starts a new group
|
||||
if (pipe % (group_blocks / thread_k_blocks) == 0) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < zp_tb_groups; i++) {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
|
||||
&zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
zp_gl_rd += zp_gl_rd_delta;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Insert a fence even when we are winding down the pipeline to ensure that
|
||||
@ -784,6 +975,12 @@ __global__ void Marlin(
|
||||
cp_async_fence();
|
||||
};
|
||||
|
||||
auto fetch_zp_to_shared = [&]() {
|
||||
if (zp_sh_wr_pred) {
|
||||
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
|
||||
}
|
||||
};
|
||||
|
||||
// Wait until the next thread tile has been loaded to shared memory.
|
||||
auto wait_for_stage = [&]() {
|
||||
// We only have `stages - 2` active fetches since we are double buffering
|
||||
@ -932,8 +1129,82 @@ __global__ void Marlin(
|
||||
}
|
||||
};
|
||||
|
||||
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
|
||||
// This code does not handle group_blocks == 0,
|
||||
// which signifies act_order.
|
||||
// has_zp implies AWQ, which doesn't have act_order,
|
||||
static_assert(!has_zp || group_blocks != 0);
|
||||
|
||||
if constexpr (has_zp) {
|
||||
int pipe = full_pipe % stages;
|
||||
|
||||
if constexpr (group_blocks == -1) {
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
|
||||
}
|
||||
|
||||
} else if constexpr (group_blocks >= thread_k_blocks) {
|
||||
int4* sh_zp_stage =
|
||||
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
|
||||
(pipe / (group_blocks / thread_k_blocks)));
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
} else {
|
||||
int warp_id = threadIdx.x / 32;
|
||||
int n_warps = thread_n_blocks / 4;
|
||||
|
||||
int warp_row = warp_id / n_warps;
|
||||
|
||||
int cur_k = warp_row * 16;
|
||||
cur_k += k_iter_size * (k % b_sh_wr_iters);
|
||||
|
||||
int k_blocks = cur_k / 16;
|
||||
int cur_group_id = 0;
|
||||
|
||||
// Suppress bogus and persistent divide-by-zero warning
|
||||
#pragma nv_diagnostic push
|
||||
#pragma nv_diag_suppress divide_by_zero
|
||||
cur_group_id = k_blocks / group_blocks;
|
||||
#pragma nv_diagnostic pop
|
||||
|
||||
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
|
||||
|
||||
sh_zp_stage += cur_group_id * zp_sh_stride;
|
||||
|
||||
for (int i = 0; i < num_ints_per_thread; i++) {
|
||||
frag_qzp[k % 2][i] =
|
||||
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Execute the actual tensor core matmul of a sub-tile.
|
||||
auto matmul = [&](int k) {
|
||||
if constexpr (has_zp) {
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
if constexpr (num_bits == 4) {
|
||||
int zp_quant = frag_qzp[k % 2][0];
|
||||
int zp_quant_shift = zp_quant >> 8;
|
||||
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
|
||||
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
|
||||
|
||||
} else {
|
||||
int zp_quant_0 = frag_qzp[k % 2][0];
|
||||
int zp_quant_1 = frag_qzp[k % 2][1];
|
||||
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
|
||||
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
|
||||
}
|
||||
|
||||
frag_zp[0] = frag_zp_0[0];
|
||||
frag_zp[1] = frag_zp_0[1];
|
||||
frag_zp[2] = frag_zp_1[0];
|
||||
frag_zp[3] = frag_zp_1[1];
|
||||
}
|
||||
|
||||
// We have the m dimension as the inner loop in order to encourage overlapping
|
||||
// dequantization and matmul operations.
|
||||
#pragma unroll
|
||||
@ -944,16 +1215,32 @@ __global__ void Marlin(
|
||||
int b_quant = frag_b_quant[k % 2][0][j];
|
||||
int b_quant_shift = b_quant >> 8;
|
||||
|
||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||
if constexpr (has_zp) {
|
||||
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
|
||||
|
||||
} else {
|
||||
frag_b0 = dequant_4bit<scalar_t>(b_quant);
|
||||
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
|
||||
}
|
||||
|
||||
} else {
|
||||
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
|
||||
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
|
||||
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
|
||||
|
||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||
if constexpr (has_zp) {
|
||||
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
|
||||
} else {
|
||||
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
|
||||
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
|
||||
}
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b0
|
||||
if constexpr (has_zp) {
|
||||
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b0
|
||||
@ -967,6 +1254,11 @@ __global__ void Marlin(
|
||||
}
|
||||
}
|
||||
|
||||
// Apply zero-point to frag_b1
|
||||
if constexpr (has_zp) {
|
||||
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
|
||||
}
|
||||
|
||||
// Apply scale to frag_b1
|
||||
if constexpr (has_act_order) {
|
||||
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
|
||||
@ -1048,7 +1340,7 @@ __global__ void Marlin(
|
||||
// finally have to globally reduce over the results. As the striped
|
||||
// partitioning minimizes the number of such reductions and our outputs are
|
||||
// usually rather small, we perform this reduction serially in L2 cache.
|
||||
auto global_reduce = [&](bool first = false, bool last = false) {
|
||||
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
|
||||
// We are very careful here to reduce directly in the output buffer to
|
||||
// maximize L2 cache utilization in this step. To do this, we write out
|
||||
// results in FP16 (but still reduce with FP32 compute).
|
||||
@ -1109,6 +1401,53 @@ __global__ void Marlin(
|
||||
}
|
||||
};
|
||||
|
||||
// Globally reduce over threadblocks that compute the same column block.
|
||||
// We use a tmp C buffer to reduce in full fp32 precision.
|
||||
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
|
||||
constexpr int tb_m = thread_m_blocks * 16;
|
||||
constexpr int tb_n = thread_n_blocks * 16;
|
||||
|
||||
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
|
||||
|
||||
constexpr int active_threads = 32 * thread_n_blocks / 4;
|
||||
bool is_th_active = threadIdx.x < active_threads;
|
||||
|
||||
int par_offset = c_size * n_tiles * par_id;
|
||||
int slice_offset = c_size * slice_col;
|
||||
|
||||
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
|
||||
constexpr int th_size = num_floats * sizeof(float) / 16;
|
||||
|
||||
int c_cur_offset = par_offset + slice_offset;
|
||||
|
||||
if (!is_th_active) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!first) {
|
||||
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
sh[threadIdx.x] =
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
|
||||
|
||||
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
|
||||
#pragma unroll
|
||||
for (int f = 0; f < 4; f++) {
|
||||
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (!last) {
|
||||
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
|
||||
#pragma unroll
|
||||
for (int k = 0; k < th_size; k++) {
|
||||
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Write out the reduce final result in the correct layout. We only actually
|
||||
// reshuffle matrix fragments in this step, the reduction above is performed
|
||||
// in fragment layout.
|
||||
@ -1189,6 +1528,12 @@ __global__ void Marlin(
|
||||
}
|
||||
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
|
||||
}
|
||||
|
||||
if constexpr (has_zp && group_blocks == -1) {
|
||||
if (i == 0) {
|
||||
fetch_zp_to_shared();
|
||||
}
|
||||
}
|
||||
fetch_to_shared(i, i, i < slice_iters);
|
||||
}
|
||||
|
||||
@ -1197,6 +1542,7 @@ __global__ void Marlin(
|
||||
init_same_group(0);
|
||||
fetch_to_registers(0, 0);
|
||||
fetch_scales_to_registers(0, 0);
|
||||
fetch_zp_to_registers(0, 0);
|
||||
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
|
||||
slice_k_start_shared_fetch += tb_k * (stages - 1);
|
||||
};
|
||||
@ -1217,6 +1563,7 @@ __global__ void Marlin(
|
||||
for (int k = 0; k < b_sh_wr_iters; k++) {
|
||||
fetch_to_registers(k + 1, pipe % stages);
|
||||
fetch_scales_to_registers(k + 1, pipe);
|
||||
fetch_zp_to_registers(k + 1, pipe);
|
||||
if (k == b_sh_wr_iters - 2) {
|
||||
fetch_to_shared((pipe + stages - 1) % stages, pipe,
|
||||
slice_iters >= stages);
|
||||
@ -1325,7 +1672,11 @@ __global__ void Marlin(
|
||||
if (slice_count > 1) { // only globally reduce if there is more than one
|
||||
// block in a slice
|
||||
barrier_acquire(&locks[slice_col], slice_idx);
|
||||
global_reduce(slice_idx == 0, last);
|
||||
if (use_fp32_reduce) {
|
||||
global_reduce_fp32(slice_idx == 0, last);
|
||||
} else {
|
||||
global_reduce_fp16(slice_idx == 0, last);
|
||||
}
|
||||
barrier_release(&locks[slice_col], last);
|
||||
}
|
||||
if (last) // only the last block in a slice actually writes the result
|
||||
@ -1354,6 +1705,7 @@ __global__ void Marlin(
|
||||
|
||||
} else {
|
||||
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
|
||||
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
|
||||
}
|
||||
|
||||
start_pipes();
|
||||
@ -1363,22 +1715,24 @@ __global__ void Marlin(
|
||||
}
|
||||
|
||||
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
|
||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
|
||||
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
|
||||
NUM_THREADS) \
|
||||
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
|
||||
thread_n_blocks == THREAD_N_BLOCKS && \
|
||||
thread_k_blocks == THREAD_K_BLOCKS && \
|
||||
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
|
||||
num_threads == NUM_THREADS) { \
|
||||
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
|
||||
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
|
||||
cudaFuncSetAttribute( \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS>, \
|
||||
HAS_ZP, GROUP_BLOCKS>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
|
||||
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
|
||||
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
|
||||
prob_k, locks); \
|
||||
HAS_ZP, GROUP_BLOCKS> \
|
||||
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
|
||||
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
|
||||
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
|
||||
}
|
||||
|
||||
typedef struct {
|
||||
@ -1517,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
|
||||
return true;
|
||||
}
|
||||
|
||||
int determine_reduce_max_m(int prob_m, int max_par) {
|
||||
constexpr int tile_m_size = 16;
|
||||
|
||||
if (prob_m <= tile_m_size) {
|
||||
return tile_m_size;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 2) {
|
||||
return tile_m_size * 2;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 3) {
|
||||
return tile_m_size * 3;
|
||||
|
||||
} else if (prob_m <= tile_m_size * 4) {
|
||||
return tile_m_size * 4;
|
||||
|
||||
} else {
|
||||
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
|
||||
return tile_m_size * 4 * cur_par;
|
||||
}
|
||||
}
|
||||
|
||||
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
int num_bits, int group_size,
|
||||
bool has_act_order, bool is_k_full,
|
||||
@ -1548,44 +1923,77 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
|
||||
return exec_config_t{0, {-1, -1, -1}};
|
||||
}
|
||||
|
||||
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
|
||||
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
|
||||
|
||||
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
|
||||
\
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
|
||||
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
|
||||
|
||||
template <typename scalar_t>
|
||||
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace, int num_bits,
|
||||
bool has_act_order, bool is_k_full, int num_groups,
|
||||
int group_size, int dev, cudaStream_t stream, int thread_k,
|
||||
int thread_n, int sms, int max_par) {
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
|
||||
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
|
||||
int prob_n, int prob_k, void* workspace,
|
||||
vllm::ScalarType const& q_type, bool has_act_order,
|
||||
bool is_k_full, bool has_zp, int num_groups, int group_size,
|
||||
int dev, cudaStream_t stream, int thread_k, int thread_n,
|
||||
int sms, int max_par, bool use_fp32_reduce) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4 || q_type == vllm::kU8,
|
||||
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
|
||||
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
q_type.str());
|
||||
}
|
||||
|
||||
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
|
||||
", ", prob_n, ", ", prob_k, "]");
|
||||
|
||||
// TODO: remove alias when we start supporting other 8bit types
|
||||
int num_bits = q_type.size_bits();
|
||||
int tot_m = prob_m;
|
||||
int tot_m_blocks = div_ceil(tot_m, 16);
|
||||
int pad = 16 * tot_m_blocks - tot_m;
|
||||
@ -1664,7 +2072,9 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
const int4* A_ptr = (const int4*)A;
|
||||
const int4* B_ptr = (const int4*)B;
|
||||
int4* C_ptr = (int4*)C;
|
||||
int4* C_tmp_ptr = (int4*)C_tmp;
|
||||
const int4* s_ptr = (const int4*)s;
|
||||
const int4* zp_ptr = (const int4*)zp;
|
||||
const int* g_idx_ptr = (const int*)g_idx;
|
||||
const int* perm_ptr = (const int*)perm;
|
||||
int4* a_tmp_ptr = (int4*)a_tmp;
|
||||
@ -1701,28 +2111,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
thread_m_blocks = exec_cfg.max_m_blocks;
|
||||
}
|
||||
|
||||
// Define kernel configurations
|
||||
if (false) {
|
||||
}
|
||||
CALL_IF(4, 32, 2, 256)
|
||||
CALL_IF(4, 16, 4, 256)
|
||||
CALL_IF(4, 8, 8, 256)
|
||||
CALL_IF(4, 8, 4, 128)
|
||||
CALL_IF(4, 4, 8, 128)
|
||||
CALL_IF(8, 32, 2, 256)
|
||||
CALL_IF(8, 16, 4, 256)
|
||||
CALL_IF(8, 8, 8, 256)
|
||||
CALL_IF(8, 8, 4, 128)
|
||||
CALL_IF(8, 4, 8, 128)
|
||||
GPTQ_CALL_IF(4, 16, 4, 256)
|
||||
GPTQ_CALL_IF(4, 8, 8, 256)
|
||||
GPTQ_CALL_IF(4, 8, 4, 128)
|
||||
GPTQ_CALL_IF(4, 4, 8, 128)
|
||||
GPTQ_CALL_IF(8, 16, 4, 256)
|
||||
GPTQ_CALL_IF(8, 8, 8, 256)
|
||||
GPTQ_CALL_IF(8, 8, 4, 128)
|
||||
GPTQ_CALL_IF(8, 4, 8, 128)
|
||||
|
||||
AWQ_CALL_IF(4, 16, 4, 256)
|
||||
AWQ_CALL_IF(4, 8, 8, 256)
|
||||
AWQ_CALL_IF(4, 8, 4, 128)
|
||||
AWQ_CALL_IF(4, 4, 8, 128)
|
||||
AWQ_CALL_IF(8, 16, 4, 256)
|
||||
AWQ_CALL_IF(8, 8, 8, 256)
|
||||
AWQ_CALL_IF(8, 8, 4, 128)
|
||||
AWQ_CALL_IF(8, 4, 8, 128)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
|
||||
str(prob_n) + ", " + str(prob_k) + "]" +
|
||||
", has_act_order = " + str(has_act_order) +
|
||||
", num_groups = " + str(num_groups) +
|
||||
", group_size = " + str(group_size) +
|
||||
", thread_m_blocks = " + str(thread_m_blocks) +
|
||||
", thread_n_blocks = " + str(thread_n_blocks) +
|
||||
", thread_k_blocks = " + str(thread_k_blocks));
|
||||
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
|
||||
", ", prob_k, "]", ", has_act_order = ", has_act_order,
|
||||
", num_groups = ", num_groups, ", group_size = ", group_size,
|
||||
", thread_m_blocks = ", thread_m_blocks,
|
||||
", thread_n_blocks = ", thread_n_blocks,
|
||||
", thread_k_blocks = ", thread_k_blocks,
|
||||
", num_bits = ", num_bits);
|
||||
}
|
||||
|
||||
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
|
||||
@ -1730,17 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor& b_scales, torch::Tensor& g_idx,
|
||||
torch::Tensor& perm, torch::Tensor& workspace,
|
||||
int64_t num_bits, int64_t size_m, int64_t size_n,
|
||||
int64_t size_k, bool is_k_full) {
|
||||
// Verify num_bits
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
int pack_factor = 32 / num_bits;
|
||||
torch::Tensor& b_scales, torch::Tensor& b_zeros,
|
||||
torch::Tensor& g_idx, torch::Tensor& perm,
|
||||
torch::Tensor& workspace,
|
||||
vllm::ScalarTypeTorchPtr const& b_q_type,
|
||||
int64_t size_m, int64_t size_n, int64_t size_k,
|
||||
bool is_k_full, bool has_zp,
|
||||
bool use_fp32_reduce) {
|
||||
if (has_zp) {
|
||||
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
|
||||
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
|
||||
b_q_type->str());
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
|
||||
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
|
||||
b_q_type->str());
|
||||
}
|
||||
|
||||
int pack_factor = 32 / b_q_type->size_bits();
|
||||
|
||||
// Verify A
|
||||
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
|
||||
@ -1749,16 +2175,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
", size_k = ", size_k);
|
||||
|
||||
// Verify B
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
|
||||
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
|
||||
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
|
||||
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
|
||||
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
|
||||
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
|
||||
"b_q_weight.size(1) = ", b_q_weight.size(1),
|
||||
" is not divisible by tile_size = ", gptq_marlin::tile_size);
|
||||
int actual_size_n =
|
||||
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
|
||||
" is not divisible by tile_size = ", marlin::tile_size);
|
||||
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
|
||||
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
|
||||
", actual_size_n = ", actual_size_n);
|
||||
|
||||
@ -1772,6 +2197,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
|
||||
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
|
||||
|
||||
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
|
||||
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
|
||||
|
||||
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
|
||||
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
|
||||
|
||||
@ -1784,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
torch::Tensor c = torch::empty({size_m, size_n}, options);
|
||||
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
|
||||
|
||||
// Alloc C tmp buffer that is going to be used for the global reduce
|
||||
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
|
||||
int reduce_n = size_n;
|
||||
auto options_fp32 =
|
||||
torch::TensorOptions().dtype(at::kFloat).device(a.device());
|
||||
if (!use_fp32_reduce) {
|
||||
reduce_max_m = 0;
|
||||
reduce_n = 0;
|
||||
}
|
||||
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
|
||||
|
||||
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
|
||||
// auto -1)
|
||||
int thread_k = -1;
|
||||
@ -1805,8 +2244,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
int group_size = -1;
|
||||
bool has_act_order = g_idx.size(0) != 0;
|
||||
|
||||
int b_rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
|
||||
int rank = b_scales.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
|
||||
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
|
||||
" is not size_n = ", size_n);
|
||||
num_groups = b_scales.size(0);
|
||||
@ -1832,34 +2271,45 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
|
||||
}
|
||||
}
|
||||
|
||||
// Verify b_zeros
|
||||
if (has_zp) {
|
||||
int rank = b_zeros.sizes().size();
|
||||
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
|
||||
TORCH_CHECK(b_zeros.size(0) == num_groups,
|
||||
"b_zeros dim 0 = ", b_zeros.size(0),
|
||||
" is not num_groups = ", num_groups);
|
||||
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
|
||||
"b_zeros dim 1 = ", b_scales.size(1),
|
||||
" is not size_n / pack_factor = ", size_n / pack_factor);
|
||||
}
|
||||
|
||||
// Verify workspace size
|
||||
TORCH_CHECK(
|
||||
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
|
||||
int min_workspace_size =
|
||||
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
|
||||
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
|
||||
", is not divisible by min_thread_n = ", marlin::min_thread_n);
|
||||
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
|
||||
TORCH_CHECK(workspace.numel() >= min_workspace_size,
|
||||
"workspace.numel = ", workspace.numel(),
|
||||
" is below min_workspace_size = ", min_workspace_size);
|
||||
|
||||
int dev = a.get_device();
|
||||
if (a.scalar_type() == at::ScalarType::Half) {
|
||||
gptq_marlin::marlin_mm_f16i4<half>(
|
||||
marlin::marlin_mm<half>(
|
||||
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
|
||||
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
|
||||
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
|
||||
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
|
||||
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
|
||||
thread_n, sms, gptq_marlin::max_par);
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
|
||||
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
|
||||
marlin::marlin_mm<nv_bfloat16>(
|
||||
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
|
||||
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
|
||||
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
|
||||
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
|
||||
is_k_full, num_groups, group_size, dev,
|
||||
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
|
||||
gptq_marlin::max_par);
|
||||
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
|
||||
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
|
||||
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
|
||||
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
|
||||
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
|
||||
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
|
||||
} else {
|
||||
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
|
||||
}
|
||||
|
@ -1,23 +1,16 @@
|
||||
#include "gptq_marlin.cuh"
|
||||
|
||||
namespace gptq_marlin {
|
||||
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
#include "marlin.cuh"
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void marlin_repack_kernel(
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
|
||||
#else
|
||||
|
||||
namespace marlin {
|
||||
|
||||
template <int const num_threads, int const num_bits, bool const has_perm>
|
||||
__global__ void marlin_repack_kernel(
|
||||
__global__ void gptq_marlin_repack_kernel(
|
||||
uint32_t const* __restrict__ b_q_weight_ptr,
|
||||
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
|
||||
int size_k, int size_n) {
|
||||
@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
|
||||
NUM_BITS, HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
#define CALL_IF(NUM_BITS, HAS_PERM) \
|
||||
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
|
||||
cudaFuncSetAttribute( \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM>, \
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
|
||||
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
|
||||
HAS_PERM> \
|
||||
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
|
||||
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
|
||||
}
|
||||
|
||||
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
int64_t size_k, int64_t size_n,
|
||||
int64_t num_bits) {
|
||||
// Verify compatibility with marlin tile of 16x64
|
||||
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
|
||||
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
|
||||
" is not divisible by tile_k_size = ", marlin::tile_k_size);
|
||||
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
|
||||
" is not divisible by tile_n_size = ", marlin::tile_n_size);
|
||||
|
||||
TORCH_CHECK(num_bits == 4 || num_bits == 8,
|
||||
"num_bits must be 4 or 8. Got = ", num_bits);
|
||||
@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
|
||||
auto options = torch::TensorOptions()
|
||||
.dtype(b_q_weight.dtype())
|
||||
.device(b_q_weight.device());
|
||||
torch::Tensor out =
|
||||
torch::empty({size_k / gptq_marlin::tile_size,
|
||||
size_n * gptq_marlin::tile_size / pack_factor},
|
||||
options);
|
||||
torch::Tensor out = torch::empty(
|
||||
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
|
||||
options);
|
||||
|
||||
// Detect if there is act_order
|
||||
bool has_perm = perm.size(0) != 0;
|
||||
|
@ -9,7 +9,9 @@
|
||||
#include <cuda_runtime.h>
|
||||
#include <iostream>
|
||||
|
||||
namespace gptq_marlin {
|
||||
namespace marlin {
|
||||
|
||||
// Marlin params
|
||||
|
||||
// 8 warps are a good choice since every SM has 4 schedulers and having more
|
||||
// than 1 warp per schedule allows some more latency hiding. At the same time,
|
||||
@ -25,6 +27,15 @@ static constexpr int min_thread_k = 64;
|
||||
static constexpr int tile_size = 16;
|
||||
static constexpr int max_par = 16;
|
||||
|
||||
// Repack params
|
||||
static constexpr int repack_stages = 8;
|
||||
|
||||
static constexpr int repack_threads = 256;
|
||||
|
||||
static constexpr int tile_k_size = tile_size;
|
||||
static constexpr int tile_n_size = tile_k_size * 4;
|
||||
|
||||
// Helpers
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
@ -73,4 +84,4 @@ __device__ inline void cp_async_wait() {
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
@ -1,11 +1,11 @@
|
||||
|
||||
#ifndef _data_types_cuh
|
||||
#define _data_types_cuh
|
||||
#include "gptq_marlin.cuh"
|
||||
#include "marlin.cuh"
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
namespace gptq_marlin {
|
||||
namespace marlin {
|
||||
|
||||
template <typename scalar_t>
|
||||
class ScalarType {};
|
||||
@ -23,6 +23,7 @@ class ScalarType<half> {
|
||||
using FragB = Vec<half2, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<half2, 1>;
|
||||
using FragZP = Vec<half2, 4>;
|
||||
|
||||
static __device__ float inline num2float(const half x) {
|
||||
return __half2float(x);
|
||||
@ -51,6 +52,7 @@ class ScalarType<nv_bfloat16> {
|
||||
using FragB = Vec<nv_bfloat162, 2>;
|
||||
using FragC = Vec<float, 4>;
|
||||
using FragS = Vec<nv_bfloat162, 1>;
|
||||
using FragZP = Vec<nv_bfloat162, 4>;
|
||||
|
||||
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
|
||||
static __device__ float inline num2float(const nv_bfloat16 x) {
|
||||
@ -72,6 +74,6 @@ class ScalarType<nv_bfloat16> {
|
||||
#endif
|
||||
};
|
||||
|
||||
} // namespace gptq_marlin
|
||||
} // namespace marlin
|
||||
|
||||
#endif
|
32
csrc/quantization/marlin/dense/common/base.h
Normal file
32
csrc/quantization/marlin/dense/common/base.h
Normal file
@ -0,0 +1,32 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; }
|
||||
|
||||
// Instances of `Vec` are used to organize groups of >>registers<<, as needed
|
||||
// for instance as inputs to tensor core operations. Consequently, all
|
||||
// corresponding index accesses must be compile-time constants, which is why we
|
||||
// extensively use `#pragma unroll` throughout the kernel code to guarantee
|
||||
// this.
|
||||
template <typename T, int n>
|
||||
struct Vec {
|
||||
T elems[n];
|
||||
__device__ T& operator[](int i) { return elems[i]; }
|
||||
};
|
89
csrc/quantization/marlin/dense/common/mem.h
Normal file
89
csrc/quantization/marlin/dense/common/mem.h
Normal file
@ -0,0 +1,89 @@
|
||||
/*
|
||||
* Modified by HandH1998
|
||||
* Modified by Neural Magic
|
||||
* Copyright (C) Marlin.2024 Elias Frantar
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
// Predicated asynchronous global->shared copy; used for inputs A where we apply
|
||||
// predication to handle batchsizes that are not multiples of 16.
|
||||
__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr,
|
||||
bool pred = true) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" .reg .pred p;\n"
|
||||
" setp.ne.b32 p, %0, 0;\n"
|
||||
" @p cp.async.cg.shared.global [%1], [%2], %3;\n"
|
||||
"}\n" ::"r"((int)pred),
|
||||
"r"(smem), "l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Asynchronous global->shared copy
|
||||
__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) {
|
||||
const int BYTES = 16;
|
||||
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
|
||||
asm volatile(
|
||||
"{\n"
|
||||
" cp.async.cg.shared.global [%0], [%1], %2;\n"
|
||||
"}\n" ::"r"(smem),
|
||||
"l"(glob_ptr), "n"(BYTES));
|
||||
}
|
||||
|
||||
// Async copy fence.
|
||||
__device__ inline void cp_async_fence() {
|
||||
asm volatile("cp.async.commit_group;\n" ::);
|
||||
}
|
||||
|
||||
// Wait until at most `n` async copy stages are still pending.
|
||||
template <int n>
|
||||
__device__ inline void cp_async_wait() {
|
||||
asm volatile("cp.async.wait_group %0;\n" ::"n"(n));
|
||||
}
|
||||
|
||||
// Wait until barrier reaches `count`, then lock for current threadblock.
|
||||
__device__ inline void barrier_acquire(int* lock, int count) {
|
||||
if (threadIdx.x == 0) {
|
||||
int state = -1;
|
||||
do
|
||||
// Guarantee that subsequent writes by this threadblock will be visible
|
||||
// globally.
|
||||
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
|
||||
: "=r"(state)
|
||||
: "l"(lock));
|
||||
while (state != count);
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
|
||||
// Release barrier and increment visitation count.
|
||||
__device__ inline void barrier_release(int* lock, bool reset = false) {
|
||||
__syncthreads();
|
||||
if (threadIdx.x == 0) {
|
||||
if (reset) {
|
||||
lock[0] = 0;
|
||||
return;
|
||||
}
|
||||
int val = 1;
|
||||
// Make sure that all writes since acquiring this barrier are visible
|
||||
// globally, while releasing the barrier.
|
||||
asm volatile("fence.acq_rel.gpu;\n");
|
||||
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
|
||||
:
|
||||
: "l"(lock), "r"(val));
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user