mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-20 23:03:52 +08:00
Compare commits
312 Commits
deepep_twe
...
v0.9.2rc2
Author | SHA1 | Date | |
---|---|---|---|
a5dd03c1eb | |||
c18b3b8e8b | |||
9528e3a05e | |||
9fb52e523a | |||
e202dd2736 | |||
43813e6361 | |||
cede942b87 | |||
fe1e924811 | |||
4548c03c50 | |||
40b86aa05e | |||
432870829d | |||
f73d02aadc | |||
c5ebe040ac | |||
8d763cb891 | |||
cf4cd53982 | |||
32c9be2200 | |||
8aeaa910a2 | |||
906e05d840 | |||
ef9a2990ae | |||
7e90870491 | |||
d3f05c9248 | |||
c108781c85 | |||
3d184b95b8 | |||
2f35a022e6 | |||
ffe00ef77a | |||
5561681d04 | |||
fbd62d8750 | |||
2e26f9156a | |||
9e5452ee34 | |||
0e3fe896e2 | |||
1caca5a589 | |||
783921d889 | |||
4a98edff1f | |||
a7bab0c9e5 | |||
25950dca9b | |||
a4113b035c | |||
7e1665b089 | |||
8d1096e7db | |||
8d775dd30a | |||
78fe77534b | |||
2f2fcb31b8 | |||
1dba2c4ebe | |||
71d6de3a26 | |||
536fd33003 | |||
619b9f5c7e | |||
d1b689c445 | |||
9854dc9040 | |||
ff5c60fad8 | |||
6f1229f91d | |||
1819fbda63 | |||
7f0367109e | |||
fb14d53cf6 | |||
b024a42e93 | |||
cb97f2bfc5 | |||
359200f6ac | |||
220aee902a | |||
67d25eca05 | |||
363528de27 | |||
4ff61ababa | |||
0ec3779df7 | |||
b616f6a53d | |||
2e25bb12a8 | |||
9965c47d0d | |||
059d4cdb49 | |||
bdb84e26b0 | |||
3dd359147d | |||
657f2f301a | |||
a1aafc827a | |||
139508a418 | |||
d265414dbc | |||
48fb076cbc | |||
c1909e7e8c | |||
b95877509b | |||
706ff13224 | |||
ccbfb1d1c9 | |||
9e5552aa13 | |||
0c600b9ab6 | |||
e303dcf523 | |||
ae9c4d416f | |||
d853520b3e | |||
ba51aea65e | |||
8452946c06 | |||
2e7cbf2d7d | |||
7da296be04 | |||
b205e8467d | |||
be0cfb2b68 | |||
1a03dd496b | |||
27b8017636 | |||
9ec1e3065a | |||
9dae7d46bf | |||
7058d7dd5d | |||
a0389e0554 | |||
3be8d312a2 | |||
3abfe22154 | |||
e81fbefe8a | |||
9290de5667 | |||
7f280d69c9 | |||
02cabff207 | |||
3d19d47d91 | |||
8acb4badee | |||
314af8617c | |||
0e96cc9b7e | |||
ecad851cbd | |||
ed70f3c64f | |||
650d5dbd04 | |||
9025a9a705 | |||
c05596f1a3 | |||
787b13389e | |||
96453cfa83 | |||
b1c1fe35a5 | |||
08d81f1014 | |||
6cc1e7d96d | |||
9909726d2a | |||
22e9d42040 | |||
86debab54c | |||
be250bbc67 | |||
27949354fa | |||
bd5038af07 | |||
a2f14dc8f9 | |||
92ee7baaf9 | |||
7151f92241 | |||
e28533a16f | |||
6d42ce8315 | |||
ded1fb635b | |||
97d9524fe9 | |||
d8cf819a9a | |||
551ef1631a | |||
2863befce3 | |||
2965c99c86 | |||
2062c0723d | |||
1c50e100a9 | |||
3ee56e26be | |||
8fe7fc8634 | |||
e936e401de | |||
f5dfa07531 | |||
022c58b80f | |||
19108ef311 | |||
5a52f389dd | |||
65b1cbb138 | |||
6c9837a761 | |||
6f2f53a82d | |||
7b1895e6ce | |||
4d36693687 | |||
daec9dea6e | |||
daceac57c7 | |||
8615d9776f | |||
7b460c25f9 | |||
f719772281 | |||
d45417b804 | |||
a29e62ea34 | |||
e53be6f00a | |||
c329ceca6d | |||
3c545c0c3b | |||
e8c3bd2cd1 | |||
c6c983053d | |||
aafabaa0d5 | |||
94a55c7681 | |||
aa0dc77ef5 | |||
4ab3ac285e | |||
d1c956dc0f | |||
dec197e3e5 | |||
6e244ae091 | |||
cd4cfee689 | |||
e110930680 | |||
8b64c895c0 | |||
0740e29b66 | |||
44d2e6af63 | |||
2d7779f888 | |||
a57d57fa72 | |||
71799fd005 | |||
e9fd658a73 | |||
07b8fae219 | |||
562308816c | |||
04e1642e32 | |||
b69781f107 | |||
0bceac9810 | |||
34878a0b48 | |||
6393b03986 | |||
0907d507bf | |||
c894c5dc1f | |||
1f5d178e9c | |||
27c065df50 | |||
84c260caeb | |||
167aca45cb | |||
0567c8249f | |||
d188913d99 | |||
1d7c29f5fe | |||
65397e40f5 | |||
9502c38138 | |||
2582683566 | |||
754b00edb3 | |||
296ce95d8e | |||
2d7620c3eb | |||
55c65ab495 | |||
2cc2069970 | |||
9f0608fc16 | |||
4e0db57fff | |||
c40692bf9a | |||
4734704b30 | |||
8b8c209e35 | |||
23a04e0895 | |||
02c97d9a92 | |||
e795d723ed | |||
8359f4c8d8 | |||
bf5181583f | |||
c53fec1fcb | |||
0f9e7354f5 | |||
ba7ba35cda | |||
015fab8c2f | |||
f59fc60fb3 | |||
879f69bed3 | |||
7108934142 | |||
3443aaf8dd | |||
2273ec322c | |||
a6c4b87fbc | |||
1afa9948f5 | |||
0d06b533a0 | |||
c01d1c5aba | |||
ead369845d | |||
c6e3bba8e6 | |||
91f7d9d0b6 | |||
8619e7158c | |||
c635c5f744 | |||
a045b7e89a | |||
981eeca41a | |||
26d34eb67e | |||
53da4cd397 | |||
9a3b88328f | |||
3014c920da | |||
0eed516951 | |||
ee5ad8d2c5 | |||
a738dbb2a1 | |||
33d5e29be9 | |||
4671ac6e2a | |||
dd2ccf8dde | |||
a3bc76e4b5 | |||
e6327c9b3e | |||
d0132f025d | |||
61f4fc5dc6 | |||
68aaeb3749 | |||
c3649e4fee | |||
53243e5c42 | |||
a6e6604d32 | |||
b82e0f82cb | |||
5111642a6f | |||
1bcd15edc7 | |||
2ebff5b77c | |||
f17aec0d63 | |||
493c275352 | |||
f39ab2d4bd | |||
4a0f7888a3 | |||
c4cf260677 | |||
33d51f599e | |||
e91386cde1 | |||
2c11a29f0b | |||
c76a506bd6 | |||
ec0db6f51c | |||
c305a2109d | |||
202c5df935 | |||
2bb246b8f7 | |||
4c409cabc2 | |||
3b1e4c6a23 | |||
2c5302fadd | |||
caa680fd2e | |||
c3bf9bad11 | |||
6f170f11dd | |||
8ca81bb069 | |||
e773a9e1c2 | |||
71baf85ae1 | |||
79f2f1c2a1 | |||
2e3e3c86dc | |||
7e8977fcd4 | |||
f1e840e842 | |||
7771d1de88 | |||
71d1219545 | |||
e384f2f108 | |||
089a306f19 | |||
5e666f72cd | |||
e3a3e4db46 | |||
e41bf15cd0 | |||
5aa4a015ce | |||
b6bad3d186 | |||
ee9a1531aa | |||
10d82f9ac5 | |||
ea10dd9d9e | |||
ead2110297 | |||
01220ce89a | |||
6f68c49220 | |||
4719460644 | |||
466166dcfd | |||
1d0ae26c85 | |||
6021999573 | |||
c7b370c603 | |||
aa20d10a91 | |||
2de12be428 | |||
83ca9ae47b | |||
e2148dc5ea | |||
b1098b4072 | |||
799397ee4f | |||
4959915089 | |||
8d1e89d946 | |||
36239f79dd | |||
dfada85eee | |||
ed33349738 | |||
d49adea1f9 | |||
14fdd21d39 | |||
04fefe7c9a | |||
3b523e38d9 | |||
16c16301c8 | |||
9206d0ff01 | |||
a89209b78d | |||
ffacb222cb |
@ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
@ -31,13 +31,27 @@ Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||
|
||||
Manually Trigger the benchmark
|
||||
|
||||
```bash
|
||||
bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
|
||||
Runtime environment variables:
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
|
||||
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead.
|
||||
### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
@ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
@ -16,7 +16,7 @@ Please download the visualization scripts in the post
|
||||
- Download `nightly-benchmarks.zip`.
|
||||
- In the same folder, run the following code:
|
||||
|
||||
```console
|
||||
```bash
|
||||
export HF_TOKEN=<your HF token>
|
||||
apt update
|
||||
apt install -y git
|
||||
|
@ -4,7 +4,8 @@
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
{latency_tests_markdown_table}
|
||||
@ -14,7 +15,8 @@
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput.
|
||||
|
||||
{throughput_tests_markdown_table}
|
||||
@ -25,12 +27,18 @@
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
- For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts.
|
||||
|
||||
{serving_tests_markdown_table}
|
||||
|
||||
## Platform Information
|
||||
|
||||
{platform_markdown_table}
|
||||
|
||||
## json version of the benchmarking tables
|
||||
|
||||
This section contains the data of the markdown tables above in JSON format.
|
||||
|
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
compare_frames = []
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
return concat_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
ignore_test_name = args.ignore_test_name
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
drop_column,
|
||||
ignore_test_name=ignore_test_name,
|
||||
)
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
@ -3,9 +3,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
@ -29,11 +31,11 @@ throughput_results = []
|
||||
throughput_results_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "num_requests": "# of req.",
|
||||
# "total_num_tokens": "Total # of tokens",
|
||||
# "elapsed_time": "Elapsed time (s)",
|
||||
"num_requests": "# of req.",
|
||||
"total_num_tokens": "Total # of tokens",
|
||||
"elapsed_time": "Elapsed time (s)",
|
||||
"requests_per_second": "Tput (req/s)",
|
||||
# "tokens_per_second": "Tput (tok/s)",
|
||||
"tokens_per_second": "Tput (tok/s)",
|
||||
}
|
||||
|
||||
# serving results and the keys that will be printed into markdown
|
||||
@ -41,16 +43,18 @@ serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "completed": "# of req.",
|
||||
"completed": "# of req.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
# "input_throughput": "Input Tput (tok/s)",
|
||||
# "output_throughput": "Output Tput (tok/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
# "mean_tpot_ms": "Mean TPOT (ms)",
|
||||
# "median_tpot_ms": "Median",
|
||||
# "p99_tpot_ms": "P99",
|
||||
"mean_tpot_ms": "Mean TPOT (ms)",
|
||||
"median_tpot_ms": "Median",
|
||||
"p99_tpot_ms": "P99",
|
||||
"mean_itl_ms": "Mean ITL (ms)",
|
||||
"median_itl_ms": "Median ITL (ms)",
|
||||
"p99_itl_ms": "P99 ITL (ms)",
|
||||
@ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving):
|
||||
)
|
||||
|
||||
|
||||
def get_size_with_unit(bytes, suffix="B"):
|
||||
"""
|
||||
Scale bytes to its proper format
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if bytes < factor:
|
||||
return f"{bytes:.2f}{unit}{suffix}"
|
||||
bytes /= factor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
@ -155,6 +173,27 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||
|
||||
svmem = psutil.virtual_memory()
|
||||
platform_data = {
|
||||
"Physical cores": [psutil.cpu_count(logical=False)],
|
||||
"Total cores": [psutil.cpu_count(logical=True)],
|
||||
"Total Memory": [get_size_with_unit(svmem.total)],
|
||||
}
|
||||
|
||||
if util.find_spec("numa") is not None:
|
||||
from numa import info
|
||||
|
||||
platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()]
|
||||
|
||||
if util.find_spec("cpuinfo") is not None:
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]]
|
||||
|
||||
platform_results = pd.DataFrame.from_dict(
|
||||
platform_data, orient="index", columns=["Platform Info"]
|
||||
)
|
||||
|
||||
raw_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
@ -200,6 +239,9 @@ if __name__ == "__main__":
|
||||
throughput_md_table = tabulate(
|
||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
platform_md_table = tabulate(
|
||||
platform_results, headers="keys", tablefmt="pipe", showindex=True
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
@ -211,6 +253,7 @@ if __name__ == "__main__":
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
serving_tests_markdown_table=serving_md_table,
|
||||
platform_markdown_table=platform_md_table,
|
||||
benchmarking_results_in_json_string=processed_results_json,
|
||||
)
|
||||
f.write(results)
|
||||
|
@ -31,6 +31,20 @@ check_gpus() {
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_cpus() {
|
||||
# check the number of CPUs and NUMA Node and GPU type.
|
||||
declare -g numa_count=$(python3 -c "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)")
|
||||
if [[ $numa_count -gt 0 ]]; then
|
||||
echo "NUMA found."
|
||||
echo $numa_count
|
||||
else
|
||||
echo "Need at least 1 NUMA to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type="cpu"
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_hf_token() {
|
||||
# check if HF_TOKEN is available and valid
|
||||
if [[ -z "$HF_TOKEN" ]]; then
|
||||
@ -69,6 +83,22 @@ json2args() {
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
json2envs() {
|
||||
# transforms the JSON string to environment variables.
|
||||
# example:
|
||||
# input: { "VLLM_CPU_KVCACHE_SPACE": 5 }
|
||||
# output: VLLM_CPU_KVCACHE_SPACE=5
|
||||
local json_string=$1
|
||||
local args=$(
|
||||
echo "$json_string" | jq -r '
|
||||
to_entries |
|
||||
map((.key ) + "=" + (.value | tostring)) |
|
||||
join(" ")
|
||||
'
|
||||
)
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
@ -158,15 +188,24 @@ run_latency_tests() {
|
||||
# get arguments
|
||||
latency_params=$(echo "$params" | jq -r '.parameters')
|
||||
latency_args=$(json2args "$latency_params")
|
||||
latency_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
latency_envs=$(json2envs "$latency_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
latency_command="python3 benchmark_latency.py \
|
||||
latency_command=" $latency_envs python3 benchmark_latency.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$latency_args"
|
||||
|
||||
@ -216,15 +255,24 @@ run_throughput_tests() {
|
||||
# get arguments
|
||||
throughput_params=$(echo "$params" | jq -r '.parameters')
|
||||
throughput_args=$(json2args "$throughput_params")
|
||||
throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
throughput_envs=$(json2envs "$throughput_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
throughput_command="python3 benchmark_throughput.py \
|
||||
throughput_command=" $throughput_envs python3 benchmark_throughput.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$throughput_args"
|
||||
|
||||
@ -272,18 +320,27 @@ run_serving_tests() {
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||
server_envs=$(echo "$params" | jq -r '.server_environment_variables')
|
||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||
server_args=$(json2args "$server_params")
|
||||
server_envs=$(json2envs "$server_envs")
|
||||
client_args=$(json2args "$client_params")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
# check if server model and client model is aligned
|
||||
@ -294,23 +351,33 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
server_command="python3 \
|
||||
server_command="$server_envs python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
$server_args"
|
||||
|
||||
# run the server
|
||||
echo "Running test case $test_name"
|
||||
echo "Server command: $server_command"
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vllm server is up and running."
|
||||
# support remote vllm server
|
||||
client_remote_args=""
|
||||
if [[ -z "${REMOTE_HOST}" ]]; then
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vLLM server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "vLLM failed to start within the timeout period."
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
echo "vllm failed to start within the timeout period."
|
||||
server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT"
|
||||
if [[ ${REMOTE_PORT} ]]; then
|
||||
client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT "
|
||||
else
|
||||
client_remote_args=" --host=$REMOTE_HOST "
|
||||
fi
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
@ -332,7 +399,7 @@ run_serving_tests() {
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args"
|
||||
$client_args $client_remote_args "
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
@ -360,7 +427,14 @@ run_serving_tests() {
|
||||
}
|
||||
|
||||
main() {
|
||||
check_gpus
|
||||
local ARCH
|
||||
ARCH=''
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
check_cpus
|
||||
ARCH='-cpu'
|
||||
else
|
||||
check_gpus
|
||||
fi
|
||||
check_hf_token
|
||||
|
||||
# Set to v1 to run v1 benchmark
|
||||
@ -386,9 +460,9 @@ main() {
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
|
||||
# benchmarking
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}"
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}"
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}"
|
||||
|
||||
# postprocess benchmarking results
|
||||
pip install tabulate pandas
|
||||
|
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "latency_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
}
|
||||
]
|
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
@ -0,0 +1,158 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
]
|
@ -0,0 +1,32 @@
|
||||
[
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
@ -52,7 +52,7 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg USE_SCCACHE=1 --build-arg GIT_REPO_CHECK=1 --build-arg CUDA_VERSION=12.8.1 --build-arg INSTALL_KV_CONNECTORS=true --tag public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT --target vllm-openai --progress plain -f docker/Dockerfile ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-release-repo:$BUILDKITE_COMMIT"
|
||||
|
||||
- label: "Annotate release workflow"
|
||||
@ -101,7 +101,8 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
@ -117,6 +118,7 @@ steps:
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
@ -107,10 +107,9 @@ fi
|
||||
|
||||
if [[ $commands == *" kernels/attention"* ]]; then
|
||||
commands="${commands} \
|
||||
--ignore=kernels/attention/stest_attention_selector.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_blocksparse_attention.py \
|
||||
--ignore=kernels/attention/test_encoder_decoder_attn.py \
|
||||
--ignore=kernels/attention/test_attention_selector.py \
|
||||
--ignore=kernels/attention/test_flash_attn.py \
|
||||
--ignore=kernels/attention/test_flashinfer.py \
|
||||
--ignore=kernels/attention/test_prefix_prefill.py \
|
||||
|
@ -51,6 +51,7 @@ function cpu_tests() {
|
||||
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation -m cpu_model
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model
|
||||
pytest -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
@ -98,4 +99,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
@ -2,10 +2,34 @@
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
set -exuo pipefail
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - .
|
||||
FROM 1.22-413-pt2.7.1:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN git clone https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
EOF
|
||||
|
||||
# Setup cleanup
|
||||
# certain versions of HPU software stack have a bug that can
|
||||
@ -14,13 +38,21 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
# functions, while other platforms only need one remove_docker_container
|
||||
# function.
|
||||
EXITCODE=1
|
||||
remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; }
|
||||
remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; }
|
||||
trap remove_docker_containers_and_exit EXIT
|
||||
remove_docker_containers() { docker rm -f hpu-plugin-v1-test || true; }
|
||||
trap 'remove_docker_containers; exit $EXITCODE;' EXIT
|
||||
remove_docker_containers
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2
|
||||
echo "Running HPU plugin v1 test"
|
||||
docker run --rm --runtime=habana --name=hpu-plugin-v1-test --network=host \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
hpu-plugin-v1-test-env \
|
||||
/bin/bash "/workspace/vllm-gaudi/tests/upstream_tests/ci_tests.sh"
|
||||
|
||||
EXITCODE=$?
|
||||
if [ $EXITCODE -eq 0 ]; then
|
||||
echo "Test with basic model passed"
|
||||
else
|
||||
echo "Test with basic model FAILED with exit code: $EXITCODE" >&2
|
||||
fi
|
||||
|
||||
# The trap will handle the container removal and final exit.
|
@ -54,10 +54,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "
|
||||
set -e; # Exit on first error
|
||||
python3 /workspace/vllm/examples/offline_inference/neuron.py;
|
||||
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
|
||||
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
|
||||
echo 'Running test file: '$f;
|
||||
echo \"Running test file: \$f\";
|
||||
python3 -m pytest \$f -v --capture=tee-sys;
|
||||
done
|
||||
"
|
@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
@ -28,4 +28,5 @@ docker run \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
'
|
||||
|
@ -4,8 +4,8 @@ CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
MAX_NUM_SEQS=512
|
||||
MAX_NUM_BATCHED_TOKENS=512
|
||||
MAX_NUM_SEQS=256
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
|
@ -68,7 +68,7 @@ docker run \
|
||||
|
||||
echo "run script..."
|
||||
echo
|
||||
docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh"
|
||||
docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh"
|
||||
|
||||
echo "copy result back..."
|
||||
VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt
|
||||
|
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
@ -0,0 +1,14 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8bw8a8
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8
|
||||
MAX_NUM_SEQS=128
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
EXPECTED_THROUGHPUT=10.0
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=128
|
@ -41,6 +41,16 @@ steps:
|
||||
# TODO: add `--strict` once warnings in docstrings are fixed
|
||||
- mkdocs build
|
||||
|
||||
- label: Pytorch Nightly Dependency Override Check # 2min
|
||||
# if this test fails, it means the nightly torch version is not compatible with some
|
||||
# of the dependencies. Please check the error message and add the package to whitelist
|
||||
# in /vllm/tools/generate_nightly_torch_test.py
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- requirements/nightly_torch_test.txt
|
||||
commands:
|
||||
- bash standalone_tests/pytorch_nightly_dependency.sh
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
@ -89,7 +99,7 @@ steps:
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Chunked Prefill Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
@ -145,6 +155,7 @@ steps:
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -153,8 +164,9 @@ steps:
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -168,6 +180,23 @@ steps:
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: EPLB Algorithm Test
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/eplb
|
||||
- tests/distributed/test_eplb_algo.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_eplb_algo.py
|
||||
|
||||
- label: EPLB Execution Test # 5min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/eplb
|
||||
- tests/distributed/test_eplb_execute.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_eplb_execute.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
num_gpus: 2
|
||||
@ -188,7 +217,7 @@ steps:
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@ -198,7 +227,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@ -271,6 +300,15 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/cuda
|
||||
commands:
|
||||
- pytest -v -s cuda/test_cuda_context.py
|
||||
|
||||
- label: Samplers Test # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
@ -302,7 +340,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -384,7 +422,7 @@ steps:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
@ -476,7 +514,7 @@ steps:
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -500,6 +538,17 @@ steps:
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
- label: Language Models Test (Hybrid) # 35 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
@ -509,7 +558,7 @@ steps:
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m 'not core_model'
|
||||
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
|
||||
|
||||
- label: Language Models Test (Extended Pooling) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -554,7 +603,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -606,13 +655,18 @@ steps:
|
||||
- vllm/executor/
|
||||
- vllm/model_executor/models/
|
||||
- tests/distributed/
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -630,10 +684,12 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
@ -736,7 +792,7 @@ steps:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||
|
||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
gpu: a100
|
||||
|
6
.github/CODEOWNERS
vendored
6
.github/CODEOWNERS
vendored
@ -16,7 +16,11 @@
|
||||
/vllm/lora @jeejeelee
|
||||
/vllm/reasoning @aarnphm
|
||||
/vllm/entrypoints @aarnphm
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
CMakeLists.txt @tlrmchlsmth @LucasWilkinson
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
|
52
.github/mergify.yml
vendored
52
.github/mergify.yml
vendored
@ -27,6 +27,22 @@ pull_request_rules:
|
||||
add:
|
||||
- ci/build
|
||||
|
||||
- name: label-deepseek
|
||||
description: Automatically apply deepseek label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*deepseek.*\.py
|
||||
- files~=^tests/.*deepseek.*\.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py
|
||||
- files~=^vllm/model_executor/models/.*deepseek.*\.py
|
||||
- files~=^vllm/reasoning/.*deepseek.*\.py
|
||||
- files~=^vllm/transformers_utils/.*deepseek.*\.py
|
||||
- title~=(?i)DeepSeek
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- deepseek
|
||||
|
||||
- name: label-frontend
|
||||
description: Automatically apply frontend label
|
||||
conditions:
|
||||
@ -45,6 +61,7 @@ pull_request_rules:
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/llama.*\.py
|
||||
- files~=^vllm/model_executor/models/.*llama.*\.py
|
||||
- files~=^vllm/transformers_utils/configs/.*llama.*\.py
|
||||
- title~=(?i)llama
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -57,14 +74,38 @@ pull_request_rules:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-new-model
|
||||
description: Automatically apply new-model label
|
||||
conditions:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
- files=tests/models/registry.py
|
||||
- files=docs/models/supported_models.md
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- new-model
|
||||
|
||||
- name: label-performance
|
||||
description: Automatically apply performance label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^benchmarks/
|
||||
- files~=^vllm/benchmarks/
|
||||
- files~=^tests/benchmarks/
|
||||
- files~=^\.buildkite/nightly-benchmarks/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- performance
|
||||
|
||||
- name: label-qwen
|
||||
description: Automatically apply qwen label
|
||||
conditions:
|
||||
@ -74,7 +115,6 @@ pull_request_rules:
|
||||
- files~=^vllm/model_executor/models/.*qwen.*\.py
|
||||
- files~=^vllm/reasoning/.*qwen.*\.py
|
||||
- title~=(?i)Qwen
|
||||
- body~=(?i)Qwen
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -127,8 +167,14 @@ pull_request_rules:
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files~=^vllm/v1/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
- files~=^tests/v1/spec_decode/
|
||||
- files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py
|
||||
- files~=^vllm/model_executor/models/.*eagle.*\.py
|
||||
- files=vllm/model_executor/models/mlp_speculator.py
|
||||
- files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
2
.github/workflows/lint-and-deploy.yaml
vendored
2
.github/workflows/lint-and-deploy.yaml
vendored
@ -68,7 +68,7 @@ jobs:
|
||||
export AWS_ACCESS_KEY_ID=minioadmin
|
||||
export AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
sleep 30 && kubectl -n ns-vllm logs -f "$(kubectl -n ns-vllm get pods | awk '/deployment/ {print $1;exit}')" &
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
helm install --wait --wait-for-jobs --timeout 5m0s --debug --create-namespace --namespace=ns-vllm test-vllm examples/online_serving/chart-helm -f examples/online_serving/chart-helm/values.yaml --set secrets.s3endpoint=http://minio:9000 --set secrets.s3bucketname=testbucket --set secrets.s3accesskeyid=$AWS_ACCESS_KEY_ID --set secrets.s3accesskey=$AWS_SECRET_ACCESS_KEY --set resources.requests.cpu=1 --set resources.requests.memory=4Gi --set resources.limits.cpu=2 --set resources.limits.memory=5Gi --set image.env[0].name=VLLM_CPU_KVCACHE_SPACE --set image.env[1].name=VLLM_LOGGING_LEVEL --set image.env[2].name=VLLM_CPU_CI_ENV --set-string image.env[0].value="1" --set-string image.env[1].value="DEBUG" --set-string image.env[2].value="1" --set-string extraInit.s3modelpath="opt-125m/" --set-string 'resources.limits.nvidia\.com/gpu=0' --set-string 'resources.requests.nvidia\.com/gpu=0' --set-string image.repository="vllm-cpu-env"
|
||||
|
||||
- name: curl test
|
||||
run: |
|
||||
|
@ -53,6 +53,11 @@ repos:
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: format-torch-nightly-test
|
||||
name: reformat nightly_torch_test.txt to be in sync with test.in
|
||||
language: python
|
||||
entry: python tools/generate_nightly_torch_test.py
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- id: mypy-local
|
||||
name: Run mypy for local Python installation
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
@ -115,6 +120,11 @@ repos:
|
||||
entry: python tools/check_spdx_header.py
|
||||
language: python
|
||||
types: [python]
|
||||
- id: check-root-lazy-imports
|
||||
name: Check root lazy imports
|
||||
entry: python tools/check_init_lazy_imports.py
|
||||
language: python
|
||||
types: [python]
|
||||
- id: check-filenames
|
||||
name: Check for spaces in all filenames
|
||||
entry: bash
|
||||
@ -150,6 +160,13 @@ repos:
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
additional_dependencies: [pathspec, regex]
|
||||
- id: validate-config
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
@ -259,7 +259,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library")
|
||||
|
||||
# Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building.
|
||||
set(CUTLASS_REVISION "v3.9.2" CACHE STRING "CUTLASS revision to use")
|
||||
set(CUTLASS_REVISION "v4.0.0" CACHE STRING "CUTLASS revision to use")
|
||||
|
||||
# Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided
|
||||
if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR})
|
||||
@ -420,6 +420,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
@ -513,6 +543,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
@ -547,8 +578,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
@ -562,6 +592,46 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/blockwise_scaled_group_mm_sm100.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building blockwise_scaled_group_mm_sm100 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 kernels as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building blockwise_scaled_group_mm_sm100 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
@ -638,6 +708,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
if (VLLM_GPU_LANG STREQUAL "HIP")
|
||||
# Add QuickReduce kernels
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/custom_quickreduce.cu"
|
||||
)
|
||||
# if ROCM endif
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
define_gpu_extension_target(
|
||||
_C
|
||||
|
@ -154,11 +154,13 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
|
||||
## Contact Us
|
||||
|
||||
<!-- --8<-- [start:contact-us] -->
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||
<!-- --8<-- [end:contact-us] -->
|
||||
|
||||
## Media Kit
|
||||
|
||||
|
@ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive
|
||||
datasets supported on vLLM. It’s a living document, updated as new features and datasets
|
||||
become available.
|
||||
|
||||
## Dataset Overview
|
||||
**Dataset Overview**
|
||||
|
||||
<table style="width:100%; border-collapse: collapse;">
|
||||
<thead>
|
||||
@ -82,7 +82,10 @@ become available.
|
||||
**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`
|
||||
|
||||
---
|
||||
## Example - Online Benchmark
|
||||
<details>
|
||||
<summary><b>🚀 Example - Online Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
First start serving your model
|
||||
|
||||
@ -130,7 +133,8 @@ P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Custom Dataset
|
||||
**Custom Dataset**
|
||||
|
||||
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
|
||||
|
||||
```
|
||||
@ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile
|
||||
|
||||
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
**VisionArena Benchmark for Vision Language Models**
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
@ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
**InstructCoder Benchmark with Speculative Decoding**
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
@ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \
|
||||
--num-prompts 2048
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
**Other HuggingFaceDataset Examples**
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
@ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 80
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
**Running With Sampling Parameters**
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
@ -269,8 +273,27 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
**Running With Ramp-Up Request Rate**
|
||||
|
||||
The benchmark tool also supports ramping up the request rate over the
|
||||
duration of the benchmark run. This can be useful for stress testing the
|
||||
server or finding the maximum throughput that it can handle, given some latency budget.
|
||||
|
||||
Two ramp-up strategies are supported:
|
||||
- `linear`: Increases the request rate linearly from a start value to an end value.
|
||||
- `exponential`: Increases the request rate exponentially.
|
||||
|
||||
The following arguments can be used to control the ramp-up:
|
||||
- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`).
|
||||
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
|
||||
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>📈 Example - Offline Throughput Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
@ -288,7 +311,7 @@ Total num prompt tokens: 5014
|
||||
Total num output tokens: 1500
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
**VisionArena Benchmark for Vision Language Models**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
@ -308,7 +331,7 @@ Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
**InstructCoder Benchmark with Speculative Decoding**
|
||||
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
@ -332,7 +355,7 @@ Total num prompt tokens: 261136
|
||||
Total num output tokens: 204800
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
**Other HuggingFaceDataset Examples**
|
||||
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
@ -371,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
### Benchmark with LoRA Adapters
|
||||
**Benchmark with LoRA Adapters**
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
@ -387,3 +410,196 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--enable-lora \
|
||||
--lora-path yard1/llama-2-7b-sql-lora-test
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>🛠️ Example - Structured Output Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of structured output generation (JSON, grammar, regex).
|
||||
|
||||
**Server Setup**
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests
|
||||
```
|
||||
|
||||
**JSON Schema Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**Grammar-based Generation Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset grammar \
|
||||
--structure-type grammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**Regex-based Generation Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset regex \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**Choice-based Generation Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset choice \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**XGrammar Benchmark Dataset**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset xgrammar_bench \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>📚 Example - Long Document QA Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of long document question-answering with prefix caching.
|
||||
|
||||
**Basic Long Document QA Test**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 16 \
|
||||
--document-length 2000 \
|
||||
--output-len 50 \
|
||||
--repeat-count 5
|
||||
```
|
||||
|
||||
**Different Repeat Modes**
|
||||
|
||||
```bash
|
||||
# Random mode (default) - shuffle prompts randomly
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode random
|
||||
|
||||
# Tile mode - repeat entire prompt list in sequence
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode tile
|
||||
|
||||
# Interleave mode - repeat each prompt consecutively
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode interleave
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>🗂️ Example - Prefix Caching Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the efficiency of automatic prefix caching.
|
||||
|
||||
**Fixed Prompt with Prefix Caching**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 1 \
|
||||
--repeat-count 100 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
**ShareGPT Dataset with Prefix Caching**
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 20 \
|
||||
--repeat-count 5 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>⚡ Example - Request Prioritization Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of request prioritization in vLLM.
|
||||
|
||||
**Basic Prioritization Test**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority
|
||||
```
|
||||
|
||||
**Multiple Sequences per Prompt**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority \
|
||||
--n 2
|
||||
```
|
||||
|
||||
</details>
|
||||
|
@ -10,6 +10,7 @@
|
||||
# 3. Set variables (ALL REQUIRED)
|
||||
# BASE: your directory for vllm repo
|
||||
# MODEL: the model served by vllm
|
||||
# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support.
|
||||
# TP: ways of tensor parallelism
|
||||
# DOWNLOAD_DIR: directory to download and load model weights.
|
||||
# INPUT_LEN: request input len
|
||||
@ -34,6 +35,7 @@
|
||||
TAG=$(date +"%Y_%m_%d_%H_%M")
|
||||
BASE=""
|
||||
MODEL="meta-llama/Llama-3.1-8B-Instruct"
|
||||
SYSTEM="TPU"
|
||||
TP=1
|
||||
DOWNLOAD_DIR=""
|
||||
INPUT_LEN=4000
|
||||
@ -45,12 +47,15 @@ NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096"
|
||||
|
||||
LOG_FOLDER="$BASE/auto-benchmark/$TAG"
|
||||
RESULT="$LOG_FOLDER/result.txt"
|
||||
PROFILE_PATH="$LOG_FOLDER/profile"
|
||||
|
||||
echo "result file: $RESULT"
|
||||
echo "model: $MODEL"
|
||||
|
||||
rm -rf $LOG_FOLDER
|
||||
rm -rf $PROFILE_PATH
|
||||
mkdir -p $LOG_FOLDER
|
||||
mkdir -p $PROFILE_PATH
|
||||
|
||||
cd "$BASE/vllm"
|
||||
|
||||
@ -70,10 +75,11 @@ start_server() {
|
||||
local max_num_seqs=$2
|
||||
local max_num_batched_tokens=$3
|
||||
local vllm_log=$4
|
||||
local profile_dir=$5
|
||||
|
||||
pkill -f vllm
|
||||
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \
|
||||
--disable-log-requests \
|
||||
--port 8004 \
|
||||
--gpu-memory-utilization $gpu_memory_utilization \
|
||||
@ -105,19 +111,37 @@ start_server() {
|
||||
fi
|
||||
}
|
||||
|
||||
update_best_profile() {
|
||||
local profile_dir=$1
|
||||
local profile_index=$2
|
||||
sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort))
|
||||
selected_profile_file=
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb"
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}"
|
||||
fi
|
||||
rm -f $PROFILE_PATH/*
|
||||
cp $selected_profile_file $PROFILE_PATH
|
||||
}
|
||||
|
||||
run_benchmark() {
|
||||
local max_num_seqs=$1
|
||||
local max_num_batched_tokens=$2
|
||||
local gpu_memory_utilization=$3
|
||||
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
|
||||
local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}"
|
||||
echo "vllm_log: $vllm_log"
|
||||
echo
|
||||
rm -f $vllm_log
|
||||
mkdir -p $profile_dir
|
||||
pkill -f vllm
|
||||
local profile_index=0
|
||||
|
||||
echo "starting server..."
|
||||
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log
|
||||
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir
|
||||
result=$?
|
||||
if [[ "$result" -eq 1 ]]; then
|
||||
echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
@ -144,7 +168,8 @@ run_benchmark() {
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 1000 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--port 8004 &> "$bm_log"
|
||||
--port 8004 \
|
||||
--profile &> "$bm_log"
|
||||
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
@ -158,6 +183,7 @@ run_benchmark() {
|
||||
# start from request-rate as int(throughput) + 1
|
||||
request_rate=$((${throughput%.*} + 1))
|
||||
while ((request_rate > 0)); do
|
||||
profile_index=$((profile_index+1))
|
||||
# clear prefix cache
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
@ -195,6 +221,12 @@ run_benchmark() {
|
||||
best_max_num_seqs=$max_num_seqs
|
||||
best_num_batched_tokens=$max_num_batched_tokens
|
||||
best_goodput=$goodput
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
update_best_profile "$profile_dir/plugins/profile" $profile_index
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
update_best_profile "$profile_dir" $profile_index
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
|
||||
@ -239,6 +271,6 @@ for num_seqs in "${num_seqs_list[@]}"; do
|
||||
done
|
||||
done
|
||||
echo "finish permutations"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
||||
|
||||
|
@ -404,8 +404,14 @@ async def async_request_openai_chat_completions(
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||
# NOTE: SSE comments (often used as pings) start with a colon.
|
||||
# These are not JSON data payload and should be skipped.
|
||||
if chunk_bytes.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.removeprefix("data: ")
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
@ -349,11 +349,12 @@ class RandomDataset(BenchmarkDataset):
|
||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
# To avoid uncontrolled change of the prompt length,
|
||||
# the encoded sequence is truncated before being decode again.
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
|
||||
: input_lens[i]
|
||||
:total_input_len
|
||||
]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
total_input_len = len(re_encoded_sequence)
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
|
@ -33,7 +33,7 @@ import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
@ -107,14 +107,42 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
ramp_up_end_rps: Optional[int],
|
||||
request_index: int,
|
||||
total_requests: int,
|
||||
request_rate: float,
|
||||
) -> float:
|
||||
if (
|
||||
ramp_up_strategy
|
||||
and ramp_up_start_rps is not None
|
||||
and ramp_up_end_rps is not None
|
||||
):
|
||||
progress = request_index / max(total_requests - 1, 1)
|
||||
if ramp_up_strategy == "linear":
|
||||
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
|
||||
return ramp_up_start_rps + increase
|
||||
elif ramp_up_strategy == "exponential":
|
||||
ratio = ramp_up_end_rps / ramp_up_start_rps
|
||||
return ramp_up_start_rps * (ratio**progress)
|
||||
else:
|
||||
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
|
||||
return request_rate
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
@ -129,22 +157,44 @@ async def get_request(
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
ramp_up_strategy (optional):
|
||||
The ramp-up strategy. Can be "linear" or "exponential".
|
||||
If None, uses constant request rate (specified by request_rate).
|
||||
ramp_up_start_rps (optional):
|
||||
The starting request rate for ramp-up.
|
||||
ramp_up_end_rps (optional):
|
||||
The ending request rate for ramp-up.
|
||||
"""
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
theta = 1.0 / (request_rate * burstiness)
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
request_index = 0
|
||||
|
||||
for request in input_requests:
|
||||
yield request
|
||||
current_request_rate = _get_current_request_rate(
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate,
|
||||
)
|
||||
|
||||
if request_rate == float("inf"):
|
||||
yield request, current_request_rate
|
||||
|
||||
request_index += 1
|
||||
|
||||
if current_request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
theta = 1.0 / (current_request_rate * burstiness)
|
||||
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
interval = np.random.gamma(shape=burstiness, scale=theta)
|
||||
@ -290,6 +340,9 @@ async def benchmark(
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -353,7 +406,15 @@ async def benchmark(
|
||||
|
||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
if ramp_up_strategy is not None:
|
||||
print(
|
||||
f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase "
|
||||
f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over "
|
||||
"the duration of the benchmark."
|
||||
)
|
||||
else:
|
||||
print(f"Traffic request rate: {request_rate} RPS.")
|
||||
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
print(f"Maximum request concurrency: {max_concurrency}")
|
||||
|
||||
@ -373,7 +434,34 @@ async def benchmark(
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
|
||||
rps_change_events = []
|
||||
last_int_rps = -1
|
||||
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
|
||||
last_int_rps = ramp_up_start_rps
|
||||
rps_change_events.append(
|
||||
{
|
||||
"rps": last_int_rps,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
async for request, current_request_rate in get_request(
|
||||
input_requests,
|
||||
request_rate,
|
||||
burstiness,
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
):
|
||||
if ramp_up_strategy is not None:
|
||||
current_int_rps = int(current_request_rate)
|
||||
if current_int_rps > last_int_rps:
|
||||
timestamp = datetime.now().isoformat()
|
||||
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
|
||||
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
|
||||
last_int_rps = current_int_rps
|
||||
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
@ -397,11 +485,8 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
)
|
||||
)
|
||||
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
tasks.append(asyncio.create_task(task))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
@ -466,7 +551,7 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
|
||||
"request_goodput": metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -477,6 +562,9 @@ async def benchmark(
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
if rps_change_events:
|
||||
result["rps_change_events"] = rps_change_events
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
@ -610,6 +698,26 @@ def main(args: argparse.Namespace):
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
# Validate ramp-up arguments
|
||||
if args.ramp_up_strategy is not None:
|
||||
if args.request_rate != float("inf"):
|
||||
raise ValueError(
|
||||
"When using ramp-up, do not specify --request-rate. "
|
||||
"The request rate will be controlled by ramp-up parameters. "
|
||||
"Please remove the --request-rate argument."
|
||||
)
|
||||
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
||||
raise ValueError(
|
||||
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps must be specified"
|
||||
)
|
||||
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
||||
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
||||
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
||||
raise ValueError("Ramp-up start RPS must be less than end RPS")
|
||||
if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0:
|
||||
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
base_url = f"{args.base_url}"
|
||||
@ -802,6 +910,9 @@ def main(args: argparse.Namespace):
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
)
|
||||
)
|
||||
|
||||
@ -834,6 +945,11 @@ def main(args: argparse.Namespace):
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
if args.ramp_up_strategy is not None:
|
||||
result_json["ramp_up_strategy"] = args.ramp_up_strategy
|
||||
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
|
||||
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
@ -859,7 +975,10 @@ def main(args: argparse.Namespace):
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
@ -1225,6 +1344,31 @@ def create_argument_parser():
|
||||
"script chooses a LoRA module at random.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ramp-up-strategy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["linear", "exponential"],
|
||||
help="The ramp-up strategy. This would be used to "
|
||||
"ramp up the request rate from initial RPS to final "
|
||||
"RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
|
||||
"over the duration of the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-start-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The starting request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-end-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The ending request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
@ -97,7 +97,7 @@ def run_vllm(
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
|
@ -19,7 +19,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
@ -117,14 +117,9 @@ def bench_fp8(
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
block_scale_a = torch.rand(
|
||||
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
|
||||
block_scale_b = torch.rand(
|
||||
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
|
||||
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
|
@ -1,4 +1,5 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
import copy
|
||||
import itertools
|
||||
|
@ -113,6 +113,7 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
@ -124,7 +125,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -148,7 +150,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -227,6 +230,7 @@ def bench_run(
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -287,12 +291,13 @@ def bench_run(
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
|
@ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_fp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
@ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
@ -57,80 +65,144 @@ def bench_run(
|
||||
size_n: int,
|
||||
):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
|
||||
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
if act_order and (group_size == -1 or group_size == size_k or has_zp):
|
||||
return
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
|
||||
|
||||
# Marlin quant
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
marlin_24_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||
b, quant_type, group_size, act_order
|
||||
repack_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in MARLIN_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
# AllSpark W8A16 quant
|
||||
as_supported_case = (
|
||||
allspark_supported = (
|
||||
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1
|
||||
and not act_order
|
||||
and is_k_full
|
||||
)
|
||||
if as_supported_case:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = sm_version >= 80 and sm_version < 90
|
||||
as_supported_case = as_supported_case and supported_arch
|
||||
if supported_arch:
|
||||
has_zp = False
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp
|
||||
def gen_marlin_params():
|
||||
# Marlin quant
|
||||
marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size != 16 or act_order:
|
||||
return
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||
b.T, group_size
|
||||
)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128] or act_order:
|
||||
return
|
||||
marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
|
||||
elif group_size == 16:
|
||||
return
|
||||
elif has_zp:
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b, quant_type, group_size
|
||||
)
|
||||
else:
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
|
||||
marlin_quantize(b, quant_type, group_size, act_order)
|
||||
)
|
||||
return (
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
)
|
||||
|
||||
def gen_marlin_24_params():
|
||||
marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
|
||||
if marlin_24_supported:
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
)
|
||||
return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)
|
||||
|
||||
def gen_repack_params():
|
||||
q_w_gptq = None
|
||||
repack_sort_indices = None
|
||||
if repack_supported:
|
||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||
b, quant_type, group_size, act_order
|
||||
)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
return q_w_gptq, repack_sort_indices
|
||||
|
||||
def gen_allspark_params():
|
||||
qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
|
||||
CUBLAS_M_THRESHOLD
|
||||
) = None
|
||||
nonlocal allspark_supported
|
||||
if allspark_supported:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = sm_version >= 80 and sm_version < 90
|
||||
allspark_supported = allspark_supported and supported_arch
|
||||
if supported_arch:
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp
|
||||
)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
return (
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
sm_count,
|
||||
sm_version,
|
||||
CUBLAS_M_THRESHOLD,
|
||||
)
|
||||
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
) = gen_marlin_params()
|
||||
marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
|
||||
gen_marlin_24_params()
|
||||
)
|
||||
q_w_gptq, repack_sort_indices = gen_repack_params()
|
||||
qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
|
||||
gen_allspark_params()
|
||||
)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
@ -140,15 +212,14 @@ def bench_run(
|
||||
"size_n": size_n,
|
||||
"size_k": size_k,
|
||||
"a": a,
|
||||
"a_tmp": a_tmp,
|
||||
# Marlin params
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_s2": marlin_s2,
|
||||
"marlin_zp": marlin_zp,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
"marlin_workspace": marlin_workspace,
|
||||
"is_k_full": is_k_full,
|
||||
# Marlin_24 params
|
||||
@ -161,12 +232,12 @@ def bench_run(
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
# AllSpark W8A16 params
|
||||
"qw_reorder": qw_reorder if as_supported_case else None,
|
||||
"s_reorder": s_reorder if as_supported_case else None,
|
||||
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||
"sm_count": sm_count if as_supported_case else None,
|
||||
"sm_version": sm_version if as_supported_case else None,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||
"qw_reorder": qw_reorder,
|
||||
"s_reorder": s_reorder,
|
||||
"zp_reorder": zp_reorder,
|
||||
"sm_count": sm_count,
|
||||
"sm_version": sm_version,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
@ -177,7 +248,7 @@ def bench_run(
|
||||
min_run_time = 1
|
||||
|
||||
# Warmup pytorch
|
||||
for i in range(5):
|
||||
for _ in range(5):
|
||||
torch.matmul(a, marlin_w_ref)
|
||||
|
||||
results.append(
|
||||
@ -192,17 +263,17 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp16",
|
||||
description="gptq_marlin_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -210,10 +281,7 @@ def bench_run(
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
):
|
||||
if marlin_24_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
@ -224,17 +292,18 @@ def bench_run(
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
if repack_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if as_supported_case:
|
||||
if allspark_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
@ -250,7 +319,6 @@ def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
@ -278,14 +346,17 @@ def main(args):
|
||||
):
|
||||
continue
|
||||
|
||||
for quant_type in query_marlin_supported_quant_types(False):
|
||||
for quant_type in query_marlin_supported_quant_types():
|
||||
if (
|
||||
len(args.limit_num_bits) > 0
|
||||
and quant_type.size_bits not in args.limit_num_bits
|
||||
):
|
||||
continue
|
||||
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
for group_size in (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES
|
||||
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
):
|
||||
if (
|
||||
len(args.limit_group_size) > 0
|
||||
and group_size not in args.limit_group_size
|
||||
|
@ -620,7 +620,7 @@ def main(args: argparse.Namespace):
|
||||
4096,
|
||||
]
|
||||
else:
|
||||
batch_sizes = [args.batch_size]
|
||||
batch_sizes = args.batch_size
|
||||
|
||||
use_deep_gemm = bool(args.use_deep_gemm)
|
||||
|
||||
@ -728,7 +728,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument("--use-deep-gemm", action="store_true")
|
||||
parser.add_argument("--seed", type=int, default=0)
|
||||
parser.add_argument("--batch-size", type=int, required=False)
|
||||
parser.add_argument("--batch-size", type=int, nargs="+", required=False)
|
||||
parser.add_argument("--tune", action="store_true")
|
||||
parser.add_argument("--trust-remote-code", action="store_true")
|
||||
parser.add_argument("--model-prefix", type=str, required=False)
|
||||
|
@ -4,12 +4,12 @@ import argparse
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.fused_moe.moe_align_block_size import (
|
||||
moe_align_block_size_triton,
|
||||
)
|
||||
from vllm.triton_utils import triton
|
||||
|
||||
|
||||
def get_topk_ids(num_tokens: int, num_experts: int, topk: int) -> torch.Tensor:
|
||||
|
@ -85,12 +85,6 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
||||
# A, block_size[1])
|
||||
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
@ -98,8 +92,6 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
@ -109,9 +101,6 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
# A, block_size[1], column_major_scales=True)
|
||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
|
@ -12,9 +12,8 @@ endif()
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
endif()
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
@ -96,12 +95,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
@ -231,12 +248,25 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
|
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 763ad155a1c826f71ff318f41edb1e4e5e376ddb
|
||||
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
@ -122,6 +122,7 @@ function (get_torch_gpu_compiler_flags OUT_GPU_FLAGS GPU_LANG)
|
||||
"-DENABLE_FP8"
|
||||
"-U__HIP_NO_HALF_CONVERSIONS__"
|
||||
"-U__HIP_NO_HALF_OPERATORS__"
|
||||
"-Werror=unused-variable"
|
||||
"-fno-gpu-rdc")
|
||||
|
||||
endif()
|
||||
@ -264,8 +265,8 @@ macro(set_gencode_flags_for_srcs)
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
@ -277,7 +278,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -312,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\a$")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
string(REPLACE "a" "" _base "${_arch}")
|
||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||
set(_CUDA_ARCHS "10.0a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
@ -358,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
@ -369,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
@ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
|
238
csrc/cpu/sgl-kernels/common.h
Normal file
238
csrc/cpu/sgl-kernels/common.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16 : { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char : { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Float8_e4m3fn : { \
|
||||
using packed_t = at::Float8_e4m3fn; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) { return (x + y - 1) / y; }
|
||||
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_cache_blocks(int BLOCK_SIZE, int K) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T))));
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
@ -0,0 +1,464 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(
|
||||
reinterpret_cast<__m256i*>(C + row * ldc + col * 16),
|
||||
(__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, const float* __restrict__ bias,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb, BLOCK_N, /* add_C */false,
|
||||
A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(
|
||||
A, B, C, Ctmp, bias,
|
||||
M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
|
||||
|
||||
// l2 cache block for n
|
||||
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
}}}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \
|
||||
int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
|
||||
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size,
|
||||
w_data + e * stride + n * IC,
|
||||
n_size,
|
||||
IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
@ -0,0 +1,266 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() { return 2 * TILE_M; }
|
||||
constexpr int block_size_n() { return 2 * TILE_N; }
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T> inline bool can_use_brgemm(int M);
|
||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for int4 w4a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int4_w4a16_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::quint4x2* __restrict__ packed_w1,
|
||||
const at::quint4x2* __restrict__ packed_w2,
|
||||
const uint8_t* __restrict__ w1z,
|
||||
const uint8_t* __restrict__ w2z,
|
||||
const scalar_t* __restrict__ w1s,
|
||||
const scalar_t* __restrict__ w2s,
|
||||
int group_size,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::quint4x2* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const uint8_t* __restrict__ Bz,
|
||||
const scalar_t* __restrict__ Bs,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int group_size,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
int64_t strideBz,
|
||||
int64_t strideBs,
|
||||
bool brg);
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
inline void print_16x32i(const __m512i x) {
|
||||
int32_t a[16];
|
||||
_mm512_storeu_si512((__m512i *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
inline void print_16x32(const __m512 x) {
|
||||
float a[16];
|
||||
_mm512_storeu_ps((__m512 *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
||||
inline void print_32x8u(const __m256i x) {
|
||||
uint8_t a[32];
|
||||
_mm256_storeu_si256((__m256i *)a, x);
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
std::cout << int32_t(a[i]) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
@ -0,0 +1,530 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
// we use 4x32 for BLOCK_M
|
||||
#define BLOCK_SIZE_M_SCALE 4
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void unpack_B(
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_B,
|
||||
int N,
|
||||
int K,
|
||||
int ldb,
|
||||
int ldb_tmp,
|
||||
float scale) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// [K/2, N, 2]
|
||||
const int K2 = K >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
|
||||
const __m512 vd = _mm512_set1_ps(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
static_assert(BLOCK_N == 32);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (int k = 0; k < K2; ++k) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
|
||||
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
|
||||
|
||||
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
|
||||
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
|
||||
|
||||
// Apply scale
|
||||
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
|
||||
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
|
||||
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
|
||||
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
|
||||
|
||||
f0_lo = _mm512_mul_ps(f0_lo, vd);
|
||||
f0_hi = _mm512_mul_ps(f0_hi, vd);
|
||||
f1_lo = _mm512_mul_ps(f1_lo, vd);
|
||||
f1_hi = _mm512_mul_ps(f1_hi, vd);
|
||||
|
||||
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
|
||||
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
|
||||
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int KB = div_up(K, BLOCK_K);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
constexpr int PREFETCH_SIZE_KB = 1;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 vsum[ROWS * COLS];
|
||||
|
||||
// block quant scale
|
||||
__m512 vscale;
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int lda2 = lda >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
|
||||
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
|
||||
}
|
||||
}
|
||||
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
|
||||
};
|
||||
|
||||
constexpr int BLOCK_K2 = BLOCK_K >> 1;
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
int kb_start = kb * BLOCK_K2;
|
||||
int kb_end = std::min(K, kb_start + BLOCK_K2);
|
||||
// 1. load scale vector
|
||||
vscale = _mm512_set1_ps(scale[kb]);
|
||||
if constexpr (PREFETCH_SIZE_KB > 0) {
|
||||
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
|
||||
}
|
||||
// 2. zero vsum for each block
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vsum[i] = _mm512_setzero_ps();
|
||||
});
|
||||
// 3. accumulate across each block
|
||||
for (int k = kb_start; k < kb_end; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
// 4. apply scale
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]);
|
||||
});
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2,4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K);
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fp8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const at::Float8_e4m3fn* __restrict__ mat2,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size, std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales2 to be float32.");
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
TORCH_CHECK(block_size.size() == 2,
|
||||
"fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
|
||||
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
|
||||
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales to be float32.");
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
fp8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<at::Float8_e4m3fn>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM,
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
size_per_thread);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
@ -0,0 +1,440 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales1, at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar,
|
||||
"int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
File diff suppressed because it is too large
Load Diff
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
@ -0,0 +1,502 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
x0 = x0 * weight_vec;
|
||||
x1 = x1 * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void silu_and_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
bVec y = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y);
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_N = div_up(2 * N, block_size_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
}
|
||||
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
scale_size_N = div_up(K, block_size_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
const int64_t stride_e2 = OC * IC;
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
TYPE* __restrict__ A_tmp, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(K, BLOCK_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
|
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
@ -0,0 +1,769 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
|
||||
// size might be 64x + 32
|
||||
std::memcpy(out, input, size * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) * weight_vec;
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2, float scale, int64_t size) {
|
||||
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb0[COLS];
|
||||
__m512i vb1[COLS];
|
||||
__m512i vc0[ROWS * COLS];
|
||||
__m512i vc1[ROWS * COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc0[i] = _mm512_set1_epi32(0);
|
||||
vc1[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
|
||||
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
|
||||
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
|
||||
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto scalec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp
|
||||
if constexpr (row == 0) {
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
}
|
||||
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
|
||||
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
|
||||
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col]));
|
||||
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col]));
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(scalec);
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const Vec one = Vec(1.f);
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
|
||||
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
|
||||
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
|
||||
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
|
||||
// silu
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
// mul
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, As + mb_start, \
|
||||
Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// pattern: 1-(2+2)-(8+8)
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 32;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w2
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2 {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
|
||||
x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]);
|
||||
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// strides for w1: [E, 2N, K]
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, Aq_tmp + index * K, K);
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_e2 = OC * packed_N;
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
|
||||
const float* __restrict__ As = As_tmp + offsets[mb];
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
|
||||
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \
|
||||
int64_t M, int64_t N, int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);
|
308
csrc/cpu/sgl-kernels/vec.h
Normal file
308
csrc/cpu/sgl-kernels/vec.h
Normal file
@ -0,0 +1,308 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
template <typename scalar_t,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
// use native instruction for bfloat16->float32 conversion
|
||||
template <>
|
||||
inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
|
||||
}
|
||||
|
||||
#define CVT_BF16_TO_FP32(a) \
|
||||
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||
|
||||
#define CVT_FP16_TO_FP32(a) \
|
||||
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't hanel NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
|
||||
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
|
||||
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
|
||||
const __m512i nonsign = _mm512_or_si512(exp, mant);
|
||||
|
||||
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
|
||||
const __m512i combined = _mm512_or_si512(nonsign, sign);
|
||||
|
||||
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
|
||||
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
|
||||
// The following conversion is without denorm behavior, that is to say,
|
||||
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
|
||||
// Min subnorm : S.0000.001 = 2**(−9)
|
||||
// 0.0019 ~ 0.0137 cannot be converted correctly.
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
auto mask = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_setzero_si512()); // mask = x & 0x7f
|
||||
auto mask_nan = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
|
||||
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
|
||||
auto exponent = _mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
|
||||
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
|
||||
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
|
||||
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
nonsign,
|
||||
_mm512_slli_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(128)),
|
||||
8))); // add sign (x & 128) << 8
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
__m512i lg2mant = _mm512_mask_mov_epi16(
|
||||
_mm512_mask_mov_epi16(
|
||||
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
|
||||
_mm512_set1_epi16(2));
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
_mm512_maskz_mov_epi16(
|
||||
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
|
||||
_mm512_mask_blend_epi16(
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
|
||||
_mm512_or_si512(
|
||||
_mm512_and_si512(
|
||||
_mm512_sllv_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
|
||||
_mm512_set1_epi16(0x007f)),
|
||||
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
|
||||
_mm512_or_si512(
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
|
||||
_mm512_slli_epi16(
|
||||
_mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
|
||||
7)))),
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
|
||||
}
|
||||
|
||||
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
|
||||
#ifdef SGLANG_CPU_FP8_CVT_FTZ
|
||||
return cvt_e4m3_bf16_intrinsic_no_nan(a);
|
||||
#else
|
||||
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// vector to scalar reduction
|
||||
#if defined(CPU_CAPABILITY_AVX512) && 0
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_add_ps(__m512(a));
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_max_ps(__m512(a));
|
||||
}
|
||||
#else
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
template <typename scalar_t>
|
||||
inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As,
|
||||
const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
|
||||
|
||||
float amax = 0.f; // absolute max
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]);
|
||||
amax = std::max(amax, std::abs(val));
|
||||
}
|
||||
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]) * inv_scale;
|
||||
Aq[k] = (uint8_t)(std::round(val)) + 128;
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As,
|
||||
const at::BFloat16* __restrict__ A, int64_t K, float eps) {
|
||||
|
||||
const __m512 signBit = _mm512_set1_ps(-0.0f);
|
||||
const __m512i off = _mm512_set1_epi32(128);
|
||||
|
||||
// K is 32x, no remainder
|
||||
float amax = 0.f;
|
||||
__m512 vamax0 = _mm512_set1_ps(0.f);
|
||||
__m512 vamax1 = _mm512_set1_ps(0.f);
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
|
||||
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
|
||||
}
|
||||
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
const __m512 vd = _mm512_set1_ps(inv_scale);
|
||||
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
va0 = _mm512_mul_ps(va0, vd);
|
||||
va1 = _mm512_mul_ps(va1, vd);
|
||||
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
|
||||
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose utils
|
||||
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline void transpose_16x16_32bit(__m512i * v) {
|
||||
__m512i v1[16];
|
||||
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
||||
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
||||
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
||||
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
||||
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
||||
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
||||
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
||||
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
||||
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
||||
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
||||
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
||||
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
||||
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
||||
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
||||
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
||||
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
||||
|
||||
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
||||
|
||||
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes]
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
|
||||
// transpose from [2, 32] to [32, 2]
|
||||
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
|
||||
// r0: {a0, a1, ..., a31}
|
||||
// r1: {b0, b1, ..., b31}
|
||||
//
|
||||
// d0: {a0, b0, ..., a15, b15}
|
||||
// d1: {a16, b16, ..., a31, b31}
|
||||
//
|
||||
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
|
||||
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
|
||||
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
||||
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
||||
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
||||
return std::make_tuple(d0, d1);
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#endif
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
template<typename scalar_t>
|
||||
void print_array(scalar_t* ptr, int size) {
|
||||
for (int d = 0; d < size; ++d) {
|
||||
if (d % 16 == 0) { std::cout << std::endl; }
|
||||
std::cout << ptr[d] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
178
csrc/cpu/shm.cpp
178
csrc/cpu/shm.cpp
@ -7,9 +7,10 @@
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define MAX_THREAD_NUM 12
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024)
|
||||
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
|
||||
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
|
||||
#define MIN_THREAD_PROCESS_SIZE (256)
|
||||
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -32,10 +33,10 @@ struct KernelVecType<c10::Half> {
|
||||
using scalar_vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||
volatile char _curr_thread_stamp;
|
||||
volatile char _ready_thread_stamp;
|
||||
char _padding1[6];
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
@ -44,14 +45,19 @@ struct ThreadSHMContext {
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
size_t _thread_buffer_mask;
|
||||
char _padding2[56];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: thread_id(thread_id),
|
||||
: _curr_thread_stamp(1),
|
||||
_ready_thread_stamp(0),
|
||||
thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0) {
|
||||
_spinning_count(0),
|
||||
_thread_buffer_mask(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
@ -60,7 +66,6 @@ struct ThreadSHMContext {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
swizzled_ranks[i] = (i + rank) % group_size;
|
||||
thread_stats[i] = ThreadSHMStat::DONE;
|
||||
}
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
@ -77,59 +82,66 @@ struct ThreadSHMContext {
|
||||
|
||||
template <typename T>
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||
return reinterpret_cast<T*>(
|
||||
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask));
|
||||
}
|
||||
|
||||
void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; }
|
||||
|
||||
char get_curr_stamp() const { return _curr_thread_stamp; }
|
||||
|
||||
char get_ready_stamp() const { return _ready_thread_stamp; }
|
||||
|
||||
void next_stamp() {
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp += 1;
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp = _curr_thread_stamp;
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
|
||||
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
template <typename Cond>
|
||||
void wait_for_all(Cond&& cond) {
|
||||
for (int idx = 1; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
wait_for_one(rank, std::forward<Cond>(cond));
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
template <typename Cond>
|
||||
void wait_for_one(int rank, Cond&& cond) {
|
||||
ThreadSHMContext* rank_ctx = shm_contexts[rank];
|
||||
for (;;) {
|
||||
char local_curr_stamp = get_curr_stamp();
|
||||
char local_ready_stamp = get_ready_stamp();
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp();
|
||||
char rank_ready_stamp = rank_ctx->get_ready_stamp();
|
||||
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
|
||||
rank_ready_stamp)) {
|
||||
break;
|
||||
}
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void set_thread_stat(ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||
}
|
||||
static bool check_no_buffer_conflict(char local_curr_stamp,
|
||||
char local_ready_stamp,
|
||||
char rank_curr_stamp,
|
||||
char rank_ready_stamp) {
|
||||
char temp = rank_curr_stamp + 2;
|
||||
return local_curr_stamp != temp;
|
||||
}
|
||||
|
||||
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
// barrier for all ranks in the group, used for all2all ops
|
||||
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||
void barrier(ThreadSHMStat next_stat) {
|
||||
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||
wait_for_all(ThreadSHMStat::DONE);
|
||||
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||
set_thread_stat(ThreadSHMStat::DONE);
|
||||
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||
}
|
||||
static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp,
|
||||
char rank_curr_stamp, char rank_ready_stamp) {
|
||||
char temp = local_curr_stamp + 1;
|
||||
return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp);
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
@ -164,7 +176,7 @@ class SHMManager {
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||
_thread_num(torch::get_num_threads()),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
(total_units_num + thread_num - 1) / thread_num;
|
||||
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||
int64_t max_per_thread_iteration_elem_num =
|
||||
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||
(PER_THREAD_SHM_BUFFER_BYTES >> 1) /
|
||||
sizeof(scalar_t); // Note: double buffer
|
||||
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
int64_t curr_elem_num =
|
||||
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num);
|
||||
|
||||
while (curr_elem_num > 0) {
|
||||
inner_func(thread_ctx, offset, curr_elem_num);
|
||||
inner_func(thread_ctx, offset, curr_elem_num, fast_mode);
|
||||
|
||||
thread_ctx->next_stamp();
|
||||
thread_ctx->next_buffer();
|
||||
offset += max_per_thread_iteration_elem_num;
|
||||
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
}
|
||||
@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
thread_ctx->get_swizzled_rank(idx + 1));
|
||||
});
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||
thread_data_elem_num);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
thread_ctx->commit_ready_stamp();
|
||||
int64_t aligned_data_elem_num =
|
||||
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||
int64_t i = 0;
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready);
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
reduced_data.save(thread_data_ptr + i,
|
||||
data_elem_num - aligned_data_elem_num);
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->commit_ready_stamp();
|
||||
if (rank == dst) {
|
||||
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
scalar_t* src_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->wait_for_one(src_rank,
|
||||
ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -599,7 +614,7 @@ struct TensorListMeta {
|
||||
int8_t _padding[40];
|
||||
};
|
||||
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
|
||||
const std::vector<torch::Tensor>& tensor_list) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||
@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
// Wait until the receiver set the stat to DONE
|
||||
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t curr_shm_offset = 0;
|
||||
thread_ctx->wait_for_one(dst,
|
||||
ThreadSHMContext::check_no_buffer_conflict);
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.ptr, frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
thread_ctx->commit_ready_stamp();
|
||||
});
|
||||
}
|
||||
|
||||
@ -646,8 +659,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
sizeof(TensorListMeta));
|
||||
@ -664,9 +676,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
// Wait until the sender set the stat to SHM_DATA_READY
|
||||
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
@ -677,8 +688,6 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle,
|
||||
int64_t dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||
shm_send_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst,
|
||||
tensor_list);
|
||||
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||
}
|
||||
|
||||
@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||
TORCH_CHECK(shm_manager);
|
||||
shm_manager->join(name);
|
||||
return shm_manager->get_shm_ctx()->to_string();
|
||||
}
|
||||
}
|
||||
|
@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle,
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
at::Tensor fused_experts_cpu(
|
||||
at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
|
||||
at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
|
||||
bool use_int8_w8a8, bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale, bool is_vnni);
|
||||
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -131,16 +152,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
"Tensor? azp) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
"Tensor!? azp) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
@ -148,7 +172,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
@ -156,7 +181,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
@ -209,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
|
||||
// sgl-kernels
|
||||
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
||||
ops.def(
|
||||
"weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
|
||||
"bias, bool is_vnni) -> Tensor");
|
||||
ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||
ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
|
||||
ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||
ops.def(
|
||||
"fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
|
||||
"topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
|
||||
"use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
|
||||
"block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
|
||||
"Tensor");
|
||||
ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
ops.def(
|
||||
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
|
||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
114
csrc/custom_quickreduce.cu
Normal file
114
csrc/custom_quickreduce.cu
Normal file
@ -0,0 +1,114 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "quickreduce/quick_reduce.h"
|
||||
|
||||
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size) {
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size == 6)
|
||||
throw std::invalid_argument("world size == 6 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
|
||||
fptr->init(world_size, rank, qr_max_size);
|
||||
return (quickreduce::fptr_t)fptr;
|
||||
}
|
||||
|
||||
void qr_destroy(quickreduce::fptr_t _fa) {
|
||||
if (_fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
fa->destroy();
|
||||
delete fa;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
hipIpcMemHandle_t handle = fa->get_handle();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
void qr_open_handles(quickreduce::fptr_t _fa,
|
||||
const std::vector<torch::Tensor>& handles) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
std::vector<hipIpcMemHandle_t> ipc_handles;
|
||||
ipc_handles.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
// Ensure the tensor is on the same device as the current device.
|
||||
hipIpcMemHandle_t ipc_handle;
|
||||
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
|
||||
ipc_handles.push_back(ipc_handle);
|
||||
}
|
||||
fa->open_ipc_handles(ipc_handles);
|
||||
}
|
||||
|
||||
void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
|
||||
torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
|
||||
if (out.scalar_type() == at::ScalarType::Half) {
|
||||
fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(), quant_level, stream);
|
||||
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
|
||||
if (cast_bf2half) {
|
||||
fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(), quant_level, stream);
|
||||
} else {
|
||||
fa->allreduce<quickreduce::nv_bfloat16, false>(
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel(), quant_level, stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"quick allreduce only supports float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t qr_max_size() {
|
||||
// The default is 2GB (2,147,483,648 bytes)
|
||||
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \
|
||||
cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \
|
||||
cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
|
||||
|
||||
#endif // USE_ROCM
|
@ -45,7 +45,6 @@
|
||||
#include "cute/algorithm/functional.hpp"
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
|
||||
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
|
||||
|
@ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
@ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x,
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
|
@ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
);
|
||||
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
|
@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
@ -239,7 +239,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = output.numel() / hidden_size;
|
||||
const auto num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
@ -492,7 +492,7 @@ void topk_softmax(
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const auto num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
|
11
csrc/ops.h
11
csrc/ops.h
@ -360,3 +360,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
void qr_destroy(fptr_t _fa);
|
||||
torch::Tensor qr_get_handle(fptr_t _fa);
|
||||
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#endif
|
@ -162,10 +162,11 @@ __global__ void dynamic_scaled_int8_quant_kernel(
|
||||
|
||||
// calculate for absmax
|
||||
float thread_max = 0.f;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
const auto v = fabsf(static_cast<float>(row_in[i]));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(
|
||||
row_in, hidden_size, tid, stride, [&] __device__(const scalar_t& src) {
|
||||
const float v = fabsf(static_cast<float>(src));
|
||||
thread_max = fmaxf(thread_max, v);
|
||||
});
|
||||
using BlockReduce = cub::BlockReduce<float, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
float block_max = BlockReduce(tmp).Reduce(thread_max, cub::Max{}, blockDim.x);
|
||||
@ -232,9 +233,10 @@ __global__ void dynamic_scaled_int8_azp_quant_kernel(
|
||||
|
||||
// 1. calculate min & max
|
||||
MinMax thread_mm;
|
||||
for (int i = tid; i < hidden_size; i += stride) {
|
||||
thread_mm += static_cast<float>(row_in[i]);
|
||||
}
|
||||
vectorize_read_with_alignment<16>(row_in, hidden_size, tid, stride,
|
||||
[&] __device__(const scalar_t& src) {
|
||||
thread_mm += static_cast<float>(src);
|
||||
});
|
||||
|
||||
using BlockReduce = cub::BlockReduce<MinMax, 256>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp;
|
||||
|
@ -51,7 +51,8 @@ struct cutlass_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
@ -144,4 +145,65 @@ struct cutlass_3x_gemm_sm100 {
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm120 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
|
@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
|
@ -29,26 +29,12 @@ struct sm100_fp8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
// M in (64, 256]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _4, _1>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
@ -57,12 +43,26 @@ struct sm100_fp8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
// M in (16, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm100_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM16 =
|
||||
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM256 =
|
||||
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM16>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (16, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (128, 256]
|
||||
// m in (64, 256]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM256>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
|
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM120 (fp8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm120_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Only work with Shape<_1, _1, _1>
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm120_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
@ -0,0 +1,374 @@
|
||||
#include "core/registration.h"
|
||||
|
||||
#include <torch/all.h>
|
||||
#include <cutlass/arch/arch.h>
|
||||
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
|
||||
#include "cute/tensor.hpp"
|
||||
#include "cutlass/tensor_ref.h"
|
||||
#include "cutlass/epilogue/collective/default_epilogue.hpp"
|
||||
#include "cutlass/epilogue/thread/linear_combination.h"
|
||||
#include "cutlass/gemm/dispatch_policy.hpp"
|
||||
#include "cutlass/gemm/group_array_problem_shape.hpp"
|
||||
#include "cutlass/gemm/collective/collective_builder.hpp"
|
||||
#include "cutlass/epilogue/collective/collective_builder.hpp"
|
||||
#include "cutlass/gemm/device/gemm_universal_adapter.h"
|
||||
#include "cutlass/gemm/kernel/gemm_universal.hpp"
|
||||
|
||||
#include "cutlass/util/command_line.h"
|
||||
#include "cutlass/util/distribution.h"
|
||||
#include "cutlass/util/host_tensor.h"
|
||||
#include "cutlass/util/packed_stride.hpp"
|
||||
#include "cutlass/util/tensor_view_io.h"
|
||||
#include "cutlass/util/reference/device/gemm.h"
|
||||
#include "cutlass/util/reference/device/tensor_compare.h"
|
||||
#include "cutlass/util/reference/host/tensor_fill.h"
|
||||
#include "cutlass/util/reference/host/gett.hpp"
|
||||
#include "cutlass/util/reference/host/tensor_norm.h"
|
||||
#include "cutlass/util/reference/host/tensor_compare.h"
|
||||
#include <cassert>
|
||||
|
||||
using namespace cute;
|
||||
|
||||
template <typename ElementAB, typename ElementC, typename ElementAccumulator,
|
||||
typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
__global__ void get_ggemm_starts(
|
||||
int32_t* expert_offsets, ElementAB** a_offsets, ElementAB** b_offsets,
|
||||
ElementC** out_offsets, ElementAccumulator** a_scale_offsets,
|
||||
ElementAccumulator** b_scale_offsets, ElementAB* a_base_as_int,
|
||||
ElementAB* b_base_as_int, ElementC* out_base_as_int,
|
||||
ElementAccumulator* a_scale_base_as_int,
|
||||
ElementAccumulator* b_scale_base_as_int, LayoutSFA* layout_sfa_base_as_int,
|
||||
LayoutSFB* layout_sfb_base_as_int, int* problem_sizes) {
|
||||
int expert_id = threadIdx.x;
|
||||
|
||||
if (expert_id >= gridDim.x * blockDim.x) {
|
||||
return;
|
||||
}
|
||||
|
||||
int m = problem_sizes[expert_id * 3];
|
||||
int n = problem_sizes[expert_id * 3 + 1];
|
||||
int k = problem_sizes[expert_id * 3 + 2];
|
||||
|
||||
int32_t expert_offset = expert_offsets[expert_id];
|
||||
int a_stride = expert_offset * k;
|
||||
int b_stride = expert_id * k * n;
|
||||
int a_scale_stride = expert_offset * k / 128;
|
||||
int b_scale_stride = expert_id * k * n / 128 / 128;
|
||||
|
||||
a_offsets[expert_id] = a_base_as_int + a_stride;
|
||||
b_offsets[expert_id] = b_base_as_int + b_stride;
|
||||
out_offsets[expert_id] = out_base_as_int + expert_offset * n;
|
||||
a_scale_offsets[expert_id] = a_scale_base_as_int + a_scale_stride;
|
||||
b_scale_offsets[expert_id] = b_scale_base_as_int + b_scale_stride;
|
||||
|
||||
LayoutSFA* layout_sfa_ptr = layout_sfa_base_as_int + expert_id;
|
||||
LayoutSFB* layout_sfb_ptr = layout_sfb_base_as_int + expert_id;
|
||||
|
||||
*layout_sfa_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFA(cute::make_shape(m, n, k, 1));
|
||||
*layout_sfb_ptr =
|
||||
ScaleConfig::tile_atom_to_shape_SFB(cute::make_shape(m, n, k, 1));
|
||||
}
|
||||
|
||||
#define __CALL_GET_STARTS_KERNEL(TENSOR_C_TYPE, C_TYPE, LayoutSFA, LayoutSFB, \
|
||||
ScaleConfig) \
|
||||
else if (out_tensors.dtype() == TENSOR_C_TYPE) { \
|
||||
get_ggemm_starts<cutlass::float_e4m3_t, C_TYPE, float, LayoutSFA, \
|
||||
LayoutSFB, ScaleConfig><<<1, num_experts, 0, stream>>>( \
|
||||
static_cast<int32_t*>(expert_offsets.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(a_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t**>(b_ptrs.data_ptr()), \
|
||||
static_cast<C_TYPE**>(out_ptrs.data_ptr()), \
|
||||
static_cast<float**>(a_scales_ptrs.data_ptr()), \
|
||||
static_cast<float**>(b_scales_ptrs.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(a_tensors.data_ptr()), \
|
||||
static_cast<cutlass::float_e4m3_t*>(b_tensors.data_ptr()), \
|
||||
static_cast<C_TYPE*>(out_tensors.data_ptr()), \
|
||||
static_cast<float*>(a_scales.data_ptr()), \
|
||||
static_cast<float*>(b_scales.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFA*>(layout_sfa.data_ptr()), \
|
||||
reinterpret_cast<LayoutSFB*>(layout_sfb.data_ptr()), \
|
||||
static_cast<int*>(problem_sizes.data_ptr())); \
|
||||
}
|
||||
|
||||
template <typename LayoutSFA, typename LayoutSFB, typename ScaleConfig>
|
||||
void run_get_ggemm_starts(
|
||||
torch::Tensor const& expert_offsets, torch::Tensor& a_ptrs,
|
||||
torch::Tensor& b_ptrs, torch::Tensor& out_ptrs,
|
||||
torch::Tensor& a_scales_ptrs, torch::Tensor& b_scales_ptrs,
|
||||
torch::Tensor const& a_tensors, torch::Tensor const& b_tensors,
|
||||
torch::Tensor out_tensors, torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales, torch::Tensor const& layout_sfa,
|
||||
torch::Tensor const& layout_sfb, torch::Tensor const& problem_sizes) {
|
||||
TORCH_CHECK(a_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b_tensors.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(out_tensors.size(1) % 128 == 0 or out_tensors.size(0) % 128 == 0);
|
||||
TORCH_CHECK(a_tensors.size(1) % 128 == 0 or a_tensors.size(0) % 128 == 0);
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
auto stream = at::cuda::getCurrentCUDAStream(a_tensors.device().index());
|
||||
|
||||
if (false) {
|
||||
}
|
||||
__CALL_GET_STARTS_KERNEL(torch::kBFloat16, cutlass::bfloat16_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
__CALL_GET_STARTS_KERNEL(torch::kFloat16, cutlass::half_t, LayoutSFA,
|
||||
LayoutSFB, ScaleConfig)
|
||||
else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
}
|
||||
|
||||
template <typename OutType, typename ScheduleConfig, typename LayoutD>
|
||||
void run_blockwise_scaled_group_mm(
|
||||
torch::Tensor& out_ptrs, const torch::Tensor& a_ptrs,
|
||||
const torch::Tensor& b_ptrs, const torch::Tensor& a_scales_ptrs,
|
||||
const torch::Tensor& b_scales_ptrs, const torch::Tensor& stride_a,
|
||||
const torch::Tensor& stride_b, const torch::Tensor& stride_c,
|
||||
const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
using ProblemShape = cutlass::gemm::GroupProblemShape<Shape<int, int, int>>;
|
||||
|
||||
// Types
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using ElementB = cutlass::float_e4m3_t;
|
||||
using ElementC = OutType;
|
||||
using ElementD = ElementC;
|
||||
using ElementAccumulator = float;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
using LayoutC = LayoutD;
|
||||
|
||||
// Alignments
|
||||
static constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value;
|
||||
static constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value;
|
||||
static constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value;
|
||||
|
||||
using ArchTag = cutlass::arch::Sm100;
|
||||
using OperatorClass = cutlass::arch::OpClassTensorOp;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator,
|
||||
ElementAccumulator, void, LayoutC*, AlignmentC, ElementD, LayoutC*,
|
||||
AlignmentC, typename ScheduleConfig::EpilogueSchedule>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
ArchTag, OperatorClass, ElementA,
|
||||
cute::tuple<LayoutA*, typename ScheduleConfig::LayoutSFA*>,
|
||||
AlignmentA, ElementB,
|
||||
cute::tuple<LayoutB*, typename ScheduleConfig::LayoutSFB*>,
|
||||
AlignmentB, ElementAccumulator, typename ScheduleConfig::MmaTileShape,
|
||||
typename ScheduleConfig::ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
typename ScheduleConfig::KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel =
|
||||
cutlass::gemm::kernel::GemmUniversal<ProblemShape, CollectiveMainloop,
|
||||
CollectiveEpilogue, void>;
|
||||
|
||||
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
|
||||
using StrideA = typename Gemm::GemmKernel::InternalStrideA;
|
||||
using StrideB = typename Gemm::GemmKernel::InternalStrideB;
|
||||
using StrideC = typename Gemm::GemmKernel::InternalStrideC;
|
||||
using StrideD = typename Gemm::GemmKernel::InternalStrideD;
|
||||
|
||||
using UnderlyingProblemShape = ProblemShape::UnderlyingProblemShape;
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
Gemm gemm_op;
|
||||
|
||||
// Mainloop Arguments
|
||||
typename GemmKernel::MainloopArguments mainloop_args{
|
||||
static_cast<const ElementA**>(a_ptrs.data_ptr()),
|
||||
static_cast<StrideA*>(stride_a.data_ptr()),
|
||||
static_cast<const ElementB**>(b_ptrs.data_ptr()),
|
||||
static_cast<StrideB*>(stride_b.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(a_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFA*>(
|
||||
layout_sfa.data_ptr()),
|
||||
static_cast<const ElementAccumulator**>(b_scales_ptrs.data_ptr()),
|
||||
reinterpret_cast<typename ScheduleConfig::LayoutSFB*>(
|
||||
layout_sfb.data_ptr())};
|
||||
|
||||
cutlass::KernelHardwareInfo hw_info;
|
||||
hw_info.device_id = a_ptrs.get_device();
|
||||
hw_info.sm_count =
|
||||
cutlass::KernelHardwareInfo::query_device_multiprocessor_count(
|
||||
hw_info.device_id);
|
||||
|
||||
// Epilogue Arguments
|
||||
typename GemmKernel::EpilogueArguments epilogue_args{
|
||||
{}, // epilogue.thread
|
||||
nullptr,
|
||||
static_cast<StrideC*>(stride_c.data_ptr()),
|
||||
static_cast<ElementD**>(out_ptrs.data_ptr()),
|
||||
static_cast<StrideC*>(stride_c.data_ptr())};
|
||||
|
||||
UnderlyingProblemShape* problem_sizes_as_shapes =
|
||||
static_cast<UnderlyingProblemShape*>(problem_sizes.data_ptr());
|
||||
|
||||
// Gemm Arguments
|
||||
typename GemmKernel::Arguments args{
|
||||
cutlass::gemm::GemmUniversalMode::kGrouped,
|
||||
{num_experts, problem_sizes_as_shapes, nullptr},
|
||||
mainloop_args,
|
||||
epilogue_args,
|
||||
hw_info};
|
||||
|
||||
at::cuda::CUDAGuard device_guard{(char)a_ptrs.device().index()};
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(a_ptrs.get_device());
|
||||
|
||||
auto can_implement_status = gemm_op.can_implement(args);
|
||||
TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess,
|
||||
"Failed to implement GEMM");
|
||||
|
||||
size_t workspace_size = gemm_op.get_workspace_size(args);
|
||||
auto const workspace_options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(a_ptrs.device());
|
||||
auto workspace = torch::empty(workspace_size, workspace_options);
|
||||
|
||||
auto status = gemm_op.initialize(args, workspace.data_ptr(), stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM");
|
||||
|
||||
status = gemm_op.run(stream);
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
template <typename OutType>
|
||||
void blockwise_scaled_group_mm_dispatch_shape(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
struct MmaConfig {
|
||||
using ElementA = cutlass::float_e4m3_t;
|
||||
using KernelSchedule =
|
||||
cutlass::gemm::KernelPtrArrayTmaWarpSpecializedBlockwise1SmSm100;
|
||||
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm;
|
||||
using ScaleConfig = cutlass::detail::Sm100BlockwiseScaleConfig<
|
||||
1, 128, 128, cute::UMMA::Major::K, cute::UMMA::Major::K>;
|
||||
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA());
|
||||
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB());
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
using MmaTileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
};
|
||||
|
||||
int num_experts = (int)expert_offsets.size(0);
|
||||
|
||||
auto a_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto out_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto a_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto b_scales_ptrs = torch::empty(
|
||||
{num_experts},
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
auto layout_sfa = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
auto layout_sfb = torch::empty(
|
||||
{num_experts, 5},
|
||||
torch::TensorOptions().dtype(torch::kInt32).device(a.device()));
|
||||
|
||||
auto stride_a = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_b = torch::full(
|
||||
{num_experts}, a.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
auto stride_c = torch::full(
|
||||
{num_experts}, output.size(1),
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device()));
|
||||
|
||||
torch::TensorOptions options_int =
|
||||
torch::TensorOptions().dtype(torch::kInt64).device(a.device());
|
||||
|
||||
run_get_ggemm_starts<typename MmaConfig::LayoutSFA,
|
||||
typename MmaConfig::LayoutSFB,
|
||||
typename MmaConfig::ScaleConfig>(
|
||||
expert_offsets, a_ptrs, b_ptrs, out_ptrs, a_scales_ptrs, b_scales_ptrs, a,
|
||||
b, output, scales_a, scales_b, layout_sfa, layout_sfb, problem_sizes);
|
||||
|
||||
run_blockwise_scaled_group_mm<OutType, MmaConfig,
|
||||
typename MmaConfig::LayoutC>(
|
||||
out_ptrs, a_ptrs, b_ptrs, a_scales_ptrs, b_scales_ptrs, stride_a,
|
||||
stride_b, stride_c, layout_sfa, layout_sfb, problem_sizes,
|
||||
expert_offsets);
|
||||
}
|
||||
|
||||
void cutlass_blockwise_scaled_grouped_mm(
|
||||
torch::Tensor& output, const torch::Tensor& a, const torch::Tensor& b,
|
||||
const torch::Tensor& scales_a, const torch::Tensor& scales_b,
|
||||
const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets) {
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(a.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"a must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(b.scalar_type() == torch::kFloat8_e4m3fn,
|
||||
"b must be kFloat8_e4m3fn");
|
||||
TORCH_CHECK(output.scalar_type() == torch::kBFloat16 ||
|
||||
output.scalar_type() == torch::kHalf,
|
||||
"output must be bfloat16 or half");
|
||||
TORCH_CHECK(scales_a.scalar_type() == torch::kFloat32,
|
||||
"scales_a must be float32");
|
||||
TORCH_CHECK(scales_b.scalar_type() == torch::kFloat32,
|
||||
"scales_b must be float32");
|
||||
TORCH_CHECK(expert_offsets.scalar_type() == torch::kInt32,
|
||||
"expert_offsets must be int32");
|
||||
|
||||
TORCH_CHECK(output.dim() == 2, "output must be 2D tensor");
|
||||
TORCH_CHECK(a.dim() == 2, "a must be 2D tensor");
|
||||
TORCH_CHECK(b.dim() == 3, "b must be 3D tensor");
|
||||
TORCH_CHECK(scales_a.dim() == 2, "scales_a must be 2D tensor");
|
||||
TORCH_CHECK(scales_b.dim() == 3, "scales_b must be 3D tensor");
|
||||
TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor");
|
||||
TORCH_CHECK(problem_sizes.size(1) == 3,
|
||||
"problem_sizes must have shape (num_experts, 3)");
|
||||
TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0),
|
||||
"Number of experts in problem_sizes must match expert_offsets");
|
||||
TORCH_CHECK(problem_sizes.dtype() == torch::kInt32,
|
||||
"problem_sizes must be int32");
|
||||
TORCH_CHECK(expert_offsets.dim() == 1, "expert_offsets must be 1D tensor");
|
||||
|
||||
#if defined(ENABLE_CUTLASS_MOE_SM100) && ENABLE_CUTLASS_MOE_SM100
|
||||
if (output.scalar_type() == torch::kBFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::bfloat16_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else if (output.scalar_type() == torch::kFloat16) {
|
||||
blockwise_scaled_group_mm_dispatch_shape<cutlass::half_t>(
|
||||
output, a, b, scales_a, scales_b, problem_sizes, expert_offsets);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Unsupported output tensor type");
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
|
||||
m.impl("cutlass_blockwise_scaled_grouped_mm",
|
||||
&cutlass_blockwise_scaled_grouped_mm);
|
||||
}
|
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
Normal file
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
Normal file
@ -0,0 +1,34 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm120 (Blackwell Geforce).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
@ -41,6 +41,14 @@ void cutlass_moe_mm_sm90(
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
@ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100) {
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
@ -241,7 +256,7 @@ void get_cutlass_moe_mm_data(
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
@ -252,7 +267,7 @@ void get_cutlass_moe_mm_data(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
@ -265,7 +280,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
@ -275,7 +291,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
false,
|
||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
@ -335,8 +335,10 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
#endif
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
|
@ -561,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
@ -579,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
// We don't support e8m0 scales at this moment.
|
||||
|
@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
B_sf.sizes()[1], ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
|
@ -1113,8 +1113,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
@ -38,7 +38,6 @@
|
||||
#include "cute/atom/mma_atom.hpp"
|
||||
#include "cute/atom/copy_traits_sm90_tma.hpp"
|
||||
#include "cute/algorithm/gemm.hpp"
|
||||
#include "cute/tensor_predicate.hpp"
|
||||
#include "cute/numeric/arithmetic_tuple.hpp"
|
||||
#include "cutlass/pipeline/pipeline.hpp"
|
||||
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
|
||||
|
@ -27,6 +27,26 @@ __device__ inline void vectorize_with_alignment(
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT); // eg: 64 B
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
// Note: currently the output is guaranteed to be same as the input, so we
|
||||
// don't check it here, comments here just for future reference.
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
using vout_t = vec_n_t<OutT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
auto* v_out = reinterpret_cast<vout_t*>(out);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vout_t tmp;
|
||||
vec_op(tmp, v_in[i]);
|
||||
v_out[i] = tmp;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1); // addr % 64
|
||||
int alignment_bytes = WIDTH - misalignment_offset; // 64 - (addr % 64)
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1); // handle 64
|
||||
@ -72,4 +92,81 @@ __device__ __forceinline__ void vectorize_with_alignment(const InT* in,
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
struct DefaultReadVecOp {
|
||||
ScaOp scalar_op;
|
||||
|
||||
__device__ __forceinline__ void operator()(
|
||||
const vec_n_t<InT, VEC_SIZE>& src) const {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < VEC_SIZE; ++i) {
|
||||
scalar_op(src.val[i]);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// read-only version: iterate over the input with alignment guarantees
|
||||
template <int VEC_SIZE, typename InT, typename VecOp, typename ScaOp>
|
||||
__device__ inline void vectorize_read_with_alignment(const InT* in, int len,
|
||||
int tid, int stride,
|
||||
VecOp&& vec_op,
|
||||
ScaOp&& scalar_op) {
|
||||
static_assert(VEC_SIZE > 0 && (VEC_SIZE & (VEC_SIZE - 1)) == 0,
|
||||
"VEC_SIZE must be a positive power-of-two");
|
||||
constexpr int WIDTH = VEC_SIZE * sizeof(InT);
|
||||
uintptr_t addr = reinterpret_cast<uintptr_t>(in);
|
||||
|
||||
// fast path when the whole region is already aligned
|
||||
bool can_vec = ((addr & (WIDTH - 1)) == 0) && ((len & (VEC_SIZE - 1)) == 0);
|
||||
if (can_vec) {
|
||||
int num_vec = len / VEC_SIZE;
|
||||
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
int misalignment_offset = addr & (WIDTH - 1);
|
||||
int alignment_bytes = WIDTH - misalignment_offset;
|
||||
int prefix_elems = alignment_bytes & (WIDTH - 1);
|
||||
prefix_elems /= sizeof(InT);
|
||||
prefix_elems = min(prefix_elems, len);
|
||||
|
||||
// 1. handle the possibly unaligned prefix with scalar access.
|
||||
for (int i = tid; i < prefix_elems; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
|
||||
in += prefix_elems;
|
||||
len -= prefix_elems;
|
||||
|
||||
int num_vec = len / VEC_SIZE;
|
||||
using vin_t = vec_n_t<InT, VEC_SIZE>;
|
||||
auto* v_in = reinterpret_cast<const vin_t*>(in);
|
||||
|
||||
// 2. vectorized traversal of the main aligned region.
|
||||
for (int i = tid; i < num_vec; i += stride) {
|
||||
vec_op(v_in[i]);
|
||||
}
|
||||
|
||||
// 3. handle remaining tail elements.
|
||||
int tail_start = num_vec * VEC_SIZE;
|
||||
for (int i = tid + tail_start; i < len; i += stride) {
|
||||
scalar_op(in[i]);
|
||||
}
|
||||
}
|
||||
|
||||
// overload that requires only a scalar_op
|
||||
template <int VEC_SIZE, typename InT, typename ScaOp>
|
||||
__device__ __forceinline__ void vectorize_read_with_alignment(
|
||||
const InT* in, int len, int tid, int stride, ScaOp&& scalar_op) {
|
||||
using Vec = DefaultReadVecOp<VEC_SIZE, InT, std::decay_t<ScaOp>>;
|
||||
vectorize_read_with_alignment<VEC_SIZE>(in, len, tid, stride, Vec{scalar_op},
|
||||
std::forward<ScaOp>(scalar_op));
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
|
338
csrc/quickreduce/base.h
Normal file
338
csrc/quickreduce/base.h
Normal file
@ -0,0 +1,338 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
#define __quickreduce_device_inline__ __device__ __forceinline__
|
||||
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
|
||||
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
|
||||
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
|
||||
// as per architecture.
|
||||
#if defined(__gfx942__)
|
||||
// CDNA3: Scope bits sc0, sc1
|
||||
#define MUBUF_ACQUIRE 16
|
||||
#define MUBUF_RELEASE 16
|
||||
#elif (defined(__gfx908__) || defined(__gfx90a__))
|
||||
// CDNA1 and CDNA2 - glc bit
|
||||
#define MUBUF_ACQUIRE 1
|
||||
#define MUBUF_RELEASE 0
|
||||
#endif
|
||||
|
||||
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
|
||||
|
||||
// Number of atoms (4xf16x2_t) processed by a single thread
|
||||
static constexpr int kAtoms = 8;
|
||||
|
||||
// We use a workgroup of 256 threads
|
||||
static constexpr int kBlockSize = 256;
|
||||
static constexpr int kAtomStride = kBlockSize;
|
||||
|
||||
// Size and atom stride of source/destination data that the block will
|
||||
// process.
|
||||
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
|
||||
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
|
||||
|
||||
// Max number of blocks. 304 CUs on MI300
|
||||
static constexpr int kMaxNumBlocks = 304 * 4;
|
||||
|
||||
// Standard CDNA wavefront size.
|
||||
static constexpr int kWavefront = 64;
|
||||
|
||||
// 256 thread, 4 wavefronts.
|
||||
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
|
||||
|
||||
// Number of threads in a group for quantization
|
||||
// It corresponds to 32 F16 elements in quantization block
|
||||
static constexpr int kThreadGroupSize = 8;
|
||||
|
||||
// Methods
|
||||
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x,
|
||||
unsigned long y) {
|
||||
return ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
union BufferResource {
|
||||
__quickreduce_device_inline__ constexpr BufferResource()
|
||||
: config(0x00020000U) {}
|
||||
|
||||
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address,
|
||||
uint32_t buffer_size)
|
||||
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
|
||||
|
||||
int32x4_t descriptor;
|
||||
struct {
|
||||
void* address; // 8B, out of which first 48b is address, and 16b is stride
|
||||
// (unused)
|
||||
uint32_t range; // Byte range for the buffer resource
|
||||
uint32_t config; // Constant, DFMT=32b
|
||||
};
|
||||
};
|
||||
|
||||
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
|
||||
int32x4_t srsrc, int32_t voffset, int32_t soffset,
|
||||
int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void buffer_store_dwordx4(
|
||||
int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset,
|
||||
int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
|
||||
#if defined(__gfx942__)
|
||||
if (value) {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
|
||||
} else {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
union bf162_int_union {
|
||||
int i;
|
||||
nv_bfloat162 bf2;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A,
|
||||
int32x4_t* B);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A,
|
||||
int32x4_t* B) {
|
||||
int32x4_t& tR_fragment = A[0];
|
||||
int32x4_t& tA_fragment = B[0];
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[0])
|
||||
: "v"(tR_fragment[0]), "v"(tA_fragment[0]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[1])
|
||||
: "v"(tR_fragment[1]), "v"(tA_fragment[1]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[2])
|
||||
: "v"(tR_fragment[2]), "v"(tA_fragment[2]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[3])
|
||||
: "v"(tR_fragment[3]), "v"(tA_fragment[3]));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(
|
||||
int32x4_t* A, int32x4_t* B) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tA[i] = __hadd2(tA[i], tB[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmax2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_min(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmin2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
|
||||
half2 wmaxh2 = __builtin_bit_cast(half2, a);
|
||||
half2 wminh2 = __builtin_bit_cast(half2, b);
|
||||
half2 wblockmaxh2;
|
||||
|
||||
wblockmaxh2.x =
|
||||
__hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
|
||||
wblockmaxh2.y =
|
||||
__hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
|
||||
return __builtin_bit_cast(int, wblockmaxh2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
|
||||
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_add(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hadd2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_sub(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
|
||||
int result;
|
||||
|
||||
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2 %3"
|
||||
: "=v"(result)
|
||||
: "v"(kNegOne), "v"(b), "v"(a));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hsub2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_mul(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
|
||||
nv_bfloat162 tR = __hmul2(*tA, *tB);
|
||||
return *(reinterpret_cast<int*>(&tR));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_rcp(int a);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
|
||||
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
|
||||
bf162_int_union A, R;
|
||||
A.i = a;
|
||||
R.bf2 = h2rcp(A.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
// changes dtype
|
||||
__quickreduce_device_inline__ float T2float_cast(half a) {
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
|
||||
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
|
||||
|
||||
int wmax, wmin, wblockmax;
|
||||
int a, b;
|
||||
a = packed_max<T>(atom[0], atom[1]);
|
||||
b = packed_max<T>(atom[2], atom[3]);
|
||||
|
||||
wmax = packed_max<T>(a, b);
|
||||
|
||||
a = packed_min<T>(atom[0], atom[1]);
|
||||
b = packed_min<T>(atom[2], atom[3]);
|
||||
|
||||
wmin = packed_min<T>(a, b);
|
||||
|
||||
// Reduce the max among a group of threads
|
||||
// Note: This is basically 2 blocks of values setup as the
|
||||
// upper/lower halves of the f16x2_t
|
||||
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
|
||||
int x = __shfl_down(wmax, i);
|
||||
wmax = packed_max<T>(wmax, x);
|
||||
|
||||
int y = __shfl_down(wmin, i);
|
||||
wmin = packed_min<T>(wmin, y);
|
||||
}
|
||||
wblockmax = packed_abs_max<T>(wmax, wmin);
|
||||
// Share with the cohort
|
||||
wblockmax = __shfl(wblockmax, group_leader);
|
||||
return wblockmax;
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr,
|
||||
uint32_t flag) {
|
||||
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr,
|
||||
uint32_t flag) {
|
||||
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quickreduce
|
196
csrc/quickreduce/quick_reduce.h
Normal file
196
csrc/quickreduce/quick_reduce.h
Normal file
@ -0,0 +1,196 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "quick_reduce_impl.cuh"
|
||||
|
||||
#define HIP_CHECK(err) \
|
||||
do { \
|
||||
hipError_t err_ = (err); \
|
||||
if (err_ != hipSuccess) { \
|
||||
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \
|
||||
hipGetErrorString(err_)); \
|
||||
throw std::runtime_error("HIP error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace quickreduce {
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void
|
||||
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
||||
int rank, uint8_t** dbuffer_list,
|
||||
uint32_t data_offset, uint32_t flag_color) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
|
||||
flag_color);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
}
|
||||
|
||||
#define TWOSHOT_DISPATCH(__codec) \
|
||||
if (world_size == 2) { \
|
||||
using LineCodec = __codec<T, 2>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
F16 = 0,
|
||||
INT8 = 1,
|
||||
INT6 = 2,
|
||||
INT4 = 3,
|
||||
};
|
||||
|
||||
struct DeviceComms {
|
||||
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
|
||||
int64_t kMaxProblemSize =
|
||||
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
|
||||
// Max TP-8
|
||||
static int constexpr kMaxWorldSize = 8;
|
||||
|
||||
bool initialized = false;
|
||||
uint32_t flag_color = 1;
|
||||
int world_size;
|
||||
int rank;
|
||||
|
||||
uint8_t* dbuffer;
|
||||
uint8_t** dbuffer_list;
|
||||
hipIpcMemHandle_t buffer_ipc_handle;
|
||||
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
|
||||
std::vector<uint8_t*> buffer_list;
|
||||
uint32_t data_offset;
|
||||
|
||||
DeviceComms() : initialized(false), world_size(1), rank(0) {}
|
||||
~DeviceComms() { destroy(); }
|
||||
|
||||
void init(int world_size, int rank,
|
||||
std::optional<int64_t> max_problem_size = std::nullopt) {
|
||||
destroy();
|
||||
this->world_size = world_size;
|
||||
this->rank = rank;
|
||||
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
|
||||
this->kMaxProblemSize = max_problem_size.value();
|
||||
}
|
||||
// Allocate buffer size for worst case: F16 2-stage buffer.
|
||||
uint32_t flags_buffer_size =
|
||||
2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
|
||||
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
|
||||
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
|
||||
data_offset = flags_buffer_size;
|
||||
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
|
||||
hipDeviceMallocUncached));
|
||||
|
||||
// Clear the flags buffer.
|
||||
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
|
||||
|
||||
// Device-side list of IPC buffers.
|
||||
buffer_list.resize(world_size);
|
||||
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
|
||||
|
||||
// Create IPC handles for rank's communication buffer.
|
||||
all_buffer_ipc_handles.resize(world_size);
|
||||
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
int get_world_size() { return world_size; }
|
||||
int get_rank() { return rank; }
|
||||
bool status() { return initialized; }
|
||||
hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; }
|
||||
|
||||
void destroy() {
|
||||
if (initialized) {
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipFree(dbuffer));
|
||||
HIP_CHECK(hipFree(dbuffer_list));
|
||||
|
||||
initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
|
||||
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
all_buffer_ipc_handles[i] = ipc_handles[i];
|
||||
}
|
||||
|
||||
// Open device memory access to the IPC communication buffers.
|
||||
// Note: For our own rank, we do not need to open a handle.
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i],
|
||||
all_buffer_ipc_handles[i],
|
||||
hipIpcMemLazyEnablePeerAccess));
|
||||
} else {
|
||||
buffer_list[i] = dbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(),
|
||||
world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template <typename T, bool cast_bf2half>
|
||||
void allreduce(T const* A, T* B, uint32_t N, int quant_level,
|
||||
hipStream_t stream) {
|
||||
if (world_size != 2 && world_size != 4 && world_size != 8) {
|
||||
throw std::runtime_error("All Reduce not supported for world_size = " +
|
||||
std::to_string(world_size));
|
||||
}
|
||||
|
||||
// Configuration.
|
||||
uint32_t msg_size = N * sizeof(T);
|
||||
uint32_t num_blocks = divceil(msg_size, kTileSize);
|
||||
uint32_t grid = min(kMaxNumBlocks, num_blocks);
|
||||
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
|
||||
switch (quant_level_) {
|
||||
case QuickReduceQuantLevel::INT8:
|
||||
TWOSHOT_DISPATCH(CodecQ8)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT6:
|
||||
TWOSHOT_DISPATCH(CodecQ6)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT4:
|
||||
TWOSHOT_DISPATCH(CodecQ4)
|
||||
break;
|
||||
default:
|
||||
TWOSHOT_DISPATCH(CodecFP)
|
||||
break;
|
||||
}
|
||||
HIP_CHECK(cudaGetLastError());
|
||||
// Rotate the flag color.
|
||||
flag_color += divceil(N, grid);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
698
csrc/quickreduce/quick_reduce_impl.cuh
Normal file
698
csrc/quickreduce/quick_reduce_impl.cuh
Normal file
@ -0,0 +1,698 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "base.h"
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
struct CodecBase {
|
||||
const int thread;
|
||||
const int rank;
|
||||
const int group_leader;
|
||||
__quickreduce_device_inline__ CodecBase(int thread, int rank)
|
||||
: thread(thread),
|
||||
rank(rank),
|
||||
group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
|
||||
set_fp16_ovfl(true);
|
||||
}
|
||||
};
|
||||
|
||||
// Default full precision codec.
|
||||
template <typename T, int world_size>
|
||||
struct CodecFP : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each thread processes atoms of f16x8_t (16B).
|
||||
static constexpr int kRankTransmittedTileSize =
|
||||
kBlockSize * kRankAtoms * sizeof(int32x4_t);
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
__quickreduce_device_inline__ CodecFP(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
const int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
__builtin_nontemporal_store(data[i], send_buffer + thread);
|
||||
send_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
|
||||
*recv_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int4 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ4 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1152;
|
||||
static constexpr int kRankTileScaleOffset = 1024;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride =
|
||||
kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/8.0h, -1/8.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor =
|
||||
std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon =
|
||||
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-8, -8}, f16x2_t
|
||||
static constexpr int kRangeMin =
|
||||
std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
|
||||
|
||||
// {+7, +7}, f16x2_t
|
||||
static constexpr int kRangeMax =
|
||||
std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
|
||||
|
||||
// {+8, +8}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00080008;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ4(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q4 into int32_t
|
||||
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr =
|
||||
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
int32_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q4 into f16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static constexpr uint kMask000F = 0x000F000F;
|
||||
static constexpr uint kHalf2_1024 =
|
||||
0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1032 =
|
||||
0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q4, kHalf2_1032);
|
||||
} else {
|
||||
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int6 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ6 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1664;
|
||||
static constexpr int kRankTileQ2Offset = 1024;
|
||||
static constexpr int kRankTileScaleOffset = 1536;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride =
|
||||
kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/32.0h, -1/32.0h}, fp16x2_t
|
||||
static constexpr int kScaleFactor =
|
||||
std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
|
||||
|
||||
// {1e-7, 1e-7}, fp16x2_t
|
||||
static constexpr int kScaleEpsilon =
|
||||
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-32, -32}, fp16x2_t
|
||||
static constexpr int kRangeMin =
|
||||
std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
|
||||
|
||||
// {+31, +31}, fp16x2_t
|
||||
static constexpr int kRangeMax =
|
||||
std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
|
||||
|
||||
// {+32, +32}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00200020;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ6(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q6 into int32_t + int16_t
|
||||
uint32_t q4w;
|
||||
uint16_t q2w = 0;
|
||||
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) |
|
||||
((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
|
||||
{
|
||||
int16_t* tw = reinterpret_cast<int16_t*>(&q);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q2w |= (tw[i] >> 4) << (i * 2);
|
||||
}
|
||||
}
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr =
|
||||
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr =
|
||||
reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(q4w, q4w_ptr);
|
||||
__builtin_nontemporal_store(q2w, q2w_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr =
|
||||
reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
|
||||
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q6 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask000F = 0x000F000F;
|
||||
static uint constexpr kHalf2_1024 =
|
||||
0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1056 =
|
||||
0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int32_t q4 = q4w & kMask000F;
|
||||
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
|
||||
q4w >>= 4;
|
||||
q2w >>= 4;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(w[i])
|
||||
: "v"(q6), "v"(kHalf2_1056));
|
||||
} else {
|
||||
int32_t int16_2 = q4 | (q2 << 4);
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
// That's pretty much it...
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int8 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ8 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of f16x8_t (16B),
|
||||
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 2176;
|
||||
static constexpr int kRankTileScaleOffset = 2048;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride =
|
||||
kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/128.0h, -1/128.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor =
|
||||
std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon =
|
||||
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-128, -128}, f16x2_t
|
||||
static constexpr int kRangeMin =
|
||||
std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
|
||||
// {+127, +127}, f16x2_t
|
||||
static constexpr int kRangeMax =
|
||||
std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
|
||||
|
||||
// {+128, +128}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00800080;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ8(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
int32x4_t const* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q8 into int32x2_t
|
||||
int32x2_t qw;
|
||||
qw[0] = q[0] | (q[1] << 8);
|
||||
qw[1] = q[2] | (q[3] << 8);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr =
|
||||
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q8 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask00FF = 0x00FF00FF;
|
||||
|
||||
// {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1024 = 0x64006400;
|
||||
|
||||
// {-1152.0, -1152.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1152 = 0xE480E480;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q8 =
|
||||
((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q8, kHalf2_1152);
|
||||
} else {
|
||||
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Twoshot All Reduce
|
||||
template <typename T, class Codec, bool cast_bf2half>
|
||||
struct AllReduceTwoshot {
|
||||
static_assert(sizeof(T) == 2);
|
||||
|
||||
static constexpr int kWorldSize = Codec::kWorldSize;
|
||||
|
||||
__device__ static void run(
|
||||
T const* __restrict__ input, T* __restrict__ output,
|
||||
uint32_t const N, // number of elements
|
||||
int const block, // block index
|
||||
int const rank, // rank index
|
||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||
uint32_t const data_offset, // offset to start of the data buffer
|
||||
uint32_t flag_color) {
|
||||
// Topology
|
||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||
uint8_t* rank_buffer = buffer_list[rank];
|
||||
Codec codec(thread, rank);
|
||||
int block_id = blockIdx.x;
|
||||
int grid_size = gridDim.x;
|
||||
// --------------------------------------------------------
|
||||
// Read input into registers
|
||||
int32x4_t tA[kAtoms];
|
||||
|
||||
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
|
||||
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
|
||||
src_offset += kAtomStride * sizeof(int32x4_t);
|
||||
if constexpr (cast_bf2half) {
|
||||
const nv_bfloat162* bf_buf =
|
||||
reinterpret_cast<const nv_bfloat162*>(&tA[i]);
|
||||
half2 half_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __bfloat1622float2(bf_buf[j]);
|
||||
half_buf[j] = __float22half2_rn(f);
|
||||
}
|
||||
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Phase-1A: Write segment data into the communication buffer of the target
|
||||
// rank responsible for this segment.
|
||||
uint32_t comm_data0_offset =
|
||||
data_offset + block_id * Codec::kTransmittedTileSize;
|
||||
uint32_t comm_data1_offset =
|
||||
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
||||
|
||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||
uint32_t comm_flags1_offset =
|
||||
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset +
|
||||
rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(
|
||||
buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
// --------------------------------------------------------
|
||||
// Phase-1B: Reduce the segment data from the communication buffers.
|
||||
int32x4_t tR[Codec::kRankAtoms] = {};
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer =
|
||||
reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
|
||||
uint32_t* flag_ptr =
|
||||
reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// note: we reuse tA as temp buffer here
|
||||
codec.recv(&recv_buffer, tA);
|
||||
|
||||
for (int i = 0; i < Codec::kRankAtoms; i++) {
|
||||
packed_assign_add<T>(&tR[i], &tA[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase-2: Write the reduced segment to every other rank
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset +
|
||||
rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, tR);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(
|
||||
buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
|
||||
// Phase-2: Read the gather segments from the rank's communication buffer.
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer =
|
||||
reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
|
||||
uint32_t* flag_ptr =
|
||||
reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Gather all reduced and final rank segments into tA.
|
||||
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Write the result to output.
|
||||
BufferResource dst_buffer(output, N * sizeof(T));
|
||||
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
if constexpr (cast_bf2half) {
|
||||
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
|
||||
nv_bfloat162 bf16_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __half22float2(half_buf[j]);
|
||||
bf16_buf[j] = __float22bfloat162_rn(f);
|
||||
}
|
||||
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf),
|
||||
dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
} else {
|
||||
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
}
|
||||
dst_offset += kAtomStride * sizeof(int32x4_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
@ -1598,7 +1598,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane2id = laneid % 2;
|
||||
const int lane4id = laneid % 4;
|
||||
const int lane16id = laneid % 16;
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
@ -1745,7 +1744,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
|
||||
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
|
||||
|
||||
@ -2368,7 +2366,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane2id = laneid % 2;
|
||||
const int lane4id = laneid % 4;
|
||||
const int lane16id = laneid % 16;
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
@ -2514,7 +2511,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
|
||||
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
|
||||
|
||||
|
@ -59,6 +59,8 @@ void apply_repetition_penalties_(
|
||||
int vocab_size = logits.size(-1);
|
||||
int num_seqs = logits.size(0);
|
||||
|
||||
if (num_seqs == 0) return;
|
||||
|
||||
// Get number of SMs on the current device
|
||||
int sms = 0;
|
||||
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount,
|
||||
|
@ -79,7 +79,8 @@ struct cutlass_sparse_3x_gemm {
|
||||
// These are the minimum alignments needed for the kernels to compile
|
||||
static constexpr int AlignmentAB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
static constexpr int AlignmentCD = 4;
|
||||
static constexpr int AlignmentCD =
|
||||
128 / cutlass::sizeof_bits<ElementD>::value;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
|
@ -393,6 +393,14 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_fp4_mm", torch::kCUDA, &cutlass_scaled_fp4_mm);
|
||||
|
||||
// cutlass blockwise scaledgroup GEMM
|
||||
ops.def(
|
||||
"cutlass_blockwise_scaled_grouped_mm(Tensor! output, Tensor a, Tensor b, "
|
||||
"Tensor scales_a, Tensor scales_b, "
|
||||
"Tensor problem_sizes, Tensor expert_offsets) -> ()",
|
||||
{stride_tag});
|
||||
// conditionally compiled so impl registration is in source file
|
||||
|
||||
// cutlass nvfp4 block scaled group GEMM
|
||||
ops.def(
|
||||
"cutlass_fp4_group_mm(Tensor! out, Tensor a, Tensor b,"
|
||||
@ -725,6 +733,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
||||
|
||||
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
||||
#ifdef USE_ROCM
|
||||
// Quick Reduce all-reduce kernels
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
@ -1,3 +1,4 @@
|
||||
|
||||
# The vLLM Dockerfile is used to construct vLLM image that can be directly used
|
||||
# to run the OpenAI compatible server.
|
||||
|
||||
@ -6,30 +7,110 @@
|
||||
# docs/assets/contributing/dockerfile-stages-dependency.png
|
||||
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# By parameterizing the base images, we allow third-party to use their own
|
||||
# base images. One use case is hermetic builds with base images stored in
|
||||
# private registries that use a different repository naming conventions.
|
||||
#
|
||||
# Example:
|
||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
|
||||
# By parameterizing the Deadsnakes repository URL, we allow third-party to use
|
||||
# their own mirror. When doing so, we don't benefit from the transparent
|
||||
# installation of the GPG key of the PPA, as done by add-apt-repository, so we
|
||||
# also need a URL for the GPG key.
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
|
||||
# The PyPA get-pip.py script is a self contained script+zip file, that provides
|
||||
# both the installer script and the pip base85-encoded zip archive. This allows
|
||||
# bootstrapping pip in environment where a dsitribution package does not exist.
|
||||
#
|
||||
# By parameterizing the URL for get-pip.py installation script, we allow
|
||||
# third-party to use their own copy of the script stored in a private mirror.
|
||||
# We set the default value to the PyPA owned get-pip.py script.
|
||||
#
|
||||
# Reference: https://pip.pypa.io/en/stable/installation/#get-pip-py
|
||||
ARG GET_PIP_URL="https://bootstrap.pypa.io/get-pip.py"
|
||||
|
||||
# PIP supports fetching the packages from custom indexes, allowing third-party
|
||||
# to host the packages in private mirrors. The PIP_INDEX_URL and
|
||||
# PIP_EXTRA_INDEX_URL are standard PIP environment variables to override the
|
||||
# default indexes. By letting them empty by default, PIP will use its default
|
||||
# indexes if the build process doesn't override the indexes.
|
||||
#
|
||||
# Uv uses different variables. We set them by default to the same values as
|
||||
# PIP, but they can be overridden.
|
||||
ARG PIP_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL
|
||||
ARG UV_INDEX_URL=${PIP_INDEX_URL}
|
||||
ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
|
||||
# PyTorch provides its own indexes for standard and nightly builds
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
|
||||
|
||||
# PIP supports multiple authentication schemes, including keyring
|
||||
# By parameterizing the PIP_KEYRING_PROVIDER variable and setting it to
|
||||
# disabled by default, we allow third-party to use keyring authentication for
|
||||
# their private Python indexes, while not changing the default behavior which
|
||||
# is no authentication.
|
||||
#
|
||||
# Reference: https://pip.pypa.io/en/stable/topics/authentication/#keyring-support
|
||||
ARG PIP_KEYRING_PROVIDER=disabled
|
||||
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
|
||||
|
||||
# Flag enables build-in KV-connector dependency libs into docker images
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
@ -63,21 +144,25 @@ WORKDIR /workspace
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
"torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
--pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
fi
|
||||
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/cuda.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
# see https://github.com/pytorch/pytorch/pull/123243
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# Override the arch list for flash-attn to reduce the binary size
|
||||
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
|
||||
@ -88,6 +173,10 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
|
||||
@ -98,7 +187,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
@ -113,6 +202,8 @@ ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
|
||||
ARG SCCACHE_ENDPOINT
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
@ -121,10 +212,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
|
||||
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
@ -162,6 +254,10 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
#################### DEV IMAGE ####################
|
||||
FROM base as dev
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
@ -176,21 +272,26 @@ COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
@ -200,17 +301,33 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
@ -232,19 +349,23 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
"torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319" ; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
--pre pytorch_triton==3.3.0+gitab727c40 ; \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# If we need to build FlashInfer wheel before its release:
|
||||
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a'
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
|
||||
# $ cd flashinfer
|
||||
# $ git checkout v0.2.6.post1
|
||||
@ -254,23 +375,49 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
|
||||
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
|
||||
uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl; \
|
||||
else \
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a' && \
|
||||
git clone https://github.com/flashinfer-ai/flashinfer.git --single-branch --branch v0.2.6.post1 --recursive && \
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
) && \
|
||||
rm -rf flashinfer; \
|
||||
fi \
|
||||
fi
|
||||
# Allow specifying a version, Git revision or local .whl file
|
||||
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
|
||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
|
||||
else
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
)
|
||||
rm -rf flashinfer
|
||||
|
||||
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch v0.2.6.post1 \
|
||||
https://github.com/flashinfer-ai/flashinfer.git flashinfer
|
||||
|
||||
pushd flashinfer
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
|
||||
rm -rf flashinfer
|
||||
fi \
|
||||
fi
|
||||
BASH
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
@ -286,7 +433,7 @@ uv pip list
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
@ -297,6 +444,11 @@ FROM vllm-base AS test
|
||||
|
||||
ADD . /vllm-workspace/
|
||||
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
@ -307,7 +459,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
|
||||
if [ "$CUDA_MAJOR" -ge 12 ]; then \
|
||||
uv pip install --system -r requirements/dev.txt; \
|
||||
@ -323,7 +475,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
# Copy in the v1 package for testing (it isn't distributed yet)
|
||||
COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
|
||||
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
||||
|
||||
# doc requires source code
|
||||
# we hide them inside `test_docs/` , so that this source code
|
||||
@ -339,17 +491,26 @@ RUN mv mkdocs.yaml test_docs/
|
||||
# base openai image with additional requirements, for any subsequent openai-style images
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
ARG INSTALL_KV_CONNECTORS=false
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
|
||||
COPY requirements/kv_connectors.txt requirements/kv_connectors.txt
|
||||
|
||||
# install additional dependencies for openai api server
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$INSTALL_KV_CONNECTORS" = "true" ]; then \
|
||||
uv pip install --system -r requirements/kv_connectors.txt; \
|
||||
fi; \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.3' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.46.1' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
@ -8,6 +8,8 @@
|
||||
# Build arguments:
|
||||
# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9
|
||||
# VLLM_CPU_DISABLE_AVX512=false (default)|true
|
||||
# VLLM_CPU_AVX512BF16=false (default)|true
|
||||
# VLLM_CPU_AVX512VNNI=false (default)|true
|
||||
#
|
||||
|
||||
######################### BASE IMAGE #########################
|
||||
@ -25,7 +27,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
@ -60,13 +62,19 @@ FROM base AS vllm-build
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
||||
ARG VLLM_CPU_DISABLE_AVX512
|
||||
ARG VLLM_CPU_DISABLE_AVX512=0
|
||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||
# Support for building with AVX512BF16 ISA: docker build --build-arg VLLM_CPU_AVX512BF16="true" ...
|
||||
ARG VLLM_CPU_AVX512BF16=0
|
||||
ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
|
||||
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
|
||||
ARG VLLM_CPU_AVX512VNNI=0
|
||||
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
||||
--mount=type=bind,src=requirements/cpu-build.txt,target=requirements/build.txt \
|
||||
uv pip install -r requirements/build.txt
|
||||
|
||||
COPY . .
|
||||
@ -79,6 +87,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel
|
||||
|
||||
######################### TEST DEPS #########################
|
||||
FROM base AS vllm-test-deps
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/cpu-test.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||
sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
|
||||
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -r requirements/cpu-test.txt
|
||||
|
||||
######################### DEV IMAGE #########################
|
||||
FROM vllm-build AS vllm-dev
|
||||
|
||||
@ -97,28 +121,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py develop
|
||||
|
||||
COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/test-cpu.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
|
||||
uv pip compile requirements/test-cpu.in -o requirements/test.txt && \
|
||||
uv pip install -r requirements/dev.txt && \
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
|
||||
ENTRYPOINT ["bash"]
|
||||
|
||||
######################### TEST IMAGE #########################
|
||||
FROM base AS vllm-test
|
||||
FROM vllm-test-deps AS vllm-test
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/test-cpu.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
|
||||
uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \
|
||||
uv pip install -r requirements/cpu-test.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
|
||||
uv pip install dist/*.whl
|
||||
@ -127,6 +142,7 @@ ADD ./tests/ ./tests/
|
||||
ADD ./examples/ ./examples/
|
||||
ADD ./benchmarks/ ./benchmarks/
|
||||
ADD ./vllm/collect_env.py .
|
||||
ADD ./.buildkite/ ./.buildkite/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="c1debd8"
|
||||
ARG AITER_BRANCH="6487649"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
ENV VLLM_TARGET_DEVICE=xpu
|
||||
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
|
@ -39,6 +39,7 @@ nav:
|
||||
- models/generative_models.md
|
||||
- models/pooling_models.md
|
||||
- models/extensions
|
||||
- Hardware Supported Models: models/hardware_supported_models
|
||||
- Features:
|
||||
- features/compatibility_matrix.md
|
||||
- features/*
|
||||
@ -48,7 +49,12 @@ nav:
|
||||
- General:
|
||||
- glob: contributing/*
|
||||
flatten_single_child_sections: true
|
||||
- Model Implementation: contributing/model
|
||||
- Model Implementation:
|
||||
- contributing/model/README.md
|
||||
- contributing/model/basic.md
|
||||
- contributing/model/registration.md
|
||||
- contributing/model/tests.md
|
||||
- contributing/model/multimodal.md
|
||||
- Design Documents:
|
||||
- V0: design
|
||||
- V1: design/v1
|
||||
|
@ -1,7 +1,8 @@
|
||||
# Welcome to vLLM
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="vLLM" class="no-scaled-link" width="60%" }
|
||||
{ align="center" alt="vLLM Light" class="logo-light" width="60%" }
|
||||
{ align="center" alt="vLLM Dark" class="logo-dark" width="60%" }
|
||||
</figure>
|
||||
|
||||
<p style="text-align:center">
|
||||
@ -40,7 +41,7 @@ vLLM is flexible and easy to use with:
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
- Prefix caching support
|
||||
- Multi-lora support
|
||||
- Multi-LoRA support
|
||||
|
||||
For more information, check out the following:
|
||||
|
||||
|
@ -91,7 +91,7 @@ source to unblock the update process.
|
||||
### FlashInfer
|
||||
Here is how to build and install it from source with torch2.7.0+cu128 in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
|
||||
|
||||
```
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
export FLASHINFER_ENABLE_SM90=1
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
|
||||
@ -105,14 +105,14 @@ team if you want to get the package published there.
|
||||
### xFormers
|
||||
Similar to FlashInfer, here is how to build and install xFormers from source:
|
||||
|
||||
```
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
MAX_JOBS=16 uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
```
|
||||
|
||||
### Mamba
|
||||
|
||||
```
|
||||
```bash
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
```
|
||||
|
||||
|
@ -16,35 +16,33 @@ vllm {chat,complete,serve,bench,collect-env,run-batch}
|
||||
|
||||
Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
Examples:
|
||||
??? Examples
|
||||
|
||||
```bash
|
||||
# Start with a model
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
```bash
|
||||
# Start with a model
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
```
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
```
|
||||
|
||||
## chat
|
||||
|
||||
Generate chat completions via the running API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Directly connect to localhost API without arguments
|
||||
vllm chat
|
||||
@ -60,8 +58,6 @@ vllm chat --quick "hi"
|
||||
|
||||
Generate text completions based on the given prompt via the running API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Directly connect to localhost API without arguments
|
||||
vllm complete
|
||||
@ -73,6 +69,8 @@ vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1
|
||||
vllm complete --quick "The future of AI is"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## bench
|
||||
|
||||
Run benchmark tests for latency online serving throughput and offline inference throughput.
|
||||
@ -89,8 +87,6 @@ vllm bench {latency, serve, throughput}
|
||||
|
||||
Benchmark the latency of a single batch of requests.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench latency \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
@ -104,8 +100,6 @@ vllm bench latency \
|
||||
|
||||
Benchmark the online serving throughput.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
@ -120,8 +114,6 @@ vllm bench serve \
|
||||
|
||||
Benchmark offline inference throughput.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
@ -143,7 +135,8 @@ vllm collect-env
|
||||
|
||||
Run batch prompts and write results to file.
|
||||
|
||||
Examples:
|
||||
<details>
|
||||
<summary>Examples</summary>
|
||||
|
||||
```bash
|
||||
# Running with a local file
|
||||
@ -159,6 +152,8 @@ vllm run-batch \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## More Help
|
||||
|
||||
For detailed options of any subcommand, use:
|
||||
|
6
docs/community/contact_us.md
Normal file
6
docs/community/contact_us.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
title: Contact Us
|
||||
---
|
||||
[](){ #contactus }
|
||||
|
||||
--8<-- "README.md:contact-us"
|
@ -57,19 +57,21 @@ By default, we optimize model inference using CUDA graphs which take up extra me
|
||||
|
||||
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
??? Code
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# By default, it goes up to max_num_seqs
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
),
|
||||
)
|
||||
```
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# By default, it goes up to max_num_seqs
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can disable graph capturing completely via the `enforce_eager` flag:
|
||||
|
||||
@ -127,18 +129,20 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
??? Code
|
||||
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
|
@ -6,7 +6,7 @@ title: Engine Arguments
|
||||
Engine arguments control the behavior of the vLLM engine.
|
||||
|
||||
- For [offline inference][offline-inference], they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving][openai-compatible-server], they are part of the arguments to `vllm serve`.
|
||||
- For [online serving][serving-openai-compatible-server], they are part of the arguments to `vllm serve`.
|
||||
|
||||
You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments.
|
||||
|
||||
|
@ -7,6 +7,8 @@ vLLM uses the following environment variables to configure the system:
|
||||
|
||||
All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables).
|
||||
|
||||
```python
|
||||
--8<-- "vllm/envs.py:env-vars-definition"
|
||||
```
|
||||
??? Code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/envs.py:env-vars-definition"
|
||||
```
|
||||
|
@ -29,6 +29,8 @@ See <gh-file:LICENSE>.
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation.
|
||||
Check out the [building from source][build-from-source] documentation for details.
|
||||
|
||||
For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations.
|
||||
|
||||
### Building the docs with MkDocs
|
||||
|
||||
#### Introduction to MkDocs
|
||||
@ -93,25 +95,27 @@ For additional features and advanced configurations, refer to the official [MkDo
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
pip install -r requirements/dev.txt
|
||||
??? note "Commands"
|
||||
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
```bash
|
||||
pip install -r requirements/dev.txt
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
|
||||
# To manually run something from CI that does not run
|
||||
# locally by default, you can run:
|
||||
pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
# To manually run something from CI that does not run
|
||||
# locally by default, you can run:
|
||||
pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
|
||||
!!! tip
|
||||
Since the <gh-file:docker/Dockerfile> ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12.
|
||||
@ -147,6 +151,14 @@ the terms of the DCO.
|
||||
|
||||
Using `-s` with `git commit` will automatically add this header.
|
||||
|
||||
!!! tip
|
||||
You can enable automatic sign-off via your IDE:
|
||||
|
||||
- **PyCharm**: Click on the `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window.
|
||||
It will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`.
|
||||
- **VSCode**: Open the [Settings editor](https://code.visualstudio.com/docs/configure/settings)
|
||||
and enable the `Git: Always Sign Off` (`git.alwaysSignOff`) field.
|
||||
|
||||
### PR Title and Classification
|
||||
|
||||
Only specific types of PRs will be reviewed. The PR title is prefixed
|
||||
@ -186,6 +198,7 @@ The PR needs to meet the following code quality standards:
|
||||
|
||||
### Adding or Changing Kernels
|
||||
|
||||
When actively developing or modifying kernels, using the [Incremental Compilation Workflow](./incremental_build.md) is highly recommended for faster build times.
|
||||
Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.
|
||||
|
||||
- Make sure custom ops are registered following PyTorch guidelines:
|
||||
|
@ -37,14 +37,14 @@ multiple Y releases:
|
||||
- **Timeline**: A removal version is explicitly stated in the deprecation
|
||||
warning (e.g., "This will be removed in v0.10.0").
|
||||
- **Communication**: Deprecation is noted in the following, as applicable:
|
||||
- Help strings
|
||||
- Log output
|
||||
- API responses
|
||||
- `/metrics` output (for metrics features)
|
||||
- User-facing documentation
|
||||
- Release notes
|
||||
- GitHub Issue (RFC) for feedback
|
||||
- Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs
|
||||
- Help strings
|
||||
- Log output
|
||||
- API responses
|
||||
- `/metrics` output (for metrics features)
|
||||
- User-facing documentation
|
||||
- Release notes
|
||||
- GitHub Issue (RFC) for feedback
|
||||
- Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs
|
||||
|
||||
**2.Deprecated (Off By Default)**
|
||||
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user