mirror of
https://github.com/vllm-project/vllm.git
synced 2025-10-28 04:14:34 +08:00
Compare commits
259 Commits
fix-precom
...
add-utils
| Author | SHA1 | Date | |
|---|---|---|---|
| 7d092fc32c | |||
| 1a6c27f271 | |||
| 3c6fd286b4 | |||
| 536fd33003 | |||
| 619b9f5c7e | |||
| d1b689c445 | |||
| 9854dc9040 | |||
| ff5c60fad8 | |||
| 6f1229f91d | |||
| 1819fbda63 | |||
| 7f0367109e | |||
| fb14d53cf6 | |||
| b024a42e93 | |||
| cb97f2bfc5 | |||
| 359200f6ac | |||
| 220aee902a | |||
| 67d25eca05 | |||
| 363528de27 | |||
| 4ff61ababa | |||
| 0ec3779df7 | |||
| b616f6a53d | |||
| 2e25bb12a8 | |||
| 9965c47d0d | |||
| 059d4cdb49 | |||
| bdb84e26b0 | |||
| 3dd359147d | |||
| 657f2f301a | |||
| a1aafc827a | |||
| 139508a418 | |||
| d265414dbc | |||
| 48fb076cbc | |||
| c1909e7e8c | |||
| b95877509b | |||
| 706ff13224 | |||
| ccbfb1d1c9 | |||
| 9e5552aa13 | |||
| 0c600b9ab6 | |||
| e303dcf523 | |||
| ae9c4d416f | |||
| d853520b3e | |||
| ba51aea65e | |||
| 8452946c06 | |||
| 2e7cbf2d7d | |||
| 7da296be04 | |||
| b205e8467d | |||
| be0cfb2b68 | |||
| 1a03dd496b | |||
| 27b8017636 | |||
| 9ec1e3065a | |||
| 9dae7d46bf | |||
| 7058d7dd5d | |||
| a0389e0554 | |||
| 3be8d312a2 | |||
| 3abfe22154 | |||
| e81fbefe8a | |||
| 9290de5667 | |||
| 7f280d69c9 | |||
| 02cabff207 | |||
| 3d19d47d91 | |||
| 8acb4badee | |||
| 314af8617c | |||
| 0e96cc9b7e | |||
| ecad851cbd | |||
| ed70f3c64f | |||
| 650d5dbd04 | |||
| 9025a9a705 | |||
| c05596f1a3 | |||
| 787b13389e | |||
| 96453cfa83 | |||
| b1c1fe35a5 | |||
| 08d81f1014 | |||
| 6cc1e7d96d | |||
| 9909726d2a | |||
| 22e9d42040 | |||
| 86debab54c | |||
| be250bbc67 | |||
| 27949354fa | |||
| bd5038af07 | |||
| a2f14dc8f9 | |||
| 92ee7baaf9 | |||
| 7151f92241 | |||
| e28533a16f | |||
| 6d42ce8315 | |||
| ded1fb635b | |||
| 97d9524fe9 | |||
| d8cf819a9a | |||
| 551ef1631a | |||
| 2863befce3 | |||
| 2965c99c86 | |||
| 2062c0723d | |||
| 1c50e100a9 | |||
| 3ee56e26be | |||
| 8fe7fc8634 | |||
| e936e401de | |||
| f5dfa07531 | |||
| 022c58b80f | |||
| 19108ef311 | |||
| 5a52f389dd | |||
| 65b1cbb138 | |||
| 6c9837a761 | |||
| 6f2f53a82d | |||
| 7b1895e6ce | |||
| 4d36693687 | |||
| daec9dea6e | |||
| daceac57c7 | |||
| 8615d9776f | |||
| 7b460c25f9 | |||
| f719772281 | |||
| d45417b804 | |||
| a29e62ea34 | |||
| e53be6f00a | |||
| c329ceca6d | |||
| 3c545c0c3b | |||
| e8c3bd2cd1 | |||
| c6c983053d | |||
| aafabaa0d5 | |||
| 94a55c7681 | |||
| aa0dc77ef5 | |||
| 4ab3ac285e | |||
| d1c956dc0f | |||
| dec197e3e5 | |||
| 6e244ae091 | |||
| cd4cfee689 | |||
| e110930680 | |||
| 8b64c895c0 | |||
| 0740e29b66 | |||
| 44d2e6af63 | |||
| 2d7779f888 | |||
| a57d57fa72 | |||
| 71799fd005 | |||
| e9fd658a73 | |||
| 07b8fae219 | |||
| 562308816c | |||
| 04e1642e32 | |||
| b69781f107 | |||
| 0bceac9810 | |||
| 34878a0b48 | |||
| 6393b03986 | |||
| 0907d507bf | |||
| c894c5dc1f | |||
| 1f5d178e9c | |||
| 27c065df50 | |||
| 84c260caeb | |||
| 167aca45cb | |||
| 0567c8249f | |||
| d188913d99 | |||
| 1d7c29f5fe | |||
| 65397e40f5 | |||
| 9502c38138 | |||
| 2582683566 | |||
| 754b00edb3 | |||
| 296ce95d8e | |||
| 2d7620c3eb | |||
| 55c65ab495 | |||
| 2cc2069970 | |||
| 9f0608fc16 | |||
| 4e0db57fff | |||
| c40692bf9a | |||
| 4734704b30 | |||
| 8b8c209e35 | |||
| 23a04e0895 | |||
| 02c97d9a92 | |||
| e795d723ed | |||
| 8359f4c8d8 | |||
| bf5181583f | |||
| c53fec1fcb | |||
| 0f9e7354f5 | |||
| ba7ba35cda | |||
| 015fab8c2f | |||
| f59fc60fb3 | |||
| 879f69bed3 | |||
| 7108934142 | |||
| 3443aaf8dd | |||
| 2273ec322c | |||
| a6c4b87fbc | |||
| 1afa9948f5 | |||
| 0d06b533a0 | |||
| c01d1c5aba | |||
| ead369845d | |||
| c6e3bba8e6 | |||
| 91f7d9d0b6 | |||
| 8619e7158c | |||
| c635c5f744 | |||
| a045b7e89a | |||
| 981eeca41a | |||
| 26d34eb67e | |||
| 53da4cd397 | |||
| 9a3b88328f | |||
| 3014c920da | |||
| 0eed516951 | |||
| ee5ad8d2c5 | |||
| a738dbb2a1 | |||
| 33d5e29be9 | |||
| 4671ac6e2a | |||
| dd2ccf8dde | |||
| a3bc76e4b5 | |||
| e6327c9b3e | |||
| d0132f025d | |||
| 61f4fc5dc6 | |||
| 68aaeb3749 | |||
| c3649e4fee | |||
| 53243e5c42 | |||
| a6e6604d32 | |||
| b82e0f82cb | |||
| 5111642a6f | |||
| 1bcd15edc7 | |||
| 2ebff5b77c | |||
| f17aec0d63 | |||
| 493c275352 | |||
| f39ab2d4bd | |||
| 4a0f7888a3 | |||
| c4cf260677 | |||
| 33d51f599e | |||
| e91386cde1 | |||
| 2c11a29f0b | |||
| c76a506bd6 | |||
| ec0db6f51c | |||
| c305a2109d | |||
| 202c5df935 | |||
| 2bb246b8f7 | |||
| 4c409cabc2 | |||
| 3b1e4c6a23 | |||
| 2c5302fadd | |||
| caa680fd2e | |||
| c3bf9bad11 | |||
| 6f170f11dd | |||
| 8ca81bb069 | |||
| e773a9e1c2 | |||
| 71baf85ae1 | |||
| 79f2f1c2a1 | |||
| 2e3e3c86dc | |||
| 7e8977fcd4 | |||
| f1e840e842 | |||
| 7771d1de88 | |||
| 71d1219545 | |||
| e384f2f108 | |||
| 089a306f19 | |||
| 5e666f72cd | |||
| e3a3e4db46 | |||
| e41bf15cd0 | |||
| 5aa4a015ce | |||
| b6bad3d186 | |||
| ee9a1531aa | |||
| 10d82f9ac5 | |||
| ea10dd9d9e | |||
| ead2110297 | |||
| 01220ce89a | |||
| 6f68c49220 | |||
| 4719460644 | |||
| 466166dcfd | |||
| 1d0ae26c85 | |||
| 6021999573 | |||
| c7b370c603 | |||
| aa20d10a91 | |||
| 2de12be428 | |||
| 83ca9ae47b | |||
| e2148dc5ea | |||
| b1098b4072 | |||
| 799397ee4f |
@ -11,7 +11,7 @@ See [vLLM performance dashboard](https://perf.vllm.ai) for the latest performanc
|
||||
|
||||
## Performance benchmark quick overview
|
||||
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!), with different models.
|
||||
**Benchmarking Coverage**: latency, throughput and fix-qps serving on A100 (the support for FP8 benchmark on H100 is coming!) and Intel® Xeon® Processors, with different models.
|
||||
|
||||
**Benchmarking Duration**: about 1hr.
|
||||
|
||||
@ -31,13 +31,27 @@ Performance benchmark will be triggered when:
|
||||
- A PR being merged into vllm.
|
||||
- Every commit for those PRs with `perf-benchmarks` label AND `ready` label.
|
||||
|
||||
Manually Trigger the benchmark
|
||||
|
||||
```bash
|
||||
bash .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh
|
||||
```
|
||||
|
||||
Runtime environment variables:
|
||||
- `ON_CPU`: set the value to '1' on Intel® Xeon® Processors. Default value is 0.
|
||||
- `SERVING_JSON`: JSON file to use for the serving tests. Default value is empty string (use default file).
|
||||
- `LATENCY_JSON`: JSON file to use for the latency tests. Default value is empty string (use default file).
|
||||
- `THROUGHPUT_JSON`: JSON file to use for the throughout tests. Default value is empty string (use default file).
|
||||
- `REMOTE_HOST`: IP for the remote vLLM service to benchmark. Default value is empty string.
|
||||
- `REMOTE_PORT`: Port for the remote vLLM service to benchmark. Default value is empty string.
|
||||
|
||||
Nightly benchmark will be triggered when:
|
||||
- Every commit for those PRs with `perf-benchmarks` label and `nightly-benchmarks` label.
|
||||
|
||||
## Performance benchmark details
|
||||
|
||||
See [performance-benchmarks-descriptions.md](performance-benchmarks-descriptions.md) for detailed descriptions, and use `tests/latency-tests.json`, `tests/throughput-tests.json`, `tests/serving-tests.json` to configure the test cases.
|
||||
|
||||
> NOTE: For Intel® Xeon® Processors, use `tests/latency-tests-cpu.json`, `tests/throughput-tests-cpu.json`, `tests/serving-tests-cpu.json` instead.
|
||||
### Latency test
|
||||
|
||||
Here is an example of one test inside `latency-tests.json`:
|
||||
@ -119,6 +133,30 @@ If you do not see the table, please wait till the benchmark finish running.
|
||||
The json version of the table (together with the json version of the benchmark) will be also attached to the markdown file.
|
||||
The raw benchmarking results (in the format of json files) are in the `Artifacts` tab of the benchmarking.
|
||||
|
||||
The `compare-json-results.py` helps to compare benchmark results JSON files converted using `convert-results-json-to-markdown.py`.
|
||||
When run, benchmark script generates results under `benchmark/results` folder, along with the `benchmark_results.md` and `benchmark_results.json`.
|
||||
`compare-json-results.py` compares two `benchmark_results.json` files and provides performance ratio e.g. for Output Tput, Median TTFT and Median TPOT.
|
||||
|
||||
Here is an example using the script to compare result_a and result_b without detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json --ignore_test_name`
|
||||
|
||||
| | results_a/benchmark_results.json | results_b/benchmark_results.json | perf_ratio |
|
||||
|----|----------------------------------------|----------------------------------------|----------|
|
||||
| 0 | 142.633982 | 156.526018 | 1.097396 |
|
||||
| 1 | 241.620334 | 294.018783 | 1.216863 |
|
||||
| 2 | 218.298905 | 262.664916 | 1.203235 |
|
||||
| 3 | 242.743860 | 299.816190 | 1.235113 |
|
||||
|
||||
Here is an example using the script to compare result_a and result_b with detail test name.
|
||||
`python3 compare-json-results.py -f results_a/benchmark_results.json -f results_b/benchmark_results.json`
|
||||
| | results_a/benchmark_results.json_name | results_a/benchmark_results.json | results_b/benchmark_results.json_name | results_b/benchmark_results.json | perf_ratio |
|
||||
|---|---------------------------------------------|----------------------------------------|---------------------------------------------|----------------------------------------|----------|
|
||||
| 0 | serving_llama8B_tp1_sharegpt_qps_1 | 142.633982 | serving_llama8B_tp1_sharegpt_qps_1 | 156.526018 | 1.097396 |
|
||||
| 1 | serving_llama8B_tp1_sharegpt_qps_16 | 241.620334 | serving_llama8B_tp1_sharegpt_qps_16 | 294.018783 | 1.216863 |
|
||||
| 2 | serving_llama8B_tp1_sharegpt_qps_4 | 218.298905 | serving_llama8B_tp1_sharegpt_qps_4 | 262.664916 | 1.203235 |
|
||||
| 3 | serving_llama8B_tp1_sharegpt_qps_inf | 242.743860 | serving_llama8B_tp1_sharegpt_qps_inf | 299.816190 | 1.235113 |
|
||||
| 4 | serving_llama8B_tp2_random_1024_128_qps_1 | 96.613390 | serving_llama8B_tp4_random_1024_128_qps_1 | 108.404853 | 1.122048 |
|
||||
|
||||
## Nightly test details
|
||||
|
||||
See [nightly-descriptions.md](nightly-descriptions.md) for the detailed description on test workload, models and docker containers of benchmarking other llm engines.
|
||||
|
||||
@ -16,7 +16,7 @@ Please download the visualization scripts in the post
|
||||
- Download `nightly-benchmarks.zip`.
|
||||
- In the same folder, run the following code:
|
||||
|
||||
```console
|
||||
```bash
|
||||
export HF_TOKEN=<your HF token>
|
||||
apt update
|
||||
apt install -y git
|
||||
|
||||
@ -4,7 +4,8 @@
|
||||
- Input length: 32 tokens.
|
||||
- Output length: 128 tokens.
|
||||
- Batch size: fixed (8).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: end-to-end latency (mean, median, p99).
|
||||
|
||||
{latency_tests_markdown_table}
|
||||
@ -14,7 +15,8 @@
|
||||
- Input length: randomly sample 200 prompts from ShareGPT dataset (with fixed random seed).
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm to achieve maximum throughput.
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput.
|
||||
|
||||
{throughput_tests_markdown_table}
|
||||
@ -25,12 +27,18 @@
|
||||
- Output length: the corresponding output length of these 200 prompts.
|
||||
- Batch size: dynamically determined by vllm and the arrival pattern of the requests.
|
||||
- **Average QPS (query per second)**: 1, 4, 16 and inf. QPS = inf means all requests come at once. For other QPS values, the arrival time of each query is determined using a random Poisson process (with fixed random seed).
|
||||
- Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B, under QPS 2
|
||||
- GPU Models: llama-3.1 8B, llama-3 70B, mixtral 8x7B.
|
||||
- We also added a speculative decoding test for llama-3 70B on GPU, under QPS 2
|
||||
- CPU Models: llama-3.1 8B.
|
||||
- Evaluation metrics: throughput, TTFT (time to the first token, with mean, median and p99), ITL (inter-token latency, with mean, median and p99).
|
||||
- For CPU, we added random dataset tests to benchmark fixed input/output length with 100 prompts.
|
||||
|
||||
{serving_tests_markdown_table}
|
||||
|
||||
## Platform Information
|
||||
|
||||
{platform_markdown_table}
|
||||
|
||||
## json version of the benchmarking tables
|
||||
|
||||
This section contains the data of the markdown tables above in JSON format.
|
||||
|
||||
@ -0,0 +1,66 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
|
||||
|
||||
def compare_data_columns(
|
||||
files, name_column, data_column, drop_column, ignore_test_name=False
|
||||
):
|
||||
print("\ncompare_data_column: " + data_column)
|
||||
frames = []
|
||||
compare_frames = []
|
||||
for file in files:
|
||||
data_df = pd.read_json(file)
|
||||
serving_df = data_df.dropna(subset=[drop_column], ignore_index=True)
|
||||
if ignore_test_name is False:
|
||||
serving_df = serving_df.rename(columns={name_column: file + "_name"})
|
||||
frames.append(serving_df[file + "_name"])
|
||||
serving_df = serving_df.rename(columns={data_column: file})
|
||||
frames.append(serving_df[file])
|
||||
compare_frames.append(serving_df[file])
|
||||
if len(compare_frames) >= 2:
|
||||
# Compare numbers among two files
|
||||
ratio_df = compare_frames[1] / compare_frames[0]
|
||||
frames.append(ratio_df)
|
||||
compare_frames.pop(1)
|
||||
|
||||
concat_df = pd.concat(frames, axis=1)
|
||||
return concat_df
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-f", "--file", action="append", type=str, help="input file name"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ignore_test_name", action="store_true", help="ignore_test_name or not"
|
||||
)
|
||||
args = parser.parse_args()
|
||||
files = args.file
|
||||
print("comparing : " + ", ".join(files))
|
||||
|
||||
drop_column = "P99"
|
||||
name_column = "Test name"
|
||||
data_cols_to_compare = ["Output Tput (tok/s)", "Median TTFT (ms)", "Median"]
|
||||
html_msgs_for_data_cols = [
|
||||
"Compare Output Tokens /n",
|
||||
"Median TTFT /n",
|
||||
"Median TPOT /n",
|
||||
]
|
||||
ignore_test_name = args.ignore_test_name
|
||||
with open("perf_comparison.html", "w") as text_file:
|
||||
for i in range(len(data_cols_to_compare)):
|
||||
output_df = compare_data_columns(
|
||||
files,
|
||||
name_column,
|
||||
data_cols_to_compare[i],
|
||||
drop_column,
|
||||
ignore_test_name=ignore_test_name,
|
||||
)
|
||||
print(output_df)
|
||||
html = output_df.to_html()
|
||||
text_file.write(html_msgs_for_data_cols[i])
|
||||
text_file.write(html)
|
||||
@ -3,9 +3,11 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib import util
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import psutil
|
||||
from tabulate import tabulate
|
||||
|
||||
results_folder = Path("results/")
|
||||
@ -29,11 +31,11 @@ throughput_results = []
|
||||
throughput_results_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "num_requests": "# of req.",
|
||||
# "total_num_tokens": "Total # of tokens",
|
||||
# "elapsed_time": "Elapsed time (s)",
|
||||
"num_requests": "# of req.",
|
||||
"total_num_tokens": "Total # of tokens",
|
||||
"elapsed_time": "Elapsed time (s)",
|
||||
"requests_per_second": "Tput (req/s)",
|
||||
# "tokens_per_second": "Tput (tok/s)",
|
||||
"tokens_per_second": "Tput (tok/s)",
|
||||
}
|
||||
|
||||
# serving results and the keys that will be printed into markdown
|
||||
@ -41,16 +43,18 @@ serving_results = []
|
||||
serving_column_mapping = {
|
||||
"test_name": "Test name",
|
||||
"gpu_type": "GPU",
|
||||
# "completed": "# of req.",
|
||||
"completed": "# of req.",
|
||||
"request_throughput": "Tput (req/s)",
|
||||
# "input_throughput": "Input Tput (tok/s)",
|
||||
# "output_throughput": "Output Tput (tok/s)",
|
||||
"total_token_throughput": "Total Token Tput (tok/s)",
|
||||
"output_throughput": "Output Tput (tok/s)",
|
||||
"total_input_tokens": "Total input tokens",
|
||||
"total_output_tokens": "Total output tokens",
|
||||
"mean_ttft_ms": "Mean TTFT (ms)",
|
||||
"median_ttft_ms": "Median TTFT (ms)",
|
||||
"p99_ttft_ms": "P99 TTFT (ms)",
|
||||
# "mean_tpot_ms": "Mean TPOT (ms)",
|
||||
# "median_tpot_ms": "Median",
|
||||
# "p99_tpot_ms": "P99",
|
||||
"mean_tpot_ms": "Mean TPOT (ms)",
|
||||
"median_tpot_ms": "Median",
|
||||
"p99_tpot_ms": "P99",
|
||||
"mean_itl_ms": "Mean ITL (ms)",
|
||||
"median_itl_ms": "Median ITL (ms)",
|
||||
"p99_itl_ms": "P99 ITL (ms)",
|
||||
@ -75,6 +79,20 @@ def results_to_json(latency, throughput, serving):
|
||||
)
|
||||
|
||||
|
||||
def get_size_with_unit(bytes, suffix="B"):
|
||||
"""
|
||||
Scale bytes to its proper format
|
||||
e.g:
|
||||
1253656 => '1.20MB'
|
||||
1253656678 => '1.17GB'
|
||||
"""
|
||||
factor = 1024
|
||||
for unit in ["", "K", "M", "G", "T", "P"]:
|
||||
if bytes < factor:
|
||||
return f"{bytes:.2f}{unit}{suffix}"
|
||||
bytes /= factor
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# collect results
|
||||
for test_file in results_folder.glob("*.json"):
|
||||
@ -155,6 +173,27 @@ if __name__ == "__main__":
|
||||
serving_results = pd.DataFrame.from_dict(serving_results)
|
||||
throughput_results = pd.DataFrame.from_dict(throughput_results)
|
||||
|
||||
svmem = psutil.virtual_memory()
|
||||
platform_data = {
|
||||
"Physical cores": [psutil.cpu_count(logical=False)],
|
||||
"Total cores": [psutil.cpu_count(logical=True)],
|
||||
"Total Memory": [get_size_with_unit(svmem.total)],
|
||||
}
|
||||
|
||||
if util.find_spec("numa") is not None:
|
||||
from numa import info
|
||||
|
||||
platform_data["Total NUMA nodes"] = [info.get_num_configured_nodes()]
|
||||
|
||||
if util.find_spec("cpuinfo") is not None:
|
||||
from cpuinfo import get_cpu_info
|
||||
|
||||
platform_data["CPU Brand"] = [get_cpu_info()["brand_raw"]]
|
||||
|
||||
platform_results = pd.DataFrame.from_dict(
|
||||
platform_data, orient="index", columns=["Platform Info"]
|
||||
)
|
||||
|
||||
raw_results_json = results_to_json(
|
||||
latency_results, throughput_results, serving_results
|
||||
)
|
||||
@ -200,6 +239,9 @@ if __name__ == "__main__":
|
||||
throughput_md_table = tabulate(
|
||||
throughput_results, headers="keys", tablefmt="pipe", showindex=False
|
||||
)
|
||||
platform_md_table = tabulate(
|
||||
platform_results, headers="keys", tablefmt="pipe", showindex=True
|
||||
)
|
||||
|
||||
# document the result
|
||||
with open(results_folder / "benchmark_results.md", "w") as f:
|
||||
@ -211,6 +253,7 @@ if __name__ == "__main__":
|
||||
latency_tests_markdown_table=latency_md_table,
|
||||
throughput_tests_markdown_table=throughput_md_table,
|
||||
serving_tests_markdown_table=serving_md_table,
|
||||
platform_markdown_table=platform_md_table,
|
||||
benchmarking_results_in_json_string=processed_results_json,
|
||||
)
|
||||
f.write(results)
|
||||
|
||||
@ -31,6 +31,20 @@ check_gpus() {
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_cpus() {
|
||||
# check the number of CPUs and NUMA Node and GPU type.
|
||||
declare -g numa_count=$(python3 -c "from numa import info;numa_size = info.get_num_configured_nodes(); print(numa_size)")
|
||||
if [[ $numa_count -gt 0 ]]; then
|
||||
echo "NUMA found."
|
||||
echo $numa_count
|
||||
else
|
||||
echo "Need at least 1 NUMA to run benchmarking."
|
||||
exit 1
|
||||
fi
|
||||
declare -g gpu_type="cpu"
|
||||
echo "GPU type is $gpu_type"
|
||||
}
|
||||
|
||||
check_hf_token() {
|
||||
# check if HF_TOKEN is available and valid
|
||||
if [[ -z "$HF_TOKEN" ]]; then
|
||||
@ -69,6 +83,22 @@ json2args() {
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
json2envs() {
|
||||
# transforms the JSON string to environment variables.
|
||||
# example:
|
||||
# input: { "VLLM_CPU_KVCACHE_SPACE": 5 }
|
||||
# output: VLLM_CPU_KVCACHE_SPACE=5
|
||||
local json_string=$1
|
||||
local args=$(
|
||||
echo "$json_string" | jq -r '
|
||||
to_entries |
|
||||
map((.key ) + "=" + (.value | tostring)) |
|
||||
join(" ")
|
||||
'
|
||||
)
|
||||
echo "$args"
|
||||
}
|
||||
|
||||
wait_for_server() {
|
||||
# wait for vllm server to start
|
||||
# return 1 if vllm server crashes
|
||||
@ -158,15 +188,24 @@ run_latency_tests() {
|
||||
# get arguments
|
||||
latency_params=$(echo "$params" | jq -r '.parameters')
|
||||
latency_args=$(json2args "$latency_params")
|
||||
latency_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
latency_envs=$(json2envs "$latency_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$latency_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
latency_command="python3 benchmark_latency.py \
|
||||
latency_command=" $latency_envs python3 benchmark_latency.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$latency_args"
|
||||
|
||||
@ -216,15 +255,24 @@ run_throughput_tests() {
|
||||
# get arguments
|
||||
throughput_params=$(echo "$params" | jq -r '.parameters')
|
||||
throughput_args=$(json2args "$throughput_params")
|
||||
throughput_environment_variables=$(echo "$params" | jq -r '.environment_variables')
|
||||
throughput_envs=$(json2envs "$throughput_environment_variables")
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
tp=$(echo "$throughput_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
throughput_command="python3 benchmark_throughput.py \
|
||||
throughput_command=" $throughput_envs python3 benchmark_throughput.py \
|
||||
--output-json $RESULTS_FOLDER/${test_name}.json \
|
||||
$throughput_args"
|
||||
|
||||
@ -272,18 +320,27 @@ run_serving_tests() {
|
||||
|
||||
# get client and server arguments
|
||||
server_params=$(echo "$params" | jq -r '.server_parameters')
|
||||
server_envs=$(echo "$params" | jq -r '.server_environment_variables')
|
||||
client_params=$(echo "$params" | jq -r '.client_parameters')
|
||||
server_args=$(json2args "$server_params")
|
||||
server_envs=$(json2envs "$server_envs")
|
||||
client_args=$(json2args "$client_params")
|
||||
qps_list=$(echo "$params" | jq -r '.qps_list')
|
||||
qps_list=$(echo "$qps_list" | jq -r '.[] | @sh')
|
||||
echo "Running over qps list $qps_list"
|
||||
|
||||
# check if there is enough GPU to run the test
|
||||
# check if there is enough resources to run the test
|
||||
tp=$(echo "$server_params" | jq -r '.tensor_parallel_size')
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
if [[ $numa_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $numa_count NUMA nodes found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
else
|
||||
if [[ $gpu_count -lt $tp ]]; then
|
||||
echo "Required tensor-parallel-size $tp but only $gpu_count GPU found. Skip testcase $test_name."
|
||||
continue
|
||||
fi
|
||||
fi
|
||||
|
||||
# check if server model and client model is aligned
|
||||
@ -294,23 +351,33 @@ run_serving_tests() {
|
||||
continue
|
||||
fi
|
||||
|
||||
server_command="python3 \
|
||||
server_command="$server_envs python3 \
|
||||
-m vllm.entrypoints.openai.api_server \
|
||||
$server_args"
|
||||
|
||||
# run the server
|
||||
echo "Running test case $test_name"
|
||||
echo "Server command: $server_command"
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vllm server is up and running."
|
||||
# support remote vllm server
|
||||
client_remote_args=""
|
||||
if [[ -z "${REMOTE_HOST}" ]]; then
|
||||
bash -c "$server_command" &
|
||||
server_pid=$!
|
||||
# wait until the server is alive
|
||||
if wait_for_server; then
|
||||
echo ""
|
||||
echo "vLLM server is up and running."
|
||||
else
|
||||
echo ""
|
||||
echo "vLLM failed to start within the timeout period."
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
echo "vllm failed to start within the timeout period."
|
||||
server_command="Using Remote Server $REMOTE_HOST $REMOTE_PORT"
|
||||
if [[ ${REMOTE_PORT} ]]; then
|
||||
client_remote_args=" --host=$REMOTE_HOST --port=$REMOTE_PORT "
|
||||
else
|
||||
client_remote_args=" --host=$REMOTE_HOST "
|
||||
fi
|
||||
fi
|
||||
|
||||
# iterate over different QPS
|
||||
@ -332,7 +399,7 @@ run_serving_tests() {
|
||||
--result-filename ${new_test_name}.json \
|
||||
--request-rate $qps \
|
||||
--metadata "tensor_parallel_size=$tp" \
|
||||
$client_args"
|
||||
$client_args $client_remote_args "
|
||||
|
||||
echo "Running test case $test_name with qps $qps"
|
||||
echo "Client command: $client_command"
|
||||
@ -360,7 +427,14 @@ run_serving_tests() {
|
||||
}
|
||||
|
||||
main() {
|
||||
check_gpus
|
||||
local ARCH
|
||||
ARCH=''
|
||||
if [ "$ON_CPU" == "1" ];then
|
||||
check_cpus
|
||||
ARCH='-cpu'
|
||||
else
|
||||
check_gpus
|
||||
fi
|
||||
check_hf_token
|
||||
|
||||
# Set to v1 to run v1 benchmark
|
||||
@ -386,9 +460,9 @@ main() {
|
||||
QUICK_BENCHMARK_ROOT=../.buildkite/nightly-benchmarks/
|
||||
|
||||
# benchmarking
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/serving-tests.json
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/latency-tests.json
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/throughput-tests.json
|
||||
run_serving_tests $QUICK_BENCHMARK_ROOT/tests/"${SERVING_JSON:-serving-tests$ARCH.json}"
|
||||
run_latency_tests $QUICK_BENCHMARK_ROOT/tests/"${LATENCY_JSON:-latency-tests$ARCH.json}"
|
||||
run_throughput_tests $QUICK_BENCHMARK_ROOT/tests/"${THROUGHPUT_JSON:-throughput-tests$ARCH.json}"
|
||||
|
||||
# postprocess benchmarking results
|
||||
pip install tabulate pandas
|
||||
|
||||
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
30
.buildkite/nightly-benchmarks/tests/latency-tests-cpu.json
Normal file
@ -0,0 +1,30 @@
|
||||
[
|
||||
{
|
||||
"test_name": "latency_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "latency_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"num_iters_warmup": 5,
|
||||
"num_iters": 15
|
||||
}
|
||||
}
|
||||
]
|
||||
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
158
.buildkite/nightly-benchmarks/tests/serving-tests-cpu.json
Normal file
@ -0,0 +1,158 @@
|
||||
[
|
||||
{
|
||||
"test_name": "serving_llama8B_tp1_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp2_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 2,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_sharegpt",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "sharegpt",
|
||||
"dataset_path": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"max_concurrency": 60,
|
||||
"num_prompts": 200
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_tp4_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "serving_llama8B_pp6_random_1024_128",
|
||||
"qps_list": [1, 4, 16, "inf"],
|
||||
"server_environment_variables": {
|
||||
"VLLM_RPC_TIMEOUT": 100000,
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_ENGINE_ITERATION_TIMEOUT_S": 120,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"server_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"pipeline_parallel_size": 6,
|
||||
"dtype": "bfloat16",
|
||||
"distributed_executor_backend": "mp",
|
||||
"block_size": 128,
|
||||
"trust_remote_code": "",
|
||||
"enable_chunked_prefill": "",
|
||||
"disable_log_stats": "",
|
||||
"disable_log_requests": "",
|
||||
"enforce_eager": "",
|
||||
"load_format": "dummy"
|
||||
},
|
||||
"client_parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"backend": "vllm",
|
||||
"dataset_name": "random",
|
||||
"random-input-len": 1024,
|
||||
"random-output-len": 128,
|
||||
"ignore-eos": "",
|
||||
"max_concurrency": 100,
|
||||
"num_prompts": 100
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -0,0 +1,32 @@
|
||||
[
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp1",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 1,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
},
|
||||
{
|
||||
"test_name": "throughput_llama8B_tp4",
|
||||
"environment_variables": {
|
||||
"VLLM_ALLOW_LONG_MAX_MODEL_LEN": 1,
|
||||
"VLLM_CPU_KVCACHE_SPACE": 40
|
||||
},
|
||||
"parameters": {
|
||||
"model": "meta-llama/Meta-Llama-3.1-8B-Instruct",
|
||||
"tensor_parallel_size": 4,
|
||||
"load_format": "dummy",
|
||||
"dataset": "./ShareGPT_V3_unfiltered_cleaned_split.json",
|
||||
"num_prompts": 200,
|
||||
"backend": "vllm"
|
||||
}
|
||||
}
|
||||
]
|
||||
@ -101,7 +101,8 @@ steps:
|
||||
queue: cpu_queue_postmerge
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --build-arg VLLM_CPU_AVX512BF16=true --build-arg VLLM_CPU_AVX512VNNI=true --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest --progress plain --target vllm-openai -f docker/Dockerfile.cpu ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-cpu-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
@ -117,6 +118,7 @@ steps:
|
||||
commands:
|
||||
- "aws ecr-public get-login-password --region us-east-1 | docker login --username AWS --password-stdin public.ecr.aws/q9t5s3a7"
|
||||
- "DOCKER_BUILDKIT=1 docker build --build-arg max_jobs=16 --build-arg GIT_REPO_CHECK=1 --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version) --tag public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest --progress plain -f docker/Dockerfile.neuron ."
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:latest"
|
||||
- "docker push public.ecr.aws/q9t5s3a7/vllm-neuron-release-repo:$(buildkite-agent meta-data get release-version)"
|
||||
env:
|
||||
DOCKER_BUILDKIT: "1"
|
||||
|
||||
@ -51,6 +51,7 @@ function cpu_tests() {
|
||||
pytest -v -s tests/kernels/attention/test_cache.py -m cpu_model
|
||||
pytest -v -s tests/kernels/attention/test_mla_decode_cpu.py -m cpu_model
|
||||
pytest -v -s tests/models/language/generation -m cpu_model
|
||||
VLLM_CPU_SGL_KERNEL=1 pytest -v -s tests/models/language/generation -m cpu_model
|
||||
pytest -v -s tests/models/language/pooling -m cpu_model
|
||||
pytest -v -s tests/models/multimodal/generation \
|
||||
--ignore=tests/models/multimodal/generation/test_mllama.py \
|
||||
@ -98,4 +99,4 @@ function cpu_tests() {
|
||||
|
||||
# All of CPU tests are expected to be finished less than 40 mins.
|
||||
export -f cpu_tests
|
||||
timeout 1h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
timeout 1.5h bash -c "cpu_tests $CORE_RANGE $NUMA_NODE"
|
||||
|
||||
@ -2,10 +2,34 @@
|
||||
|
||||
# This script build the CPU docker image and run the offline inference inside the container.
|
||||
# It serves a sanity check for compilation and basic model usage.
|
||||
set -ex
|
||||
set -exuo pipefail
|
||||
|
||||
# Try building the docker image
|
||||
docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
cat <<EOF | docker build -t hpu-plugin-v1-test-env -f - .
|
||||
FROM 1.22-413-pt2.7.1:latest
|
||||
|
||||
COPY ./ /workspace/vllm
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN pip install -v -r requirements/hpu.txt
|
||||
RUN pip install git+https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
ENV no_proxy=localhost,127.0.0.1
|
||||
ENV PT_HPU_ENABLE_LAZY_COLLECTIVES=true
|
||||
|
||||
RUN VLLM_TARGET_DEVICE=hpu python3 setup.py install
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN python3 -m pip install -e tests/vllm_test_utils
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN git clone https://github.com/vllm-project/vllm-gaudi.git
|
||||
|
||||
RUN ln -s /workspace/vllm/tests && ln -s /workspace/vllm/examples && ln -s /workspace/vllm/benchmarks
|
||||
|
||||
EOF
|
||||
|
||||
# Setup cleanup
|
||||
# certain versions of HPU software stack have a bug that can
|
||||
@ -14,13 +38,21 @@ docker build -t hpu-test-env -f docker/Dockerfile.hpu .
|
||||
# functions, while other platforms only need one remove_docker_container
|
||||
# function.
|
||||
EXITCODE=1
|
||||
remove_docker_containers() { docker rm -f hpu-test || true; docker rm -f hpu-test-tp2 || true; }
|
||||
remove_docker_containers_and_exit() { remove_docker_containers; exit $EXITCODE; }
|
||||
trap remove_docker_containers_and_exit EXIT
|
||||
remove_docker_containers() { docker rm -f hpu-plugin-v1-test || true; }
|
||||
trap 'remove_docker_containers; exit $EXITCODE;' EXIT
|
||||
remove_docker_containers
|
||||
|
||||
# Run the image and launch offline inference
|
||||
docker run --runtime=habana --name=hpu-test --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
docker run --runtime=habana --name=hpu-test-tp2 --network=host -e HABANA_VISIBLE_DEVICES=all -e VLLM_SKIP_WARMUP=true --entrypoint="" hpu-test-env python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --tensor-parallel-size 2
|
||||
echo "Running HPU plugin v1 test"
|
||||
docker run --rm --runtime=habana --name=hpu-plugin-v1-test --network=host \
|
||||
-e HABANA_VISIBLE_DEVICES=all \
|
||||
hpu-plugin-v1-test-env \
|
||||
/bin/bash "/workspace/vllm-gaudi/tests/upstream_tests/ci_tests.sh"
|
||||
|
||||
EXITCODE=$?
|
||||
if [ $EXITCODE -eq 0 ]; then
|
||||
echo "Test with basic model passed"
|
||||
else
|
||||
echo "Test with basic model FAILED with exit code: $EXITCODE" >&2
|
||||
fi
|
||||
|
||||
# The trap will handle the container removal and final exit.
|
||||
@ -54,10 +54,11 @@ docker run --rm -it --device=/dev/neuron0 --network bridge \
|
||||
--name "${container_name}" \
|
||||
${image_name} \
|
||||
/bin/bash -c "
|
||||
set -e; # Exit on first error
|
||||
python3 /workspace/vllm/examples/offline_inference/neuron.py;
|
||||
python3 -m pytest /workspace/vllm/tests/neuron/1_core/ -v --capture=tee-sys;
|
||||
for f in /workspace/vllm/tests/neuron/2_core/*.py; do
|
||||
echo 'Running test file: '$f;
|
||||
echo \"Running test file: \$f\";
|
||||
python3 -m pytest \$f -v --capture=tee-sys;
|
||||
done
|
||||
"
|
||||
@ -159,6 +159,8 @@ run_and_track_test 14 "test_tpu_qkv_linear.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_tpu_qkv_linear.py"
|
||||
run_and_track_test 15 "test_spmd_model_weight_loading.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_spmd_model_weight_loading.py"
|
||||
run_and_track_test 16 "test_kv_cache_update_kernel.py" \
|
||||
"python3 -m pytest -s -v /workspace/vllm/tests/v1/tpu/test_kv_cache_update_kernel.py"
|
||||
|
||||
# After all tests have been attempted, exit with the overall status.
|
||||
if [ "$overall_script_exit_code" -ne 0 ]; then
|
||||
|
||||
@ -28,4 +28,5 @@ docker run \
|
||||
sh -c '
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m
|
||||
VLLM_USE_V1=0 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m -tp 2
|
||||
VLLM_USE_V1=1 python3 examples/offline_inference/basic/generate.py --model facebook/opt-125m --block-size 64 --enforce-eager
|
||||
'
|
||||
|
||||
@ -4,8 +4,8 @@ CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=meta-llama/Llama-3.1-8B-Instruct
|
||||
MAX_NUM_SEQS=512
|
||||
MAX_NUM_BATCHED_TOKENS=512
|
||||
MAX_NUM_SEQS=256
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
|
||||
@ -68,7 +68,7 @@ docker run \
|
||||
|
||||
echo "run script..."
|
||||
echo
|
||||
docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/hardware_ci/run_bm.sh"
|
||||
docker exec "$CONTAINER_NAME" /bin/bash -c ".buildkite/scripts/tpu/run_bm.sh"
|
||||
|
||||
echo "copy result back..."
|
||||
VLLM_LOG="$LOG_ROOT/$TEST_NAME"_vllm_log.txt
|
||||
|
||||
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
14
.buildkite/scripts/tpu/quantized_v6e_1.env
Normal file
@ -0,0 +1,14 @@
|
||||
# Environment config
|
||||
TEST_NAME=llama8bw8a8
|
||||
CONTAINER_NAME=vllm-tpu
|
||||
|
||||
# vllm config
|
||||
MODEL=RedHatAI/Meta-Llama-3.1-8B-Instruct-quantized.w8a8
|
||||
MAX_NUM_SEQS=128
|
||||
MAX_NUM_BATCHED_TOKENS=1024
|
||||
TENSOR_PARALLEL_SIZE=1
|
||||
MAX_MODEL_LEN=2048
|
||||
DOWNLOAD_DIR=/mnt/disks/persist
|
||||
EXPECTED_THROUGHPUT=10.0
|
||||
INPUT_LEN=1800
|
||||
OUTPUT_LEN=128
|
||||
@ -41,6 +41,16 @@ steps:
|
||||
# TODO: add `--strict` once warnings in docstrings are fixed
|
||||
- mkdocs build
|
||||
|
||||
- label: Pytorch Nightly Dependency Override Check # 2min
|
||||
# if this test fails, it means the nightly torch version is not compatible with some
|
||||
# of the dependencies. Please check the error message and add the package to whitelist
|
||||
# in /vllm/tools/generate_nightly_torch_test.py
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- requirements/nightly_torch_test.txt
|
||||
commands:
|
||||
- bash standalone_tests/pytorch_nightly_dependency.sh
|
||||
|
||||
- label: Async Engine, Inputs, Utils, Worker Test # 24min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
@ -89,7 +99,7 @@ steps:
|
||||
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
|
||||
|
||||
- label: Chunked Prefill Test
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/basic_correctness/test_chunked_prefill
|
||||
@ -145,6 +155,7 @@ steps:
|
||||
- examples/offline_inference/rlhf_colocate.py
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/engine/test_engine_core_client.py
|
||||
commands:
|
||||
# test with tp=2 and external_dp=2
|
||||
@ -153,8 +164,9 @@ steps:
|
||||
# test with tp=2 and pp=2
|
||||
- PP_SIZE=2 torchrun --nproc-per-node=4 distributed/test_torchrun_example.py
|
||||
# test with internal dp
|
||||
- python3 ../examples/offline_inference/data_parallel.py
|
||||
- python3 ../examples/offline_inference/data_parallel.py --enforce-eager
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=2 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- pytest -v -s v1/engine/test_engine_core_client.py::test_kv_cache_events_dp
|
||||
- pytest -v -s distributed/test_utils.py
|
||||
- pytest -v -s compile/test_basic_correctness.py
|
||||
@ -168,6 +180,23 @@ steps:
|
||||
- VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py
|
||||
- popd
|
||||
|
||||
- label: EPLB Algorithm Test
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/eplb
|
||||
- tests/distributed/test_eplb_algo.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_eplb_algo.py
|
||||
|
||||
- label: EPLB Execution Test # 5min
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 4
|
||||
source_file_dependencies:
|
||||
- vllm/distributed/eplb
|
||||
- tests/distributed/test_eplb_execute.py
|
||||
commands:
|
||||
- pytest -v -s distributed/test_eplb_execute.py
|
||||
|
||||
- label: Metrics, Tracing Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
num_gpus: 2
|
||||
@ -188,7 +217,7 @@ steps:
|
||||
##### 1 GPU test #####
|
||||
|
||||
- label: Regression Test # 5min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/test_regression
|
||||
@ -198,7 +227,7 @@ steps:
|
||||
working_dir: "/vllm-workspace/tests" # optional
|
||||
|
||||
- label: Engine Test # 10min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/engine
|
||||
@ -271,6 +300,15 @@ steps:
|
||||
commands:
|
||||
- pytest -v -s prefix_caching
|
||||
|
||||
|
||||
- label: Platform Tests (CUDA)
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/cuda
|
||||
commands:
|
||||
- pytest -v -s cuda/test_cuda_context.py
|
||||
|
||||
- label: Samplers Test # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
source_file_dependencies:
|
||||
@ -302,7 +340,7 @@ steps:
|
||||
parallelism: 4
|
||||
|
||||
- label: PyTorch Compilation Unit Tests
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -384,7 +422,7 @@ steps:
|
||||
- pytest -v -s kernels/mamba
|
||||
|
||||
- label: Tensorizer Test # 11min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
soft_fail: true
|
||||
source_file_dependencies:
|
||||
- vllm/model_executor/model_loader
|
||||
@ -476,7 +514,7 @@ steps:
|
||||
##### models test #####
|
||||
|
||||
- label: Basic Models Test # 24min
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -500,6 +538,17 @@ steps:
|
||||
- pip freeze | grep -E 'torch'
|
||||
- pytest -v -s models/language -m core_model
|
||||
|
||||
- label: Language Models Test (Hybrid) # 35 min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
torch_nightly: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
- tests/models/language/generation
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m hybrid_model
|
||||
|
||||
- label: Language Models Test (Extended Generation) # 1hr20min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
@ -509,7 +558,7 @@ steps:
|
||||
commands:
|
||||
# Install causal-conv1d for plamo2 models here, as it is not compatible with pip-compile.
|
||||
- pip install 'git+https://github.com/Dao-AILab/causal-conv1d@v1.5.0.post8'
|
||||
- pytest -v -s models/language/generation -m 'not core_model'
|
||||
- pytest -v -s models/language/generation -m '(not core_model) and (not hybrid_model)'
|
||||
|
||||
- label: Language Models Test (Extended Pooling) # 36min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -554,7 +603,7 @@ steps:
|
||||
- pytest -v -s models/multimodal/generation/test_common.py -m 'split(group=0) and not core_model'
|
||||
|
||||
- label: Multi-Modal Models Test (Extended) 3
|
||||
mirror_hardwares: [amdexperimental, amdproduction]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
optional: true
|
||||
source_file_dependencies:
|
||||
- vllm/
|
||||
@ -606,13 +655,18 @@ steps:
|
||||
- vllm/executor/
|
||||
- vllm/model_executor/models/
|
||||
- tests/distributed/
|
||||
- tests/examples/offline_inference/data_parallel.py
|
||||
commands:
|
||||
- # the following commands are for the first node, with ip 192.168.10.10 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=0 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_multi_node_assignment.py
|
||||
- VLLM_MULTI_NODE=1 pytest -v -s distributed/test_pipeline_parallel.py
|
||||
- # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up)
|
||||
- VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep 'Same node test passed'
|
||||
- NUM_NODES=2 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_node_count.py | grep 'Node count test passed'
|
||||
- python3 ../examples/offline_inference/data_parallel.py --dp-size=2 --tp-size=1 --node-size=2 --node-rank=1 --master-addr=192.168.10.10 --master-port=12345 --enforce-eager --trust-remote-code
|
||||
|
||||
- label: Distributed Tests (2 GPUs) # 40min
|
||||
mirror_hardwares: [amdexperimental]
|
||||
@ -630,10 +684,12 @@ steps:
|
||||
- vllm/worker/model_runner.py
|
||||
- entrypoints/llm/test_collective_rpc.py
|
||||
- tests/v1/test_async_llm_dp.py
|
||||
- tests/v1/test_external_lb_dp.py
|
||||
- tests/v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- vllm/v1/engine/
|
||||
commands:
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_async_llm_dp.py
|
||||
- TP_SIZE=1 DP_SIZE=2 pytest -v -s v1/test_external_lb_dp.py
|
||||
- DP_SIZE=2 pytest -v -s v1/entrypoints/openai/test_multi_api_servers.py
|
||||
- pytest -v -s entrypoints/llm/test_collective_rpc.py
|
||||
- pytest -v -s ./compile/test_basic_correctness.py
|
||||
@ -736,7 +792,7 @@ steps:
|
||||
- bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models.txt
|
||||
|
||||
- label: Weight Loading Multiple GPU Test - Large Models # optional
|
||||
mirror_hardwares: [amdexperimental]
|
||||
mirror_hardwares: [amdexperimental]
|
||||
working_dir: "/vllm-workspace/tests"
|
||||
num_gpus: 2
|
||||
gpu: a100
|
||||
|
||||
4
.github/CODEOWNERS
vendored
4
.github/CODEOWNERS
vendored
@ -18,6 +18,10 @@
|
||||
/vllm/entrypoints @aarnphm
|
||||
CMakeLists.txt @tlrmchlsmth
|
||||
|
||||
# Any change to the VllmConfig changes can have a large user-facing impact,
|
||||
# so spam a lot of people
|
||||
/vllm/config.py @simon-mo @WoosukKwon @youkaichao @robertgshaw2-redhat @mgoin @tlrmchlsmth @houseroad @hmellor
|
||||
|
||||
# vLLM V1
|
||||
/vllm/v1 @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @comaniac @alexm-redhat
|
||||
/vllm/v1/structured_output @mgoin @russellb @aarnphm
|
||||
|
||||
52
.github/mergify.yml
vendored
52
.github/mergify.yml
vendored
@ -27,6 +27,22 @@ pull_request_rules:
|
||||
add:
|
||||
- ci/build
|
||||
|
||||
- name: label-deepseek
|
||||
description: Automatically apply deepseek label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^examples/.*deepseek.*\.py
|
||||
- files~=^tests/.*deepseek.*\.py
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/.*deepseek.*\.py
|
||||
- files~=^vllm/model_executor/models/.*deepseek.*\.py
|
||||
- files~=^vllm/reasoning/.*deepseek.*\.py
|
||||
- files~=^vllm/transformers_utils/.*deepseek.*\.py
|
||||
- title~=(?i)DeepSeek
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- deepseek
|
||||
|
||||
- name: label-frontend
|
||||
description: Automatically apply frontend label
|
||||
conditions:
|
||||
@ -45,6 +61,7 @@ pull_request_rules:
|
||||
- files~=^vllm/entrypoints/openai/tool_parsers/llama.*\.py
|
||||
- files~=^vllm/model_executor/models/.*llama.*\.py
|
||||
- files~=^vllm/transformers_utils/configs/.*llama.*\.py
|
||||
- title~=(?i)llama
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -57,14 +74,38 @@ pull_request_rules:
|
||||
- files~=^vllm/multimodal/
|
||||
- files~=^tests/multimodal/
|
||||
- files~=^tests/models/multimodal/
|
||||
- files~=^tests/models/*/audio_language/
|
||||
- files~=^tests/models/*/vision_language/
|
||||
- files=tests/models/test_vision.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- multi-modality
|
||||
|
||||
- name: label-new-model
|
||||
description: Automatically apply new-model label
|
||||
conditions:
|
||||
- and:
|
||||
- files~=^vllm/model_executor/models/
|
||||
- files=vllm/model_executor/models/registry.py
|
||||
- files=tests/models/registry.py
|
||||
- files=docs/models/supported_models.md
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- new-model
|
||||
|
||||
- name: label-performance
|
||||
description: Automatically apply performance label
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^benchmarks/
|
||||
- files~=^vllm/benchmarks/
|
||||
- files~=^tests/benchmarks/
|
||||
- files~=^\.buildkite/nightly-benchmarks/
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
- performance
|
||||
|
||||
- name: label-qwen
|
||||
description: Automatically apply qwen label
|
||||
conditions:
|
||||
@ -74,7 +115,6 @@ pull_request_rules:
|
||||
- files~=^vllm/model_executor/models/.*qwen.*\.py
|
||||
- files~=^vllm/reasoning/.*qwen.*\.py
|
||||
- title~=(?i)Qwen
|
||||
- body~=(?i)Qwen
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
@ -127,8 +167,14 @@ pull_request_rules:
|
||||
conditions:
|
||||
- or:
|
||||
- files~=^vllm/spec_decode/
|
||||
- files~=^vllm/v1/spec_decode/
|
||||
- files=vllm/model_executor/layers/spec_decode_base_sampler.py
|
||||
- files~=^tests/spec_decode/
|
||||
- files~=^tests/v1/spec_decode/
|
||||
- files~=^examples/.*(spec_decode|mlpspeculator|eagle|speculation).*\.py
|
||||
- files~=^vllm/model_executor/models/.*eagle.*\.py
|
||||
- files=vllm/model_executor/models/mlp_speculator.py
|
||||
- files~=^vllm/transformers_utils/configs/(eagle|medusa|mlp_speculator)\.py
|
||||
actions:
|
||||
label:
|
||||
add:
|
||||
|
||||
@ -53,6 +53,11 @@ repos:
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- repo: local
|
||||
hooks:
|
||||
- id: format-torch-nightly-test
|
||||
name: reformat nightly_torch_test.txt to be in sync with test.in
|
||||
language: python
|
||||
entry: python tools/generate_nightly_torch_test.py
|
||||
files: ^requirements/test\.(in|txt)$
|
||||
- id: mypy-local
|
||||
name: Run mypy for local Python installation
|
||||
entry: tools/mypy.sh 0 "local"
|
||||
@ -115,6 +120,11 @@ repos:
|
||||
entry: python tools/check_spdx_header.py
|
||||
language: python
|
||||
types: [python]
|
||||
- id: check-root-lazy-imports
|
||||
name: Check root lazy imports
|
||||
entry: python tools/check_init_lazy_imports.py
|
||||
language: python
|
||||
types: [python]
|
||||
- id: check-filenames
|
||||
name: Check for spaces in all filenames
|
||||
entry: bash
|
||||
@ -150,6 +160,13 @@ repos:
|
||||
types: [python]
|
||||
pass_filenames: false
|
||||
additional_dependencies: [pathspec, regex]
|
||||
- id: validate-config
|
||||
name: Validate configuration has default values and that each field has a docstring
|
||||
entry: python tools/validate_config.py
|
||||
language: python
|
||||
types: [python]
|
||||
pass_filenames: true
|
||||
files: vllm/config.py|tests/test_config.py
|
||||
# Keep `suggestion` last
|
||||
- id: suggestion
|
||||
name: Suggestion
|
||||
|
||||
@ -420,6 +420,36 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Geforce Blackwell SM120 (c3x, i.e. CUTLASS 3.x) require
|
||||
# CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0;12.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
set(SRCS
|
||||
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
|
||||
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
|
||||
)
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1")
|
||||
# Let scaled_mm_c2x know it doesn't need to build these arches
|
||||
list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}")
|
||||
message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8 AND SCALED_MM_ARCHS)
|
||||
message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is "
|
||||
"not >= 12.8, we recommend upgrading to CUDA 12.8 or "
|
||||
"later if you intend on running FP8 quantized models on "
|
||||
"Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
|
||||
# The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x)
|
||||
# require CUDA 12.8 or later
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a" "${CUDA_ARCHS}")
|
||||
@ -513,6 +543,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
CUDA_ARCHS "${FP4_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_NVFP4=1")
|
||||
list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1")
|
||||
message(STATUS "Building NVFP4 for archs: ${FP4_ARCHS}")
|
||||
else()
|
||||
message(STATUS "Not building NVFP4 as no compatible archs were found.")
|
||||
@ -547,8 +578,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if it's possible to compile MoE kernels that use its output.
|
||||
cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu"
|
||||
"csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/grouped_mm_c3x.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${SCALED_MM_ARCHS}")
|
||||
@ -562,7 +592,27 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
"if you intend on running FP8 quantized MoE models on Hopper.")
|
||||
else()
|
||||
message(STATUS "Not building grouped_mm_c3x as no compatible archs found "
|
||||
"in CUDA target architectures")
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# moe_data.cu is used by all CUTLASS MoE kernels.
|
||||
cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a" "${CUDA_ARCHS}")
|
||||
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
set(SRCS "csrc/quantization/cutlass_w8a8/moe/moe_data.cu")
|
||||
set_gencode_flags_for_srcs(
|
||||
SRCS "${SRCS}"
|
||||
CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}")
|
||||
list(APPEND VLLM_EXT_SRC "${SRCS}")
|
||||
message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}")
|
||||
else()
|
||||
if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS)
|
||||
message(STATUS "Not building moe_data as CUDA Compiler version is "
|
||||
"not >= 12.3, we recommend upgrading to CUDA 12.3 or later "
|
||||
"if you intend on running FP8 quantized MoE models on Hopper or Blackwell.")
|
||||
else()
|
||||
message(STATUS "Not building moe_data as no compatible archs found "
|
||||
"in CUDA target architectures.")
|
||||
endif()
|
||||
endif()
|
||||
|
||||
@ -638,6 +688,14 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
|
||||
# if CUDA endif
|
||||
endif()
|
||||
|
||||
if (VLLM_GPU_LANG STREQUAL "HIP")
|
||||
# Add QuickReduce kernels
|
||||
list(APPEND VLLM_EXT_SRC
|
||||
"csrc/custom_quickreduce.cu"
|
||||
)
|
||||
# if ROCM endif
|
||||
endif()
|
||||
|
||||
message(STATUS "Enabling C extension.")
|
||||
define_gpu_extension_target(
|
||||
_C
|
||||
|
||||
@ -154,11 +154,13 @@ If you use vLLM for your research, please cite our [paper](https://arxiv.org/abs
|
||||
|
||||
## Contact Us
|
||||
|
||||
<!-- --8<-- [start:contact-us] -->
|
||||
- For technical questions and feature requests, please use GitHub [Issues](https://github.com/vllm-project/vllm/issues) or [Discussions](https://github.com/vllm-project/vllm/discussions)
|
||||
- For discussing with fellow users, please use the [vLLM Forum](https://discuss.vllm.ai)
|
||||
- For coordinating contributions and development, please use [Slack](https://slack.vllm.ai)
|
||||
- For security disclosures, please use GitHub's [Security Advisories](https://github.com/vllm-project/vllm/security/advisories) feature
|
||||
- For collaborations and partnerships, please contact us at [vllm-questions@lists.berkeley.edu](mailto:vllm-questions@lists.berkeley.edu)
|
||||
<!-- --8<-- [end:contact-us] -->
|
||||
|
||||
## Media Kit
|
||||
|
||||
|
||||
@ -4,7 +4,7 @@ This README guides you through running benchmark tests with the extensive
|
||||
datasets supported on vLLM. It’s a living document, updated as new features and datasets
|
||||
become available.
|
||||
|
||||
## Dataset Overview
|
||||
**Dataset Overview**
|
||||
|
||||
<table style="width:100%; border-collapse: collapse;">
|
||||
<thead>
|
||||
@ -82,7 +82,10 @@ become available.
|
||||
**Note**: HuggingFace dataset's `dataset-name` should be set to `hf`
|
||||
|
||||
---
|
||||
## Example - Online Benchmark
|
||||
<details>
|
||||
<summary><b>🚀 Example - Online Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
First start serving your model
|
||||
|
||||
@ -130,7 +133,8 @@ P99 ITL (ms): 8.39
|
||||
==================================================
|
||||
```
|
||||
|
||||
### Custom Dataset
|
||||
**Custom Dataset**
|
||||
|
||||
If the dataset you want to benchmark is not supported yet in vLLM, even then you can benchmark on it using `CustomDataset`. Your data needs to be in `.jsonl` format and needs to have "prompt" field per entry, e.g., data.jsonl
|
||||
|
||||
```
|
||||
@ -162,7 +166,7 @@ python3 benchmarks/benchmark_serving.py --port 9001 --save-result --save-detaile
|
||||
|
||||
You can skip applying chat template if your data already has it by using `--custom-skip-chat-template`.
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
**VisionArena Benchmark for Vision Language Models**
|
||||
|
||||
```bash
|
||||
# need a model with vision capability here
|
||||
@ -180,7 +184,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
**InstructCoder Benchmark with Speculative Decoding**
|
||||
|
||||
``` bash
|
||||
VLLM_USE_V1=1 vllm serve meta-llama/Meta-Llama-3-8B-Instruct \
|
||||
@ -197,7 +201,7 @@ python3 benchmarks/benchmark_serving.py \
|
||||
--num-prompts 2048
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
**Other HuggingFaceDataset Examples**
|
||||
|
||||
```bash
|
||||
vllm serve Qwen/Qwen2-VL-7B-Instruct --disable-log-requests
|
||||
@ -251,7 +255,7 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 80
|
||||
```
|
||||
|
||||
### Running With Sampling Parameters
|
||||
**Running With Sampling Parameters**
|
||||
|
||||
When using OpenAI-compatible backends such as `vllm`, optional sampling
|
||||
parameters can be specified. Example client command:
|
||||
@ -269,8 +273,27 @@ python3 vllm/benchmarks/benchmark_serving.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
---
|
||||
## Example - Offline Throughput Benchmark
|
||||
**Running With Ramp-Up Request Rate**
|
||||
|
||||
The benchmark tool also supports ramping up the request rate over the
|
||||
duration of the benchmark run. This can be useful for stress testing the
|
||||
server or finding the maximum throughput that it can handle, given some latency budget.
|
||||
|
||||
Two ramp-up strategies are supported:
|
||||
- `linear`: Increases the request rate linearly from a start value to an end value.
|
||||
- `exponential`: Increases the request rate exponentially.
|
||||
|
||||
The following arguments can be used to control the ramp-up:
|
||||
- `--ramp-up-strategy`: The ramp-up strategy to use (`linear` or `exponential`).
|
||||
- `--ramp-up-start-rps`: The request rate at the beginning of the benchmark.
|
||||
- `--ramp-up-end-rps`: The request rate at the end of the benchmark.
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>📈 Example - Offline Throughput Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
```bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
@ -288,7 +311,7 @@ Total num prompt tokens: 5014
|
||||
Total num output tokens: 1500
|
||||
```
|
||||
|
||||
### VisionArena Benchmark for Vision Language Models
|
||||
**VisionArena Benchmark for Vision Language Models**
|
||||
|
||||
``` bash
|
||||
python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
@ -308,7 +331,7 @@ Total num prompt tokens: 14527
|
||||
Total num output tokens: 1280
|
||||
```
|
||||
|
||||
### InstructCoder Benchmark with Speculative Decoding
|
||||
**InstructCoder Benchmark with Speculative Decoding**
|
||||
|
||||
``` bash
|
||||
VLLM_WORKER_MULTIPROC_METHOD=spawn \
|
||||
@ -332,7 +355,7 @@ Total num prompt tokens: 261136
|
||||
Total num output tokens: 204800
|
||||
```
|
||||
|
||||
### Other HuggingFaceDataset Examples
|
||||
**Other HuggingFaceDataset Examples**
|
||||
|
||||
**`lmms-lab/LLaVA-OneVision-Data`**
|
||||
|
||||
@ -371,7 +394,7 @@ python3 benchmarks/benchmark_throughput.py \
|
||||
--num-prompts 10
|
||||
```
|
||||
|
||||
### Benchmark with LoRA Adapters
|
||||
**Benchmark with LoRA Adapters**
|
||||
|
||||
``` bash
|
||||
# download dataset
|
||||
@ -387,3 +410,196 @@ python3 vllm/benchmarks/benchmark_throughput.py \
|
||||
--enable-lora \
|
||||
--lora-path yard1/llama-2-7b-sql-lora-test
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>🛠️ Example - Structured Output Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of structured output generation (JSON, grammar, regex).
|
||||
|
||||
**Server Setup**
|
||||
|
||||
```bash
|
||||
vllm serve NousResearch/Hermes-3-Llama-3.1-8B --disable-log-requests
|
||||
```
|
||||
|
||||
**JSON Schema Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset json \
|
||||
--structured-output-ratio 1.0 \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**Grammar-based Generation Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset grammar \
|
||||
--structure-type grammar \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**Regex-based Generation Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset regex \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**Choice-based Generation Benchmark**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset choice \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
**XGrammar Benchmark Dataset**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_serving_structured_output.py \
|
||||
--backend vllm \
|
||||
--model NousResearch/Hermes-3-Llama-3.1-8B \
|
||||
--dataset xgrammar_bench \
|
||||
--request-rate 10 \
|
||||
--num-prompts 1000
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>📚 Example - Long Document QA Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of long document question-answering with prefix caching.
|
||||
|
||||
**Basic Long Document QA Test**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 16 \
|
||||
--document-length 2000 \
|
||||
--output-len 50 \
|
||||
--repeat-count 5
|
||||
```
|
||||
|
||||
**Different Repeat Modes**
|
||||
|
||||
```bash
|
||||
# Random mode (default) - shuffle prompts randomly
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode random
|
||||
|
||||
# Tile mode - repeat entire prompt list in sequence
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode tile
|
||||
|
||||
# Interleave mode - repeat each prompt consecutively
|
||||
python3 benchmarks/benchmark_long_document_qa_throughput.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-documents 8 \
|
||||
--document-length 3000 \
|
||||
--repeat-count 3 \
|
||||
--repeat-mode interleave
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>🗂️ Example - Prefix Caching Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the efficiency of automatic prefix caching.
|
||||
|
||||
**Fixed Prompt with Prefix Caching**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 1 \
|
||||
--repeat-count 100 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
**ShareGPT Dataset with Prefix Caching**
|
||||
|
||||
```bash
|
||||
# download dataset
|
||||
# wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json
|
||||
|
||||
python3 benchmarks/benchmark_prefix_caching.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--dataset-path /path/ShareGPT_V3_unfiltered_cleaned_split.json \
|
||||
--enable-prefix-caching \
|
||||
--num-prompts 20 \
|
||||
--repeat-count 5 \
|
||||
--input-length-range 128:256
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary><b>⚡ Example - Request Prioritization Benchmark</b></summary>
|
||||
|
||||
<br/>
|
||||
|
||||
Benchmark the performance of request prioritization in vLLM.
|
||||
|
||||
**Basic Prioritization Test**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority
|
||||
```
|
||||
|
||||
**Multiple Sequences per Prompt**
|
||||
|
||||
```bash
|
||||
python3 benchmarks/benchmark_prioritization.py \
|
||||
--model meta-llama/Llama-2-7b-chat-hf \
|
||||
--input-len 128 \
|
||||
--output-len 64 \
|
||||
--num-prompts 100 \
|
||||
--scheduling-policy priority \
|
||||
--n 2
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
@ -10,6 +10,7 @@
|
||||
# 3. Set variables (ALL REQUIRED)
|
||||
# BASE: your directory for vllm repo
|
||||
# MODEL: the model served by vllm
|
||||
# SYSTEM: the hardware, choice TPU or GPU, for other systems, "get best profile" might not support.
|
||||
# TP: ways of tensor parallelism
|
||||
# DOWNLOAD_DIR: directory to download and load model weights.
|
||||
# INPUT_LEN: request input len
|
||||
@ -34,6 +35,7 @@
|
||||
TAG=$(date +"%Y_%m_%d_%H_%M")
|
||||
BASE=""
|
||||
MODEL="meta-llama/Llama-3.1-8B-Instruct"
|
||||
SYSTEM="TPU"
|
||||
TP=1
|
||||
DOWNLOAD_DIR=""
|
||||
INPUT_LEN=4000
|
||||
@ -45,12 +47,15 @@ NUM_BATCHED_TOKENS_LIST="512 1024 2048 4096"
|
||||
|
||||
LOG_FOLDER="$BASE/auto-benchmark/$TAG"
|
||||
RESULT="$LOG_FOLDER/result.txt"
|
||||
PROFILE_PATH="$LOG_FOLDER/profile"
|
||||
|
||||
echo "result file: $RESULT"
|
||||
echo "model: $MODEL"
|
||||
|
||||
rm -rf $LOG_FOLDER
|
||||
rm -rf $PROFILE_PATH
|
||||
mkdir -p $LOG_FOLDER
|
||||
mkdir -p $PROFILE_PATH
|
||||
|
||||
cd "$BASE/vllm"
|
||||
|
||||
@ -70,10 +75,11 @@ start_server() {
|
||||
local max_num_seqs=$2
|
||||
local max_num_batched_tokens=$3
|
||||
local vllm_log=$4
|
||||
local profile_dir=$5
|
||||
|
||||
pkill -f vllm
|
||||
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 vllm serve $MODEL \
|
||||
VLLM_USE_V1=1 VLLM_SERVER_DEV_MODE=1 VLLM_TORCH_PROFILER_DIR=$profile_dir vllm serve $MODEL \
|
||||
--disable-log-requests \
|
||||
--port 8004 \
|
||||
--gpu-memory-utilization $gpu_memory_utilization \
|
||||
@ -105,19 +111,37 @@ start_server() {
|
||||
fi
|
||||
}
|
||||
|
||||
update_best_profile() {
|
||||
local profile_dir=$1
|
||||
local profile_index=$2
|
||||
sorted_paths=($(find "$profile_dir" -maxdepth 1 -not -path "$profile_dir" | sort))
|
||||
selected_profile_file=
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}/*.xplane.pb"
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
selected_profile_file="${sorted_paths[$profile_index]}"
|
||||
fi
|
||||
rm -f $PROFILE_PATH/*
|
||||
cp $selected_profile_file $PROFILE_PATH
|
||||
}
|
||||
|
||||
run_benchmark() {
|
||||
local max_num_seqs=$1
|
||||
local max_num_batched_tokens=$2
|
||||
local gpu_memory_utilization=$3
|
||||
echo "max_num_seq: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
local vllm_log="$LOG_FOLDER/vllm_log_${max_num_seqs}_${max_num_batched_tokens}.txt"
|
||||
local profile_dir="$LOG_FOLDER/profile_${max_num_seqs}_${max_num_batched_tokens}"
|
||||
echo "vllm_log: $vllm_log"
|
||||
echo
|
||||
rm -f $vllm_log
|
||||
mkdir -p $profile_dir
|
||||
pkill -f vllm
|
||||
local profile_index=0
|
||||
|
||||
echo "starting server..."
|
||||
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log
|
||||
start_server $gpu_memory_utilization $max_num_seqs $max_num_batched_tokens $vllm_log $profile_dir
|
||||
result=$?
|
||||
if [[ "$result" -eq 1 ]]; then
|
||||
echo "server failed to start. gpu_memory_utilization:$gpu_memory_utilization, max_num_seqs:$max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens"
|
||||
@ -144,7 +168,8 @@ run_benchmark() {
|
||||
--goodput e2el:$MAX_LATENCY_ALLOWED_MS \
|
||||
--num-prompts 1000 \
|
||||
--random-prefix-len $prefix_len \
|
||||
--port 8004 &> "$bm_log"
|
||||
--port 8004 \
|
||||
--profile &> "$bm_log"
|
||||
throughput=$(grep "Request throughput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
e2el=$(grep "P99 E2EL (ms):" "$bm_log" | awk '{print $NF}')
|
||||
goodput=$(grep "Request goodput (req/s):" "$bm_log" | sed 's/[^0-9.]//g')
|
||||
@ -158,6 +183,7 @@ run_benchmark() {
|
||||
# start from request-rate as int(throughput) + 1
|
||||
request_rate=$((${throughput%.*} + 1))
|
||||
while ((request_rate > 0)); do
|
||||
profile_index=$((profile_index+1))
|
||||
# clear prefix cache
|
||||
curl -X POST http://0.0.0.0:8004/reset_prefix_cache
|
||||
sleep 5
|
||||
@ -195,6 +221,12 @@ run_benchmark() {
|
||||
best_max_num_seqs=$max_num_seqs
|
||||
best_num_batched_tokens=$max_num_batched_tokens
|
||||
best_goodput=$goodput
|
||||
if [[ "$SYSTEM" == "TPU" ]]; then
|
||||
update_best_profile "$profile_dir/plugins/profile" $profile_index
|
||||
fi
|
||||
if [[ "$SYSTEM" == "GPU" ]]; then
|
||||
update_best_profile "$profile_dir" $profile_index
|
||||
fi
|
||||
fi
|
||||
else
|
||||
echo "max_num_seqs: $max_num_seqs, max_num_batched_tokens: $max_num_batched_tokens does not meet latency requirement ${MAX_LATENCY_ALLOWED_MS}"
|
||||
@ -239,6 +271,6 @@ for num_seqs in "${num_seqs_list[@]}"; do
|
||||
done
|
||||
done
|
||||
echo "finish permutations"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput" >> "$RESULT"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH"
|
||||
echo "best_max_num_seqs: $best_max_num_seqs, best_num_batched_tokens: $best_num_batched_tokens, best_throughput: $best_throughput, profile saved in: $PROFILE_PATH" >> "$RESULT"
|
||||
|
||||
|
||||
@ -404,8 +404,14 @@ async def async_request_openai_chat_completions(
|
||||
chunk_bytes = chunk_bytes.strip()
|
||||
if not chunk_bytes:
|
||||
continue
|
||||
chunk_bytes = chunk_bytes.decode("utf-8")
|
||||
# NOTE: SSE comments (often used as pings) start with a colon.
|
||||
# These are not JSON data payload and should be skipped.
|
||||
if chunk_bytes.startswith(":"):
|
||||
continue
|
||||
|
||||
chunk = chunk_bytes.removeprefix("data: ")
|
||||
|
||||
chunk = chunk_bytes.decode("utf-8").removeprefix("data: ")
|
||||
if chunk != "[DONE]":
|
||||
timestamp = time.perf_counter()
|
||||
data = json.loads(chunk)
|
||||
|
||||
@ -349,11 +349,12 @@ class RandomDataset(BenchmarkDataset):
|
||||
# [1650, 939, 486] -> ['Ġcall', 'sh', 'ere']
|
||||
# To avoid uncontrolled change of the prompt length,
|
||||
# the encoded sequence is truncated before being decode again.
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
re_encoded_sequence = tokenizer.encode(prompt, add_special_tokens=False)[
|
||||
: input_lens[i]
|
||||
:total_input_len
|
||||
]
|
||||
prompt = tokenizer.decode(re_encoded_sequence)
|
||||
total_input_len = prefix_len + int(input_lens[i])
|
||||
total_input_len = len(re_encoded_sequence)
|
||||
requests.append(
|
||||
SampleRequest(
|
||||
prompt=prompt,
|
||||
|
||||
362
benchmarks/benchmark_one_concurrent.py
Normal file
362
benchmarks/benchmark_one_concurrent.py
Normal file
@ -0,0 +1,362 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import aiohttp # Import aiohttp
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
from backend_request_func import RequestFuncInput, RequestFuncOutput
|
||||
from benchmark_dataset import RandomDataset, SampleRequest
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.tokenizer import get_tokenizer
|
||||
except ImportError:
|
||||
from backend_request_func import get_tokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkMetrics:
|
||||
completed: int
|
||||
total_input: int
|
||||
total_output: int
|
||||
mean_ttft_ms: float
|
||||
median_ttft_ms: float
|
||||
std_ttft_ms: float
|
||||
percentiles_ttft_ms: list[tuple[float, float]]
|
||||
mean_itl_ms: float
|
||||
median_itl_ms: float
|
||||
std_itl_ms: float
|
||||
percentiles_itl_ms: list[tuple[float, float]]
|
||||
mean_e2el_ms: float
|
||||
median_e2el_ms: float
|
||||
std_e2el_ms: float
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
async def reset_cache(reset_url: str):
|
||||
"""Sends a POST request to reset the prefix cache."""
|
||||
logger.debug("Resetting prefix cache at %s", reset_url)
|
||||
try:
|
||||
async with (
|
||||
aiohttp.ClientSession() as session,
|
||||
session.post(reset_url) as response,
|
||||
):
|
||||
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
||||
logger.debug("Prefix cache reset successful: %s", response.status)
|
||||
except aiohttp.ClientConnectorError as e:
|
||||
logger.error("Failed to connect to cache reset endpoint %s: %s}", reset_url, e)
|
||||
except aiohttp.ClientResponseError as e:
|
||||
logger.error(
|
||||
"Cache reset request failed with status %s: %s", e.status, e.message
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("An unexpected error occurred during cache reset: %s", e)
|
||||
|
||||
|
||||
async def sequential_benchmark(
|
||||
backend: str,
|
||||
api_url: str,
|
||||
model_id: str,
|
||||
tokenizer,
|
||||
input_requests: list[SampleRequest],
|
||||
request_func,
|
||||
selected_percentiles: list[float],
|
||||
cache_reset_url: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Benchmark that processes requests sequentially, waiting for each to complete
|
||||
before starting the next one. Resets prefix cache between requests.
|
||||
"""
|
||||
outputs = []
|
||||
|
||||
pbar = tqdm(total=len(input_requests))
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
|
||||
# Process requests sequentially
|
||||
for request in input_requests:
|
||||
prompt, prompt_len, output_len = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
request.expected_output_len,
|
||||
)
|
||||
|
||||
logger.info("Sending request with len %s", request.prompt_len)
|
||||
logger.debug('Request str: "%s"', request.prompt[:50])
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
request_func_input = RequestFuncInput(
|
||||
model=model_id,
|
||||
prompt=prompt,
|
||||
api_url=api_url,
|
||||
prompt_len=prompt_len,
|
||||
output_len=output_len,
|
||||
)
|
||||
|
||||
output = await request_func(request_func_input=request_func_input)
|
||||
|
||||
request_end_time = time.perf_counter()
|
||||
# Add timing information
|
||||
if output.success and not hasattr(output, "latency"):
|
||||
output.latency = request_end_time - request_start_time
|
||||
logger.info("Finished request with latency %.4f s", output.latency)
|
||||
|
||||
outputs.append(output)
|
||||
pbar.update(1)
|
||||
|
||||
pbar.close()
|
||||
|
||||
benchmark_duration = time.perf_counter() - benchmark_start_time
|
||||
|
||||
# Calculate metrics
|
||||
metrics = calculate_metrics(
|
||||
input_requests=input_requests,
|
||||
outputs=outputs,
|
||||
dur_s=benchmark_duration,
|
||||
tokenizer=tokenizer,
|
||||
selected_percentiles=selected_percentiles,
|
||||
)
|
||||
|
||||
print_results(metrics, benchmark_duration)
|
||||
|
||||
result = {
|
||||
"duration": benchmark_duration,
|
||||
"completed": metrics.completed,
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"input_lens": [request.prompt_len for request in input_requests],
|
||||
"output_lens": [
|
||||
output.output_tokens if output.success else 0 for output in outputs
|
||||
],
|
||||
"ttfts": [output.ttft for output in outputs if output.success],
|
||||
"itls": [output.itl for output in outputs if output.success],
|
||||
"generated_texts": [
|
||||
output.generated_text for output in outputs if output.success
|
||||
],
|
||||
"errors": [output.error for output in outputs if not output.success],
|
||||
}
|
||||
|
||||
# Add summary statistics
|
||||
for stat_name in ["ttft", "itl", "e2el"]:
|
||||
for metric_name in ["mean", "median", "std"]:
|
||||
result[f"{metric_name}_{stat_name}_ms"] = getattr(
|
||||
metrics, f"{metric_name}_{stat_name}_ms"
|
||||
)
|
||||
|
||||
for p, value in getattr(metrics, f"percentiles_{stat_name}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
result[f"p{p_word}_{stat_name}_ms"] = value
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def calculate_metrics(
|
||||
input_requests: list[SampleRequest],
|
||||
outputs: list[RequestFuncOutput],
|
||||
dur_s: float,
|
||||
tokenizer,
|
||||
selected_percentiles: list[float],
|
||||
) -> BenchmarkMetrics:
|
||||
"""Calculate benchmark metrics from results."""
|
||||
total_input = 0
|
||||
completed = 0
|
||||
total_output = 0
|
||||
ttfts = []
|
||||
itls = []
|
||||
e2els = []
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
if output.success:
|
||||
output_len = output.output_tokens
|
||||
|
||||
if not output_len:
|
||||
# Use tokenizer to count output tokens if not provided
|
||||
output_len = len(
|
||||
tokenizer(output.generated_text, add_special_tokens=False).input_ids
|
||||
)
|
||||
|
||||
total_output += output_len
|
||||
total_input += input_requests[i].prompt_len
|
||||
|
||||
if hasattr(output, "ttft") and output.ttft is not None:
|
||||
ttfts.append(output.ttft)
|
||||
|
||||
if hasattr(output, "itl") and output.itl:
|
||||
# Ensure itl is a list of floats
|
||||
if isinstance(output.itl, list):
|
||||
itls.extend(output.itl)
|
||||
else:
|
||||
logger.warning(
|
||||
"Expected list for ITL but got %s. Appending as is.",
|
||||
type(output.itl),
|
||||
)
|
||||
itls.append(output.itl)
|
||||
|
||||
if hasattr(output, "latency") and output.latency is not None:
|
||||
e2els.append(output.latency)
|
||||
|
||||
completed += 1
|
||||
|
||||
return BenchmarkMetrics(
|
||||
completed=completed,
|
||||
total_input=total_input,
|
||||
total_output=total_output,
|
||||
mean_ttft_ms=np.mean(ttfts or [0]) * 1000,
|
||||
median_ttft_ms=np.median(ttfts or [0]) * 1000,
|
||||
std_ttft_ms=np.std(ttfts or [0]) * 1000,
|
||||
percentiles_ttft_ms=[
|
||||
(p, np.percentile(ttfts or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_itl_ms=np.mean(itls or [0]) * 1000,
|
||||
median_itl_ms=np.median(itls or [0]) * 1000,
|
||||
std_itl_ms=np.std(itls or [0]) * 1000,
|
||||
percentiles_itl_ms=[
|
||||
(p, np.percentile(itls or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
mean_e2el_ms=np.mean(e2els or [0]) * 1000,
|
||||
median_e2el_ms=np.median(e2els or [0]) * 1000,
|
||||
std_e2el_ms=np.std(e2els or [0]) * 1000,
|
||||
percentiles_e2el_ms=[
|
||||
(p, np.percentile(e2els or [0], p) * 1000) for p in selected_percentiles
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def print_results(metrics: BenchmarkMetrics, benchmark_duration: float):
|
||||
"""Print benchmark results in a formatted way."""
|
||||
print("{s:{c}^{n}}".format(s=" Sequential Benchmark Result ", n=60, c="="))
|
||||
print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
|
||||
print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
|
||||
print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
|
||||
print("{:<40} {:<10}".format("Total generated tokens:", metrics.total_output))
|
||||
|
||||
def print_metric_stats(metric_name, header):
|
||||
print("{s:{c}^{n}}".format(s=header, n=60, c="-"))
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Mean {metric_name} (ms):",
|
||||
getattr(metrics, f"mean_{metric_name.lower()}_ms"),
|
||||
)
|
||||
)
|
||||
print(
|
||||
"{:<40} {:<10.2f}".format(
|
||||
f"Median {metric_name} (ms):",
|
||||
getattr(metrics, f"median_{metric_name.lower()}_ms"),
|
||||
)
|
||||
)
|
||||
|
||||
for p, value in getattr(metrics, f"percentiles_{metric_name.lower()}_ms"):
|
||||
p_word = str(int(p)) if int(p) == p else str(p)
|
||||
print("{:<40} {:<10.2f}".format(f"P{p_word} {metric_name} (ms):", value))
|
||||
|
||||
print_metric_stats("TTFT", "Time to First Token")
|
||||
print_metric_stats("ITL", "Inter-token Latency")
|
||||
print_metric_stats("E2EL", "End-to-end Latency")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
async def main_async(args):
|
||||
# Import needed functions based on your setup
|
||||
from backend_request_func import ASYNC_REQUEST_FUNCS
|
||||
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
|
||||
# Set up API URL
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
else:
|
||||
api_url = f"http://{args.host}:{args.port}{args.endpoint}"
|
||||
|
||||
# Set up Cache Reset URL
|
||||
cache_reset_url = f"http://{args.host}:{args.port}/reset_prefix_cache"
|
||||
logger.info("Prefix cache reset configured at: %s", cache_reset_url)
|
||||
|
||||
# Get tokenizer
|
||||
tokenizer = get_tokenizer(tokenizer_id, trust_remote_code=args.trust_remote_code)
|
||||
|
||||
# Get request function
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
else:
|
||||
raise ValueError(f"Unknown backend: {backend}")
|
||||
|
||||
input_requests = RandomDataset().sample(
|
||||
tokenizer=tokenizer,
|
||||
num_requests=args.num_requests,
|
||||
prefix_len=0,
|
||||
input_len=args.input_len,
|
||||
output_len=args.output_len,
|
||||
range_ratio=0.0,
|
||||
)
|
||||
|
||||
# Run benchmark
|
||||
result = await sequential_benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_func=request_func,
|
||||
selected_percentiles=[50, 90, 95, 99],
|
||||
cache_reset_url=cache_reset_url,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main(args):
|
||||
print(args)
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
|
||||
asyncio.run(main_async(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Sequential benchmark for LLM serving")
|
||||
parser.add_argument(
|
||||
"--backend", type=str, default="vllm", help="Backend to use for requests"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--base-url",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Server base URL (overrides --host and --port)",
|
||||
)
|
||||
parser.add_argument("--host", type=str, default="127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=8000)
|
||||
parser.add_argument(
|
||||
"--endpoint", type=str, default="/v1/completions", help="API endpoint"
|
||||
)
|
||||
parser.add_argument("--model", type=str, required=True, help="Name of the model")
|
||||
parser.add_argument(
|
||||
"--tokenizer", type=str, help="Name of the tokenizer (defaults to model name)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-requests", type=int, default=100, help="Number of requests to process"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input-len", type=int, default=128, help="Input len for generated prompts"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-len", type=int, default=None, help="Override output len for requests"
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=42)
|
||||
parser.add_argument(
|
||||
"--trust-remote-code",
|
||||
action="store_true",
|
||||
help="Trust remote code from HuggingFace",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@ -33,7 +33,7 @@ import warnings
|
||||
from collections.abc import AsyncGenerator, Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import numpy as np
|
||||
from tqdm.asyncio import tqdm
|
||||
@ -107,14 +107,42 @@ class BenchmarkMetrics:
|
||||
percentiles_e2el_ms: list[tuple[float, float]]
|
||||
|
||||
|
||||
def _get_current_request_rate(
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]],
|
||||
ramp_up_start_rps: Optional[int],
|
||||
ramp_up_end_rps: Optional[int],
|
||||
request_index: int,
|
||||
total_requests: int,
|
||||
request_rate: float,
|
||||
) -> float:
|
||||
if (
|
||||
ramp_up_strategy
|
||||
and ramp_up_start_rps is not None
|
||||
and ramp_up_end_rps is not None
|
||||
):
|
||||
progress = request_index / max(total_requests - 1, 1)
|
||||
if ramp_up_strategy == "linear":
|
||||
increase = (ramp_up_end_rps - ramp_up_start_rps) * progress
|
||||
return ramp_up_start_rps + increase
|
||||
elif ramp_up_strategy == "exponential":
|
||||
ratio = ramp_up_end_rps / ramp_up_start_rps
|
||||
return ramp_up_start_rps * (ratio**progress)
|
||||
else:
|
||||
raise ValueError(f"Unknown ramp-up strategy: {ramp_up_strategy}")
|
||||
return request_rate
|
||||
|
||||
|
||||
async def get_request(
|
||||
input_requests: list[SampleRequest],
|
||||
request_rate: float,
|
||||
burstiness: float = 1.0,
|
||||
) -> AsyncGenerator[SampleRequest, None]:
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
) -> AsyncGenerator[tuple[SampleRequest, float], None]:
|
||||
"""
|
||||
Asynchronously generates requests at a specified rate
|
||||
with OPTIONAL burstiness.
|
||||
with OPTIONAL burstiness and OPTIONAL ramp-up strategy.
|
||||
|
||||
Args:
|
||||
input_requests:
|
||||
@ -129,22 +157,44 @@ async def get_request(
|
||||
A lower burstiness value (0 < burstiness < 1) results
|
||||
in more bursty requests, while a higher burstiness value
|
||||
(burstiness > 1) results in a more uniform arrival of requests.
|
||||
ramp_up_strategy (optional):
|
||||
The ramp-up strategy. Can be "linear" or "exponential".
|
||||
If None, uses constant request rate (specified by request_rate).
|
||||
ramp_up_start_rps (optional):
|
||||
The starting request rate for ramp-up.
|
||||
ramp_up_end_rps (optional):
|
||||
The ending request rate for ramp-up.
|
||||
"""
|
||||
input_requests: Iterable[SampleRequest] = iter(input_requests)
|
||||
|
||||
# Calculate scale parameter theta to maintain the desired request_rate.
|
||||
assert burstiness > 0, (
|
||||
f"A positive burstiness factor is expected, but given {burstiness}."
|
||||
)
|
||||
theta = 1.0 / (request_rate * burstiness)
|
||||
# Convert to list to get length for ramp-up calculations
|
||||
if isinstance(input_requests, Iterable) and not isinstance(input_requests, list):
|
||||
input_requests = list(input_requests)
|
||||
|
||||
total_requests = len(input_requests)
|
||||
request_index = 0
|
||||
|
||||
for request in input_requests:
|
||||
yield request
|
||||
current_request_rate = _get_current_request_rate(
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
request_index,
|
||||
total_requests,
|
||||
request_rate,
|
||||
)
|
||||
|
||||
if request_rate == float("inf"):
|
||||
yield request, current_request_rate
|
||||
|
||||
request_index += 1
|
||||
|
||||
if current_request_rate == float("inf"):
|
||||
# If the request rate is infinity, then we don't need to wait.
|
||||
continue
|
||||
|
||||
theta = 1.0 / (current_request_rate * burstiness)
|
||||
|
||||
# Sample the request interval from the gamma distribution.
|
||||
# If burstiness is 1, it follows exponential distribution.
|
||||
interval = np.random.gamma(shape=burstiness, scale=theta)
|
||||
@ -290,6 +340,9 @@ async def benchmark(
|
||||
max_concurrency: Optional[int],
|
||||
lora_modules: Optional[Iterable[str]],
|
||||
extra_body: Optional[dict],
|
||||
ramp_up_strategy: Optional[Literal["linear", "exponential"]] = None,
|
||||
ramp_up_start_rps: Optional[int] = None,
|
||||
ramp_up_end_rps: Optional[int] = None,
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
request_func = ASYNC_REQUEST_FUNCS[backend]
|
||||
@ -353,7 +406,15 @@ async def benchmark(
|
||||
|
||||
distribution = "Poisson process" if burstiness == 1.0 else "Gamma distribution"
|
||||
|
||||
print(f"Traffic request rate: {request_rate}")
|
||||
if ramp_up_strategy is not None:
|
||||
print(
|
||||
f"Traffic ramp-up strategy: {ramp_up_strategy}. Will increase "
|
||||
f"RPS from {ramp_up_start_rps} to {ramp_up_end_rps} RPS over "
|
||||
"the duration of the benchmark."
|
||||
)
|
||||
else:
|
||||
print(f"Traffic request rate: {request_rate} RPS.")
|
||||
|
||||
print(f"Burstiness factor: {burstiness} ({distribution})")
|
||||
print(f"Maximum request concurrency: {max_concurrency}")
|
||||
|
||||
@ -373,7 +434,34 @@ async def benchmark(
|
||||
|
||||
benchmark_start_time = time.perf_counter()
|
||||
tasks: list[asyncio.Task] = []
|
||||
async for request in get_request(input_requests, request_rate, burstiness):
|
||||
|
||||
rps_change_events = []
|
||||
last_int_rps = -1
|
||||
if ramp_up_strategy is not None and ramp_up_start_rps is not None:
|
||||
last_int_rps = ramp_up_start_rps
|
||||
rps_change_events.append(
|
||||
{
|
||||
"rps": last_int_rps,
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
}
|
||||
)
|
||||
|
||||
async for request, current_request_rate in get_request(
|
||||
input_requests,
|
||||
request_rate,
|
||||
burstiness,
|
||||
ramp_up_strategy,
|
||||
ramp_up_start_rps,
|
||||
ramp_up_end_rps,
|
||||
):
|
||||
if ramp_up_strategy is not None:
|
||||
current_int_rps = int(current_request_rate)
|
||||
if current_int_rps > last_int_rps:
|
||||
timestamp = datetime.now().isoformat()
|
||||
for rps_val in range(last_int_rps + 1, current_int_rps + 1):
|
||||
rps_change_events.append({"rps": rps_val, "timestamp": timestamp})
|
||||
last_int_rps = current_int_rps
|
||||
|
||||
prompt, prompt_len, output_len, mm_content = (
|
||||
request.prompt,
|
||||
request.prompt_len,
|
||||
@ -397,11 +485,8 @@ async def benchmark(
|
||||
ignore_eos=ignore_eos,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
)
|
||||
)
|
||||
task = limited_request_func(request_func_input=request_func_input, pbar=pbar)
|
||||
tasks.append(asyncio.create_task(task))
|
||||
outputs: list[RequestFuncOutput] = await asyncio.gather(*tasks)
|
||||
|
||||
if profile:
|
||||
@ -466,7 +551,7 @@ async def benchmark(
|
||||
"total_input_tokens": metrics.total_input,
|
||||
"total_output_tokens": metrics.total_output,
|
||||
"request_throughput": metrics.request_throughput,
|
||||
"request_goodput:": metrics.request_goodput if goodput_config_dict else None,
|
||||
"request_goodput": metrics.request_goodput if goodput_config_dict else None,
|
||||
"output_throughput": metrics.output_throughput,
|
||||
"total_token_throughput": metrics.total_token_throughput,
|
||||
"input_lens": [output.prompt_len for output in outputs],
|
||||
@ -477,6 +562,9 @@ async def benchmark(
|
||||
"errors": [output.error for output in outputs],
|
||||
}
|
||||
|
||||
if rps_change_events:
|
||||
result["rps_change_events"] = rps_change_events
|
||||
|
||||
def process_one_metric(
|
||||
# E.g., "ttft"
|
||||
metric_attribute_name: str,
|
||||
@ -610,6 +698,26 @@ def main(args: argparse.Namespace):
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
tokenizer_mode = args.tokenizer_mode
|
||||
|
||||
# Validate ramp-up arguments
|
||||
if args.ramp_up_strategy is not None:
|
||||
if args.request_rate != float("inf"):
|
||||
raise ValueError(
|
||||
"When using ramp-up, do not specify --request-rate. "
|
||||
"The request rate will be controlled by ramp-up parameters. "
|
||||
"Please remove the --request-rate argument."
|
||||
)
|
||||
if args.ramp_up_start_rps is None or args.ramp_up_end_rps is None:
|
||||
raise ValueError(
|
||||
"When using --ramp-up-strategy, both --ramp-up-start-rps and "
|
||||
"--ramp-up-end-rps must be specified"
|
||||
)
|
||||
if args.ramp_up_start_rps < 0 or args.ramp_up_end_rps < 0:
|
||||
raise ValueError("Ramp-up start and end RPS must be non-negative")
|
||||
if args.ramp_up_start_rps > args.ramp_up_end_rps:
|
||||
raise ValueError("Ramp-up start RPS must be less than end RPS")
|
||||
if args.ramp_up_strategy == "exponential" and args.ramp_up_start_rps == 0:
|
||||
raise ValueError("For exponential ramp-up, the start RPS cannot be 0.")
|
||||
|
||||
if args.base_url is not None:
|
||||
api_url = f"{args.base_url}{args.endpoint}"
|
||||
base_url = f"{args.base_url}"
|
||||
@ -802,6 +910,9 @@ def main(args: argparse.Namespace):
|
||||
max_concurrency=args.max_concurrency,
|
||||
lora_modules=args.lora_modules,
|
||||
extra_body=sampling_params,
|
||||
ramp_up_strategy=args.ramp_up_strategy,
|
||||
ramp_up_start_rps=args.ramp_up_start_rps,
|
||||
ramp_up_end_rps=args.ramp_up_end_rps,
|
||||
)
|
||||
)
|
||||
|
||||
@ -834,6 +945,11 @@ def main(args: argparse.Namespace):
|
||||
result_json["burstiness"] = args.burstiness
|
||||
result_json["max_concurrency"] = args.max_concurrency
|
||||
|
||||
if args.ramp_up_strategy is not None:
|
||||
result_json["ramp_up_strategy"] = args.ramp_up_strategy
|
||||
result_json["ramp_up_start_rps"] = args.ramp_up_start_rps
|
||||
result_json["ramp_up_end_rps"] = args.ramp_up_end_rps
|
||||
|
||||
# Merge with benchmark result
|
||||
result_json = {**result_json, **benchmark_result}
|
||||
|
||||
@ -859,7 +975,10 @@ def main(args: argparse.Namespace):
|
||||
if args.max_concurrency is not None
|
||||
else ""
|
||||
)
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.ramp_up_strategy is not None:
|
||||
file_name = f"{backend}-ramp-up-{args.ramp_up_strategy}-{args.ramp_up_start_rps}qps-{args.ramp_up_end_rps}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
else:
|
||||
file_name = f"{backend}-{args.request_rate}qps{max_concurrency_str}-{base_model_id}-{current_dt}.json" # noqa
|
||||
if args.result_filename:
|
||||
file_name = args.result_filename
|
||||
if args.result_dir:
|
||||
@ -1225,6 +1344,31 @@ def create_argument_parser():
|
||||
"script chooses a LoRA module at random.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--ramp-up-strategy",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["linear", "exponential"],
|
||||
help="The ramp-up strategy. This would be used to "
|
||||
"ramp up the request rate from initial RPS to final "
|
||||
"RPS rate (specified by --ramp-up-start-rps and --ramp-up-end-rps). "
|
||||
"over the duration of the benchmark.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-start-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The starting request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ramp-up-end-rps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The ending request rate for ramp-up (RPS). "
|
||||
"Needs to be specified when --ramp-up-strategy is used.",
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
@ -97,7 +97,7 @@ def run_vllm(
|
||||
assert lora_requests is None, "BeamSearch API does not support LoRA"
|
||||
prompts = [request.prompt for request in requests]
|
||||
# output_len should be the same for all requests.
|
||||
output_len = requests[0][2]
|
||||
output_len = requests[0].expected_output_len
|
||||
for request in requests:
|
||||
assert request.expected_output_len == output_len
|
||||
start = time.perf_counter()
|
||||
|
||||
@ -19,7 +19,7 @@ from vllm import _custom_ops as ops
|
||||
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
|
||||
w8a8_block_fp8_matmul,
|
||||
)
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
from vllm.utils import FlexibleArgumentParser, cdiv
|
||||
|
||||
DEFAULT_MODELS = list(WEIGHT_SHAPES.keys())
|
||||
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512]
|
||||
@ -117,14 +117,9 @@ def bench_fp8(
|
||||
scale_a = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
scale_b = torch.tensor(1.0, device="cuda", dtype=torch.float32)
|
||||
|
||||
def ceil_div(x: int, y: int) -> int:
|
||||
return (x + y - 1) // y
|
||||
|
||||
block_scale_a = torch.rand(
|
||||
(m, ceil_div(k, 128)), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a = torch.rand((m, cdiv(k, 128)), device="cuda", dtype=torch.float32)
|
||||
block_scale_b = torch.rand(
|
||||
ceil_div(k, 128), ceil_div(n, 128), device="cuda", dtype=torch.float32
|
||||
cdiv(k, 128), cdiv(n, 128), device="cuda", dtype=torch.float32
|
||||
)
|
||||
block_scale_a_M_major = block_scale_a.t().contiguous().t()
|
||||
block_scale_b_K_major = block_scale_b.t().contiguous().t()
|
||||
|
||||
@ -113,6 +113,7 @@ def bench_run(
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
per_act_token: bool,
|
||||
num_repeats: int,
|
||||
):
|
||||
for _ in range(num_repeats):
|
||||
@ -124,7 +125,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_cutlass_from_graph(
|
||||
@ -148,7 +150,8 @@ def bench_run(
|
||||
topk_ids,
|
||||
w1_scale,
|
||||
w2_scale,
|
||||
a1_scale=a_scale,
|
||||
per_act_token,
|
||||
a1_scale=None,
|
||||
)
|
||||
|
||||
def run_triton_from_graph(
|
||||
@ -227,6 +230,7 @@ def bench_run(
|
||||
"w2_q": w2_q,
|
||||
"w1_scale": w1_scale,
|
||||
"w2_scale": w2_scale,
|
||||
"per_act_token": per_act_token,
|
||||
# cuda graph params
|
||||
"cutlass_graph": cutlass_graph,
|
||||
"triton_graph": triton_graph,
|
||||
@ -287,12 +291,13 @@ def bench_run(
|
||||
w2_scale,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
per_act_token,
|
||||
num_warmup,
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, num_runs)", # noqa: E501
|
||||
stmt="run_cutlass_moe(a, a_scale, w1_q, w2_q, w1_scale, w2_scale, topk_weights, topk_ids, per_act_token, num_runs)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
|
||||
@ -234,8 +234,10 @@ def marlin_create_bench_fn(bt: BenchmarkTensors) -> Callable:
|
||||
|
||||
fn = lambda: ops.gptq_marlin_gemm(
|
||||
a=bt.a,
|
||||
c=None,
|
||||
b_q_weight=w_q,
|
||||
b_scales=w_s,
|
||||
global_scale=None,
|
||||
b_zeros=w_zp,
|
||||
g_idx=g_idx,
|
||||
perm=sort_indices,
|
||||
|
||||
@ -22,8 +22,16 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
query_marlin_supported_quant_types,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import (
|
||||
FP4_MARLIN_SUPPORTED_GROUP_SIZES,
|
||||
rand_marlin_weight_fp4_like,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
||||
marlin_quant_fp8_torch,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test import (
|
||||
MarlinWorkspace,
|
||||
awq_marlin_quantize,
|
||||
marlin_quantize,
|
||||
)
|
||||
from vllm.model_executor.layers.quantization.utils.marlin_utils_test_24 import (
|
||||
@ -35,7 +43,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
|
||||
quantize_weights,
|
||||
sort_weights,
|
||||
)
|
||||
from vllm.scalar_type import ScalarType
|
||||
from vllm.scalar_type import ScalarType, scalar_types
|
||||
from vllm.utils import FlexibleArgumentParser
|
||||
|
||||
DEFAULT_MODELS = ["meta-llama/Llama-2-7b-hf/TP1"]
|
||||
@ -57,80 +65,144 @@ def bench_run(
|
||||
size_n: int,
|
||||
):
|
||||
label = "Quant Matmul"
|
||||
|
||||
sub_label = "{}, act={} k_full={}, q={}, g={}, MKN=({}x{}x{})".format(
|
||||
model, act_order, is_k_full, str(quant_type), group_size, size_m, size_k, size_n
|
||||
)
|
||||
|
||||
print(f"Testing: {sub_label}")
|
||||
|
||||
a = torch.randn(size_m, size_k).to(torch.half).cuda()
|
||||
b = torch.rand(size_k, size_n).to(torch.half).cuda()
|
||||
has_zp = quant_type in [scalar_types.uint4, scalar_types.uint8]
|
||||
if act_order and (group_size == -1 or group_size == size_k or has_zp):
|
||||
return
|
||||
if size_k % group_size != 0:
|
||||
return
|
||||
|
||||
a_tmp = torch.zeros(size_m, size_k).to(torch.half).cuda()
|
||||
|
||||
# Marlin quant
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
marlin_rand_perm,
|
||||
) = marlin_quantize(b, quant_type, group_size, act_order)
|
||||
|
||||
# Marlin_24 quant
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
marlin_24_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
|
||||
marlin_zp = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
|
||||
# GPTQ quant
|
||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||
b, quant_type, group_size, act_order
|
||||
repack_supported = (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in MARLIN_SUPPORTED_GROUP_SIZES
|
||||
)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
marlin_zp = torch.zeros_like(marlin_s, dtype=torch.int)
|
||||
|
||||
# AllSpark W8A16 quant
|
||||
as_supported_case = (
|
||||
allspark_supported = (
|
||||
quant_type in ALLSPARK_SUPPORTED_QUANT_TYPES
|
||||
and group_size == -1
|
||||
and not act_order
|
||||
and is_k_full
|
||||
)
|
||||
if as_supported_case:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = sm_version >= 80 and sm_version < 90
|
||||
as_supported_case = as_supported_case and supported_arch
|
||||
if supported_arch:
|
||||
has_zp = False
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp
|
||||
def gen_marlin_params():
|
||||
# Marlin quant
|
||||
marlin_g_idx = marlin_sort_indices = marlin_zp = marlin_s2 = None
|
||||
if quant_type == scalar_types.float4_e2m1f:
|
||||
if group_size != 16 or act_order:
|
||||
return
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_s2 = rand_marlin_weight_fp4_like(
|
||||
b.T, group_size
|
||||
)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
elif quant_type == scalar_types.float8_e4m3fn:
|
||||
if group_size not in [-1, 128] or act_order:
|
||||
return
|
||||
marlin_w_ref, marlin_q_w, marlin_s = marlin_quant_fp8_torch(b.T, group_size)
|
||||
elif group_size == 16:
|
||||
return
|
||||
elif has_zp:
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_zp = awq_marlin_quantize(
|
||||
b, quant_type, group_size
|
||||
)
|
||||
else:
|
||||
marlin_w_ref, marlin_q_w, marlin_s, marlin_g_idx, marlin_sort_indices, _ = (
|
||||
marlin_quantize(b, quant_type, group_size, act_order)
|
||||
)
|
||||
return (
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
)
|
||||
|
||||
def gen_marlin_24_params():
|
||||
marlin_24_w_ref = marlin_24_q_w_comp = marlin_24_meta = marlin_24_s = None
|
||||
if marlin_24_supported:
|
||||
(marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s) = (
|
||||
marlin_24_quantize(b, quant_type, group_size)
|
||||
)
|
||||
return (marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s)
|
||||
|
||||
def gen_repack_params():
|
||||
q_w_gptq = None
|
||||
repack_sort_indices = None
|
||||
if repack_supported:
|
||||
(w_ref, q_w, s, g_idx, rand_perm) = gptq_quantize_weights(
|
||||
b, quant_type, group_size, act_order
|
||||
)
|
||||
q_w_gptq = gptq_pack(q_w, quant_type.size_bits, size_k, size_n)
|
||||
|
||||
# For act_order, sort the "weights" and "g_idx"
|
||||
# so that group ids are increasing
|
||||
repack_sort_indices = torch.empty(0, dtype=torch.int, device=b.device)
|
||||
if act_order:
|
||||
(q_w, g_idx, repack_sort_indices) = sort_weights(q_w, g_idx)
|
||||
return q_w_gptq, repack_sort_indices
|
||||
|
||||
def gen_allspark_params():
|
||||
qw_reorder = s_reorder = zp_reorder = sm_count = sm_version = (
|
||||
CUBLAS_M_THRESHOLD
|
||||
) = None
|
||||
nonlocal allspark_supported
|
||||
if allspark_supported:
|
||||
properties = torch.cuda.get_device_properties(b.device.index)
|
||||
sm_count = properties.multi_processor_count
|
||||
sm_version = properties.major * 10 + properties.minor
|
||||
|
||||
supported_arch = sm_version >= 80 and sm_version < 90
|
||||
allspark_supported = allspark_supported and supported_arch
|
||||
if supported_arch:
|
||||
w_ref, qw, s, zp = quantize_weights(b, quant_type, group_size, has_zp)
|
||||
qw = qw.to(torch.uint8)
|
||||
|
||||
qw_reorder, s_reorder, zp_reorder = ops.allspark_repack_weight(
|
||||
qw, s, zp, has_zp
|
||||
)
|
||||
CUBLAS_M_THRESHOLD = ALLSPARK_AMPERE_M_CUBLAS_THRESHOLD
|
||||
return (
|
||||
qw_reorder,
|
||||
s_reorder,
|
||||
zp_reorder,
|
||||
sm_count,
|
||||
sm_version,
|
||||
CUBLAS_M_THRESHOLD,
|
||||
)
|
||||
|
||||
(
|
||||
marlin_w_ref,
|
||||
marlin_q_w,
|
||||
marlin_s,
|
||||
marlin_s2,
|
||||
marlin_zp,
|
||||
marlin_g_idx,
|
||||
marlin_sort_indices,
|
||||
) = gen_marlin_params()
|
||||
marlin_24_w_ref, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s = (
|
||||
gen_marlin_24_params()
|
||||
)
|
||||
q_w_gptq, repack_sort_indices = gen_repack_params()
|
||||
qw_reorder, s_reorder, zp_reorder, sm_count, sm_version, CUBLAS_M_THRESHOLD = (
|
||||
gen_allspark_params()
|
||||
)
|
||||
|
||||
# Prepare
|
||||
marlin_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_MIN_THREAD_N, GPTQ_MARLIN_MAX_PARALLEL
|
||||
)
|
||||
marlin_24_workspace = MarlinWorkspace(
|
||||
size_n, GPTQ_MARLIN_24_MIN_THREAD_N, GPTQ_MARLIN_24_MAX_PARALLEL
|
||||
)
|
||||
|
||||
globals = {
|
||||
# Gen params
|
||||
@ -140,15 +212,14 @@ def bench_run(
|
||||
"size_n": size_n,
|
||||
"size_k": size_k,
|
||||
"a": a,
|
||||
"a_tmp": a_tmp,
|
||||
# Marlin params
|
||||
"marlin_w_ref": marlin_w_ref,
|
||||
"marlin_q_w": marlin_q_w,
|
||||
"marlin_s": marlin_s,
|
||||
"marlin_s2": marlin_s2,
|
||||
"marlin_zp": marlin_zp,
|
||||
"marlin_g_idx": marlin_g_idx,
|
||||
"marlin_sort_indices": marlin_sort_indices,
|
||||
"marlin_rand_perm": marlin_rand_perm,
|
||||
"marlin_workspace": marlin_workspace,
|
||||
"is_k_full": is_k_full,
|
||||
# Marlin_24 params
|
||||
@ -161,12 +232,12 @@ def bench_run(
|
||||
"q_w_gptq": q_w_gptq,
|
||||
"repack_sort_indices": repack_sort_indices,
|
||||
# AllSpark W8A16 params
|
||||
"qw_reorder": qw_reorder if as_supported_case else None,
|
||||
"s_reorder": s_reorder if as_supported_case else None,
|
||||
"zp_reorder": zp_reorder if as_supported_case else None,
|
||||
"sm_count": sm_count if as_supported_case else None,
|
||||
"sm_version": sm_version if as_supported_case else None,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD if as_supported_case else None,
|
||||
"qw_reorder": qw_reorder,
|
||||
"s_reorder": s_reorder,
|
||||
"zp_reorder": zp_reorder,
|
||||
"sm_count": sm_count,
|
||||
"sm_version": sm_version,
|
||||
"CUBLAS_M_THRESHOLD": CUBLAS_M_THRESHOLD,
|
||||
# Kernels
|
||||
"gptq_marlin_gemm": ops.gptq_marlin_gemm,
|
||||
"gptq_marlin_24_gemm": ops.gptq_marlin_24_gemm,
|
||||
@ -177,7 +248,7 @@ def bench_run(
|
||||
min_run_time = 1
|
||||
|
||||
# Warmup pytorch
|
||||
for i in range(5):
|
||||
for _ in range(5):
|
||||
torch.matmul(a, marlin_w_ref)
|
||||
|
||||
results.append(
|
||||
@ -192,17 +263,17 @@ def bench_run(
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, False, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_gemm_fp16",
|
||||
description="gptq_marlin_gemm",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_gemm(a, marlin_q_w, marlin_s, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
stmt="output = gptq_marlin_gemm(a, None, marlin_q_w, marlin_s, marlin_s2, marlin_zp, marlin_g_idx, marlin_sort_indices, marlin_workspace.scratch, quant_type, size_m, size_n, size_k, is_k_full, False, True, False)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
@ -210,10 +281,7 @@ def bench_run(
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if (
|
||||
quant_type in GPTQ_MARLIN_24_SUPPORTED_QUANT_TYPES
|
||||
and group_size in GPTQ_MARLIN_24_SUPPORTED_GROUP_SIZES
|
||||
):
|
||||
if marlin_24_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = gptq_marlin_24_gemm(a, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s, marlin_24_workspace.scratch, quant_type, size_m, size_n, size_k)", # noqa: E501
|
||||
@ -224,17 +292,18 @@ def bench_run(
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
if repack_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="q_res = gptq_marlin_repack(q_w_gptq, repack_sort_indices, size_k, size_n, quant_type.size_bits)", # noqa: E501
|
||||
globals=globals,
|
||||
label=label,
|
||||
sub_label=sub_label,
|
||||
description="gptq_marlin_repack",
|
||||
).blocked_autorange(min_run_time=min_run_time)
|
||||
)
|
||||
|
||||
if as_supported_case:
|
||||
if allspark_supported:
|
||||
results.append(
|
||||
benchmark.Timer(
|
||||
stmt="output = allspark_w8a16_gemm(a, qw_reorder, s_reorder, zp_reorder, size_n, group_size, sm_count, sm_version, CUBLAS_M_THRESHOLD, False, True)", # noqa: E501
|
||||
@ -250,7 +319,6 @@ def main(args):
|
||||
print("Benchmarking models:")
|
||||
for i, model in enumerate(args.models):
|
||||
print(f"[{i}] {model}")
|
||||
|
||||
results: list[benchmark.Measurement] = []
|
||||
|
||||
for model in args.models:
|
||||
@ -278,14 +346,17 @@ def main(args):
|
||||
):
|
||||
continue
|
||||
|
||||
for quant_type in query_marlin_supported_quant_types(False):
|
||||
for quant_type in query_marlin_supported_quant_types():
|
||||
if (
|
||||
len(args.limit_num_bits) > 0
|
||||
and quant_type.size_bits not in args.limit_num_bits
|
||||
):
|
||||
continue
|
||||
|
||||
for group_size in MARLIN_SUPPORTED_GROUP_SIZES:
|
||||
for group_size in (
|
||||
MARLIN_SUPPORTED_GROUP_SIZES
|
||||
+ FP4_MARLIN_SUPPORTED_GROUP_SIZES
|
||||
):
|
||||
if (
|
||||
len(args.limit_group_size) > 0
|
||||
and group_size not in args.limit_group_size
|
||||
|
||||
@ -85,12 +85,6 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === DeepGEMM Implementation ===
|
||||
def deepgemm_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_cast_to_fp8(A)
|
||||
# A_deepgemm, A_scale_deepgemm = per_token_group_quant_fp8(
|
||||
# A, block_size[1])
|
||||
# A_scale_aligned = get_col_major_tma_aligned_tensor(A_scale_deepgemm)
|
||||
# C_deepgemm = torch.empty((m, n), device='cuda', dtype=torch.bfloat16)
|
||||
deep_gemm.gemm_fp8_fp8_bf16_nt((A_deepgemm, A_scale_deepgemm),
|
||||
(B_deepgemm, B_scale_deepgemm),
|
||||
C_deepgemm)
|
||||
@ -98,8 +92,6 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === vLLM Triton Implementation ===
|
||||
def vllm_triton_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm, A_scale_vllm = per_token_group_quant_fp8(A, block_size[1])
|
||||
return w8a8_block_fp8_matmul(A_vllm,
|
||||
B_vllm,
|
||||
A_scale_vllm,
|
||||
@ -109,9 +101,6 @@ def benchmark_shape(m: int,
|
||||
|
||||
# === vLLM CUTLASS Implementation ===
|
||||
def vllm_cutlass_gemm():
|
||||
# A quantization is inside the loop as it depends on activations
|
||||
# A_vllm_cutlass, A_scale_vllm_cutlass = per_token_group_quant_fp8(
|
||||
# A, block_size[1], column_major_scales=True)
|
||||
return ops.cutlass_scaled_mm(A_vllm_cutlass,
|
||||
B_vllm.T,
|
||||
scale_a=A_scale_vllm_cutlass,
|
||||
|
||||
@ -12,9 +12,8 @@ endif()
|
||||
#
|
||||
# Define environment variables for special configurations
|
||||
#
|
||||
if(DEFINED ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
endif()
|
||||
set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16})
|
||||
set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI})
|
||||
|
||||
include_directories("${CMAKE_SOURCE_DIR}/csrc")
|
||||
|
||||
@ -96,12 +95,30 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16")
|
||||
set(ENABLE_AVX512BF16 ON)
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512BF16 OFF)
|
||||
message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.")
|
||||
endif()
|
||||
|
||||
find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND)
|
||||
if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI)
|
||||
if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND
|
||||
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni")
|
||||
set(ENABLE_AVX512VNNI ON)
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3")
|
||||
endif()
|
||||
else()
|
||||
set(ENABLE_AVX512VNNI OFF)
|
||||
message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.")
|
||||
endif()
|
||||
|
||||
elseif (AVX2_FOUND)
|
||||
list(APPEND CXX_COMPILE_FLAGS "-mavx2")
|
||||
@ -231,12 +248,25 @@ if (AVX512_FOUND AND NOT AVX512_DISABLED)
|
||||
"csrc/cpu/quant.cpp"
|
||||
"csrc/cpu/shm.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/sgl-kernels/gemm.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/gemm_fp8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_int8.cpp"
|
||||
"csrc/cpu/sgl-kernels/moe_fp8.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
add_compile_definitions(-DCPU_CAPABILITY_AVX512)
|
||||
endif()
|
||||
elseif(POWER10_FOUND)
|
||||
set(VLLM_EXT_SRC
|
||||
"csrc/cpu/quant.cpp"
|
||||
${VLLM_EXT_SRC})
|
||||
endif()
|
||||
|
||||
message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}")
|
||||
|
||||
#
|
||||
# Define extension targets
|
||||
#
|
||||
|
||||
@ -38,7 +38,7 @@ else()
|
||||
FetchContent_Declare(
|
||||
vllm-flash-attn
|
||||
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
|
||||
GIT_TAG 763ad155a1c826f71ff318f41edb1e4e5e376ddb
|
||||
GIT_TAG 1c2624e53c078854e0637ee566c72fe2107e75f4
|
||||
GIT_PROGRESS TRUE
|
||||
# Don't share the vllm-flash-attn build between build types
|
||||
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
|
||||
|
||||
@ -265,8 +265,8 @@ macro(set_gencode_flags_for_srcs)
|
||||
endmacro()
|
||||
|
||||
#
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form
|
||||
# `<major>.<minor>[letter]` compute the "loose intersection" with the
|
||||
# `TGT_CUDA_ARCHS` list of gencodes. We also support the `+PTX` suffix in
|
||||
# `SRC_CUDA_ARCHS` which indicates that the PTX code should be built when there
|
||||
# is a CUDA_ARCH in `TGT_CUDA_ARCHS` that is equal to or larger than the
|
||||
@ -278,7 +278,7 @@ endmacro()
|
||||
# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`.
|
||||
# We have special handling for x.0a, if x.0a is in `SRC_CUDA_ARCHS` and x.0 is
|
||||
# in `TGT_CUDA_ARCHS` then we should remove x.0a from `SRC_CUDA_ARCHS` and add
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# x.0a to the result (and remove x.0 from TGT_CUDA_ARCHS).
|
||||
# The result is stored in `OUT_CUDA_ARCHS`.
|
||||
#
|
||||
# Example:
|
||||
@ -313,21 +313,16 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
# if x.0a is in SRC_CUDA_ARCHS and x.0 is in CUDA_ARCHS then we should
|
||||
# remove x.0a from SRC_CUDA_ARCHS and add x.0a to _CUDA_ARCHS
|
||||
set(_CUDA_ARCHS)
|
||||
if ("9.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "9.0a")
|
||||
if ("9.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "9.0")
|
||||
set(_CUDA_ARCHS "9.0a")
|
||||
foreach(_arch ${_SRC_CUDA_ARCHS})
|
||||
if(_arch MATCHES "\\a$")
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "${_arch}")
|
||||
string(REPLACE "a" "" _base "${_arch}")
|
||||
if ("${_base}" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "${_base}")
|
||||
list(APPEND _CUDA_ARCHS "${_arch}")
|
||||
endif()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if ("10.0a" IN_LIST _SRC_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _SRC_CUDA_ARCHS "10.0a")
|
||||
if ("10.0" IN_LIST TGT_CUDA_ARCHS)
|
||||
list(REMOVE_ITEM _TGT_CUDA_ARCHS "10.0")
|
||||
set(_CUDA_ARCHS "10.0a")
|
||||
endif()
|
||||
endif()
|
||||
endforeach()
|
||||
|
||||
list(SORT _SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING)
|
||||
|
||||
@ -359,7 +354,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endforeach()
|
||||
|
||||
list(REMOVE_DUPLICATES _CUDA_ARCHS)
|
||||
|
||||
|
||||
# reapply +PTX suffix to architectures that requested PTX
|
||||
set(_FINAL_ARCHS)
|
||||
foreach(_arch ${_CUDA_ARCHS})
|
||||
@ -370,7 +365,7 @@ function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_AR
|
||||
endif()
|
||||
endforeach()
|
||||
set(_CUDA_ARCHS ${_FINAL_ARCHS})
|
||||
|
||||
|
||||
set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE)
|
||||
endfunction()
|
||||
|
||||
|
||||
@ -207,7 +207,7 @@ void cutlass_mla_decode_sm100a(torch::Tensor const& out,
|
||||
"page_table must be a 32-bit integer tensor");
|
||||
|
||||
auto in_dtype = q_nope.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)q_nope.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(q_nope));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(q_nope.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
|
||||
238
csrc/cpu/sgl-kernels/common.h
Normal file
238
csrc/cpu/sgl-kernels/common.h
Normal file
@ -0,0 +1,238 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/record_function.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#include <omp.h>
|
||||
#endif
|
||||
|
||||
namespace {
|
||||
|
||||
// dispatch bool
|
||||
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
|
||||
[&] { \
|
||||
if (BOOL_V) { \
|
||||
constexpr bool BOOL_NAME = true; \
|
||||
return __VA_ARGS__(); \
|
||||
} else { \
|
||||
constexpr bool BOOL_NAME = false; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
}()
|
||||
|
||||
// dispatch: bfloat16, float16, int8_t, fp8_e4m3
|
||||
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
|
||||
[&] { \
|
||||
switch (TYPE) { \
|
||||
case at::ScalarType::BFloat16 : { \
|
||||
using packed_t = at::BFloat16; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Half: { \
|
||||
using packed_t = at::Half; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Char : { \
|
||||
using packed_t = int8_t; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
case at::ScalarType::Float8_e4m3fn : { \
|
||||
using packed_t = at::Float8_e4m3fn; \
|
||||
return __VA_ARGS__(); \
|
||||
} \
|
||||
default: \
|
||||
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
|
||||
} \
|
||||
}()
|
||||
|
||||
#define UNUSED(x) (void)(x)
|
||||
|
||||
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CPU(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
// parallel routines
|
||||
constexpr int GRAIN_SIZE = 1024;
|
||||
|
||||
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
|
||||
inline T div_up(T x, T y) { return (x + y - 1) / y; }
|
||||
|
||||
template <typename T>
|
||||
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
|
||||
#if 0
|
||||
// onednn partition pattern
|
||||
T& n_my = n_end;
|
||||
if (nth <= 1 || n == 0) {
|
||||
n_start = 0;
|
||||
n_my = n;
|
||||
} else {
|
||||
T n1 = div_up(n, nth);
|
||||
T n2 = n1 - 1;
|
||||
T T1 = n - n2 * nth;
|
||||
n_my = ith < T1 ? n1 : n2;
|
||||
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
|
||||
}
|
||||
n_end += n_start;
|
||||
#else
|
||||
// pytorch aten partition pattern
|
||||
T n_my = div_up(n, nth);
|
||||
n_start = ith * n_my;
|
||||
n_end = std::min(n_start + n_my, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_for(int n, const func_t& f) {
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel
|
||||
{
|
||||
int nth = omp_get_num_threads();
|
||||
int ith = omp_get_thread_num();
|
||||
int tbegin, tend;
|
||||
balance211(n, nth, ith, tbegin, tend);
|
||||
f(tbegin, tend);
|
||||
}
|
||||
#else
|
||||
f(0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
// for 1d parallel, use `actual_nth`
|
||||
// for 2d parallel, use even nths, e.g. 43->42
|
||||
int inline adjust_num_threads(int m) {
|
||||
int actual_nth = at::get_num_threads();
|
||||
if (m == 1) {
|
||||
return actual_nth;
|
||||
}
|
||||
return std::max(1, (actual_nth >> 1) * 2);
|
||||
}
|
||||
|
||||
template <typename func_t>
|
||||
inline void parallel_2d(int m, int n, const func_t& f) {
|
||||
|
||||
// make sure we have even num_threads
|
||||
int nth = adjust_num_threads(m);
|
||||
|
||||
// [NOTE] thread blocking:
|
||||
//
|
||||
// 1) prefer square block per thread
|
||||
// 2) use even number of CPU cores
|
||||
// 3) use all `num_threads` cores
|
||||
//
|
||||
// we have:
|
||||
// TM * TN = T
|
||||
// BM / TM = BN / TN
|
||||
// then:
|
||||
// TM = ((BM / BN) * T) ^ 0.5
|
||||
//
|
||||
float r = float(m) / n;
|
||||
int nth_m = std::ceil(std::sqrt(r * nth));
|
||||
int nth_n = 1;
|
||||
for (; nth_m > 0; --nth_m) {
|
||||
nth_n = nth / nth_m;
|
||||
if (nth_m * nth_n == nth) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
#if defined(_OPENMP)
|
||||
#pragma omp parallel num_threads(nth)
|
||||
{
|
||||
int ith = omp_get_thread_num();
|
||||
int ith_m = ith / nth_n;
|
||||
int ith_n = ith % nth_n;
|
||||
|
||||
int thread_block_m = div_up(m, nth_m);
|
||||
int thread_block_n = div_up(n, nth_n);
|
||||
|
||||
int begin_m = ith_m * thread_block_m;
|
||||
int end_m = std::min(m, begin_m + thread_block_m);
|
||||
int begin_n = ith_n * thread_block_n;
|
||||
int end_n = std::min(n, begin_n + thread_block_n);
|
||||
|
||||
f(begin_m, end_m, begin_n, end_n);
|
||||
}
|
||||
#else
|
||||
f(0, m, 0, n);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
int get_cache_blocks(int BLOCK_SIZE, int K) {
|
||||
// L2 2MB and ratio of 50%
|
||||
const int L2_size = 2048 * 1024 >> 1;
|
||||
return std::max(1, int(L2_size / (BLOCK_SIZE * K * sizeof(T))));
|
||||
}
|
||||
|
||||
// data indexing for dimension collapse
|
||||
template <typename T>
|
||||
inline T data_index_init(T offset) {
|
||||
return offset;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
|
||||
offset = data_index_init(offset, std::forward<Args>(args)...);
|
||||
x = offset % X;
|
||||
return offset / X;
|
||||
}
|
||||
|
||||
inline bool data_index_step() {
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T, typename... Args>
|
||||
inline bool data_index_step(T& x, const T& X, Args&&... args) {
|
||||
if (data_index_step(std::forward<Args>(args)...)) {
|
||||
x = ((x + 1) == X) ? 0 : (x + 1);
|
||||
return x == 0;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
// forced unroll for perf critical path
|
||||
|
||||
#if __has_attribute(always_inline)
|
||||
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
|
||||
#else
|
||||
#define ALWAYS_INLINE inline
|
||||
#endif
|
||||
|
||||
template <int n>
|
||||
struct Unroll {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
Unroll<n - 1>{}(f, args...);
|
||||
f(std::integral_constant<int, n - 1>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
template <>
|
||||
struct Unroll<1> {
|
||||
template <typename Func, typename... Args>
|
||||
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
|
||||
f(std::integral_constant<int, 0>{}, args...);
|
||||
}
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
464
csrc/cpu/sgl-kernels/gemm.cpp
Normal file
@ -0,0 +1,464 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
// packed layout:
|
||||
// quants {N, K} int8_t
|
||||
// comp {N} int32_t
|
||||
template <int BLOCK_N>
|
||||
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
__m512i vcomp[COLS];
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
vcomp[col] = _mm512_setzero_si512();
|
||||
}
|
||||
|
||||
const int64_t offset = BLOCK_N * K;
|
||||
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
|
||||
for (int k = 0; k < K / 4; ++k) {
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
__m512i vb = _mm512_loadu_si512((const __m512i *)(packed + k * BLOCK_N * 4 + col * 64));
|
||||
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
|
||||
}
|
||||
}
|
||||
|
||||
for (int col = 0; col < COLS; ++col) {
|
||||
_mm512_storeu_si512((__m512i *)(packed + offset + col * 64), vcomp[col]);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "s8s8_compensation not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
// convert to vnni format
|
||||
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
|
||||
template <typename packed_t>
|
||||
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
|
||||
const int VNNI_BLK = 2;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(N == BLOCK_N);
|
||||
|
||||
const int VNNI_BLK = 4;
|
||||
for (int n = 0; n < N; ++n) {
|
||||
for (int k = 0; k < K / VNNI_BLK; ++k) {
|
||||
for (int d = 0; d < VNNI_BLK; ++d) {
|
||||
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
|
||||
}
|
||||
}
|
||||
}
|
||||
s8s8_compensation<BLOCK_N>(packed, K);
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_set1_ps(0.f);
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K2 = K >> 1;
|
||||
const int64_t lda2 = lda >> 1;
|
||||
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const float* b_ptr = reinterpret_cast<const float*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K2; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
// for COLS = 1, 3 use 256bit store
|
||||
if constexpr (COLS % 2 == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
} else {
|
||||
_mm256_storeu_si256(
|
||||
reinterpret_cast<__m256i*>(C + row * ldc + col * 16),
|
||||
(__m256i)(_mm512_cvtneps_pbh(vc[i])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, const float* __restrict__ bias,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb, BLOCK_N, /* add_C */false,
|
||||
A, B, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int64_t m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, has_bias>::apply(
|
||||
A, B, C, Ctmp, bias,
|
||||
M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void weight_packed_linear_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const scalar_t* __restrict__ mat2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
|
||||
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
|
||||
|
||||
// l2 cache block for n
|
||||
int64_t cache_blocks_nb = get_cache_blocks<scalar_t>(BLOCK_N, K);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
parallel_2d(MB, NB, [&](int64_t begin_mb, int64_t end_mb, int64_t begin_nb, int64_t end_nb) {
|
||||
|
||||
// for brgemm, use float32 for accumulate
|
||||
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int64_t nbb = begin_nb; nbb < end_nb; nbb += cache_blocks_nb) {
|
||||
for (int64_t mb = begin_mb; mb < end_mb; ++mb) {
|
||||
for (int64_t nb = nbb; nb < std::min(nbb + cache_blocks_nb, end_nb); ++nb) {
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm);
|
||||
}}}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, const TYPE* __restrict__ B, TYPE* __restrict__ C, \
|
||||
float* __restrict__ Ctmp, int64_t M, int64_t N, int64_t K, int64_t lda, \
|
||||
int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight) {
|
||||
// for 3d moe weights
|
||||
// weight : [E, OC, IC]
|
||||
// w1 : [E, 2N, K]
|
||||
// w2 : [E, K, N]
|
||||
CHECK_INPUT(weight);
|
||||
|
||||
const int64_t ndim = weight.ndimension();
|
||||
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
|
||||
const auto st = weight.scalar_type();
|
||||
const int64_t E = ndim == 3 ? weight.size(0) : 1;
|
||||
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
|
||||
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
|
||||
|
||||
// we handle 2 TILE_N at a time.
|
||||
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
|
||||
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
|
||||
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t NB = div_up(OC, BLOCK_N);
|
||||
|
||||
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
|
||||
auto packed_weight = at::empty({}, weight.options());
|
||||
const int64_t stride = OC * IC;
|
||||
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf || st == at::kChar || st == at::kFloat8_e4m3fn,
|
||||
"expect weight to be bfloat16, float16, int8 or fp8_e4m3.");
|
||||
|
||||
CPU_DISPATCH_PACKED_TYPES(st, [&] {
|
||||
// adjust most inner dimension size
|
||||
const int packed_row_size = get_row_size<packed_t>(IC);
|
||||
auto sizes = weight.sizes().vec();
|
||||
sizes[ndim - 1] = packed_row_size;
|
||||
packed_weight.resize_(sizes);
|
||||
|
||||
const packed_t* w_data = weight.data_ptr<packed_t>();
|
||||
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
|
||||
|
||||
// parallel on {E, NB}
|
||||
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t e{0}, nb{0};
|
||||
data_index_init(begin, e, E, nb, NB);
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
|
||||
int64_t n = nb * BLOCK_N;
|
||||
int64_t n_size = std::min(BLOCK_N, OC - n);
|
||||
pack_vnni<packed_t>(
|
||||
packed_data + e * OC * packed_row_size + n * packed_row_size,
|
||||
w_data + e * stride + n * IC,
|
||||
n_size,
|
||||
IC);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(e, E, nb, NB);
|
||||
}
|
||||
});
|
||||
});
|
||||
return packed_weight;
|
||||
}
|
||||
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias, bool is_vnni) {
|
||||
RECORD_FUNCTION(
|
||||
"sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options());
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
|
||||
weight_packed_linear_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<scalar_t>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
266
csrc/cpu/sgl-kernels/gemm.h
Normal file
@ -0,0 +1,266 @@
|
||||
#pragma once
|
||||
|
||||
#include <ATen/native/CPUBlas.h>
|
||||
|
||||
// clang-format off
|
||||
|
||||
// amx-bf16
|
||||
#define TILE_M 16
|
||||
#define TILE_N 16
|
||||
#define TILE_K 32
|
||||
|
||||
// block size for AMX gemm
|
||||
constexpr int block_size_m() { return 2 * TILE_M; }
|
||||
constexpr int block_size_n() { return 2 * TILE_N; }
|
||||
|
||||
// define threshold using brgemm (intel AMX)
|
||||
template <typename T> inline bool can_use_brgemm(int M);
|
||||
template <> inline bool can_use_brgemm<at::BFloat16>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::Half>(int M) { return true; }
|
||||
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
|
||||
template <> inline bool can_use_brgemm<int8_t>(int M) { return false; }
|
||||
template <> inline bool can_use_brgemm<at::Float8_e4m3fn>(int M) { return M > 4; }
|
||||
template <> inline bool can_use_brgemm<at::quint4x2>(int M) { return M > 4; }
|
||||
|
||||
// work around compiler internal error
|
||||
#define BLOCK_K 128 // 4 * TILE_K
|
||||
|
||||
// adjust leading dimension size for K
|
||||
template <typename T>
|
||||
inline int64_t get_row_size(int64_t K) {
|
||||
return K;
|
||||
}
|
||||
|
||||
template <>
|
||||
inline int64_t get_row_size<int8_t>(int64_t K) {
|
||||
return K + sizeof(int32_t);
|
||||
}
|
||||
|
||||
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
|
||||
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
|
||||
}
|
||||
|
||||
// pack weight to vnni format
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
// moe implementations for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for fp8 w8a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// moe implementations for int4 w4a16
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int4_w4a16_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::quint4x2* __restrict__ packed_w1,
|
||||
const at::quint4x2* __restrict__ packed_w2,
|
||||
const uint8_t* __restrict__ w1z,
|
||||
const uint8_t* __restrict__ w2z,
|
||||
const scalar_t* __restrict__ w1s,
|
||||
const scalar_t* __restrict__ w2s,
|
||||
int group_size,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad);
|
||||
|
||||
// shared expert implememntation for int8 w8a8
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K);
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const scalar_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::quint4x2* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
const uint8_t* __restrict__ Bz,
|
||||
const scalar_t* __restrict__ Bs,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int group_size,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
int64_t strideBz,
|
||||
int64_t strideBs,
|
||||
bool brg);
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
inline void print_16x32i(const __m512i x) {
|
||||
int32_t a[16];
|
||||
_mm512_storeu_si512((__m512i *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
inline void print_16x32(const __m512 x) {
|
||||
float a[16];
|
||||
_mm512_storeu_ps((__m512 *)a, x);
|
||||
|
||||
for (int i = 0; i < 16; i++){
|
||||
std::cout << a[i] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
|
||||
inline void print_32x8u(const __m256i x) {
|
||||
uint8_t a[32];
|
||||
_mm256_storeu_si256((__m256i *)a, x);
|
||||
|
||||
for (int i = 0; i < 32; ++i) {
|
||||
std::cout << int32_t(a[i]) << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
530
csrc/cpu/sgl-kernels/gemm_fp8.cpp
Normal file
@ -0,0 +1,530 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
// we use 4x32 for BLOCK_M
|
||||
#define BLOCK_SIZE_M_SCALE 4
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_add_stub(scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
|
||||
}
|
||||
}
|
||||
|
||||
inline void unpack_B(
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_B,
|
||||
int N,
|
||||
int K,
|
||||
int ldb,
|
||||
int ldb_tmp,
|
||||
float scale) {
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
// [K/2, N, 2]
|
||||
const int K2 = K >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(packed_B);
|
||||
const __m512 vd = _mm512_set1_ps(scale);
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
static_assert(BLOCK_N == 32);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
|
||||
#pragma GCC unroll 4
|
||||
for (int k = 0; k < K2; ++k) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
|
||||
}
|
||||
|
||||
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
|
||||
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
|
||||
|
||||
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
|
||||
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
|
||||
|
||||
// Apply scale
|
||||
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
|
||||
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
|
||||
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
|
||||
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
|
||||
|
||||
f0_lo = _mm512_mul_ps(f0_lo, vd);
|
||||
f0_hi = _mm512_mul_ps(f0_hi, vd);
|
||||
f1_lo = _mm512_mul_ps(f1_lo, vd);
|
||||
f1_hi = _mm512_mul_ps(f1_hi, vd);
|
||||
|
||||
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
|
||||
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
|
||||
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
|
||||
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
|
||||
}
|
||||
#else
|
||||
TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A, const packed_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A, const at::Float8_e4m3fn* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ bias, const float* __restrict__ scale, int K, int lda, int ldb, int ldc, int64_t block_size_K) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
|
||||
const int KB = div_up(K, BLOCK_K);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 64;
|
||||
constexpr int PREFETCH_SIZE_KB = 1;
|
||||
|
||||
__m512bh va;
|
||||
__m512bh vb[COLS];
|
||||
__m512 vc[ROWS * COLS];
|
||||
__m512 vsum[ROWS * COLS];
|
||||
|
||||
// block quant scale
|
||||
__m512 vscale;
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
constexpr int col = i % COLS;
|
||||
if constexpr (has_bias) {
|
||||
vc[i] = _mm512_loadu_ps(bias + col * 16);
|
||||
} else {
|
||||
vc[i] = _mm512_setzero_ps();
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int lda2 = lda >> 1;
|
||||
const int ldb2 = ldb; // ldb * 2 >> 1;
|
||||
const float* a_ptr = reinterpret_cast<const float*>(A);
|
||||
const uint16_t* b_ptr = reinterpret_cast<const uint16_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
|
||||
vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
|
||||
}
|
||||
}
|
||||
vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
|
||||
};
|
||||
|
||||
constexpr int BLOCK_K2 = BLOCK_K >> 1;
|
||||
for (int kb = 0; kb < KB; ++kb) {
|
||||
int kb_start = kb * BLOCK_K2;
|
||||
int kb_end = std::min(K, kb_start + BLOCK_K2);
|
||||
// 1. load scale vector
|
||||
vscale = _mm512_set1_ps(scale[kb]);
|
||||
if constexpr (PREFETCH_SIZE_KB > 0) {
|
||||
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
|
||||
}
|
||||
// 2. zero vsum for each block
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vsum[i] = _mm512_setzero_ps();
|
||||
});
|
||||
// 3. accumulate across each block
|
||||
for (int k = kb_start; k < kb_end; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
// 4. apply scale
|
||||
Unroll<ROWS * COLS>{}([&](auto i) {
|
||||
vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]);
|
||||
});
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2,4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, at::Float8_e4m3fn, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, scale, K, lda, ldb, ldc, block_size_K);
|
||||
|
||||
template <typename scalar_t, typename packed_t, bool has_bias>
|
||||
struct brgemm {
|
||||
static inline void apply(
|
||||
const scalar_t* __restrict__ A,
|
||||
const packed_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
TORCH_CHECK(false, "struct brgemm: primary template not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
template <bool has_bias>
|
||||
struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
|
||||
static inline void apply(
|
||||
const at::BFloat16* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
at::BFloat16* __restrict__ C,
|
||||
at::BFloat16* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ bias,
|
||||
const float* __restrict__ scale,
|
||||
int M,
|
||||
int N,
|
||||
int K,
|
||||
int lda,
|
||||
int ldb,
|
||||
int ldc) {
|
||||
|
||||
constexpr int BLOCK_N = block_size_n();
|
||||
|
||||
// [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
|
||||
const int ldb_tmp = BLOCK_N;
|
||||
|
||||
for (int k = 0; k < K; k += BLOCK_K) {
|
||||
int kb_size = std::min(BLOCK_K, K - k);
|
||||
|
||||
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
|
||||
unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
|
||||
}
|
||||
|
||||
at::native::cpublas::brgemm(
|
||||
M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
|
||||
|
||||
// copy from Ctmp to C
|
||||
for (int m = 0; m < M; ++m) {
|
||||
if constexpr (has_bias) {
|
||||
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
|
||||
} else {
|
||||
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
|
||||
if (brg) {
|
||||
brgemm<scalar_t, at::Float8_e4m3fn, has_bias>::apply(
|
||||
A, B, C, Btmp, Ctmp, bias, scale, M, N, K, lda, ldb, ldc);
|
||||
return;
|
||||
}
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
void fp8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ mat1,
|
||||
const at::Float8_e4m3fn* __restrict__ mat2,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
scalar_t* __restrict__ buffer,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t mat1_strideM,
|
||||
int64_t out_strideM,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
int64_t buffer_size_per_thread) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
const int64_t scale_size_K = div_up(K, block_size_K);
|
||||
const int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
// parallel on [MB, NB]
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
|
||||
float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
const float* scale_ptr = scales2 + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * mat1_strideM,
|
||||
/* B */ mat2 + nb_start * K, // nb * BLOCK_N * K
|
||||
/* C */ out + mb_start * out_strideM + nb_start,
|
||||
/* Btmp */ Btmp,
|
||||
/* Ctmp */ Ctmp,
|
||||
/* scale */ scale_ptr,
|
||||
/* bias */ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ mat1_strideM,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ out_strideM,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const scalar_t* __restrict__ A,
|
||||
const at::Float8_e4m3fn* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
scalar_t* __restrict__ Btmp,
|
||||
float* __restrict__ Ctmp,
|
||||
const float* __restrict__ scale,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg,
|
||||
int64_t block_size_K) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Btmp, Ctmp, scale, nullptr, M, N, K, lda, ldb, ldc, brg, block_size_K);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const TYPE* __restrict__ A, \
|
||||
const at::Float8_e4m3fn* __restrict__ B, \
|
||||
TYPE* __restrict__ C, \
|
||||
TYPE* __restrict__ Btmp, \
|
||||
float* __restrict__ Ctmp, \
|
||||
const float* __restrict__ scale, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t lda, \
|
||||
int64_t ldb, \
|
||||
int64_t ldc, \
|
||||
bool brg, \
|
||||
int64_t block_size_K)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
at::Tensor fp8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
std::vector<int64_t> block_size, std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::fp8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, block_size, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales2 to be float32.");
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat2.size(1);
|
||||
|
||||
CHECK_EQ(mat1.size(1), K);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
TORCH_CHECK(block_size.size() == 2,
|
||||
"fp8_scaled_mm_cpu: expect block_size.size() to be 2.");
|
||||
|
||||
int64_t block_size_N = block_size[0];
|
||||
int64_t block_size_K = block_size[1];
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
|
||||
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
|
||||
CHECK_EQ(scales2.size(0), div_up(N, block_size_N));
|
||||
CHECK_EQ(scales2.size(1), div_up(K, block_size_K));
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"fp8_scaled_mm_cpu: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"fp8_scaled_mm_cpu: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"fp8_scaled_mm_cpu: expect mat2 to be fp8_e4m3.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"fp8_scaled_mm_cpu: expect scales to be float32.");
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
// strides
|
||||
int64_t mat1_strideM = mat1.stride(0);
|
||||
int64_t out_strideM = out.stride(0);
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
// Btmp : [T, BLOCK_N * K]
|
||||
// Ctmp : [T, BLOCK_M * BLOCK_N]
|
||||
int num_threads = at::get_num_threads();
|
||||
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
|
||||
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
|
||||
fp8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<scalar_t>(),
|
||||
packed_w.data_ptr<at::Float8_e4m3fn>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
buffer.data_ptr<scalar_t>(),
|
||||
M,
|
||||
N,
|
||||
K,
|
||||
mat1_strideM,
|
||||
out_strideM,
|
||||
block_size_N,
|
||||
block_size_K,
|
||||
size_per_thread);
|
||||
});
|
||||
|
||||
return out;
|
||||
}
|
||||
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
440
csrc/cpu/sgl-kernels/gemm_int8.cpp
Normal file
@ -0,0 +1,440 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <bool has_bias, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
const float* __restrict__ bias, int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
// prefetch distance
|
||||
constexpr int PREFETCH_SIZE_K = 0;
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 vd0;
|
||||
__m512 vd1[COLS];
|
||||
|
||||
// oops! 4x4 spills but luckly we use 4x2
|
||||
__m512 vbias[COLS];
|
||||
|
||||
// [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
//
|
||||
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
|
||||
//
|
||||
// a * b = (a + 128) * b - 128 * b
|
||||
// s s u s u s
|
||||
//
|
||||
// 1) 128 * b is pre-computed when packing B to vnni formats
|
||||
// 2) a + 128 is fused when dynamically quantize A
|
||||
//
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
if constexpr (PREFETCH_SIZE_K > 0) {
|
||||
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
|
||||
}
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
vd0 = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
if constexpr (has_bias) {
|
||||
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
|
||||
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
|
||||
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
|
||||
if constexpr (has_bias) {
|
||||
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
|
||||
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
|
||||
} else {
|
||||
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
|
||||
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
|
||||
}
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
has_bias ? bias + nb_start : nullptr, K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t, bool has_bias>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc,
|
||||
bool brg) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
// mb_size = 1
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break;
|
||||
case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break;
|
||||
// mb_size = 2
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break;
|
||||
case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break;
|
||||
// mb_size = 3
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break;
|
||||
case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break;
|
||||
// mb_size = 4
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break;
|
||||
case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename scalar_t>
|
||||
void int8_scaled_mm_kernel_impl(
|
||||
scalar_t* __restrict__ out,
|
||||
const uint8_t* __restrict__ mat1,
|
||||
const int8_t* __restrict__ mat2,
|
||||
const float* __restrict__ scales1,
|
||||
const float* __restrict__ scales2,
|
||||
const float* __restrict__ bias,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
|
||||
const bool use_brgemm = false;
|
||||
|
||||
// K + 4 after compensation
|
||||
const int64_t packed_row_size = get_row_size<int8_t>(K);
|
||||
|
||||
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int64_t mb{0}, nb{0};
|
||||
data_index_init(begin, mb, MB, nb, NB);
|
||||
|
||||
// for brgemm, use int32_t for accumulate
|
||||
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
|
||||
|
||||
for (int i = begin; i < end; ++i) {
|
||||
UNUSED(i);
|
||||
int mb_start = mb * BLOCK_M;
|
||||
int mb_size = std::min(M - mb_start, BLOCK_M);
|
||||
int nb_start = nb * BLOCK_N;
|
||||
int nb_size = std::min(N - nb_start, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t, has_bias>(
|
||||
/* A */ mat1 + mb_start * K,
|
||||
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
|
||||
/* C */ out + mb_start * N + nb_start,
|
||||
/* Ctmp*/ Ctmp,
|
||||
/* As */ scales1 + mb_start,
|
||||
/* Bs */ scales2 + nb_start,
|
||||
/* bias*/ bias + nb_start,
|
||||
/* M */ mb_size,
|
||||
/* N */ nb_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ nb_size,
|
||||
/* ldc */ N,
|
||||
/* brg */ use_brgemm);
|
||||
|
||||
// move to the next index
|
||||
data_index_step(mb, MB, nb, NB);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
// tinygemm interface
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(const uint8_t* __restrict__ A, const int8_t* __restrict__ B, scalar_t* __restrict__ C,
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs,
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg) {
|
||||
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
|
||||
}
|
||||
|
||||
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
|
||||
template void tinygemm_kernel<TYPE>( \
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, TYPE* __restrict__ C, \
|
||||
int32_t* __restrict__ Ctmp, const float* __restrict__ As, const float* __restrict__ Bs, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, bool brg)
|
||||
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
|
||||
|
||||
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
|
||||
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
|
||||
CHECK_DIM(2, A);
|
||||
|
||||
int64_t M = A.size(0);
|
||||
int64_t K = A.size(1);
|
||||
int64_t lda = A.stride(0);
|
||||
|
||||
const auto st = A.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"per_token_quant_int8: expect A to be bfloat16 or half.");
|
||||
|
||||
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
|
||||
auto As = at::empty({M}, A.options().dtype(at::kFloat));
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
|
||||
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = As.data_ptr<float>();
|
||||
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
});
|
||||
return std::make_tuple(Aq, As);
|
||||
}
|
||||
|
||||
// weight : static, per-channel, symmetric
|
||||
// activation : dynamic, per-token, symmetric
|
||||
//
|
||||
// mat1 : [M, K]
|
||||
// mat2 : [N, K]
|
||||
// scales1 : [M]
|
||||
// scales2 : [N]
|
||||
// bias : [N]
|
||||
// out : [M, N]
|
||||
//
|
||||
at::Tensor int8_scaled_mm_cpu(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales1, at::Tensor& scales2,
|
||||
std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales1);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales1.numel(), M);
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm: expect scales to be float32.");
|
||||
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
mat1.data_ptr<uint8_t>(),
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
scales1.data_ptr<float>(),
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
|
||||
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2, at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias, at::ScalarType out_dtype, bool is_vnni) {
|
||||
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
|
||||
|
||||
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
|
||||
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
|
||||
CHECK_INPUT(mat2);
|
||||
CHECK_INPUT(scales2);
|
||||
CHECK_DIM(2, mat1);
|
||||
CHECK_DIM(2, mat2);
|
||||
|
||||
int64_t M = mat1.size(0);
|
||||
int64_t N = mat2.size(0);
|
||||
int64_t K = mat1.size(1);
|
||||
int64_t lda = mat1.stride(0);
|
||||
|
||||
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
|
||||
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
|
||||
CHECK_EQ(scales2.numel(), N);
|
||||
|
||||
const auto st = mat1.scalar_type();
|
||||
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf,
|
||||
"int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
|
||||
TORCH_CHECK(st == out_dtype,
|
||||
"int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
|
||||
TORCH_CHECK(mat2.scalar_type() == at::kChar,
|
||||
"int8_scaled_mm_with_quant: expect mat2 to be int8.");
|
||||
TORCH_CHECK(scales2.scalar_type() == at::kFloat,
|
||||
"int8_scaled_mm_with_quant: expect scales to be float32.");
|
||||
|
||||
const int64_t buffer_size = M * K + M * sizeof(float);
|
||||
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
|
||||
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
|
||||
|
||||
const bool has_bias = bias.has_value();
|
||||
const float* bias_data = nullptr;
|
||||
if (has_bias) {
|
||||
CHECK_EQ(bias.value().size(0), N);
|
||||
bias_data = bias.value().data_ptr<float>();
|
||||
}
|
||||
|
||||
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
|
||||
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
|
||||
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
|
||||
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
|
||||
|
||||
at::parallel_for(0, M, 0, [&] (int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_data + m * K,
|
||||
As_data[m],
|
||||
A_data + m * lda,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
int8_scaled_mm_kernel_impl<scalar_t>(
|
||||
out.data_ptr<scalar_t>(),
|
||||
Aq_data,
|
||||
packed_w.data_ptr<int8_t>(),
|
||||
As_data,
|
||||
scales2.data_ptr<float>(),
|
||||
bias_data,
|
||||
M,
|
||||
N,
|
||||
K);
|
||||
});
|
||||
return out;
|
||||
}
|
||||
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
1330
csrc/cpu/sgl-kernels/moe.cpp
Normal file
File diff suppressed because it is too large
Load Diff
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
502
csrc/cpu/sgl-kernels/moe_fp8.cpp
Normal file
@ -0,0 +1,502 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "gemm.h"
|
||||
#include "vec.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
x0 = x0 * weight_vec;
|
||||
x1 = x1 * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
float scale,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
bVec x_bvec = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void silu_and_mul_stub(
|
||||
scalar_t* __restrict__ out,
|
||||
const scalar_t* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2,
|
||||
int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
const fVec one = fVec(1.f);
|
||||
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += bVec::size()) {
|
||||
bVec x = bVec::loadu(input + d);
|
||||
fVec x0, x1;
|
||||
std::tie(x0, x1) = at::vec::convert_to_float(x);
|
||||
bVec y = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y);
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
scalar_t* __restrict__ A_tmp,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_N = div_up(2 * N, block_size_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const int64_t stride_e = 2 * N * K;
|
||||
const int64_t stride_n = K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w1 + expert_id * stride_e + nb * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs = w1s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, input + index * K, K);
|
||||
}
|
||||
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
scale_size_N = div_up(K, block_size_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
const int64_t stride_e2 = OC * IC;
|
||||
const int64_t stride_oc = IC;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
bool is_brgemm_used = false;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(m_size);
|
||||
is_brgemm_used = is_brgemm_used || use_brgemm;
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const at::Float8_e4m3fn* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * scale_size_N * scale_size_K + (nb / blocks_n_per_group) * scale_size_K;
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
|
||||
if (is_brgemm_used) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_FP8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, \
|
||||
TYPE* __restrict__ A_tmp, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const float* __restrict__ topk_weights, \
|
||||
const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, \
|
||||
const int32_t* __restrict__ offsets, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K, \
|
||||
int64_t E, \
|
||||
int64_t topk, \
|
||||
int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_FP8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_fp8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic0,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ B_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1,
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
int64_t block_size_N,
|
||||
int64_t block_size_K,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 1: intermediate_cache0 = hidden_states @ w1
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(2 * N, BLOCK_N);
|
||||
int64_t scale_size_K = div_up(K, block_size_K);
|
||||
int64_t blocks_n_per_group = block_size_N / BLOCK_N;
|
||||
|
||||
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
|
||||
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(2 * N - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ input + mb * BLOCK_M * K,
|
||||
/* B */ packed_w1 + nb * BLOCK_N * K,
|
||||
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ 2 * N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
}
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: intermediate_cache1 = silu(intermediate_cache0)
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
silu_and_mul_stub(
|
||||
ic1 + m * N,
|
||||
ic0 + m * 2 * N,
|
||||
ic0 + m * 2 * N + N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(K, BLOCK_N);
|
||||
scale_size_K = div_up(N, block_size_K);
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
int tid = at::get_thread_num();
|
||||
alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ ic1 + mb * BLOCK_M * N,
|
||||
/* B */ packed_w2 + nb * BLOCK_N * N,
|
||||
/* C */ C,
|
||||
/* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
|
||||
/* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
|
||||
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N,
|
||||
/* brg */ use_brgemm,
|
||||
/* block_size_K */ block_size_K);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (use_brgemm) {
|
||||
at::native::cpublas::brgemm_release();
|
||||
}
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_fp8_kernel_impl<TYPE>( \
|
||||
TYPE* __restrict__ output, \
|
||||
TYPE* __restrict__ ic0, \
|
||||
TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ B_tmp, \
|
||||
float* __restrict__ C_tmp, \
|
||||
const TYPE* __restrict__ input, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w1, \
|
||||
const at::Float8_e4m3fn* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, \
|
||||
const float* __restrict__ w2s, \
|
||||
int64_t block_size_N, \
|
||||
int64_t block_size_K, \
|
||||
const TYPE* __restrict__ fused_experts_out, \
|
||||
float routed_scaling_factor, \
|
||||
int64_t M, \
|
||||
int64_t N, \
|
||||
int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_FP8_TEMPLATE(at::Half);
|
||||
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
769
csrc/cpu/sgl-kernels/moe_int8.cpp
Normal file
@ -0,0 +1,769 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#include "common.h"
|
||||
#include "vec.h"
|
||||
#include "gemm.h"
|
||||
|
||||
// clang-format off
|
||||
|
||||
namespace {
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
|
||||
using Vec = at::vec::Vectorized<scalar_t>;
|
||||
// no remainder
|
||||
#pragma GCC unroll 4
|
||||
for (int64_t d = 0; d < size; d += Vec::size()) {
|
||||
Vec data = Vec::loadu(input + d);
|
||||
data.store(out + d);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
|
||||
// size might be 64x + 32
|
||||
std::memcpy(out, input, size * sizeof(uint8_t));
|
||||
}
|
||||
|
||||
template <typename scalar_t>
|
||||
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec weight_vec = fVec(weight);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec data0 = fVec::loadu(input + d) * weight_vec;
|
||||
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] * weight);
|
||||
}
|
||||
}
|
||||
|
||||
// acc from [topk, K] to [K]
|
||||
template <typename scalar_t>
|
||||
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
if (topk == 1) {
|
||||
// do copy for topk = 1
|
||||
copy_stub(out, input, K);
|
||||
} else {
|
||||
// do sum for topk != 1
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= K - kVecSize; d += kVecSize) {
|
||||
fVec sum_fvec0 = fVec(0.f);
|
||||
fVec sum_fvec1 = fVec(0.f);
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
bVec x_bvec = bVec::loadu(input + t * K + d);
|
||||
fVec x_fvec0, x_fvec1;
|
||||
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
|
||||
|
||||
sum_fvec0 += x_fvec0;
|
||||
sum_fvec1 += x_fvec1;
|
||||
}
|
||||
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
|
||||
out_bvec.store(out + d);
|
||||
}
|
||||
for (; d < K; ++d) {
|
||||
float sum_val = 0.f;
|
||||
for (int t = 0; t < topk; ++t) {
|
||||
sum_val += static_cast<float>(input[t * K + d]);
|
||||
}
|
||||
out[d] = static_cast<scalar_t>(sum_val);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// out = input + input2 * scale
|
||||
template <typename scalar_t>
|
||||
inline void add_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input,
|
||||
const scalar_t* __restrict__ input2, float scale, int64_t size) {
|
||||
|
||||
using bVec = at::vec::Vectorized<scalar_t>;
|
||||
using fVec = at::vec::Vectorized<float>;
|
||||
constexpr int kVecSize = bVec::size();
|
||||
const fVec s_vec = fVec(scale);
|
||||
int64_t d;
|
||||
#pragma GCC unroll 4
|
||||
for (d = 0; d <= size - kVecSize; d += kVecSize) {
|
||||
fVec x0 = fVec::loadu(input + d);
|
||||
fVec x1 = fVec::loadu(input + d + fVec::size());
|
||||
|
||||
bVec y_bvec = bVec::loadu(input2 + d);
|
||||
fVec y0, y1;
|
||||
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
|
||||
|
||||
x0 = x0 + y0 * s_vec;
|
||||
x1 = x1 + y1 * s_vec;
|
||||
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
|
||||
out_vec.store(out + d);
|
||||
}
|
||||
for (; d < size; ++d) {
|
||||
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w13
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B0, const int8_t* __restrict__ B1, at::BFloat16* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs0, const float* __restrict__ Bs1,
|
||||
const int32_t* __restrict__ Bcomp0, const int32_t* __restrict__ Bcomp1,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb0[COLS];
|
||||
__m512i vb1[COLS];
|
||||
__m512i vc0[ROWS * COLS];
|
||||
__m512i vc1[ROWS * COLS];
|
||||
__m512i vcomp0[COLS];
|
||||
__m512i vcomp1[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs0[COLS];
|
||||
__m512 vbs1[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc0[i] = _mm512_set1_epi32(0);
|
||||
vc1[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
|
||||
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
|
||||
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
|
||||
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto scalec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp
|
||||
if constexpr (row == 0) {
|
||||
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
|
||||
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
|
||||
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
|
||||
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
|
||||
}
|
||||
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
|
||||
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
|
||||
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, was), vbs0[col]));
|
||||
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, was), vbs1[col]));
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(scalec);
|
||||
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
const Vec one = Vec(1.f);
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
// for COLS = 2, 4 use 512bit store
|
||||
if constexpr (col % 2 == 0) {
|
||||
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
|
||||
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
|
||||
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
|
||||
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
|
||||
// silu
|
||||
x0 = x0 / (one + x0.neg().exp_u20());
|
||||
x1 = x1 / (one + x1.neg().exp_u20());
|
||||
// mul
|
||||
x0 = x0 * y0;
|
||||
x1 = x1 * y1;
|
||||
|
||||
_mm512_storeu_si512(
|
||||
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
|
||||
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
|
||||
}
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B0 + nb_start * 4, B1 + nb_start * 4, \
|
||||
C + mb_start * ldc + nb_start, As + mb_start, \
|
||||
Bs0 + nb_start, Bs1 + nb_start, Bcomp0 + nb_start, Bcomp1 + nb_start,\
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B0,
|
||||
const int8_t* __restrict__ B1,
|
||||
scalar_t* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs0,
|
||||
const float* __restrict__ Bs1,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
|
||||
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
|
||||
|
||||
// pattern: 1-(2+2)-(8+8)
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 32;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// gemm for w2
|
||||
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2 {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
|
||||
}
|
||||
};
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <int BLOCK_M, int BLOCK_N>
|
||||
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
|
||||
static inline void apply(
|
||||
const uint8_t* __restrict__ A, const int8_t* __restrict__ B, float* __restrict__ C,
|
||||
const float* __restrict__ As, const float* __restrict__ Bs, const int32_t* __restrict__ Bcomp,
|
||||
int64_t K, int64_t lda, int64_t ldb, int64_t ldc) {
|
||||
|
||||
constexpr int ROWS = BLOCK_M;
|
||||
constexpr int COLS = BLOCK_N / 16;
|
||||
static_assert(COLS % 2 == 0);
|
||||
|
||||
__m512i va;
|
||||
__m512i vb[COLS];
|
||||
__m512i vc[ROWS * COLS];
|
||||
__m512i vcomp[COLS];
|
||||
__m512 was;
|
||||
__m512 vbs[COLS];
|
||||
|
||||
auto loadc = [&](auto i) {
|
||||
vc[i] = _mm512_set1_epi32(0);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(loadc);
|
||||
|
||||
const int64_t K4 = K >> 2;
|
||||
const int64_t lda4 = lda >> 2;
|
||||
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
|
||||
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
|
||||
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
|
||||
|
||||
auto compute = [&](auto i, int64_t k) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
if constexpr (col == 0) {
|
||||
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
|
||||
}
|
||||
if constexpr (row == 0) {
|
||||
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
|
||||
}
|
||||
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
|
||||
};
|
||||
for (int64_t k = 0; k < K4; ++k) {
|
||||
Unroll<ROWS * COLS>{}(compute, k);
|
||||
}
|
||||
|
||||
auto storec = [&](auto i) {
|
||||
constexpr int row = i / COLS;
|
||||
constexpr int col = i % COLS;
|
||||
|
||||
// load a scale
|
||||
if constexpr(col == 0) {
|
||||
was = _mm512_set1_ps(As[row]);
|
||||
}
|
||||
// load b scale and vcomp per 2 vectors
|
||||
// also load bias if any
|
||||
if constexpr (row == 0) {
|
||||
if constexpr (col % 2 == 0) {
|
||||
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
|
||||
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
|
||||
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
|
||||
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
|
||||
}
|
||||
}
|
||||
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
|
||||
x = _mm512_mul_ps(_mm512_mul_ps(x, was), vbs[col]);
|
||||
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
|
||||
};
|
||||
Unroll<ROWS * COLS>{}(storec);
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
|
||||
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
|
||||
A + mb_start * lda, B + nb_start * 4, C + mb_start * ldc + nb_start, \
|
||||
As + mb_start, Bs + nb_start, Bcomp + nb_start, \
|
||||
K, lda, ldb, ldc);
|
||||
|
||||
template <typename scalar_t>
|
||||
void tinygemm_kernel(
|
||||
const uint8_t* __restrict__ A,
|
||||
const int8_t* __restrict__ B,
|
||||
float* __restrict__ C,
|
||||
const float* __restrict__ As,
|
||||
const float* __restrict__ Bs,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t lda,
|
||||
int64_t ldb,
|
||||
int64_t ldc) {
|
||||
|
||||
// B compensation
|
||||
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
|
||||
|
||||
// pattern: 1-4-16
|
||||
constexpr int64_t BLOCK_M = 4;
|
||||
constexpr int64_t BLOCK_N = 64;
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
for (int64_t mb = 0; mb < MB; ++mb) {
|
||||
int64_t mb_start = mb * BLOCK_M;
|
||||
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
|
||||
for (int64_t nb = 0; nb < NB; ++nb) {
|
||||
int64_t nb_start = nb * BLOCK_N;
|
||||
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
|
||||
|
||||
switch(mb_size << 4 | nb_size >> 4) {
|
||||
case 0x12: LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32); break;
|
||||
case 0x22: LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32); break;
|
||||
case 0x32: LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32); break;
|
||||
case 0x42: LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32); break;
|
||||
default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename scalar_t>
|
||||
void fused_experts_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
scalar_t* __restrict__ ic2,
|
||||
uint8_t* __restrict__ A_tmp,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const float* __restrict__ topk_weights,
|
||||
const int32_t* __restrict__ sorted_ids,
|
||||
const int32_t* __restrict__ expert_ids,
|
||||
const int32_t* __restrict__ offsets,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K,
|
||||
int64_t E,
|
||||
int64_t topk,
|
||||
int64_t num_tokens_post_pad) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
// strides for w1: [E, 2N, K]
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
|
||||
const int64_t stride_e = 2 * N * packed_K;
|
||||
const int64_t stride_n = packed_K;
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
|
||||
|
||||
alignas(64) float As[BLOCK_M];
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
|
||||
|
||||
// 1.a load A
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m] / topk;
|
||||
copy_stub(A + m * K, Aq_tmp + index * K, K);
|
||||
As[m] = As_tmp[index];
|
||||
}
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
const int64_t offset = offsets[mb];
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + offset * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [E, K, N] as [E, OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_e2 = OC * packed_N;
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = offsets[mb + 1] - offsets[mb];
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A ptr from ic1 of [M * topk, N] in sorted order
|
||||
// so as to avoid copy A to tmp buffer again
|
||||
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
|
||||
const float* __restrict__ As = As_tmp + offsets[mb];
|
||||
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
int32_t expert_id = expert_ids[mb];
|
||||
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to ic2 in original order
|
||||
// and also mul topk_weights in float32
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
int32_t index = A_ids[m];
|
||||
float weight = topk_weights[index];
|
||||
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// stage 3: out = intermediate_cache2.sum(dim=1)
|
||||
// from [M, topk, K] to [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
|
||||
template void fused_experts_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
TYPE* __restrict__ ic2, uint8_t* __restrict__ A_tmp, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const float* __restrict__ topk_weights, const int32_t* __restrict__ sorted_ids, \
|
||||
const int32_t* __restrict__ expert_ids, const int32_t* __restrict__ offsets, \
|
||||
int64_t M, int64_t N, int64_t K, int64_t E, int64_t topk, int64_t num_tokens_post_pad)
|
||||
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
|
||||
|
||||
template <typename scalar_t>
|
||||
void shared_expert_int8_kernel_impl(
|
||||
scalar_t* __restrict__ output,
|
||||
scalar_t* __restrict__ ic1,
|
||||
float* __restrict__ C_tmp,
|
||||
uint8_t* __restrict__ Aq_tmp,
|
||||
float* __restrict__ As_tmp,
|
||||
const scalar_t* __restrict__ input,
|
||||
const int8_t* __restrict__ packed_w1,
|
||||
const int8_t* __restrict__ packed_w2,
|
||||
const float* __restrict__ w1s,
|
||||
const float* __restrict__ w2s,
|
||||
const scalar_t* __restrict__ fused_experts_out,
|
||||
float routed_scaling_factor,
|
||||
int64_t M,
|
||||
int64_t N,
|
||||
int64_t K) {
|
||||
|
||||
// handle 2 tiles per block
|
||||
constexpr int64_t BLOCK_M = block_size_m();
|
||||
constexpr int64_t BLOCK_N = block_size_n();
|
||||
|
||||
// stage 0: quantize input to uint8, [M, K]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * K,
|
||||
As_tmp[m],
|
||||
input + m * K,
|
||||
K);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
|
||||
const int64_t MB = div_up(M, BLOCK_M);
|
||||
const int64_t NB = div_up(N, BLOCK_N);
|
||||
|
||||
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
|
||||
|
||||
// K and N are packed for int8
|
||||
const int64_t packed_K = get_row_size<int8_t>(K);
|
||||
const int64_t packed_N = get_row_size<int8_t>(N);
|
||||
const int64_t stride_n = packed_K;
|
||||
|
||||
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
|
||||
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB;
|
||||
int64_t nb = i % NB;
|
||||
|
||||
// nb0 from top half and nb1 from bottom half
|
||||
int64_t nb0 = nb, nb1 = nb + NB;
|
||||
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
|
||||
// A shape [m_size, K]
|
||||
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
|
||||
const float* As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [K, n_size] in vnni format
|
||||
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
|
||||
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
|
||||
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
|
||||
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
|
||||
|
||||
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
|
||||
tinygemm_kernel(
|
||||
/* A */ A,
|
||||
/* B0 */ B0,
|
||||
/* B1 */ B1,
|
||||
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
|
||||
/* As */ As,
|
||||
/* Bs0 */ Bs0,
|
||||
/* Bs1 */ Bs1,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ K,
|
||||
/* lda */ K,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
|
||||
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t m = begin; m < end; ++m) {
|
||||
quantize_row_int8<scalar_t>(
|
||||
Aq_tmp + m * N,
|
||||
As_tmp[m],
|
||||
ic1 + m * N,
|
||||
N);
|
||||
}
|
||||
});
|
||||
|
||||
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
|
||||
// w2 : [K, N] as [OC, IC]
|
||||
const int64_t OC = K; // rename K as OC
|
||||
const int64_t IC = N; // rename N as IC
|
||||
const int64_t MB2 = MB;
|
||||
const int64_t NB2 = div_up(OC, BLOCK_N);
|
||||
const int64_t stride_oc = packed_N;
|
||||
|
||||
// parallel on [MB2, NB2]
|
||||
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
|
||||
// get local pointers
|
||||
int tid = at::get_thread_num();
|
||||
// we won't be using C1 for gemm2
|
||||
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
|
||||
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
int64_t mb = i / NB2;
|
||||
int64_t nb = i % NB2;
|
||||
|
||||
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
|
||||
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
|
||||
|
||||
// A shape [m_size, IC]
|
||||
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
|
||||
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
|
||||
|
||||
// B shape [IC, n_size] in vnni format
|
||||
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
|
||||
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
|
||||
|
||||
// 2.a gemm: C = A @ B
|
||||
tinygemm_kernel<scalar_t>(
|
||||
/* A */ A,
|
||||
/* B */ B,
|
||||
/* C */ C,
|
||||
/* As */ As,
|
||||
/* Bs */ Bs,
|
||||
/* M */ m_size,
|
||||
/* N */ n_size,
|
||||
/* K */ IC,
|
||||
/* lda */ IC,
|
||||
/* ldb */ n_size,
|
||||
/* ldc */ BLOCK_N);
|
||||
|
||||
// 2.b copy from C to output and add fused_experts_out
|
||||
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
|
||||
for (int64_t m = 0; m < m_size; ++m) {
|
||||
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
|
||||
template void shared_expert_int8_kernel_impl<TYPE> ( \
|
||||
TYPE* __restrict__ output, TYPE* __restrict__ ic1, \
|
||||
float* __restrict__ C_tmp, uint8_t* __restrict__ Aq_tmp, \
|
||||
float* __restrict__ As_tmp, const TYPE* __restrict__ input, \
|
||||
const int8_t* __restrict__ packed_w1, const int8_t* __restrict__ packed_w2, \
|
||||
const float* __restrict__ w1s, const float* __restrict__ w2s, \
|
||||
const TYPE* __restrict__ fused_experts_out, float routed_scaling_factor, \
|
||||
int64_t M, int64_t N, int64_t K)
|
||||
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
|
||||
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);
|
||||
308
csrc/cpu/sgl-kernels/vec.h
Normal file
308
csrc/cpu/sgl-kernels/vec.h
Normal file
@ -0,0 +1,308 @@
|
||||
// Adapted from
|
||||
// https://github.com/sgl-project/sglang/tree/main/sgl-kernel/csrc/cpu
|
||||
|
||||
#pragma once
|
||||
|
||||
// clang-format off
|
||||
|
||||
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
|
||||
#define CPU_CAPABILITY_AVX512
|
||||
#endif
|
||||
|
||||
#include <ATen/cpu/vec/functional.h>
|
||||
#include <ATen/cpu/vec/vec.h>
|
||||
|
||||
namespace {
|
||||
|
||||
using namespace at::vec;
|
||||
|
||||
template <typename scalar_t,
|
||||
typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
|
||||
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return at::vec::convert_from_float<scalar_t>(a, b);
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
|
||||
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
|
||||
// use native instruction for bfloat16->float32 conversion
|
||||
template <>
|
||||
inline Vectorized<at::BFloat16> convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
|
||||
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
|
||||
}
|
||||
|
||||
#define CVT_BF16_TO_FP32(a) \
|
||||
_mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
|
||||
|
||||
#define CVT_FP16_TO_FP32(a) \
|
||||
_mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
|
||||
|
||||
// this doesn't hanel NaN.
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
|
||||
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
|
||||
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
|
||||
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
|
||||
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
|
||||
const __m512i nonsign = _mm512_or_si512(exp, mant);
|
||||
|
||||
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
|
||||
const __m512i combined = _mm512_or_si512(nonsign, sign);
|
||||
|
||||
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
|
||||
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
|
||||
// The following conversion is without denorm behavior, that is to say,
|
||||
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
|
||||
// Min subnorm : S.0000.001 = 2**(−9)
|
||||
// 0.0019 ~ 0.0137 cannot be converted correctly.
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
auto mask = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_setzero_si512()); // mask = x & 0x7f
|
||||
auto mask_nan = _mm512_cmpneq_epi16_mask(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(127)),
|
||||
_mm512_set1_epi16(127)); // mask_nan = x & 0x7f
|
||||
auto mantissa = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4); // mantissa = (x & 7) << 4
|
||||
auto exponent = _mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3),
|
||||
_mm512_set1_epi16(120)); // exponent = (((x >> 3) & 15) + 120)
|
||||
auto nonsign = _mm512_maskz_mov_epi16(mask, _mm512_or_si512(mantissa, _mm512_slli_epi16(exponent, 7)));
|
||||
nonsign = _mm512_mask_mov_epi16(_mm512_set1_epi16(0x7fff), mask_nan, nonsign); // deal with Nan
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
nonsign,
|
||||
_mm512_slli_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(128)),
|
||||
8))); // add sign (x & 128) << 8
|
||||
}
|
||||
|
||||
inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
|
||||
__m512i x = _mm512_cvtepu8_epi16(fp8_vec);
|
||||
__m512i lg2mant = _mm512_mask_mov_epi16(
|
||||
_mm512_mask_mov_epi16(
|
||||
_mm512_setzero_si512(), _mm512_test_epi16_mask(x, _mm512_set1_epi16(2)), _mm512_set1_epi16(1)),
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(4)),
|
||||
_mm512_set1_epi16(2));
|
||||
return (__m512bh)(_mm512_or_si512(
|
||||
_mm512_maskz_mov_epi16(
|
||||
_mm512_cmpneq_epi16_mask(_mm512_and_si512(x, _mm512_set1_epi16(127)), _mm512_setzero_si512()),
|
||||
_mm512_mask_blend_epi16(
|
||||
_mm512_test_epi16_mask(x, _mm512_set1_epi16(120)),
|
||||
_mm512_or_si512(
|
||||
_mm512_and_si512(
|
||||
_mm512_sllv_epi16(
|
||||
_mm512_and_si512(x, _mm512_set1_epi16(3)), _mm512_sub_epi16(_mm512_set1_epi16(7), lg2mant)),
|
||||
_mm512_set1_epi16(0x007f)),
|
||||
_mm512_slli_epi16(_mm512_add_epi16(lg2mant, _mm512_set1_epi16(118)), 7)),
|
||||
_mm512_or_si512(
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(7)), 4),
|
||||
_mm512_slli_epi16(
|
||||
_mm512_add_epi16(
|
||||
_mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(120)), 3), _mm512_set1_epi16(120)),
|
||||
7)))),
|
||||
_mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(128)), 8)));
|
||||
}
|
||||
|
||||
inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
|
||||
#ifdef SGLANG_CPU_FP8_CVT_FTZ
|
||||
return cvt_e4m3_bf16_intrinsic_no_nan(a);
|
||||
#else
|
||||
return cvt_e4m3_bf16_intrinsic_with_denorm(a);
|
||||
#endif
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
// vector to scalar reduction
|
||||
#if defined(CPU_CAPABILITY_AVX512) && 0
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_add_ps(__m512(a));
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return _mm512_reduce_max_ps(__m512(a));
|
||||
}
|
||||
#else
|
||||
inline float vec_reduce_sum(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
|
||||
}
|
||||
|
||||
inline float vec_reduce_max(const Vectorized<float>& a) {
|
||||
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
|
||||
}
|
||||
#endif
|
||||
|
||||
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
|
||||
template <typename scalar_t>
|
||||
inline void quantize_row_int8(uint8_t* __restrict__ Aq, float& As,
|
||||
const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
|
||||
|
||||
float amax = 0.f; // absolute max
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]);
|
||||
amax = std::max(amax, std::abs(val));
|
||||
}
|
||||
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
|
||||
for (int64_t k = 0; k < K; ++k) {
|
||||
const float val = static_cast<float>(A[k]) * inv_scale;
|
||||
Aq[k] = (uint8_t)(std::round(val)) + 128;
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
template <>
|
||||
inline void quantize_row_int8<at::BFloat16>(uint8_t* __restrict__ Aq, float& As,
|
||||
const at::BFloat16* __restrict__ A, int64_t K, float eps) {
|
||||
|
||||
const __m512 signBit = _mm512_set1_ps(-0.0f);
|
||||
const __m512i off = _mm512_set1_epi32(128);
|
||||
|
||||
// K is 32x, no remainder
|
||||
float amax = 0.f;
|
||||
__m512 vamax0 = _mm512_set1_ps(0.f);
|
||||
__m512 vamax1 = _mm512_set1_ps(0.f);
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
|
||||
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
|
||||
}
|
||||
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
|
||||
amax = std::max(amax, eps);
|
||||
const float scale = amax / 127;
|
||||
const float inv_scale = 127 / amax;
|
||||
const __m512 vd = _mm512_set1_ps(inv_scale);
|
||||
|
||||
for (int64_t k = 0; k < K; k += 32) {
|
||||
__m512i va = _mm512_loadu_si512((void*)(A + k));
|
||||
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
|
||||
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
|
||||
va0 = _mm512_mul_ps(va0, vd);
|
||||
va1 = _mm512_mul_ps(va1, vd);
|
||||
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
|
||||
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
|
||||
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
|
||||
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
|
||||
}
|
||||
As = scale;
|
||||
}
|
||||
#endif
|
||||
|
||||
// transpose utils
|
||||
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
|
||||
#if defined(CPU_CAPABILITY_AVX512)
|
||||
inline void transpose_16x16_32bit(__m512i * v) {
|
||||
__m512i v1[16];
|
||||
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
|
||||
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
|
||||
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
|
||||
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
|
||||
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
|
||||
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
|
||||
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
|
||||
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
|
||||
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
|
||||
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
|
||||
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
|
||||
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
|
||||
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
|
||||
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
|
||||
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
|
||||
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
|
||||
|
||||
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
|
||||
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
|
||||
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
|
||||
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
|
||||
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
|
||||
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
|
||||
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
|
||||
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
|
||||
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
|
||||
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
|
||||
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
|
||||
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
|
||||
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
|
||||
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
|
||||
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
|
||||
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
|
||||
|
||||
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
|
||||
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
|
||||
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
|
||||
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
|
||||
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
|
||||
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
|
||||
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
|
||||
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
|
||||
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
|
||||
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
|
||||
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
|
||||
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
|
||||
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
|
||||
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
|
||||
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
|
||||
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
|
||||
|
||||
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
|
||||
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
|
||||
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
|
||||
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
|
||||
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
|
||||
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
|
||||
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
|
||||
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
|
||||
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
|
||||
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
|
||||
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
|
||||
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
|
||||
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
|
||||
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
|
||||
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
|
||||
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
|
||||
}
|
||||
|
||||
// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes]
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wignored-attributes"
|
||||
|
||||
// transpose from [2, 32] to [32, 2]
|
||||
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
|
||||
// r0: {a0, a1, ..., a31}
|
||||
// r1: {b0, b1, ..., b31}
|
||||
//
|
||||
// d0: {a0, b0, ..., a15, b15}
|
||||
// d1: {a16, b16, ..., a31, b31}
|
||||
//
|
||||
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
|
||||
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
|
||||
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
|
||||
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
|
||||
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
|
||||
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
|
||||
return std::make_tuple(d0, d1);
|
||||
}
|
||||
#pragma GCC diagnostic pop
|
||||
|
||||
#endif
|
||||
|
||||
// TODO: debug print, remove me later
|
||||
template<typename scalar_t>
|
||||
void print_array(scalar_t* ptr, int size) {
|
||||
for (int d = 0; d < size; ++d) {
|
||||
if (d % 16 == 0) { std::cout << std::endl; }
|
||||
std::cout << ptr[d] << " ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
178
csrc/cpu/shm.cpp
178
csrc/cpu/shm.cpp
@ -7,9 +7,10 @@
|
||||
|
||||
namespace {
|
||||
#define MAX_SHM_RANK_NUM 8
|
||||
#define MAX_THREAD_NUM 12
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (4 * 1024 * 1024)
|
||||
#define MIN_THREAD_PROCESS_SIZE (8 * 1024)
|
||||
#define PER_THREAD_SHM_BUFFER_BYTES (2 * 1024 * 1024)
|
||||
static_assert(PER_THREAD_SHM_BUFFER_BYTES % 2 == 0);
|
||||
#define PER_THREAD_SHM_BUFFER_OFFSET (PER_THREAD_SHM_BUFFER_BYTES >> 1)
|
||||
#define MIN_THREAD_PROCESS_SIZE (256)
|
||||
#define MAX_P2P_SEND_TENSOR_NUM 8
|
||||
|
||||
template <typename scalar_t>
|
||||
@ -32,10 +33,10 @@ struct KernelVecType<c10::Half> {
|
||||
using scalar_vec_t = vec_op::FP16Vec16;
|
||||
};
|
||||
|
||||
enum class ThreadSHMStat : char { THREAD_READY = 0, SHM_DATA_READY, DONE };
|
||||
|
||||
struct ThreadSHMContext {
|
||||
volatile ThreadSHMStat thread_stats[MAX_SHM_RANK_NUM];
|
||||
volatile char _curr_thread_stamp;
|
||||
volatile char _ready_thread_stamp;
|
||||
char _padding1[6];
|
||||
int thread_id;
|
||||
int thread_num;
|
||||
int rank;
|
||||
@ -44,14 +45,19 @@ struct ThreadSHMContext {
|
||||
int swizzled_ranks[MAX_SHM_RANK_NUM];
|
||||
void* thread_shm_ptrs[MAX_SHM_RANK_NUM];
|
||||
ThreadSHMContext* shm_contexts[MAX_SHM_RANK_NUM];
|
||||
size_t _thread_buffer_mask;
|
||||
char _padding2[56];
|
||||
|
||||
ThreadSHMContext(const int thread_id, const int thread_num, const int rank,
|
||||
const int group_size, void* thread_shm_ptr)
|
||||
: thread_id(thread_id),
|
||||
: _curr_thread_stamp(1),
|
||||
_ready_thread_stamp(0),
|
||||
thread_id(thread_id),
|
||||
thread_num(thread_num),
|
||||
rank(rank),
|
||||
group_size(group_size),
|
||||
_spinning_count(0) {
|
||||
_spinning_count(0),
|
||||
_thread_buffer_mask(0) {
|
||||
static_assert(sizeof(ThreadSHMContext) % 64 == 0);
|
||||
TORCH_CHECK(group_size <= MAX_SHM_RANK_NUM);
|
||||
TORCH_CHECK((size_t)this % 64 == 0);
|
||||
@ -60,7 +66,6 @@ struct ThreadSHMContext {
|
||||
shm_contexts[i] = nullptr;
|
||||
thread_shm_ptrs[i] = nullptr;
|
||||
swizzled_ranks[i] = (i + rank) % group_size;
|
||||
thread_stats[i] = ThreadSHMStat::DONE;
|
||||
}
|
||||
set_context(rank, this, thread_shm_ptr);
|
||||
}
|
||||
@ -77,59 +82,66 @@ struct ThreadSHMContext {
|
||||
|
||||
template <typename T>
|
||||
T* get_thread_shm_ptr(int rank) {
|
||||
return reinterpret_cast<T*>(thread_shm_ptrs[rank]);
|
||||
return reinterpret_cast<T*>(
|
||||
reinterpret_cast<int8_t*>(thread_shm_ptrs[rank]) +
|
||||
(PER_THREAD_SHM_BUFFER_OFFSET & _thread_buffer_mask));
|
||||
}
|
||||
|
||||
void next_buffer() { _thread_buffer_mask ^= 0xFFFFFFFFFFFFFFFF; }
|
||||
|
||||
char get_curr_stamp() const { return _curr_thread_stamp; }
|
||||
|
||||
char get_ready_stamp() const { return _ready_thread_stamp; }
|
||||
|
||||
void next_stamp() {
|
||||
_mm_mfence();
|
||||
_curr_thread_stamp += 1;
|
||||
}
|
||||
|
||||
void commit_ready_stamp() {
|
||||
_mm_mfence();
|
||||
_ready_thread_stamp = _curr_thread_stamp;
|
||||
}
|
||||
|
||||
int get_swizzled_rank(int idx) { return swizzled_ranks[idx]; }
|
||||
|
||||
void wait_for_all(ThreadSHMStat prev_stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
template <typename Cond>
|
||||
void wait_for_all(Cond&& cond) {
|
||||
for (int idx = 1; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
wait_for_one(rank, std::forward<Cond>(cond));
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void wait_for_one(int rank, ThreadSHMStat prev_stat) {
|
||||
while (thread_stats[rank] == prev_stat) {
|
||||
template <typename Cond>
|
||||
void wait_for_one(int rank, Cond&& cond) {
|
||||
ThreadSHMContext* rank_ctx = shm_contexts[rank];
|
||||
for (;;) {
|
||||
char local_curr_stamp = get_curr_stamp();
|
||||
char local_ready_stamp = get_ready_stamp();
|
||||
char rank_curr_stamp = rank_ctx->get_curr_stamp();
|
||||
char rank_ready_stamp = rank_ctx->get_ready_stamp();
|
||||
if (cond(local_curr_stamp, local_ready_stamp, rank_curr_stamp,
|
||||
rank_ready_stamp)) {
|
||||
break;
|
||||
}
|
||||
++_spinning_count;
|
||||
_mm_pause();
|
||||
}
|
||||
vec_op::mem_barrier();
|
||||
}
|
||||
|
||||
void set_thread_stat(ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[this->rank] = stat;
|
||||
}
|
||||
static bool check_no_buffer_conflict(char local_curr_stamp,
|
||||
char local_ready_stamp,
|
||||
char rank_curr_stamp,
|
||||
char rank_ready_stamp) {
|
||||
char temp = rank_curr_stamp + 2;
|
||||
return local_curr_stamp != temp;
|
||||
}
|
||||
|
||||
void set_thread_stat(int target_rank, ThreadSHMStat stat) {
|
||||
for (int idx = 0; idx < group_size; ++idx) {
|
||||
int rank = get_swizzled_rank(idx);
|
||||
shm_contexts[rank]->thread_stats[target_rank] = stat;
|
||||
}
|
||||
}
|
||||
|
||||
// barrier for all ranks in the group, used for all2all ops
|
||||
// DONE -> THREAD_READY -> SHM_DATA_READY -> DONE -> ...
|
||||
void barrier(ThreadSHMStat next_stat) {
|
||||
if (next_stat == ThreadSHMStat::THREAD_READY) {
|
||||
set_thread_stat(ThreadSHMStat::THREAD_READY);
|
||||
wait_for_all(ThreadSHMStat::DONE);
|
||||
} else if (next_stat == ThreadSHMStat::SHM_DATA_READY) {
|
||||
set_thread_stat(ThreadSHMStat::SHM_DATA_READY);
|
||||
wait_for_all(ThreadSHMStat::THREAD_READY);
|
||||
} else if (next_stat == ThreadSHMStat::DONE) {
|
||||
set_thread_stat(ThreadSHMStat::DONE);
|
||||
wait_for_all(ThreadSHMStat::SHM_DATA_READY);
|
||||
} else {
|
||||
TORCH_CHECK(false, "Invalid next_stat to barrier.");
|
||||
}
|
||||
static bool check_stamp_ready(char local_curr_stamp, char local_ready_stamp,
|
||||
char rank_curr_stamp, char rank_ready_stamp) {
|
||||
char temp = local_curr_stamp + 1;
|
||||
return (local_curr_stamp == rank_ready_stamp) || (temp == rank_ready_stamp);
|
||||
}
|
||||
|
||||
std::string to_string() const {
|
||||
@ -164,7 +176,7 @@ class SHMManager {
|
||||
const int group_size)
|
||||
: _rank(rank),
|
||||
_group_size(group_size),
|
||||
_thread_num(std::min(torch::get_num_threads(), MAX_THREAD_NUM)),
|
||||
_thread_num(torch::get_num_threads()),
|
||||
_shm_names({""}),
|
||||
_shared_mem_ptrs({nullptr}),
|
||||
_shm_ctx(nullptr) {
|
||||
@ -326,7 +338,8 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
(total_units_num + thread_num - 1) / thread_num;
|
||||
int64_t per_unit_elem_num = MIN_THREAD_PROCESS_SIZE / sizeof(scalar_t);
|
||||
int64_t max_per_thread_iteration_elem_num =
|
||||
PER_THREAD_SHM_BUFFER_BYTES / sizeof(scalar_t);
|
||||
(PER_THREAD_SHM_BUFFER_BYTES >> 1) /
|
||||
sizeof(scalar_t); // Note: double buffer
|
||||
int64_t per_thread_elem_num = per_unit_elem_num * per_thread_units_num;
|
||||
|
||||
#pragma omp parallel for schedule(static, 1)
|
||||
@ -336,10 +349,13 @@ void shm_cc_loop(ThreadSHMContext* ctx, int64_t elem_num, F&& inner_func) {
|
||||
int64_t curr_elem_num =
|
||||
std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
ThreadSHMContext* thread_ctx = ctx + i;
|
||||
bool fast_mode = ((end - offset) <= max_per_thread_iteration_elem_num);
|
||||
|
||||
while (curr_elem_num > 0) {
|
||||
inner_func(thread_ctx, offset, curr_elem_num);
|
||||
inner_func(thread_ctx, offset, curr_elem_num, fast_mode);
|
||||
|
||||
thread_ctx->next_stamp();
|
||||
thread_ctx->next_buffer();
|
||||
offset += max_per_thread_iteration_elem_num;
|
||||
curr_elem_num = std::min(max_per_thread_iteration_elem_num, end - offset);
|
||||
}
|
||||
@ -397,7 +413,7 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
@ -410,16 +426,17 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
thread_ctx->get_swizzled_rank(idx + 1));
|
||||
});
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, thread_data_ptr,
|
||||
thread_data_elem_num);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
thread_ctx->commit_ready_stamp();
|
||||
int64_t aligned_data_elem_num =
|
||||
(data_elem_num / vec_elem_num) * vec_elem_num;
|
||||
int64_t i = 0;
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_stamp_ready);
|
||||
#pragma GCC unroll 4
|
||||
for (; i < aligned_data_elem_num; i += vec_elem_num) {
|
||||
vec_t local_data(thread_data_ptr + i); // load from cache
|
||||
@ -447,8 +464,6 @@ void all_reduce_sum_impl(ThreadSHMContext* ctx, scalar_t* data,
|
||||
reduced_data.save(thread_data_ptr + i,
|
||||
data_elem_num - aligned_data_elem_num);
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -488,18 +503,18 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
shm_cc_ops::shm_cc_loop<scalar_t>(
|
||||
ctx, elem_num,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
scalar_t* thread_shm_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(rank);
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::THREAD_READY);
|
||||
|
||||
shm_cc_ops::memcpy_to_shm(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::SHM_DATA_READY);
|
||||
if (!fast_mode) {
|
||||
thread_ctx->wait_for_all(ThreadSHMContext::check_no_buffer_conflict);
|
||||
}
|
||||
|
||||
shm_cc_ops::memcpy(thread_shm_ptr, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->commit_ready_stamp();
|
||||
if (rank == dst) {
|
||||
shm_cc_ops::memcpy(outputs[rank] + data_offset, data + data_offset,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
@ -508,12 +523,12 @@ void shm_gather_impl(ThreadSHMContext* ctx, scalar_t* data, size_t elem_num,
|
||||
scalar_t* src_ptr =
|
||||
thread_ctx->get_thread_shm_ptr<scalar_t>(src_rank); // shm
|
||||
scalar_t* dst_ptr = outputs[src_rank] + data_offset;
|
||||
shm_cc_ops::memcpy_from_shm(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
thread_ctx->wait_for_one(src_rank,
|
||||
ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(dst_ptr, src_ptr,
|
||||
data_elem_num * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
|
||||
thread_ctx->barrier(ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
return;
|
||||
@ -599,7 +614,7 @@ struct TensorListMeta {
|
||||
int8_t _padding[40];
|
||||
};
|
||||
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
void shm_send_tensor_list_impl(ThreadSHMContext* ctx, int64_t dst,
|
||||
const std::vector<torch::Tensor>& tensor_list) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list_impl)
|
||||
std::vector<torch::Tensor> tensor_list_with_metadata;
|
||||
@ -620,12 +635,11 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata->total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
int rank = thread_ctx->rank;
|
||||
// Wait until the receiver set the stat to DONE
|
||||
thread_ctx->wait_for_one(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
|
||||
int64_t curr_shm_offset = 0;
|
||||
thread_ctx->wait_for_one(dst,
|
||||
ThreadSHMContext::check_no_buffer_conflict);
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata->get_data(data_offset + curr_shm_offset);
|
||||
frag.size = std::min(frag.size, data_elem_num - curr_shm_offset);
|
||||
@ -634,8 +648,7 @@ void shm_send_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.ptr, frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(rank, ThreadSHMStat::SHM_DATA_READY);
|
||||
thread_ctx->commit_ready_stamp();
|
||||
});
|
||||
}
|
||||
|
||||
@ -646,8 +659,7 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
torch::Tensor metadata_tensor =
|
||||
torch::empty({sizeof(TensorListMeta)}, options);
|
||||
|
||||
// Wait until the sender set the stat of the thread 0 to SHM_DATA_READY
|
||||
ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
shm_cc_ops::memcpy(metadata_tensor.data_ptr(),
|
||||
ctx->get_thread_shm_ptr<void>(src),
|
||||
sizeof(TensorListMeta));
|
||||
@ -664,9 +676,8 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
shm_cc_ops::shm_cc_loop<int8_t>(
|
||||
ctx, metadata.total_bytes,
|
||||
[&](ThreadSHMContext* thread_ctx, int64_t data_offset,
|
||||
int64_t data_elem_num) {
|
||||
// Wait until the sender set the stat to SHM_DATA_READY
|
||||
thread_ctx->wait_for_one(src, ThreadSHMStat::DONE);
|
||||
int64_t data_elem_num, bool fast_mode) {
|
||||
ctx->wait_for_one(src, ThreadSHMContext::check_stamp_ready);
|
||||
int64_t curr_shm_offset = 0;
|
||||
while (curr_shm_offset < data_elem_num) {
|
||||
MemPiece frag = metadata.get_data(data_offset + curr_shm_offset);
|
||||
@ -677,8 +688,6 @@ std::vector<torch::Tensor> shm_recv_tensor_list_impl(ThreadSHMContext* ctx,
|
||||
frag.size);
|
||||
curr_shm_offset += frag.size;
|
||||
}
|
||||
|
||||
thread_ctx->set_thread_stat(src, ThreadSHMStat::DONE);
|
||||
});
|
||||
|
||||
std::vector<torch::Tensor> tensor_list;
|
||||
@ -756,7 +765,8 @@ void shm_send_tensor_list(int64_t handle,
|
||||
int64_t dst) {
|
||||
CPU_KERNEL_GUARD_IN(shm_send_tensor_list)
|
||||
shm_send_tensor_list_impl(
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), tensor_list);
|
||||
SHMManager::get_singleton_instance(handle)->get_shm_ctx(), dst,
|
||||
tensor_list);
|
||||
CPU_KERNEL_GUARD_OUT(shm_send_tensor_list)
|
||||
}
|
||||
|
||||
@ -778,4 +788,4 @@ std::string join_shm_manager(int64_t handle, const std::string& name) {
|
||||
TORCH_CHECK(shm_manager);
|
||||
shm_manager->join(name);
|
||||
return shm_manager->get_shm_ctx()->to_string();
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,6 +50,27 @@ void shm_send_tensor_list(int64_t handle,
|
||||
|
||||
std::vector<torch::Tensor> shm_recv_tensor_list(int64_t handle, int64_t src);
|
||||
|
||||
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
bool is_vnni);
|
||||
|
||||
at::Tensor convert_weight_packed(at::Tensor& weight);
|
||||
|
||||
at::Tensor fused_experts_cpu(
|
||||
at::Tensor& hidden_states, at::Tensor& w1, at::Tensor& w2,
|
||||
at::Tensor& topk_weights, at::Tensor& topk_ids, bool inplace,
|
||||
bool use_int8_w8a8, bool use_fp8_w8a16,
|
||||
const std::optional<at::Tensor>& w1_scale,
|
||||
const std::optional<at::Tensor>& w2_scale,
|
||||
const std::optional<std::vector<int64_t>> block_size,
|
||||
const std::optional<at::Tensor>& a1_scale,
|
||||
const std::optional<at::Tensor>& a2_scale, bool is_vnni);
|
||||
|
||||
at::Tensor int8_scaled_mm_with_quant(at::Tensor& mat1, at::Tensor& mat2,
|
||||
at::Tensor& scales2,
|
||||
const std::optional<at::Tensor>& bias,
|
||||
at::ScalarType out_dtype, bool is_vnni);
|
||||
|
||||
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
// vLLM custom ops
|
||||
|
||||
@ -131,16 +152,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
|
||||
// Quantization
|
||||
#ifdef __AVX512F__
|
||||
at::Tag stride_tag = at::Tag::needs_fixed_stride_order;
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
ops.def(
|
||||
"static_scaled_int8_quant(Tensor! out, Tensor input, Tensor scale,"
|
||||
"Tensor? azp) -> ()");
|
||||
"Tensor? azp) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("static_scaled_int8_quant", torch::kCPU, &static_scaled_int8_quant);
|
||||
|
||||
// Compute int8 quantized tensor and scaling factor
|
||||
ops.def(
|
||||
"dynamic_scaled_int8_quant(Tensor! out, Tensor input, Tensor! scale, "
|
||||
"Tensor!? azp) -> ()");
|
||||
"Tensor!? azp) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("dynamic_scaled_int8_quant", torch::kCPU,
|
||||
&dynamic_scaled_int8_quant);
|
||||
// W8A8 GEMM, supporting symmetric per-tensor or per-row/column
|
||||
@ -148,7 +172,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def(
|
||||
"cutlass_scaled_mm(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor? bias) -> ()");
|
||||
" Tensor b_scales, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm", torch::kCPU, &int8_scaled_mm);
|
||||
// w8a8 GEMM, supporting asymmetric per-tensor or per-row/column
|
||||
// quantization.
|
||||
@ -156,7 +181,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
"cutlass_scaled_mm_azp(Tensor! out, Tensor a,"
|
||||
" Tensor b, Tensor a_scales,"
|
||||
" Tensor b_scales, Tensor azp_adj,"
|
||||
" Tensor? azp, Tensor? bias) -> ()");
|
||||
" Tensor? azp, Tensor? bias) -> ()",
|
||||
{stride_tag});
|
||||
ops.impl("cutlass_scaled_mm_azp", torch::kCPU, &int8_scaled_mm_azp);
|
||||
#elif defined(__powerpc64__)
|
||||
// Compute int8 quantized tensor for given scaling factor.
|
||||
@ -209,6 +235,28 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
||||
ops.def("shm_recv_tensor_list(int handle, int src) -> Tensor[](a)",
|
||||
&shm_recv_tensor_list);
|
||||
#endif
|
||||
|
||||
// sgl-kernels
|
||||
#if defined(__AVX512BF16__) && defined(__AVX512F__) && defined(__AVX512VNNI__)
|
||||
ops.def(
|
||||
"weight_packed_linear(Tensor(a0!) mat1, Tensor(a1!) mat2, Tensor(a2!)? "
|
||||
"bias, bool is_vnni) -> Tensor");
|
||||
ops.impl("weight_packed_linear", torch::kCPU, &weight_packed_linear);
|
||||
ops.def("convert_weight_packed(Tensor! weight) -> Tensor");
|
||||
ops.impl("convert_weight_packed", torch::kCPU, &convert_weight_packed);
|
||||
ops.def(
|
||||
"fused_experts_cpu(Tensor! hidden_states, Tensor w1, Tensor w2, Tensor "
|
||||
"topk_weights, Tensor topk_ids, bool inplace, bool use_int8_w8a8, bool "
|
||||
"use_fp8_w8a16, Tensor? w1_scale, Tensor? w2_scale, SymInt[]? "
|
||||
"block_size, Tensor? a1_scale, Tensor? a2_scale, bool is_vnni) -> "
|
||||
"Tensor");
|
||||
ops.impl("fused_experts_cpu", torch::kCPU, &fused_experts_cpu);
|
||||
ops.def(
|
||||
"int8_scaled_mm_with_quant(Tensor mat1, Tensor mat2, Tensor scales2, "
|
||||
"Tensor? bias, ScalarType out_dtype, bool is_vnni) -> Tensor");
|
||||
ops.impl("int8_scaled_mm_with_quant", torch::kCPU,
|
||||
&int8_scaled_mm_with_quant);
|
||||
#endif
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
|
||||
|
||||
114
csrc/custom_quickreduce.cu
Normal file
114
csrc/custom_quickreduce.cu
Normal file
@ -0,0 +1,114 @@
|
||||
#include <ATen/cuda/Exceptions.h>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
#include <c10/cuda/CUDAStream.h>
|
||||
#include <torch/all.h>
|
||||
|
||||
#ifdef USE_ROCM
|
||||
|
||||
#include "quickreduce/quick_reduce.h"
|
||||
|
||||
quickreduce::fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size) {
|
||||
if (world_size > 8)
|
||||
throw std::invalid_argument("world size > 8 is not supported");
|
||||
if (world_size == 6)
|
||||
throw std::invalid_argument("world size == 6 is not supported");
|
||||
if (world_size % 2 != 0)
|
||||
throw std::invalid_argument("Odd num gpus is not supported for now");
|
||||
if (rank < 0 || rank >= world_size)
|
||||
throw std::invalid_argument("invalid rank passed in");
|
||||
quickreduce::DeviceComms* fptr = new quickreduce::DeviceComms();
|
||||
fptr->init(world_size, rank, qr_max_size);
|
||||
return (quickreduce::fptr_t)fptr;
|
||||
}
|
||||
|
||||
void qr_destroy(quickreduce::fptr_t _fa) {
|
||||
if (_fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
fa->destroy();
|
||||
delete fa;
|
||||
}
|
||||
}
|
||||
|
||||
torch::Tensor qr_get_handle(quickreduce::fptr_t _fa) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
hipIpcMemHandle_t handle = fa->get_handle();
|
||||
auto options =
|
||||
torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU);
|
||||
auto data_handle =
|
||||
torch::empty({static_cast<int64_t>(sizeof(hipIpcMemHandle_t))}, options);
|
||||
std::memcpy(data_handle.data_ptr(), &handle, sizeof(hipIpcMemHandle_t));
|
||||
return data_handle;
|
||||
}
|
||||
|
||||
void qr_open_handles(quickreduce::fptr_t _fa,
|
||||
const std::vector<torch::Tensor>& handles) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
std::vector<hipIpcMemHandle_t> ipc_handles;
|
||||
ipc_handles.reserve(handles.size());
|
||||
for (auto& handle : handles) {
|
||||
// Ensure the tensor is on the same device as the current device.
|
||||
hipIpcMemHandle_t ipc_handle;
|
||||
std::memcpy(&ipc_handle, handle.data_ptr(), sizeof(hipIpcMemHandle_t));
|
||||
ipc_handles.push_back(ipc_handle);
|
||||
}
|
||||
fa->open_ipc_handles(ipc_handles);
|
||||
}
|
||||
|
||||
void qr_all_reduce(quickreduce::fptr_t _fa, torch::Tensor& inp,
|
||||
torch::Tensor& out, int64_t quant_level, bool cast_bf2half) {
|
||||
auto fa = reinterpret_cast<quickreduce::DeviceComms*>(_fa);
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(inp));
|
||||
auto stream = at::cuda::getCurrentHIPStreamMasqueradingAsCUDA();
|
||||
|
||||
TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type());
|
||||
TORCH_CHECK_EQ(inp.numel(), out.numel());
|
||||
TORCH_CHECK_LE(out.numel(), fa->kMaxProblemSize);
|
||||
if (out.scalar_type() == at::ScalarType::Half) {
|
||||
fa->allreduce<half, false>(reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(), quant_level, stream);
|
||||
} else if (out.scalar_type() == at::ScalarType::BFloat16) {
|
||||
if (cast_bf2half) {
|
||||
fa->allreduce<half, true>(reinterpret_cast<half*>(inp.data_ptr()),
|
||||
reinterpret_cast<half*>(out.data_ptr()),
|
||||
out.numel(), quant_level, stream);
|
||||
} else {
|
||||
fa->allreduce<quickreduce::nv_bfloat16, false>(
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(inp.data_ptr()),
|
||||
reinterpret_cast<quickreduce::nv_bfloat16*>(out.data_ptr()),
|
||||
out.numel(), quant_level, stream);
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error(
|
||||
"quick allreduce only supports float16 and bfloat16");
|
||||
}
|
||||
}
|
||||
|
||||
int64_t qr_max_size() {
|
||||
// The default is 2GB (2,147,483,648 bytes)
|
||||
return static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
}
|
||||
|
||||
#define INSTANTIATE_FOR_WORLDSIZE(T, Codec, cast_bf2half) \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 2>, \
|
||||
cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 4>, \
|
||||
cast_bf2half>; \
|
||||
template struct quickreduce::AllReduceTwoshot<T, Codec<T, 8>, cast_bf2half>;
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecFP, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ4, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ6, true)
|
||||
INSTANTIATE_FOR_WORLDSIZE(quickreduce::nv_bfloat16, quickreduce::CodecQ8, true)
|
||||
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecFP, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ4, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ6, false)
|
||||
INSTANTIATE_FOR_WORLDSIZE(half, quickreduce::CodecQ8, false)
|
||||
|
||||
#endif // USE_ROCM
|
||||
@ -185,9 +185,7 @@ void causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight,
|
||||
params.conv_states_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
||||
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
@ -278,9 +276,7 @@ void causal_conv1d_update(const at::Tensor &x,
|
||||
params.conv_state_indices_ptr = nullptr;
|
||||
}
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)x.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
||||
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
||||
|
||||
@ -647,9 +647,7 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta,
|
||||
);
|
||||
|
||||
|
||||
// Otherwise the kernel will be launched from cuda:0 device
|
||||
// Cast to char to avoid compiler warning about narrowing
|
||||
at::cuda::CUDAGuard device_guard{(char)u.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(u));
|
||||
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
||||
DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] {
|
||||
selective_scan_fwd_cuda<input_t, weight_t>(params, stream);
|
||||
|
||||
@ -1255,8 +1255,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
||||
@ -239,7 +239,7 @@ void moe_sum(torch::Tensor& input, // [num_tokens, topk, hidden_size]
|
||||
torch::Tensor& output) // [num_tokens, hidden_size]
|
||||
{
|
||||
const int hidden_size = input.size(-1);
|
||||
const int num_tokens = output.numel() / hidden_size;
|
||||
const auto num_tokens = output.numel() / hidden_size;
|
||||
const int topk = input.size(1);
|
||||
|
||||
dim3 grid(num_tokens);
|
||||
|
||||
@ -492,7 +492,7 @@ void topk_softmax(
|
||||
torch::Tensor& gating_output) // [num_tokens, num_experts]
|
||||
{
|
||||
const int num_experts = gating_output.size(-1);
|
||||
const int num_tokens = gating_output.numel() / num_experts;
|
||||
const auto num_tokens = gating_output.numel() / num_experts;
|
||||
const int topk = topk_weights.size(-1);
|
||||
|
||||
const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0);
|
||||
|
||||
11
csrc/ops.h
11
csrc/ops.h
@ -360,3 +360,14 @@ std::tuple<int64_t, torch::Tensor> allocate_shared_buffer_and_handle(
|
||||
int64_t size);
|
||||
int64_t open_mem_handle(torch::Tensor& mem_handle);
|
||||
void free_shared_buffer(int64_t buffer);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
fptr_t init_custom_qr(int64_t rank, int64_t world_size,
|
||||
std::optional<int64_t> qr_max_size = std::nullopt);
|
||||
void qr_destroy(fptr_t _fa);
|
||||
torch::Tensor qr_get_handle(fptr_t _fa);
|
||||
void qr_open_handles(fptr_t _fa, const std::vector<torch::Tensor>& handles);
|
||||
void qr_all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out,
|
||||
int64_t quant_level, bool cast_bf2half = false);
|
||||
int64_t qr_max_size();
|
||||
#endif
|
||||
@ -144,4 +144,65 @@ struct cutlass_3x_gemm_sm100 {
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
template <typename ElementAB_, typename ElementD_,
|
||||
template <typename, typename, typename> typename Epilogue_,
|
||||
typename TileShape, typename ClusterShape, typename KernelSchedule,
|
||||
typename EpilogueSchedule>
|
||||
struct cutlass_3x_gemm_sm120 {
|
||||
using ElementAB = ElementAB_;
|
||||
using LayoutA = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentA =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using LayoutB = cutlass::layout::ColumnMajor;
|
||||
static constexpr int AlignmentB =
|
||||
128 / cutlass::sizeof_bits<ElementAB>::value;
|
||||
|
||||
using ElementC = void;
|
||||
using LayoutC = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentC =
|
||||
128 / cutlass::sizeof_bits<ElementD_>::value;
|
||||
|
||||
using ElementD = ElementD_;
|
||||
using LayoutD = cutlass::layout::RowMajor;
|
||||
static constexpr int AlignmentD = AlignmentC;
|
||||
|
||||
using ElementAcc =
|
||||
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
|
||||
float>::type;
|
||||
using Epilogue = Epilogue_<ElementAcc, ElementD, TileShape>;
|
||||
|
||||
// MMA type
|
||||
using ElementAccumulator = float;
|
||||
|
||||
// Epilogue types
|
||||
using ElementBias = cutlass::half_t;
|
||||
using ElementCompute = float;
|
||||
using ElementAux = ElementD;
|
||||
using LayoutAux = LayoutD;
|
||||
using ElementAmax = float;
|
||||
|
||||
using EVTCompute = typename Epilogue::EVTCompute;
|
||||
|
||||
using CollectiveEpilogue =
|
||||
typename cutlass::epilogue::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, TileShape,
|
||||
ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto,
|
||||
ElementAccumulator, ElementCompute, ElementC, LayoutC, AlignmentC,
|
||||
ElementD, LayoutD, AlignmentD, EpilogueSchedule,
|
||||
EVTCompute>::CollectiveOp;
|
||||
|
||||
using CollectiveMainloop =
|
||||
typename cutlass::gemm::collective::CollectiveBuilder<
|
||||
cutlass::arch::Sm120, cutlass::arch::OpClassTensorOp, ElementAB,
|
||||
LayoutA, AlignmentA, ElementAB, LayoutB, AlignmentB,
|
||||
ElementAccumulator, TileShape, ClusterShape,
|
||||
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
|
||||
sizeof(typename CollectiveEpilogue::SharedStorage))>,
|
||||
KernelSchedule>::CollectiveOp;
|
||||
|
||||
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
|
||||
Shape<int, int, int, int>, CollectiveMainloop, CollectiveEpilogue, void>;
|
||||
};
|
||||
|
||||
} // namespace vllm
|
||||
|
||||
@ -36,6 +36,12 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
|
||||
void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
|
||||
@ -29,26 +29,12 @@ struct sm100_fp8_config_default {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M256 {
|
||||
// M in (128, 256]
|
||||
// M in (64, 256]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_2, _2, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M128 {
|
||||
// M in (64, 128]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _256>;
|
||||
using ClusterShape = Shape<_2, _4, _1>;
|
||||
using ClusterShape = Shape<_2, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
@ -57,12 +43,26 @@ struct sm100_fp8_config_M128 {
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M64 {
|
||||
// M in [1, 64]
|
||||
// M in (16, 64]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _64, _256>;
|
||||
using ClusterShape = Shape<_1, _8, _1>;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm100_fp8_config_M16 {
|
||||
// M in [1, 16]
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_64, _64, _128>;
|
||||
using ClusterShape = Shape<_1, _4, _1>;
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm100<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
@ -82,27 +82,27 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm100_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM16 =
|
||||
typename sm100_fp8_config_M16<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM64 =
|
||||
typename sm100_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM128 =
|
||||
typename sm100_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
using Cutlass3xGemmM256 =
|
||||
typename sm100_fp8_config_M256<InType, OutType, Epilogue>::Cutlass3xGemm;
|
||||
|
||||
uint32_t const m = a.size(0);
|
||||
uint32_t const mp2 =
|
||||
std::max(static_cast<uint32_t>(64), next_pow_2(m)); // next power of 2
|
||||
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
|
||||
|
||||
if (mp2 <= 64) {
|
||||
// m in [1, 64]
|
||||
if (mp2 <= 16) {
|
||||
// m in [1, 16]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM16>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 64) {
|
||||
// m in (16, 64]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM64>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 128) {
|
||||
// m in (64, 128]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM128>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else if (mp2 <= 256) {
|
||||
// m in (128, 256]
|
||||
// m in (64, 256]
|
||||
return cutlass_gemm_caller<Cutlass3xGemmM256>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
} else {
|
||||
|
||||
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
Normal file
24
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu
Normal file
@ -0,0 +1,24 @@
|
||||
#include "scaled_mm_kernels.hpp"
|
||||
#include "scaled_mm_sm120_fp8_dispatch.cuh"
|
||||
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
|
||||
|
||||
namespace vllm {
|
||||
|
||||
void cutlass_scaled_mm_sm120_fp8(torch::Tensor& out, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
if (bias) {
|
||||
TORCH_CHECK(bias->dtype() == out.dtype(),
|
||||
"currently bias dtype must match output dtype ", out.dtype());
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogueBias>(
|
||||
out, a, b, a_scales, b_scales, *bias);
|
||||
} else {
|
||||
return cutlass_scaled_mm_sm120_fp8_epilogue<c3x::ScaledEpilogue>(
|
||||
out, a, b, a_scales, b_scales);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
@ -0,0 +1,67 @@
|
||||
#pragma once
|
||||
|
||||
#include "scaled_mm.cuh"
|
||||
#include "cutlass_gemm_caller.cuh"
|
||||
|
||||
/**
|
||||
* This file defines Gemm kernel configurations for SM120 (fp8) based on the
|
||||
* Gemm shape.
|
||||
*/
|
||||
|
||||
namespace vllm {
|
||||
|
||||
using c3x::cutlass_gemm_caller;
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue>
|
||||
struct sm120_fp8_config_default {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto;
|
||||
using EpilogueSchedule = cutlass::epilogue::collective::EpilogueScheduleAuto;
|
||||
using TileShape = Shape<_128, _128, _128>;
|
||||
using ClusterShape = Shape<_1, _1, _1>; // Only work with Shape<_1, _1, _1>
|
||||
using Cutlass3xGemm =
|
||||
cutlass_3x_gemm_sm120<InType, OutType, Epilogue, TileShape, ClusterShape,
|
||||
KernelSchedule, EpilogueSchedule>;
|
||||
};
|
||||
|
||||
template <typename InType, typename OutType,
|
||||
template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... args) {
|
||||
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
using Cutlass3xGemmDefault =
|
||||
typename sm120_fp8_config_default<InType, OutType,
|
||||
Epilogue>::Cutlass3xGemm;
|
||||
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
|
||||
out, a, b, std::forward<EpilogueArgs>(args)...);
|
||||
}
|
||||
|
||||
template <template <typename, typename, typename> typename Epilogue,
|
||||
typename... EpilogueArgs>
|
||||
void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
|
||||
torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
EpilogueArgs&&... epilogue_args) {
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
|
||||
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
|
||||
|
||||
if (out.dtype() == torch::kBFloat16) {
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::bfloat16_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
} else {
|
||||
TORCH_CHECK(out.dtype() == torch::kFloat16);
|
||||
return cutlass_gemm_sm120_fp8_dispatch<cutlass::float_e4m3_t,
|
||||
cutlass::half_t, Epilogue>(
|
||||
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace vllm
|
||||
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
Normal file
34
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
Normal file
@ -0,0 +1,34 @@
|
||||
#include <cudaTypedefs.h>
|
||||
#include "c3x/scaled_mm_kernels.hpp"
|
||||
|
||||
#include "cuda_utils.h"
|
||||
|
||||
/*
|
||||
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
|
||||
NVIDIA GPUs with sm120 (Blackwell Geforce).
|
||||
*/
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias) {
|
||||
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
|
||||
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
|
||||
|
||||
int M = a.size(0), N = b.size(1), K = a.size(1);
|
||||
TORCH_CHECK(
|
||||
(a_scales.numel() == 1 || a_scales.numel() == a.size(0)) &&
|
||||
(b_scales.numel() == 1 || b_scales.numel() == b.size(1)),
|
||||
"Currently, block scaled fp8 gemm is not implemented for Blackwell");
|
||||
|
||||
// Standard per-tensor/per-token/per-channel scaling
|
||||
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
|
||||
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn,
|
||||
"Currently, only fp8 gemm is implemented for Blackwell");
|
||||
vllm::cutlass_scaled_mm_sm120_fp8(c, a, b, a_scales, b_scales, bias);
|
||||
}
|
||||
|
||||
#endif
|
||||
@ -41,6 +41,14 @@ void cutlass_moe_mm_sm90(
|
||||
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
torch::Tensor const& a_scales,
|
||||
torch::Tensor const& b_scales,
|
||||
std::optional<torch::Tensor> const& bias);
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
|
||||
torch::Tensor const& b,
|
||||
@ -168,8 +176,15 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
|
||||
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
|
||||
int32_t version_num = get_sm_version_num();
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
|
||||
if (version_num >= 120) {
|
||||
cutlass_scaled_mm_sm120(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM100
|
||||
if (version_num >= 100) {
|
||||
if (version_num >= 100 && version_num < 120) {
|
||||
cutlass_scaled_mm_sm100(c, a, b, a_scales, b_scales, bias);
|
||||
return;
|
||||
}
|
||||
@ -241,7 +256,7 @@ void get_cutlass_moe_mm_data(
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_SCALED_MM_SM100 && ENABLE_SCALED_MM_SM90)
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_moe_mm_data_caller(topk_ids, expert_offsets, problem_sizes1,
|
||||
problem_sizes2, input_permutation,
|
||||
output_permutation, num_experts, n, k,
|
||||
@ -252,7 +267,7 @@ void get_cutlass_moe_mm_data(
|
||||
false,
|
||||
"No compiled get_cutlass_moe_mm_data: no cutlass_scaled_mm kernel for "
|
||||
"CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
@ -265,7 +280,8 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
// This function currently gets compiled only if we have a valid cutlass moe
|
||||
// mm to run it for.
|
||||
int32_t version_num = get_sm_version_num();
|
||||
#if defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90
|
||||
#if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \
|
||||
(defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100)
|
||||
get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1,
|
||||
problem_sizes2, expert_num_tokens,
|
||||
num_local_experts, padded_m, n, k);
|
||||
@ -275,7 +291,7 @@ void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets,
|
||||
false,
|
||||
"No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel "
|
||||
"for CUDA device capability: ",
|
||||
version_num, ". Required capability: 90");
|
||||
version_num, ". Required capability: 90 or 100");
|
||||
}
|
||||
|
||||
void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a,
|
||||
|
||||
@ -335,8 +335,10 @@ void run_fp4_blockwise_scaled_group_mm(
|
||||
TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM");
|
||||
}
|
||||
|
||||
#if defined ENABLE_NVFP4 && ENABLE_NVFP4
|
||||
constexpr auto FLOAT4_E2M1X2 = at::ScalarType::Byte;
|
||||
constexpr auto SF_DTYPE = at::ScalarType::Float8_e4m3fn;
|
||||
#endif
|
||||
|
||||
#define CHECK_TYPE(x, st, m) \
|
||||
TORCH_CHECK(x.scalar_type() == st, ": Inconsistency of Tensor type:", m)
|
||||
|
||||
@ -561,7 +561,7 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
TORCH_CHECK(output_scale.size(1) * 4 == padded_k);
|
||||
|
||||
auto in_dtype = input.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
const cudaStream_t stream =
|
||||
at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
if (in_dtype == at::ScalarType::Half) {
|
||||
@ -579,4 +579,4 @@ void scaled_fp4_experts_quant_sm100a(
|
||||
} else {
|
||||
TORCH_CHECK(false, "Expected input data type to be half or bfloat16");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -347,7 +347,7 @@ void scaled_fp4_quant_sm100a(torch::Tensor const& output,
|
||||
auto input_sf_ptr = static_cast<float const*>(input_sf.data_ptr());
|
||||
auto sf_out = static_cast<int32_t*>(output_sf.data_ptr());
|
||||
auto output_ptr = static_cast<int64_t*>(output.data_ptr());
|
||||
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
|
||||
auto stream = at::cuda::getCurrentCUDAStream(input.get_device());
|
||||
|
||||
// We don't support e8m0 scales at this moment.
|
||||
|
||||
@ -267,7 +267,7 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
|
||||
B_sf.sizes()[1], ")");
|
||||
|
||||
auto out_dtype = D.dtype();
|
||||
at::cuda::CUDAGuard device_guard{(char)A.get_device()};
|
||||
const at::cuda::OptionalCUDAGuard device_guard(device_of(A));
|
||||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(A.get_device());
|
||||
|
||||
if (out_dtype == at::ScalarType::Half) {
|
||||
|
||||
@ -1113,8 +1113,6 @@ __global__ void Marlin(
|
||||
if constexpr (has_zp && !is_zp_float) {
|
||||
if (is_new_zp) {
|
||||
if constexpr (group_blocks == -1) is_first_matmul_in_slice = false;
|
||||
FragB frag_zp_0;
|
||||
FragB frag_zp_1;
|
||||
int zp_quant_0, zp_quant_1;
|
||||
|
||||
if constexpr (w_type.size_bits() == 4) {
|
||||
|
||||
338
csrc/quickreduce/base.h
Normal file
338
csrc/quickreduce/base.h
Normal file
@ -0,0 +1,338 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstdint>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#include <hip/hip_bf16.h>
|
||||
|
||||
#define __quickreduce_device_inline__ __device__ __forceinline__
|
||||
#define __quickreduce_launch_bounds_two_shot__ __launch_bounds__(256, 4)
|
||||
#define __quickreduce_launch_bounds_one_shot__ __launch_bounds__(512, 4)
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
typedef __hip_bfloat16 nv_bfloat16;
|
||||
typedef __hip_bfloat162 nv_bfloat162;
|
||||
|
||||
using int32x2_t = __attribute__((__vector_size__(2 * sizeof(int)))) int;
|
||||
using int32x4_t = __attribute__((__vector_size__(4 * sizeof(int)))) int;
|
||||
|
||||
// Setup acquire-release semantics for vector memory reads (mubuf instruction)
|
||||
// as per architecture.
|
||||
#if defined(__gfx942__)
|
||||
// CDNA3: Scope bits sc0, sc1
|
||||
#define MUBUF_ACQUIRE 16
|
||||
#define MUBUF_RELEASE 16
|
||||
#elif (defined(__gfx908__) || defined(__gfx90a__))
|
||||
// CDNA1 and CDNA2 - glc bit
|
||||
#define MUBUF_ACQUIRE 1
|
||||
#define MUBUF_RELEASE 0
|
||||
#endif
|
||||
|
||||
static constexpr int kNegOne = 0xBC00BC00; // {-1, -1}, fp16x2_t
|
||||
|
||||
// Number of atoms (4xf16x2_t) processed by a single thread
|
||||
static constexpr int kAtoms = 8;
|
||||
|
||||
// We use a workgroup of 256 threads
|
||||
static constexpr int kBlockSize = 256;
|
||||
static constexpr int kAtomStride = kBlockSize;
|
||||
|
||||
// Size and atom stride of source/destination data that the block will
|
||||
// process.
|
||||
// Workgroup scope = Tile = (256 threads x 8 atoms x 16B)
|
||||
static constexpr int kTileSize = kBlockSize * kAtoms * sizeof(int32x4_t);
|
||||
|
||||
// Max number of blocks. 304 CUs on MI300
|
||||
static constexpr int kMaxNumBlocks = 304 * 4;
|
||||
|
||||
// Standard CDNA wavefront size.
|
||||
static constexpr int kWavefront = 64;
|
||||
|
||||
// 256 thread, 4 wavefronts.
|
||||
static dim3 constexpr kBlockTwoShot = {kWavefront, kBlockSize / kWavefront, 1};
|
||||
|
||||
// Number of threads in a group for quantization
|
||||
// It corresponds to 32 F16 elements in quantization block
|
||||
static constexpr int kThreadGroupSize = 8;
|
||||
|
||||
// Methods
|
||||
__quickreduce_device_inline__ __host__ unsigned long divceil(unsigned long x,
|
||||
unsigned long y) {
|
||||
return ((x + y - 1) / y);
|
||||
}
|
||||
|
||||
union BufferResource {
|
||||
__quickreduce_device_inline__ constexpr BufferResource()
|
||||
: config(0x00020000U) {}
|
||||
|
||||
__quickreduce_device_inline__ constexpr BufferResource(void* buffer_address,
|
||||
uint32_t buffer_size)
|
||||
: address(buffer_address), range(buffer_size), config(0x00020000U) {}
|
||||
|
||||
int32x4_t descriptor;
|
||||
struct {
|
||||
void* address; // 8B, out of which first 48b is address, and 16b is stride
|
||||
// (unused)
|
||||
uint32_t range; // Byte range for the buffer resource
|
||||
uint32_t config; // Constant, DFMT=32b
|
||||
};
|
||||
};
|
||||
|
||||
__quickreduce_device_inline__ static int32x4_t buffer_load_dwordx4(
|
||||
int32x4_t srsrc, int32_t voffset, int32_t soffset,
|
||||
int32_t aux) __asm("llvm.amdgcn.raw.buffer.load.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void buffer_store_dwordx4(
|
||||
int32x4_t data, int32x4_t srsrc, int32_t voffset, int32_t soffset,
|
||||
int32_t aux) __asm("llvm.amdgcn.raw.buffer.store.v4i32");
|
||||
|
||||
__quickreduce_device_inline__ static void set_fp16_ovfl(bool const value) {
|
||||
#if defined(__gfx942__)
|
||||
if (value) {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 1;" ::);
|
||||
} else {
|
||||
asm volatile("s_setreg_imm32_b32 0xdc1, 0;" ::);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
union bf162_int_union {
|
||||
int i;
|
||||
nv_bfloat162 bf2;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ void packed_assign_add(int32x4_t* A,
|
||||
int32x4_t* B);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<half>(int32x4_t* A,
|
||||
int32x4_t* B) {
|
||||
int32x4_t& tR_fragment = A[0];
|
||||
int32x4_t& tA_fragment = B[0];
|
||||
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[0])
|
||||
: "v"(tR_fragment[0]), "v"(tA_fragment[0]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[1])
|
||||
: "v"(tR_fragment[1]), "v"(tA_fragment[1]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[2])
|
||||
: "v"(tR_fragment[2]), "v"(tA_fragment[2]));
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(tR_fragment[3])
|
||||
: "v"(tR_fragment[3]), "v"(tA_fragment[3]));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ void packed_assign_add<nv_bfloat16>(
|
||||
int32x4_t* A, int32x4_t* B) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(A);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(B);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
tA[i] = __hadd2(tA[i], tB[i]);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_max_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmax2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_min(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_min_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_min<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hmin2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_abs_max(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<half>(int a, int b) {
|
||||
half2 wmaxh2 = __builtin_bit_cast(half2, a);
|
||||
half2 wminh2 = __builtin_bit_cast(half2, b);
|
||||
half2 wblockmaxh2;
|
||||
|
||||
wblockmaxh2.x =
|
||||
__hgt(__habs(wmaxh2.x), __habs(wminh2.x)) ? wmaxh2.x : wminh2.x;
|
||||
wblockmaxh2.y =
|
||||
__hgt(__habs(wmaxh2.y), __habs(wminh2.y)) ? wmaxh2.y : wminh2.y;
|
||||
return __builtin_bit_cast(int, wblockmaxh2);
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_abs_max<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2.x = __hgt(__habs(A.bf2.x), __habs(B.bf2.x)) ? A.bf2.x : B.bf2.x;
|
||||
R.bf2.y = __hgt(__habs(A.bf2.y), __habs(B.bf2.y)) ? A.bf2.y : B.bf2.y;
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_add(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hadd2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_add<int16_t>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_add_i16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_sub(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<half>(int a, int b) {
|
||||
int result;
|
||||
|
||||
// MI300 lacks packed fp16 sub instruction. So we do -1 * min + max
|
||||
asm volatile("v_pk_fma_f16 %0, %1, %2 %3"
|
||||
: "=v"(result)
|
||||
: "v"(kNegOne), "v"(b), "v"(a));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_sub<nv_bfloat16>(int a, int b) {
|
||||
bf162_int_union A, B, R;
|
||||
A.i = a;
|
||||
B.i = b;
|
||||
R.bf2 = __hsub2(A.bf2, B.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_mul(int a, int b);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<half>(int a, int b) {
|
||||
int result;
|
||||
asm volatile("v_pk_mul_f16 %0, %1, %2" : "=v"(result) : "v"(a), "v"(b));
|
||||
return result;
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_mul<nv_bfloat16>(int a, int b) {
|
||||
nv_bfloat162* tA = reinterpret_cast<nv_bfloat162*>(&a);
|
||||
nv_bfloat162* tB = reinterpret_cast<nv_bfloat162*>(&b);
|
||||
nv_bfloat162 tR = __hmul2(*tA, *tB);
|
||||
return *(reinterpret_cast<int*>(&tR));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int packed_rcp(int a);
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<half>(int a) {
|
||||
return __builtin_bit_cast(int, h2rcp(__builtin_bit_cast(half2, a)));
|
||||
}
|
||||
|
||||
template <>
|
||||
__quickreduce_device_inline__ int packed_rcp<nv_bfloat16>(int a) {
|
||||
bf162_int_union A, R;
|
||||
A.i = a;
|
||||
R.bf2 = h2rcp(A.bf2);
|
||||
return R.i;
|
||||
}
|
||||
|
||||
// changes dtype
|
||||
__quickreduce_device_inline__ float T2float_cast(half a) {
|
||||
return __half2float(a);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ float T2float_cast(nv_bfloat16 a) {
|
||||
return __bfloat162float(a);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__quickreduce_device_inline__ int group_abs_max(int32x4_t atom) {
|
||||
const int group_leader = (threadIdx.x / kThreadGroupSize) * kThreadGroupSize;
|
||||
|
||||
int wmax, wmin, wblockmax;
|
||||
int a, b;
|
||||
a = packed_max<T>(atom[0], atom[1]);
|
||||
b = packed_max<T>(atom[2], atom[3]);
|
||||
|
||||
wmax = packed_max<T>(a, b);
|
||||
|
||||
a = packed_min<T>(atom[0], atom[1]);
|
||||
b = packed_min<T>(atom[2], atom[3]);
|
||||
|
||||
wmin = packed_min<T>(a, b);
|
||||
|
||||
// Reduce the max among a group of threads
|
||||
// Note: This is basically 2 blocks of values setup as the
|
||||
// upper/lower halves of the f16x2_t
|
||||
for (int i = 1; i < kThreadGroupSize; i <<= 1) {
|
||||
int x = __shfl_down(wmax, i);
|
||||
wmax = packed_max<T>(wmax, x);
|
||||
|
||||
int y = __shfl_down(wmin, i);
|
||||
wmin = packed_min<T>(wmin, y);
|
||||
}
|
||||
wblockmax = packed_abs_max<T>(wmax, wmin);
|
||||
// Share with the cohort
|
||||
wblockmax = __shfl(wblockmax, group_leader);
|
||||
return wblockmax;
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void set_sync_flag(uint32_t* flag_ptr,
|
||||
uint32_t flag) {
|
||||
__atomic_store_n(flag_ptr, flag, __ATOMIC_RELEASE);
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void wait_sync_flag(uint32_t* flag_ptr,
|
||||
uint32_t flag) {
|
||||
while (__atomic_load_n(flag_ptr, __ATOMIC_RELAXED) != flag) {
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace quickreduce
|
||||
196
csrc/quickreduce/quick_reduce.h
Normal file
196
csrc/quickreduce/quick_reduce.h
Normal file
@ -0,0 +1,196 @@
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "quick_reduce_impl.cuh"
|
||||
|
||||
#define HIP_CHECK(err) \
|
||||
do { \
|
||||
hipError_t err_ = (err); \
|
||||
if (err_ != hipSuccess) { \
|
||||
std::printf("HIP error %d at %s:%d. %s\n", err_, __FILE__, __LINE__, \
|
||||
hipGetErrorString(err_)); \
|
||||
throw std::runtime_error("HIP error"); \
|
||||
} \
|
||||
} while (0)
|
||||
|
||||
namespace quickreduce {
|
||||
using fptr_t = int64_t;
|
||||
static_assert(sizeof(void*) == sizeof(fptr_t));
|
||||
|
||||
template <typename AllReduceKernel, typename T>
|
||||
__global__ __quickreduce_launch_bounds_two_shot__ static void
|
||||
allreduce_prototype_twoshot(T const* A, T* B, uint32_t N, uint32_t num_blocks,
|
||||
int rank, uint8_t** dbuffer_list,
|
||||
uint32_t data_offset, uint32_t flag_color) {
|
||||
int block = blockIdx.x;
|
||||
int grid = gridDim.x;
|
||||
|
||||
while (block < num_blocks) {
|
||||
AllReduceKernel::run(A, B, N, block, rank, dbuffer_list, data_offset,
|
||||
flag_color);
|
||||
block += grid;
|
||||
flag_color++;
|
||||
}
|
||||
}
|
||||
|
||||
#define TWOSHOT_DISPATCH(__codec) \
|
||||
if (world_size == 2) { \
|
||||
using LineCodec = __codec<T, 2>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 4) { \
|
||||
using LineCodec = __codec<T, 4>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
} else if (world_size == 8) { \
|
||||
using LineCodec = __codec<T, 8>; \
|
||||
using AllReduceKernel = AllReduceTwoshot<T, LineCodec, cast_bf2half>; \
|
||||
hipLaunchKernelGGL((allreduce_prototype_twoshot<AllReduceKernel, T>), \
|
||||
dim3(grid), dim3(kBlockTwoShot), 0, stream, A, B, N, \
|
||||
num_blocks, rank, dbuffer_list, data_offset, \
|
||||
flag_color); \
|
||||
}
|
||||
|
||||
enum QuickReduceQuantLevel {
|
||||
F16 = 0,
|
||||
INT8 = 1,
|
||||
INT6 = 2,
|
||||
INT4 = 3,
|
||||
};
|
||||
|
||||
struct DeviceComms {
|
||||
// Max problem size is 2GB (in bytes) or half of uint32_t max value.
|
||||
int64_t kMaxProblemSize =
|
||||
static_cast<int64_t>(std::numeric_limits<int32_t>::max()) + 1;
|
||||
|
||||
// Max TP-8
|
||||
static int constexpr kMaxWorldSize = 8;
|
||||
|
||||
bool initialized = false;
|
||||
uint32_t flag_color = 1;
|
||||
int world_size;
|
||||
int rank;
|
||||
|
||||
uint8_t* dbuffer;
|
||||
uint8_t** dbuffer_list;
|
||||
hipIpcMemHandle_t buffer_ipc_handle;
|
||||
std::vector<hipIpcMemHandle_t> all_buffer_ipc_handles;
|
||||
std::vector<uint8_t*> buffer_list;
|
||||
uint32_t data_offset;
|
||||
|
||||
DeviceComms() : initialized(false), world_size(1), rank(0) {}
|
||||
~DeviceComms() { destroy(); }
|
||||
|
||||
void init(int world_size, int rank,
|
||||
std::optional<int64_t> max_problem_size = std::nullopt) {
|
||||
destroy();
|
||||
this->world_size = world_size;
|
||||
this->rank = rank;
|
||||
if (max_problem_size.has_value() && max_problem_size.value() > 0) {
|
||||
this->kMaxProblemSize = max_problem_size.value();
|
||||
}
|
||||
// Allocate buffer size for worst case: F16 2-stage buffer.
|
||||
uint32_t flags_buffer_size =
|
||||
2 * world_size * kMaxNumBlocks * sizeof(uint32_t);
|
||||
static int64_t data_buffer_size = 2 * this->kMaxProblemSize;
|
||||
int64_t total_buffer_size = flags_buffer_size + data_buffer_size;
|
||||
data_offset = flags_buffer_size;
|
||||
HIP_CHECK(hipExtMallocWithFlags((void**)&dbuffer, total_buffer_size,
|
||||
hipDeviceMallocUncached));
|
||||
|
||||
// Clear the flags buffer.
|
||||
HIP_CHECK(hipMemset(dbuffer, 0, flags_buffer_size));
|
||||
|
||||
// Device-side list of IPC buffers.
|
||||
buffer_list.resize(world_size);
|
||||
HIP_CHECK(hipMalloc(&dbuffer_list, world_size * sizeof(uint8_t*)));
|
||||
|
||||
// Create IPC handles for rank's communication buffer.
|
||||
all_buffer_ipc_handles.resize(world_size);
|
||||
HIP_CHECK(hipIpcGetMemHandle(&buffer_ipc_handle, dbuffer));
|
||||
|
||||
initialized = true;
|
||||
}
|
||||
int get_world_size() { return world_size; }
|
||||
int get_rank() { return rank; }
|
||||
bool status() { return initialized; }
|
||||
hipIpcMemHandle_t const get_handle() { return buffer_ipc_handle; }
|
||||
|
||||
void destroy() {
|
||||
if (initialized) {
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcCloseMemHandle(dbuffer_list[i]));
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipFree(dbuffer));
|
||||
HIP_CHECK(hipFree(dbuffer_list));
|
||||
|
||||
initialized = false;
|
||||
}
|
||||
}
|
||||
|
||||
void open_ipc_handles(std::vector<hipIpcMemHandle_t> const& ipc_handles) {
|
||||
assert(ipc_handles.size() == all_buffer_ipc_handles.size());
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
all_buffer_ipc_handles[i] = ipc_handles[i];
|
||||
}
|
||||
|
||||
// Open device memory access to the IPC communication buffers.
|
||||
// Note: For our own rank, we do not need to open a handle.
|
||||
for (int i = 0; i < world_size; i++) {
|
||||
if (i != rank) {
|
||||
HIP_CHECK(hipIpcOpenMemHandle((void**)&buffer_list[i],
|
||||
all_buffer_ipc_handles[i],
|
||||
hipIpcMemLazyEnablePeerAccess));
|
||||
} else {
|
||||
buffer_list[i] = dbuffer;
|
||||
}
|
||||
}
|
||||
|
||||
HIP_CHECK(hipMemcpy(dbuffer_list, buffer_list.data(),
|
||||
world_size * sizeof(uint8_t*), hipMemcpyHostToDevice));
|
||||
}
|
||||
|
||||
template <typename T, bool cast_bf2half>
|
||||
void allreduce(T const* A, T* B, uint32_t N, int quant_level,
|
||||
hipStream_t stream) {
|
||||
if (world_size != 2 && world_size != 4 && world_size != 8) {
|
||||
throw std::runtime_error("All Reduce not supported for world_size = " +
|
||||
std::to_string(world_size));
|
||||
}
|
||||
|
||||
// Configuration.
|
||||
uint32_t msg_size = N * sizeof(T);
|
||||
uint32_t num_blocks = divceil(msg_size, kTileSize);
|
||||
uint32_t grid = min(kMaxNumBlocks, num_blocks);
|
||||
auto quant_level_ = static_cast<QuickReduceQuantLevel>(quant_level);
|
||||
switch (quant_level_) {
|
||||
case QuickReduceQuantLevel::INT8:
|
||||
TWOSHOT_DISPATCH(CodecQ8)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT6:
|
||||
TWOSHOT_DISPATCH(CodecQ6)
|
||||
break;
|
||||
case QuickReduceQuantLevel::INT4:
|
||||
TWOSHOT_DISPATCH(CodecQ4)
|
||||
break;
|
||||
default:
|
||||
TWOSHOT_DISPATCH(CodecFP)
|
||||
break;
|
||||
}
|
||||
HIP_CHECK(cudaGetLastError());
|
||||
// Rotate the flag color.
|
||||
flag_color += divceil(N, grid);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
698
csrc/quickreduce/quick_reduce_impl.cuh
Normal file
698
csrc/quickreduce/quick_reduce_impl.cuh
Normal file
@ -0,0 +1,698 @@
|
||||
#pragma once
|
||||
|
||||
#include <hip/hip_runtime.h>
|
||||
#include "base.h"
|
||||
|
||||
namespace quickreduce {
|
||||
|
||||
struct CodecBase {
|
||||
const int thread;
|
||||
const int rank;
|
||||
const int group_leader;
|
||||
__quickreduce_device_inline__ CodecBase(int thread, int rank)
|
||||
: thread(thread),
|
||||
rank(rank),
|
||||
group_leader((threadIdx.x / kThreadGroupSize) * kThreadGroupSize) {
|
||||
set_fp16_ovfl(true);
|
||||
}
|
||||
};
|
||||
|
||||
// Default full precision codec.
|
||||
template <typename T, int world_size>
|
||||
struct CodecFP : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each thread processes atoms of f16x8_t (16B).
|
||||
static constexpr int kRankTransmittedTileSize =
|
||||
kBlockSize * kRankAtoms * sizeof(int32x4_t);
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
__quickreduce_device_inline__ CodecFP(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
const int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
__builtin_nontemporal_store(data[i], send_buffer + thread);
|
||||
send_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int i = 0; i < kRankAtoms; i++) {
|
||||
data[i] = __builtin_nontemporal_load(*recv_buffer + thread);
|
||||
*recv_buffer += kAtomStride;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int4 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int4 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ4 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int4x8_t (4B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1152;
|
||||
static constexpr int kRankTileScaleOffset = 1024;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride =
|
||||
kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/8.0h, -1/8.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor =
|
||||
std::is_same<T, half>::value ? 0xB000B000 : 0xBE00BE00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon =
|
||||
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-8, -8}, f16x2_t
|
||||
static constexpr int kRangeMin =
|
||||
std::is_same<T, half>::value ? 0xC800C800 : 0xC100C100;
|
||||
|
||||
// {+7, +7}, f16x2_t
|
||||
static constexpr int kRangeMax =
|
||||
std::is_same<T, half>::value ? 0x47004700 : 0x40E040E0;
|
||||
|
||||
// {+8, +8}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00080008;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ4(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q4 into int32_t
|
||||
int qw = q[0] | (q[1] << 4) | (q[2] << 8) | (q[3] << 12);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr =
|
||||
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32_t* qw_ptr = reinterpret_cast<int32_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
int32_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q4 into f16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static constexpr uint kMask000F = 0x000F000F;
|
||||
static constexpr uint kHalf2_1024 =
|
||||
0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1032 =
|
||||
0xE408E408; // {-1032.0, -1032.0}, fp16x2_t
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q4 = ((qw >> (i * 4)) & kMask000F) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q4, kHalf2_1032);
|
||||
} else {
|
||||
int32_t int16_2 = (qw >> (i * 4)) & kMask000F;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int6 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int6 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ6 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of fp16x8_t (16B),
|
||||
// into a int6x8_t (4B + 2B) and a fp16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 1664;
|
||||
static constexpr int kRankTileQ2Offset = 1024;
|
||||
static constexpr int kRankTileScaleOffset = 1536;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTransmittedTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride =
|
||||
kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/32.0h, -1/32.0h}, fp16x2_t
|
||||
static constexpr int kScaleFactor =
|
||||
std::is_same<T, half>::value ? 0xA800A800 : 0xBD00BD00;
|
||||
|
||||
// {1e-7, 1e-7}, fp16x2_t
|
||||
static constexpr int kScaleEpsilon =
|
||||
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-32, -32}, fp16x2_t
|
||||
static constexpr int kRangeMin =
|
||||
std::is_same<T, half>::value ? 0xD000D000 : 0xC200C200;
|
||||
|
||||
// {+31, +31}, fp16x2_t
|
||||
static constexpr int kRangeMax =
|
||||
std::is_same<T, half>::value ? 0x4FC04FC0 : 0x41F841F8;
|
||||
|
||||
// {+32, +32}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00200020;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ6(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
const int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q6 into int32_t + int16_t
|
||||
uint32_t q4w;
|
||||
uint16_t q2w = 0;
|
||||
q4w = (q[0] & 0x000F000F) | ((q[1] & 0x000F000F) << 4) |
|
||||
((q[2] & 0x000F000F) << 8) | ((q[3] & 0x000F000F) << 12);
|
||||
{
|
||||
int16_t* tw = reinterpret_cast<int16_t*>(&q);
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 8; i++) {
|
||||
q2w |= (tw[i] >> 4) << (i * 2);
|
||||
}
|
||||
}
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr =
|
||||
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr =
|
||||
reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(q4w, q4w_ptr);
|
||||
__builtin_nontemporal_store(q2w, q2w_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
uint32_t* q4w_ptr = reinterpret_cast<uint32_t*>(atom_ptr) + thread;
|
||||
uint16_t* q2w_ptr =
|
||||
reinterpret_cast<uint16_t*>(atom_ptr + kRankTileQ2Offset) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
uint32_t q4w = __builtin_nontemporal_load(q4w_ptr);
|
||||
uint16_t q2w = __builtin_nontemporal_load(q2w_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q6 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask000F = 0x000F000F;
|
||||
static uint constexpr kHalf2_1024 =
|
||||
0x64006400; // {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1056 =
|
||||
0xE420E420; // {-1056.0, -1056.0}, fp16x2_t
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
int32_t q4 = q4w & kMask000F;
|
||||
int32_t q2 = (q2w & 0x3) | ((q2w & 0xC) << 14);
|
||||
q4w >>= 4;
|
||||
q2w >>= 4;
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q6 = q4 | (q2 << 4) | kHalf2_1024;
|
||||
asm volatile("v_pk_add_f16 %0, %1, %2"
|
||||
: "=v"(w[i])
|
||||
: "v"(q6), "v"(kHalf2_1056));
|
||||
} else {
|
||||
int32_t int16_2 = q4 | (q2 << 4);
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
// That's pretty much it...
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Int8 symmetric quantization codec.
|
||||
// We quantize the FP16 data to block-scaled Int8 in blocks of 4 *
|
||||
// kThreadGroupSize.
|
||||
template <typename T, int world_size>
|
||||
struct CodecQ8 : public CodecBase {
|
||||
static constexpr int kWorldSize = world_size;
|
||||
|
||||
// Codec tile size process by this workgroup.
|
||||
// Each threads processes a fragment of f16x8_t (16B),
|
||||
// into a int8x8_t (8B) and a f16 scale shared among 32 values.
|
||||
static constexpr int kRankAtoms = kAtoms / kWorldSize;
|
||||
static constexpr int kRankTileStride = 2176;
|
||||
static constexpr int kRankTileScaleOffset = 2048;
|
||||
static constexpr int kRankTransmittedTileSize = kRankTileStride * kRankAtoms;
|
||||
static_assert(kRankTransmittedTileSize % 16 == 0,
|
||||
"kRankTileSize must be 16B aligned.");
|
||||
|
||||
static constexpr int kRankBufferTileStride =
|
||||
kRankTileStride / sizeof(int32x4_t);
|
||||
|
||||
// Total tile size for the collective communication.
|
||||
static constexpr int kTransmittedTileSize =
|
||||
kRankTransmittedTileSize * kWorldSize;
|
||||
|
||||
// Constants configuration
|
||||
|
||||
// {-1/128.0h, -1/128.0h}, f16x2_t
|
||||
static constexpr int kScaleFactor =
|
||||
std::is_same<T, half>::value ? 0xA000A000 : 0xBC00BC00;
|
||||
|
||||
// {1e-7, 1e-7}, f16x2_t
|
||||
static constexpr int kScaleEpsilon =
|
||||
std::is_same<T, half>::value ? 0x00010001 : 0x33D733D7;
|
||||
|
||||
// {-128, -128}, f16x2_t
|
||||
static constexpr int kRangeMin =
|
||||
std::is_same<T, half>::value ? 0xD800D800 : 0xC300C300;
|
||||
// {+127, +127}, f16x2_t
|
||||
static constexpr int kRangeMax =
|
||||
std::is_same<T, half>::value ? 0x57F057F0 : 0x42FE42FE;
|
||||
|
||||
// {+128, +128}, int16x2_t
|
||||
static constexpr int kRangeBias = 0x00800080;
|
||||
|
||||
__quickreduce_device_inline__ CodecQ8(int thread, int rank)
|
||||
: CodecBase(thread, rank) {}
|
||||
|
||||
__quickreduce_device_inline__ void send(int32x4_t* __restrict__ send_buffer,
|
||||
int32x4_t const* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
int32x4_t const atom = data[k];
|
||||
// Compute the absolute maximum of the atom in the thread group
|
||||
// In 2 blocks of values, upper/lower halves of the f16x2_t
|
||||
int wblockmax = group_abs_max<T>(atom);
|
||||
|
||||
// Derive scales
|
||||
int decoding_scale;
|
||||
int encoding_scale;
|
||||
decoding_scale = packed_mul<T>(wblockmax, kScaleFactor);
|
||||
encoding_scale = packed_add<T>(decoding_scale, kScaleEpsilon);
|
||||
encoding_scale = packed_rcp<T>(encoding_scale);
|
||||
|
||||
// Apply scales to get quantized values
|
||||
int32x4_t w;
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(atom[i], encoding_scale);
|
||||
w[i] = packed_max<T>(w[i], kRangeMin);
|
||||
w[i] = packed_min<T>(w[i], kRangeMax);
|
||||
}
|
||||
|
||||
// Convert from f16x2_t to uint16x2_t
|
||||
int32x4_t q;
|
||||
{
|
||||
int16_t* qi = reinterpret_cast<int16_t*>(&q);
|
||||
T* wh = reinterpret_cast<T*>(&w);
|
||||
for (int i = 0; i < 8; i++) qi[i] = (int16_t)rintf(T2float_cast(wh[i]));
|
||||
|
||||
for (int i = 0; i < 4; i++) {
|
||||
q[i] = packed_add<int16_t>(q[i], kRangeBias);
|
||||
}
|
||||
}
|
||||
|
||||
// Pack 8 x q8 into int32x2_t
|
||||
int32x2_t qw;
|
||||
qw[0] = q[0] | (q[1] << 8);
|
||||
qw[1] = q[2] | (q[3] << 8);
|
||||
|
||||
// Write quantized atom to send_buffer
|
||||
// note: only the group leader stores the scale
|
||||
uint8_t* atom_ptr =
|
||||
reinterpret_cast<uint8_t*>(send_buffer + k * kRankBufferTileStride);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
__builtin_nontemporal_store(qw, qw_ptr);
|
||||
if (threadIdx.x == group_leader) {
|
||||
__builtin_nontemporal_store(decoding_scale, qs_ptr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
__quickreduce_device_inline__ void recv(int32x4_t** __restrict__ recv_buffer,
|
||||
int32x4_t* __restrict__ data) {
|
||||
for (int k = 0; k < kRankAtoms; k++) {
|
||||
// Directly read quantized atom from recv_buffer
|
||||
uint8_t* atom_ptr = reinterpret_cast<uint8_t*>(*recv_buffer);
|
||||
int32x2_t* qw_ptr = reinterpret_cast<int32x2_t*>(atom_ptr) + thread;
|
||||
int* qs_ptr = reinterpret_cast<int*>(atom_ptr + kRankTileScaleOffset) +
|
||||
(thread / 8);
|
||||
|
||||
int32x2_t qw = __builtin_nontemporal_load(qw_ptr);
|
||||
int qs = __builtin_nontemporal_load(qs_ptr);
|
||||
|
||||
*recv_buffer += kRankBufferTileStride;
|
||||
|
||||
// Unpack q8 into fp16x8_t
|
||||
int32x4_t w;
|
||||
{
|
||||
static uint constexpr kMask00FF = 0x00FF00FF;
|
||||
|
||||
// {1024.0, 1024.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1024 = 0x64006400;
|
||||
|
||||
// {-1152.0, -1152.0}, fp16x2_t
|
||||
static uint constexpr kHalf2_1152 = 0xE480E480;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 4; i++) {
|
||||
if constexpr (std::is_same<T, half>::value) {
|
||||
int32_t q8 =
|
||||
((qw[i / 2] >> ((i % 2) * 8)) & kMask00FF) | kHalf2_1024;
|
||||
w[i] = packed_add<half>(q8, kHalf2_1152);
|
||||
} else {
|
||||
int32_t int16_2 = (qw[i / 2] >> ((i % 2) * 8)) & kMask00FF;
|
||||
int16_t low = static_cast<int16_t>(int16_2 & 0xFFFF);
|
||||
int16_t high = static_cast<int16_t>((int16_2 >> 16) & 0xFFFF);
|
||||
nv_bfloat16 bf_low = __float2bfloat16(static_cast<float>(low));
|
||||
nv_bfloat16 bf_high = __float2bfloat16(static_cast<float>(high));
|
||||
nv_bfloat162 bf2 = __halves2bfloat162(bf_low, bf_high);
|
||||
int32_t packed_bf16 = *reinterpret_cast<int32_t*>(&bf2);
|
||||
w[i] = packed_add<nv_bfloat16>(packed_bf16, kRangeMin);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Apply decoding scales
|
||||
for (int i = 0; i < 4; i++) {
|
||||
w[i] = packed_mul<T>(w[i], qs);
|
||||
}
|
||||
|
||||
data[k] = w;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Twoshot All Reduce
|
||||
template <typename T, class Codec, bool cast_bf2half>
|
||||
struct AllReduceTwoshot {
|
||||
static_assert(sizeof(T) == 2);
|
||||
|
||||
static constexpr int kWorldSize = Codec::kWorldSize;
|
||||
|
||||
__device__ static void run(
|
||||
T const* __restrict__ input, T* __restrict__ output,
|
||||
uint32_t const N, // number of elements
|
||||
int const block, // block index
|
||||
int const rank, // rank index
|
||||
uint8_t** __restrict__ buffer_list, // communication buffers
|
||||
uint32_t const data_offset, // offset to start of the data buffer
|
||||
uint32_t flag_color) {
|
||||
// Topology
|
||||
int thread = threadIdx.x + threadIdx.y * kWavefront;
|
||||
uint8_t* rank_buffer = buffer_list[rank];
|
||||
Codec codec(thread, rank);
|
||||
int block_id = blockIdx.x;
|
||||
int grid_size = gridDim.x;
|
||||
// --------------------------------------------------------
|
||||
// Read input into registers
|
||||
int32x4_t tA[kAtoms];
|
||||
|
||||
BufferResource src_buffer(const_cast<T*>(input), N * sizeof(T));
|
||||
uint32_t src_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
tA[i] = buffer_load_dwordx4(src_buffer.descriptor, src_offset, 0, 0);
|
||||
src_offset += kAtomStride * sizeof(int32x4_t);
|
||||
if constexpr (cast_bf2half) {
|
||||
const nv_bfloat162* bf_buf =
|
||||
reinterpret_cast<const nv_bfloat162*>(&tA[i]);
|
||||
half2 half_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __bfloat1622float2(bf_buf[j]);
|
||||
half_buf[j] = __float22half2_rn(f);
|
||||
}
|
||||
tA[i] = *reinterpret_cast<const int32x4_t*>(half_buf);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Phase-1A: Write segment data into the communication buffer of the target
|
||||
// rank responsible for this segment.
|
||||
uint32_t comm_data0_offset =
|
||||
data_offset + block_id * Codec::kTransmittedTileSize;
|
||||
uint32_t comm_data1_offset =
|
||||
grid_size * Codec::kTransmittedTileSize + comm_data0_offset;
|
||||
|
||||
uint32_t comm_flags0_offset = block_id * (kWorldSize * sizeof(uint32_t));
|
||||
uint32_t comm_flags1_offset =
|
||||
grid_size * (kWorldSize * sizeof(uint32_t)) + comm_flags0_offset;
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data0_offset +
|
||||
rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(
|
||||
buffer_list[r] + comm_flags0_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
// --------------------------------------------------------
|
||||
// Phase-1B: Reduce the segment data from the communication buffers.
|
||||
int32x4_t tR[Codec::kRankAtoms] = {};
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer =
|
||||
reinterpret_cast<int32x4_t*>(rank_buffer + comm_data0_offset);
|
||||
uint32_t* flag_ptr =
|
||||
reinterpret_cast<uint32_t*>(rank_buffer + comm_flags0_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// note: we reuse tA as temp buffer here
|
||||
codec.recv(&recv_buffer, tA);
|
||||
|
||||
for (int i = 0; i < Codec::kRankAtoms; i++) {
|
||||
packed_assign_add<T>(&tR[i], &tA[i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Phase-2: Write the reduced segment to every other rank
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
int32x4_t* send_buffer =
|
||||
reinterpret_cast<int32x4_t*>(buffer_list[r] + comm_data1_offset +
|
||||
rank * Codec::kRankTransmittedTileSize);
|
||||
codec.send(send_buffer, tR);
|
||||
}
|
||||
|
||||
__syncthreads();
|
||||
if (thread < kWorldSize) {
|
||||
int r = thread;
|
||||
uint32_t* flag_ptr = reinterpret_cast<uint32_t*>(
|
||||
buffer_list[r] + comm_flags1_offset + rank * sizeof(uint32_t));
|
||||
set_sync_flag(flag_ptr, flag_color);
|
||||
}
|
||||
|
||||
// Phase-2: Read the gather segments from the rank's communication buffer.
|
||||
{
|
||||
// Read the data from the communication buffer.
|
||||
int32x4_t* recv_buffer =
|
||||
reinterpret_cast<int32x4_t*>(rank_buffer + comm_data1_offset);
|
||||
uint32_t* flag_ptr =
|
||||
reinterpret_cast<uint32_t*>(rank_buffer + comm_flags1_offset);
|
||||
|
||||
for (int r = 0; r < kWorldSize; r++) {
|
||||
// Wait for the flags to be set.
|
||||
if (thread == 0) {
|
||||
wait_sync_flag(&flag_ptr[r], flag_color);
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// Gather all reduced and final rank segments into tA.
|
||||
codec.recv(&recv_buffer, &tA[r * Codec::kRankAtoms]);
|
||||
}
|
||||
}
|
||||
|
||||
// --------------------------------------------------------
|
||||
// Write the result to output.
|
||||
BufferResource dst_buffer(output, N * sizeof(T));
|
||||
uint32_t dst_offset = block * kTileSize + thread * sizeof(int32x4_t);
|
||||
|
||||
for (int i = 0; i < kAtoms; i++) {
|
||||
if constexpr (cast_bf2half) {
|
||||
const half2* half_buf = reinterpret_cast<const half2*>(&tA[i]);
|
||||
nv_bfloat162 bf16_buf[4];
|
||||
#pragma unroll
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
float2 f = __half22float2(half_buf[j]);
|
||||
bf16_buf[j] = __float22bfloat162_rn(f);
|
||||
}
|
||||
buffer_store_dwordx4(*reinterpret_cast<const int32x4_t*>(bf16_buf),
|
||||
dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
} else {
|
||||
buffer_store_dwordx4(tA[i], dst_buffer.descriptor, dst_offset, 0, 0);
|
||||
}
|
||||
dst_offset += kAtomStride * sizeof(int32x4_t);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace quickreduce
|
||||
@ -1598,7 +1598,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane2id = laneid % 2;
|
||||
const int lane4id = laneid % 4;
|
||||
const int lane16id = laneid % 16;
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
@ -1745,7 +1744,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
|
||||
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
|
||||
|
||||
@ -2368,7 +2366,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const int warpid = threadIdx.x / WARP_SIZE;
|
||||
const int laneid = threadIdx.x % WARP_SIZE;
|
||||
const int lane2id = laneid % 2;
|
||||
const int lane4id = laneid % 4;
|
||||
const int lane16id = laneid % 16;
|
||||
const int rowid = laneid / 16;
|
||||
|
||||
@ -2514,7 +2511,6 @@ __launch_bounds__(NUM_THREADS, 3) void paged_attention_ll4mi_QKV_mfma16_kernel(
|
||||
const cache_t* k_ptr2 = k_ptr + kblock_number * kv_block_stride;
|
||||
const int klocal_token_idx =
|
||||
TOKENS_PER_WARP * warpid + token_depth * 16 + lane16id;
|
||||
const int kglobal_token_idx = partition_start_token_idx + klocal_token_idx;
|
||||
const int kphysical_block_offset = klocal_token_idx % BLOCK_SIZE;
|
||||
const cache_t* k_ptr3 = k_ptr2 + kphysical_block_offset * KX;
|
||||
|
||||
|
||||
@ -725,6 +725,24 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _custom_ar), custom_ar) {
|
||||
custom_ar.impl("open_mem_handle", torch::kCPU, &open_mem_handle);
|
||||
|
||||
custom_ar.def("free_shared_buffer", &free_shared_buffer);
|
||||
#ifdef USE_ROCM
|
||||
// Quick Reduce all-reduce kernels
|
||||
custom_ar.def(
|
||||
"qr_all_reduce(int fa, Tensor inp, Tensor out, int quant_level, bool "
|
||||
"cast_bf2half) -> ()");
|
||||
custom_ar.impl("qr_all_reduce", torch::kCUDA, &qr_all_reduce);
|
||||
|
||||
custom_ar.def("init_custom_qr", &init_custom_qr);
|
||||
custom_ar.def("qr_destroy", &qr_destroy);
|
||||
|
||||
custom_ar.def("qr_get_handle", &qr_get_handle);
|
||||
|
||||
custom_ar.def("qr_open_handles(int _fa, Tensor[](b!) handles) -> ()");
|
||||
custom_ar.impl("qr_open_handles", torch::kCPU, &qr_open_handles);
|
||||
|
||||
// Max input size in bytes
|
||||
custom_ar.def("qr_max_size", &qr_max_size);
|
||||
#endif
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
||||
|
||||
@ -6,30 +6,106 @@
|
||||
# docs/assets/contributing/dockerfile-stages-dependency.png
|
||||
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
# By parameterizing the base images, we allow third-party to use their own
|
||||
# base images. One use case is hermetic builds with base images stored in
|
||||
# private registries that use a different repository naming conventions.
|
||||
#
|
||||
# Example:
|
||||
# docker build --build-arg BUILD_BASE_IMAGE=registry.acme.org/mirror/nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
ARG BUILD_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04
|
||||
ARG FINAL_BASE_IMAGE=nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04
|
||||
|
||||
# By parameterizing the Deadsnakes repository URL, we allow third-party to use
|
||||
# their own mirror. When doing so, we don't benefit from the transparent
|
||||
# installation of the GPG key of the PPA, as done by add-apt-repository, so we
|
||||
# also need a URL for the GPG key.
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
|
||||
# The PyPA get-pip.py script is a self contained script+zip file, that provides
|
||||
# both the installer script and the pip base85-encoded zip archive. This allows
|
||||
# bootstrapping pip in environment where a dsitribution package does not exist.
|
||||
#
|
||||
# By parameterizing the URL for get-pip.py installation script, we allow
|
||||
# third-party to use their own copy of the script stored in a private mirror.
|
||||
# We set the default value to the PyPA owned get-pip.py script.
|
||||
#
|
||||
# Reference: https://pip.pypa.io/en/stable/installation/#get-pip-py
|
||||
ARG GET_PIP_URL="https://bootstrap.pypa.io/get-pip.py"
|
||||
|
||||
# PIP supports fetching the packages from custom indexes, allowing third-party
|
||||
# to host the packages in private mirrors. The PIP_INDEX_URL and
|
||||
# PIP_EXTRA_INDEX_URL are standard PIP environment variables to override the
|
||||
# default indexes. By letting them empty by default, PIP will use its default
|
||||
# indexes if the build process doesn't override the indexes.
|
||||
#
|
||||
# Uv uses different variables. We set them by default to the same values as
|
||||
# PIP, but they can be overridden.
|
||||
ARG PIP_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL
|
||||
ARG UV_INDEX_URL=${PIP_INDEX_URL}
|
||||
ARG UV_EXTRA_INDEX_URL=${PIP_EXTRA_INDEX_URL}
|
||||
|
||||
# PyTorch provides its own indexes for standard and nightly builds
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL=https://download.pytorch.org/whl
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL=https://download.pytorch.org/whl/nightly
|
||||
|
||||
# PIP supports multiple authentication schemes, including keyring
|
||||
# By parameterizing the PIP_KEYRING_PROVIDER variable and setting it to
|
||||
# disabled by default, we allow third-party to use keyring authentication for
|
||||
# their private Python indexes, while not changing the default behavior which
|
||||
# is no authentication.
|
||||
#
|
||||
# Reference: https://pip.pypa.io/en/stable/topics/authentication/#keyring-support
|
||||
ARG PIP_KEYRING_PROVIDER=disabled
|
||||
ARG UV_KEYRING_PROVIDER=${PIP_KEYRING_PROVIDER}
|
||||
|
||||
#################### BASE BUILD IMAGE ####################
|
||||
# prepare basic build environment
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu20.04 AS base
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM ${BUILD_BASE_IMAGE} AS base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
ARG TARGETPLATFORM
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
# Install Python and other dependencies
|
||||
RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& echo 'tzdata tzdata/Zones/America select Los_Angeles' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl sudo \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
@ -63,21 +139,25 @@ WORKDIR /workspace
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
"torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
--pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
fi
|
||||
|
||||
COPY requirements/common.txt requirements/common.txt
|
||||
COPY requirements/cuda.txt requirements/cuda.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/cuda.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# cuda arch list used by torch
|
||||
# can be useful for both `dev` and `test`
|
||||
# explicitly set the list to avoid issues with torch 2.2
|
||||
# see https://github.com/pytorch/pytorch/pull/123243
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
ARG torch_cuda_arch_list='7.0 7.5 8.0 8.9 9.0 10.0 12.0'
|
||||
ENV TORCH_CUDA_ARCH_LIST=${torch_cuda_arch_list}
|
||||
# Override the arch list for flash-attn to reduce the binary size
|
||||
ARG vllm_fa_cmake_gpu_arches='80-real;90-real'
|
||||
@ -88,6 +168,10 @@ ENV VLLM_FA_CMAKE_GPU_ARCHES=${vllm_fa_cmake_gpu_arches}
|
||||
FROM base AS build
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
|
||||
# install build dependencies
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
|
||||
@ -98,7 +182,7 @@ ENV UV_INDEX_STRATEGY="unsafe-best-match"
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
COPY . .
|
||||
ARG GIT_REPO_CHECK=0
|
||||
@ -113,6 +197,8 @@ ARG nvcc_threads=8
|
||||
ENV NVCC_THREADS=$nvcc_threads
|
||||
|
||||
ARG USE_SCCACHE
|
||||
ARG SCCACHE_DOWNLOAD_URL=https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz
|
||||
ARG SCCACHE_ENDPOINT
|
||||
ARG SCCACHE_BUCKET_NAME=vllm-build-sccache
|
||||
ARG SCCACHE_REGION_NAME=us-west-2
|
||||
ARG SCCACHE_S3_NO_CREDENTIALS=0
|
||||
@ -121,10 +207,11 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
if [ "$USE_SCCACHE" = "1" ]; then \
|
||||
echo "Installing sccache..." \
|
||||
&& curl -L -o sccache.tar.gz https://github.com/mozilla/sccache/releases/download/v0.8.1/sccache-v0.8.1-x86_64-unknown-linux-musl.tar.gz \
|
||||
&& curl -L -o sccache.tar.gz ${SCCACHE_DOWNLOAD_URL} \
|
||||
&& tar -xzf sccache.tar.gz \
|
||||
&& sudo mv sccache-v0.8.1-x86_64-unknown-linux-musl/sccache /usr/bin/sccache \
|
||||
&& rm -rf sccache.tar.gz sccache-v0.8.1-x86_64-unknown-linux-musl \
|
||||
&& if [ ! -z ${SCCACHE_ENDPOINT} ] ; then export SCCACHE_ENDPOINT=${SCCACHE_ENDPOINT} ; fi \
|
||||
&& export SCCACHE_BUCKET=${SCCACHE_BUCKET_NAME} \
|
||||
&& export SCCACHE_REGION=${SCCACHE_REGION_NAME} \
|
||||
&& export SCCACHE_S3_NO_CREDENTIALS=${SCCACHE_S3_NO_CREDENTIALS} \
|
||||
@ -162,6 +249,10 @@ RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
|
||||
#################### DEV IMAGE ####################
|
||||
FROM base as dev
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
@ -176,21 +267,25 @@ COPY requirements/test.txt requirements/test.txt
|
||||
COPY requirements/dev.txt requirements/dev.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/dev.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
#################### DEV IMAGE ####################
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
# image with vLLM installed
|
||||
# TODO: Restore to base image after FlashInfer AOT wheel fixed
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-devel-ubuntu22.04 AS vllm-base
|
||||
ARG CUDA_VERSION=12.8.1
|
||||
ARG PYTHON_VERSION=3.12
|
||||
FROM ${FINAL_BASE_IMAGE} AS vllm-base
|
||||
ARG CUDA_VERSION
|
||||
ARG PYTHON_VERSION
|
||||
WORKDIR /vllm-workspace
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
ARG DEADSNAKES_MIRROR_URL
|
||||
ARG DEADSNAKES_GPGKEY_URL
|
||||
ARG GET_PIP_URL
|
||||
|
||||
RUN PYTHON_VERSION_STR=$(echo ${PYTHON_VERSION} | sed 's/\.//g') && \
|
||||
echo "export PYTHON_VERSION_STR=${PYTHON_VERSION_STR}" >> /etc/environment
|
||||
|
||||
@ -200,17 +295,33 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y ccache software-properties-common git curl wget sudo vim python3-pip \
|
||||
&& apt-get install -y ffmpeg libsm6 libxext6 libgl1 \
|
||||
&& for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done \
|
||||
&& if [ ! -z ${DEADSNAKES_MIRROR_URL} ] ; then \
|
||||
if [ ! -z "${DEADSNAKES_GPGKEY_URL}" ] ; then \
|
||||
mkdir -p -m 0755 /etc/apt/keyrings ; \
|
||||
curl -L ${DEADSNAKES_GPGKEY_URL} | gpg --dearmor > /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
sudo chmod 644 /etc/apt/keyrings/deadsnakes.gpg ; \
|
||||
echo "deb [signed-by=/etc/apt/keyrings/deadsnakes.gpg] ${DEADSNAKES_MIRROR_URL} $(lsb_release -cs) main" > /etc/apt/sources.list.d/deadsnakes.list ; \
|
||||
fi ; \
|
||||
else \
|
||||
for i in 1 2 3; do \
|
||||
add-apt-repository -y ppa:deadsnakes/ppa && break || \
|
||||
{ echo "Attempt $i failed, retrying in 5s..."; sleep 5; }; \
|
||||
done ; \
|
||||
fi \
|
||||
&& apt-get update -y \
|
||||
&& apt-get install -y python${PYTHON_VERSION} python${PYTHON_VERSION}-dev python${PYTHON_VERSION}-venv libibverbs-dev \
|
||||
&& update-alternatives --install /usr/bin/python3 python3 /usr/bin/python${PYTHON_VERSION} 1 \
|
||||
&& update-alternatives --set python3 /usr/bin/python${PYTHON_VERSION} \
|
||||
&& ln -sf /usr/bin/python${PYTHON_VERSION}-config /usr/bin/python3-config \
|
||||
&& curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \
|
||||
&& curl -sS ${GET_PIP_URL} | python${PYTHON_VERSION} \
|
||||
&& python3 --version && python3 -m pip --version
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
ARG PYTORCH_CUDA_INDEX_BASE_URL
|
||||
ARG PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL
|
||||
ARG PIP_KEYRING_PROVIDER UV_KEYRING_PROVIDER
|
||||
|
||||
# Install uv for faster pip installs
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
python3 -m pip install uv
|
||||
@ -232,19 +343,23 @@ RUN ldconfig /usr/local/cuda-$(echo $CUDA_VERSION | cut -d. -f1,2)/compat/
|
||||
# after this step
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 "torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319"; \
|
||||
uv pip install --system --index-url https://download.pytorch.org/whl/nightly/cu128 --pre pytorch_triton==3.3.0+gitab727c40; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
"torch==2.8.0.dev20250318+cu128" "torchvision==0.22.0.dev20250319" ; \
|
||||
uv pip install --system \
|
||||
--index-url ${PYTORCH_CUDA_NIGHTLY_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \
|
||||
--pre pytorch_triton==3.3.0+gitab727c40 ; \
|
||||
fi
|
||||
|
||||
# Install vllm wheel first, so that torch etc will be installed.
|
||||
RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist \
|
||||
--mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system dist/*.whl --verbose \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
# If we need to build FlashInfer wheel before its release:
|
||||
# $ # Note we remove 7.0 from the arch list compared to the list below, since FlashInfer only supports sm75+
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a'
|
||||
# $ export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||
# $ git clone https://github.com/flashinfer-ai/flashinfer.git --recursive
|
||||
# $ cd flashinfer
|
||||
# $ git checkout v0.2.6.post1
|
||||
@ -254,23 +369,49 @@ RUN --mount=type=bind,from=build,src=/workspace/dist,target=/vllm-workspace/dist
|
||||
# -rw-rw-r-- 1 mgoin mgoin 205M Jun 9 18:03 flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
|
||||
# $ # upload the wheel to a public location, e.g. https://wheels.vllm.ai/flashinfer/v0.2.6.post1/flashinfer_python-0.2.6.post1-cp39-abi3-linux_x86_64.whl
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
. /etc/environment && \
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then \
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then \
|
||||
uv pip install --system https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl; \
|
||||
else \
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a' && \
|
||||
git clone https://github.com/flashinfer-ai/flashinfer.git --single-branch --branch v0.2.6.post1 --recursive && \
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
) && \
|
||||
rm -rf flashinfer; \
|
||||
fi \
|
||||
fi
|
||||
# Allow specifying a version, Git revision or local .whl file
|
||||
ARG FLASHINFER_CUDA128_INDEX_URL="https://download.pytorch.org/whl/cu128/flashinfer"
|
||||
ARG FLASHINFER_CUDA128_WHEEL="flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl"
|
||||
ARG FLASHINFER_GIT_REPO="https://github.com/flashinfer-ai/flashinfer.git"
|
||||
ARG FLASHINFER_GIT_REF="v0.2.6.post1"
|
||||
RUN --mount=type=cache,target=/root/.cache/uv bash - <<'BASH'
|
||||
. /etc/environment
|
||||
if [ "$TARGETPLATFORM" != "linux/arm64" ]; then
|
||||
# FlashInfer already has a wheel for PyTorch 2.7.0 and CUDA 12.8. This is enough for CI use
|
||||
if [[ "$CUDA_VERSION" == 12.8* ]]; then
|
||||
uv pip install --system ${FLASHINFER_CUDA128_INDEX_URL}/${FLASHINFER_CUDA128_WHEEL}
|
||||
else
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0a 10.0a 12.0'
|
||||
git clone ${FLASHINFER_GIT_REPO} --single-branch --branch ${FLASHINFER_GIT_REF} --recursive
|
||||
# Needed to build AOT kernels
|
||||
(cd flashinfer && \
|
||||
python3 -m flashinfer.aot && \
|
||||
uv pip install --system --no-build-isolation . \
|
||||
)
|
||||
rm -rf flashinfer
|
||||
|
||||
# Default arches (skipping 10.0a and 12.0 since these need 12.8)
|
||||
# TODO: Update this to allow setting TORCH_CUDA_ARCH_LIST as a build arg.
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9 9.0a"
|
||||
if [[ "${CUDA_VERSION}" == 11.* ]]; then
|
||||
TORCH_CUDA_ARCH_LIST="7.5 8.0 8.9"
|
||||
fi
|
||||
echo "🏗️ Building FlashInfer for arches: ${TORCH_CUDA_ARCH_LIST}"
|
||||
|
||||
git clone --depth 1 --recursive --shallow-submodules \
|
||||
--branch v0.2.6.post1 \
|
||||
https://github.com/flashinfer-ai/flashinfer.git flashinfer
|
||||
|
||||
pushd flashinfer
|
||||
python3 -m flashinfer.aot
|
||||
TORCH_CUDA_ARCH_LIST="${TORCH_CUDA_ARCH_LIST}" \
|
||||
uv pip install --system --no-build-isolation .
|
||||
popd
|
||||
|
||||
rm -rf flashinfer
|
||||
fi \
|
||||
fi
|
||||
BASH
|
||||
COPY examples examples
|
||||
COPY benchmarks benchmarks
|
||||
COPY ./vllm/collect_env.py .
|
||||
@ -286,7 +427,7 @@ uv pip list
|
||||
COPY requirements/build.txt requirements/build.txt
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system -r requirements/build.txt \
|
||||
--extra-index-url https://download.pytorch.org/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
--extra-index-url ${PYTORCH_CUDA_INDEX_BASE_URL}/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.')
|
||||
|
||||
#################### vLLM installation IMAGE ####################
|
||||
|
||||
@ -297,6 +438,11 @@ FROM vllm-base AS test
|
||||
|
||||
ADD . /vllm-workspace/
|
||||
|
||||
ARG PYTHON_VERSION
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
@ -307,7 +453,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
CUDA_MAJOR="${CUDA_VERSION%%.*}"; \
|
||||
if [ "$CUDA_MAJOR" -ge 12 ]; then \
|
||||
uv pip install --system -r requirements/dev.txt; \
|
||||
@ -323,7 +469,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
ENV HF_HUB_ENABLE_HF_TRANSFER 1
|
||||
|
||||
# Copy in the v1 package for testing (it isn't distributed yet)
|
||||
COPY vllm/v1 /usr/local/lib/python3.12/dist-packages/vllm/v1
|
||||
COPY vllm/v1 /usr/local/lib/python${PYTHON_VERSION}/dist-packages/vllm/v1
|
||||
|
||||
# doc requires source code
|
||||
# we hide them inside `test_docs/` , so that this source code
|
||||
@ -340,6 +486,9 @@ RUN mv mkdocs.yaml test_docs/
|
||||
FROM vllm-base AS vllm-openai-base
|
||||
ARG TARGETPLATFORM
|
||||
|
||||
ARG PIP_INDEX_URL UV_INDEX_URL
|
||||
ARG PIP_EXTRA_INDEX_URL UV_EXTRA_INDEX_URL
|
||||
|
||||
# This timeout (in seconds) is necessary when installing some dependencies via uv since it's likely to time out
|
||||
# Reference: https://github.com/astral-sh/uv/pull/1694
|
||||
ENV UV_HTTP_TIMEOUT=500
|
||||
@ -349,7 +498,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
if [ "$TARGETPLATFORM" = "linux/arm64" ]; then \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.42.0' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
else \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.45.3' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
uv pip install --system accelerate hf_transfer 'modelscope!=1.15.0' 'bitsandbytes>=0.46.1' 'timm==0.9.10' boto3 runai-model-streamer runai-model-streamer[s3]; \
|
||||
fi
|
||||
|
||||
ENV VLLM_USAGE_SOURCE production-docker-image
|
||||
|
||||
@ -8,6 +8,8 @@
|
||||
# Build arguments:
|
||||
# PYTHON_VERSION=3.12 (default)|3.11|3.10|3.9
|
||||
# VLLM_CPU_DISABLE_AVX512=false (default)|true
|
||||
# VLLM_CPU_AVX512BF16=false (default)|true
|
||||
# VLLM_CPU_AVX512VNNI=false (default)|true
|
||||
#
|
||||
|
||||
######################### BASE IMAGE #########################
|
||||
@ -25,7 +27,7 @@ RUN --mount=type=cache,target=/var/cache/apt,sharing=locked \
|
||||
--mount=type=cache,target=/var/lib/apt,sharing=locked \
|
||||
apt-get update -y \
|
||||
&& apt-get install -y --no-install-recommends ccache git curl wget ca-certificates \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 \
|
||||
gcc-12 g++-12 libtcmalloc-minimal4 libnuma-dev ffmpeg libsm6 libxext6 libgl1 jq lsof \
|
||||
&& update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-12 10 --slave /usr/bin/g++ g++ /usr/bin/g++-12 \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
|
||||
@ -60,13 +62,19 @@ FROM base AS vllm-build
|
||||
|
||||
ARG GIT_REPO_CHECK=0
|
||||
# Support for building with non-AVX512 vLLM: docker build --build-arg VLLM_CPU_DISABLE_AVX512="true" ...
|
||||
ARG VLLM_CPU_DISABLE_AVX512
|
||||
ARG VLLM_CPU_DISABLE_AVX512=0
|
||||
ENV VLLM_CPU_DISABLE_AVX512=${VLLM_CPU_DISABLE_AVX512}
|
||||
# Support for building with AVX512BF16 ISA: docker build --build-arg VLLM_CPU_AVX512BF16="true" ...
|
||||
ARG VLLM_CPU_AVX512BF16=0
|
||||
ENV VLLM_CPU_AVX512BF16=${VLLM_CPU_AVX512BF16}
|
||||
# Support for building with AVX512VNNI ISA: docker build --build-arg VLLM_CPU_AVX512VNNI="true" ...
|
||||
ARG VLLM_CPU_AVX512VNNI=0
|
||||
ENV VLLM_CPU_AVX512VNNI=${VLLM_CPU_AVX512VNNI}
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,src=requirements/build.txt,target=requirements/build.txt \
|
||||
--mount=type=bind,src=requirements/cpu-build.txt,target=requirements/build.txt \
|
||||
uv pip install -r requirements/build.txt
|
||||
|
||||
COPY . .
|
||||
@ -79,6 +87,22 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py bdist_wheel
|
||||
|
||||
######################### TEST DEPS #########################
|
||||
FROM base AS vllm-test-deps
|
||||
|
||||
WORKDIR /workspace/vllm
|
||||
|
||||
RUN --mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/cpu-test.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/cpu-test.in && \
|
||||
sed -i 's/torch==.*/torch==2.6.0/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchaudio.*/torchaudio/g' requirements/cpu-test.in && \
|
||||
sed -i 's/torchvision.*/torchvision/g' requirements/cpu-test.in && \
|
||||
uv pip compile requirements/cpu-test.in -o requirements/cpu-test.txt --index-strategy unsafe-best-match --torch-backend cpu
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
uv pip install -r requirements/cpu-test.txt
|
||||
|
||||
######################### DEV IMAGE #########################
|
||||
FROM vllm-build AS vllm-dev
|
||||
|
||||
@ -97,28 +121,19 @@ RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
VLLM_TARGET_DEVICE=cpu python3 setup.py develop
|
||||
|
||||
COPY --from=vllm-test-deps /workspace/vllm/requirements/cpu-test.txt requirements/test.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/test-cpu.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
|
||||
uv pip compile requirements/test-cpu.in -o requirements/test.txt && \
|
||||
uv pip install -r requirements/dev.txt && \
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
|
||||
ENTRYPOINT ["bash"]
|
||||
|
||||
######################### TEST IMAGE #########################
|
||||
FROM base AS vllm-test
|
||||
FROM vllm-test-deps AS vllm-test
|
||||
|
||||
WORKDIR /workspace/
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,src=requirements/test.in,target=requirements/test.in \
|
||||
cp requirements/test.in requirements/test-cpu.in && \
|
||||
sed -i '/mamba_ssm/d' requirements/test-cpu.in && \
|
||||
uv pip compile requirements/test-cpu.in -o requirements/cpu-test.txt && \
|
||||
uv pip install -r requirements/cpu-test.txt
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
--mount=type=bind,from=vllm-build,src=/workspace/vllm/dist,target=dist \
|
||||
uv pip install dist/*.whl
|
||||
@ -127,6 +142,7 @@ ADD ./tests/ ./tests/
|
||||
ADD ./examples/ ./examples/
|
||||
ADD ./benchmarks/ ./benchmarks/
|
||||
ADD ./vllm/collect_env.py .
|
||||
ADD ./.buildkite/ ./.buildkite/
|
||||
|
||||
# install development dependencies (for testing)
|
||||
RUN --mount=type=cache,target=/root/.cache/uv \
|
||||
|
||||
@ -12,7 +12,7 @@ ARG PYTORCH_REPO="https://github.com/pytorch/pytorch.git"
|
||||
ARG PYTORCH_VISION_REPO="https://github.com/pytorch/vision.git"
|
||||
ARG FA_BRANCH="1a7f4dfa"
|
||||
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
|
||||
ARG AITER_BRANCH="c1debd8"
|
||||
ARG AITER_BRANCH="6487649"
|
||||
ARG AITER_REPO="https://github.com/ROCm/aiter.git"
|
||||
|
||||
FROM ${BASE_IMAGE} AS base
|
||||
|
||||
@ -35,6 +35,7 @@ RUN --mount=type=bind,source=.git,target=.git \
|
||||
if [ "$GIT_REPO_CHECK" != 0 ]; then bash tools/check_repo.sh; fi
|
||||
|
||||
ENV VLLM_TARGET_DEVICE=xpu
|
||||
ENV VLLM_WORKER_MULTIPROC_METHOD=spawn
|
||||
|
||||
RUN --mount=type=cache,target=/root/.cache/pip \
|
||||
--mount=type=bind,source=.git,target=.git \
|
||||
|
||||
@ -39,6 +39,7 @@ nav:
|
||||
- models/generative_models.md
|
||||
- models/pooling_models.md
|
||||
- models/extensions
|
||||
- Hardware Supported Models: models/hardware_supported_models
|
||||
- Features:
|
||||
- features/compatibility_matrix.md
|
||||
- features/*
|
||||
@ -48,7 +49,12 @@ nav:
|
||||
- General:
|
||||
- glob: contributing/*
|
||||
flatten_single_child_sections: true
|
||||
- Model Implementation: contributing/model
|
||||
- Model Implementation:
|
||||
- contributing/model/README.md
|
||||
- contributing/model/basic.md
|
||||
- contributing/model/registration.md
|
||||
- contributing/model/tests.md
|
||||
- contributing/model/multimodal.md
|
||||
- Design Documents:
|
||||
- V0: design
|
||||
- V1: design/v1
|
||||
|
||||
@ -1,7 +1,8 @@
|
||||
# Welcome to vLLM
|
||||
|
||||
<figure markdown="span">
|
||||
{ align="center" alt="vLLM" class="no-scaled-link" width="60%" }
|
||||
{ align="center" alt="vLLM Light" class="logo-light" width="60%" }
|
||||
{ align="center" alt="vLLM Dark" class="logo-dark" width="60%" }
|
||||
</figure>
|
||||
|
||||
<p style="text-align:center">
|
||||
@ -40,7 +41,7 @@ vLLM is flexible and easy to use with:
|
||||
- OpenAI-compatible API server
|
||||
- Support NVIDIA GPUs, AMD CPUs and GPUs, Intel CPUs, Gaudi® accelerators and GPUs, IBM Power CPUs, TPU, and AWS Trainium and Inferentia Accelerators.
|
||||
- Prefix caching support
|
||||
- Multi-lora support
|
||||
- Multi-LoRA support
|
||||
|
||||
For more information, check out the following:
|
||||
|
||||
|
||||
@ -91,7 +91,7 @@ source to unblock the update process.
|
||||
### FlashInfer
|
||||
Here is how to build and install it from source with torch2.7.0+cu128 in vLLM [Dockerfile](https://github.com/vllm-project/vllm/blob/27bebcd89792d5c4b08af7a65095759526f2f9e1/docker/Dockerfile#L259-L271):
|
||||
|
||||
```
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
export FLASHINFER_ENABLE_SM90=1
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/flashinfer-ai/flashinfer@v0.2.6.post1"
|
||||
@ -105,14 +105,14 @@ team if you want to get the package published there.
|
||||
### xFormers
|
||||
Similar to FlashInfer, here is how to build and install xFormers from source:
|
||||
|
||||
```
|
||||
```bash
|
||||
export TORCH_CUDA_ARCH_LIST='7.0 7.5 8.0 8.9 9.0 10.0+PTX'
|
||||
MAX_JOBS=16 uv pip install --system --no-build-isolation "git+https://github.com/facebookresearch/xformers@v0.0.30"
|
||||
```
|
||||
|
||||
### Mamba
|
||||
|
||||
```
|
||||
```bash
|
||||
uv pip install --system --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
|
||||
```
|
||||
|
||||
|
||||
@ -16,35 +16,33 @@ vllm {chat,complete,serve,bench,collect-env,run-batch}
|
||||
|
||||
Start the vLLM OpenAI Compatible API server.
|
||||
|
||||
Examples:
|
||||
??? Examples
|
||||
|
||||
```bash
|
||||
# Start with a model
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
```bash
|
||||
# Start with a model
|
||||
vllm serve meta-llama/Llama-2-7b-hf
|
||||
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
# Specify the port
|
||||
vllm serve meta-llama/Llama-2-7b-hf --port 8100
|
||||
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
# Check with --help for more options
|
||||
# To list all groups
|
||||
vllm serve --help=listgroup
|
||||
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
# To view a argument group
|
||||
vllm serve --help=ModelConfig
|
||||
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
# To view a single argument
|
||||
vllm serve --help=max-num-seqs
|
||||
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
```
|
||||
# To search by keyword
|
||||
vllm serve --help=max
|
||||
```
|
||||
|
||||
## chat
|
||||
|
||||
Generate chat completions via the running API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Directly connect to localhost API without arguments
|
||||
vllm chat
|
||||
@ -60,8 +58,6 @@ vllm chat --quick "hi"
|
||||
|
||||
Generate text completions based on the given prompt via the running API server.
|
||||
|
||||
Examples:
|
||||
|
||||
```bash
|
||||
# Directly connect to localhost API without arguments
|
||||
vllm complete
|
||||
@ -73,6 +69,8 @@ vllm complete --url http://{vllm-serve-host}:{vllm-serve-port}/v1
|
||||
vllm complete --quick "The future of AI is"
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## bench
|
||||
|
||||
Run benchmark tests for latency online serving throughput and offline inference throughput.
|
||||
@ -89,8 +87,6 @@ vllm bench {latency, serve, throughput}
|
||||
|
||||
Benchmark the latency of a single batch of requests.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench latency \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
@ -104,8 +100,6 @@ vllm bench latency \
|
||||
|
||||
Benchmark the online serving throughput.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench serve \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
@ -120,8 +114,6 @@ vllm bench serve \
|
||||
|
||||
Benchmark offline inference throughput.
|
||||
|
||||
Example:
|
||||
|
||||
```bash
|
||||
vllm bench throughput \
|
||||
--model meta-llama/Llama-3.2-1B-Instruct \
|
||||
@ -143,7 +135,8 @@ vllm collect-env
|
||||
|
||||
Run batch prompts and write results to file.
|
||||
|
||||
Examples:
|
||||
<details>
|
||||
<summary>Examples</summary>
|
||||
|
||||
```bash
|
||||
# Running with a local file
|
||||
@ -159,6 +152,8 @@ vllm run-batch \
|
||||
--model meta-llama/Meta-Llama-3-8B-Instruct
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## More Help
|
||||
|
||||
For detailed options of any subcommand, use:
|
||||
|
||||
6
docs/community/contact_us.md
Normal file
6
docs/community/contact_us.md
Normal file
@ -0,0 +1,6 @@
|
||||
---
|
||||
title: Contact Us
|
||||
---
|
||||
[](){ #contactus }
|
||||
|
||||
--8<-- "README.md:contact-us"
|
||||
@ -57,19 +57,21 @@ By default, we optimize model inference using CUDA graphs which take up extra me
|
||||
|
||||
You can adjust `compilation_config` to achieve a better balance between inference speed and memory usage:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
??? Code
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# By default, it goes up to max_num_seqs
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
),
|
||||
)
|
||||
```
|
||||
```python
|
||||
from vllm import LLM
|
||||
from vllm.config import CompilationConfig, CompilationLevel
|
||||
|
||||
llm = LLM(
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
compilation_config=CompilationConfig(
|
||||
level=CompilationLevel.PIECEWISE,
|
||||
# By default, it goes up to max_num_seqs
|
||||
cudagraph_capture_sizes=[1, 2, 4, 8, 16],
|
||||
),
|
||||
)
|
||||
```
|
||||
|
||||
You can disable graph capturing completely via the `enforce_eager` flag:
|
||||
|
||||
@ -127,18 +129,20 @@ reduce the size of the processed multi-modal inputs, which in turn saves memory.
|
||||
|
||||
Here are some examples:
|
||||
|
||||
```python
|
||||
from vllm import LLM
|
||||
??? Code
|
||||
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
```python
|
||||
from vllm import LLM
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
# Available for Qwen2-VL series models
|
||||
llm = LLM(model="Qwen/Qwen2.5-VL-3B-Instruct",
|
||||
mm_processor_kwargs={
|
||||
"max_pixels": 768 * 768, # Default is 1280 * 28 * 28
|
||||
})
|
||||
|
||||
# Available for InternVL series models
|
||||
llm = LLM(model="OpenGVLab/InternVL2-2B",
|
||||
mm_processor_kwargs={
|
||||
"max_dynamic_patch": 4, # Default is 12
|
||||
})
|
||||
```
|
||||
|
||||
@ -6,7 +6,7 @@ title: Engine Arguments
|
||||
Engine arguments control the behavior of the vLLM engine.
|
||||
|
||||
- For [offline inference][offline-inference], they are part of the arguments to [LLM][vllm.LLM] class.
|
||||
- For [online serving][openai-compatible-server], they are part of the arguments to `vllm serve`.
|
||||
- For [online serving][serving-openai-compatible-server], they are part of the arguments to `vllm serve`.
|
||||
|
||||
You can look at [EngineArgs][vllm.engine.arg_utils.EngineArgs] and [AsyncEngineArgs][vllm.engine.arg_utils.AsyncEngineArgs] to see the available engine arguments.
|
||||
|
||||
|
||||
@ -7,6 +7,8 @@ vLLM uses the following environment variables to configure the system:
|
||||
|
||||
All environment variables used by vLLM are prefixed with `VLLM_`. **Special care should be taken for Kubernetes users**: please do not name the service as `vllm`, otherwise environment variables set by Kubernetes might conflict with vLLM's environment variables, because [Kubernetes sets environment variables for each service with the capitalized service name as the prefix](https://kubernetes.io/docs/concepts/services-networking/service/#environment-variables).
|
||||
|
||||
```python
|
||||
--8<-- "vllm/envs.py:env-vars-definition"
|
||||
```
|
||||
??? Code
|
||||
|
||||
```python
|
||||
--8<-- "vllm/envs.py:env-vars-definition"
|
||||
```
|
||||
|
||||
@ -29,6 +29,8 @@ See <gh-file:LICENSE>.
|
||||
Depending on the kind of development you'd like to do (e.g. Python, CUDA), you can choose to build vLLM with or without compilation.
|
||||
Check out the [building from source][build-from-source] documentation for details.
|
||||
|
||||
For an optimized workflow when iterating on C++/CUDA kernels, see the [Incremental Compilation Workflow](./incremental_build.md) for recommendations.
|
||||
|
||||
### Building the docs with MkDocs
|
||||
|
||||
#### Introduction to MkDocs
|
||||
@ -93,25 +95,27 @@ For additional features and advanced configurations, refer to the official [MkDo
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
pip install -r requirements/dev.txt
|
||||
??? note "Commands"
|
||||
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
```bash
|
||||
pip install -r requirements/dev.txt
|
||||
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
# Linting, formatting and static type checking
|
||||
pre-commit install --hook-type pre-commit --hook-type commit-msg
|
||||
|
||||
# To manually run something from CI that does not run
|
||||
# locally by default, you can run:
|
||||
pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
# You can manually run pre-commit with
|
||||
pre-commit run --all-files
|
||||
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
# To manually run something from CI that does not run
|
||||
# locally by default, you can run:
|
||||
pre-commit run mypy-3.9 --hook-stage manual --all-files
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
# Unit tests
|
||||
pytest tests/
|
||||
|
||||
# Run tests for a single test file with detailed output
|
||||
pytest -s -v tests/test_logger.py
|
||||
```
|
||||
|
||||
!!! tip
|
||||
Since the <gh-file:docker/Dockerfile> ships with Python 3.12, all tests in CI (except `mypy`) are run with Python 3.12.
|
||||
@ -147,6 +151,14 @@ the terms of the DCO.
|
||||
|
||||
Using `-s` with `git commit` will automatically add this header.
|
||||
|
||||
!!! tip
|
||||
You can enable automatic sign-off via your IDE:
|
||||
|
||||
- **PyCharm**: Click on the `Show Commit Options` icon to the right of the `Commit and Push...` button in the `Commit` window.
|
||||
It will bring up a `git` window where you can modify the `Author` and enable `Sign-off commit`.
|
||||
- **VSCode**: Open the [Settings editor](https://code.visualstudio.com/docs/configure/settings)
|
||||
and enable the `Git: Always Sign Off` (`git.alwaysSignOff`) field.
|
||||
|
||||
### PR Title and Classification
|
||||
|
||||
Only specific types of PRs will be reviewed. The PR title is prefixed
|
||||
@ -186,6 +198,7 @@ The PR needs to meet the following code quality standards:
|
||||
|
||||
### Adding or Changing Kernels
|
||||
|
||||
When actively developing or modifying kernels, using the [Incremental Compilation Workflow](./incremental_build.md) is highly recommended for faster build times.
|
||||
Each custom kernel needs a schema and one or more implementations to be registered with PyTorch.
|
||||
|
||||
- Make sure custom ops are registered following PyTorch guidelines:
|
||||
|
||||
@ -37,14 +37,14 @@ multiple Y releases:
|
||||
- **Timeline**: A removal version is explicitly stated in the deprecation
|
||||
warning (e.g., "This will be removed in v0.10.0").
|
||||
- **Communication**: Deprecation is noted in the following, as applicable:
|
||||
- Help strings
|
||||
- Log output
|
||||
- API responses
|
||||
- `/metrics` output (for metrics features)
|
||||
- User-facing documentation
|
||||
- Release notes
|
||||
- GitHub Issue (RFC) for feedback
|
||||
- Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs
|
||||
- Help strings
|
||||
- Log output
|
||||
- API responses
|
||||
- `/metrics` output (for metrics features)
|
||||
- User-facing documentation
|
||||
- Release notes
|
||||
- GitHub Issue (RFC) for feedback
|
||||
- Documentation and use of the `@typing_extensions.deprecated` decorator for Python APIs
|
||||
|
||||
**2.Deprecated (Off By Default)**
|
||||
|
||||
|
||||
138
docs/contributing/incremental_build.md
Normal file
138
docs/contributing/incremental_build.md
Normal file
@ -0,0 +1,138 @@
|
||||
# Incremental Compilation Workflow
|
||||
|
||||
When working on vLLM's C++/CUDA kernels located in the `csrc/` directory, recompiling the entire project with `uv pip install -e .` for every change can be time-consuming. An incremental compilation workflow using CMake allows for faster iteration by only recompiling the necessary components after an initial setup. This guide details how to set up and use such a workflow, which complements your editable Python installation.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Before setting up the incremental build:
|
||||
|
||||
1. **vLLM Editable Install:** Ensure you have vLLM installed from source in an editable mode. Using pre-compiled wheels for the initial editable setup can be faster, as the CMake workflow will handle subsequent kernel recompilations.
|
||||
|
||||
```console
|
||||
uv venv --python 3.12 --seed
|
||||
source .venv/bin/activate
|
||||
VLLM_USE_PRECOMPILED=1 uv pip install -U -e . --torch-backend=auto
|
||||
```
|
||||
|
||||
2. **CUDA Toolkit:** Verify that the NVIDIA CUDA Toolkit is correctly installed and `nvcc` is accessible in your `PATH`. CMake relies on `nvcc` to compile CUDA code. You can typically find `nvcc` in `$CUDA_HOME/bin/nvcc` or by running `which nvcc`. If you encounter issues, refer to the [official CUDA Toolkit installation guides](https://developer.nvidia.com/cuda-toolkit-archive) and vLLM's main [GPU installation documentation](../getting_started/installation/gpu/cuda.inc.md#troubleshooting) for troubleshooting. The `CMAKE_CUDA_COMPILER` variable in your `CMakeUserPresets.json` should also point to your `nvcc` binary.
|
||||
|
||||
3. **Build Tools:** It is highly recommended to install `ccache` for fast rebuilds by caching compilation results (e.g., `sudo apt install ccache` or `conda install ccache`). Also, ensure the core build dependencies like `cmake` and `ninja` are installed. These are installable through `requirements/build.txt` or your system's package manager.
|
||||
|
||||
```console
|
||||
uv pip install -r requirements/build.txt --torch-backend=auto
|
||||
```
|
||||
|
||||
## Setting up the CMake Build Environment
|
||||
|
||||
The incremental build process is managed through CMake. You can configure your build settings using a `CMakeUserPresets.json` file at the root of the vLLM repository.
|
||||
|
||||
### Generate `CMakeUserPresets.json` using the helper script
|
||||
|
||||
To simplify the setup, vLLM provides a helper script that attempts to auto-detect your system's configuration (like CUDA path, Python environment, and CPU cores) and generates the `CMakeUserPresets.json` file for you.
|
||||
|
||||
**Run the script:**
|
||||
|
||||
Navigate to the root of your vLLM clone and execute the following command:
|
||||
|
||||
```console
|
||||
python tools/generate_cmake_presets.py
|
||||
```
|
||||
|
||||
The script will prompt you if it cannot automatically determine certain paths (e.g., `nvcc` or a specific Python executable for your vLLM development environment). Follow the on-screen prompts. If an existing `CMakeUserPresets.json` is found, the script will ask for confirmation before overwriting it.
|
||||
|
||||
After running the script, a `CMakeUserPresets.json` file will be created in the root of your vLLM repository.
|
||||
|
||||
### Example `CMakeUserPresets.json`
|
||||
|
||||
Below is an example of what the generated `CMakeUserPresets.json` might look like. The script will tailor these values based on your system and any input you provide.
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 6,
|
||||
"cmakeMinimumRequired": {
|
||||
"major": 3,
|
||||
"minor": 26,
|
||||
"patch": 1
|
||||
},
|
||||
"configurePresets": [
|
||||
{
|
||||
"name": "release",
|
||||
"generator": "Ninja",
|
||||
"binaryDir": "${sourceDir}/cmake-build-release",
|
||||
"cacheVariables": {
|
||||
"CMAKE_CUDA_COMPILER": "/usr/local/cuda/bin/nvcc",
|
||||
"CMAKE_C_COMPILER_LAUNCHER": "ccache",
|
||||
"CMAKE_CXX_COMPILER_LAUNCHER": "ccache",
|
||||
"CMAKE_CUDA_COMPILER_LAUNCHER": "ccache",
|
||||
"CMAKE_BUILD_TYPE": "Release",
|
||||
"VLLM_PYTHON_EXECUTABLE": "/home/user/venvs/vllm/bin/python",
|
||||
"CMAKE_INSTALL_PREFIX": "${sourceDir}",
|
||||
"CMAKE_CUDA_FLAGS": "",
|
||||
"NVCC_THREADS": "4",
|
||||
"CMAKE_JOB_POOLS": "compile=32"
|
||||
}
|
||||
}
|
||||
],
|
||||
"buildPresets": [
|
||||
{
|
||||
"name": "release",
|
||||
"configurePreset": "release",
|
||||
"jobs": 32
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**What do the various configurations mean?**
|
||||
- `CMAKE_CUDA_COMPILER`: Path to your `nvcc` binary. The script attempts to find this automatically.
|
||||
- `CMAKE_C_COMPILER_LAUNCHER`, `CMAKE_CXX_COMPILER_LAUNCHER`, `CMAKE_CUDA_COMPILER_LAUNCHER`: Setting these to `ccache` (or `sccache`) significantly speeds up rebuilds by caching compilation results. Ensure `ccache` is installed (e.g., `sudo apt install ccache` or `conda install ccache`). The script sets these by default.
|
||||
- `VLLM_PYTHON_EXECUTABLE`: Path to the Python executable in your vLLM development environment. The script will prompt for this, defaulting to the current Python environment if suitable.
|
||||
- `CMAKE_INSTALL_PREFIX: "${sourceDir}"`: Specifies that the compiled components should be installed back into your vLLM source directory. This is crucial for the editable install, as it makes the newly built kernels immediately available to your Python environment.
|
||||
- `CMAKE_JOB_POOLS` and `jobs` in build presets: Control the parallelism of the build. The script sets these based on the number of CPU cores detected on your system.
|
||||
- `binaryDir`: Specifies where the build artifacts will be stored (e.g., `cmake-build-release`).
|
||||
|
||||
## Building and Installing with CMake
|
||||
|
||||
Once your `CMakeUserPresets.json` is configured:
|
||||
|
||||
1. **Initialize the CMake build environment:**
|
||||
This step configures the build system according to your chosen preset (e.g., `release`) and creates the build directory at `binaryDir`
|
||||
|
||||
```console
|
||||
cmake --preset release
|
||||
```
|
||||
|
||||
2. **Build and install the vLLM components:**
|
||||
This command compiles the code and installs the resulting binaries into your vLLM source directory, making them available to your editable Python installation.
|
||||
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
|
||||
3. **Make changes and repeat!**
|
||||
Now you start using your editable install of vLLM, testing and making changes as needed. If you need to build again to update based on changes, simply run the CMake command again to build only the affected files.
|
||||
|
||||
```console
|
||||
cmake --build --preset release --target install
|
||||
```
|
||||
|
||||
## Verifying the Build
|
||||
|
||||
After a successful build, you will find a populated build directory (e.g., `cmake-build-release/` if you used the `release` preset and the example configuration).
|
||||
|
||||
```console
|
||||
> ls cmake-build-release/
|
||||
bin cmake_install.cmake _deps machete_generation.log
|
||||
build.ninja CPackConfig.cmake detect_cuda_compute_capabilities.cu marlin_generation.log
|
||||
_C.abi3.so CPackSourceConfig.cmake detect_cuda_version.cc _moe_C.abi3.so
|
||||
CMakeCache.txt ctest _flashmla_C.abi3.so moe_marlin_generation.log
|
||||
CMakeFiles cumem_allocator.abi3.so install_local_manifest.txt vllm-flash-attn
|
||||
```
|
||||
|
||||
The `cmake --build ... --target install` command copies the compiled shared libraries (like `_C.abi3.so`, `_moe_C.abi3.so`, etc.) into the appropriate `vllm` package directory within your source tree. This updates your editable installation with the newly compiled kernels.
|
||||
|
||||
## Additional Tips
|
||||
|
||||
- **Adjust Parallelism:** Fine-tune the `CMAKE_JOB_POOLS` in `configurePresets` and `jobs` in `buildPresets` in your `CMakeUserPresets.json`. Too many jobs can overload systems with limited RAM or CPU cores, leading to slower builds or system instability. Too few won't fully utilize available resources.
|
||||
- **Clean Builds When Necessary:** If you encounter persistent or strange build errors, especially after significant changes or switching branches, consider removing the CMake build directory (e.g., `rm -rf cmake-build-release`) and re-running the `cmake --preset` and `cmake --build` commands.
|
||||
- **Specific Target Builds:** For even faster iterations when working on a specific module, you can sometimes build a specific target instead of the full `install` target, though `install` ensures all necessary components are updated in your Python environment. Refer to CMake documentation for more advanced target management.
|
||||
@ -1,21 +1,23 @@
|
||||
---
|
||||
title: Adding a New Model
|
||||
title: Summary
|
||||
---
|
||||
[](){ #new-model }
|
||||
|
||||
This section provides more information on how to integrate a [PyTorch](https://pytorch.org/) model into vLLM.
|
||||
!!! important
|
||||
Many decoder language models can now be automatically loaded using the [Transformers backend][transformers-backend] without having to implement them in vLLM. See if `vllm serve <model>` works first!
|
||||
|
||||
Contents:
|
||||
vLLM models are specialized [PyTorch](https://pytorch.org/) models that take advantage of various [features][compatibility-matrix] to optimize their performance.
|
||||
|
||||
- [Basic](basic.md)
|
||||
- [Registration](registration.md)
|
||||
- [Tests](tests.md)
|
||||
- [Multimodal](multimodal.md)
|
||||
The complexity of integrating a model into vLLM depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
However, this can be more complex for models that include new operators (e.g., a new attention mechanism).
|
||||
|
||||
!!! note
|
||||
The complexity of adding a new model depends heavily on the model's architecture.
|
||||
The process is considerably straightforward if the model shares a similar architecture with an existing model in vLLM.
|
||||
However, for models that include new operators (e.g., a new attention mechanism), the process can be a bit more complex.
|
||||
Read through these pages for a step-by-step guide:
|
||||
|
||||
- [Basic Model](basic.md)
|
||||
- [Registering a Model](registration.md)
|
||||
- [Unit Testing](tests.md)
|
||||
- [Multi-Modal Support](multimodal.md)
|
||||
|
||||
!!! tip
|
||||
If you are encountering issues while integrating your model into vLLM, feel free to open a [GitHub issue](https://github.com/vllm-project/vllm/issues)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Implementing a Basic Model
|
||||
title: Basic Model
|
||||
---
|
||||
[](){ #new-model-basic }
|
||||
|
||||
@ -27,33 +27,35 @@ All vLLM modules within the model must include a `prefix` argument in their cons
|
||||
|
||||
The initialization code should look like this:
|
||||
|
||||
```python
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention import Attention
|
||||
??? Code
|
||||
|
||||
class MyAttention(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.attn = Attention(prefix=f"{prefix}.attn")
|
||||
```python
|
||||
from torch import nn
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.attention import Attention
|
||||
|
||||
class MyDecoderLayer(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
|
||||
class MyAttention(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.attn = Attention(prefix=f"{prefix}.attn")
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
|
||||
)
|
||||
class MyDecoderLayer(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.self_attn = MyAttention(prefix=f"{prefix}.self_attn")
|
||||
|
||||
class MyModelForCausalLM(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
|
||||
```
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)]
|
||||
)
|
||||
|
||||
class MyModelForCausalLM(nn.Module):
|
||||
def __init__(self, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
self.model = MyModel(vllm_config, prefix=f"{prefix}.model")
|
||||
```
|
||||
|
||||
### Computation Code
|
||||
|
||||
|
||||
@ -10,6 +10,22 @@ This document walks you through the steps to extend a basic model so that it acc
|
||||
It is assumed that you have already implemented the model in vLLM according to [these steps][new-model-basic].
|
||||
Further update the model as follows:
|
||||
|
||||
- Implement [get_placeholder_str][vllm.model_executor.models.interfaces.SupportsMultiModal.get_placeholder_str] to define the placeholder string which is used to represent the multi-modal item in the text prompt. This should be consistent with the chat template of the model.
|
||||
|
||||
??? Code
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_placeholder_str(cls, modality: str, i: int) -> Optional[str]:
|
||||
if modality.startswith("image"):
|
||||
return "<image>"
|
||||
|
||||
raise ValueError("Only image modality is supported")
|
||||
```
|
||||
|
||||
- Reserve a keyword parameter in [forward][torch.nn.Module.forward] for each input tensor that corresponds to a multi-modal input, as shown in the following example:
|
||||
|
||||
```diff
|
||||
@ -25,59 +41,63 @@ Further update the model as follows:
|
||||
|
||||
- Implement [get_multimodal_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_multimodal_embeddings] that returns the embeddings from running the multimodal inputs through the multimodal tokenizer of the model. Below we provide a boilerplate of a typical implementation pattern, but feel free to adjust it to your own needs.
|
||||
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
??? Code
|
||||
|
||||
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
|
||||
```python
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
assert self.vision_encoder is not None
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
def _process_image_input(self, image_input: YourModelImageInputs) -> torch.Tensor:
|
||||
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
assert self.vision_encoder is not None
|
||||
image_features = self.vision_encoder(image_input)
|
||||
return self.multi_modal_projector(image_features)
|
||||
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
def get_multimodal_embeddings(
|
||||
self, **kwargs: object) -> Optional[MultiModalEmbeddings]:
|
||||
|
||||
# Run multimodal inputs through encoder and projector
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
```
|
||||
# Validate the multimodal input keyword arguments
|
||||
image_input = self._parse_and_validate_image_input(**kwargs)
|
||||
if image_input is None:
|
||||
return None
|
||||
|
||||
# Run multimodal inputs through encoder and projector
|
||||
vision_embeddings = self._process_image_input(image_input)
|
||||
return vision_embeddings
|
||||
```
|
||||
|
||||
!!! important
|
||||
The returned `multimodal_embeddings` must be either a **3D [torch.Tensor][]** of shape `(num_items, feature_size, hidden_size)`, or a **list / tuple of 2D [torch.Tensor][]'s** of shape `(feature_size, hidden_size)`, so that `multimodal_embeddings[i]` retrieves the embeddings generated from the `i`-th multimodal data item (e.g, image) of the request.
|
||||
|
||||
- Implement [get_input_embeddings][vllm.model_executor.models.interfaces.SupportsMultiModal.get_input_embeddings] to merge `multimodal_embeddings` with text embeddings from the `input_ids`. If input processing for the model is implemented correctly (see sections below), then you can leverage the utility function we provide to easily merge the embeddings.
|
||||
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
??? Code
|
||||
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
```python
|
||||
from .utils import merge_multimodal_embeddings
|
||||
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
class YourModelForImage2Seq(nn.Module):
|
||||
...
|
||||
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
def get_input_embeddings(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
multimodal_embeddings: Optional[MultiModalEmbeddings] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
# `get_input_embeddings` should already be implemented for the language
|
||||
# model as one of the requirements of basic vLLM model implementation.
|
||||
inputs_embeds = self.language_model.get_input_embeddings(input_ids)
|
||||
|
||||
return inputs_embeds
|
||||
```
|
||||
if multimodal_embeddings is not None:
|
||||
inputs_embeds = merge_multimodal_embeddings(
|
||||
input_ids=input_ids,
|
||||
inputs_embeds=inputs_embeds,
|
||||
multimodal_embeddings=multimodal_embeddings,
|
||||
placeholder_token_id=self.config.image_token_index)
|
||||
|
||||
return inputs_embeds
|
||||
```
|
||||
|
||||
- Implement [get_language_model][vllm.model_executor.models.interfaces.SupportsMultiModal.get_language_model] getter to provide stable access to the underlying language model.
|
||||
|
||||
@ -135,42 +155,46 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Looking at the code of HF's `LlavaForConditionalGeneration`:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
??? Code
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L530-L544
|
||||
n_image_tokens = (input_ids == self.config.image_token_index).sum().item()
|
||||
n_image_features = image_features.shape[0] * image_features.shape[1]
|
||||
|
||||
if n_image_tokens != n_image_features:
|
||||
raise ValueError(
|
||||
f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {n_image_features}"
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
special_image_mask = (
|
||||
(input_ids == self.config.image_token_index)
|
||||
.unsqueeze(-1)
|
||||
.expand_as(inputs_embeds)
|
||||
.to(inputs_embeds.device)
|
||||
)
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
```
|
||||
image_features = image_features.to(inputs_embeds.device, inputs_embeds.dtype)
|
||||
inputs_embeds = inputs_embeds.masked_scatter(special_image_mask, image_features)
|
||||
```
|
||||
|
||||
The number of placeholder feature tokens per image is `image_features.shape[1]`.
|
||||
`image_features` is calculated inside the `get_image_features` method:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
??? Code
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
```
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/llava/modeling_llava.py#L290-L300
|
||||
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
||||
|
||||
selected_image_feature = image_outputs.hidden_states[vision_feature_layer]
|
||||
if vision_feature_select_strategy == "default":
|
||||
selected_image_feature = selected_image_feature[:, 1:]
|
||||
elif vision_feature_select_strategy == "full":
|
||||
selected_image_feature = selected_image_feature
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}")
|
||||
image_features = self.multi_modal_projector(selected_image_feature)
|
||||
return image_features
|
||||
```
|
||||
|
||||
We can infer that `image_features.shape[1]` is based on `image_outputs.hidden_states.shape[1]` from the vision tower
|
||||
(`CLIPVisionModel` for the [`llava-hf/llava-1.5-7b-hf`](https://huggingface.co/llava-hf/llava-1.5-7b-hf) model).
|
||||
@ -193,20 +217,22 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
To find the sequence length, we turn to the code of `CLIPVisionEmbeddings`:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
??? Code
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
```
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/modeling_clip.py#L247-L257
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # shape = [*, width, grid, grid]
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
if interpolate_pos_encoding:
|
||||
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
||||
else:
|
||||
embeddings = embeddings + self.position_embedding(self.position_ids)
|
||||
return embeddings
|
||||
```
|
||||
|
||||
We can infer that `embeddings.shape[1] == self.num_positions`, where
|
||||
|
||||
@ -218,55 +244,59 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Overall, the number of placeholder feature tokens for an image can be calculated as:
|
||||
|
||||
```python
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
hf_processor = self.get_hf_processor()
|
||||
??? Code
|
||||
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
```python
|
||||
def get_num_image_tokens(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> int:
|
||||
hf_config = self.get_hf_config()
|
||||
hf_processor = self.get_hf_processor()
|
||||
|
||||
num_image_tokens = (image_size // patch_size) ** 2 + 1
|
||||
if hf_processor.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
image_size = hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
|
||||
return num_image_tokens
|
||||
```
|
||||
num_image_tokens = (image_size // patch_size) ** 2 + 1
|
||||
if hf_processor.vision_feature_select_strategy == "default":
|
||||
num_image_tokens -= 1
|
||||
|
||||
return num_image_tokens
|
||||
```
|
||||
|
||||
Notice that the number of image tokens doesn't depend on the image width and height.
|
||||
We can simply use a dummy `image_size` to calculate the multimodal profiling data:
|
||||
|
||||
```python
|
||||
# NOTE: In actuality, this is usually implemented as part of the
|
||||
# model's subclass of `BaseProcessingInfo`, but we show it as is
|
||||
# here for simplicity.
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self.get_hf_config()
|
||||
width = height = hf_config.image_size
|
||||
return ImageSize(width=width, height=height)
|
||||
??? Code
|
||||
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
```python
|
||||
# NOTE: In actuality, this is usually implemented as part of the
|
||||
# model's subclass of `BaseProcessingInfo`, but we show it as is
|
||||
# here for simplicity.
|
||||
def get_image_size_with_most_features(self) -> ImageSize:
|
||||
hf_config = self.get_hf_config()
|
||||
width = height = hf_config.image_size
|
||||
return ImageSize(width=width, height=height)
|
||||
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
```
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
```
|
||||
|
||||
For the text, we simply expand the multimodal image token from the model config to match the desired number of images.
|
||||
|
||||
@ -284,21 +314,23 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
Looking at the code of HF's `FuyuForCausalLM`:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = [
|
||||
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
|
||||
.squeeze(0)
|
||||
.to(inputs_embeds.device)
|
||||
for patch in image_patches
|
||||
]
|
||||
inputs_embeds = self.gather_continuous_embeddings(
|
||||
word_embeddings=inputs_embeds,
|
||||
continuous_embeddings=patch_embeddings,
|
||||
image_patch_input_indices=image_patches_indices,
|
||||
)
|
||||
```
|
||||
??? Code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/modeling_fuyu.py#L311-L322
|
||||
if image_patches is not None and past_key_values is None:
|
||||
patch_embeddings = [
|
||||
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
|
||||
.squeeze(0)
|
||||
.to(inputs_embeds.device)
|
||||
for patch in image_patches
|
||||
]
|
||||
inputs_embeds = self.gather_continuous_embeddings(
|
||||
word_embeddings=inputs_embeds,
|
||||
continuous_embeddings=patch_embeddings,
|
||||
image_patch_input_indices=image_patches_indices,
|
||||
)
|
||||
```
|
||||
|
||||
The number of placeholder feature tokens for the `i`th item in the batch is `patch_embeddings[i].shape[0]`,
|
||||
which is the same as `image_patches[i].shape[0]`, i.e. `num_total_patches`.
|
||||
@ -312,92 +344,98 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
In `FuyuImageProcessor.preprocess`, the images are resized and padded to the target `FuyuImageProcessor.size`,
|
||||
returning the dimensions after resizing (but before padding) as metadata.
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
|
||||
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
|
||||
batch_images = image_encoding["images"]
|
||||
image_unpadded_heights = image_encoding["image_unpadded_heights"]
|
||||
image_unpadded_widths = image_encoding["image_unpadded_widths"]
|
||||
??? Code
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L
|
||||
if do_resize:
|
||||
batch_images = [
|
||||
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
|
||||
for images in batch_images
|
||||
]
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L541-L544
|
||||
image_encoding = self.image_processor.preprocess(images, **output_kwargs["images_kwargs"])
|
||||
batch_images = image_encoding["images"]
|
||||
image_unpadded_heights = image_encoding["image_unpadded_heights"]
|
||||
image_unpadded_widths = image_encoding["image_unpadded_widths"]
|
||||
|
||||
image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
|
||||
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
|
||||
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
|
||||
|
||||
if do_pad:
|
||||
batch_images = [
|
||||
[
|
||||
self.pad_image(
|
||||
image,
|
||||
size=size,
|
||||
mode=padding_mode,
|
||||
constant_values=padding_value,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L480-L
|
||||
if do_resize:
|
||||
batch_images = [
|
||||
[self.resize(image, size=size, input_data_format=input_data_format) for image in images]
|
||||
for images in batch_images
|
||||
]
|
||||
for images in batch_images
|
||||
]
|
||||
```
|
||||
|
||||
image_sizes = [get_image_size(images[0], channel_dim=input_data_format) for images in batch_images]
|
||||
image_unpadded_heights = [[image_size[0]] for image_size in image_sizes]
|
||||
image_unpadded_widths = [[image_size[1]] for image_size in image_sizes]
|
||||
|
||||
if do_pad:
|
||||
batch_images = [
|
||||
[
|
||||
self.pad_image(
|
||||
image,
|
||||
size=size,
|
||||
mode=padding_mode,
|
||||
constant_values=padding_value,
|
||||
input_data_format=input_data_format,
|
||||
)
|
||||
for image in images
|
||||
]
|
||||
for images in batch_images
|
||||
]
|
||||
```
|
||||
|
||||
In `FuyuImageProcessor.preprocess_with_tokenizer_info`, the images are split into patches based on this metadata:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
|
||||
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=tensor_batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=image_placeholder_id,
|
||||
image_newline_id=image_newline_id,
|
||||
variable_sized=True,
|
||||
)
|
||||
??? Code
|
||||
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658
|
||||
image_height, image_width = image.shape[1], image.shape[2]
|
||||
if variable_sized: # variable_sized=True
|
||||
new_h = min(
|
||||
image_height,
|
||||
math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L425
|
||||
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=tensor_batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=image_placeholder_id,
|
||||
image_newline_id=image_newline_id,
|
||||
variable_sized=True,
|
||||
)
|
||||
new_w = min(
|
||||
image_width,
|
||||
math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
|
||||
)
|
||||
image = image[:, :new_h, :new_w]
|
||||
image_height, image_width = new_h, new_w
|
||||
|
||||
num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
|
||||
tensor_of_image_ids = torch.full(
|
||||
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
|
||||
)
|
||||
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
|
||||
assert num_patches == patches.shape[0]
|
||||
```
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L638-L658
|
||||
image_height, image_width = image.shape[1], image.shape[2]
|
||||
if variable_sized: # variable_sized=True
|
||||
new_h = min(
|
||||
image_height,
|
||||
math.ceil(image_unpadded_h[batch_index, subseq_index] / patch_height) * patch_height,
|
||||
)
|
||||
new_w = min(
|
||||
image_width,
|
||||
math.ceil(image_unpadded_w[batch_index, subseq_index] / patch_width) * patch_width,
|
||||
)
|
||||
image = image[:, :new_h, :new_w]
|
||||
image_height, image_width = new_h, new_w
|
||||
|
||||
num_patches = self.get_num_patches(image_height=image_height, image_width=image_width)
|
||||
tensor_of_image_ids = torch.full(
|
||||
[num_patches], image_placeholder_id, dtype=torch.int32, device=image_input.device
|
||||
)
|
||||
patches = self.patchify_image(image=image.unsqueeze(0)).squeeze(0)
|
||||
assert num_patches == patches.shape[0]
|
||||
```
|
||||
|
||||
The number of patches is in turn defined by `FuyuImageProcessor.get_num_patches`:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
??? Code
|
||||
|
||||
if image_height % patch_height != 0:
|
||||
raise ValueError(f"{image_height=} must be divisible by {patch_height}")
|
||||
if image_width % patch_width != 0:
|
||||
raise ValueError(f"{image_width=} must be divisible by {patch_width}")
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/image_processing_fuyu.py#L552-L562
|
||||
patch_size = patch_size if patch_size is not None else self.patch_size
|
||||
patch_height, patch_width = self.patch_size["height"], self.patch_size["width"]
|
||||
|
||||
num_patches_per_dim_h = image_height // patch_height
|
||||
num_patches_per_dim_w = image_width // patch_width
|
||||
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
|
||||
```
|
||||
if image_height % patch_height != 0:
|
||||
raise ValueError(f"{image_height=} must be divisible by {patch_height}")
|
||||
if image_width % patch_width != 0:
|
||||
raise ValueError(f"{image_width=} must be divisible by {patch_width}")
|
||||
|
||||
num_patches_per_dim_h = image_height // patch_height
|
||||
num_patches_per_dim_w = image_width // patch_width
|
||||
num_patches = num_patches_per_dim_h * num_patches_per_dim_w
|
||||
```
|
||||
|
||||
These image patches correspond to placeholder tokens (`|SPEAKER|`). So, we just need to maximize the number of image patches. Since input images are first resized
|
||||
to fit within `image_processor.size`, we can maximize the number of image patches by inputting an image with size equal to `image_processor.size`.
|
||||
@ -419,23 +457,25 @@ Assuming that the memory usage increases with the number of tokens, the dummy in
|
||||
|
||||
For the multimodal image profiling data, the logic is very similar to LLaVA:
|
||||
|
||||
```python
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
??? Code
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
```
|
||||
```python
|
||||
def get_dummy_mm_data(
|
||||
self,
|
||||
seq_len: int,
|
||||
mm_counts: Mapping[str, int],
|
||||
) -> MultiModalDataDict:
|
||||
target_width, target_height = \
|
||||
self.info.get_image_size_with_most_features()
|
||||
num_images = mm_counts.get("image", 0)
|
||||
|
||||
return {
|
||||
"image":
|
||||
self._get_dummy_images(width=target_width,
|
||||
height=target_height,
|
||||
num_images=num_images)
|
||||
}
|
||||
```
|
||||
|
||||
## 4. Specify processing details
|
||||
|
||||
@ -455,6 +495,7 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
The output of `CLIPImageProcessor` is a simple tensor with shape
|
||||
`(num_images, num_channels, image_height, image_width)`:
|
||||
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.47.1/src/transformers/models/clip/image_processing_clip.py#L339-L345
|
||||
images = [
|
||||
@ -505,40 +546,49 @@ return a schema of the tensors outputted by the HF processor that are related to
|
||||
In order to support the use of [MultiModalFieldConfig.batched][] like in LLaVA,
|
||||
we remove the extra batch dimension by overriding [BaseMultiModalProcessor._call_hf_processor][]:
|
||||
|
||||
```python
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
)
|
||||
??? Code
|
||||
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
if image_patches is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
```python
|
||||
def _call_hf_processor(
|
||||
self,
|
||||
prompt: str,
|
||||
mm_data: Mapping[str, object],
|
||||
mm_kwargs: Mapping[str, object],
|
||||
tok_kwargs: Mapping[str, object],
|
||||
) -> BatchFeature:
|
||||
processed_outputs = super()._call_hf_processor(
|
||||
prompt=prompt,
|
||||
mm_data=mm_data,
|
||||
mm_kwargs=mm_kwargs,
|
||||
tok_kwargs=tok_kwargs,
|
||||
)
|
||||
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
assert (isinstance(image_patches, list)
|
||||
and len(image_patches) == 1)
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
image_patches = processed_outputs.get("image_patches")
|
||||
if image_patches is not None:
|
||||
images = mm_data["images"]
|
||||
assert isinstance(images, list)
|
||||
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
# Original output: (1, num_images, Pn, Px * Py * C)
|
||||
# New output: (num_images, Pn, Px * Py * C)
|
||||
assert (isinstance(image_patches, list)
|
||||
and len(image_patches) == 1)
|
||||
assert (isinstance(image_patches[0], torch.Tensor)
|
||||
and len(image_patches[0]) == len(images))
|
||||
|
||||
return processed_outputs
|
||||
```
|
||||
processed_outputs["image_patches"] = image_patches[0]
|
||||
|
||||
return processed_outputs
|
||||
```
|
||||
|
||||
!!! note
|
||||
Our [actual code](gh-file:vllm/model_executor/models/fuyu.py) has special handling
|
||||
for text-only inputs to prevent unnecessary warnings from HF processor.
|
||||
|
||||
!!! note
|
||||
The `_call_hf_processor` method specifies both `mm_kwargs` and `tok_kwargs` for
|
||||
processing. `mm_kwargs` is used to both initialize and call the huggingface
|
||||
processor, whereas `tok_kwargs` is only used to call the huggingface processor.
|
||||
|
||||
This lets us override [_get_mm_fields_config][vllm.multimodal.processing.BaseMultiModalProcessor._get_mm_fields_config] as follows:
|
||||
|
||||
```python
|
||||
@ -573,35 +623,37 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
It simply repeats each input `image_token` a number of times equal to the number of placeholder feature tokens (`num_image_tokens`).
|
||||
Based on this, we override [_get_prompt_updates][vllm.multimodal.processing.BaseMultiModalProcessor._get_prompt_updates] as follows:
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
??? Code
|
||||
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
image_token_id = hf_config.image_token_index
|
||||
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
|
||||
return [image_token_id] * num_image_tokens
|
||||
image_size = images.get_image_size(item_idx)
|
||||
num_image_tokens = self.info.get_num_image_tokens(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
```
|
||||
return [image_token_id] * num_image_tokens
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[image_token_id],
|
||||
replacement=get_replacement,
|
||||
),
|
||||
]
|
||||
```
|
||||
|
||||
=== "Handling additional tokens: Fuyu"
|
||||
|
||||
@ -616,117 +668,90 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
|
||||
We define a helper function to return `ncols` and `nrows` directly:
|
||||
|
||||
```python
|
||||
def get_image_feature_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
image_processor = self.get_image_processor()
|
||||
target_width = image_processor.size["width"]
|
||||
target_height = image_processor.size["height"]
|
||||
patch_width = image_processor.patch_size["width"]
|
||||
patch_height = image_processor.patch_size["height"]
|
||||
??? Code
|
||||
|
||||
if not (image_width <= target_width and image_height <= target_height):
|
||||
height_scale_factor = target_height / image_height
|
||||
width_scale_factor = target_width / image_width
|
||||
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||
```python
|
||||
def get_image_feature_grid_size(
|
||||
self,
|
||||
*,
|
||||
image_width: int,
|
||||
image_height: int,
|
||||
) -> tuple[int, int]:
|
||||
image_processor = self.get_image_processor()
|
||||
target_width = image_processor.size["width"]
|
||||
target_height = image_processor.size["height"]
|
||||
patch_width = image_processor.patch_size["width"]
|
||||
patch_height = image_processor.patch_size["height"]
|
||||
|
||||
image_height = int(image_height * optimal_scale_factor)
|
||||
image_width = int(image_width * optimal_scale_factor)
|
||||
if not (image_width <= target_width and image_height <= target_height):
|
||||
height_scale_factor = target_height / image_height
|
||||
width_scale_factor = target_width / image_width
|
||||
optimal_scale_factor = min(height_scale_factor, width_scale_factor)
|
||||
|
||||
ncols = math.ceil(image_width / patch_width)
|
||||
nrows = math.ceil(image_height / patch_height)
|
||||
return ncols, nrows
|
||||
```
|
||||
image_height = int(image_height * optimal_scale_factor)
|
||||
image_width = int(image_width * optimal_scale_factor)
|
||||
|
||||
ncols = math.ceil(image_width / patch_width)
|
||||
nrows = math.ceil(image_height / patch_height)
|
||||
return ncols, nrows
|
||||
```
|
||||
|
||||
Based on this, we can initially define our replacement tokens as:
|
||||
|
||||
```python
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
??? Code
|
||||
|
||||
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
```python
|
||||
def get_replacement(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
# `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|`
|
||||
# `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|`
|
||||
return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
|
||||
```
|
||||
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
|
||||
# `_IMAGE_TOKEN_ID` corresponds to `|SPEAKER|`
|
||||
# `_NEWLINE_TOKEN_ID` corresponds to `|NEWLINE|`
|
||||
return ([_IMAGE_TOKEN_ID] * ncols + [_NEWLINE_TOKEN_ID]) * nrows
|
||||
```
|
||||
|
||||
However, this is not entirely correct. After `FuyuImageProcessor.preprocess_with_tokenizer_info` is called,
|
||||
a BOS token (`<s>`) is also added to the promopt:
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
|
||||
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=tensor_batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=image_placeholder_id,
|
||||
image_newline_id=image_newline_id,
|
||||
variable_sized=True,
|
||||
)
|
||||
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
tokenizer=self.tokenizer,
|
||||
prompts=prompts,
|
||||
scale_factors=scale_factors,
|
||||
max_tokens_to_generate=self.max_tokens_to_generate,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
add_BOS=True,
|
||||
add_beginning_of_answer_token=True,
|
||||
)
|
||||
```
|
||||
??? Code
|
||||
|
||||
```python
|
||||
# https://github.com/huggingface/transformers/blob/v4.48.3/src/transformers/models/fuyu/processing_fuyu.py#L417-L435
|
||||
model_image_input = self.image_processor.preprocess_with_tokenizer_info(
|
||||
image_input=tensor_batch_images,
|
||||
image_present=image_present,
|
||||
image_unpadded_h=image_unpadded_heights,
|
||||
image_unpadded_w=image_unpadded_widths,
|
||||
image_placeholder_id=image_placeholder_id,
|
||||
image_newline_id=image_newline_id,
|
||||
variable_sized=True,
|
||||
)
|
||||
prompt_tokens, prompts_length = _tokenize_prompts_with_image_and_batch(
|
||||
tokenizer=self.tokenizer,
|
||||
prompts=prompts,
|
||||
scale_factors=scale_factors,
|
||||
max_tokens_to_generate=self.max_tokens_to_generate,
|
||||
max_position_embeddings=self.max_position_embeddings,
|
||||
add_BOS=True,
|
||||
add_beginning_of_answer_token=True,
|
||||
)
|
||||
```
|
||||
|
||||
To assign the vision embeddings to only the image tokens, instead of a string
|
||||
you can return an instance of [PromptUpdateDetails][vllm.multimodal.processing.PromptUpdateDetails]:
|
||||
|
||||
```python
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id # `<s>`
|
||||
assert isinstance(bos_token_id, int)
|
||||
??? Code
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
)
|
||||
```
|
||||
|
||||
Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt,
|
||||
we can search for it to conduct the replacement at the start of the string:
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
```python
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
bos_token_id = hf_config.bos_token_id # `<s>`
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(eot_token_id, int)
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
@ -742,15 +767,52 @@ Each [PromptUpdate][vllm.multimodal.processing.PromptUpdate] instance specifies
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
)
|
||||
```
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[eot_token_id],
|
||||
replacement=get_replacement_fuyu,
|
||||
)
|
||||
]
|
||||
```
|
||||
Finally, noticing that the HF processor removes the `|ENDOFTEXT|` token from the tokenized prompt,
|
||||
we can search for it to conduct the replacement at the start of the string:
|
||||
|
||||
??? Code
|
||||
|
||||
```python
|
||||
def _get_prompt_updates(
|
||||
self,
|
||||
mm_items: MultiModalDataItems,
|
||||
hf_processor_mm_kwargs: Mapping[str, object],
|
||||
out_mm_kwargs: MultiModalKwargs,
|
||||
) -> Sequence[PromptUpdate]:
|
||||
hf_config = self.info.get_hf_config()
|
||||
bos_token_id = hf_config.bos_token_id
|
||||
assert isinstance(bos_token_id, int)
|
||||
|
||||
tokenizer = self.info.get_tokenizer()
|
||||
eot_token_id = tokenizer.bos_token_id
|
||||
assert isinstance(eot_token_id, int)
|
||||
|
||||
def get_replacement_fuyu(item_idx: int):
|
||||
images = mm_items.get_items("image", ImageProcessorItems)
|
||||
image_size = images.get_image_size(item_idx)
|
||||
|
||||
ncols, nrows = self.info.get_image_feature_grid_size(
|
||||
image_width=image_size.width,
|
||||
image_height=image_size.height,
|
||||
)
|
||||
image_tokens = ([_IMAGE_TOKEN_ID] * ncols +
|
||||
[_NEWLINE_TOKEN_ID]) * nrows
|
||||
|
||||
return PromptUpdateDetails.select_token_id(
|
||||
image_tokens + [bos_token_id],
|
||||
embed_token_id=_IMAGE_TOKEN_ID,
|
||||
)
|
||||
|
||||
return [
|
||||
PromptReplacement(
|
||||
modality="image",
|
||||
target=[eot_token_id],
|
||||
replacement=get_replacement_fuyu,
|
||||
)
|
||||
]
|
||||
```
|
||||
|
||||
## 5. Register processor-related classes
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Registering a Model to vLLM
|
||||
title: Registering a Model
|
||||
---
|
||||
[](){ #new-model-registration }
|
||||
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Writing Unit Tests
|
||||
title: Unit Testing
|
||||
---
|
||||
[](){ #new-model-tests }
|
||||
|
||||
|
||||
@ -30,13 +30,21 @@ Refer to <gh-file:examples/offline_inference/simple_profiling.py> for an example
|
||||
#### OpenAI Server
|
||||
|
||||
```bash
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile python -m vllm.entrypoints.openai.api_server --model meta-llama/Meta-Llama-3-70B
|
||||
VLLM_TORCH_PROFILER_DIR=./vllm_profile \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model meta-llama/Meta-Llama-3-70B
|
||||
```
|
||||
|
||||
benchmark_serving.py:
|
||||
|
||||
```bash
|
||||
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Meta-Llama-3-70B --dataset-name sharegpt --dataset-path sharegpt.json --profile --num-prompts 2
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model meta-llama/Meta-Llama-3-70B \
|
||||
--dataset-name sharegpt \
|
||||
--dataset-path sharegpt.json \
|
||||
--profile \
|
||||
--num-prompts 2
|
||||
```
|
||||
|
||||
## Profile with NVIDIA Nsight Systems
|
||||
@ -64,7 +72,16 @@ For basic usage, you can just append `nsys profile -o report.nsys-rep --trace-fo
|
||||
The following is an example using the `benchmarks/benchmark_latency.py` script:
|
||||
|
||||
```bash
|
||||
nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node python benchmarks/benchmark_latency.py --model meta-llama/Llama-3.1-8B-Instruct --num-iters-warmup 5 --num-iters 1 --batch-size 16 --input-len 512 --output-len 8
|
||||
nsys profile -o report.nsys-rep \
|
||||
--trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node \
|
||||
python benchmarks/benchmark_latency.py \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--num-iters-warmup 5 \
|
||||
--num-iters 1 \
|
||||
--batch-size 16 \
|
||||
--input-len 512 \
|
||||
--output-len 8
|
||||
```
|
||||
|
||||
#### OpenAI Server
|
||||
@ -73,10 +90,21 @@ To profile the server, you will want to prepend your `vllm serve` command with `
|
||||
|
||||
```bash
|
||||
# server
|
||||
nsys profile -o report.nsys-rep --trace-fork-before-exec=true --cuda-graph-trace=node --delay 30 --duration 60 vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
nsys profile -o report.nsys-rep \
|
||||
--trace-fork-before-exec=true \
|
||||
--cuda-graph-trace=node \
|
||||
--delay 30 \
|
||||
--duration 60 \
|
||||
vllm serve meta-llama/Llama-3.1-8B-Instruct
|
||||
|
||||
# client
|
||||
python benchmarks/benchmark_serving.py --backend vllm --model meta-llama/Llama-3.1-8B-Instruct --num-prompts 1 --dataset-name random --random-input 1024 --random-output 512
|
||||
python benchmarks/benchmark_serving.py \
|
||||
--backend vllm \
|
||||
--model meta-llama/Llama-3.1-8B-Instruct \
|
||||
--num-prompts 1 \
|
||||
--dataset-name random \
|
||||
--random-input 1024 \
|
||||
--random-output 512
|
||||
```
|
||||
|
||||
In practice, you should set the `--duration` argument to a large value. Whenever you want the server to stop profiling, run:
|
||||
@ -97,26 +125,26 @@ to manually kill the profiler and generate your `nsys-rep` report.
|
||||
|
||||
You can view these profiles either as summaries in the CLI, using `nsys stats [profile-file]`, or in the GUI by installing Nsight [locally following the directions here](https://developer.nvidia.com/nsight-systems/get-started).
|
||||
|
||||
CLI example:
|
||||
??? CLI example
|
||||
|
||||
```bash
|
||||
nsys stats report1.nsys-rep
|
||||
...
|
||||
** CUDA GPU Kernel Summary (cuda_gpu_kern_sum):
|
||||
```bash
|
||||
nsys stats report1.nsys-rep
|
||||
...
|
||||
** CUDA GPU Kernel Summary (cuda_gpu_kern_sum):
|
||||
|
||||
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
|
||||
-------- --------------- --------- ----------- ----------- -------- --------- ----------- ----------------------------------------------------------------------------------------------------
|
||||
46.3 10,327,352,338 17,505 589,965.9 144,383.0 27,040 3,126,460 944,263.8 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_of…
|
||||
14.8 3,305,114,764 5,152 641,520.7 293,408.0 287,296 2,822,716 867,124.9 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize256x128x64_warpgroupsize2x1x1_execute_segment_k_of…
|
||||
12.1 2,692,284,876 14,280 188,535.4 83,904.0 19,328 2,862,237 497,999.9 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x1x1_execute_segment_k_off…
|
||||
9.5 2,116,600,578 33,920 62,399.8 21,504.0 15,326 2,532,285 290,954.1 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsize1x1x1_execute_segment_k_off_…
|
||||
5.0 1,119,749,165 18,912 59,208.4 9,056.0 6,784 2,578,366 271,581.7 void vllm::act_and_mul_kernel<c10::BFloat16, &vllm::silu_kernel<c10::BFloat16>, (bool)1>(T1 *, cons…
|
||||
4.1 916,662,515 21,312 43,011.6 19,776.0 8,928 2,586,205 199,790.1 void cutlass::device_kernel<flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<flash::CollectiveMa…
|
||||
2.6 587,283,113 37,824 15,526.7 3,008.0 2,719 2,517,756 139,091.1 std::enable_if<T2>(int)0&&vllm::_typeConvert<T1>::exists, void>::type vllm::fused_add_rms_norm_kern…
|
||||
1.9 418,362,605 18,912 22,121.5 3,871.0 3,328 2,523,870 175,248.2 void vllm::rotary_embedding_kernel<c10::BFloat16, (bool)1>(const long *, T1 *, T1 *, const T1 *, in…
|
||||
0.7 167,083,069 18,880 8,849.7 2,240.0 1,471 2,499,996 101,436.1 void vllm::reshape_and_cache_flash_kernel<__nv_bfloat16, __nv_bfloat16, (vllm::Fp8KVCacheDataType)0…
|
||||
...
|
||||
```
|
||||
Time (%) Total Time (ns) Instances Avg (ns) Med (ns) Min (ns) Max (ns) StdDev (ns) Name
|
||||
-------- --------------- --------- ----------- ----------- -------- --------- ----------- ----------------------------------------------------------------------------------------------------
|
||||
46.3 10,327,352,338 17,505 589,965.9 144,383.0 27,040 3,126,460 944,263.8 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize128x128x64_warpgroupsize1x1x1_execute_segment_k_of…
|
||||
14.8 3,305,114,764 5,152 641,520.7 293,408.0 287,296 2,822,716 867,124.9 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize256x128x64_warpgroupsize2x1x1_execute_segment_k_of…
|
||||
12.1 2,692,284,876 14,280 188,535.4 83,904.0 19,328 2,862,237 497,999.9 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x1x1_execute_segment_k_off…
|
||||
9.5 2,116,600,578 33,920 62,399.8 21,504.0 15,326 2,532,285 290,954.1 sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x64x64_warpgroupsize1x1x1_execute_segment_k_off_…
|
||||
5.0 1,119,749,165 18,912 59,208.4 9,056.0 6,784 2,578,366 271,581.7 void vllm::act_and_mul_kernel<c10::BFloat16, &vllm::silu_kernel<c10::BFloat16>, (bool)1>(T1 *, cons…
|
||||
4.1 916,662,515 21,312 43,011.6 19,776.0 8,928 2,586,205 199,790.1 void cutlass::device_kernel<flash::enable_sm90_or_later<flash::FlashAttnFwdSm90<flash::CollectiveMa…
|
||||
2.6 587,283,113 37,824 15,526.7 3,008.0 2,719 2,517,756 139,091.1 std::enable_if<T2>(int)0&&vllm::_typeConvert<T1>::exists, void>::type vllm::fused_add_rms_norm_kern…
|
||||
1.9 418,362,605 18,912 22,121.5 3,871.0 3,328 2,523,870 175,248.2 void vllm::rotary_embedding_kernel<c10::BFloat16, (bool)1>(const long *, T1 *, T1 *, const T1 *, in…
|
||||
0.7 167,083,069 18,880 8,849.7 2,240.0 1,471 2,499,996 101,436.1 void vllm::reshape_and_cache_flash_kernel<__nv_bfloat16, __nv_bfloat16, (vllm::Fp8KVCacheDataType)0…
|
||||
...
|
||||
```
|
||||
|
||||
GUI example:
|
||||
|
||||
|
||||
@ -10,7 +10,7 @@ title: Using Docker
|
||||
vLLM offers an official Docker image for deployment.
|
||||
The image can be used to run OpenAI compatible server and is available on Docker Hub as [vllm/vllm-openai](https://hub.docker.com/r/vllm/vllm-openai/tags).
|
||||
|
||||
```console
|
||||
```bash
|
||||
docker run --runtime nvidia --gpus all \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=<secret>" \
|
||||
@ -22,7 +22,7 @@ docker run --runtime nvidia --gpus all \
|
||||
|
||||
This image can also be used with other container engines such as [Podman](https://podman.io/).
|
||||
|
||||
```console
|
||||
```bash
|
||||
podman run --gpus all \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
@ -71,7 +71,7 @@ You can add any other [engine-args][engine-args] you need after the image tag (`
|
||||
|
||||
You can build and run vLLM from source via the provided <gh-file:docker/Dockerfile>. To build vLLM:
|
||||
|
||||
```console
|
||||
```bash
|
||||
# optionally specifies: --build-arg max_jobs=8 --build-arg nvcc_threads=2
|
||||
DOCKER_BUILDKIT=1 docker build . \
|
||||
--target vllm-openai \
|
||||
@ -97,26 +97,28 @@ of PyTorch Nightly and should be considered **experimental**. Using the flag `--
|
||||
flags to speed up build process. However, ensure your `max_jobs` is substantially larger than `nvcc_threads` to get the most benefits.
|
||||
Keep an eye on memory usage with parallel jobs as it can be substantial (see example below).
|
||||
|
||||
```console
|
||||
# Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB)
|
||||
python3 use_existing_torch.py
|
||||
DOCKER_BUILDKIT=1 docker build . \
|
||||
--file docker/Dockerfile \
|
||||
--target vllm-openai \
|
||||
--platform "linux/arm64" \
|
||||
-t vllm/vllm-gh200-openai:latest \
|
||||
--build-arg max_jobs=66 \
|
||||
--build-arg nvcc_threads=2 \
|
||||
--build-arg torch_cuda_arch_list="9.0 10.0+PTX" \
|
||||
--build-arg vllm_fa_cmake_gpu_arches="90-real"
|
||||
```
|
||||
??? Command
|
||||
|
||||
```bash
|
||||
# Example of building on Nvidia GH200 server. (Memory usage: ~15GB, Build time: ~1475s / ~25 min, Image size: 6.93GB)
|
||||
python3 use_existing_torch.py
|
||||
DOCKER_BUILDKIT=1 docker build . \
|
||||
--file docker/Dockerfile \
|
||||
--target vllm-openai \
|
||||
--platform "linux/arm64" \
|
||||
-t vllm/vllm-gh200-openai:latest \
|
||||
--build-arg max_jobs=66 \
|
||||
--build-arg nvcc_threads=2 \
|
||||
--build-arg torch_cuda_arch_list="9.0 10.0+PTX" \
|
||||
--build-arg vllm_fa_cmake_gpu_arches="90-real"
|
||||
```
|
||||
|
||||
!!! note
|
||||
If you are building the `linux/arm64` image on a non-ARM host (e.g., an x86_64 machine), you need to ensure your system is set up for cross-compilation using QEMU. This allows your host machine to emulate ARM64 execution.
|
||||
|
||||
Run the following command on your host machine to register QEMU user static handlers:
|
||||
|
||||
```console
|
||||
```bash
|
||||
docker run --rm --privileged multiarch/qemu-user-static --reset -p yes
|
||||
```
|
||||
|
||||
@ -126,7 +128,7 @@ DOCKER_BUILDKIT=1 docker build . \
|
||||
|
||||
To run vLLM with the custom-built Docker image:
|
||||
|
||||
```console
|
||||
```bash
|
||||
docker run --runtime nvidia --gpus all \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
-p 8000:8000 \
|
||||
|
||||
@ -15,7 +15,7 @@ It allows you to deploy a large language model (LLM) server with vLLM as the bac
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```console
|
||||
```bash
|
||||
vllm serve Qwen/Qwen1.5-32B-Chat-AWQ --max-model-len 4096
|
||||
```
|
||||
|
||||
|
||||
@ -11,7 +11,7 @@ title: AutoGen
|
||||
|
||||
- Setup [AutoGen](https://microsoft.github.io/autogen/0.2/docs/installation/) environment
|
||||
|
||||
```console
|
||||
```bash
|
||||
pip install vllm
|
||||
|
||||
# Install AgentChat and OpenAI client from Extensions
|
||||
@ -23,58 +23,60 @@ pip install -U "autogen-agentchat" "autogen-ext[openai]"
|
||||
|
||||
- Start the vLLM server with the supported chat completion model, e.g.
|
||||
|
||||
```console
|
||||
```bash
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model mistralai/Mistral-7B-Instruct-v0.2
|
||||
```
|
||||
|
||||
- Call it with AutoGen:
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from autogen_core.models import UserMessage
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_core.models import ModelFamily
|
||||
??? Code
|
||||
|
||||
```python
|
||||
import asyncio
|
||||
from autogen_core.models import UserMessage
|
||||
from autogen_ext.models.openai import OpenAIChatCompletionClient
|
||||
from autogen_core.models import ModelFamily
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
base_url="http://{your-vllm-host-ip}:{your-vllm-host-port}/v1",
|
||||
api_key="EMPTY",
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.MISTRAL,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
async def main() -> None:
|
||||
# Create a model client
|
||||
model_client = OpenAIChatCompletionClient(
|
||||
model="mistralai/Mistral-7B-Instruct-v0.2",
|
||||
base_url="http://{your-vllm-host-ip}:{your-vllm-host-port}/v1",
|
||||
api_key="EMPTY",
|
||||
model_info={
|
||||
"vision": False,
|
||||
"function_calling": False,
|
||||
"json_output": False,
|
||||
"family": ModelFamily.MISTRAL,
|
||||
"structured_output": True,
|
||||
},
|
||||
)
|
||||
|
||||
messages = [UserMessage(content="Write a very short story about a dragon.", source="user")]
|
||||
messages = [UserMessage(content="Write a very short story about a dragon.", source="user")]
|
||||
|
||||
# Create a stream.
|
||||
stream = model_client.create_stream(messages=messages)
|
||||
# Create a stream.
|
||||
stream = model_client.create_stream(messages=messages)
|
||||
|
||||
# Iterate over the stream and print the responses.
|
||||
print("Streamed responses:")
|
||||
async for response in stream:
|
||||
if isinstance(response, str):
|
||||
# A partial response is a string.
|
||||
print(response, flush=True, end="")
|
||||
else:
|
||||
# The last response is a CreateResult object with the complete message.
|
||||
print("\n\n------------\n")
|
||||
print("The complete response:", flush=True)
|
||||
print(response.content, flush=True)
|
||||
# Iterate over the stream and print the responses.
|
||||
print("Streamed responses:")
|
||||
async for response in stream:
|
||||
if isinstance(response, str):
|
||||
# A partial response is a string.
|
||||
print(response, flush=True, end="")
|
||||
else:
|
||||
# The last response is a CreateResult object with the complete message.
|
||||
print("\n\n------------\n")
|
||||
print("The complete response:", flush=True)
|
||||
print(response.content, flush=True)
|
||||
|
||||
# Close the client when done.
|
||||
await model_client.close()
|
||||
# Close the client when done.
|
||||
await model_client.close()
|
||||
|
||||
|
||||
asyncio.run(main())
|
||||
```
|
||||
asyncio.run(main())
|
||||
```
|
||||
|
||||
For details, see the tutorial:
|
||||
|
||||
|
||||
@ -11,14 +11,14 @@ vLLM can be run on a cloud based GPU machine with [Cerebrium](https://www.cerebr
|
||||
|
||||
To install the Cerebrium client, run:
|
||||
|
||||
```console
|
||||
```bash
|
||||
pip install cerebrium
|
||||
cerebrium login
|
||||
```
|
||||
|
||||
Next, create your Cerebrium project, run:
|
||||
|
||||
```console
|
||||
```bash
|
||||
cerebrium init vllm-project
|
||||
```
|
||||
|
||||
@ -34,75 +34,81 @@ vllm = "latest"
|
||||
|
||||
Next, let us add our code to handle inference for the LLM of your choice (`mistralai/Mistral-7B-Instruct-v0.1` for this example), add the following code to your `main.py`:
|
||||
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
??? Code
|
||||
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1")
|
||||
```python
|
||||
from vllm import LLM, SamplingParams
|
||||
|
||||
def run(prompts: list[str], temperature: float = 0.8, top_p: float = 0.95):
|
||||
llm = LLM(model="mistralai/Mistral-7B-Instruct-v0.1")
|
||||
|
||||
sampling_params = SamplingParams(temperature=temperature, top_p=top_p)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
def run(prompts: list[str], temperature: float = 0.8, top_p: float = 0.95):
|
||||
|
||||
# Print the outputs.
|
||||
results = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
results.append({"prompt": prompt, "generated_text": generated_text})
|
||||
sampling_params = SamplingParams(temperature=temperature, top_p=top_p)
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
|
||||
return {"results": results}
|
||||
```
|
||||
# Print the outputs.
|
||||
results = []
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
results.append({"prompt": prompt, "generated_text": generated_text})
|
||||
|
||||
return {"results": results}
|
||||
```
|
||||
|
||||
Then, run the following code to deploy it to the cloud:
|
||||
|
||||
```console
|
||||
```bash
|
||||
cerebrium deploy
|
||||
```
|
||||
|
||||
If successful, you should be returned a CURL command that you can call inference against. Just remember to end the url with the function name you are calling (in our case`/run`)
|
||||
|
||||
```python
|
||||
curl -X POST https://api.cortex.cerebrium.ai/v4/p-xxxxxx/vllm/run \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: <JWT TOKEN>' \
|
||||
--data '{
|
||||
"prompts": [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is"
|
||||
]
|
||||
}'
|
||||
```
|
||||
??? Command
|
||||
|
||||
```python
|
||||
curl -X POST https://api.cortex.cerebrium.ai/v4/p-xxxxxx/vllm/run \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: <JWT TOKEN>' \
|
||||
--data '{
|
||||
"prompts": [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is"
|
||||
]
|
||||
}'
|
||||
```
|
||||
|
||||
You should get a response like:
|
||||
|
||||
```python
|
||||
{
|
||||
"run_id": "52911756-3066-9ae8-bcc9-d9129d1bd262",
|
||||
"result": {
|
||||
"result": [
|
||||
{
|
||||
"prompt": "Hello, my name is",
|
||||
"generated_text": " Sarah, and I'm a teacher. I teach elementary school students. One of"
|
||||
},
|
||||
{
|
||||
"prompt": "The president of the United States is",
|
||||
"generated_text": " elected every four years. This is a democratic system.\n\n5. What"
|
||||
},
|
||||
{
|
||||
"prompt": "The capital of France is",
|
||||
"generated_text": " Paris.\n"
|
||||
},
|
||||
{
|
||||
"prompt": "The future of AI is",
|
||||
"generated_text": " bright, but it's important to approach it with a balanced and nuanced perspective."
|
||||
}
|
||||
]
|
||||
},
|
||||
"run_time_ms": 152.53663063049316
|
||||
}
|
||||
```
|
||||
??? Response
|
||||
|
||||
```python
|
||||
{
|
||||
"run_id": "52911756-3066-9ae8-bcc9-d9129d1bd262",
|
||||
"result": {
|
||||
"result": [
|
||||
{
|
||||
"prompt": "Hello, my name is",
|
||||
"generated_text": " Sarah, and I'm a teacher. I teach elementary school students. One of"
|
||||
},
|
||||
{
|
||||
"prompt": "The president of the United States is",
|
||||
"generated_text": " elected every four years. This is a democratic system.\n\n5. What"
|
||||
},
|
||||
{
|
||||
"prompt": "The capital of France is",
|
||||
"generated_text": " Paris.\n"
|
||||
},
|
||||
{
|
||||
"prompt": "The future of AI is",
|
||||
"generated_text": " bright, but it's important to approach it with a balanced and nuanced perspective."
|
||||
}
|
||||
]
|
||||
},
|
||||
"run_time_ms": 152.53663063049316
|
||||
}
|
||||
```
|
||||
|
||||
You now have an autoscaling endpoint where you only pay for the compute you use!
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user