mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 14:53:52 +08:00
Compare commits
495 Commits
minus_x
...
v0.10.0rc2
Author | SHA1 | Date | |
---|---|---|---|
6d8d0a24c0 | |||
11ef7a611e | |||
dc2f159f8a | |||
d5b981f8b1 | |||
eec6942014 | |||
fd48d99ffd | |||
f8c15c4efb | |||
aa08a954f9 | |||
13e4ee1dc3 | |||
772ce5af97 | |||
63d92abb7c | |||
11599b0e1f | |||
f3137cdd81 | |||
82ec66f514 | |||
78c13e30e1 | |||
5c9b807b34 | |||
14bf19e39f | |||
4ac7713e32 | |||
8560a5b258 | |||
316b1bf706 | |||
7c734ee09b | |||
f59ec35b7f | |||
2671334d45 | |||
2cc5016a19 | |||
6929f8b437 | |||
32ec9e2f2a | |||
accac82928 | |||
23637dcdef | |||
6364af92f8 | |||
7aaa2bd5a8 | |||
2f5c14de6a | |||
f002e9a870 | |||
a1f3610fc6 | |||
4ecedd1806 | |||
107111a859 | |||
2dec7c1a5d | |||
08d2bd78da | |||
4f76a05f4f | |||
f154bb9ff0 | |||
3ec7170ff1 | |||
c401c64b4c | |||
b77c7d327f | |||
35bc8bd5fb | |||
4594fc3b28 | |||
ae268b6326 | |||
35366ae57c | |||
2226d5bd85 | |||
44554a0068 | |||
226b452a20 | |||
f38ee34a0a | |||
b194557a6c | |||
774d0c014b | |||
2c8db17cfd | |||
4fb56914c5 | |||
0df4d9b06b | |||
ed25054577 | |||
10904e6d75 | |||
a32237665d | |||
bc8a8ce5ec | |||
32142b3c62 | |||
82b8027be6 | |||
3779eb8c81 | |||
9e23ad9655 | |||
e69a92a1ce | |||
8425f785ad | |||
c17231e827 | |||
6e5b5ca580 | |||
488d8a986a | |||
af376ca19d | |||
e7b2042681 | |||
90f1e55421 | |||
5e70dcd6e6 | |||
25d585ab7b | |||
8d0a01a5f2 | |||
0ec82edda5 | |||
005ae9be6c | |||
29d1ffc5b4 | |||
304dce7ec0 | |||
6ece16c4fe | |||
a0e827e07c | |||
a15a50fc17 | |||
6dda13c86b | |||
6b46c4b653 | |||
d97841078b | |||
e6b90a2805 | |||
be54a951a3 | |||
042af0c8d3 | |||
378d33c392 | |||
940af1f03a | |||
92615d7fe8 | |||
8188196a1c | |||
7ba34b1241 | |||
9499e26e2a | |||
51ba839555 | |||
d1fb65bde3 | |||
3a1d8940ae | |||
2b504eb770 | |||
10eb24cc91 | |||
2e8cbb58f3 | |||
752c6ade2e | |||
881e3cbe3b | |||
9f414a12ad | |||
6a971ed692 | |||
da6579bf41 | |||
c81259d33a | |||
e3a0e43d7f | |||
b3d82108e7 | |||
6d0734c562 | |||
7d94577138 | |||
59f935300c | |||
18e519ec86 | |||
1eaff27815 | |||
cf8cc32674 | |||
3a2cb2649d | |||
3e04107d97 | |||
37bd8d6e4c | |||
468e2400fe | |||
dcc6cfb991 | |||
dd572c0ab3 | |||
9ffe905a41 | |||
9a9fda1423 | |||
466e878f2a | |||
217937221b | |||
5782581acf | |||
0f199f197b | |||
b2eb2b5ad7 | |||
21274ab476 | |||
ed8cbfedf8 | |||
45badd05d0 | |||
4adc66f64d | |||
55ad648715 | |||
5895afd780 | |||
ca4eb82bcb | |||
ba2dfbb0c2 | |||
1bf65138f6 | |||
54cf1cae62 | |||
5780121c95 | |||
c7d8724e78 | |||
b38baabcf9 | |||
89cab4d01f | |||
b9a21e9173 | |||
c4e3b12524 | |||
8dfb45ca33 | |||
8a8fc94639 | |||
4de7146351 | |||
ac9fb732a5 | |||
a3a6c695f4 | |||
90bd2ab6e3 | |||
9fb2d22032 | |||
2d6a38209b | |||
89e3c4e9b4 | |||
fe8a2c544a | |||
4ef00b5cac | |||
5a7fb3ab9e | |||
11dfdf21bf | |||
fdc5b43d20 | |||
c5b8b5953a | |||
4fcef49ec4 | |||
8a4e5c5f3c | |||
76b494444f | |||
28a6d5423d | |||
58760e12b1 | |||
a50d918225 | |||
c9ba8104ed | |||
4e7dfbe7b4 | |||
72ad273582 | |||
01513a334a | |||
ac2bf41e53 | |||
a931b4cdcf | |||
a0f8a79646 | |||
18bdcf4113 | |||
1c3198b6c4 | |||
260127ea54 | |||
d0dc4cfca4 | |||
d31a647124 | |||
85431bd9ad | |||
c11013db8b | |||
1eb2b9c102 | |||
6ebf313790 | |||
cfbcb9ed87 | |||
76ddeff293 | |||
f46098335b | |||
e9534c7202 | |||
7976446015 | |||
fcb9f879c1 | |||
3ed94f9d0a | |||
fa839565f2 | |||
75a99b98bf | |||
b5c3b68359 | |||
6cbc4d4bea | |||
153c6f1e61 | |||
34cda778a0 | |||
30800b01c2 | |||
10be209493 | |||
19c863068b | |||
f29fd8a7f8 | |||
ed10f3cea1 | |||
b637e9dcb8 | |||
1e36c8687e | |||
5bac61362b | |||
313ae8c16a | |||
c847e34b39 | |||
e7e3e6d263 | |||
4ffd963fa0 | |||
56fe4bedd6 | |||
d91278181d | |||
20149d84d9 | |||
3534c39a20 | |||
c586b55667 | |||
33d560001e | |||
f148c44c6a | |||
235bfd5dfe | |||
68d28e37b0 | |||
37a7d5d74a | |||
d4d309409f | |||
85bd6599e4 | |||
91b3d190ae | |||
fc017915f5 | |||
9ad0a4588b | |||
016b8d1b7f | |||
80305c1b24 | |||
37e2ecace2 | |||
054c8657e3 | |||
d4170fad39 | |||
946aadb4a0 | |||
bcdfb2a330 | |||
ba8c300018 | |||
8cdc371217 | |||
61e20828da | |||
55e1c66da5 | |||
86f3ac21ce | |||
149f2435a5 | |||
c0569dbc82 | |||
8bb43b9c9e | |||
559756214b | |||
6d0cf239c6 | |||
3fc964433a | |||
0caf61c08a | |||
667624659b | |||
38efa28278 | |||
e8cc53af5e | |||
a4851cfe68 | |||
9887e8ec50 | |||
f326ab9c88 | |||
dcf2a5e208 | |||
1e9438e0b0 | |||
697ef765ee | |||
a99b9f7dee | |||
c488b928a7 | |||
2c7fa47161 | |||
88fc8a97e3 | |||
66f6fbd393 | |||
8632e831ba | |||
4bbfc36b16 | |||
80d38b8ac8 | |||
211b6a6113 | |||
247102f07f | |||
bd4c1e6fdb | |||
99b4f080d8 | |||
020f58abcd | |||
c1acd6d7d4 | |||
3b3b778d4a | |||
42d440c22b | |||
f45a332886 | |||
6e2c176e1f | |||
a86754a12b | |||
c2a2f19aba | |||
2c11a738b3 | |||
b639327ad9 | |||
4afe687a82 | |||
5de8d9f111 | |||
c1c8ca57ff | |||
a3a5a47e48 | |||
fb25e95688 | |||
0d4891cd03 | |||
f56d2996ca | |||
147afb448b | |||
3c7d942da8 | |||
890323dc1b | |||
01cae37713 | |||
11c0198615 | |||
b1235c3e10 | |||
44d02f54db | |||
a8593237c0 | |||
fc0f41d10a | |||
7b828e30d5 | |||
5f0af36af5 | |||
0d21b2664c | |||
9907fc4494 | |||
d47661f0cd | |||
53fa457391 | |||
6fb162447b | |||
66177189c5 | |||
b4f0b5f9aa | |||
cbd14ed561 | |||
7bd4c37ae7 | |||
8020e98c9f | |||
762be26a8e | |||
6a9e6b2abf | |||
5d09152ff1 | |||
31d5c1797f | |||
35514b682a | |||
e2de455c34 | |||
5b032352cc | |||
922f316441 | |||
5923ab9524 | |||
0cf893cae1 | |||
cf75cd2098 | |||
b854321ffe | |||
5b6fe23d05 | |||
f0c98cae27 | |||
574ad60db9 | |||
fdadb6f43a | |||
41060c6e08 | |||
3de2ed767f | |||
299252ea82 | |||
d6902ce79f | |||
5e53c89a74 | |||
c66e38ea4c | |||
251595368f | |||
4bed167768 | |||
b140416abf | |||
5b8366b61a | |||
c7753a9809 | |||
4b9a9435bb | |||
3482fd7e4e | |||
77f77a951e | |||
1a4f35e2ea | |||
be1e128dfb | |||
65393ee064 | |||
dc221ad72d | |||
7571a4a7e5 | |||
f67d986dd1 | |||
cc876d0f29 | |||
fdfd409f8f | |||
ffbcc9e757 | |||
59389c927b | |||
8f2720def9 | |||
ad6c2e1a0b | |||
49e8c7ea25 | |||
805d62ca88 | |||
b7d9e9416f | |||
7c12a765aa | |||
cd587c93ef | |||
332d4cb17b | |||
bf03ff3575 | |||
47043eb678 | |||
31b96d1c64 | |||
e59ba9e142 | |||
403b481573 | |||
138709f8d1 | |||
0bbac1c1b4 | |||
a3e4e85ece | |||
eb58f5953d | |||
4ac9c33f78 | |||
efe73d0575 | |||
853487bc1b | |||
9ff2af6d2b | |||
70ca5484f5 | |||
5358cce5ff | |||
2155e95ef1 | |||
f95570a52d | |||
b6e7e3d58f | |||
e760fcef22 | |||
6bbf1795b7 | |||
9e0ef888f0 | |||
97abeb1daa | |||
34dad19e7b | |||
6db31e7a27 | |||
977180c912 | |||
c40784c794 | |||
baed180aa0 | |||
0b407479ef | |||
5eaf570050 | |||
d8ee5a2ca4 | |||
b9fca83256 | |||
32dffc2772 | |||
c438183e99 | |||
baba0389f7 | |||
c6c22f16d3 | |||
dd382e0fe3 | |||
849590a2a7 | |||
a4c23314c0 | |||
b942c094e3 | |||
b4bab81660 | |||
b91cb3fa5c | |||
71d1d75b7a | |||
72d14d0eed | |||
e34d130c16 | |||
7721ef1786 | |||
8369b7c2a9 | |||
3eb4ad53f3 | |||
90a2769f20 | |||
e60d422f19 | |||
0d914c81a2 | |||
6e428cdd7a | |||
93b9d9f499 | |||
af107d5a0e | |||
31c5d0a1b7 | |||
afb7cff1b9 | |||
d2e841a10a | |||
14601f5fba | |||
042d131f39 | |||
8e807cdfa4 | |||
e601efcb10 | |||
22dd9c2730 | |||
a6d795d593 | |||
a37d75bbec | |||
edd270bc78 | |||
110df74332 | |||
1ad69e8375 | |||
b8a498c9b2 | |||
923147b5e8 | |||
45877ef740 | |||
6e4bef1bea | |||
4ff79a136e | |||
448acad31e | |||
eb0b2d2f08 | |||
3112271f6e | |||
1fd471e957 | |||
2c5ebec064 | |||
2e610deb72 | |||
6e2c19ce22 | |||
47db8c2c15 | |||
462b269280 | |||
c18b3b8e8b | |||
9528e3a05e | |||
9fb52e523a | |||
e202dd2736 | |||
43813e6361 | |||
cede942b87 | |||
fe1e924811 | |||
4548c03c50 | |||
40b86aa05e | |||
432870829d | |||
f73d02aadc | |||
c5ebe040ac | |||
8d763cb891 | |||
cf4cd53982 | |||
32c9be2200 | |||
8aeaa910a2 | |||
906e05d840 | |||
ef9a2990ae | |||
7e90870491 | |||
d3f05c9248 | |||
c108781c85 | |||
3d184b95b8 | |||
2f35a022e6 | |||
ffe00ef77a | |||
5561681d04 | |||
fbd62d8750 | |||
2e26f9156a | |||
9e5452ee34 | |||
0e3fe896e2 | |||
1caca5a589 | |||
783921d889 | |||
4a98edff1f | |||
a7bab0c9e5 | |||
25950dca9b | |||
a4113b035c | |||
7e1665b089 | |||
8d1096e7db | |||
8d775dd30a | |||
78fe77534b | |||
2f2fcb31b8 | |||
1dba2c4ebe | |||
71d6de3a26 | |||
536fd33003 | |||
619b9f5c7e | |||
d1b689c445 | |||
9854dc9040 | |||
ff5c60fad8 | |||
6f1229f91d | |||
1819fbda63 | |||
7f0367109e | |||
fb14d53cf6 | |||
b024a42e93 | |||
cb97f2bfc5 | |||
359200f6ac | |||
220aee902a | |||
67d25eca05 | |||
363528de27 | |||
4ff61ababa | |||
0ec3779df7 | |||
b616f6a53d | |||
2e25bb12a8 | |||
9965c47d0d | |||
059d4cdb49 | |||
bdb84e26b0 | |||
3dd359147d | |||
657f2f301a | |||
a1aafc827a | |||
139508a418 | |||
d265414dbc | |||
48fb076cbc |
@ -46,6 +46,6 @@ while getopts "m:b:l:f:t:" OPT; do
|
||||
done
|
||||
|
||||
lm_eval --model vllm \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,distributed_executor_backend=ray,trust_remote_code=true,max_model_len=4096" \
|
||||
--model_args "pretrained=$MODEL,tensor_parallel_size=$TP_SIZE,add_bos_token=true,trust_remote_code=true,max_model_len=4096" \
|
||||
--tasks gsm8k --num_fewshot "$FEWSHOT" --limit "$LIMIT" \
|
||||
--batch_size "$BATCH_SIZE"
|
||||
|
@ -18,12 +18,14 @@ RTOL = 0.08
|
||||
|
||||
def launch_lm_eval(eval_config, tp_size):
|
||||
trust_remote_code = eval_config.get("trust_remote_code", False)
|
||||
max_model_len = eval_config.get("max_model_len", 4096)
|
||||
model_args = (
|
||||
f"pretrained={eval_config['model_name']},"
|
||||
f"tensor_parallel_size={tp_size},"
|
||||
f"enforce_eager=true,"
|
||||
f"add_bos_token=true,"
|
||||
f"trust_remote_code={trust_remote_code}"
|
||||
f"trust_remote_code={trust_remote_code},"
|
||||
f"max_model_len={max_model_len}"
|
||||
)
|
||||
results = lm_eval.simple_evaluate(
|
||||
model="vllm",
|
||||
|
@ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
@ -31,13 +31,27 @@ Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||
|
||||
Manually Trigger the benchmark
|
||||
|
||||
```bash
|
||||
bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
|
||||
Runtime environment variables:
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
|
||||
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead.
|
||||
### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
@ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
@ -4,7 +4,8 @@
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
{latency_tests_markdown_table}
|
||||
@ -14,7 +15,8 @@
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput.
|
||||
|
||||
{throughput_tests_markdown_table}
|
||||
@ -25,12 +27,18 @@
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
- For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts.
|
||||
|
||||
{serving_tests_markdown_table}
|
||||
|
||||
## Platform Information
|
||||
|
||||
{platform_markdown_table}
|
||||
|
||||
## json version of the benchmarking tables
|
||||
|
||||
This section contains the data of the markdown tables above in JSON format.
|
||||
|
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
compare_frames = []
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
return concat_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
ignore_test_name = args.ignore_test_name
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
drop_column,
|
||||
ignore_test_name=ignore_test_name,
|
||||
)
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
@ -3,9 +3,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
@ -29,11 +31,11 @@ throughput_results = []
|
||||
throughput_results_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "num_requests": "# of req.",
|
||||
# "total_num_tokens": "Total # of tokens",
|
||||
# "elapsed_time": "Elapsed time (s)",
|
||||
"num_requests": "# of req.",
|
||||
"total_num_tokens": "Total # of tokens",
|
||||
"elapsed_time": "Elapsed time (s)",
|
||||
"requests_per_second": "Tput (req/s)",
|
||||
# "tokens_per_second": "Tput (tok/s)",
|
||||
"tokens_per_second": "Tput (tok/s)",
|
||||
}
|
||||
|
||||
# serving results and the keys that will be printed into markdown
|
||||
@ -41,16 +43,18 @@ serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "completed": "# of req.",
|
||||
"completed": "# of req.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
# "input_throughput": "Input Tput (tok/s)",
|
||||
# "output_throughput": "Output Tput (tok/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
# "mean_tpot_ms": "Mean TPOT (ms)",
|
||||
# "median_tpot_ms": "Median",
|
||||
# "p99_tpot_ms": "P99",
|
||||
"mean_tpot_ms": "Mean TPOT (ms)",
|
||||
"median_tpot_ms": "Median",
|
||||
"p99_tpot_ms": "P99",
|
||||
"mean_itl_ms": "Mean ITL (ms)",
|
||||
"median_itl_ms": "Median ITL (ms)",
|
||||
"p99_itl_ms": "P99 ITL (ms)",
|
||||
@ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving):
|
||||
)
|
||||
|
||||
|
||||
def get_size_with_unit(bytes, suffix="B"):
|
||||
"""
|
||||
Scale bytes to its proper format
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if bytes < factor:
|
||||
return f"{bytes:.2f}{unit}{suffix}"
|
||||
bytes /= factor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
@ -155,6 +173,27 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||
|
||||
svmem = psutil.virtual_memory()
|
||||
platform_data = {
|
||||
"Physical cores": [psutil.cpu_count(logical=False)],
|
||||
"Total cores": [psutil.cpu_count(logical=True)],
|
||||
"Total Memory": [get_size_with_unit(svmem.total)],
|
||||
}
|
||||
|
||||
if util.find_spec("numa") is not None:
|
||||
from numa import info
|
||||
|
||||
platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()]
|
||||
|
||||
if util.find_spec("cpuinfo") is not None:
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]]
|
||||
|
||||
platform_results = pd.DataFrame.from_dict(
|
||||
platform_data, orient="index", columns=["Platform Info"]
|
||||
)
|
||||
|
||||
raw_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
@ -200,6 +239,9 @@ if __name__ == "__main__":
|
||||
throughput_md_table = tabulate(
|
||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
platform_md_table = tabulate(
|
||||
platform_results, headers="keys", tablefmt="pipe", showindex=True
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
@ -211,6 +253,7 @@ if __name__ == "__main__":
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
serving_tests_markdown_table=serving_md_table,
|
||||
platform_markdown_table=platform_md_table,
|
||||
benchmarking_results_in_json_string=processed_results_json,
|
||||
)
|
||||
f.write(results)
|
||||
|
@ -31,6 +31,20 @@ check_gpus() {
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_cpus() {
|
||||
# check the number of CPUs and NUMA Node and GPU type.
|
||||
declare -g numa_count=$(python3 -c "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)")
|
||||
if [[ $numa_count -gt 0 ]]; then
|
||||
echo "NUMA found."
|
||||
echo $numa_count
|
||||
else
|
||||
echo "Need at least 1 NUMA to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type="cpu"
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_hf_token() {
|
||||
# check if HF_TOKEN is available and valid
|
||||
if [[ -z "$HF_TOKEN" ]]; then
|
||||
@ -69,6 +83,22 @@ json2args() {
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
json2envs() {
|
||||
# transforms the JSON string to environment variables.
|
||||
# example:
|
||||
# input: { "VLLM_CPU_KVCACHE_SPACE": 5 }
|
||||
# output: VLLM_CPU_KVCACHE_SPACE=5
|
||||
local json_string=$1
|
||||
local args=$(
|
||||
echo "$json_string" | jq -r '
|
||||
to_entries |
|
||||
map((.key ) + "=" + (.value | tostring)) |
|
||||
join(" ")
|
||||
'
|
||||
)
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
@ -158,15 +188,24 @@ run_latency_tests() {
|
||||
# get arguments
|
||||
latency_params=$(echo "$params" | jq -r '.parameters')
|
||||
latency_args=$(json2args "$latency_params")
|
||||
latency_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
latency_envs=$(json2envs "$latency_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
latency_command="python3 benchmark_latency.py \
|
||||
latency_command=" $latency_envs python3 benchmark_latency.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$latency_args"
|
||||
|
||||
@ -216,15 +255,24 @@ run_throughput_tests() {
|
||||
# get arguments
|
||||
throughput_params=$(echo "$params" | jq -r '.parameters')
|
||||
throughput_args=$(json2args "$throughput_params")
|
||||
throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
throughput_envs=$(json2envs "$throughput_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
throughput_command="python3 benchmark_throughput.py \
|
||||
throughput_command=" $throughput_envs python3 benchmark_throughput.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$throughput_args"
|
||||
|
||||
@ -272,18 +320,27 @@ run_serving_tests() {
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||
server_envs=$(echo "$params" | jq -r '.server_environment_variables')
|
||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||
server_args=$(json2args "$server_params")
|
||||
server_envs=$(json2envs "$server_envs")
|
||||
client_args=$(json2args "$client_params")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
# check if server model and client model is aligned
|
||||
@ -294,23 +351,33 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
server_command="python3 \
|
||||
server_command="$server_envs python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
$server_args"
|
||||
|
||||
# run the server
|
||||
echo "Running test case $test_name"
|
||||
echo "Server command: $server_command"
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vllm server is up and running."
|
||||
# support remote vllm server
|
||||
client_remote_args=""
|
||||
if [[ -z "${REMOTE_HOST}" ]]; then
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vLLM server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "vLLM failed to start within the timeout period."
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
echo "vllm failed to start within the timeout period."
|
||||
server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT"
|
||||
if [[ ${REMOTE_PORT} ]]; then
|
||||
client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT "
|
||||
else
|
||||
client_remote_args=" --host=$REMOTE_HOST "
|
||||
fi
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
@ -332,7 +399,7 @@ run_serving_tests() {
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args"
|
||||
$client_args $client_remote_args "
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
@ -360,7 +427,14 @@ run_serving_tests() {
|
||||
}
|
||||
|
||||
main() {
|
||||
check_gpus
|
||||
local ARCH
|
||||
ARCH=''
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
check_cpus
|
||||
ARCH='-cpu'
|
||||
else
|
||||
check_gpus
|
||||
fi
|
||||
check_hf_token
|
||||
|
||||
# Set to v1 to run v1 benchmark
|
||||
@ -386,9 +460,9 @@ main() {
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
|
||||
# benchmarking
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}"
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}"
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}"
|
||||
|
||||
# postprocess benchmarking results
|
||||
pip install tabulate pandas
|
||||
|
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "latency_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
}
|
||||
]
|
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
@ -0,0 +1,158 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
]
|
@ -0,0 +1,32 @@
|
||||
[
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
@ -52,7 +52,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
@ -101,7 +101,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
|
@ -107,10 +107,8 @@ fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
|
@ -6,6 +6,7 @@ set -ex
|
||||
|
||||
# allow to bind to different cores
|
||||
CORE_RANGE=${CORE_RANGE:-48-95}
|
||||
# used for TP/PP E2E test
|
||||
OMP_CORE_RANGE=${OMP_CORE_RANGE:-48-95}
|
||||
NUMA_NODE=${NUMA_NODE:-1}
|
||||
|
||||
@ -24,8 +25,8 @@ numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --tag cpu-test-"$NUMA_NODE
|
||||
numactl -C "$CORE_RANGE" -N "$NUMA_NODE" docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" --tag cpu-test-"$NUMA_NODE"-avx2 --target vllm-test -f docker/Dockerfile.cpu .
|
||||
|
||||
# Run the image, setting --shm-size=4g for tensor parallel.
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_OMP_THREADS_BIND="$OMP_CORE_RANGE" --env VLLM_CPU_CI_ENV=1 --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE" cpu-test-"$NUMA_NODE"
|
||||
docker run -itd --cpuset-cpus="$CORE_RANGE" --cpuset-mems="$NUMA_NODE" --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --env VLLM_CPU_CI_ENV=1 -e E2E_OMP_THREADS="$OMP_CORE_RANGE" --shm-size=4g --name cpu-test-"$NUMA_NODE"-avx2 cpu-test-"$NUMA_NODE"-avx2
|
||||
|
||||
function cpu_tests() {
|
||||
set -e
|
||||
@ -48,10 +49,16 @@ function cpu_tests() {
|
||||
# Run basic model test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation -m cpu_model
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model
|
||||
# Note: disable until supports V1
|
||||
# pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
# pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
|
||||
# Note: disable Bart until supports V1
|
||||
pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model \
|
||||
--ignore=tests/models/language/generation/test_bart.py
|
||||
|
||||
pytest -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
@ -62,33 +69,26 @@ function cpu_tests() {
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_static_setup \
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_dynamic_per_token"
|
||||
tests/quantization/test_compressed_tensors.py::test_compressed_tensors_w8a8_logprobs[False-10-32-neuralmagic/Llama-3.2-1B-quantized.w8a8]"
|
||||
|
||||
# Note: disable it until supports V1
|
||||
# Run AWQ test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
VLLM_USE_V1=0 pytest -s -v \
|
||||
tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# Run chunked-prefill and prefix-cache test
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
set -e
|
||||
pytest -s -v -k cpu_model \
|
||||
tests/basic_correctness/test_chunked_prefill.py"
|
||||
# docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
# set -e
|
||||
# VLLM_USE_V1=0 pytest -s -v \
|
||||
# tests/quantization/test_ipex_quant.py"
|
||||
|
||||
# online serving
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c '
|
||||
set -e
|
||||
python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m --dtype half &
|
||||
timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1
|
||||
VLLM_CPU_CI_ENV=0 python3 benchmarks/benchmark_serving.py \
|
||||
VLLM_CPU_OMP_THREADS_BIND=$E2E_OMP_THREADS VLLM_CPU_SGL_KERNEL=1 vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 &
|
||||
timeout 600 bash -c "until curl localhost:8000/v1/models; do sleep 1; done" || exit 1
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--dataset-name random \
|
||||
--model facebook/opt-125m \
|
||||
--model meta-llama/Llama-3.2-3B-Instruct \
|
||||
--num-prompts 20 \
|
||||
--endpoint /v1/completions \
|
||||
--tokenizer facebook/opt-125m"
|
||||
--endpoint /v1/completions'
|
||||
|
||||
# Run multi-lora tests
|
||||
docker exec cpu-test-"$NUMA_NODE" bash -c "
|
||||
|
@ -6,19 +6,17 @@ set -exuo pipefail
|
||||
|
||||
# Try building the docker image
|
||||
cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - .
|
||||
FROM 1.22-413-pt2.7.1:latest
|
||||
FROM gaudi-base-image:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
|
||||
RUN VLLM_TARGET_DEVICE=empty pip install .
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
@ -62,7 +62,8 @@ echo "Results will be stored in: $RESULTS_DIR"
|
||||
echo "--- Installing Python dependencies ---"
|
||||
python3 -m pip install --progress-bar off git+https://github.com/thuml/depyf.git \
|
||||
&& python3 -m pip install --progress-bar off pytest pytest-asyncio tpu-info \
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4
|
||||
&& python3 -m pip install --progress-bar off lm_eval[api]==0.4.4 \
|
||||
&& python3 -m pip install --progress-bar off hf-transfer
|
||||
echo "--- Python dependencies installed ---"
|
||||
export VLLM_USE_V1=1
|
||||
export VLLM_XLA_CHECK_RECOMPILATION=1
|
||||
@ -70,7 +71,7 @@ export VLLM_XLA_CACHE_PATH=
|
||||
echo "Using VLLM V1"
|
||||
|
||||
echo "--- Hardware Information ---"
|
||||
tpu-info
|
||||
# tpu-info
|
||||
echo "--- Starting Tests ---"
|
||||
set +e
|
||||
overall_script_exit_code=0
|
||||
@ -150,7 +151,7 @@ run_and_track_test 9 "test_multimodal.py" \
|
||||
run_and_track_test 10 "test_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_pallas.py"
|
||||
run_and_track_test 11 "test_struct_output_generate.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
"HF_HUB_DISABLE_XET=1 python3 -m pytest -s -v /workspace/vllm/tests/v1/entrypoints/llm/test_struct_output_generate.py -k \"not test_structured_output_with_reasoning_matrices\""
|
||||
run_and_track_test 12 "test_moe_pallas.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/tpu/test_moe_pallas.py"
|
||||
run_and_track_test 13 "test_lora.py" \
|
||||
|
@ -11,8 +11,8 @@ container_name="xpu_${BUILDKITE_COMMIT}_$(tr -dc A-Za-z0-9 < /dev/urandom | head
|
||||
docker build -t ${image_name} -f docker/Dockerfile.xpu .
|
||||
|
||||
# Setup cleanup
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
remove_docker_container() {
|
||||
docker rm -f "${container_name}" || true;
|
||||
docker image rm -f "${image_name}" || true;
|
||||
docker system prune -f || true;
|
||||
}
|
||||
@ -26,7 +26,18 @@ docker run \
|
||||
--name "${container_name}" \
|
||||
"${image_name}" \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend ray
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager -tp 2 --distributed-executor-backend mp
|
||||
cd tests
|
||||
pytest -v -s v1/core
|
||||
pytest -v -s v1/engine
|
||||
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
|
||||
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
|
||||
pytest -v -s v1/structured_output
|
||||
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py
|
||||
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py
|
||||
pytest -v -s v1/test_serial_utils.py
|
||||
pytest -v -s v1/test_utils.py
|
||||
pytest -v -s v1/test_metrics_reader.py
|
||||
'
|
||||
|
@ -22,16 +22,6 @@ trap remove_docker_container EXIT
|
||||
# Remove the container that might not be cleaned up in the previous run.
|
||||
remove_docker_container
|
||||
|
||||
# Build docker image.
|
||||
# TODO: build the image outside the script and share the image with other
|
||||
# tpu test if building time is too long.
|
||||
DOCKER_BUILDKIT=1 docker build \
|
||||
--build-arg max_jobs=16 \
|
||||
--build-arg USE_SCCACHE=1 \
|
||||
--build-arg GIT_REPO_CHECK=0 \
|
||||
--tag vllm/vllm-tpu-bm \
|
||||
--progress plain -f docker/Dockerfile.tpu .
|
||||
|
||||
LOG_ROOT=$(mktemp -d)
|
||||
# If mktemp fails, set -e will cause the script to exit.
|
||||
echo "Results will be stored in: $LOG_ROOT"
|
||||
|
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
@ -0,0 +1,14 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8bw8a8
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8
|
||||
MAX_NUM_SEQS=128
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
EXPECTED_THROUGHPUT=10.0
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=128
|
@ -117,7 +117,7 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s core
|
||||
|
||||
- label: Entrypoints Test # 40min
|
||||
- label: Entrypoints Test (LLM) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
@ -125,8 +125,6 @@ steps:
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/llm
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
- tests/entrypoints/offline_mode
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
@ -135,9 +133,21 @@ steps:
|
||||
- pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process
|
||||
- pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Entrypoints Test (API Server) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
fast_check: true
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/entrypoints/openai
|
||||
- tests/entrypoints/test_chat_utils
|
||||
commands:
|
||||
- export VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
- pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_chat_with_tool_reasoning.py --ignore=entrypoints/openai/test_oot_registration.py --ignore=entrypoints/openai/test_tensorizer_entrypoint.py --ignore=entrypoints/openai/correctness/
|
||||
- pytest -v -s entrypoints/test_chat_utils.py
|
||||
- VLLM_USE_V1=0 pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests
|
||||
|
||||
- label: Distributed Tests (4 GPUs) # 10min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -149,12 +159,14 @@ steps:
|
||||
- tests/distributed/test_utils
|
||||
- tests/distributed/test_pynccl
|
||||
- tests/distributed/test_events
|
||||
- tests/spec_decode/e2e/test_integration_dist_tp4
|
||||
- tests/compile/test_basic_correctness
|
||||
- examples/offline_inference/rlhf.py
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/test_internal_lb_dp.py
|
||||
- tests/v1/test_hybrid_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -163,14 +175,16 @@ steps:
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_internal_lb_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=4 pytest -v -s v1/test_hybrid_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
- pytest -v -s distributed/test_pynccl.py
|
||||
- pytest -v -s distributed/test_events.py
|
||||
- pytest -v -s spec_decode/e2e/test_integration_dist_tp4.py
|
||||
# TODO: create a dedicated test section for multi-GPU example tests
|
||||
# when we have multiple distributed example tests
|
||||
- pushd ../examples/offline_inference
|
||||
@ -225,7 +239,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@ -254,6 +268,7 @@ steps:
|
||||
- pytest -v -s v1/structured_output
|
||||
- pytest -v -s v1/spec_decode
|
||||
- pytest -v -s v1/kv_connector/unit
|
||||
- pytest -v -s v1/metrics
|
||||
- pytest -v -s v1/test_serial_utils.py
|
||||
- pytest -v -s v1/test_utils.py
|
||||
- pytest -v -s v1/test_oracle.py
|
||||
@ -262,11 +277,11 @@ steps:
|
||||
# VLLM_USE_FLASHINFER_SAMPLER or not on H100.
|
||||
- pytest -v -s v1/e2e
|
||||
# Integration test for streaming correctness (requires special branch).
|
||||
- pip install -U git+https://github.com/robertgshaw2-neuralmagic/lm-evaluation-harness.git@streaming-api
|
||||
- pip install -U git+https://github.com/robertgshaw2-redhat/lm-evaluation-harness.git@streaming-api
|
||||
- pytest -v -s entrypoints/openai/correctness/test_lmeval.py::test_lm_eval_accuracy_v1_engine
|
||||
|
||||
- label: Examples Test # 25min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
working_dir: "/vllm-workspace/examples"
|
||||
source_file_dependencies:
|
||||
- vllm/entrypoints
|
||||
@ -280,7 +295,7 @@ steps:
|
||||
- python3 offline_inference/llm_engine_example.py
|
||||
- python3 offline_inference/audio_language.py --seed 0
|
||||
- python3 offline_inference/vision_language.py --seed 0
|
||||
- python3 offline_inference/vision_language_embedding.py --seed 0
|
||||
- python3 offline_inference/vision_language_pooling.py --seed 0
|
||||
- python3 offline_inference/vision_language_multi_image.py --seed 0
|
||||
- VLLM_USE_V1=0 python3 others/tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 others/tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors
|
||||
- python3 offline_inference/encoder_decoder.py
|
||||
@ -300,7 +315,7 @@ steps:
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/cuda
|
||||
@ -318,19 +333,8 @@ steps:
|
||||
- pytest -v -s samplers
|
||||
- VLLM_USE_FLASHINFER_SAMPLER=1 pytest -v -s samplers
|
||||
|
||||
- label: Speculative decoding tests # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/spec_decode
|
||||
- tests/spec_decode
|
||||
- vllm/model_executor/models/eagle.py
|
||||
commands:
|
||||
- pytest -v -s spec_decode/e2e/test_multistep_correctness.py
|
||||
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py --ignore=spec_decode/e2e/test_mtp_correctness.py
|
||||
- pytest -v -s spec_decode/e2e/test_eagle_correctness.py
|
||||
|
||||
- label: LoRA Test %N # 15min each
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/lora
|
||||
- tests/lora
|
||||
@ -338,7 +342,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -382,7 +386,7 @@ steps:
|
||||
- pytest -v -s kernels/core
|
||||
|
||||
- label: Kernels Attention Test %N
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/attention/
|
||||
- vllm/attention
|
||||
@ -393,7 +397,7 @@ steps:
|
||||
parallelism: 2
|
||||
|
||||
- label: Kernels Quantization Test %N
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- csrc/quantization/
|
||||
- vllm/model_executor/layers/quantization
|
||||
@ -412,7 +416,7 @@ steps:
|
||||
- pytest -v -s kernels/moe
|
||||
|
||||
- label: Kernels Mamba Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- csrc/mamba/
|
||||
- tests/kernels/mamba
|
||||
@ -434,7 +438,6 @@ steps:
|
||||
|
||||
- label: Model Executor Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor
|
||||
- tests/model_executor
|
||||
@ -491,7 +494,7 @@ steps:
|
||||
- pytest -s entrypoints/openai/correctness/
|
||||
|
||||
- label: Encoder Decoder tests # 5min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/encoder_decoder
|
||||
@ -499,7 +502,7 @@ steps:
|
||||
- pytest -v -s encoder_decoder
|
||||
|
||||
- label: OpenAI-Compatible Tool Use # 20 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
fast_check: false
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -512,7 +515,7 @@ steps:
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -601,7 +604,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -611,7 +614,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=1) and not core_model'
|
||||
|
||||
- label: Quantized Models Test
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/layers/quantization
|
||||
- tests/models/quantization
|
||||
@ -628,6 +631,18 @@ steps:
|
||||
# e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py
|
||||
# *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR*
|
||||
|
||||
- label: Transformers Nightly Models Test
|
||||
working_dir: "/vllm-workspace/"
|
||||
optional: true
|
||||
commands:
|
||||
- pip install --upgrade git+https://github.com/huggingface/transformers
|
||||
- pytest -v -s tests/models/test_initialization.py
|
||||
- pytest -v -s tests/models/multimodal/processing/
|
||||
- pytest -v -s tests/models/multimodal/test_mapping.py
|
||||
- python3 examples/offline_inference/basic/chat.py
|
||||
- python3 examples/offline_inference/audio_language.py --model-type whisper
|
||||
- python3 examples/offline_inference/vision_language.py --model-type qwen2_5_vl
|
||||
|
||||
##### 1 GPU test #####
|
||||
##### multi gpus test #####
|
||||
|
||||
@ -682,10 +697,12 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
@ -700,10 +717,10 @@ steps:
|
||||
- pytest -v -s distributed/test_sequence_parallel.py
|
||||
# this test fails consistently.
|
||||
# TODO: investigate and fix
|
||||
# - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py
|
||||
- VLLM_USE_V1=0 CUDA_VISIBLE_DEVICES=0,1 pytest -v -s kv_transfer/test_disagg.py
|
||||
- CUDA_VISIBLE_DEVICES=0,1 pytest -v -s v1/shutdown
|
||||
- pytest -v -s models/multimodal/generation/test_maverick.py
|
||||
|
||||
- label: Plugin Tests (2 GPUs) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
|
6
.gemini/config.yaml
Normal file
6
.gemini/config.yaml
Normal file
@ -0,0 +1,6 @@
|
||||
# https://developers.google.com/gemini-code-assist/docs/customize-gemini-behavior-github
|
||||
have_fun: false # Just review the code
|
||||
code_review:
|
||||
comment_severity_threshold: HIGH # Reduce quantity of comments
|
||||
pull_request_opened:
|
||||
summary: false # Don't summarize the PR in a separate comment
|
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -16,7 +16,8 @@
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
/vllm/compilation @zou3519 @youkaichao @ProExpertProg
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
@ -42,7 +43,6 @@ CMakeLists.txt @tlrmchlsmth
|
||||
/tests/multimodal @DarkLight1337 @ywang96
|
||||
/tests/prefix_caching @comaniac @KuntaiDu
|
||||
/tests/quantization @mgoin @robertgshaw2-redhat
|
||||
/tests/spec_decode @njhill @LiuXiaoxuanPKU
|
||||
/tests/test_inputs.py @DarkLight1337 @ywang96
|
||||
/tests/v1/entrypoints/llm/test_struct_output_generate.py @mgoin @russellb @aarnphm
|
||||
/tests/v1/structured_output @mgoin @russellb @aarnphm
|
||||
|
2
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
2
.github/ISSUE_TEMPLATE/750-RFC.yml
vendored
@ -46,7 +46,7 @@ body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
||||
Thanks for contributing 🎉! The vLLM core team hosts a biweekly RFC review session at 9:30AM Pacific Time, while most RFCs can be discussed online, you can optionally sign up for a slot to discuss your RFC online [here](https://docs.google.com/document/d/1CiLVBZeIVfR7_PNAKVSusxpceywkoOOB78qoWqHvSZc/edit).
|
||||
- type: checkboxes
|
||||
id: askllm
|
||||
attributes:
|
||||
|
22
.github/mergify.yml
vendored
22
.github/mergify.yml
vendored
@ -74,14 +74,23 @@ pull_request_rules:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-new-model
|
||||
description: Automatically apply new-model label
|
||||
conditions:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- new-model
|
||||
|
||||
- name: label-performance
|
||||
description: Automatically apply performance label
|
||||
conditions:
|
||||
@ -155,9 +164,12 @@ pull_request_rules:
|
||||
description: Automatically apply speculative-decoding label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
- files~=^vllm/v1/spec_decode/
|
||||
- files~=^tests/v1/spec_decode/
|
||||
- files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py
|
||||
- files~=^vllm/model_executor/models/.*eagle.*\.py
|
||||
- files=vllm/model_executor/models/mlp_speculator.py
|
||||
- files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -68,7 +68,7 @@ jobs:
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
|
1
.gitignore
vendored
1
.gitignore
vendored
@ -146,6 +146,7 @@ venv.bak/
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
docs/argparse
|
||||
docs/examples
|
||||
|
||||
# mypy
|
||||
|
@ -21,7 +21,7 @@ repos:
|
||||
- id: ruff-format
|
||||
files: ^(.buildkite|benchmarks|examples)/.*
|
||||
- repo: https://github.com/crate-ci/typos
|
||||
rev: v1.32.0
|
||||
rev: v1.34.0
|
||||
hooks:
|
||||
- id: typos
|
||||
- repo: https://github.com/PyCQA/isort
|
||||
@ -166,11 +166,11 @@ repos:
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py
|
||||
files: vllm/config.py|tests/test_config.py|vllm/entrypoints/openai/cli_args.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
entry: bash -c 'echo "To bypass pre-commit hooks, add --no-verify to git commit."'
|
||||
entry: bash -c 'echo "To bypass all the pre-commit hooks, add --no-verify to git commit. To skip a specific hook, prefix the commit command with SKIP=<hook-id>."'
|
||||
language: system
|
||||
verbose: true
|
||||
pass_filenames: false
|
||||
|
@ -45,7 +45,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx942;gfx950;gfx1030;gfx1100;gfx1
|
||||
# requirements.txt files and should be kept consistent. The ROCm torch
|
||||
# versions are derived from docker/Dockerfile.rocm
|
||||
#
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.0")
|
||||
set(TORCH_SUPPORTED_VERSION_CUDA "2.7.1")
|
||||
set(TORCH_SUPPORTED_VERSION_ROCM "2.7.0")
|
||||
|
||||
#
|
||||
@ -171,7 +171,6 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}")
|
||||
endif()
|
||||
|
||||
|
||||
#
|
||||
# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process.
|
||||
# setup.py will override FETCHCONTENT_BASE_DIR to play nicely with sccache.
|
||||
@ -232,7 +231,6 @@ endif()
|
||||
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/mamba/mamba_ssm/selective_scan_fwd.cu"
|
||||
"csrc/mamba/causal_conv1d/causal_conv1d.cu"
|
||||
"csrc/cache_kernels.cu"
|
||||
"csrc/attention/paged_attention_v1.cu"
|
||||
"csrc/attention/paged_attention_v2.cu"
|
||||
@ -259,7 +257,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -298,7 +296,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu"
|
||||
"csrc/sparse/cutlass/sparse_scaled_mm_entry.cu"
|
||||
"csrc/cutlass_extensions/common.cpp"
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu")
|
||||
"csrc/attention/mla/cutlass_mla_entry.cu"
|
||||
"csrc/quantization/fp8/per_token_group_quant.cu")
|
||||
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${VLLM_EXT_SRC}"
|
||||
@ -393,7 +392,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.0 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm90_fp8.cu"
|
||||
@ -409,7 +408,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -424,7 +423,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
@ -438,7 +437,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -453,7 +452,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
|
||||
@ -468,7 +467,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
@ -511,7 +510,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
|
||||
# require CUDA 12.2 or later (and only work on Hopper).
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
@ -520,7 +519,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SPARSE_SCALED_MM_C3X=1")
|
||||
message(STATUS "Building sparse_scaled_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.2 AND SCALED_MM_ARCHS)
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.2 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building sparse_scaled_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.2, we recommend upgrading to CUDA 12.2 or later "
|
||||
"if you intend on running FP8 sparse quantized models on Hopper.")
|
||||
@ -532,7 +531,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# FP4 Archs and flags
|
||||
cuda_archs_loose_intersection(FP4_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND FP4_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND FP4_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
|
||||
"csrc/quantization/fp4/nvfp4_experts_quant.cu"
|
||||
@ -553,9 +552,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
# CUTLASS MLA Archs and flags
|
||||
cuda_archs_loose_intersection(MLA_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND MLA_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND MLA_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu")
|
||||
"csrc/attention/mla/cutlass_mla_kernels.cu"
|
||||
"csrc/attention/mla/sm100_cutlass_mla_kernel.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${MLA_ARCHS}")
|
||||
@ -578,7 +578,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm90.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
@ -596,6 +596,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
@ -615,6 +635,26 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Machete kernels
|
||||
@ -622,7 +662,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# The machete kernels only work on hopper and require CUDA 12.0 or later.
|
||||
# Only build Machete kernels if we are building for something compatible with sm90a
|
||||
cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS)
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND MACHETE_ARCHS)
|
||||
#
|
||||
# For the Machete kernels we automatically generate sources for various
|
||||
# preselected input type pairs and schedules.
|
||||
@ -674,7 +714,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
|
||||
message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0
|
||||
AND MACHETE_ARCHS)
|
||||
message(STATUS "Not building Machete kernels as CUDA Compiler version is "
|
||||
"not >= 12.0, we recommend upgrading to CUDA 12.0 or "
|
||||
|
@ -63,13 +63,11 @@ vLLM is fast with:
|
||||
- Speculative decoding
|
||||
- Chunked prefill
|
||||
|
||||
**Performance benchmark**: We include a performance benchmark at the end of [our blog post](https://blog.vllm.ai/2024/09/05/perf-update.html). It compares the performance of vLLM against other LLM serving engines ([TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM), [SGLang](https://github.com/sgl-project/sglang) and [LMDeploy](https://github.com/InternLM/lmdeploy)). The implementation is under [nightly-benchmarks folder](.buildkite/nightly-benchmarks/) and you can [reproduce](https://github.com/vllm-project/vllm/issues/8176) this benchmark using our one-click runnable script.
|
||||
|
||||
vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular Hugging Face models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs and GPUs, PowerPC CPUs, TPU, and AWS Neuron
|
||||
|
33
RELEASE.md
33
RELEASE.md
@ -52,3 +52,36 @@ After branch cut, we approach finalizing the release branch with clear criteria
|
||||
* Release branch specific changes (e.g. change version identifiers or CI fixes)
|
||||
|
||||
Please note: **No feature work allowed for cherry picks**. All PRs that are considered for cherry-picks need to be merged on trunk, the only exception are Release branch specific changes.
|
||||
|
||||
## Manual validations
|
||||
|
||||
### E2E Performance Validation
|
||||
|
||||
Before each release, we perform end-to-end performance validation to ensure no regressions are introduced. This validation uses the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) on PyTorch CI.
|
||||
|
||||
**Current Coverage:**
|
||||
* Models: Llama3, Llama4, and Mixtral
|
||||
* Hardware: NVIDIA H100 and AMD MI300x
|
||||
* *Note: Coverage may change based on new model releases and hardware availability*
|
||||
|
||||
**Performance Validation Process:**
|
||||
|
||||
**Step 1: Get Access**
|
||||
Request write access to the [pytorch/pytorch-integration-testing](https://github.com/pytorch/pytorch-integration-testing) repository to run the benchmark workflow.
|
||||
|
||||
**Step 2: Review Benchmark Setup**
|
||||
Familiarize yourself with the benchmark configurations:
|
||||
* [CUDA setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/cuda)
|
||||
* [ROCm setup](https://github.com/pytorch/pytorch-integration-testing/tree/main/vllm-benchmarks/benchmarks/rocm)
|
||||
|
||||
**Step 3: Run the Benchmark**
|
||||
Navigate to the [vllm-benchmark workflow](https://github.com/pytorch/pytorch-integration-testing/actions/workflows/vllm-benchmark.yml) and configure:
|
||||
* **vLLM branch**: Set to the release branch (e.g., `releases/v0.9.2`)
|
||||
* **vLLM commit**: Set to the RC commit hash
|
||||
|
||||
**Step 4: Review Results**
|
||||
Once the workflow completes, benchmark results will be available on the [vLLM benchmark dashboard](https://hud.pytorch.org/benchmark/llms?repoName=vllm-project%2Fvllm) under the corresponding branch and commit.
|
||||
|
||||
**Step 5: Performance Comparison**
|
||||
Compare the current results against the previous release to verify no performance regressions have occurred. Here is an
|
||||
example of [v0.9.1 vs v0.9.2](https://hud.pytorch.org/benchmark/llms?startTime=Thu%2C%2017%20Apr%202025%2021%3A43%3A50%20GMT&stopTime=Wed%2C%2016%20Jul%202025%2021%3A43%3A50%20GMT&granularity=week&lBranch=releases/v0.9.1&lCommit=b6553be1bc75f046b00046a4ad7576364d03c835&rBranch=releases/v0.9.2&rCommit=a5dd03c1ebc5e4f56f3c9d3dc0436e9c582c978f&repoName=vllm-project%2Fvllm&benchmarkName=&modelName=All%20Models&backendName=All%20Backends&modeName=All%20Modes&dtypeName=All%20DType&deviceName=All%20Devices&archName=All%20Platforms).
|
||||
|
137
benchmarks/auto_tune/README.md
Normal file
137
benchmarks/auto_tune/README.md
Normal file
@ -0,0 +1,137 @@
|
||||
# Automated vLLM Server Parameter Tuning
|
||||
|
||||
This script automates the process of finding the optimal server parameter combination (`max-num-seqs` and `max-num-batched-tokens`) to maximize throughput for a vLLM server. It also supports additional constraints such as E2E latency and prefix cache hit rate.
|
||||
|
||||
## Table of Contents
|
||||
- [Prerequisites](#prerequisites)
|
||||
- [Configuration](#configuration)
|
||||
- [How to Run](#how-to-run)
|
||||
- [Example Use Cases](#example-use-cases)
|
||||
- [Output](#output)
|
||||
- [How It Works](#how-it-works)
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before running the script, please ensure the following steps are completed:
|
||||
|
||||
1. **Clone vLLM & Set Up Branch**: Clone the vLLM repository and check out to your desired branch.
|
||||
|
||||
```bash
|
||||
git clone https://github.com/vllm-project/vllm.git
|
||||
cd vllm
|
||||
# git checkout <your-branch>
|
||||
```
|
||||
|
||||
1. **Install Environment**: Install or update the correct running environment. For TPU usage, activate your `conda` environment and install the corresponding `torch` and `torch_xla` versions.
|
||||
|
||||
2. **Model Configuration**: If you are using a customized model, ensure its configuration files are correctly placed and accessible.
|
||||
|
||||
## Configuration
|
||||
|
||||
You must set the following variables at the top of the script before execution.
|
||||
|
||||
| Variable | Description | Example Value |
|
||||
| --- | --- | --- |
|
||||
| `BASE` | **Required.** The absolute path to the parent directory of your vLLM repository directory. | `"$HOME"` |
|
||||
| `MODEL` | **Required.** The Hugging Face model identifier to be served by vllm. | `"meta-llama/Llama-3.1-8B-Instruct"` |
|
||||
| `SYSTEM`| **Required.** The hardware you are running on. Choices: `TPU` or `GPU`. (For other systems, it might not support saving profiles) | `"TPU"` |
|
||||
| `TP` | **Required.** The tensor-parallelism size. | `1` |
|
||||
| `DOWNLOAD_DIR` | **Required.** Directory to download and load model weights from. | `""` (default download path) |
|
||||
| `INPUT_LEN` | **Required.** Request input length. | `4000` |
|
||||
| `OUTPUT_LEN` | **Required.** Request output length. | `16` |
|
||||
| `MIN_CACHE_HIT_PCT` | Prefix cache hit rate in percentage (0-100). Set to `0` to disable. | `60` |
|
||||
| `MAX_LATENCY_ALLOWED_MS` | The maximum allowed P99 end-to-end latency in milliseconds. Set to a very large number (e.g., `100000000000`) to effectively ignore the latency constraint. | `500` |
|
||||
| `NUM_SEQS_LIST` | A space-separated string of `max-num-seqs` values to test. | `"128 256"` |
|
||||
| `NUM_BATCHED_TOKENS_LIST` | A space-separated string of `max-num-batched-tokens` values to test. | `"1024 2048 4096"` |
|
||||
|
||||
**Note**: The default `NUM_SEQS_LIST` and `NUM_BATCHED_TOKENS_LIST` are set for medium-sized inputs/outputs. For very short contexts (e.g., 20 input, 20 output tokens), you may need to test larger values for `max-num-seqs`.
|
||||
|
||||
## How to Run
|
||||
|
||||
1. **Configure**: Edit the script and set the variables in the [Configuration](#configuration) section.
|
||||
2. **Execute**: Run the script. Since the process can take a long time, it is highly recommended to use a terminal multiplexer like `tmux` or `screen` to prevent the script from stopping if your connection is lost.
|
||||
|
||||
```
|
||||
cd <FOLDER_OF_THIS_SCRIPT>
|
||||
bash auto_tune.sh
|
||||
```
|
||||
|
||||
Please note that the `bash auto_tune.sh` command cannot contain full or partial path with keyword `vllm`, otherwise `pkill -f vllm` command will also kill this script itself.
|
||||
|
||||
## Example Use Cases
|
||||
|
||||
Here are a few examples of how to configure the script for different goals:
|
||||
|
||||
### 1. Maximize Throughput (No Latency Constraint)
|
||||
- **Goal**: Find the best `max-num-seqs` and `max-num-batched-tokens` to get the highest possible throughput for 1800 input tokens and 20 output tokens.
|
||||
- **Configuration**:
|
||||
|
||||
```bash
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=20
|
||||
MIN_CACHE_HIT_PCT=0
|
||||
MAX_LATENCY_ALLOWED_MS=100000000000 # A very large number
|
||||
```
|
||||
|
||||
#### 2. Maximize Throughput with a Latency Requirement
|
||||
- **Goal**: Find the best server parameters when P99 end-to-end latency must be below 500ms.
|
||||
- **Configuration**:
|
||||
|
||||
```bash
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=20
|
||||
MIN_CACHE_HIT_PCT=0
|
||||
MAX_LATENCY_ALLOWED_MS=500
|
||||
```
|
||||
|
||||
#### 3. Maximize Throughput with Prefix Caching and Latency Requirements
|
||||
- **Goal**: Find the best server parameters assuming a 60% prefix cache hit rate and a latency requirement of 500ms.
|
||||
- **Configuration**:
|
||||
|
||||
```bash
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=20
|
||||
MIN_CACHE_HIT_PCT=60
|
||||
MAX_LATENCY_ALLOWED_MS=500
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
After the script finishes, you will find the results in a new, timestamped directory created inside `$BASE/auto-benchmark/`.
|
||||
|
||||
- **Log Files**: The directory (`$BASE/auto-benchmark/YYYY_MM_DD_HH_MM/`) contains detailed logs for each run:
|
||||
- `vllm_log_...txt`: The log output from the vLLM server for each parameter combination.
|
||||
- `bm_log_...txt`: The log output from the `benchmark_serving.py` script for each benchmark run.
|
||||
|
||||
- **Final Result Summary**: A file named `result.txt` is created in the log directory. It contains a summary of each tested combination and concludes with the overall best parameters found.
|
||||
|
||||
```
|
||||
# Example result.txt content
|
||||
hash:a1b2c3d4...
|
||||
max_num_seqs: 128, max_num_batched_tokens: 2048, request_rate: 10.0, e2el: 450.5, throughput: 9.8, goodput: 9.8
|
||||
max_num_seqs: 128, max_num_batched_tokens: 4096 does not meet latency requirement 500
|
||||
...
|
||||
best_max_num_seqs: 256, best_num_batched_tokens: 2048, best_throughput: 12.5, profile saved in: /home/user/vllm/auto-benchmark/2024_08_01_10_30/profile
|
||||
```
|
||||
|
||||
If it cannot find the best parameters, the final row will be `best_max_num_seqs: 0, best_num_batched_tokens: 0, best_throughput: 0`. This can be due to either the server not starting properly, or the latency requirement being too strict.
|
||||
|
||||
- **Profiler Trace**: A directory named `profile` is created inside the log directory. It contains the profiler trace file (e.g., `.xplane.pb` for TPU or a `.json` trace for GPU) from the single best-performing run.
|
||||
|
||||
## How It Works
|
||||
|
||||
The script follows a systematic process to find the optimal parameters:
|
||||
|
||||
1. **Find Max GPU Memory Utilization**: The script first determines the highest safe `gpu-memory-utilization` (starting from 0.98 and decreasing) that does not cause an Out-Of-Memory (OOM) error when launching the server. This ensures the benchmark runs use the maximum available memory without crashing.
|
||||
|
||||
2. **Iterate and Benchmark**: It then enters a nested loop, iterating through every combination of `max-num-seqs` and `max-num-batched-tokens` provided in the configuration lists.
|
||||
|
||||
3. **Latency-Aware Throughput Search**: For each parameter combination:
|
||||
- The vLLM server is started.
|
||||
- A benchmark is first run with an infinite request rate (`--request-rate inf`).
|
||||
- If the resulting P99 E2E latency is within the `MAX_LATENCY_ALLOWED_MS` limit, this throughput is considered the maximum for this configuration.
|
||||
- If the latency is too high, the script performs a search by iteratively decreasing the request rate until the latency constraint is met. This finds the highest sustainable throughput for the given parameters and latency requirement.
|
||||
|
||||
4. **Track Best Result**: Throughout the process, the script tracks the parameter combination that has yielded the highest valid throughput so far.
|
||||
|
||||
5. **Profile Collection**: For the best-performing run, the script saves the vLLM profiler output, which can be used for deep-dive performance analysis with tools like TensorBoard.
|
@ -1,36 +1,7 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script aims to tune the best server parameter combinations to maximize throughput for given requirement.
|
||||
# The current server parameter combination is max_num_seqs and max_num_batched_tokens
|
||||
# It also supports additional requirement: e2e latency and prefix cache.
|
||||
|
||||
# Pre-requisite:
|
||||
# 1. Checkout to your branch, install/ update the correct running env. For TPU, activate conda env and install the corresponding torch, xla version.
|
||||
# 2. If the model is customized, replace the MODEL's config with the customized config.
|
||||
# 3. Set variables (ALL REQUIRED)
|
||||
# BASE: your directory for vllm repo
|
||||
# MODEL: the model served by vllm
|
||||
# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support.
|
||||
# TP: ways of tensor parallelism
|
||||
# DOWNLOAD_DIR: directory to download and load model weights.
|
||||
# INPUT_LEN: request input len
|
||||
# OUTPUT_LEN: request output len
|
||||
# MIN_CACHE_HIT_PCT: prefix cache rate
|
||||
# MAX_LATENCY_ALLOWED_MS: (e2e) latency requirement. If there's no latency requirement, set it to a large number like 1000000000
|
||||
# NUM_SEQS_LIST: a list of `max-num-seqs` you want to loop with.
|
||||
# NUM_BATCHED_TOKENS_LIST: a list of `max-num-batched-tokens` you want to loop with.
|
||||
# Note that the default NUM_SEQS_LIST and NUM_BATCHED_TOKENS_LIST are set for medium size input/output len, for extra short context (such as 20:20), you might need to include larger numbers in NUM_SEQS_LIST.
|
||||
# 4. Run the script, it might take a long time, you can use tmux to avoid the script stop if disconnection happens.
|
||||
# 5. The final result will be saved in RESULT file.
|
||||
|
||||
|
||||
# Example use cases
|
||||
# 1. Given input_len=1800, output_len=20, what's the best max_num_seqs and max_num_batched_tokens to get highest throughput?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=100000000000
|
||||
# 2. If we have latency requirement to be lower than 500ms, what's the best server parameter?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=0, MAX_LATENCY_ALLOWED_MS=500
|
||||
# 3. If we want to reach 60% prefix cache, what's the best server parameter?
|
||||
# Use INPUT_LEN=1800, OUTPUT_LEN=20, MIN_CACHE_HIT_PCT=60, MAX_LATENCY_ALLOWED_MS=500
|
||||
# See details in README (benchmarks/auto_tune/README.md).
|
||||
|
||||
TAG=$(date +"%Y_%m_%d_%H_%M")
|
||||
BASE=""
|
||||
@ -155,11 +126,12 @@ run_benchmark() {
|
||||
# get a basic qps by using request-rate inf
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_inf.txt"
|
||||
prefix_len=$(( INPUT_LEN * MIN_CACHE_HIT_PCT / 100 ))
|
||||
python benchmarks/benchmark_serving.py \
|
||||
adjusted_input_len=$(( INPUT_LEN - prefix_len ))
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name random \
|
||||
--random-input-len $INPUT_LEN \
|
||||
--random-input-len $adjusted_input_len \
|
||||
--random-output-len $OUTPUT_LEN \
|
||||
--ignore-eos \
|
||||
--disable-tqdm \
|
||||
@ -188,11 +160,11 @@ run_benchmark() {
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
bm_log="$LOG_FOLDER/bm_log_${max_num_seqs}_${max_num_batched_tokens}_requestrate_${request_rate}.txt"
|
||||
python benchmarks/benchmark_serving.py \
|
||||
python3 benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model $MODEL \
|
||||
--dataset-name random \
|
||||
--random-input-len $INPUT_LEN \
|
||||
--random-input-len $adjusted_input_len \
|
||||
--random-output-len $OUTPUT_LEN \
|
||||
--ignore-eos \
|
||||
--disable-tqdm \
|
@ -324,6 +324,9 @@ class RandomDataset(BenchmarkDataset):
|
||||
input_low = int(real_input_len * (1 - range_ratio))
|
||||
input_high = int(real_input_len * (1 + range_ratio))
|
||||
output_low = int(output_len * (1 - range_ratio))
|
||||
# Ensure the lower bound for output length is at least 1 to prevent
|
||||
# sampling 0 tokens, which can cause request failures.
|
||||
output_low = max(output_low, 1)
|
||||
output_high = int(output_len * (1 + range_ratio))
|
||||
|
||||
# Add logging for debugging
|
||||
@ -701,6 +704,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self,
|
||||
dataset_path: str,
|
||||
dataset_split: str,
|
||||
no_stream: bool = False,
|
||||
dataset_subset: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
@ -708,6 +712,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
|
||||
self.dataset_split = dataset_split
|
||||
self.dataset_subset = dataset_subset
|
||||
self.load_stream = not no_stream
|
||||
self.load_data()
|
||||
|
||||
def load_data(self) -> None:
|
||||
@ -716,7 +721,7 @@ class HuggingFaceDataset(BenchmarkDataset):
|
||||
self.dataset_path,
|
||||
name=self.dataset_subset,
|
||||
split=self.dataset_split,
|
||||
streaming=True,
|
||||
streaming=self.load_stream,
|
||||
)
|
||||
self.data = self.data.shuffle(seed=self.random_seed)
|
||||
|
||||
|
@ -30,7 +30,7 @@ import os
|
||||
import random
|
||||
import time
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, Optional
|
||||
@ -73,6 +73,7 @@ from benchmark_dataset import (
|
||||
VisionArenaDataset,
|
||||
)
|
||||
from benchmark_utils import convert_to_pytorch_benchmark_format, write_to_json
|
||||
from vllm.benchmarks.serve import get_request
|
||||
|
||||
MILLISECONDS_TO_SECONDS_CONVERSION = 1000
|
||||
|
||||
@ -107,101 +108,6 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
ramp_up_end_rps: Optional[int],
|
||||
request_index: int,
|
||||
total_requests: int,
|
||||
request_rate: float,
|
||||
) -> float:
|
||||
if (
|
||||
ramp_up_strategy
|
||||
and ramp_up_start_rps is not None
|
||||
and ramp_up_end_rps is not None
|
||||
):
|
||||
progress = request_index / max(total_requests - 1, 1)
|
||||
if ramp_up_strategy == "linear":
|
||||
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
|
||||
return ramp_up_start_rps + increase
|
||||
elif ramp_up_strategy == "exponential":
|
||||
ratio = ramp_up_end_rps / ramp_up_start_rps
|
||||
return ramp_up_start_rps * (ratio**progress)
|
||||
else:
|
||||
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
|
||||
return request_rate
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
A list of input requests, each represented as a SampleRequest.
|
||||
request_rate:
|
||||
The rate at which requests are generated (requests/s).
|
||||
burstiness (optional):
|
||||
The burstiness factor of the request generation.
|
||||
Only takes effect when request_rate is not inf.
|
||||
Default value is 1, which follows a Poisson process.
|
||||
Otherwise, the request intervals follow a gamma distribution.
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
ramp_up_strategy (optional):
|
||||
The ramp-up strategy. Can be "linear" or "exponential".
|
||||
If None, uses constant request rate (specified by request_rate).
|
||||
ramp_up_start_rps (optional):
|
||||
The starting request rate for ramp-up.
|
||||
ramp_up_end_rps (optional):
|
||||
The ending request rate for ramp-up.
|
||||
"""
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
request_index = 0
|
||||
|
||||
for request in input_requests:
|
||||
current_request_rate = _get_current_request_rate(
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate,
|
||||
)
|
||||
|
||||
yield request, current_request_rate
|
||||
|
||||
request_index += 1
|
||||
|
||||
if current_request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
theta = 1.0 / (current_request_rate * burstiness)
|
||||
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
interval = np.random.gamma(shape=burstiness, scale=theta)
|
||||
# The next request will be sent after the interval.
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
@ -825,6 +731,7 @@ def main(args: argparse.Namespace):
|
||||
dataset_subset=args.hf_subset,
|
||||
dataset_split=args.hf_split,
|
||||
random_seed=args.seed,
|
||||
no_stream=args.no_stream,
|
||||
).sample(
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
@ -1033,6 +940,11 @@ def create_argument_parser():
|
||||
help="Path to the sharegpt/sonnet dataset. "
|
||||
"Or the huggingface dataset ID if using HF dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max-concurrency",
|
||||
type=int,
|
||||
|
@ -356,6 +356,7 @@ def get_requests(args, tokenizer):
|
||||
elif args.dataset_name == "burstgpt":
|
||||
dataset_cls = BurstGPTDataset
|
||||
elif args.dataset_name == "hf":
|
||||
common_kwargs["no_stream"] = args.no_stream
|
||||
if args.dataset_path in VisionArenaDataset.SUPPORTED_DATASET_PATHS:
|
||||
dataset_cls = VisionArenaDataset
|
||||
common_kwargs["dataset_subset"] = None
|
||||
@ -610,6 +611,11 @@ def create_argument_parser():
|
||||
help="Name of the dataset to benchmark on.",
|
||||
default="sharegpt",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-stream",
|
||||
action="store_true",
|
||||
help="Do not load the dataset in streaming mode.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset",
|
||||
type=str,
|
||||
|
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
141
benchmarks/kernels/bench_nvfp4_gemm.py
Normal file
@ -0,0 +1,141 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from weight_shapes import WEIGHT_SHAPES
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.scalar_type import scalar_types
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
if not current_platform.has_device_capability(100):
|
||||
raise RuntimeError("NVFP4 requires compute capability of 10.0 (Blackwell)")
|
||||
|
||||
|
||||
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
|
||||
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
|
||||
|
||||
PROVIDER_CFGS = {
|
||||
"torch-bf16": dict(enabled=True),
|
||||
"nvfp4": dict(no_a_quant=False, enabled=True),
|
||||
"nvfp4-noquant": dict(no_a_quant=True, enabled=True),
|
||||
}
|
||||
|
||||
_enabled = [k for k, v in PROVIDER_CFGS.items() if v["enabled"]]
|
||||
|
||||
|
||||
def _quant_weight_nvfp4(b: torch.Tensor, device: str):
|
||||
# Compute global scale for weight
|
||||
b_amax = torch.abs(b).max().to(torch.float32)
|
||||
b_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax
|
||||
b_fp4, scale_b_fp4 = ops.scaled_fp4_quant(b, b_global_scale)
|
||||
return b_fp4, scale_b_fp4, b_global_scale
|
||||
|
||||
|
||||
def build_nvfp4_runner(cfg, a, b, dtype, device):
|
||||
b_fp4, scale_b_fp4, b_global_scale = _quant_weight_nvfp4(b, device)
|
||||
|
||||
# Compute global scale for activation
|
||||
# NOTE: This is generally provided ahead-of-time by the model checkpoint.
|
||||
a_amax = torch.abs(a).max().to(torch.float32)
|
||||
a_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax
|
||||
|
||||
# Alpha for the GEMM operation
|
||||
alpha = 1.0 / (a_global_scale * b_global_scale)
|
||||
|
||||
if cfg["no_a_quant"]:
|
||||
# Pre-quantize activation
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
|
||||
def run():
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
# Quantize activation on-the-fly
|
||||
def run():
|
||||
a_fp4, scale_a_fp4 = ops.scaled_fp4_quant(a, a_global_scale)
|
||||
return ops.cutlass_scaled_fp4_mm(
|
||||
a_fp4, b_fp4, scale_a_fp4, scale_b_fp4, alpha, dtype
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size"],
|
||||
x_vals=[1, 16, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384],
|
||||
x_log=False,
|
||||
line_arg="provider",
|
||||
line_vals=_enabled,
|
||||
line_names=_enabled,
|
||||
ylabel="TFLOP/s (larger is better)",
|
||||
plot_name="BF16 vs NVFP4 GEMMs",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark(batch_size, provider, N, K):
|
||||
M = batch_size
|
||||
device = "cuda"
|
||||
dtype = torch.bfloat16
|
||||
|
||||
a = torch.randn((M, K), device=device, dtype=dtype)
|
||||
b = torch.randn((N, K), device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch-bf16":
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: torch.nn.functional.linear(a, b), quantiles=quantiles
|
||||
)
|
||||
else:
|
||||
cfg = PROVIDER_CFGS[provider]
|
||||
run_quant = build_nvfp4_runner(cfg, a, b, dtype, device)
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(
|
||||
lambda: run_quant(), quantiles=quantiles
|
||||
)
|
||||
|
||||
to_tflops = lambda t_ms: (2 * M * N * K) * 1e-12 / (t_ms * 1e-3)
|
||||
return to_tflops(ms), to_tflops(max_ms), to_tflops(min_ms)
|
||||
|
||||
|
||||
def prepare_shapes(args):
|
||||
out = []
|
||||
for model, tp_size in itertools.product(args.models, args.tp_sizes):
|
||||
for KN, tp_dim in copy.deepcopy(WEIGHT_SHAPES[model]):
|
||||
KN[tp_dim] //= tp_size
|
||||
KN.append(model)
|
||||
out.append(KN)
|
||||
return out
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--models",
|
||||
nargs="+",
|
||||
type=str,
|
||||
default=["meta-llama/Llama-3.1-8B-Instruct"],
|
||||
choices=list(WEIGHT_SHAPES.keys()),
|
||||
)
|
||||
parser.add_argument("--tp-sizes", nargs="+", type=int, default=[1])
|
||||
args = parser.parse_args()
|
||||
|
||||
for K, N, model in prepare_shapes(args):
|
||||
print(f"{model}, N={N} K={K}, BF16 vs NVFP4 GEMMs TFLOP/s:")
|
||||
benchmark.run(
|
||||
print_data=True,
|
||||
show_plots=True,
|
||||
save_path=f"bench_nvfp4_res_n{N}_k{K}",
|
||||
N=N,
|
||||
K=K,
|
||||
)
|
||||
|
||||
print("Benchmark finished!")
|
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
98
benchmarks/kernels/bench_per_token_quant_fp8.py
Normal file
@ -0,0 +1,98 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import itertools
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config
|
||||
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
|
||||
from vllm.model_executor.layers.quantization.utils.quant_utils import GroupShape
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
# TODO(luka): use standalone_compile utility
|
||||
def with_dyn_arg(fn: Callable, arg_index: int, dim_index: int):
|
||||
def inner(*args):
|
||||
torch._dynamo.mark_dynamic(args[arg_index], dim_index)
|
||||
return fn(*args)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
torch._dynamo.config.recompile_limit = 8888
|
||||
compilation_config = CompilationConfig(custom_ops=["none"])
|
||||
with set_current_vllm_config(VllmConfig(compilation_config=compilation_config)):
|
||||
torch_per_token_quant_fp8 = torch.compile(
|
||||
QuantFP8(False, GroupShape.PER_TOKEN),
|
||||
fullgraph=True,
|
||||
dynamic=False, # recompile for different shapes
|
||||
)
|
||||
|
||||
# First dim is explicitly dynamic to simulate vLLM usage
|
||||
torch_per_token_quant_fp8 = with_dyn_arg(torch_per_token_quant_fp8, 0, 0)
|
||||
|
||||
|
||||
def cuda_per_token_quant_fp8(
|
||||
input: torch.Tensor,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
return ops.scaled_fp8_quant(input)
|
||||
|
||||
|
||||
def calculate_diff(batch_size: int, seq_len: int):
|
||||
"""Calculate difference between Triton and CUDA implementations."""
|
||||
device = torch.device("cuda")
|
||||
x = torch.rand((batch_size * seq_len, 4096), dtype=torch.float16, device=device)
|
||||
|
||||
torch_out, torch_scale = torch_per_token_quant_fp8(x)
|
||||
cuda_out, cuda_scale = cuda_per_token_quant_fp8(x)
|
||||
|
||||
if torch.allclose(
|
||||
cuda_out.to(torch.float32), torch_out.to(torch.float32), rtol=1e-3, atol=1e-5
|
||||
) and torch.allclose(cuda_scale, torch_scale, rtol=1e-3, atol=1e-5):
|
||||
print("✅ All implementations match")
|
||||
else:
|
||||
print("❌ Implementations differ")
|
||||
|
||||
|
||||
batch_size_range = [1, 16, 32, 64, 128]
|
||||
seq_len_range = [1, 16, 64, 128, 256, 512, 1024, 2048, 4096]
|
||||
|
||||
configs = list(itertools.product(batch_size_range, seq_len_range))
|
||||
|
||||
|
||||
@triton.testing.perf_report(
|
||||
triton.testing.Benchmark(
|
||||
x_names=["batch_size", "seq_len"],
|
||||
x_vals=configs,
|
||||
line_arg="provider",
|
||||
line_vals=["torch", "cuda"],
|
||||
line_names=["Torch", "CUDA"],
|
||||
styles=[("blue", "-"), ("green", "-")],
|
||||
ylabel="us",
|
||||
plot_name="per-token-dynamic-quant-fp8-performance",
|
||||
args={},
|
||||
)
|
||||
)
|
||||
def benchmark_quantization(batch_size, seq_len, provider):
|
||||
dtype = torch.float16
|
||||
device = torch.device("cuda")
|
||||
|
||||
x = torch.randn(batch_size * seq_len, 4096, device=device, dtype=dtype)
|
||||
|
||||
quantiles = [0.5, 0.2, 0.8]
|
||||
|
||||
if provider == "torch":
|
||||
fn = lambda: torch_per_token_quant_fp8(x.clone())
|
||||
elif provider == "cuda":
|
||||
fn = lambda: cuda_per_token_quant_fp8(x.clone())
|
||||
|
||||
ms, min_ms, max_ms = triton.testing.do_bench_cudagraph(fn, quantiles=quantiles)
|
||||
|
||||
return 1000 * ms, 1000 * max_ms, 1000 * min_ms
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
calculate_diff(batch_size=4, seq_len=4096)
|
||||
benchmark_quantization.run(print_data=True)
|
@ -86,6 +86,9 @@ def benchmark_config(
|
||||
(num_experts, 2 * shard_intermediate_size), dtype=torch.float32
|
||||
)
|
||||
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
|
||||
if use_deep_gemm:
|
||||
# we use the default block shape for deepgemm
|
||||
block_quant_shape = [128, 128]
|
||||
if use_fp8_w8a8:
|
||||
if block_quant_shape:
|
||||
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
|
||||
@ -573,7 +576,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("DeepseekV3ForCausalLM", "DeepseekV2ForCausalLM"):
|
||||
elif config.architectures[0] in (
|
||||
"DeepseekV3ForCausalLM",
|
||||
"DeepseekV2ForCausalLM",
|
||||
"Glm4MoeForCausalLM",
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
@ -583,6 +590,11 @@ def main(args: argparse.Namespace):
|
||||
topk = config.num_experts_per_tok
|
||||
intermediate_size = config.moe_intermediate_size
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
elif config.architectures[0] in ("HunYuanMoEV1ForCausalLM"):
|
||||
E = config.num_experts
|
||||
topk = config.moe_topk[0]
|
||||
intermediate_size = config.moe_intermediate_size[0]
|
||||
shard_intermediate_size = 2 * intermediate_size // args.tp_size
|
||||
else:
|
||||
# Support for llama4
|
||||
config = config.get_text_config()
|
||||
@ -620,7 +632,7 @@ def main(args: argparse.Namespace):
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
@ -728,7 +740,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
|
@ -33,15 +33,13 @@ def check_correctness(num_tokens, num_experts=256, block_size=256, topk=8):
|
||||
sorted_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sorted_ids_triton.fill_(topk_ids.numel()) # fill with sentinel value
|
||||
expert_ids_triton = torch.zeros(
|
||||
expert_ids_triton = torch.empty(
|
||||
(max_num_tokens_padded // block_size,), dtype=torch.int32, device="cuda"
|
||||
)
|
||||
num_tokens_post_pad_triton = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
||||
sorted_ids_vllm = torch.empty_like(sorted_ids_triton)
|
||||
sorted_ids_vllm.fill_(topk_ids.numel())
|
||||
expert_ids_vllm = torch.zeros_like(expert_ids_triton)
|
||||
expert_ids_vllm = torch.empty_like(expert_ids_triton)
|
||||
num_tokens_post_pad_vllm = torch.empty_like(num_tokens_post_pad_triton)
|
||||
|
||||
# 2. run implementations
|
||||
@ -102,7 +100,6 @@ def benchmark(num_tokens, num_experts, topk, provider):
|
||||
|
||||
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
|
||||
sorted_ids = torch.empty((max_num_tokens_padded,), dtype=torch.int32, device="cuda")
|
||||
sorted_ids.fill_(topk_ids.numel())
|
||||
max_num_m_blocks = max_num_tokens_padded // block_size
|
||||
expert_ids = torch.empty((max_num_m_blocks,), dtype=torch.int32, device="cuda")
|
||||
num_tokens_post_pad = torch.empty((1,), dtype=torch.int32, device="cuda")
|
||||
|
@ -318,6 +318,7 @@ def main(args: argparse.Namespace):
|
||||
elif (
|
||||
config.architectures[0] == "DeepseekV3ForCausalLM"
|
||||
or config.architectures[0] == "DeepseekV2ForCausalLM"
|
||||
or config.architectures[0] == "Glm4MoeForCausalLM"
|
||||
):
|
||||
E = config.n_routed_experts
|
||||
topk = config.num_experts_per_tok
|
||||
|
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
240
benchmarks/kernels/benchmark_trtllm_attention.py
Normal file
@ -0,0 +1,240 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import csv
|
||||
import os
|
||||
import random
|
||||
from datetime import datetime
|
||||
|
||||
import flashinfer
|
||||
import torch
|
||||
|
||||
FLOAT32_BYTES = torch.finfo(torch.float).bits // 8
|
||||
|
||||
# KV Cache Layout for TRT-LLM
|
||||
# kv_cache_shape = (num_blocks, 2, num_kv_heads, page_size, head_dim)
|
||||
|
||||
|
||||
def to_float8(x, dtype=torch.float8_e4m3fn):
|
||||
finfo = torch.finfo(dtype)
|
||||
min_val, max_val = x.aminmax()
|
||||
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
||||
scale = finfo.max / amax * 0.1
|
||||
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
||||
return x_scl_sat.to(dtype), scale.float().reciprocal()
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def benchmark_decode(
|
||||
num_seqs,
|
||||
max_seq_len,
|
||||
page_size=16,
|
||||
dtype=torch.bfloat16,
|
||||
kv_layout="HND",
|
||||
num_kv_heads=8,
|
||||
kv_cache_dtype="auto",
|
||||
head_dim=128,
|
||||
warmup=10,
|
||||
trials=20,
|
||||
):
|
||||
torch.set_default_device("cuda")
|
||||
device = "cuda"
|
||||
torch.manual_seed(0)
|
||||
|
||||
# Currently only HEAD_GRP_SIZE == 8 is supported
|
||||
HEAD_GRP_SIZE = 8
|
||||
MAX_SEQ_LEN = max_seq_len
|
||||
|
||||
# large number to reduce kv_cache reuse
|
||||
NUM_BLOCKS = int(256000 / page_size)
|
||||
|
||||
workspace_buffer = torch.empty(1024 * 1024 * 1024, dtype=torch.int8, device=device)
|
||||
|
||||
# For decode, batch_size is num_decode_token
|
||||
num_qo_heads = num_kv_heads * HEAD_GRP_SIZE
|
||||
sm_scale = float(1.0 / (head_dim**0.5))
|
||||
q = torch.randn(num_seqs, num_qo_heads, head_dim, device=device, dtype=dtype)
|
||||
kv_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)]
|
||||
|
||||
max_kv_len = max(kv_lens)
|
||||
kv_lens_tensor = torch.tensor(kv_lens, dtype=torch.int, device=device)
|
||||
max_num_blocks_per_seq = (max_kv_len + page_size - 1) // page_size
|
||||
|
||||
block_tables = torch.randint(
|
||||
0, NUM_BLOCKS, (num_seqs, max_num_blocks_per_seq), dtype=torch.int32
|
||||
)
|
||||
|
||||
kv_cache_shape = (NUM_BLOCKS, 2, num_kv_heads, page_size, head_dim)
|
||||
kv_cache = torch.randn(size=kv_cache_shape, device=device, dtype=dtype)
|
||||
k_scale = v_scale = 1.0
|
||||
|
||||
if kv_cache_dtype.startswith("fp8"):
|
||||
kv_cache, _ = to_float8(kv_cache)
|
||||
|
||||
# Benchmark TRT decode
|
||||
def trt_decode():
|
||||
return flashinfer.decode.trtllm_batch_decode_with_kv_cache(
|
||||
q,
|
||||
kv_cache,
|
||||
workspace_buffer,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
sm_scale,
|
||||
block_tables,
|
||||
kv_lens_tensor,
|
||||
page_size,
|
||||
max_kv_len,
|
||||
kv_cache_dtype,
|
||||
k_scale,
|
||||
v_scale,
|
||||
)
|
||||
|
||||
def time_fn(fn, warmup=10, trials=20):
|
||||
torch.cuda.synchronize()
|
||||
start = torch.cuda.Event(enable_timing=True)
|
||||
end = torch.cuda.Event(enable_timing=True)
|
||||
times = []
|
||||
for i in range(warmup):
|
||||
fn()
|
||||
for i in range(trials):
|
||||
start.record()
|
||||
fn()
|
||||
end.record()
|
||||
torch.cuda.synchronize()
|
||||
times.append(start.elapsed_time(end)) # ms
|
||||
return sum(times) / len(times), torch.std(torch.tensor(times))
|
||||
|
||||
# TRT Decode
|
||||
trt_mean, trt_std = time_fn(trt_decode)
|
||||
|
||||
kv_indptr = [0]
|
||||
kv_indices = []
|
||||
kv_last_page_lens = []
|
||||
for i in range(num_seqs):
|
||||
seq_len = kv_lens[i]
|
||||
assert seq_len > 0
|
||||
num_blocks = (seq_len + page_size - 1) // page_size
|
||||
kv_indices.extend(block_tables[i, :num_blocks])
|
||||
kv_indptr.append(kv_indptr[-1] + num_blocks)
|
||||
kv_last_page_len = seq_len % page_size
|
||||
if kv_last_page_len == 0:
|
||||
kv_last_page_len = page_size
|
||||
kv_last_page_lens.append(kv_last_page_len)
|
||||
|
||||
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
|
||||
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
|
||||
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
|
||||
|
||||
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
|
||||
workspace_buffer,
|
||||
kv_layout,
|
||||
use_tensor_cores=((num_qo_heads // num_kv_heads) > 4),
|
||||
)
|
||||
|
||||
wrapper.plan(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
kv_last_page_lens,
|
||||
num_qo_heads,
|
||||
num_kv_heads,
|
||||
head_dim,
|
||||
page_size,
|
||||
"NONE",
|
||||
q_data_type=dtype,
|
||||
kv_data_type=torch.float8_e4m3fn if kv_cache_dtype.startswith("fp8") else dtype,
|
||||
)
|
||||
|
||||
def baseline_decode():
|
||||
return wrapper.run(q, kv_cache, sm_scale, k_scale, v_scale)
|
||||
|
||||
baseline_mean, baseline_std = time_fn(baseline_decode)
|
||||
|
||||
# Calculate percentage speedup (positive means TRT is faster)
|
||||
speedup_percent = (baseline_mean - trt_mean) / baseline_mean
|
||||
|
||||
print(
|
||||
f"\t{num_seqs}\t{max_seq_len}\t{trt_mean:.3f}\t{trt_std.item():.3f}"
|
||||
f"\t{baseline_mean:.3f}\t{baseline_std.item():.3f}\t{speedup_percent:.3f}"
|
||||
)
|
||||
|
||||
# Return results for CSV writing
|
||||
return {
|
||||
"num_seqs": num_seqs,
|
||||
"trt_mean": trt_mean,
|
||||
"trt_std": trt_std.item(),
|
||||
"baseline_mean": baseline_mean,
|
||||
"baseline_std": baseline_std.item(),
|
||||
"speedup_percent": speedup_percent,
|
||||
"q_dtype": str(dtype),
|
||||
"kv_cache_dtype": kv_cache_dtype,
|
||||
"page_size": page_size,
|
||||
"num_kv_heads": num_kv_heads,
|
||||
"head_dim": head_dim,
|
||||
"max_seq_len": max_seq_len,
|
||||
}
|
||||
|
||||
|
||||
def write_results_to_csv(results, filename=None):
|
||||
"""Write benchmark results to CSV file."""
|
||||
if filename is None:
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
filename = f"flashinfer_trtllm_benchmark_{timestamp}.csv"
|
||||
|
||||
fieldnames = [
|
||||
"num_seqs",
|
||||
"trt_mean",
|
||||
"trt_std",
|
||||
"baseline_mean",
|
||||
"baseline_std",
|
||||
"speedup_percent",
|
||||
"q_dtype",
|
||||
"kv_cache_dtype",
|
||||
"page_size",
|
||||
"num_kv_heads",
|
||||
"head_dim",
|
||||
"max_seq_len",
|
||||
]
|
||||
|
||||
file_exists = os.path.exists(filename)
|
||||
|
||||
with open(filename, "a", newline="") as csvfile:
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
|
||||
|
||||
if not file_exists:
|
||||
writer.writeheader()
|
||||
|
||||
for result in results:
|
||||
writer.writerow(result)
|
||||
|
||||
print(f"Results written to {filename}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
num_seqs = [1, 4, 8, 16, 32, 64, 128, 256]
|
||||
max_seq_lens = [1024, 2048, 4096, 8192, 16384, 32768, 65536, 131072]
|
||||
all_results = []
|
||||
|
||||
print("Running benchmark for kv_cache_dtype: bfloat16")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="auto"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
print("Running benchmark for q_dtype = bfloat16, kv_cache_dtype: fp8")
|
||||
print(
|
||||
"\tnum_seqs\tmax_seq_len\ttrt_mean\ttrt_std\tbaseline_mean\tbaseline_std\tspeedup_percent"
|
||||
)
|
||||
for max_seq_len in max_seq_lens:
|
||||
for bs in num_seqs:
|
||||
result = benchmark_decode(
|
||||
bs, max_seq_len, dtype=torch.bfloat16, kv_cache_dtype="fp8"
|
||||
)
|
||||
all_results.append(result)
|
||||
|
||||
# Write all results to CSV
|
||||
write_results_to_csv(all_results)
|
108
benchmarks/kv_cache/benchmark_block_pool.py
Normal file
108
benchmarks/kv_cache/benchmark_block_pool.py
Normal file
@ -0,0 +1,108 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import gc
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from tabulate import tabulate
|
||||
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.v1.core.block_pool import BlockPool
|
||||
|
||||
|
||||
class Metric:
|
||||
def __init__(self) -> None:
|
||||
self.cnt: int = 0
|
||||
self.sum_v: int = 0
|
||||
self.max_v: Optional[int] = None
|
||||
|
||||
def update(self, v: int) -> None:
|
||||
self.cnt += 1
|
||||
self.sum_v += v
|
||||
if self.max_v is None:
|
||||
self.max_v = v
|
||||
else:
|
||||
self.max_v = max(self.max_v, v)
|
||||
|
||||
def avg_v(self) -> float:
|
||||
return self.sum_v * 1.0 / self.cnt
|
||||
|
||||
|
||||
def main(args):
|
||||
rows = []
|
||||
for allocate_block in args.allocate_blocks:
|
||||
# Enforce a GC collect ahead to minimize the impact among runs
|
||||
gc.collect()
|
||||
block_pool = BlockPool(num_gpu_blocks=args.num_gpu_blocks, enable_caching=True)
|
||||
|
||||
get_blocks_metric: Metric = Metric()
|
||||
free_blocks_metric: Metric = Metric()
|
||||
for _ in range(args.num_iteration):
|
||||
t1 = time.monotonic_ns()
|
||||
blocks = block_pool.get_new_blocks(allocate_block)
|
||||
t2 = time.monotonic_ns()
|
||||
block_pool.free_blocks(blocks)
|
||||
t3 = time.monotonic_ns()
|
||||
get_blocks_metric.update(t2 - t1)
|
||||
free_blocks_metric.update(t3 - t2)
|
||||
|
||||
if get_blocks_metric.max_v is not None and free_blocks_metric.max_v is not None:
|
||||
rows.append(
|
||||
[
|
||||
get_blocks_metric.cnt,
|
||||
args.num_gpu_blocks,
|
||||
allocate_block,
|
||||
get_blocks_metric.avg_v() / 1000000,
|
||||
get_blocks_metric.max_v / 1000000.0,
|
||||
free_blocks_metric.avg_v() / 1000000,
|
||||
free_blocks_metric.max_v / 1000000.0,
|
||||
]
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"No valid metrics found."
|
||||
f" {get_blocks_metric.max_v=} {free_blocks_metric.max_v=}"
|
||||
)
|
||||
|
||||
print(
|
||||
tabulate(
|
||||
rows,
|
||||
headers=[
|
||||
"Iterations",
|
||||
"Total\nBlocks",
|
||||
"Allocated\nBlocks",
|
||||
"Get Blocks\nAvg (ms)",
|
||||
"Get Blocks\nMax (ms)",
|
||||
"Free Blocks\nAvg (ms)",
|
||||
"Free Blocks\nMax (ms)",
|
||||
],
|
||||
tablefmt="grid",
|
||||
floatfmt=".6f",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def invoke_main() -> None:
|
||||
parser = FlexibleArgumentParser(
|
||||
description="Benchmark the performance of BlockPool for KV Cache."
|
||||
)
|
||||
parser.add_argument("--num-gpu-blocks", type=int, default=100000)
|
||||
parser.add_argument(
|
||||
"--num-iteration",
|
||||
type=int,
|
||||
default=1000,
|
||||
help="Number of iterations to run to stablize final data readings",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allocate-blocks",
|
||||
type=int,
|
||||
nargs="*",
|
||||
default=[10, 50, 100, 500, 1000],
|
||||
help="Number of blocks to allocate",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
invoke_main() # pragma: no cover
|
@ -12,9 +12,8 @@ endif()
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
endif()
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
@ -107,10 +106,19 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
endif()
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
@ -157,17 +165,32 @@ else()
|
||||
endif()
|
||||
|
||||
#
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 platforms)
|
||||
#
|
||||
if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
# Build oneDNN for W8A8 GEMM kernels (only for x86-AVX512 /ARM platforms)
|
||||
# Flag to enable ACL kernels for AARCH64 platforms
|
||||
if ( VLLM_BUILD_ACL STREQUAL "ON")
|
||||
set(USE_ACL ON)
|
||||
else()
|
||||
set(USE_ACL OFF)
|
||||
endif()
|
||||
|
||||
if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR ASIMD_FOUND)
|
||||
FetchContent_Declare(
|
||||
oneDNN
|
||||
GIT_REPOSITORY https://github.com/oneapi-src/oneDNN.git
|
||||
GIT_TAG v3.7.1
|
||||
GIT_TAG v3.8.1
|
||||
GIT_PROGRESS TRUE
|
||||
GIT_SHALLOW TRUE
|
||||
)
|
||||
|
||||
if(USE_ACL)
|
||||
find_library(ARM_COMPUTE_LIBRARY NAMES arm_compute PATHS $ENV{ACL_ROOT_DIR}/build/)
|
||||
if(NOT ARM_COMPUTE_LIBRARY)
|
||||
message(FATAL_ERROR "Could not find ARM Compute Library: please set ACL_ROOT_DIR")
|
||||
endif()
|
||||
set(ONEDNN_AARCH64_USE_ACL "ON")
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wl,-rpath,$ENV{ACL_ROOT_DIR}/build/")
|
||||
endif()
|
||||
|
||||
set(ONEDNN_LIBRARY_TYPE "STATIC")
|
||||
set(ONEDNN_BUILD_DOC "OFF")
|
||||
set(ONEDNN_BUILD_EXAMPLES "OFF")
|
||||
@ -256,6 +279,13 @@ elseif(POWER10_FOUND)
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
if (ASIMD_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
|
@ -24,6 +24,7 @@
|
||||
|
||||
#include "attention_dtypes.h"
|
||||
#include "attention_utils.cuh"
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#ifdef USE_ROCM
|
||||
#include <hip/hip_bf16.h>
|
||||
@ -33,12 +34,6 @@ typedef __hip_bfloat16 __nv_bfloat16;
|
||||
#include "../quantization/fp8/nvidia/quant_utils.cuh"
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
|
||||
@ -670,7 +665,6 @@ __global__ void paged_attention_v2_reduce_kernel(
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
372
csrc/attention/mla/cutlass_sm100_mla/device/sm100_mla.hpp
Normal file
@ -0,0 +1,372 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
/*!
|
||||
\file
|
||||
\brief An universal device layer for cutlass 3.x-style kernels.
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
// common
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/device_kernel.h"
|
||||
|
||||
#if !defined(__CUDACC_RTC__)
|
||||
#include "cutlass/cluster_launch.hpp"
|
||||
#include "cutlass/trace.h"
|
||||
#endif // !defined(__CUDACC_RTC__)
|
||||
|
||||
#include "../kernel/sm100_fmha_mla_tma_warpspecialized.hpp"
|
||||
#include "../kernel/sm100_fmha_mla_reduction.hpp"
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
namespace cutlass::fmha::device {
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
////////////////////////////// CUTLASS 3.x API /////////////////////////////////
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<
|
||||
class Kernel_
|
||||
>
|
||||
class MLA {
|
||||
public:
|
||||
|
||||
using Kernel = Kernel_;
|
||||
|
||||
using ReductionKernel = cutlass::fmha::kernel::Sm100FmhaMlaReductionKernel<
|
||||
typename Kernel::ElementOut,
|
||||
typename Kernel::ElementAcc,
|
||||
typename Kernel::ElementAcc,
|
||||
Kernel::TileShapeH::value,
|
||||
Kernel::TileShapeL::value,
|
||||
256 /*Max split*/
|
||||
>;
|
||||
|
||||
/// Argument structure: User API
|
||||
using KernelArguments = typename Kernel::Arguments;
|
||||
using ReductionArguments = typename ReductionKernel::Arguments;
|
||||
|
||||
using Arguments = KernelArguments;
|
||||
|
||||
/// Argument structure: Kernel API
|
||||
using KernelParams = typename Kernel::Params;
|
||||
using ReductionParams = typename ReductionKernel::Params;
|
||||
struct Params {
|
||||
KernelParams fmha_params;
|
||||
ReductionParams reduction_params;
|
||||
};
|
||||
|
||||
private:
|
||||
|
||||
/// Kernel API parameters object
|
||||
Params params_;
|
||||
|
||||
bool is_initialized(bool set = false) {
|
||||
static bool initialized = false;
|
||||
if (set) initialized = true;
|
||||
return initialized;
|
||||
}
|
||||
|
||||
static ReductionArguments to_reduction_args(Arguments const& args) {
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
return ReductionArguments{
|
||||
nullptr, args.epilogue.ptr_o, nullptr, args.epilogue.ptr_lse,
|
||||
args.mainloop.softmax_scale, B, args.split_kv, K, args.mainloop.ptr_seq,
|
||||
args.ptr_split_kv, Kernel::TileShapeS::value
|
||||
};
|
||||
}
|
||||
|
||||
public:
|
||||
|
||||
/// Access the Params structure
|
||||
Params const& params() const {
|
||||
return params_;
|
||||
}
|
||||
|
||||
static void set_split_kv (KernelArguments& args) {
|
||||
// printf("set_split_kv start");
|
||||
if (args.split_kv >= 1) return;
|
||||
auto [H, K, D, B] = args.problem_shape;
|
||||
// std::cout << H << " " << K << " " << D << " " << B << "\n";
|
||||
int sm_count = args.hw_info.sm_count;
|
||||
// printf(" sm_count = %d\n", sm_count);
|
||||
int max_splits = ceil_div(K, 128);
|
||||
max_splits = min(16, max_splits);
|
||||
// printf(" max_splits = %d\n", max_splits);
|
||||
int sms_per_batch = max(1, sm_count / B);
|
||||
// printf(" sms_per_batch = %d\n", sms_per_batch);
|
||||
int split_heur = min(max_splits, sms_per_batch);
|
||||
int waves = ceil_div(B * split_heur, sm_count);
|
||||
int k_waves = ceil_div(max_splits, split_heur);
|
||||
int split_wave_aware = ceil_div(max_splits, k_waves);
|
||||
args.split_kv = split_wave_aware;
|
||||
// printf(" args.split_kv = %d\n", args.split_kv);
|
||||
|
||||
}
|
||||
|
||||
/// Determines whether the GEMM can execute the given problem.
|
||||
static Status
|
||||
can_implement(Arguments const& args) {
|
||||
if (! Kernel::can_implement(args)) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
if (! ReductionKernel::can_implement(to_reduction_args(args))) {
|
||||
return Status::kInvalid;
|
||||
}
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Gets the workspace size
|
||||
static size_t
|
||||
get_workspace_size(Arguments const& args) {
|
||||
size_t workspace_bytes = 0;
|
||||
workspace_bytes += Kernel::get_workspace_size(args);
|
||||
workspace_bytes += ReductionKernel::get_workspace_size(to_reduction_args(args));
|
||||
return workspace_bytes;
|
||||
}
|
||||
|
||||
/// Computes the maximum number of active blocks per multiprocessor
|
||||
static int maximum_active_blocks(int /* smem_capacity */ = -1) {
|
||||
CUTLASS_TRACE_HOST("MLA::maximum_active_blocks()");
|
||||
int max_active_blocks = -1;
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
// first, account for dynamic smem capacity if needed
|
||||
cudaError_t result;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaFuncSetAttribute() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
}
|
||||
|
||||
// query occupancy after setting smem size
|
||||
result = cudaOccupancyMaxActiveBlocksPerMultiprocessor(
|
||||
&max_active_blocks,
|
||||
device_kernel<Kernel>,
|
||||
Kernel::MaxThreadsPerBlock,
|
||||
smem_size);
|
||||
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(
|
||||
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: "
|
||||
<< cudaGetErrorString(result));
|
||||
return -1;
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks);
|
||||
return max_active_blocks;
|
||||
}
|
||||
|
||||
/// Initializes GEMM state from arguments.
|
||||
Status
|
||||
initialize(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::initialize() - workspace "
|
||||
<< workspace << ", stream: " << (stream ? "non-null" : "null"));
|
||||
|
||||
// Initialize the workspace
|
||||
Status status = Kernel::initialize_workspace(args, workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
status = ReductionKernel::initialize_workspace(to_reduction_args(args), workspace, stream);
|
||||
if (status != Status::kSuccess) {
|
||||
return status;
|
||||
}
|
||||
KernelParams kernel_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = kernel_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = kernel_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {kernel_params, reduction_params};
|
||||
|
||||
if (is_initialized()) return Status::kSuccess;
|
||||
|
||||
// account for dynamic smem capacity if needed
|
||||
// no dynamic smem is needed for reduction kernel
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
if (smem_size >= (48 << 10)) {
|
||||
CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size);
|
||||
cudaError_t result = cudaFuncSetAttribute(
|
||||
device_kernel<Kernel>,
|
||||
cudaFuncAttributeMaxDynamicSharedMemorySize,
|
||||
smem_size);
|
||||
if (cudaSuccess != result) {
|
||||
result = cudaGetLastError(); // to clear the error bit
|
||||
CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result));
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
|
||||
is_initialized(true);
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Update API is preserved in 3.0, but does not guarantee a lightweight update of params.
|
||||
Status
|
||||
update(Arguments const& args, void* workspace = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA()::update() - workspace: " << workspace);
|
||||
|
||||
size_t workspace_bytes = get_workspace_size(args);
|
||||
if (workspace_bytes > 0 && nullptr == workspace) {
|
||||
return Status::kErrorWorkspaceNull;
|
||||
}
|
||||
|
||||
auto fmha_params = Kernel::to_underlying_arguments(args, workspace);
|
||||
|
||||
ReductionArguments reduction_args = to_reduction_args(args);
|
||||
if (reduction_args.split_kv > 1) {
|
||||
reduction_args.ptr_oaccum = fmha_params.epilogue.ptr_o_acc;
|
||||
reduction_args.ptr_lseaccum = fmha_params.epilogue.ptr_lse_acc;
|
||||
}
|
||||
ReductionParams reduction_params = ReductionKernel::to_underlying_arguments(reduction_args, workspace);
|
||||
// Initialize the Params structure
|
||||
params_ = Params {fmha_params, reduction_params};
|
||||
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
/// Primary run() entry point API that is static allowing users to create and manage their own params.
|
||||
/// Supplied params struct must be construct by calling Kernel::to_underling_arguments()
|
||||
static Status
|
||||
run(Params& params, cudaStream_t stream = nullptr) {
|
||||
CUTLASS_TRACE_HOST("MLA::run()");
|
||||
dim3 const block = Kernel::get_block_shape();
|
||||
dim3 const grid = Kernel::get_grid_shape(params.fmha_params);
|
||||
|
||||
// configure smem size and carveout
|
||||
int smem_size = Kernel::SharedStorageSize;
|
||||
|
||||
Status launch_result;
|
||||
// Use extended launch API only for mainloops that use it
|
||||
if constexpr(Kernel::ArchTag::kMinComputeCapability >= 90) {
|
||||
dim3 cluster(cute::size<0>(typename Kernel::ClusterShape{}),
|
||||
cute::size<1>(typename Kernel::ClusterShape{}),
|
||||
cute::size<2>(typename Kernel::ClusterShape{}));
|
||||
void const* kernel = (void const*) device_kernel<Kernel>;
|
||||
void* kernel_params[] = {¶ms.fmha_params};
|
||||
launch_result = ClusterLauncher::launch(grid, cluster, block, smem_size, stream, kernel, kernel_params);
|
||||
}
|
||||
else {
|
||||
launch_result = Status::kSuccess;
|
||||
device_kernel<Kernel><<<grid, block, smem_size, stream>>>(params.fmha_params);
|
||||
}
|
||||
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess != result or Status::kSuccess != launch_result) {
|
||||
//return Status::kSuccess;
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
if (params.reduction_params.split_kv > 1) {
|
||||
// launch reduction kernel
|
||||
dim3 const block = ReductionKernel::get_block_shape();
|
||||
dim3 const grid = ReductionKernel::get_grid_shape(params.reduction_params);
|
||||
device_kernel<ReductionKernel><<<grid, block, 0, stream>>>(params.reduction_params);
|
||||
cudaError_t result = cudaGetLastError();
|
||||
if (cudaSuccess == result) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
else {
|
||||
CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result);
|
||||
return Status::kErrorInternal;
|
||||
}
|
||||
}
|
||||
else {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// Non-static launch overloads that first create and set the internal params struct of this kernel handle.
|
||||
//
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
run(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
Status status = initialize(args, workspace, stream);
|
||||
if (Status::kSuccess == status) {
|
||||
status = run(params_, stream);
|
||||
}
|
||||
return status;
|
||||
}
|
||||
|
||||
/// Launches the kernel after first constructing Params internal state from supplied arguments.
|
||||
Status
|
||||
operator()(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr) {
|
||||
return run(args, workspace, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
run(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
|
||||
/// Overload that allows a user to re-launch the same kernel without updating internal params struct.
|
||||
Status
|
||||
operator()(cudaStream_t stream = nullptr) {
|
||||
return run(params_, stream);
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::device
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
@ -0,0 +1,203 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/arch/arch.h"
|
||||
#include "cute/tensor.hpp"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
using namespace cute;
|
||||
template<
|
||||
class ElementOut,
|
||||
class ElementAcc,
|
||||
class ElementScale,
|
||||
size_t kNumHeads,
|
||||
size_t kHeadDimLatent,
|
||||
int kMaxSplits
|
||||
>
|
||||
struct Sm100FmhaMlaReductionKernel {
|
||||
|
||||
static const int SharedStorageSize = 0;
|
||||
static const int MaxThreadsPerBlock = 128;
|
||||
static const int MinBlocksPerMultiprocessor = 1;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
static_assert(kHeadDimLatent % MaxThreadsPerBlock == 0);
|
||||
struct Arguments {
|
||||
ElementAcc* ptr_oaccum = nullptr;
|
||||
ElementOut* ptr_o = nullptr;
|
||||
ElementAcc* ptr_lseaccum = nullptr;
|
||||
ElementAcc* ptr_lse = nullptr;
|
||||
ElementScale scale = 1.f;
|
||||
int num_batches = 0;
|
||||
int split_kv = -1;
|
||||
int dim_k = -1;
|
||||
int* ptr_seq = nullptr;
|
||||
int* ptr_split_kv = nullptr;
|
||||
int tile_shape_s = 128;
|
||||
};
|
||||
using Params = Arguments;
|
||||
|
||||
static Params to_underlying_arguments(Arguments const& args, void* workspace) {
|
||||
return {args.ptr_oaccum, args.ptr_o, args.ptr_lseaccum, args.ptr_lse,
|
||||
args.scale, args.num_batches, args.split_kv, args.dim_k, args.ptr_seq,
|
||||
args.ptr_split_kv, args.tile_shape_s};
|
||||
}
|
||||
|
||||
static size_t get_workspace_size(Arguments const& /*args*/) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
static Status initialize_workspace(
|
||||
Arguments const& /*args*/, void* /*ws*/, cudaStream_t /*stream*/) {
|
||||
return Status::kSuccess;
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return dim3(kNumHeads, 1, params.num_batches);
|
||||
}
|
||||
|
||||
static dim3 get_block_shape() {
|
||||
return dim3(MaxThreadsPerBlock, 1, 1);
|
||||
}
|
||||
|
||||
static bool can_implement(Arguments const& args) {
|
||||
if (args.num_batches <= 0) return false;
|
||||
if (args.split_kv <= 0) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE void operator() (Params const& params, char* smem_raw) {
|
||||
if (params.split_kv <= 1) return;
|
||||
auto blk_coord = make_coord(blockIdx.x, _0{}, blockIdx.z);
|
||||
|
||||
__shared__ ElementAcc sLseScale[kMaxSplits];
|
||||
const size_t offset_lseaccum = get<0>(blk_coord) + kNumHeads * params.split_kv * get<2>(blk_coord);
|
||||
const size_t offset_lse = get<0>(blk_coord) + kNumHeads * get<2>(blk_coord);
|
||||
|
||||
Tensor gLSEaccum = make_tensor(make_gmem_ptr(params.ptr_lseaccum + offset_lseaccum),
|
||||
make_shape(params.split_kv), Stride<Int<kNumHeads>>{});
|
||||
|
||||
Tensor gLSE = make_tensor(make_gmem_ptr(params.ptr_lse + offset_lse),
|
||||
Shape<_1>{}, Stride<_1>{});
|
||||
|
||||
auto dim_k = params.ptr_seq == nullptr ? params.dim_k : params.ptr_seq[get<2>(blk_coord)];
|
||||
auto local_split_kv = params.ptr_split_kv == nullptr ? params.split_kv : params.ptr_split_kv[get<2>(blk_coord)];
|
||||
auto k_tile_total = ceil_div(dim_k, params.tile_shape_s);
|
||||
auto k_tile_per_cta = ceil_div(k_tile_total, local_split_kv);
|
||||
local_split_kv = ceil_div(k_tile_total, k_tile_per_cta);
|
||||
|
||||
int warp_idx = cutlass::canonical_warp_idx_sync();
|
||||
if (warp_idx == 0) {
|
||||
constexpr int kNLsePerThread = cute::ceil_div(kMaxSplits, 32);
|
||||
|
||||
ElementAcc local_lse[kNLsePerThread];
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
local_lse[i] = split < local_split_kv ? gLSEaccum(split) : -std::numeric_limits<ElementAcc>::infinity();
|
||||
}
|
||||
|
||||
ElementAcc lse_max = -std::numeric_limits<ElementAcc>::infinity();
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
lse_max = max(lse_max, local_lse[i]);
|
||||
}
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
lse_max = max(lse_max, __shfl_xor_sync(0xffffffff, lse_max, offset));
|
||||
}
|
||||
lse_max = lse_max == -std::numeric_limits<ElementAcc>::infinity() ? 0.0f : lse_max; // In case all local LSEs are -inf
|
||||
lse_max = __shfl_sync(0xffffffff, lse_max, 0);
|
||||
|
||||
ElementAcc sum_lse = 0;
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
sum_lse = sum_lse + expf(local_lse[i] - lse_max);
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int offset = 16; offset >= 1; offset /= 2) {
|
||||
sum_lse = sum_lse + __shfl_xor_sync(0xffffffff, sum_lse, offset);
|
||||
}
|
||||
|
||||
sum_lse = __shfl_sync(0xffffffff, sum_lse, 0);
|
||||
|
||||
ElementAcc global_lse = (sum_lse == 0.f || sum_lse != sum_lse) ? std::numeric_limits<ElementAcc>::infinity() : logf(sum_lse) + lse_max;
|
||||
if (threadIdx.x == 0 and params.ptr_lse != nullptr) {
|
||||
gLSE(0) = global_lse;
|
||||
}
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for (int i = 0; i < kNLsePerThread; ++i) {
|
||||
const int split = i * 32 + threadIdx.x;
|
||||
if (split < local_split_kv) {
|
||||
sLseScale[split] = expf(local_lse[i] - global_lse);
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
constexpr int Elements = kHeadDimLatent / MaxThreadsPerBlock;
|
||||
const size_t offset_oaccum = kHeadDimLatent * params.split_kv * (get<0>(blk_coord) + kNumHeads * get<2>(blk_coord));
|
||||
Tensor gOaccum = make_tensor(make_gmem_ptr(params.ptr_oaccum + offset_oaccum),
|
||||
Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
ElementAcc local_val[Elements] = {0};
|
||||
for (int split = 0; split < local_split_kv; ++split) {
|
||||
ElementAcc lse_scale = sLseScale[split];
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
local_val[i] += lse_scale * gOaccum(threadIdx.x + MaxThreadsPerBlock * i);
|
||||
}
|
||||
gOaccum.data() = gOaccum.data() + kHeadDimLatent;
|
||||
}
|
||||
auto ptr_o_local = params.ptr_o + (get<0>(blk_coord) + get<2>(blk_coord) * kNumHeads) * kHeadDimLatent;
|
||||
Tensor gO = make_tensor(make_gmem_ptr(ptr_o_local), Shape<Int<kHeadDimLatent>>{}, Stride<_1>{});
|
||||
|
||||
CUTLASS_PRAGMA_UNROLL
|
||||
for(int i = 0; i < Elements; ++i) {
|
||||
gO(threadIdx.x + MaxThreadsPerBlock * i) = static_cast<ElementOut>(local_val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,165 @@
|
||||
/***************************************************************************************************
|
||||
* Copyright (c) 2024 - 2025 NVIDIA CORPORATION & AFFILIATES. All rights
|
||||
*reserved. SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice,
|
||||
*this list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holder nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
|
||||
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
|
||||
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
|
||||
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
|
||||
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
|
||||
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
|
||||
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
|
||||
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
|
||||
*POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
|
||||
// clang-format off
|
||||
#pragma once
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "cutlass/fast_math.h"
|
||||
#include "cutlass/kernel_hardware_info.h"
|
||||
|
||||
namespace cutlass::fmha::kernel {
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaIndividualTileScheduler {
|
||||
|
||||
struct Params {
|
||||
dim3 grid;
|
||||
};
|
||||
|
||||
bool valid_ = true;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler(Params const&) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
dim3 grid(get<0>(cluster_shape), get<3>(problem_shape) /* Batch */, split_kv /*Maximum Split KV*/);
|
||||
return Params{ grid };
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
return params.grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return valid_;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
return make_coord(blockIdx.x, _0{}, blockIdx.y, blockIdx.z);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaIndividualTileScheduler& operator++() {
|
||||
valid_ = false;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct Sm100MlaPersistentTileScheduler {
|
||||
|
||||
struct Params {
|
||||
int num_blocks;
|
||||
FastDivmod divmod_m_block;
|
||||
FastDivmod divmod_b;
|
||||
FastDivmod divmod_split_kv;
|
||||
KernelHardwareInfo hw_info;
|
||||
};
|
||||
|
||||
int block_idx = 0;
|
||||
Params params;
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler(Params const& params) : block_idx(blockIdx.x), params(params) {}
|
||||
|
||||
template<class ProblemShape, class ClusterShape>
|
||||
static Params to_underlying_arguments(
|
||||
ProblemShape const& problem_shape, KernelHardwareInfo hw_info,
|
||||
ClusterShape const& cluster_shape, int const& split_kv) {
|
||||
using namespace cute;
|
||||
// Get SM count if needed, otherwise use user supplied SM count
|
||||
int sm_count = hw_info.sm_count;
|
||||
if (sm_count <= 1 || sm_count % size<0>(cluster_shape) != 0) {
|
||||
CUTLASS_TRACE_HOST(" WARNING: Arguments do not include a valid SM count.\n"
|
||||
" For optimal performance, populate the arguments KernelHardwareInfo struct with the SM count.");
|
||||
sm_count = KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
}
|
||||
|
||||
CUTLASS_TRACE_HOST("to_underlying_arguments(): Setting persistent grid SM count to " << sm_count);
|
||||
hw_info.sm_count = sm_count;
|
||||
|
||||
int num_m_blocks = size<0>(cluster_shape);
|
||||
int num_blocks = num_m_blocks * get<3>(problem_shape) /* Batch */;
|
||||
num_blocks *= split_kv; /* Maximum Split KV*/
|
||||
|
||||
return Params {
|
||||
num_blocks,
|
||||
{ num_m_blocks}, { get<3>(problem_shape) }, {split_kv},
|
||||
hw_info
|
||||
};
|
||||
}
|
||||
|
||||
static dim3 get_grid_shape(Params const& params) {
|
||||
dim3 grid(std::min(params.num_blocks, params.hw_info.sm_count), 1, 1);
|
||||
return grid;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
bool is_valid() {
|
||||
return block_idx < params.num_blocks;
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
auto get_block_coord() {
|
||||
using namespace cute;
|
||||
int block_decode = block_idx;
|
||||
int m_block, bidb, n_split_kv;
|
||||
params.divmod_m_block(block_decode, m_block, block_decode);
|
||||
params.divmod_b(block_decode, bidb, block_decode);
|
||||
params.divmod_split_kv(block_decode, n_split_kv, block_decode);
|
||||
return make_coord(m_block, _0{}, bidb, n_split_kv);
|
||||
}
|
||||
|
||||
CUTLASS_DEVICE
|
||||
Sm100MlaPersistentTileScheduler& operator++() {
|
||||
block_idx += gridDim.x;
|
||||
return *this;
|
||||
}
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
} // namespace cutlass::fmha::kernel
|
283
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
283
csrc/attention/mla/sm100_cutlass_mla_kernel.cu
Normal file
@ -0,0 +1,283 @@
|
||||
/*
|
||||
Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/*
|
||||
* Taken from SGLANG PR https://github.com/sgl-project/sglang/pull/6929
|
||||
* by Alcanderian JieXin Liang
|
||||
*/
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cutlass/cutlass.h>
|
||||
#include <cutlass/kernel_hardware_info.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include <cute/tensor.hpp>
|
||||
#include <iostream>
|
||||
|
||||
#include "cutlass_sm100_mla/device/sm100_mla.hpp"
|
||||
#include "cutlass_sm100_mla/kernel/sm100_mla_tile_scheduler.hpp"
|
||||
|
||||
// clang-format off
|
||||
#if !defined(CUDA_VERSION) || CUDA_VERSION < 12040
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_decode");
|
||||
}
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
TORCH_CHECK(false, "CUDA version must be >= 12.4 for cutlass_mla_get_workspace_size");
|
||||
}
|
||||
#else
|
||||
|
||||
#define CUTLASS_CHECK(status) \
|
||||
{ \
|
||||
cutlass::Status error = status; \
|
||||
TORCH_CHECK(error == cutlass::Status::kSuccess, cutlassGetStatusString(error)); \
|
||||
}
|
||||
|
||||
using namespace cute;
|
||||
using namespace cutlass::fmha::kernel;
|
||||
|
||||
template <bool v>
|
||||
struct IsPersistent {
|
||||
static const bool value = v;
|
||||
};
|
||||
|
||||
template <typename T, bool IsPaged128, typename PersistenceOption = IsPersistent<true>>
|
||||
struct MlaSm100 {
|
||||
using Element = T;
|
||||
using ElementAcc = float;
|
||||
using ElementOut = T;
|
||||
|
||||
using TileShape = Shape<_128, _128, Shape<_512, _64>>;
|
||||
using TileShapeH = cute::tuple_element_t<0, TileShape>;
|
||||
using TileShapeD = cute::tuple_element_t<2, TileShape>;
|
||||
|
||||
// H K (D_latent D_rope) B
|
||||
using ProblemShape = cute::tuple<TileShapeH, int, TileShapeD, int>;
|
||||
|
||||
using StrideQ = cute::tuple<int64_t, _1, int64_t>; // H D B
|
||||
using StrideK = cute::tuple<int64_t, _1, int64_t>; // K D B
|
||||
using StrideO = StrideK; // H D B
|
||||
using StrideLSE = cute::tuple<_1, int>; // H B
|
||||
|
||||
using TileScheduler =
|
||||
std::conditional_t<PersistenceOption::value, Sm100MlaPersistentTileScheduler, Sm100MlaIndividualTileScheduler>;
|
||||
|
||||
using FmhaKernel = cutlass::fmha::kernel::Sm100FmhaMlaKernelTmaWarpspecialized<
|
||||
TileShape,
|
||||
Element,
|
||||
ElementAcc,
|
||||
ElementOut,
|
||||
ElementAcc,
|
||||
TileScheduler,
|
||||
/*kIsCpAsync=*/!IsPaged128>;
|
||||
using Fmha = cutlass::fmha::device::MLA<FmhaKernel>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Fmha::Arguments args_from_options(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = q_nope.device().index();
|
||||
hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id);
|
||||
|
||||
int batches = q_nope.sizes()[0];
|
||||
int page_count_per_seq = page_table.sizes()[1];
|
||||
int page_count_total = kv_c_and_k_pe_cache.sizes()[0];
|
||||
int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
int max_seq_len = page_size * page_count_per_seq;
|
||||
using TileShapeH = typename T::TileShapeH;
|
||||
using TileShapeD = typename T::TileShapeD;
|
||||
auto problem_shape = cute::make_tuple(TileShapeH{}, max_seq_len, TileShapeD{}, batches);
|
||||
|
||||
auto [H, K, D, B] = problem_shape;
|
||||
auto [D_latent, D_rope] = D;
|
||||
|
||||
float scale = float(sm_scale);
|
||||
|
||||
using StrideQ = typename T::StrideQ;
|
||||
using StrideK = typename T::StrideK;
|
||||
using StrideO = typename T::StrideO;
|
||||
using StrideLSE = typename T::StrideLSE;
|
||||
|
||||
StrideQ stride_Q_nope = cute::make_tuple(
|
||||
static_cast<int64_t>(q_nope.stride(1)), _1{}, static_cast<int64_t>(q_nope.stride(0)));
|
||||
StrideQ stride_Q_pe = cute::make_tuple(
|
||||
static_cast<int64_t>(q_pe.stride(1)), _1{}, static_cast<int64_t>(q_pe.stride(0)));
|
||||
|
||||
StrideK stride_C = cute::make_tuple(
|
||||
static_cast<int64_t>(0 + D_latent + D_rope), _1{}, static_cast<int64_t>(page_size * (D_latent + D_rope)));
|
||||
StrideLSE stride_PT = cute::make_stride(_1{}, page_count_per_seq);
|
||||
StrideLSE stride_LSE = cute::make_tuple(_1{}, 0 + H);
|
||||
StrideO stride_O = cute::make_tuple(static_cast<int64_t>(0 + D_latent), _1{}, static_cast<int64_t>(0 + H * D_latent));
|
||||
|
||||
using Element = typename T::Element;
|
||||
using ElementOut = typename T::ElementOut;
|
||||
using ElementAcc = typename T::ElementAcc;
|
||||
auto Q_nope_ptr = static_cast<Element*>(q_nope.data_ptr());
|
||||
auto Q_pe_ptr = static_cast<Element*>(q_pe.data_ptr());
|
||||
auto C_ptr = static_cast<Element*>(kv_c_and_k_pe_cache.data_ptr());
|
||||
typename T::Fmha::Arguments arguments{
|
||||
problem_shape,
|
||||
{scale,
|
||||
Q_nope_ptr,
|
||||
stride_Q_nope,
|
||||
Q_pe_ptr,
|
||||
stride_Q_pe,
|
||||
C_ptr,
|
||||
stride_C,
|
||||
C_ptr + D_latent,
|
||||
stride_C,
|
||||
static_cast<int*>(seq_lens.data_ptr()),
|
||||
static_cast<int*>(page_table.data_ptr()),
|
||||
stride_PT,
|
||||
page_count_total,
|
||||
page_size},
|
||||
{static_cast<ElementOut*>(out.data_ptr()), stride_O, static_cast<ElementAcc*>(nullptr), stride_LSE},
|
||||
hw_info,
|
||||
// TODO(trevor-m): Change split_kv back to -1 when
|
||||
// https://github.com/NVIDIA/cutlass/issues/2274 is fixed. Split_kv=1 will
|
||||
// perform worse with larger context length and smaller batch sizes.
|
||||
num_kv_splits, // split_kv
|
||||
nullptr, // is_var_split_kv
|
||||
};
|
||||
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
|
||||
// split_kv automatically based on batch size and sequence length to balance
|
||||
// workload across available SMs. Consider using var_split_kv for manual
|
||||
// control if needed.
|
||||
T::Fmha::set_split_kv(arguments);
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename Element, bool IsPaged128, typename PersistenceOption>
|
||||
void runMla(
|
||||
at::Tensor const& out,
|
||||
at::Tensor const& q_nope,
|
||||
at::Tensor const& q_pe,
|
||||
at::Tensor const& kv_c_and_k_pe_cache,
|
||||
at::Tensor const& seq_lens,
|
||||
at::Tensor const& page_table,
|
||||
at::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits,
|
||||
cudaStream_t stream) {
|
||||
using MlaSm100Type = MlaSm100<Element, IsPaged128, PersistenceOption>;
|
||||
typename MlaSm100Type::Fmha fmha;
|
||||
auto arguments = args_from_options<MlaSm100Type>(out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, sm_scale, num_kv_splits);
|
||||
|
||||
CUTLASS_CHECK(fmha.can_implement(arguments));
|
||||
|
||||
CUTLASS_CHECK(fmha.initialize(arguments, workspace.data_ptr(), stream));
|
||||
|
||||
CUTLASS_CHECK(fmha.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
#define DISPATCH_BOOL(expr, const_expr, ...) \
|
||||
[&]() -> bool { \
|
||||
if (expr) { \
|
||||
constexpr bool const_expr = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool const_expr = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
void sm100_cutlass_mla_decode(
|
||||
torch::Tensor const& out,
|
||||
torch::Tensor const& q_nope,
|
||||
torch::Tensor const& q_pe,
|
||||
torch::Tensor const& kv_c_and_k_pe_cache,
|
||||
torch::Tensor const& seq_lens,
|
||||
torch::Tensor const& page_table,
|
||||
torch::Tensor const& workspace,
|
||||
double sm_scale,
|
||||
int64_t num_kv_splits) {
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
const int page_size = kv_c_and_k_pe_cache.sizes()[1];
|
||||
|
||||
// NOTE(alcanderian): IsPersistent has bug with manual split_kv.
|
||||
// Kernel will hang if batch is too large with large num_kv_splits. (for example bs=8, num_kv_splits=8)
|
||||
// Maybe per batch split kv will fix this.
|
||||
DISPATCH_BOOL(page_size == 128, IsPaged128, [&] {
|
||||
DISPATCH_BOOL(num_kv_splits <= 1, NotManualSplitKV, [&] {
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
runMla<cutlass::half_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::BFloat16) {
|
||||
runMla<cutlass::bfloat16_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else if (in_dtype == at::ScalarType::Float8_e4m3fn) {
|
||||
runMla<cutlass::float_e4m3_t, IsPaged128, IsPersistent<NotManualSplitKV>>(
|
||||
out, q_nope, q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace, sm_scale, num_kv_splits, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported input data type of MLA");
|
||||
}
|
||||
return true;
|
||||
});
|
||||
return true;
|
||||
});
|
||||
}
|
||||
|
||||
int64_t sm100_cutlass_mla_get_workspace_size(int64_t max_seq_len, int64_t num_batches, int64_t sm_count, int64_t num_kv_splits) {
|
||||
// Workspace size depends on ElementAcc and ElementLSE (same as ElementAcc)
|
||||
// which are float, so Element type here doesn't matter.
|
||||
using MlaSm100Type = MlaSm100<cutlass::half_t, true>;
|
||||
|
||||
// Get split kv. Requires problem shape and sm_count only.
|
||||
typename MlaSm100Type::Fmha::Arguments arguments;
|
||||
using TileShapeH = typename MlaSm100Type::TileShapeH;
|
||||
using TileShapeD = typename MlaSm100Type::TileShapeD;
|
||||
arguments.problem_shape =
|
||||
cute::make_tuple(TileShapeH{}, static_cast<int>(max_seq_len), TileShapeD{}, static_cast<int>(num_batches));
|
||||
// Assumes device 0 when getting sm_count.
|
||||
arguments.hw_info.sm_count =
|
||||
sm_count <= 0 ? cutlass::KernelHardwareInfo::query_device_multiprocessor_count(/*device_id=*/0) : sm_count;
|
||||
arguments.split_kv = num_kv_splits;
|
||||
MlaSm100Type::Fmha::set_split_kv(arguments);
|
||||
|
||||
return MlaSm100Type::Fmha::get_workspace_size(arguments);
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("sm100_cutlass_mla_decode", &sm100_cutlass_mla_decode);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CatchAll, m) {
|
||||
m.impl("sm100_cutlass_mla_get_workspace_size", &sm100_cutlass_mla_get_workspace_size);
|
||||
}
|
||||
|
||||
// clang-format on
|
@ -18,12 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -187,7 +182,6 @@ void paged_attention_v1(
|
||||
CALL_V1_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -18,12 +18,7 @@
|
||||
*/
|
||||
|
||||
#include "attention_kernels.cuh"
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#endif
|
||||
#include "cuda_compat.h"
|
||||
|
||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||
#define MIN(a, b) ((a) < (b) ? (a) : (b))
|
||||
@ -197,7 +192,6 @@ void paged_attention_v2(
|
||||
CALL_V2_LAUNCHER_BLOCK_SIZE)
|
||||
}
|
||||
|
||||
#undef WARP_SIZE
|
||||
#undef MAX
|
||||
#undef MIN
|
||||
#undef DIVIDE_ROUND_UP
|
||||
|
@ -33,6 +33,8 @@ namespace vec_op {
|
||||
#endif
|
||||
|
||||
#define FORCE_INLINE __attribute__((always_inline)) inline
|
||||
// Number of elements in single ASIMD vector of given Datatype
|
||||
#define NUM_ELEMENTS_REG(vec) (sizeof(vec) / sizeof(vec[0]))
|
||||
|
||||
namespace {
|
||||
template <typename T, T... indexes, typename F>
|
||||
@ -86,8 +88,8 @@ struct FP16Vec16 : public Vec<FP16Vec16> {
|
||||
}
|
||||
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / 8;
|
||||
int remainder = elem_num % 8;
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
if (full_blocks > 0) {
|
||||
vst1q_f16(reinterpret_cast<__fp16*>(ptr), reg.val[0]);
|
||||
@ -197,6 +199,25 @@ struct BF16Vec16 : public Vec<BF16Vec16> {
|
||||
vcvtq_high_bf16_f32(vcvtq_low_bf16_f32(v.val[2]), v.val[3])}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x2_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
if (remainder > 0) base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
@ -213,6 +234,25 @@ struct BF16Vec32 : public Vec<BF16Vec32> {
|
||||
: reg({vec8_data.reg, vec8_data.reg, vec8_data.reg, vec8_data.reg}) {};
|
||||
|
||||
void save(void* ptr) const { *reinterpret_cast<bfloat16x8x4_t*>(ptr) = reg; };
|
||||
void save(void* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_bf16(
|
||||
reinterpret_cast<__bf16*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
if (remainder > 0) {
|
||||
bfloat16x8_t temp = reg.val[full_blocks];
|
||||
bfloat16_t* base = reinterpret_cast<bfloat16_t*>(ptr) + full_blocks * 8;
|
||||
base[0] = vgetq_lane_bf16(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_bf16(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_bf16(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_bf16(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_bf16(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_bf16(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_bf16(temp, 6);
|
||||
}
|
||||
};
|
||||
};
|
||||
#endif
|
||||
|
||||
@ -372,6 +412,48 @@ struct FP32Vec8 : public Vec<FP32Vec8> {
|
||||
}
|
||||
};
|
||||
|
||||
struct INT32Vec16 : public Vec<INT32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int32x4x4_t reg;
|
||||
int32_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int32x4x4_t reg;
|
||||
|
||||
explicit INT32Vec16(const void* ptr) {
|
||||
reg.val[0] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr));
|
||||
reg.val[1] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 4);
|
||||
reg.val[2] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 8);
|
||||
reg.val[3] = vld1q_s32(reinterpret_cast<const int32_t*>(ptr) + 12);
|
||||
}
|
||||
|
||||
void save(int32_t* ptr) const {
|
||||
vst1q_s32(ptr, reg.val[0]);
|
||||
vst1q_s32(ptr + 4, reg.val[1]);
|
||||
vst1q_s32(ptr + 8, reg.val[2]);
|
||||
vst1q_s32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(int32_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s32(
|
||||
reinterpret_cast<__int32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
int32x4_t temp = reg.val[full_blocks];
|
||||
int32_t* base = reinterpret_cast<int32_t*>(ptr) + full_blocks * 4;
|
||||
if (remainder > 0) base[0] = vgetq_lane_s32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s32(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s32(temp, 3);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
@ -434,7 +516,12 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
reg.val[2] = vcvt_f32_f16(vget_low_f16(v.reg.val[1]));
|
||||
reg.val[3] = vcvt_f32_f16(vget_high_f16(v.reg.val[1]));
|
||||
};
|
||||
|
||||
explicit FP32Vec16(const INT32Vec16& v) {
|
||||
reg.val[0] = vcvtq_f32_s32(v.reg.val[0]);
|
||||
reg.val[1] = vcvtq_f32_s32(v.reg.val[1]);
|
||||
reg.val[2] = vcvtq_f32_s32(v.reg.val[2]);
|
||||
reg.val[3] = vcvtq_f32_s32(v.reg.val[3]);
|
||||
};
|
||||
FP32Vec16 operator+(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vaddq_f32(reg.val[0], b.reg.val[0]),
|
||||
vaddq_f32(reg.val[1], b.reg.val[1]),
|
||||
@ -463,6 +550,85 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vdivq_f32(reg.val[3], b.reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 clamp(const FP32Vec16& min, const FP32Vec16& max) const {
|
||||
return FP32Vec16(float32x4x4_t(
|
||||
{vminq_f32(max.reg.val[0], vmaxq_f32(min.reg.val[0], reg.val[0])),
|
||||
vminq_f32(max.reg.val[1], vmaxq_f32(min.reg.val[1], reg.val[1])),
|
||||
vminq_f32(max.reg.val[2], vmaxq_f32(min.reg.val[2], reg.val[2])),
|
||||
vminq_f32(max.reg.val[3], vmaxq_f32(min.reg.val[3], reg.val[3]))}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({vmaxq_f32(b.reg.val[0], reg.val[0]),
|
||||
vmaxq_f32(b.reg.val[1], reg.val[1]),
|
||||
vmaxq_f32(b.reg.val[2], reg.val[2]),
|
||||
vmaxq_f32(b.reg.val[3], reg.val[3])}));
|
||||
};
|
||||
|
||||
FP32Vec16 max(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vmaxq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float max_v = std::max(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(max_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
|
||||
FP32Vec16 min(const FP32Vec16& b) const {
|
||||
return FP32Vec16(float32x4x4_t({
|
||||
vminq_f32(b.reg.val[0], reg.val[0]),
|
||||
vminq_f32(b.reg.val[1], reg.val[1]),
|
||||
vminq_f32(b.reg.val[2], reg.val[2]),
|
||||
vminq_f32(b.reg.val[3], reg.val[3]),
|
||||
}));
|
||||
};
|
||||
FP32Vec16 min(const FP32Vec16& b, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
const int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
float32x4x4_t temp;
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
temp.val[i] = vminq_f32(b.reg.val[i], reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 0),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 0));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 0);
|
||||
}
|
||||
if (remainder > 1) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 1),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 1));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 1);
|
||||
}
|
||||
if (remainder > 2) {
|
||||
float min_v = std::min(vgetq_lane_f32(reg.val[full_blocks], 2),
|
||||
vgetq_lane_f32(b.reg.val[full_blocks], 2));
|
||||
temp.val[full_blocks] = vsetq_lane_f32(min_v, temp.val[full_blocks], 2);
|
||||
}
|
||||
|
||||
return FP32Vec16(temp);
|
||||
};
|
||||
FP32Vec16 abs() const {
|
||||
return FP32Vec16(
|
||||
float32x4x4_t({vabsq_f32(reg.val[0]), vabsq_f32(reg.val[1]),
|
||||
vabsq_f32(reg.val[2]), vabsq_f32(reg.val[3])}));
|
||||
}
|
||||
float reduce_sum() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
@ -473,6 +639,24 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
return answer;
|
||||
};
|
||||
|
||||
float reduce_max() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float max_v = std::numeric_limits<float>::lowest();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&max_v, &ar](int i) { max_v = std::max(max_v, ar.values[i]); });
|
||||
return max_v;
|
||||
}
|
||||
|
||||
float reduce_min() const {
|
||||
AliasReg ar;
|
||||
ar.reg = reg;
|
||||
float min_v = std::numeric_limits<float>::max();
|
||||
unroll_loop<int, VEC_ELEM_NUM>(
|
||||
[&min_v, &ar](int i) { min_v = std::min(min_v, ar.values[i]); });
|
||||
return min_v;
|
||||
}
|
||||
|
||||
template <int group_size>
|
||||
float reduce_sub_sum(int idx) {
|
||||
static_assert(VEC_ELEM_NUM % group_size == 0);
|
||||
@ -493,6 +677,83 @@ struct FP32Vec16 : public Vec<FP32Vec16> {
|
||||
vst1q_f32(ptr + 8, reg.val[2]);
|
||||
vst1q_f32(ptr + 12, reg.val[3]);
|
||||
};
|
||||
|
||||
void save(float* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg.val[0]);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg.val[0]);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_f32(
|
||||
reinterpret_cast<float32_t*>(ptr) + NUM_ELEMENTS_REG(reg.val[0]) * i,
|
||||
reg.val[i]);
|
||||
|
||||
if (remainder > 0) {
|
||||
float32x4_t temp = reg.val[full_blocks];
|
||||
float* base = reinterpret_cast<float32_t*>(ptr) +
|
||||
full_blocks * NUM_ELEMENTS_REG(reg.val[0]);
|
||||
if (remainder > 0) base[0] = vgetq_lane_f32(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_f32(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_f32(temp, 2);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
struct INT8Vec16 : public Vec<INT8Vec16> {
|
||||
constexpr static int VEC_ELEM_NUM = 16;
|
||||
union AliasReg {
|
||||
int8x16_t reg;
|
||||
int8_t values[VEC_ELEM_NUM];
|
||||
};
|
||||
int8x16_t reg;
|
||||
|
||||
explicit INT8Vec16(const FP32Vec16& vec) {
|
||||
// Convert each 128-bit float32 vector to int32
|
||||
int32x4_t part0 =
|
||||
vcvtq_s32_f32(vec.reg.val[0]); // Convert first 128-bit block
|
||||
int32x4_t part1 =
|
||||
vcvtq_s32_f32(vec.reg.val[1]); // Convert second 128-bit block
|
||||
int32x4_t part2 =
|
||||
vcvtq_s32_f32(vec.reg.val[2]); // Convert third 128-bit block
|
||||
int32x4_t part3 =
|
||||
vcvtq_s32_f32(vec.reg.val[3]); // Convert fourth 128-bit block
|
||||
|
||||
// Narrow each 32-bit vector to 8 bits and combine
|
||||
int8x8_t lower =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part0), vqmovn_s32(part1)));
|
||||
int8x8_t upper =
|
||||
vqmovn_s16(vcombine_s16(vqmovn_s32(part2), vqmovn_s32(part3)));
|
||||
reg = vcombine_s8(lower, upper); // Combine to form a single 128-bit vector
|
||||
}
|
||||
|
||||
void save(int8_t* ptr) const { vst1q_s8(ptr, reg); };
|
||||
|
||||
void save(int8_t* ptr, const int elem_num) const {
|
||||
int full_blocks = elem_num / NUM_ELEMENTS_REG(reg);
|
||||
int remainder = elem_num % NUM_ELEMENTS_REG(reg);
|
||||
|
||||
for (int i = 0; i < full_blocks; i++)
|
||||
vst1q_s8(reinterpret_cast<int8_t*>(ptr) + NUM_ELEMENTS_REG(reg) * i, reg);
|
||||
if (remainder > 0) {
|
||||
int8x16_t temp = reg;
|
||||
int8_t* base =
|
||||
reinterpret_cast<int8_t*>(ptr) + full_blocks * NUM_ELEMENTS_REG(reg);
|
||||
if (remainder > 0) base[0] = vgetq_lane_s8(temp, 0);
|
||||
if (remainder > 1) base[1] = vgetq_lane_s8(temp, 1);
|
||||
if (remainder > 2) base[2] = vgetq_lane_s8(temp, 2);
|
||||
if (remainder > 3) base[3] = vgetq_lane_s8(temp, 3);
|
||||
if (remainder > 4) base[4] = vgetq_lane_s8(temp, 4);
|
||||
if (remainder > 5) base[5] = vgetq_lane_s8(temp, 5);
|
||||
if (remainder > 6) base[6] = vgetq_lane_s8(temp, 6);
|
||||
if (remainder > 7) base[7] = vgetq_lane_s8(temp, 7);
|
||||
if (remainder > 8) base[8] = vgetq_lane_s8(temp, 8);
|
||||
if (remainder > 9) base[9] = vgetq_lane_s8(temp, 9);
|
||||
if (remainder > 10) base[10] = vgetq_lane_s8(temp, 10);
|
||||
if (remainder > 11) base[11] = vgetq_lane_s8(temp, 11);
|
||||
if (remainder > 12) base[12] = vgetq_lane_s8(temp, 12);
|
||||
if (remainder > 13) base[13] = vgetq_lane_s8(temp, 13);
|
||||
if (remainder > 14) base[14] = vgetq_lane_s8(temp, 14);
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
@ -57,6 +57,7 @@ class DNNLPrimitiveHelper {
|
||||
// Note: Due to the limitation of oneDNN
|
||||
// (https://github.com/oneapi-src/oneDNN/issues/1636), the quantized bias is
|
||||
// not supported.
|
||||
|
||||
template <typename OutputT, typename BiasT>
|
||||
static void gemm_s8s8_jit(const int8_t* a, const int8_t* b, OutputT* c,
|
||||
const BiasT* bias, dnnl_dim_t M, dnnl_dim_t N,
|
||||
@ -90,6 +91,27 @@ class DNNLPrimitiveHelper {
|
||||
}
|
||||
|
||||
dnnl::matmul::primitive_desc matmul_pd;
|
||||
// Create memory descriptors with format_tag::any for the primitive. This
|
||||
// enables the matmul primitive to choose memory layouts for an
|
||||
// optimized primitive implementation, and these layouts may differ from the
|
||||
// ones provided by the user.
|
||||
#ifdef __aarch64__
|
||||
auto mat_src_md = dnnl::memory::desc({M, K}, dnnl::memory::data_type::s8,
|
||||
dnnl::memory::format_tag::any);
|
||||
auto mat_weights_md = dnnl::memory::desc(
|
||||
{K, N}, dnnl::memory::data_type::s8, dnnl::memory::format_tag::any);
|
||||
auto mat_dst_md =
|
||||
dnnl::memory::desc({M, N}, OutputType, dnnl::memory::format_tag::any);
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), mat_src_md,
|
||||
mat_weights_md, bias_md,
|
||||
mat_dst_md, attr);
|
||||
} else {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(
|
||||
default_engine(), mat_src_md, mat_weights_md, mat_dst_md, attr);
|
||||
}
|
||||
#else
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({1, N}, BiasType, {N, 1});
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
@ -98,6 +120,7 @@ class DNNLPrimitiveHelper {
|
||||
matmul_pd = dnnl::matmul::primitive_desc(default_engine(), a_md, b_md,
|
||||
c_md, attr);
|
||||
}
|
||||
#endif
|
||||
dnnl::matmul matmul(matmul_pd);
|
||||
|
||||
auto& engine = default_engine();
|
||||
@ -111,24 +134,34 @@ class DNNLPrimitiveHelper {
|
||||
(void*)b_scales);
|
||||
|
||||
auto& stream = default_stream();
|
||||
|
||||
auto mat_src_mem = a_m;
|
||||
auto mat_weights_mem = b_m;
|
||||
auto mat_dst_mem = c_m;
|
||||
#ifdef __aarch64__
|
||||
if (matmul_pd.weights_desc() != b_m.get_desc()) {
|
||||
mat_weights_mem = dnnl::memory(matmul_pd.weights_desc(), engine);
|
||||
dnnl::reorder(b_m, mat_weights_mem).execute(stream, b_m, mat_weights_mem);
|
||||
}
|
||||
#endif
|
||||
if constexpr (InputNoScale) {
|
||||
if (bias) {
|
||||
dnnl::memory::desc bias_md({N}, BiasType, {1});
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
}
|
||||
@ -138,19 +171,19 @@ class DNNLPrimitiveHelper {
|
||||
dnnl::memory bias_m(bias_md, engine, (void*)bias);
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_BIAS, bias_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
} else {
|
||||
matmul.execute(
|
||||
stream, {
|
||||
{DNNL_ARG_SRC, a_m},
|
||||
{DNNL_ARG_WEIGHTS, b_m},
|
||||
{DNNL_ARG_DST, c_m},
|
||||
{DNNL_ARG_SRC, mat_src_mem},
|
||||
{DNNL_ARG_WEIGHTS, mat_weights_mem},
|
||||
{DNNL_ARG_DST, mat_dst_mem},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, a_scales_m},
|
||||
{DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, b_scales_m},
|
||||
});
|
||||
@ -170,5 +203,4 @@ class DNNLPrimitiveHelper {
|
||||
return stream;
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
@ -36,7 +36,7 @@ struct KernelVecType<c10::Half> {
|
||||
using cvt_vec_type = vec_op::FP32Vec16;
|
||||
};
|
||||
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
template <bool AZP, typename scalar_t>
|
||||
void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
@ -598,8 +598,9 @@ void static_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
const float* scale, const int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false, "static_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"static_scaled_int8_quant_impl requires AVX512/powerpc64/AArch64 "
|
||||
"support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -607,9 +608,9 @@ void dynamic_scaled_int8_quant_impl(const scalar_t* input, int8_t* output,
|
||||
float* scale, int32_t* azp,
|
||||
const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_scaled_int8_quant_impl requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_scaled_int8_quant_impl requires "
|
||||
"AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <bool PerChannel, typename scalar_t>
|
||||
@ -617,7 +618,8 @@ void static_quant_epilogue(const float* input, scalar_t* output,
|
||||
const float a_scale, const float* b_scale,
|
||||
const int32_t* azp_with_adj, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false, "static_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false, "static_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -626,8 +628,9 @@ void dynamic_quant_epilogue(const float* input, scalar_t* output,
|
||||
const int32_t* azp, const int32_t* azp_with_adj,
|
||||
const scalar_t* bias, const int num_tokens,
|
||||
const int hidden_size) {
|
||||
TORCH_CHECK(false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64 support.")
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"dynamic_quant_epilogue requires AVX512/powerpc64/AArch64 support.")
|
||||
}
|
||||
#endif
|
||||
} // namespace
|
||||
|
@ -58,7 +58,7 @@ namespace {
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
|
@ -126,7 +126,7 @@ void fused_experts_int4_w4a16_kernel_impl(
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
// shared expert implementation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
|
@ -41,7 +41,7 @@ struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
// oops! 4x4 spills but luckily we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
|
@ -37,7 +37,7 @@ inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vecto
|
||||
#define CVT_FP16_TO_FP32(a) \
|
||||
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't hanel NaN.
|
||||
// this doesn't handle NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
|
@ -7,7 +7,7 @@
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024)
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
|
||||
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
|
||||
#define MIN_THREAD_PROCESS_SIZE (256)
|
||||
@ -34,9 +34,10 @@ struct KernelVecType<c10::Half> {
|
||||
};
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile char _curr_thread_stamp;
|
||||
volatile char _ready_thread_stamp;
|
||||
char _padding1[6];
|
||||
volatile char _curr_thread_stamp[2];
|
||||
volatile char _ready_thread_stamp[2];
|
||||
int local_stamp_buffer_idx;
|
||||
int remote_stamp_buffer_idx;
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
@ -45,23 +46,28 @@ struct ThreadSHMContext {
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
size_t _thread_buffer_mask;
|
||||
char _padding2[56];
|
||||
size_t _thread_buffer_mask[2];
|
||||
char _padding2[40];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: _curr_thread_stamp(1),
|
||||
_ready_thread_stamp(0),
|
||||
: local_stamp_buffer_idx(0),
|
||||
remote_stamp_buffer_idx(0),
|
||||
thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0),
|
||||
_thread_buffer_mask(0) {
|
||||
_spinning_count(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
TORCH_CHECK((size_t)thread_shm_ptr % 64 == 0);
|
||||
_curr_thread_stamp[0] = 1;
|
||||
_curr_thread_stamp[1] = 1;
|
||||
_ready_thread_stamp[0] = 0;
|
||||
_ready_thread_stamp[1] = 0;
|
||||
_thread_buffer_mask[0] = 0;
|
||||
_thread_buffer_mask[1] = 0;
|
||||
for (int i = 0; i < MAX_SHM_RANK_NUM; ++i) {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
@ -70,6 +76,11 @@ struct ThreadSHMContext {
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
|
||||
void set_stamp_buffer_idx(int local, int remote) {
|
||||
local_stamp_buffer_idx = local;
|
||||
remote_stamp_buffer_idx = remote;
|
||||
}
|
||||
|
||||
void set_context(int rank, ThreadSHMContext* ptr, void* thread_shm_ptr) {
|
||||
TORCH_CHECK(rank < MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK(ptr);
|
||||
@ -84,23 +95,27 @@ struct ThreadSHMContext {
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(
|
||||
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask));
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET &
|
||||
_thread_buffer_mask[local_stamp_buffer_idx]));
|
||||
}
|
||||
|
||||
void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; }
|
||||
void next_buffer() {
|
||||
_thread_buffer_mask[local_stamp_buffer_idx] ^= 0xFFFFFFFFFFFFFFFF;
|
||||
}
|
||||
|
||||
char get_curr_stamp() const { return _curr_thread_stamp; }
|
||||
char get_curr_stamp(int idx) const { return _curr_thread_stamp[idx]; }
|
||||
|
||||
char get_ready_stamp() const { return _ready_thread_stamp; }
|
||||
char get_ready_stamp(int idx) const { return _ready_thread_stamp[idx]; }
|
||||
|
||||
void next_stamp() {
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp += 1;
|
||||
_curr_thread_stamp[local_stamp_buffer_idx] += 1;
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp = _curr_thread_stamp;
|
||||
_ready_thread_stamp[local_stamp_buffer_idx] =
|
||||
_curr_thread_stamp[local_stamp_buffer_idx];
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
@ -117,10 +132,11 @@ struct ThreadSHMContext {
|
||||
void wait_for_one(int rank, Cond&& cond) {
|
||||
ThreadSHMContext* rank_ctx = shm_contexts[rank];
|
||||
for (;;) {
|
||||
char local_curr_stamp = get_curr_stamp();
|
||||
char local_ready_stamp = get_ready_stamp();
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp();
|
||||
char rank_ready_stamp = rank_ctx->get_ready_stamp();
|
||||
char local_curr_stamp = get_curr_stamp(local_stamp_buffer_idx);
|
||||
char local_ready_stamp = get_ready_stamp(local_stamp_buffer_idx);
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp(remote_stamp_buffer_idx);
|
||||
char rank_ready_stamp =
|
||||
rank_ctx->get_ready_stamp(remote_stamp_buffer_idx);
|
||||
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
|
||||
rank_ready_stamp)) {
|
||||
break;
|
||||
@ -361,6 +377,15 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void reset_threads_stamp_buffer_idx(ThreadSHMContext* ctx, int local,
|
||||
int remote) {
|
||||
int thread_num = ctx->thread_num;
|
||||
for (int i = 0; i < thread_num; ++i) {
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
thread_ctx->set_stamp_buffer_idx(local, remote);
|
||||
}
|
||||
}
|
||||
}; // namespace shm_cc_ops
|
||||
|
||||
namespace shm_cc_ops {
|
||||
@ -632,6 +657,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
|
||||
TensorListMeta* metadata = new (metadata_tensor.data_ptr()) TensorListMeta();
|
||||
metadata->bind_tensor_list(tensor_list_with_metadata);
|
||||
|
||||
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 0, 1);
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
@ -659,6 +685,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
shm_cc_ops::reset_threads_stamp_buffer_idx(ctx, 1, 0);
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
@ -677,7 +704,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
thread_ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
|
@ -151,8 +151,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.impl("rotary_embedding", torch::kCPU, &rotary_embedding);
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
#if defined(__AVX512F__) || defined(__aarch64__)
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
|
@ -4,10 +4,10 @@
|
||||
#include <hip/hip_runtime.h>
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#define WARP_SIZE 32
|
||||
#if defined(USE_ROCM) && defined(__GFX9__)
|
||||
#define WARP_SIZE 64
|
||||
#else
|
||||
#define WARP_SIZE warpSize
|
||||
#define WARP_SIZE 32
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
|
@ -153,7 +153,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -210,7 +210,7 @@ struct ScaledEpilogueBiasAzp
|
||||
EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -288,7 +288,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::threadblock::VisitorCompute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
@ -195,7 +195,7 @@ struct ScaledEpilogueBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -238,7 +238,7 @@ struct ScaledEpilogueColumnBias
|
||||
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
|
||||
|
||||
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -295,7 +295,7 @@ struct ScaledEpilogueBiasAzp
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAzp>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
@ -371,7 +371,7 @@ struct ScaledEpilogueBiasAzpToken
|
||||
cutlass::epilogue::fusion::Sm90EVT<ComputeScaleB, ScaleB, EVTComputeAcc>;
|
||||
|
||||
using ComputeScaleBiasA = cutlass::epilogue::fusion::Sm90Compute<
|
||||
cutlass::multiply_add, ElementD, float,
|
||||
cutlass::homogeneous_multiply_add, ElementD, float,
|
||||
cutlass::FloatRoundStyle::round_to_nearest>;
|
||||
|
||||
public:
|
||||
|
@ -45,7 +45,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
@ -15,15 +15,16 @@ namespace vllm {
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t>
|
||||
__global__ void rms_norm_kernel(
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
|
||||
@ -37,7 +38,7 @@ __global__ void rms_norm_kernel(
|
||||
__syncthreads();
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
@ -50,7 +51,8 @@ __global__ void rms_norm_kernel(
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
@ -59,6 +61,7 @@ fused_add_rms_norm_kernel(
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
const int64_t vec_input_stride = input_stride / width;
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
@ -73,7 +76,8 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[strided_id];
|
||||
temp += residual_v[id];
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
@ -90,10 +94,11 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
int64_t strided_id = blockIdx.x * vec_input_stride + idx;
|
||||
_f16Vec<scalar_t, width> temp = residual_v[id];
|
||||
temp *= s_variance;
|
||||
temp *= weight_v[idx];
|
||||
input_v[id] = temp;
|
||||
input_v[strided_id] = temp;
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,7 +108,8 @@ fused_add_rms_norm_kernel(
|
||||
template <typename scalar_t, int width>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_kernel(
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int64_t input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
@ -111,7 +117,7 @@ fused_add_rms_norm_kernel(
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
scalar_t z = input[blockIdx.x * input_stride + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)z;
|
||||
variance += x * x;
|
||||
@ -129,7 +135,7 @@ fused_add_rms_norm_kernel(
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)residual[blockIdx.x * hidden_size + idx];
|
||||
input[blockIdx.x * hidden_size + idx] =
|
||||
input[blockIdx.x * input_stride + idx] =
|
||||
((scalar_t)(x * s_variance)) * weight[idx];
|
||||
}
|
||||
}
|
||||
@ -141,11 +147,12 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(input.stride(-1) == 1);
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
|
||||
int hidden_size = input.size(-1);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
int64_t input_stride = input.stride(-2);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
dim3 block(std::min(hidden_size, 1024));
|
||||
@ -153,26 +160,29 @@ void rms_norm(torch::Tensor& out, // [..., hidden_size]
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] {
|
||||
vllm::rms_norm_kernel<scalar_t><<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(),
|
||||
out.data_ptr<scalar_t>(), input.data_ptr<scalar_t>(), input_stride,
|
||||
weight.data_ptr<scalar_t>(), epsilon, num_tokens, hidden_size);
|
||||
});
|
||||
}
|
||||
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), epsilon, \
|
||||
num_tokens, hidden_size); \
|
||||
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
|
||||
VLLM_DISPATCH_FLOATING_TYPES( \
|
||||
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
|
||||
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
input.data_ptr<scalar_t>(), input_stride, \
|
||||
residual.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
});
|
||||
|
||||
void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
torch::Tensor& residual, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
TORCH_CHECK(weight.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
int64_t input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -194,9 +204,16 @@ void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size]
|
||||
auto inp_ptr = reinterpret_cast<std::uintptr_t>(input.data_ptr());
|
||||
auto res_ptr = reinterpret_cast<std::uintptr_t>(residual.data_ptr());
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
constexpr int vector_width = 8;
|
||||
constexpr int req_alignment_bytes =
|
||||
vector_width * 2; // vector_width * sizeof(bfloat16 or float16) (float32
|
||||
// falls back to non-vectorized version anyway)
|
||||
bool ptrs_are_aligned = inp_ptr % req_alignment_bytes == 0 &&
|
||||
res_ptr % req_alignment_bytes == 0 &&
|
||||
wt_ptr % req_alignment_bytes == 0;
|
||||
bool offsets_are_multiple_of_vector_width =
|
||||
hidden_size % vector_width == 0 && input_stride % vector_width == 0;
|
||||
if (ptrs_are_aligned && offsets_are_multiple_of_vector_width) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
@ -23,8 +23,9 @@ namespace vllm {
|
||||
// TODO(woosuk): Further optimize this kernel.
|
||||
template <typename scalar_t, typename fp8_type>
|
||||
__global__ void rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int input_stride,
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
const float epsilon, const int num_tokens, const int hidden_size) {
|
||||
@ -32,7 +33,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
const float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
const float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
variance += x * x;
|
||||
}
|
||||
|
||||
@ -49,7 +50,7 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
float const scale_inv = 1.0f / *scale;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
float x = (float)input[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)input[blockIdx.x * input_stride + idx];
|
||||
float const out_norm = ((scalar_t)(x * s_variance)) * weight[idx];
|
||||
out[blockIdx.x * hidden_size + idx] =
|
||||
scaled_fp8_conversion<true, fp8_type>(out_norm, scale_inv);
|
||||
@ -63,8 +64,9 @@ __global__ void rms_norm_static_fp8_quant_kernel(
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width > 0) && _typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@ -74,6 +76,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
static_assert(sizeof(_f16Vec<scalar_t, width>) == sizeof(scalar_t) * width);
|
||||
|
||||
const int vec_hidden_size = hidden_size / width;
|
||||
const int vec_input_stride = input_stride / width;
|
||||
__shared__ float s_variance;
|
||||
float variance = 0.0f;
|
||||
/* These and the argument pointers are all declared `restrict` as they are
|
||||
@ -87,8 +90,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
reinterpret_cast<const _f16Vec<scalar_t, width>*>(weight);
|
||||
|
||||
for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) {
|
||||
int stride_id = blockIdx.x * vec_input_stride + idx;
|
||||
int id = blockIdx.x * vec_hidden_size + idx;
|
||||
_f16Vec<scalar_t, width> temp = input_v[id];
|
||||
_f16Vec<scalar_t, width> temp = input_v[stride_id];
|
||||
temp += residual_v[id];
|
||||
variance += temp.sum_squares();
|
||||
residual_v[id] = temp;
|
||||
@ -125,8 +129,9 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
template <typename scalar_t, int width, typename fp8_type>
|
||||
__global__ std::enable_if_t<(width == 0) || !_typeConvert<scalar_t>::exists>
|
||||
fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
fp8_type* __restrict__ out, // [..., hidden_size]
|
||||
scalar_t* __restrict__ input, // [..., hidden_size]
|
||||
const int input_stride,
|
||||
scalar_t* __restrict__ residual, // [..., hidden_size]
|
||||
const scalar_t* __restrict__ weight, // [hidden_size]
|
||||
const float* __restrict__ scale, // [1]
|
||||
@ -135,7 +140,7 @@ fused_add_rms_norm_static_fp8_quant_kernel(
|
||||
float variance = 0.0f;
|
||||
|
||||
for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) {
|
||||
scalar_t z = input[blockIdx.x * hidden_size + idx];
|
||||
scalar_t z = input[blockIdx.x * input_stride + idx];
|
||||
z += residual[blockIdx.x * hidden_size + idx];
|
||||
float x = (float)z;
|
||||
variance += x * x;
|
||||
@ -169,7 +174,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
torch::Tensor& scale, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -183,8 +190,9 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
vllm::rms_norm_static_fp8_quant_kernel<scalar_t, fp8_t>
|
||||
<<<grid, block, 0, stream>>>(
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(),
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(),
|
||||
epsilon, num_tokens, hidden_size);
|
||||
input_stride, weight.data_ptr<scalar_t>(),
|
||||
scale.data_ptr<float>(), epsilon, num_tokens,
|
||||
hidden_size);
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -198,7 +206,7 @@ void rms_norm_static_fp8_quant(torch::Tensor& out, // [..., hidden_size]
|
||||
width, fp8_t> \
|
||||
<<<grid, block, 0, stream>>>( \
|
||||
out.data_ptr<fp8_t>(), input.data_ptr<scalar_t>(), \
|
||||
residual.data_ptr<scalar_t>(), \
|
||||
input_stride, residual.data_ptr<scalar_t>(), \
|
||||
weight.data_ptr<scalar_t>(), scale.data_ptr<float>(), \
|
||||
epsilon, num_tokens, hidden_size); \
|
||||
}); \
|
||||
@ -210,7 +218,10 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
torch::Tensor& weight, // [hidden_size]
|
||||
torch::Tensor& scale, // [1]
|
||||
double epsilon) {
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
TORCH_CHECK(residual.is_contiguous());
|
||||
int hidden_size = input.size(-1);
|
||||
int input_stride = input.stride(-2);
|
||||
int num_tokens = input.numel() / hidden_size;
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
@ -234,7 +245,7 @@ void fused_add_rms_norm_static_fp8_quant(
|
||||
auto wt_ptr = reinterpret_cast<std::uintptr_t>(weight.data_ptr());
|
||||
bool ptrs_are_aligned =
|
||||
inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0;
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0) {
|
||||
if (ptrs_are_aligned && hidden_size % 8 == 0 && input_stride % 8 == 0) {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(8);
|
||||
} else {
|
||||
LAUNCH_FUSED_ADD_RMS_NORM(0);
|
||||
|
@ -1,656 +0,0 @@
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_fwd.cu
|
||||
// and https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d_update.cu
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
||||
#include "causal_conv1d.h"
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
|
||||
#include <cub/block/block_load.cuh>
|
||||
#include <cub/block/block_store.cuh>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
namespace cub = hipcub;
|
||||
#endif
|
||||
|
||||
#include "static_switch.h"
|
||||
|
||||
|
||||
|
||||
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
||||
|
||||
#define DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
||||
if (ITYPE == at::ScalarType::Half) { \
|
||||
using input_t = at::Half; \
|
||||
using weight_t = at::Half; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
||||
using input_t = at::BFloat16; \
|
||||
using weight_t = at::BFloat16; \
|
||||
__VA_ARGS__(); \
|
||||
} else if (ITYPE == at::ScalarType::Float) { \
|
||||
using input_t = float; \
|
||||
using weight_t = float; \
|
||||
__VA_ARGS__(); \
|
||||
} else { \
|
||||
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
||||
}
|
||||
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
||||
// sizes
|
||||
const size_t batch,
|
||||
const size_t dim,
|
||||
const size_t seqlen,
|
||||
const size_t width,
|
||||
// device pointers
|
||||
const at::Tensor x,
|
||||
const at::Tensor weight,
|
||||
const at::Tensor out,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool silu_activation,
|
||||
int64_t pad_slot_id,
|
||||
const std::optional<at::Tensor>& query_start_loc = std::nullopt,
|
||||
const std::optional<at::Tensor>& cache_indices = std::nullopt,
|
||||
const std::optional<at::Tensor>& has_initial_state = std::nullopt) {
|
||||
|
||||
// Reset the parameters
|
||||
memset(¶ms, 0, sizeof(params));
|
||||
|
||||
params.batch = batch;
|
||||
params.dim = dim;
|
||||
params.seqlen = seqlen;
|
||||
params.width = width;
|
||||
params.pad_slot_id = pad_slot_id;
|
||||
|
||||
params.silu_activation = silu_activation;
|
||||
|
||||
// Set the pointers and strides.
|
||||
params.x_ptr = x.data_ptr();
|
||||
params.weight_ptr = weight.data_ptr();
|
||||
params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr;
|
||||
params.out_ptr = out.data_ptr();
|
||||
// All stride are in elements, not bytes.
|
||||
params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr;
|
||||
params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr;
|
||||
params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr;
|
||||
const bool varlen = params.query_start_loc_ptr != nullptr;
|
||||
params.x_batch_stride = x.stride(varlen ? 1 : 0);
|
||||
params.x_c_stride = x.stride(varlen ? 0 : 1);
|
||||
params.x_l_stride = x.stride(varlen ? 1 : -1);
|
||||
params.weight_c_stride = weight.stride(0);
|
||||
params.weight_width_stride = weight.stride(1);
|
||||
params.out_batch_stride = out.stride(varlen ? 1 : 0);
|
||||
params.out_c_stride = out.stride(varlen ? 0 : 1);
|
||||
params.out_l_stride = out.stride(varlen ? 1 : -1);
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
const std::optional<at::Tensor> &conv_states,
|
||||
const std::optional<at::Tensor> &query_start_loc,
|
||||
const std::optional<at::Tensor> &cache_indices,
|
||||
const std::optional<at::Tensor> &has_initial_state,
|
||||
bool silu_activation,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const bool varlen = query_start_loc.has_value() ? true : false;
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0];
|
||||
const int dim = varlen ? sizes[0] : sizes[1];
|
||||
const int seqlen = varlen ? sizes[1] : sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(x, dim, seqlen);
|
||||
}
|
||||
else {
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
}
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
|
||||
if (has_initial_state.has_value()) {
|
||||
auto has_initial_state_ = has_initial_state.value();
|
||||
TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool);
|
||||
TORCH_CHECK(has_initial_state_.is_cuda());
|
||||
CHECK_SHAPE(has_initial_state_, batch_size);
|
||||
}
|
||||
|
||||
|
||||
if (query_start_loc.has_value()) {
|
||||
auto query_start_loc_ = query_start_loc.value();
|
||||
TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(query_start_loc_.is_cuda());
|
||||
}
|
||||
|
||||
|
||||
if (cache_indices.has_value()) {
|
||||
auto cache_indices_ = cache_indices.value();
|
||||
TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int);
|
||||
TORCH_CHECK(cache_indices_.is_cuda());
|
||||
CHECK_SHAPE(cache_indices_, batch_size);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id,
|
||||
query_start_loc,
|
||||
cache_indices,
|
||||
has_initial_state
|
||||
);
|
||||
|
||||
if (conv_states.has_value()) {
|
||||
auto conv_states_ = conv_states.value();
|
||||
TORCH_CHECK(conv_states_.scalar_type() == input_type);
|
||||
TORCH_CHECK(conv_states_.is_cuda());
|
||||
params.conv_states_ptr = conv_states_.data_ptr();
|
||||
params.conv_states_batch_stride = conv_states_.stride(0);
|
||||
params.conv_states_c_stride = conv_states_.stride(1);
|
||||
params.conv_states_l_stride = conv_states_.stride(2);
|
||||
} else {
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
void causal_conv1d_update(const at::Tensor &x,
|
||||
const at::Tensor &conv_state,
|
||||
const at::Tensor &weight,
|
||||
const std::optional<at::Tensor> &bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor> &cache_seqlens_,
|
||||
const std::optional<at::Tensor> &conv_state_indices_,
|
||||
// used to identify padding entries if cache_indices provided
|
||||
// in case of padding, the kernel will return early
|
||||
int64_t pad_slot_id) {
|
||||
auto input_type = x.scalar_type();
|
||||
auto weight_type = weight.scalar_type();
|
||||
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
||||
TORCH_CHECK(weight_type == input_type, "weight type must equal to input type, other variations are disabled due to binary size limitations");
|
||||
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
||||
|
||||
TORCH_CHECK(x.is_cuda());
|
||||
TORCH_CHECK(conv_state.is_cuda());
|
||||
TORCH_CHECK(weight.is_cuda());
|
||||
|
||||
const auto sizes = x.sizes();
|
||||
const int batch_size = sizes[0];
|
||||
const int dim = sizes[1];
|
||||
const int seqlen = sizes[2];
|
||||
const int width = weight.size(-1);
|
||||
const int conv_state_len = conv_state.size(2);
|
||||
TORCH_CHECK(conv_state_len >= width - 1);
|
||||
|
||||
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
||||
CHECK_SHAPE(weight, dim, width);
|
||||
|
||||
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
||||
|
||||
if (bias_.has_value()) {
|
||||
auto bias = bias_.value();
|
||||
TORCH_CHECK(bias.scalar_type() == weight_type);
|
||||
TORCH_CHECK(bias.is_cuda());
|
||||
TORCH_CHECK(bias.stride(-1) == 1);
|
||||
CHECK_SHAPE(bias, dim);
|
||||
}
|
||||
|
||||
at::Tensor out = x;
|
||||
|
||||
ConvParamsBase params;
|
||||
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
||||
bias_,
|
||||
silu_activation,
|
||||
pad_slot_id);
|
||||
params.conv_state_ptr = conv_state.data_ptr();
|
||||
params.conv_state_len = conv_state_len;
|
||||
// All stride are in elements, not bytes.
|
||||
params.conv_state_batch_stride = conv_state.stride(0);
|
||||
params.conv_state_c_stride = conv_state.stride(1);
|
||||
params.conv_state_l_stride = conv_state.stride(2);
|
||||
|
||||
if (cache_seqlens_.has_value()) {
|
||||
auto cache_seqlens = cache_seqlens_.value();
|
||||
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
||||
TORCH_CHECK(cache_seqlens.is_cuda());
|
||||
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
||||
CHECK_SHAPE(cache_seqlens, batch_size);
|
||||
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
||||
} else {
|
||||
params.cache_seqlens = nullptr;
|
||||
}
|
||||
|
||||
if (conv_state_indices_.has_value()) {
|
||||
auto conv_state_indices = conv_state_indices_.value();
|
||||
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
||||
TORCH_CHECK(conv_state_indices.is_cuda());
|
||||
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
||||
CHECK_SHAPE(conv_state_indices, batch_size);
|
||||
|
||||
int conv_state_entries = conv_state.size(0);
|
||||
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
||||
|
||||
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
||||
} else {
|
||||
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_fwd_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
||||
static_assert(kWidth <= kNElts);
|
||||
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
||||
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
||||
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
||||
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
||||
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
||||
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
||||
static constexpr int kSmemIOSize = kIsVecLoad
|
||||
? 0
|
||||
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
||||
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
||||
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
||||
};
|
||||
|
||||
template<typename Ktraits>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
constexpr int kNElts = Ktraits::kNElts;
|
||||
constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using vec_t = typename Ktraits::vec_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
// Shared memory.
|
||||
extern __shared__ char smem_[];
|
||||
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
||||
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
||||
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
||||
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
||||
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
||||
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y;
|
||||
const int *query_start_loc = kVarlen ? reinterpret_cast<int *>(params.query_start_loc_ptr) : nullptr;
|
||||
const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id;
|
||||
const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + sequence_start_index * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + sequence_start_index * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
bool has_initial_state = params.has_initial_state_ptr == nullptr ? false
|
||||
: reinterpret_cast<bool *>(params.has_initial_state_ptr)[batch_id];
|
||||
|
||||
int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<int *>(params.cache_indices_ptr);
|
||||
int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id];
|
||||
// cache_index == params.pad_slot_id is defined as padding, so we exit early
|
||||
if (cache_index == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr
|
||||
: reinterpret_cast<input_t *>(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride;
|
||||
|
||||
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
||||
if (tidx == 0) {
|
||||
input_t initial_state[kNElts] = {0};
|
||||
if (has_initial_state) {
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; }
|
||||
}
|
||||
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(initial_state)[0];
|
||||
}
|
||||
|
||||
float weight_vals[kWidth];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
constexpr int kChunkSize = kNThreads * kNElts;
|
||||
const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize;
|
||||
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
||||
input_t x_vals_load[2 * kNElts] = {0};
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
__syncthreads();
|
||||
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize);
|
||||
}
|
||||
x += kChunkSize;
|
||||
__syncthreads();
|
||||
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
||||
// the last elements of the previous chunk.
|
||||
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
__syncthreads();
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
||||
__syncthreads();
|
||||
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
||||
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
||||
|
||||
float x_vals[2 * kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
||||
|
||||
float out_vals[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = bias_val;
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth; ++w) {
|
||||
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
||||
}
|
||||
}
|
||||
|
||||
if (params.silu_activation) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) {
|
||||
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
||||
}
|
||||
}
|
||||
|
||||
input_t out_vals_store[kNElts];
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
||||
if constexpr(kIsVecLoad) {
|
||||
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts);
|
||||
} else {
|
||||
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize);
|
||||
}
|
||||
out += kChunkSize;
|
||||
|
||||
int final_state_position = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize);
|
||||
// in case the final state is separated between the last "smem_exchange" and
|
||||
// and the one before it (chunk = n_chunks - 1 and chunk = n_chunks - 2),
|
||||
// (which occurs when `final_state_position` is a non-positive index)
|
||||
// we load the correct data from smem_exchange from both chunks, the last chunk iteration and the one before it
|
||||
if (conv_states != nullptr && final_state_position < 0 && seqlen > kWidth){
|
||||
input_t vals_load[kNElts] = {0};
|
||||
if ((chunk == n_chunks - 2) && (tidx == kNThreads - 1)){
|
||||
// chunk = n_chunks - 2, a segment of the final state sits in the last index
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[kNThreads - 1];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < -final_state_position; ++w){
|
||||
conv_states[w] = vals_load[kNElts + final_state_position + w];
|
||||
}
|
||||
}
|
||||
if ((chunk == n_chunks - 1) && tidx == 0){
|
||||
// chunk = n_chunks - 1, the second segment of the final state first positions
|
||||
reinterpret_cast<vec_t *>(vals_load)[0] = smem_exchange[0];
|
||||
for (int w = -final_state_position; w < kWidth - 1; ++w){
|
||||
conv_states[w] = vals_load[w + final_state_position];
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Final state is stored in the smem_exchange last token slot,
|
||||
// in case seqlen < kWidth, we would need to take the final state from the
|
||||
// initial state which is stored in conv_states
|
||||
// in case seqlen > kWidth, we would need to load the last kWidth - 1 data
|
||||
// and load it into conv_state accordingly
|
||||
int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts;
|
||||
if (conv_states != nullptr && tidx == last_thread) {
|
||||
input_t x_vals_load[kNElts * 2] = {0};
|
||||
// in case we are on the first kWidth tokens
|
||||
if (last_thread == 0 && seqlen < kWidth){
|
||||
// Need to take the initial state
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[0];
|
||||
const int offset = seqlen - (kWidth - 1);
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
// pad the existing state
|
||||
if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; }
|
||||
else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); }
|
||||
}
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
if (offset + w >= 0)
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
else {
|
||||
// in case the final state is in between the threads data
|
||||
const int offset = ((seqlen - (kWidth - 1)) % (kNElts));
|
||||
if ((offset + kWidth - 2) >= kNElts && (last_thread + 1 < kNThreads)){
|
||||
// In case last_thread == kNThreads - 1, accessing last_thread + 1 will result in a
|
||||
// illegal access error on H100.
|
||||
// Therefore, we access last_thread + 1, only if the final state data sits there
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[1] = smem_exchange[last_thread + 1];
|
||||
}
|
||||
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[last_thread];
|
||||
#pragma unroll
|
||||
for (int w = 0; w < kWidth - 1; ++w){
|
||||
conv_states[w] = x_vals_load[offset + w ];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
||||
const bool kVarlen = params.query_start_loc_ptr != nullptr;
|
||||
BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] {
|
||||
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize;
|
||||
dim3 grid(params.batch, params.dim);
|
||||
|
||||
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
||||
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
|
||||
|
||||
|
||||
|
||||
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
||||
struct Causal_conv1d_update_kernel_traits {
|
||||
using input_t = input_t_;
|
||||
using weight_t = weight_t_;
|
||||
static constexpr int kNThreads = kNThreads_;
|
||||
static constexpr int kWidth = kWidth_;
|
||||
static constexpr int kNBytes = sizeof(input_t);
|
||||
static_assert(kNBytes == 2 || kNBytes == 4);
|
||||
};
|
||||
|
||||
template<typename Ktraits, bool kIsCircularBuffer>
|
||||
__global__ __launch_bounds__(Ktraits::kNThreads)
|
||||
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
||||
constexpr int kWidth = Ktraits::kWidth;
|
||||
constexpr int kNThreads = Ktraits::kNThreads;
|
||||
using input_t = typename Ktraits::input_t;
|
||||
using weight_t = typename Ktraits::weight_t;
|
||||
|
||||
const int tidx = threadIdx.x;
|
||||
const int batch_id = blockIdx.x;
|
||||
const int channel_id = blockIdx.y * kNThreads + tidx;
|
||||
if (channel_id >= params.dim) return;
|
||||
|
||||
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
||||
+ channel_id * params.x_c_stride;
|
||||
|
||||
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
||||
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
||||
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
||||
? batch_id
|
||||
: params.conv_state_indices_ptr[batch_id];
|
||||
// conv_state_batch_coord == params.pad_slot_id is defined as padding so we exit early
|
||||
if (conv_state_batch_coord == params.pad_slot_id){
|
||||
return;
|
||||
}
|
||||
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
||||
+ conv_state_batch_coord * params.conv_state_batch_stride
|
||||
+ channel_id * params.conv_state_c_stride;
|
||||
|
||||
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
||||
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
||||
+ channel_id * params.out_c_stride;
|
||||
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
||||
|
||||
int state_len = params.conv_state_len;
|
||||
int advance_len = params.seqlen;
|
||||
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
||||
int update_idx = cache_seqlen - (kWidth - 1);
|
||||
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
||||
|
||||
float weight_vals[kWidth] = {0};
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
||||
|
||||
float x_vals[kWidth] = {0};
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
||||
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
||||
}
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) {
|
||||
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
||||
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
||||
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
||||
}
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
||||
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
||||
x_vals[i] = float(state_val);
|
||||
}
|
||||
}
|
||||
#pragma unroll 2
|
||||
for (int i = 0; i < params.seqlen; ++i) {
|
||||
input_t x_val = x[i * params.x_l_stride];
|
||||
if constexpr (!kIsCircularBuffer) {
|
||||
if (i < advance_len && state_len - advance_len + i >= 0) {
|
||||
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
||||
}
|
||||
} else {
|
||||
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
||||
++update_idx;
|
||||
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
||||
}
|
||||
x_vals[kWidth - 1] = float(x_val);
|
||||
float out_val = bias_val;
|
||||
#pragma unroll
|
||||
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
||||
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
||||
out[i * params.out_l_stride] = input_t(out_val);
|
||||
// Shift the input buffer by 1
|
||||
#pragma unroll
|
||||
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
||||
}
|
||||
}
|
||||
|
||||
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
||||
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
||||
auto kernel = params.cache_seqlens == nullptr
|
||||
? &causal_conv1d_update_kernel<Ktraits, false>
|
||||
: &causal_conv1d_update_kernel<Ktraits, true>;
|
||||
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
|
||||
template<typename input_t, typename weight_t>
|
||||
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
||||
if (params.width == 2) {
|
||||
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 3) {
|
||||
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
||||
} else if (params.width == 4) {
|
||||
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
||||
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
@ -1,159 +0,0 @@
|
||||
/******************************************************************************
|
||||
* Copyright (c) 2024, Tri Dao.
|
||||
******************************************************************************/
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/causal_conv1d.h
|
||||
#pragma once
|
||||
|
||||
#include <cuda_bf16.h>
|
||||
#include <cuda_fp16.h>
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
struct ConvParamsBase {
|
||||
using index_t = uint32_t;
|
||||
|
||||
int batch, dim, seqlen, width;
|
||||
int64_t pad_slot_id;
|
||||
bool silu_activation;
|
||||
|
||||
index_t x_batch_stride;
|
||||
index_t x_c_stride;
|
||||
index_t x_l_stride;
|
||||
index_t weight_c_stride;
|
||||
index_t weight_width_stride;
|
||||
index_t out_batch_stride;
|
||||
index_t out_c_stride;
|
||||
index_t out_l_stride;
|
||||
|
||||
int conv_state_len;
|
||||
index_t conv_state_batch_stride;
|
||||
index_t conv_state_c_stride;
|
||||
index_t conv_state_l_stride;
|
||||
|
||||
// Common data pointers.
|
||||
void *__restrict__ x_ptr;
|
||||
void *__restrict__ weight_ptr;
|
||||
void *__restrict__ bias_ptr;
|
||||
void *__restrict__ out_ptr;
|
||||
|
||||
void *__restrict__ conv_state_ptr;
|
||||
void *__restrict__ query_start_loc_ptr;
|
||||
void *__restrict__ has_initial_state_ptr;
|
||||
void *__restrict__ cache_indices_ptr;
|
||||
int32_t *__restrict__ cache_seqlens;
|
||||
|
||||
// For the continuous batching case. Makes it so that the mamba state for
|
||||
// the current batch doesn't need to be a contiguous tensor.
|
||||
int32_t *__restrict__ conv_state_indices_ptr;
|
||||
|
||||
void *__restrict__ seq_idx_ptr;
|
||||
|
||||
// No __restrict__ since initial_states could be the same as final_states.
|
||||
void * initial_states_ptr;
|
||||
index_t initial_states_batch_stride;
|
||||
index_t initial_states_l_stride;
|
||||
index_t initial_states_c_stride;
|
||||
|
||||
void * final_states_ptr;
|
||||
index_t final_states_batch_stride;
|
||||
index_t final_states_l_stride;
|
||||
index_t final_states_c_stride;
|
||||
|
||||
void * conv_states_ptr;
|
||||
index_t conv_states_batch_stride;
|
||||
index_t conv_states_l_stride;
|
||||
index_t conv_states_c_stride;
|
||||
};
|
||||
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
||||
}
|
||||
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return std::max(ilist);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return std::min(a, b);
|
||||
}
|
||||
|
||||
#else
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
template<typename T>
|
||||
__device__ inline T shuffle_xor(T val, int offset) {
|
||||
return __shfl_xor(val, offset);
|
||||
}
|
||||
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
||||
{
|
||||
return *std::max_element(ilist.begin(), ilist.end());
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
constexpr T constexpr_min(T a, T b) {
|
||||
return a < b ? a : b;
|
||||
}
|
||||
#endif
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<int BYTES> struct BytesToType {};
|
||||
|
||||
template<> struct BytesToType<16> {
|
||||
using Type = uint4;
|
||||
static_assert(sizeof(Type) == 16);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<8> {
|
||||
using Type = uint64_t;
|
||||
static_assert(sizeof(Type) == 8);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<4> {
|
||||
using Type = uint32_t;
|
||||
static_assert(sizeof(Type) == 4);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<2> {
|
||||
using Type = uint16_t;
|
||||
static_assert(sizeof(Type) == 2);
|
||||
};
|
||||
|
||||
template<> struct BytesToType<1> {
|
||||
using Type = uint8_t;
|
||||
static_assert(sizeof(Type) == 1);
|
||||
};
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
template<typename T>
|
||||
struct SumOp {
|
||||
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
||||
};
|
||||
|
||||
template<int THREADS>
|
||||
struct Allreduce {
|
||||
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
constexpr int OFFSET = THREADS / 2;
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, OFFSET));
|
||||
return Allreduce<OFFSET>::run(x, op);
|
||||
}
|
||||
};
|
||||
|
||||
template<>
|
||||
struct Allreduce<2> {
|
||||
template<typename T, typename Operator>
|
||||
static __device__ inline T run(T x, Operator &op) {
|
||||
x = op(x, __shfl_xor_sync(uint32_t(-1), x, 1));
|
||||
return x;
|
||||
}
|
||||
};
|
@ -1,28 +0,0 @@
|
||||
// Inspired by
|
||||
// https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
||||
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
||||
// clang-format off
|
||||
// adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/csrc/static_switch.h
|
||||
|
||||
#pragma once
|
||||
|
||||
/// @param COND - a boolean expression to switch by
|
||||
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
||||
/// @param ... - code to execute for true and false
|
||||
///
|
||||
/// Usage:
|
||||
/// ```
|
||||
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
||||
/// some_function<BoolConst>(...);
|
||||
/// });
|
||||
/// ```
|
||||
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
||||
[&] { \
|
||||
if (COND) { \
|
||||
static constexpr bool CONST_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
static constexpr bool CONST_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
@ -7,7 +7,11 @@
|
||||
|
||||
#include <c10/util/BFloat16.h>
|
||||
#include <c10/util/Half.h>
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#ifdef USE_ROCM
|
||||
#include <c10/hip/HIPException.h> // For C10_HIP_CHECK and C10_HIP_KERNEL_LAUNCH_CHECK
|
||||
#else
|
||||
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
||||
#endif
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include <cub/block/block_load.cuh>
|
||||
@ -312,19 +316,25 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) {
|
||||
// kIsVariableB, kIsVariableC and kHasZ are all set to True to reduce binary size
|
||||
constexpr bool kIsVariableB = true;
|
||||
constexpr bool kIsVariableC = true;
|
||||
constexpr bool kHasZ = true;
|
||||
BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
BOOL_SWITCH(params.z_ptr != nullptr , kHasZ, [&] {
|
||||
BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] {
|
||||
using Ktraits = Selective_Scan_fwd_kernel_traits<kNThreads, kNItems, kNRows, kIsEvenLen, kIsVariableB, kIsVariableC, kHasZ, kVarlen, input_t, weight_t>;
|
||||
constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t);
|
||||
dim3 grid(params.batch, params.dim / kNRows);
|
||||
auto kernel = &selective_scan_fwd_kernel<Ktraits>;
|
||||
if (kSmemSize >= 48 * 1024) {
|
||||
#ifdef USE_ROCM
|
||||
C10_HIP_CHECK(hipFuncSetAttribute(
|
||||
reinterpret_cast<const void*>(kernel), hipFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#else
|
||||
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
||||
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
||||
#endif
|
||||
}
|
||||
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@ -612,19 +622,20 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
|
||||
at::Tensor z, out_z;
|
||||
const bool has_z = z_.has_value();
|
||||
TORCH_CHECK(has_z, "has_z = False is disabled in favor of reduced binary size")
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
if (has_z) {
|
||||
z = z_.value();
|
||||
TORCH_CHECK(z.scalar_type() == input_type);
|
||||
TORCH_CHECK(z.is_cuda());
|
||||
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
|
||||
if (varlen){
|
||||
CHECK_SHAPE(z, dim, seqlen);
|
||||
} else {
|
||||
CHECK_SHAPE(z, batch_size, dim, seqlen);
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
}
|
||||
|
||||
out_z = z;
|
||||
|
||||
// Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout
|
||||
at::Tensor out = delta;
|
||||
TORCH_CHECK(ssm_states.scalar_type() == input_type);
|
||||
@ -653,4 +664,3 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
#include <torch/all.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/Atomic.cuh>
|
||||
@ -19,9 +20,14 @@ __global__ void moe_align_block_size_kernel(
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t padded_num_experts, int32_t experts_per_warp, int32_t block_size,
|
||||
size_t numel, int32_t* __restrict__ cumsum) {
|
||||
size_t numel, int32_t* __restrict__ cumsum, int32_t max_num_tokens_padded) {
|
||||
extern __shared__ int32_t shared_counts[];
|
||||
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
}
|
||||
|
||||
const int warp_id = threadIdx.x / WARP_SIZE;
|
||||
const int my_expert_start = warp_id * experts_per_warp;
|
||||
|
||||
@ -45,18 +51,27 @@ __global__ void moe_align_block_size_kernel(
|
||||
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
cumsum[0] = 0;
|
||||
for (int i = 1; i <= num_experts; ++i) {
|
||||
int expert_count = 0;
|
||||
int warp_idx = (i - 1) / experts_per_warp;
|
||||
int expert_offset = (i - 1) % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
// Compute prefix sum over token counts per expert
|
||||
using BlockScan = cub::BlockScan<int32_t, 1024>;
|
||||
__shared__ typename BlockScan::TempStorage temp_storage;
|
||||
|
||||
cumsum[i] =
|
||||
cumsum[i - 1] + CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
*total_tokens_post_pad = cumsum[num_experts];
|
||||
int expert_count = 0;
|
||||
int expert_id = threadIdx.x;
|
||||
if (expert_id < num_experts) {
|
||||
int warp_idx = expert_id / experts_per_warp;
|
||||
int expert_offset = expert_id % experts_per_warp;
|
||||
expert_count = shared_counts[warp_idx * experts_per_warp + expert_offset];
|
||||
expert_count = CEILDIV(expert_count, block_size) * block_size;
|
||||
}
|
||||
|
||||
int cumsum_val;
|
||||
BlockScan(temp_storage).ExclusiveSum(expert_count, cumsum_val);
|
||||
if (expert_id <= num_experts) {
|
||||
cumsum[expert_id] = cumsum_val;
|
||||
}
|
||||
|
||||
if (expert_id == num_experts) {
|
||||
*total_tokens_post_pad = cumsum_val;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
@ -67,6 +82,13 @@ __global__ void moe_align_block_size_kernel(
|
||||
expert_ids[i / block_size] = threadIdx.x;
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
|
||||
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
|
||||
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
|
||||
expert_ids[i] = 0;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -105,7 +127,12 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
const scalar_t* __restrict__ topk_ids,
|
||||
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
|
||||
int32_t* __restrict__ total_tokens_post_pad, int32_t num_experts,
|
||||
int32_t block_size, size_t numel) {
|
||||
int32_t block_size, size_t numel, int32_t max_num_tokens_padded) {
|
||||
// Initialize sorted_token_ids with numel
|
||||
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
|
||||
sorted_token_ids[it] = numel;
|
||||
}
|
||||
|
||||
const size_t tid = threadIdx.x;
|
||||
const size_t stride = blockDim.x;
|
||||
|
||||
@ -153,6 +180,13 @@ __global__ void moe_align_block_size_small_batch_expert_kernel(
|
||||
}
|
||||
}
|
||||
|
||||
// Fill remaining expert_ids with 0
|
||||
const size_t fill_start_idx = cumsum[num_experts] / block_size + threadIdx.x;
|
||||
const size_t expert_ids_size = CEILDIV(max_num_tokens_padded, block_size);
|
||||
for (size_t i = fill_start_idx; i < expert_ids_size; i += blockDim.x) {
|
||||
expert_ids[i] = 0;
|
||||
}
|
||||
|
||||
for (size_t i = tid; i < numel; i += stride) {
|
||||
int32_t expert_id = topk_ids[i];
|
||||
int32_t rank_post_pad =
|
||||
@ -179,13 +213,17 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
int threads = 1024;
|
||||
threads = ((threads + WARP_SIZE - 1) / WARP_SIZE) * WARP_SIZE;
|
||||
|
||||
// BlockScan uses 1024 threads and assigns one thread per expert.
|
||||
TORCH_CHECK(padded_num_experts < 1024,
|
||||
"padded_num_experts must be less than 1024");
|
||||
|
||||
VLLM_DISPATCH_INTEGRAL_AND_UNSIGNED_TYPES(
|
||||
topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] {
|
||||
// calc needed amount of shared mem for `cumsum` tensors
|
||||
auto options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt).device(topk_ids.device());
|
||||
torch::Tensor cumsum_buffer =
|
||||
torch::zeros({num_experts + 1}, options_int);
|
||||
torch::empty({num_experts + 1}, options_int);
|
||||
bool small_batch_expert_mode =
|
||||
(topk_ids.numel() < 1024) && (num_experts <= 64);
|
||||
|
||||
@ -203,7 +241,7 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
sorted_token_ids.data_ptr<int32_t>(),
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts, block_size,
|
||||
topk_ids.numel());
|
||||
topk_ids.numel(), sorted_token_ids.size(0));
|
||||
} else {
|
||||
auto align_kernel = vllm::moe::moe_align_block_size_kernel<scalar_t>;
|
||||
|
||||
@ -217,7 +255,8 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts,
|
||||
experts_ids.data_ptr<int32_t>(),
|
||||
num_tokens_post_pad.data_ptr<int32_t>(), num_experts,
|
||||
padded_num_experts, experts_per_warp, block_size,
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>());
|
||||
topk_ids.numel(), cumsum_buffer.data_ptr<int32_t>(),
|
||||
sorted_token_ids.size(0));
|
||||
|
||||
const int block_threads = std::min(256, (int)threads);
|
||||
const int num_blocks =
|
||||
|
21
csrc/ops.h
21
csrc/ops.h
@ -287,6 +287,11 @@ void scaled_fp4_experts_quant(
|
||||
torch::Tensor const& input, torch::Tensor const& input_global_scale,
|
||||
torch::Tensor const& input_offset_by_experts,
|
||||
torch::Tensor const& output_scale_offset_by_experts);
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0);
|
||||
#endif
|
||||
|
||||
void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
|
||||
@ -326,22 +331,6 @@ void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta,
|
||||
const std::optional<torch::Tensor>& has_initial_state,
|
||||
const torch::Tensor& ssm_states, int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_update(const at::Tensor& x, const at::Tensor& conv_state,
|
||||
const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
bool silu_activation,
|
||||
const std::optional<at::Tensor>& cache_seqlens_,
|
||||
const std::optional<at::Tensor>& conv_state_indices_,
|
||||
int64_t pad_slot_id);
|
||||
|
||||
void causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight,
|
||||
const std::optional<at::Tensor>& bias_,
|
||||
const std::optional<at::Tensor>& conv_states,
|
||||
const std::optional<at::Tensor>& query_start_loc,
|
||||
const std::optional<at::Tensor>& cache_indices,
|
||||
const std::optional<at::Tensor>& has_initial_state,
|
||||
bool silu_activation, int64_t pad_slot_id);
|
||||
|
||||
using fptr_t = int64_t;
|
||||
fptr_t init_custom_ar(const std::vector<int64_t>& fake_ipc_ptrs,
|
||||
torch::Tensor& rank_data, int64_t rank,
|
||||
|
@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
const float v = fabsf(static_cast<float>(src));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
});
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
thread_mm += static_cast<float>(row_in[i]);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||
[&] __device__(const scalar_t& src) {
|
||||
thread_mm += static_cast<float>(src);
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
|
@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
@ -0,0 +1,373 @@
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include <cassert>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
|
||||
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
__global__ void get_ggemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
|
||||
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scale_base_as_int,
|
||||
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
|
||||
int m = problem_sizes[expert_id * 3];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
int a_stride = expert_offset * k;
|
||||
int b_stride = expert_id * k * n;
|
||||
int a_scale_stride = expert_offset * k / 128;
|
||||
int b_scale_stride = expert_id * k * n / 128 / 128;
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + a_stride;
|
||||
b_offsets[expert_id] = b_base_as_int + b_stride;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
|
||||
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
|
||||
ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes.data_ptr())); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_ggemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
|
||||
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
|
||||
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void run_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Types
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
|
||||
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA, ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
|
||||
layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
int device_id = a_ptrs.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void blockwise_scaled_group_mm_dispatch_shape(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
struct MmaConfig {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
auto a_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto out_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto a_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
auto layout_sfa = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
auto layout_sfb = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
|
||||
auto stride_a = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_b = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_c = torch::full(
|
||||
{num_experts}, output.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
torch::TensorOptions options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
|
||||
typename MmaConfig::LayoutSFB,
|
||||
typename MmaConfig::ScaleConfig>(
|
||||
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
|
||||
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
|
||||
|
||||
run_blockwise_scaled_group_mm<OutType, MmaConfig,
|
||||
typename MmaConfig::LayoutC>(
|
||||
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
|
||||
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
|
||||
expert_offsets);
|
||||
}
|
||||
|
||||
void cutlass_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"a must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"b must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
|
||||
output.scalar_type() == torch::kHalf,
|
||||
"output must be bfloat16 or half");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
|
||||
"scales_a must be float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
|
||||
"scales_b must be float32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
|
||||
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else if (output.scalar_type() == torch::kFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_blockwise_scaled_grouped_mm",
|
||||
&cutlass_blockwise_scaled_grouped_mm);
|
||||
}
|
@ -18,28 +18,34 @@ using ProblemShape =
|
||||
cutlass::gemm::GroupProblemShape<cute::Shape<int, int, int>>;
|
||||
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutA_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using LayoutB_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutB>::type;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
using LayoutD_Transpose =
|
||||
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
|
||||
using LayoutC = LayoutD;
|
||||
using LayoutC_Transpose = LayoutD_Transpose;
|
||||
|
||||
template <typename ElementAB_, typename ElementC_,
|
||||
template <typename ElementAB_, typename ElementC_, typename ArchTag_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
typename EpilogueSchedule, bool swap_ab_ = false>
|
||||
struct cutlass_3x_group_gemm {
|
||||
static constexpr bool swap_ab = swap_ab_;
|
||||
using ElementAB = ElementAB_;
|
||||
using ElementC = void;
|
||||
using ElementD = ElementC_;
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = ArchTag_;
|
||||
|
||||
using Epilogue = Epilogue_<ElementAccumulator, ElementD, TileShape>;
|
||||
|
||||
using StrideC =
|
||||
cute::remove_pointer_t<cute::Stride<int64_t, cute::Int<1>, cute::Int<0>>>;
|
||||
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
@ -50,21 +56,28 @@ struct cutlass_3x_group_gemm {
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, TileShape, ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, ElementC, LayoutC*, AlignmentC, ElementD,
|
||||
LayoutC*, AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
ElementAccumulator, ElementC,
|
||||
conditional_t<swap_ab, LayoutC_Transpose*, LayoutC*>, AlignmentC,
|
||||
ElementD, conditional_t<swap_ab, LayoutD_Transpose*, LayoutD*>,
|
||||
AlignmentC, EpilogueSchedule, EVTCompute>::CollectiveOp;
|
||||
|
||||
static constexpr size_t CEStorageSize =
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage);
|
||||
using Stages = typename cutlass::gemm::collective::StageCountAutoCarveout<
|
||||
static_cast<int>(CEStorageSize)>;
|
||||
|
||||
using CollectiveMainloop =
|
||||
using CollectiveMainloop = conditional_t<
|
||||
swap_ab,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutB_Transpose*, AlignmentAB,
|
||||
ElementAB, LayoutA_Transpose*, AlignmentAB, ElementAccumulator,
|
||||
TileShape, ClusterShape, Stages, KernelSchedule>::CollectiveOp,
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementAB, LayoutA*, AlignmentAB, ElementAB,
|
||||
LayoutB*, AlignmentAB, ElementAccumulator, TileShape, ClusterShape,
|
||||
Stages, KernelSchedule>::CollectiveOp;
|
||||
Stages, KernelSchedule>::CollectiveOp>;
|
||||
|
||||
using KernelType = enable_sm90_only<cutlass::gemm::kernel::GemmUniversal<
|
||||
using KernelType = enable_sm90_or_later<cutlass::gemm::kernel::GemmUniversal<
|
||||
ProblemShape, CollectiveMainloop, CollectiveEpilogue>>;
|
||||
|
||||
struct GemmKernel : public KernelType {};
|
||||
@ -78,12 +91,12 @@ void cutlass_group_gemm_caller(
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
static constexpr bool swap_ab = Gemm::swap_ab;
|
||||
|
||||
using ElementAB = typename Gemm::ElementAB;
|
||||
using ElementD = typename Gemm::ElementD;
|
||||
|
||||
int num_experts = static_cast<int>(expert_offsets.size(0));
|
||||
int k_size = a_tensors.size(1);
|
||||
int n_size = out_tensors.size(1);
|
||||
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
@ -110,26 +123,47 @@ void cutlass_group_gemm_caller(
|
||||
problem_sizes.data_ptr());
|
||||
ProblemShape prob_shape{num_experts, problem_sizes_as_shapes, nullptr};
|
||||
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||
typename GemmKernel::MainloopArguments mainloop_args;
|
||||
if constexpr (swap_ab) {
|
||||
mainloop_args = typename GemmKernel::MainloopArguments{
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr())};
|
||||
} else {
|
||||
mainloop_args = typename GemmKernel::MainloopArguments{
|
||||
static_cast<const ElementAB**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(a_strides.data_ptr()),
|
||||
static_cast<const ElementAB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(b_strides.data_ptr())};
|
||||
}
|
||||
|
||||
// Currently, we are only able to do broadcast on either all or none a_scales
|
||||
// and on either all or none b_scales
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
Gemm::Epilogue::prepare_args(
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
per_act_token, per_out_ch),
|
||||
swap_ab ? static_cast<const ElementAccumulator**>(
|
||||
b_scales_ptrs.data_ptr())
|
||||
: static_cast<const ElementAccumulator**>(
|
||||
a_scales_ptrs.data_ptr()),
|
||||
swap_ab ? static_cast<const ElementAccumulator**>(
|
||||
a_scales_ptrs.data_ptr())
|
||||
: static_cast<const ElementAccumulator**>(
|
||||
b_scales_ptrs.data_ptr()),
|
||||
swap_ab ? per_out_ch : per_act_token,
|
||||
swap_ab ? per_act_token : per_out_ch),
|
||||
nullptr, static_cast<StrideC*>(c_strides.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(c_strides.data_ptr())};
|
||||
|
||||
int device_id = a_tensors.device().index();
|
||||
static const cutlass::KernelHardwareInfo hw_info{
|
||||
device_id, cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
device_id)};
|
||||
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped, prob_shape, mainloop_args,
|
||||
epilogue_args};
|
||||
epilogue_args, hw_info};
|
||||
|
||||
using GemmOp = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
GemmOp gemm_op;
|
||||
|
140
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu
Normal file
140
csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x_sm100.cu
Normal file
@ -0,0 +1,140 @@
|
||||
#include <cudaTypedefs.h>
|
||||
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "cutlass/cutlass.h"
|
||||
#include "grouped_mm_c3x.cuh"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M64 {
|
||||
// M in [1,64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_N8192 {
|
||||
// N in [8192, inf)
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecialized2SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized2Sm;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
void run_cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
TORCH_CHECK(a_tensors.size(0) > 0, "No input A tensors provided.");
|
||||
TORCH_CHECK(b_tensors.size(0) > 0, "No input B tensors provided.");
|
||||
TORCH_CHECK(out_tensors.size(0) > 0, "No output tensors provided.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"A tensors must be of type float8_e4m3fn.");
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
using Cutlass3xGemmDefault = typename sm100_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmN8192 = typename sm100_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 = typename sm100_fp8_config_M64<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a_tensors.size(0);
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
|
||||
if (m <= 64) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
void dispatch_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
if (out_tensors.dtype() == torch::kBFloat16) {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::bfloat16_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
run_cutlass_moe_mm_sm100<cutlass::float_e4m3_t, cutlass::half_t>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
}
|
||||
}
|
||||
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
dispatch_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
}
|
@ -21,27 +21,49 @@ struct sm90_fp8_config_default {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_256, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_2, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
struct sm90_fp8_config_M4 {
|
||||
// M in [1, 4]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_64, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_4, cute::_1>;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm90_fp8_config_M64 {
|
||||
// M in (4, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum;
|
||||
using EpilogueSchedule =
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_16, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_2, cute::_1, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule,
|
||||
true>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -55,10 +77,11 @@ struct sm90_fp8_config_K8192 {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_128, cute::_128, cute::_128>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
@ -72,10 +95,11 @@ struct sm90_fp8_config_N8192 {
|
||||
cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong;
|
||||
using TileShape = cute::Shape<cute::_64, cute::_128, cute::_256>;
|
||||
using ClusterShape = cute::Shape<cute::_1, cute::_8, cute::_1>;
|
||||
using ArchTag = cutlass::arch::Sm90;
|
||||
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_group_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
cutlass_3x_group_gemm<InType, OutType, ArchTag, Epilogue, TileShape,
|
||||
ClusterShape, KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType>
|
||||
@ -95,14 +119,13 @@ void run_cutlass_moe_mm_sm90(
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn,
|
||||
"B tensors must be of type float8_e4m3fn.");
|
||||
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmN8192 = typename sm90_fp8_config_N8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmK8192 = typename sm90_fp8_config_K8192<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM16 = typename sm90_fp8_config_M16<
|
||||
using Cutlass3xGemmM4 = typename sm90_fp8_config_M4<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 = typename sm90_fp8_config_M64<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmDefault = typename sm90_fp8_config_default<
|
||||
InType, OutType, vllm::c3x::ScaledEpilogueArray>::Cutlass3xGemm;
|
||||
@ -111,7 +134,18 @@ void run_cutlass_moe_mm_sm90(
|
||||
uint32_t const n = out_tensors.size(1);
|
||||
uint32_t const k = a_tensors.size(1);
|
||||
|
||||
if (n >= 8192) {
|
||||
// Use swap_ab for M <= 64 by default to reduce padding
|
||||
if (m <= 4) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM4>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (m <= 64) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM64>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (n >= 8192) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmN8192>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
@ -121,11 +155,6 @@ void run_cutlass_moe_mm_sm90(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else if (m <= 16) {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmM16>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
||||
problem_sizes, a_strides, b_strides, c_strides, per_act_token,
|
||||
per_out_ch);
|
||||
} else {
|
||||
cutlass_group_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out_tensors, a_tensors, b_tensors, a_scales, b_scales, expert_offsets,
|
@ -6,8 +6,11 @@
|
||||
#include <iostream>
|
||||
|
||||
constexpr uint64_t THREADS_PER_EXPERT = 512;
|
||||
// threshold must match the dispatch logic in run_cutlass_moe_mm_sm90()
|
||||
constexpr int SWAP_AB_THRESHOLD = 64;
|
||||
|
||||
__global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||
template <bool SWAP_AB>
|
||||
__global__ void compute_problem_sizes(const int32_t* __restrict__ topk_ids,
|
||||
int32_t* problem_sizes1,
|
||||
int32_t* problem_sizes2,
|
||||
int32_t* atomic_buffer,
|
||||
@ -24,45 +27,58 @@ __global__ void compute_problem_sizes(const uint32_t* __restrict__ topk_ids,
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int final_occurrences = atomic_buffer[expert_id];
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
if constexpr (!SWAP_AB) {
|
||||
problem_sizes1[expert_id * 3] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 1] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 1] = k;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
} else {
|
||||
problem_sizes1[expert_id * 3] = 2 * n;
|
||||
problem_sizes1[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes1[expert_id * 3 + 2] = k;
|
||||
problem_sizes2[expert_id * 3] = k;
|
||||
problem_sizes2[expert_id * 3 + 1] = final_occurrences;
|
||||
problem_sizes2[expert_id * 3 + 2] = n;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* atomic_buffer, const int num_experts) {
|
||||
int32_t* atomic_buffer, const int num_experts, const int topk_length) {
|
||||
int32_t tot_offset = 0;
|
||||
expert_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
tot_offset += topk_length > SWAP_AB_THRESHOLD ? problem_sizes1[i * 3]
|
||||
: problem_sizes1[i * 3 + 1];
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_expert_blockscale_offsets(
|
||||
const int32_t* __restrict__ problem_sizes1, int32_t* expert_offsets,
|
||||
int32_t* blockscale_offsets, int32_t* atomic_buffer,
|
||||
const int num_experts) {
|
||||
int32_t* blockscale_offsets, int32_t* atomic_buffer, const int num_experts,
|
||||
const int topk_length) {
|
||||
int32_t tot_offset = 0;
|
||||
int32_t tot_offset_round = 0;
|
||||
expert_offsets[0] = 0;
|
||||
blockscale_offsets[0] = 0;
|
||||
for (int i = 0; i < num_experts; ++i) {
|
||||
int32_t cur_offset = topk_length > SWAP_AB_THRESHOLD
|
||||
? problem_sizes1[i * 3]
|
||||
: problem_sizes1[i * 3 + 1];
|
||||
atomic_buffer[i] = tot_offset;
|
||||
tot_offset += problem_sizes1[i * 3];
|
||||
tot_offset += cur_offset;
|
||||
expert_offsets[i + 1] = tot_offset;
|
||||
tot_offset_round += (problem_sizes1[i * 3] + (128 - 1)) / 128 * 128;
|
||||
tot_offset_round += (cur_offset + (128 - 1)) / 128 * 128;
|
||||
blockscale_offsets[i + 1] = tot_offset_round;
|
||||
}
|
||||
}
|
||||
|
||||
__global__ void compute_arg_sorts(const uint32_t* __restrict__ topk_ids,
|
||||
__global__ void compute_arg_sorts(const int32_t* __restrict__ topk_ids,
|
||||
const int32_t* __restrict__ expert_offsets,
|
||||
int32_t* input_permutation,
|
||||
int32_t* output_permutation,
|
||||
@ -102,25 +118,39 @@ void get_cutlass_moe_mm_data_caller(
|
||||
torch::Tensor atomic_buffer = torch::zeros(num_experts, options_int32);
|
||||
|
||||
int num_threads = min(THREADS_PER_EXPERT, topk_ids.numel());
|
||||
compute_problem_sizes<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n, k);
|
||||
|
||||
if (topk_ids.numel() > SWAP_AB_THRESHOLD) {
|
||||
compute_problem_sizes<false><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
} else {
|
||||
compute_problem_sizes<true><<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), topk_ids.numel(), n,
|
||||
k);
|
||||
}
|
||||
|
||||
if (blockscale_offsets.has_value()) {
|
||||
compute_expert_blockscale_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(blockscale_offsets.value().data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||
topk_ids.numel());
|
||||
} else {
|
||||
compute_expert_offsets<<<1, 1, 0, stream>>>(
|
||||
static_cast<const int32_t*>(problem_sizes1.data_ptr()),
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts);
|
||||
static_cast<int32_t*>(atomic_buffer.data_ptr()), num_experts,
|
||||
topk_ids.numel());
|
||||
}
|
||||
compute_arg_sorts<<<num_experts, num_threads, 0, stream>>>(
|
||||
static_cast<const uint32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(topk_ids.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_offsets.data_ptr()),
|
||||
static_cast<int32_t*>(input_permutation.data_ptr()),
|
||||
static_cast<int32_t*>(output_permutation.data_ptr()),
|
||||
@ -160,4 +190,4 @@ void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets,
|
||||
static_cast<int32_t*>(problem_sizes2.data_ptr()),
|
||||
static_cast<const int32_t*>(expert_num_tokens.data_ptr()), padded_m, n,
|
||||
k);
|
||||
}
|
||||
}
|
@ -41,6 +41,16 @@ void cutlass_moe_mm_sm90(
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
void cutlass_moe_mm_sm100(
|
||||
torch::Tensor& out_tensors, torch::Tensor const& a_tensors,
|
||||
torch::Tensor const& b_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& expert_offsets,
|
||||
torch::Tensor const& problem_sizes, torch::Tensor const& a_strides,
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
@ -130,10 +140,10 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
// and at least SM90 (Hopper)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability >= 90 && cuda_device_capability < 100) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
} else if (cuda_device_capability >= 100) {
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
} else if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12000;
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -141,11 +151,14 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
|
||||
}
|
||||
|
||||
bool cutlass_group_gemm_supported(int64_t cuda_device_capability) {
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3
|
||||
// and SM90 (Hopper)
|
||||
// CUTLASS grouped FP8 kernels need at least CUDA 12.3 and SM90 (Hopper)
|
||||
// or CUDA 12.8 and SM100 (Blackwell)
|
||||
|
||||
#if defined CUDA_VERSION
|
||||
if (cuda_device_capability == 90) {
|
||||
if (cuda_device_capability >= 100) {
|
||||
return CUDA_VERSION >= 12080;
|
||||
}
|
||||
if (cuda_device_capability >= 90) {
|
||||
return CUDA_VERSION >= 12030;
|
||||
}
|
||||
#endif
|
||||
@ -234,16 +247,26 @@ void cutlass_moe_mm(
|
||||
torch::Tensor const& b_strides, torch::Tensor const& c_strides,
|
||||
bool per_act_token, bool per_out_ch) {
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100
|
||||
if (version_num >= 100) {
|
||||
cutlass_moe_mm_sm100(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
if (version_num >= 90) {
|
||||
cutlass_moe_mm_sm90(out_tensors, a_tensors, b_tensors, a_scales, b_scales,
|
||||
expert_offsets, problem_sizes, a_strides, b_strides,
|
||||
c_strides, per_act_token, per_out_ch);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false,
|
||||
"No compiled cutlass_scaled_mm for CUDA device capability: ", version_num,
|
||||
". Required capability: 90");
|
||||
". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_moe_mm_data(
|
||||
|
@ -30,35 +30,40 @@
|
||||
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
|
||||
#include "core/math.hpp"
|
||||
|
||||
using namespace cute;
|
||||
|
||||
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
|
||||
// Kernel Perf config
|
||||
template <typename T>
|
||||
struct KernelTraits;
|
||||
|
||||
template <>
|
||||
struct KernelTraits<float> {
|
||||
using MmaTileShape = Shape<_128, _128, _256>;
|
||||
// Configuration for M in (256, inf)
|
||||
struct sm100_fp4_config_default {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in (16, 256]
|
||||
struct sm100_fp4_config_M256 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_256, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
// Configuration for M in [1, 16]
|
||||
struct sm100_fp4_config_M16 {
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _128, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::half_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <>
|
||||
struct KernelTraits<cutlass::bfloat16_t> {
|
||||
using MmaTileShape = Shape<_256, _256, _256>;
|
||||
using ClusterShape = Shape<_4, _4, _1>;
|
||||
using PerSmTileShape_MNK = Shape<_128, _256, _256>;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
template <typename Config, typename OutType>
|
||||
struct Fp4GemmSm100 {
|
||||
// A matrix configuration
|
||||
using ElementA = cutlass::nv_float4_t<cutlass::float_e2m1_t>;
|
||||
@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
|
||||
static constexpr int AlignmentB = 32;
|
||||
|
||||
// C/D matrix configuration
|
||||
using ElementD = T;
|
||||
using ElementC = T;
|
||||
using ElementD = OutType;
|
||||
using ElementC = OutType;
|
||||
using LayoutCTag = cutlass::layout::RowMajor;
|
||||
using LayoutDTag = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = 128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
// Kernel functional config
|
||||
using ElementAccumulator = float;
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp;
|
||||
|
||||
// Kernel Perf config
|
||||
using MmaTileShape = typename KernelTraits<T>::MmaTileShape;
|
||||
using ClusterShape = typename KernelTraits<T>::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename KernelTraits<T>::PerSmTileShape_MNK;
|
||||
// Use config's tile shapes
|
||||
using MmaTileShape = typename Config::TileShape;
|
||||
using ClusterShape = typename Config::ClusterShape;
|
||||
using PerSmTileShape_MNK = typename Config::PerSmTileShape_MNK;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
|
||||
using LayoutD = decltype(cute::make_layout(make_shape(0, 0, 0), StrideD{}));
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
typename T::Gemm::Arguments args_from_options(
|
||||
template <typename Config>
|
||||
typename Config::Gemm::Arguments args_from_options(
|
||||
at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf, at::Tensor const& alpha,
|
||||
int64_t M, int64_t N, int64_t K) {
|
||||
using ElementA = typename T::Gemm::ElementA;
|
||||
using ElementB = typename T::Gemm::ElementB;
|
||||
using ElementA = typename Config::Gemm::ElementA;
|
||||
using ElementB = typename Config::Gemm::ElementB;
|
||||
using ElementSFA = cutlass::float_ue4m3_t;
|
||||
using ElementSFB = cutlass::float_ue4m3_t;
|
||||
using ElementD = typename T::Gemm::ElementD;
|
||||
using ElementD = typename Config::Gemm::ElementD;
|
||||
using ElementCompute = float;
|
||||
using StrideA = typename T::StrideA;
|
||||
using StrideB = typename T::StrideB;
|
||||
using StrideD = typename T::StrideD;
|
||||
using Sm100BlkScaledConfig =
|
||||
typename T::Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
using StrideA = typename Config::StrideA;
|
||||
using StrideB = typename Config::StrideB;
|
||||
using StrideD = typename Config::StrideD;
|
||||
using Sm100BlkScaledConfig = typename Config::Gemm::GemmKernel::
|
||||
CollectiveMainloop::Sm1xxBlkScaledConfig;
|
||||
|
||||
int m = static_cast<int>(M);
|
||||
int n = static_cast<int>(N);
|
||||
@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
|
||||
auto layout_SFB = Sm100BlkScaledConfig::tile_atom_to_shape_SFB(
|
||||
cute::make_shape(m, n, k, 1));
|
||||
|
||||
typename T::Gemm::Arguments arguments{
|
||||
typename Config::Gemm::Arguments arguments{
|
||||
cutlass::gemm::GemmUniversalMode::kGemm,
|
||||
{m, n, k, 1},
|
||||
{// Mainloop arguments
|
||||
@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
|
||||
return arguments;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
template <typename Config>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
typename Fp4GemmSm100<T>::Gemm gemm;
|
||||
typename Config::Gemm gemm;
|
||||
|
||||
auto arguments =
|
||||
args_from_options<Fp4GemmSm100<T>>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
args_from_options<Config>(D, A, B, A_sf, B_sf, alpha, m, n, k);
|
||||
|
||||
size_t workspace_size = Fp4GemmSm100<T>::Gemm::get_workspace_size(arguments);
|
||||
size_t workspace_size = Config::Gemm::get_workspace_size(arguments);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(A.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
|
||||
CUTLASS_CHECK(gemm.run(arguments, workspace.data_ptr(), stream));
|
||||
}
|
||||
|
||||
// Dispatch function to select appropriate config based on M
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
uint32_t const mp2 = std::max(static_cast<uint32_t>(16), next_pow_2(m));
|
||||
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M16, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (16, 256]
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_M256, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else {
|
||||
// m in (256, inf)
|
||||
runGemm<Fp4GemmSm100<sm100_fp4_config_default, OutType>>(
|
||||
D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
}
|
||||
}
|
||||
|
||||
#else
|
||||
template <typename T>
|
||||
void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
|
||||
at::Tensor const& A_sf, at::Tensor const& B_sf,
|
||||
at::Tensor const& alpha, int64_t m, int64_t n, int64_t k,
|
||||
cudaStream_t stream) {
|
||||
template <typename OutType>
|
||||
void cutlass_fp4_gemm_dispatch(torch::Tensor& D, torch::Tensor const& A,
|
||||
torch::Tensor const& B,
|
||||
torch::Tensor const& A_sf,
|
||||
torch::Tensor const& B_sf,
|
||||
torch::Tensor const& alpha, int64_t m, int64_t n,
|
||||
int64_t k, cudaStream_t stream) {
|
||||
TORCH_CHECK(false,
|
||||
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
|
||||
"a CUTLASS 3.8 source directory to enable support.");
|
||||
@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
runGemm<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::half_t>(D, A, B, A_sf, B_sf, alpha, m, n,
|
||||
k, stream);
|
||||
} else if (out_dtype == at::ScalarType::BFloat16) {
|
||||
runGemm<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
} else if (out_dtype == at::ScalarType::Float) {
|
||||
runGemm<float>(D, A, B, A_sf, B_sf, alpha, m, n, k, stream);
|
||||
cutlass_fp4_gemm_dispatch<cutlass::bfloat16_t>(D, A, B, A_sf, B_sf, alpha,
|
||||
m, n, k, stream);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm");
|
||||
TORCH_CHECK(false, "Unsupported output data type of nvfp4 mm (", out_dtype,
|
||||
")");
|
||||
}
|
||||
}
|
||||
|
@ -88,6 +88,8 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor const& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
int const block_size = 256;
|
||||
int const num_tokens = input.numel() / input.size(-1);
|
||||
int const num_elems = input.numel();
|
||||
@ -111,6 +113,8 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
|
||||
torch::Tensor const& input, // [..., d]
|
||||
torch::Tensor& scale) // [1]
|
||||
{
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(out.is_contiguous());
|
||||
int const block_size = 256;
|
||||
int const num_tokens = input.numel() / input.size(-1);
|
||||
int const num_elems = input.numel();
|
||||
|
213
csrc/quantization/fp8/per_token_group_quant.cu
Normal file
213
csrc/quantization/fp8/per_token_group_quant.cu
Normal file
@ -0,0 +1,213 @@
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include <cuda_fp16.h>
|
||||
#include <cuda_bf16.h>
|
||||
|
||||
#include <torch/all.h>
|
||||
|
||||
#include "../vectorization.cuh"
|
||||
#include "../vectorization_utils.cuh"
|
||||
#include "../../dispatch_utils.h"
|
||||
|
||||
__device__ __forceinline__ float GroupReduceMax(float val, const int tid) {
|
||||
unsigned mask = 0xffff;
|
||||
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 8));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 4));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 2));
|
||||
val = fmaxf(val, __shfl_xor_sync(mask, val, 1));
|
||||
return val;
|
||||
}
|
||||
|
||||
template <typename T, typename DST_DTYPE, bool IS_COLUMN_MAJOR = false,
|
||||
bool SCALE_UE8M0 = false, typename scale_packed_t = float>
|
||||
__global__ void per_token_group_quant_8bit_kernel(
|
||||
const T* __restrict__ input, void* __restrict__ output_q,
|
||||
scale_packed_t* __restrict__ output_s, const int group_size,
|
||||
const int num_groups, const int groups_per_block, const float eps,
|
||||
const float min_8bit, const float max_8bit, const int scale_num_rows = 0,
|
||||
const int scale_stride = 0) {
|
||||
const int threads_per_group = 16;
|
||||
const int64_t local_group_id = threadIdx.x / threads_per_group;
|
||||
const int lane_id = threadIdx.x % threads_per_group;
|
||||
|
||||
const int64_t block_group_id = blockIdx.x * groups_per_block;
|
||||
const int64_t global_group_id = block_group_id + local_group_id;
|
||||
const int64_t block_group_offset = global_group_id * group_size;
|
||||
|
||||
float local_absmax = eps;
|
||||
|
||||
using scale_element_t = float;
|
||||
static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0);
|
||||
|
||||
const T* group_input = input + block_group_offset;
|
||||
DST_DTYPE* group_output =
|
||||
static_cast<DST_DTYPE*>(output_q) + block_group_offset;
|
||||
scale_element_t* scale_output;
|
||||
|
||||
if constexpr (IS_COLUMN_MAJOR) {
|
||||
const int num_elems_per_pack =
|
||||
static_cast<int>(sizeof(scale_packed_t) / sizeof(scale_element_t));
|
||||
const int scale_num_rows_element = scale_num_rows * num_elems_per_pack;
|
||||
const int row_idx = global_group_id / scale_num_rows_element;
|
||||
const int col_idx_raw = global_group_id % scale_num_rows_element;
|
||||
const int col_idx = col_idx_raw / num_elems_per_pack;
|
||||
const int pack_idx = col_idx_raw % num_elems_per_pack;
|
||||
scale_output = reinterpret_cast<scale_element_t*>(output_s) +
|
||||
(col_idx * scale_stride * num_elems_per_pack +
|
||||
row_idx * num_elems_per_pack + pack_idx);
|
||||
} else {
|
||||
scale_output = output_s + global_group_id;
|
||||
}
|
||||
|
||||
// shared memory to cache each group's data to avoid double DRAM reads.
|
||||
extern __shared__ __align__(16) char smem_raw[];
|
||||
T* smem = reinterpret_cast<T*>(smem_raw);
|
||||
T* smem_group = smem + local_group_id * group_size;
|
||||
|
||||
constexpr int vec_size = 16 / sizeof(T);
|
||||
using vec_t = vllm::vec_n_t<T, vec_size>;
|
||||
|
||||
// copy global -> shared & compute absmax
|
||||
auto scalar_op_cache = [&] __device__(T & dst, const T& src) {
|
||||
float abs_v = fabsf(static_cast<float>(src));
|
||||
local_absmax = fmaxf(local_absmax, abs_v);
|
||||
dst = src;
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
group_input, // in
|
||||
smem_group, // out (shared)
|
||||
group_size, // elements per group
|
||||
lane_id, // thread id
|
||||
threads_per_group, // stride in group
|
||||
scalar_op_cache); // scalar handler
|
||||
|
||||
local_absmax = GroupReduceMax(local_absmax, lane_id);
|
||||
|
||||
float y_s = local_absmax / max_8bit;
|
||||
if constexpr (SCALE_UE8M0) {
|
||||
y_s = exp2f(ceilf(log2f(fmaxf(fabsf(y_s), 1e-10f))));
|
||||
}
|
||||
|
||||
scale_element_t y_s_quant = y_s;
|
||||
|
||||
if (lane_id == 0) {
|
||||
*scale_output = y_s_quant;
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
|
||||
// quantize shared -> global 8-bit
|
||||
auto scalar_op_quant = [&] __device__(DST_DTYPE & dst, const T& src) {
|
||||
float q = fminf(fmaxf(static_cast<float>(src) / y_s, min_8bit), max_8bit);
|
||||
dst = DST_DTYPE(q);
|
||||
};
|
||||
|
||||
vllm::vectorize_with_alignment<vec_size>(
|
||||
smem_group, // in (shared)
|
||||
group_output, // out (global quant tensor)
|
||||
group_size, // elements
|
||||
lane_id, // tid
|
||||
threads_per_group, // stride
|
||||
scalar_op_quant); // scalar handler
|
||||
}
|
||||
|
||||
void per_token_group_quant_8bit(const torch::Tensor& input,
|
||||
torch::Tensor& output_q,
|
||||
torch::Tensor& output_s, int64_t group_size,
|
||||
double eps, double min_8bit, double max_8bit,
|
||||
bool scale_ue8m0 = false) {
|
||||
TORCH_CHECK(input.is_contiguous());
|
||||
TORCH_CHECK(output_q.is_contiguous());
|
||||
|
||||
const int num_groups = input.numel() / group_size;
|
||||
|
||||
TORCH_CHECK(input.numel() % group_size == 0);
|
||||
TORCH_CHECK(output_s.dim() == 2);
|
||||
|
||||
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
constexpr int THREADS_PER_GROUP = 16;
|
||||
|
||||
int groups_per_block = 1;
|
||||
|
||||
if (num_groups % 16 == 0) {
|
||||
groups_per_block = 16;
|
||||
} else if (num_groups % 8 == 0) {
|
||||
groups_per_block = 8;
|
||||
} else if (num_groups % 4 == 0) {
|
||||
groups_per_block = 4;
|
||||
} else if (num_groups % 2 == 0) {
|
||||
groups_per_block = 2;
|
||||
}
|
||||
|
||||
auto dst_type = output_q.scalar_type();
|
||||
const int num_blocks = num_groups / groups_per_block;
|
||||
const int num_threads = groups_per_block * THREADS_PER_GROUP;
|
||||
|
||||
const bool is_column_major = output_s.stride(0) < output_s.stride(1);
|
||||
const int scale_num_rows = output_s.size(1);
|
||||
const int scale_stride = output_s.stride(1);
|
||||
|
||||
#define LAUNCH_KERNEL(T, DST_DTYPE) \
|
||||
do { \
|
||||
dim3 grid(num_blocks); \
|
||||
dim3 block(num_threads); \
|
||||
size_t smem_bytes = \
|
||||
static_cast<size_t>(groups_per_block) * group_size * sizeof(T); \
|
||||
if (is_column_major) { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, true> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit, scale_num_rows, scale_stride); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, true, false> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit, scale_num_rows, scale_stride); \
|
||||
} \
|
||||
} else { \
|
||||
if (scale_ue8m0) { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, true> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} else { \
|
||||
per_token_group_quant_8bit_kernel<T, DST_DTYPE, false, false> \
|
||||
<<<grid, block, smem_bytes, stream>>>( \
|
||||
static_cast<T*>(input.data_ptr()), output_q.data_ptr(), \
|
||||
static_cast<float*>(output_s.data_ptr()), group_size, \
|
||||
num_groups, groups_per_block, (float)eps, (float)min_8bit, \
|
||||
(float)max_8bit); \
|
||||
} \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
VLLM_DISPATCH_FLOATING_TYPES(
|
||||
input.scalar_type(), "per_token_group_quant_8bit", ([&] {
|
||||
if (dst_type == at::ScalarType::Float8_e4m3fn) {
|
||||
LAUNCH_KERNEL(scalar_t, c10::Float8_e4m3fn);
|
||||
}
|
||||
}));
|
||||
|
||||
#undef LAUNCH_KERNEL
|
||||
}
|
||||
|
||||
void per_token_group_quant_fp8(const torch::Tensor& input,
|
||||
torch::Tensor& output_q, torch::Tensor& output_s,
|
||||
int64_t group_size, double eps, double fp8_min,
|
||||
double fp8_max, bool scale_ue8m0) {
|
||||
per_token_group_quant_8bit(input, output_q, output_s, group_size, eps,
|
||||
fp8_min, fp8_max, scale_ue8m0);
|
||||
}
|
@ -38,7 +38,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
|
||||
|
@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
// Note: currently the output is guaranteed to be same as the input, so we
|
||||
// don't check it here, comments here just for future reference.
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||
@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
struct DefaultReadVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// read-only version: iterate over the input with alignment guarantees
|
||||
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
|
||||
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||
int tid, int stride,
|
||||
VecOp&& vec_op,
|
||||
ScaOp&& scalar_op) {
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1);
|
||||
int alignment_bytes = WIDTH - misalignment_offset;
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1);
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len);
|
||||
|
||||
// 1. handle the possibly unaligned prefix with scalar access.
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
// 2. vectorized traversal of the main aligned region.
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
|
||||
// 3. handle remaining tail elements.
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// overload that requires only a scalar_op
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_read_with_alignment(
|
||||
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
|
||||
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
|
||||
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -59,6 +59,8 @@ void apply_repetition_penalties_(
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
|
||||
if (num_seqs == 0) return;
|
||||
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
|
@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
@ -20,13 +20,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
//
|
||||
|
||||
// The default behavior in PyTorch 2.6 is "requires_contiguous", so we need
|
||||
// The default behavior in PyTorch 2.6 was changed to "requires_contiguous",
|
||||
// so we need
|
||||
// to override this for many GEMMs with the following tag. Otherwise,
|
||||
// torch.compile will force all input tensors to be contiguous(), which
|
||||
// will break many custom ops that require column-major weight matrices.
|
||||
// TODO: remove this for PyTorch 2.8, when the default is planned to switch
|
||||
// to match exact eager-mode strides.
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// This was a bug and PyTorch 2.7 has since fixed this.
|
||||
#if TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 6
|
||||
#define stride_tag at::Tag::needs_fixed_stride_order
|
||||
#else
|
||||
#define stride_tag
|
||||
#endif
|
||||
|
||||
ops.def("weak_ref_tensor(Tensor input) -> Tensor");
|
||||
ops.impl("weak_ref_tensor", torch::kCUDA, &weak_ref_tensor);
|
||||
@ -393,6 +397,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass blockwise scaledgroup GEMM
|
||||
ops.def(
|
||||
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
|
||||
"Tensor scales_a, Tensor scales_b, "
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -506,6 +518,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
" Tensor page_table, float scale) -> ()");
|
||||
ops.impl("cutlass_mla_decode", torch::kCUDA, &cutlass_mla_decode);
|
||||
|
||||
// SM100 CUTLASS MLA decode
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_decode(Tensor! out, Tensor q_nope, Tensor q_pe,"
|
||||
" Tensor kv_c_and_k_pe_cache, Tensor seq_lens,"
|
||||
" Tensor page_table, Tensor workspace, float "
|
||||
"scale,"
|
||||
" int num_kv_splits) -> ()");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// SM100 CUTLASS MLA workspace
|
||||
ops.def(
|
||||
"sm100_cutlass_mla_get_workspace_size(int max_seq_len, int num_batches,"
|
||||
" int sm_count, int num_kv_splits) "
|
||||
"-> int");
|
||||
// conditionally compiled so impl in source file
|
||||
|
||||
// Compute NVFP4 block quantized tensor.
|
||||
ops.def(
|
||||
"scaled_fp4_quant(Tensor! output, Tensor input,"
|
||||
@ -586,29 +614,16 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_update(Tensor! x,"
|
||||
"Tensor! conv_state,"
|
||||
"Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"bool silu_activation,"
|
||||
"Tensor? cache_seqlens_,"
|
||||
"Tensor? conv_state_indices,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update);
|
||||
|
||||
ops.def(
|
||||
"causal_conv1d_fwd(Tensor! x, Tensor! weight,"
|
||||
"Tensor? bias_,"
|
||||
"Tensor!? conv_states,"
|
||||
"Tensor? query_start_loc,"
|
||||
"Tensor? cache_indices,"
|
||||
"Tensor? has_initial_state,"
|
||||
"bool silu_activation,"
|
||||
"int pad_slot_id) -> ()");
|
||||
ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd);
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Compute per-token-group FP8 quantized tensor and scaling factor.
|
||||
ops.def(
|
||||
"per_token_group_fp8_quant(Tensor input, Tensor! output_q, Tensor! "
|
||||
"output_s, "
|
||||
"int group_size, float eps, float fp8_min, float fp8_max, bool "
|
||||
"scale_ue8m0) -> ()");
|
||||
ops.impl("per_token_group_fp8_quant", torch::kCUDA,
|
||||
&per_token_group_quant_fp8);
|
||||
|
||||
// reorder weight for AllSpark Ampere W8A16 Fused Gemm kernel
|
||||
ops.def(
|
||||
"rearrange_kn_weight_as_n32k16_order(Tensor b_qweight, Tensor b_scales, "
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
@ -62,12 +63,16 @@ ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
|
||||
ARG PIP_KEYRING_PROVIDER=disabled
|
||||
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
|
||||
|
||||
# Flag enables built-in KV-connector dependency libs into docker images
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
@ -202,6 +207,19 @@ ARG SCCACHE_ENDPOINT
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
|
||||
# Flag to control whether to use pre-built vLLM wheels
|
||||
ARG VLLM_USE_PRECOMPILED
|
||||
# TODO: in setup.py VLLM_USE_PRECOMPILED is sensitive to truthiness, it will take =0 as "true", this should be fixed
|
||||
ENV VLLM_USE_PRECOMPILED=""
|
||||
RUN if [ "${VLLM_USE_PRECOMPILED}" = "1" ]; then \
|
||||
export VLLM_USE_PRECOMPILED=1 && \
|
||||
echo "Using precompiled wheels"; \
|
||||
else \
|
||||
unset VLLM_USE_PRECOMPILED && \
|
||||
echo "Leaving VLLM_USE_PRECOMPILED unset to build wheels from source"; \
|
||||
fi
|
||||
|
||||
# if USE_SCCACHE is set, use sccache to speed up compilation
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
@ -247,7 +265,7 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
#################### EXTENSION Build IMAGE ####################
|
||||
|
||||
#################### DEV IMAGE ####################
|
||||
FROM base as dev
|
||||
FROM base AS dev
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
@ -276,6 +294,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
@ -369,28 +388,34 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
|
||||
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
|
||||
|
||||
# Allow specifying a version, Git revision or local .whl file
|
||||
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
|
||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||
# Install FlashInfer from source
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL} ; \
|
||||
else \
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0' && \
|
||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive && \
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
) && \
|
||||
rm -rf flashinfer; \
|
||||
fi \
|
||||
fi
|
||||
ARG FLASHINFER_GIT_REF="v0.2.8rc1"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch ${FLASHINFER_GIT_REF} \
|
||||
${FLASHINFER_GIT_REPO} flashinfer
|
||||
# Exclude CUDA arches for older versions (11.x and 12.0-12.7)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
elif [[ "${CUDA_VERSION}" == 12.[0-7]* ]]; then
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
else
|
||||
# CUDA 12.8+ supports 10.0a and 12.0
|
||||
FI_TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a 10.0a 12.0"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${FI_TORCH_CUDA_ARCH_LIST}"
|
||||
# Needed to build AOT kernels
|
||||
pushd flashinfer
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${FI_TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
rm -rf flashinfer
|
||||
BASH
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
@ -464,6 +489,7 @@ RUN mv mkdocs.yaml test_docs/
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
@ -472,13 +498,19 @@ ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY requirements/kv_connectors.txt requirements/kv_connectors.txt
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r requirements/kv_connectors.txt; \
|
||||
fi; \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
BITSANDBYTES_VERSION="0.42.0"; \
|
||||
else \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.3' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
BITSANDBYTES_VERSION="0.46.1"; \
|
||||
fi; \
|
||||
uv pip install --system accelerate hf_transfer modelscope "bitsandbytes>=${BITSANDBYTES_VERSION}" 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
|
@ -8,6 +8,8 @@
|
||||
# Build arguments:
|
||||
# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9
|
||||
# VLLM_CPU_DISABLE_AVX512=false (default)|true
|
||||
# VLLM_CPU_AVX512BF16=false (default)|true
|
||||
# VLLM_CPU_AVX512VNNI=false (default)|true
|
||||
#
|
||||
|
||||
######################### BASE IMAGE #########################
|
||||
@ -25,7 +27,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
@ -60,8 +62,14 @@ FROM base AS vllm-build
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
||||
ARG VLLM_CPU_DISABLE_AVX512
|
||||
ARG VLLM_CPU_DISABLE_AVX512=0
|
||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||
# Support for building with AVX512BF16 ISA: docker build --build-arg VLLM_CPU_AVX512BF16="true" ...
|
||||
ARG VLLM_CPU_AVX512BF16=0
|
||||
ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
|
||||
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
|
||||
ARG VLLM_CPU_AVX512VNNI=0
|
||||
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
@ -87,7 +95,7 @@ WORKDIR /workspace/vllm
|
||||
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/cpu-test.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||
sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/^torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
|
||||
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
|
||||
@ -134,6 +142,7 @@ ADD ./tests/ ./tests/
|
||||
ADD ./examples/ ./examples/
|
||||
ADD ./benchmarks/ ./benchmarks/
|
||||
ADD ./vllm/collect_env.py .
|
||||
ADD ./.buildkite/ ./.buildkite/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
@ -1,21 +0,0 @@
|
||||
FROM vault.habana.ai/gaudi-docker/1.20.1/ubuntu22.04/habanalabs/pytorch-installer-2.6.0:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"]
|
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="6487649"
|
||||
ARG AITER_BRANCH="916bf3c"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
@ -1,5 +1,5 @@
|
||||
ARG NIGHTLY_DATE="20250124"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_$NIGHTLY_DATE"
|
||||
ARG NIGHTLY_DATE="20250714"
|
||||
ARG BASE_IMAGE="us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.12_tpuvm_$NIGHTLY_DATE"
|
||||
|
||||
FROM $BASE_IMAGE
|
||||
WORKDIR /workspace/vllm
|
||||
|
@ -47,7 +47,7 @@ FROM vllm-base AS vllm-openai
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
pip install accelerate hf_transfer 'modelscope!=1.15.0'
|
||||
pip install accelerate hf_transfer pytest pytest_asyncio lm_eval[api] modelscope
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image \
|
||||
TRITON_XPU_PROFILE 1
|
||||
|
@ -55,6 +55,7 @@ nav:
|
||||
- contributing/model/registration.md
|
||||
- contributing/model/tests.md
|
||||
- contributing/model/multimodal.md
|
||||
- CI: contributing/ci
|
||||
- Design Documents:
|
||||
- V0: design
|
||||
- V1: design/v1
|
||||
|
@ -36,7 +36,7 @@ vLLM is flexible and easy to use with:
|
||||
|
||||
- Seamless integration with popular HuggingFace models
|
||||
- High-throughput serving with various decoding algorithms, including *parallel sampling*, *beam search*, and more
|
||||
- Tensor parallelism and pipeline parallelism support for distributed inference
|
||||
- Tensor, pipeline, data and expert parallelism support for distributed inference
|
||||
- Streaming outputs
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
@ -48,4 +48,4 @@ For more information, check out the following:
|
||||
- [vLLM announcing blog post](https://vllm.ai) (intro to PagedAttention)
|
||||
- [vLLM paper](https://arxiv.org/abs/2309.06180) (SOSP 2023)
|
||||
- [How continuous batching enables 23x throughput in LLM inference while reducing p50 latency](https://www.anyscale.com/blog/continuous-batching-llm-inference) by Cade Daniel et al.
|
||||
- [vLLM Meetups][meetups]
|
||||
- [vLLM Meetups](community/meetups.md)
|
||||
|
@ -8,14 +8,12 @@ API documentation for vLLM's configuration classes.
|
||||
|
||||
- [vllm.config.ModelConfig][]
|
||||
- [vllm.config.CacheConfig][]
|
||||
- [vllm.config.TokenizerPoolConfig][]
|
||||
- [vllm.config.LoadConfig][]
|
||||
- [vllm.config.ParallelConfig][]
|
||||
- [vllm.config.SchedulerConfig][]
|
||||
- [vllm.config.DeviceConfig][]
|
||||
- [vllm.config.SpeculativeConfig][]
|
||||
- [vllm.config.LoRAConfig][]
|
||||
- [vllm.config.PromptAdapterConfig][]
|
||||
- [vllm.config.MultiModalConfig][]
|
||||
- [vllm.config.PoolerConfig][]
|
||||
- [vllm.config.DecodingConfig][]
|
||||
@ -64,7 +62,7 @@ vLLM provides experimental support for multi-modal models through the [vllm.mult
|
||||
Multi-modal inputs can be passed alongside text and token prompts to [supported models][supported-mm-models]
|
||||
via the `multi_modal_data` field in [vllm.inputs.PromptType][].
|
||||
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here][supports-multimodal].
|
||||
Looking to add your own multi-modal model? Please follow the instructions listed [here](../contributing/model/multimodal.md).
|
||||
|
||||
- [vllm.multimodal.MULTIMODAL_REGISTRY][]
|
||||
|
||||
|
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
BIN
docs/assets/deployment/dp_external_lb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 84 KiB |
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
BIN
docs/assets/deployment/dp_internal_lb.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
Binary file not shown.
Before Width: | Height: | Size: 68 KiB After Width: | Height: | Size: 57 KiB |
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user